IO

Functions to be used for handling the serialization of models

metatrain.utils.io.check_file_extension(filename: str | Path, extension: str) str | Path[source]

Check the file extension of a file name and adds if it is not present.

If filename does not end with extension the extension is added and a warning will be issued.

Parameters:
  • filename (str | Path) – Name of the file to be checked.

  • extension (str) – Expected file extension i.e. .txt.

Returns:

Checked and probably extended file name.

Return type:

str | Path

metatrain.utils.io.is_exported_file(path: str) bool[source]

Check if a saved model file has been exported to a metatomic AtomisticModel.

The functions uses metatomic.torch.check_atomistic_model() to verify.

Parameters:

path (str) – model path

Returns:

True if the model has been exported, False otherwise.

Return type:

bool

See also

metatomic.torch.is_atomistic_model() to verify if an already loaded model is exported.

metatrain.utils.io.load_model(path: str | Path, extensions_directory: str | Path | None = None, hf_token: str | None = None) Any[source]

Load checkpoints and exported models from an URL or a local file for inference.

Remote models from Hugging Face are downloaded to a local cache directory.

If an exported model should be loaded and requires compiled extensions, their location should be passed using the extensions_directory parameter.

After reading a checkpoint, the returned model can be exported with the model’s own export() method.

Note

This function is intended to load models only for inference in Python. To continue training or to finetune use metatrain’s command line interface.

Parameters:
  • path (str | Path) – local or remote path to a model. For supported URL schemes see urllib.request

  • extensions_directory (str | Path | None) – path to a directory containing all extensions required by an exported model

  • hf_token (str | None) – HuggingFace API token to download (private) models from HuggingFace

Raises:

ValueError – if path is a YAML option file and no model

Returns:

the loaded model

Return type:

Any

metatrain.utils.io.model_from_checkpoint(checkpoint: Dict[str, Any], context: Literal['restart', 'finetune', 'export']) Module[source]

Load the checkpoint at the given path, and create the corresponding model instance. The model architecture is determined from information stored inside the checkpoint.

Parameters:
  • checkpoint (Dict[str, Any]) – checkpoint dictionary as returned by torch.load(path).

  • context (Literal['restart', 'finetune', 'export']) –

    context in which the model is loaded, one of:

    • "restart": the model is loaded to restart training from a previous

      checkpoint;

    • "finetune": the model is loaded to finetune a pretrained model;

    • "export": the model is loaded to export a trained model.

Returns:

the loaded model instance.

Return type:

Module

metatrain.utils.io.trainer_from_checkpoint(checkpoint: Dict[str, Any], context: Literal['restart', 'finetune', 'export'], hypers: Dict[str, Any]) Any[source]

Load the checkpoint at the given path, and create the corresponding trainer instance. The architecture is determined from information stored inside the checkpoint.

Parameters:
  • checkpoint (Dict[str, Any]) – checkpoint dictionary as returned by torch.load(path).

  • context (Literal['restart', 'finetune', 'export']) –

    context in which the trainer is loaded, one of:

    • "restart": the trainer is loaded to restart training from a previous

      checkpoint;

    • "finetune": the trainer is loaded to finetune a pretrained model;

    • "export": the trainer is loaded to export a trained model.

  • hypers (Dict[str, Any]) – hyperparameters to be used by the trainer. Required if context="finetune", ignored otherwise.

Returns:

the loaded trainer instance.

Return type:

Any