Dataset¶
- class metatrain.utils.data.dataset.DatasetInfo(length_unit: str | None, atomic_types: List[int], targets: Dict[str, TargetInfo], extra_data: Dict[str, TargetInfo] | None = None)[source]¶
Bases:
object
A class that contains information about datasets.
This class is used to communicate additional dataset details to the training functions of the individual models.
- Parameters:
length_unit (str | None) – Unit of length used in the dataset. Examples are
"angstrom"
or"nanometer"
. If None, the unit will be set to the empty string.atomic_types (List[int]) – List containing all integer atomic types present in the dataset.
atomic_types
will be stored as a sorted list of unique atomic types.targets (Dict[str, TargetInfo]) – Information about targets in the dataset.
extra_data (Dict[str, TargetInfo] | None) – Optional dictionary containing additional data that is not used as a target, but is still relevant to the dataset.
- property device: device | None¶
Return the device where the tensors of DatasetInfo are located.
This function only checks the device of the first target and assumes that all targets and extra data are on the same device. This is guaranteed if the
to()
method has been used to move the DatasetInfo to a specific device.
- to(device: device | None = None, dtype: dtype | None = None) DatasetInfo [source]¶
Return a copy with all tensors moved to the device and dtype.
- Parameters:
- Returns:
A copy of the DatasetInfo with all tensors moved to the device and dtype.
- Return type:
- copy() DatasetInfo [source]¶
- Returns:
A shallow copy of the DatasetInfo.
- Return type:
- update(other: DatasetInfo) None [source]¶
Update this instance with the union of itself and
other
.- Parameters:
other (DatasetInfo) – Another
DatasetInfo
instance to update this one with.- Raises:
ValueError – If the
length_units
are different.- Return type:
None
- union(other: DatasetInfo) DatasetInfo [source]¶
Return the union of this instance with
other
.- Parameters:
other (DatasetInfo) – Another
DatasetInfo
instance to combine with this one.- Returns:
A new
DatasetInfo
instance containing the union of this instance andother
.- Return type:
- metatrain.utils.data.dataset.get_stats(dataset: Dataset | Subset, dataset_info: DatasetInfo) str [source]¶
Returns the statistics of a dataset or subset as a string.
- Parameters:
dataset (Dataset | Subset) – The dataset or subset to analyze.
dataset_info (DatasetInfo) – The DatasetInfo associated with the dataset.
- Returns:
A string containing the computed statistics for the dataset.
- Return type:
- metatrain.utils.data.dataset.get_atomic_types(datasets: Dataset | List[Dataset]) List[int] [source]¶
List of all atomic types present in a dataset or list of datasets.
- metatrain.utils.data.dataset.get_all_targets(datasets: Dataset | List[Dataset]) List[str] [source]¶
Sorted list of all unique targets present in a dataset or list of datasets.
- class metatrain.utils.data.dataset.CollateFn(target_keys: List[str], callables: List[Callable] | None = None, join_kwargs: Dict[str, Any] | None = None)[source]¶
Bases:
object
- metatrain.utils.data.dataset.unpack_batch(batch: Any) Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]] [source]¶
Unpacks a batch into its constituent parts.
- metatrain.utils.data.dataset.check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]) None [source]¶
Check that the training and validation sets are compatible with one another
Although these checks will not fit all use cases, most models would be expected to be able to use this function.
- Parameters:
- Raises:
TypeError – If the
dtype
within the datasets are inconsistent.ValueError – If the val_datasets has a target that is not present in the
train_datasets
.ValueError – If the training or validation set contains chemical species or targets that are not present in the training set
- Return type:
None
- class metatrain.utils.data.dataset.DiskDataset(path: str | Path, fields: List[str] | None = None)[source]¶
Bases:
Dataset
A class representing a dataset stored on disk.
The dataset is stored in a zip file, where each sample is stored in a separate directory. The directory’s name is the index of the sample (e.g.
0/
), and the files in the directory are the system (system.mta
) and the targets (each named<target_name>.mts
). These aremetatomic.torch.System
andmetatensor.torch.TensorMap
objects, respectively.Such a dataset can be created conveniently using the
DiskDatasetWriter
class.- Parameters:
- metatrain.utils.data.dataset.get_num_workers() int [source]¶
Gets a good number of workers for data loading.
- Returns:
A good number of workers for data loading.
- Return type:
- metatrain.utils.data.dataset.validate_num_workers(num_workers: int) None [source]¶
Gets a good number of workers for data loading.
- Parameters:
num_workers (int) – The number of workers to validate.
- Raises:
ValueError – If the number of workers is greater than 0 and the multiprocessing start method is not “fork”.
- Return type:
None