API
module data.dataset
function collate_graphs
collate_graphs(
batch_data: 'list'
) → tuple[list[CrystalGraph], dict[str, Tensor]]
Collate of list of (graph, target) into batch data.
Args:
batch_data
(list): list of (graph, target(dict))
Returns:
graphs
(List): a list of graphstargets
(Dict): dictionary of targets, where key and values are:e
(Tensor): energies of the structures [batch_size]f
(Tensor): forces of the structures [n_batch_atoms, 3]s
(Tensor): stresses of the structures [3*batch_size, 3]m
(Tensor): magmom of the structures [n_batch_atoms]
function get_train_val_test_loader
get_train_val_test_loader(
dataset: 'Dataset',
batch_size: 'int' = 64,
train_ratio: 'float' = 0.8,
val_ratio: 'float' = 0.1,
return_test: 'bool' = True,
num_workers: 'int' = 0,
pin_memory: 'bool' = True
) → tuple[DataLoader, DataLoader, DataLoader]
Randomly partition a dataset into train, val, test loaders.
Args:
dataset
(Dataset): The dataset to partition.batch_size
(int): The batch size for the data loaders Default = 64train_ratio
(float): The ratio of the dataset to use for training Default = 0.8val_ratio
(float): The ratio of the dataset to use for validationDefault
: 0.1return_test
(bool): Whether to return a test data loader Default = Truenum_workers
(int): The number of worker processes for loading the data see torch Dataloader documentation for more info Default = 0pin_memory
(bool): Whether to pin the memory of the data loadersDefault
: True
Returns: train_loader, val_loader and optionally test_loader
function get_loader
get_loader(
dataset,
batch_size: 'int' = 64,
num_workers: 'int' = 0,
pin_memory: 'bool' = True
) → DataLoader
Get a dataloader from a dataset.
Args:
dataset
(Dataset): The dataset to partition.batch_size
(int): The batch size for the data loaders Default = 64num_workers
(int): The number of worker processes for loading the data see torch Dataloader documentation for more info Default = 0pin_memory
(bool): Whether to pin the memory of the data loadersDefault
: True
Returns: data_loader
class StructureData
A simple torch Dataset of structures.
method __init__
__init__(
structures: 'list[Structure]',
energies: 'list[float]',
forces: 'list[Sequence[Sequence[float]]]',
stresses: 'list[Sequence[Sequence[float]]] | None' = None,
magmoms: 'list[Sequence[Sequence[float]]] | None' = None,
structure_ids: 'list | None' = None,
graph_converter: 'CrystalGraphConverter | None' = None,
shuffle: 'bool' = True
) → None
Initialize the dataset.
Args:
structures
(list[dict]): pymatgen Structure objects.energies
(list[float]): [data_size, 1]forces
(list[list[float]]): [data_size, n_atoms, 3]stresses
(list[list[float]], optional): [data_size, 3, 3] Default = Nonemagmoms
(list[list[float]], optional): [data_size, n_atoms, 1] Default = Nonestructure_ids
(list, optional): a list of ids to track the structures Default = Nonegraph_converter
(CrystalGraphConverter, optional): Converts the structures to graphs. If None, it will be set to CHGNet 0.3.0 converter with AtomGraph cutoff = 6A.shuffle
(bool): whether to shuffle the sequence of dataset Default = True
Raises:
RuntimeError
: if the length of structures and labels (energies, forces, stresses, magmoms) are not equal.
classmethod from_vasp
from_vasp(
file_root: 'str',
check_electronic_convergence: 'bool' = True,
save_path: 'str | None' = None,
graph_converter: 'CrystalGraphConverter | None' = None,
shuffle: 'bool' = True
) → Self
Parse VASP output files into structures and labels and feed into the dataset.
Args:
file_root
(str): the directory of the VASP calculation outputscheck_electronic_convergence
(bool): if set to True, this function will raise Exception to VASP calculation that did not achieve electronic convergence. Default = Truesave_path
(str): path to save the parsed VASP labels Default = Nonegraph_converter
(CrystalGraphConverter, optional): Converts the structures to graphs. If None, it will be set to CHGNet 0.3.0 converter with AtomGraph cutoff = 6A.shuffle
(bool): whether to shuffle the sequence of dataset Default = True
class CIFData
A dataset from CIFs.
method __init__
__init__(
cif_path: 'str',
labels: 'str | dict' = 'labels.json',
targets: 'TrainTask' = 'efsm',
graph_converter: 'CrystalGraphConverter | None' = None,
energy_key: 'str' = 'energy_per_atom',
force_key: 'str' = 'force',
stress_key: 'str' = 'stress',
magmom_key: 'str' = 'magmom',
shuffle: 'bool' = True
) → None
Initialize the dataset from a directory containing CIFs.
Args:
cif_path
(str): path that contain all the graphs, labels.jsonlabels
(str, dict): the path or dictionary of labelstargets
(“ef” | “efs” | “efm” | “efsm”): The training targets. Default = “efsm”graph_converter
(CrystalGraphConverter, optional): Converts the structures to graphs. If None, it will be set to CHGNet 0.3.0 converter with AtomGraph cutoff = 6A.energy_key
(str, optional): the key of energy in the labels. Default = “energy_per_atom”force_key
(str, optional): the key of force in the labels. Default = “force”stress_key
(str, optional): the key of stress in the labels. Default = “stress”magmom_key
(str, optional): the key of magmom in the labels. Default = “magmom”shuffle
(bool): whether to shuffle the sequence of dataset Default = True
class GraphData
A dataset of graphs. This is compatible with the graph.pt documents made by make_graphs.py. We recommend you to use the dataset to avoid graph conversion steps.
method __init__
__init__(
graph_path: 'str',
labels: 'str | dict' = 'labels.json',
targets: 'TrainTask' = 'efsm',
exclude: 'str | list | None' = None,
energy_key: 'str' = 'energy_per_atom',
force_key: 'str' = 'force',
stress_key: 'str' = 'stress',
magmom_key: 'str' = 'magmom',
shuffle: 'bool' = True
) → None
Initialize the dataset from a directory containing saved crystal graphs.
Args:
graph_path
(str): path that contain all the graphs, labels.jsonlabels
(str, dict): the path or dictionary of labels. Default = “labels.json”targets
(“ef” | “efs” | “efm” | “efsm”): The training targets. Default = “efsm”exclude
(str, list | None): the path or list of excluded graphs. Default = Noneenergy_key
(str, optional): the key of energy in the labels. Default = “energy_per_atom”force_key
(str, optional): the key of force in the labels. Default = “force”stress_key
(str, optional): the key of stress in the labels. Default = “stress”magmom_key
(str, optional): the key of magmom in the labels. Default = “magmom”shuffle
(bool): whether to shuffle the sequence of dataset Default = True
method get_train_val_test_loader
get_train_val_test_loader(
train_ratio: 'float' = 0.8,
val_ratio: 'float' = 0.1,
train_key: 'list[str] | None' = None,
val_key: 'list[str] | None' = None,
test_key: 'list[str] | None' = None,
batch_size=32,
num_workers=0,
pin_memory=True
) → tuple[DataLoader, DataLoader, DataLoader]
Partition the GraphData using materials id, randomly select the train_keys, val_keys, test_keys by train val test ratio, or use pre-defined train_keys, val_keys, and test_keys to create train, val, test loaders.
Args:
train_ratio
(float): The ratio of the dataset to use for training Default = 0.8val_ratio
(float): The ratio of the dataset to use for validationDefault
: 0.1train_key
(List(str), optional): a list of mp_ids for train setval_key
(List(str), optional): a list of mp_ids for val settest_key
(List(str), optional): a list of mp_ids for test setbatch_size
(int): batch size Default = 32num_workers
(int): The number of worker processes for loading the data see torch Dataloader documentation for more info Default = 0pin_memory
(bool): Whether to pin the memory of the data loadersDefault
: True
Returns: train_loader, val_loader, test_loader
class StructureJsonData
Read structure and targets from a JSON file. This class is used to load the MPtrj dataset.
method __init__
__init__(
data: 'str | dict',
graph_converter: 'CrystalGraphConverter',
targets: 'TrainTask' = 'efsm',
energy_key: 'str' = 'energy_per_atom',
force_key: 'str' = 'force',
stress_key: 'str' = 'stress',
magmom_key: 'str' = 'magmom',
shuffle: 'bool' = True
) → None
Initialize the dataset by reading JSON files.
Args:
data
(str | dict): file path or dir name that contain all the JSONsgraph_converter
(CrystalGraphConverter): Converts pymatgen.core.Structure to CrystalGraph object.targets
(“ef” | “efs” | “efm” | “efsm”): The training targets. Default = “efsm”energy_key
(str, optional): the key of energy in the labels. Default = “energy_per_atom”force_key
(str, optional): the key of force in the labels. Default = “force”stress_key
(str, optional): the key of stress in the labels. Default = “stress”magmom_key
(str, optional): the key of magmom in the labels. Default = “magmom”shuffle
(bool): whether to shuffle the sequence of dataset Default = True
method get_train_val_test_loader
get_train_val_test_loader(
train_ratio: 'float' = 0.8,
val_ratio: 'float' = 0.1,
train_key: 'list[str] | None' = None,
val_key: 'list[str] | None' = None,
test_key: 'list[str] | None' = None,
batch_size=32,
num_workers=0,
pin_memory=True
) → tuple[DataLoader, DataLoader, DataLoader]
Partition the Dataset using materials id, randomly select the train_keys, val_keys, test_keys by train val test ratio, or use pre-defined train_keys, val_keys, and test_keys to create train, val, test loaders.
Args:
train_ratio
(float): The ratio of the dataset to use for training Default = 0.8val_ratio
(float): The ratio of the dataset to use for validationDefault
: 0.1train_key
(List(str), optional): a list of mp_ids for train setval_key
(List(str), optional): a list of mp_ids for val settest_key
(List(str), optional): a list of mp_ids for test setbatch_size
(int): batch size Default = 32num_workers
(int): The number of worker processes for loading the data see torch Dataloader documentation for more info Default = 0pin_memory
(bool): Whether to pin the memory of the data loadersDefault
: True
Returns: train_loader, val_loader, test_loader
module graph.converter
class CrystalGraphConverter
Convert a pymatgen.core.Structure to a CrystalGraph The CrystalGraph dataclass stores essential field to make sure that gradients like force and stress can be calculated through back-propagation later.
method __init__
__init__(
atom_graph_cutoff: 'float' = 6,
bond_graph_cutoff: 'float' = 3,
algorithm: "Literal['legacy', 'fast']" = 'fast',
on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'error',
verbose: 'bool' = False
) → None
Initialize the Crystal Graph Converter.
Args:
atom_graph_cutoff
(float): cutoff radius to search for neighboring atom in atom_graph. Default = 5.bond_graph_cutoff
(float): bond length threshold to include bond in bond_graph. Default = 3.algorithm
(‘legacy’ | ‘fast’): algorithm to use for converting graphs.'legacy'
: python implementation of graph creation'fast'
: C implementation of graph creation, this is faster, but will need the cygraph.c file correctly compiled from pip install Default = ‘fast’on_isolated_atoms
(‘ignore’ | ‘warn’ | ‘error’): how to handle Structures with isolated atoms. Default = ‘error’verbose
(bool): whether to print the CrystalGraphConverter initialization message. Default = False.
method as_dict
as_dict() → dict[str, str | float]
Save the args of the graph converter.
method forward
forward(structure: 'Structure', graph_id=None, mp_id=None) → CrystalGraph
Convert a structure, return a CrystalGraph.
Args:
structure
(pymatgen.core.Structure): structure to convertgraph_id
(str): an id to keep track of this crystal graph Default = Nonemp_id
(str): Materials Project id of this structure Default = None
Return: CrystalGraph that is ready to use by CHGNet
classmethod from_dict
from_dict(dct: 'dict') → Self
Create converter from dictionary.
method set_isolated_atom_response
set_isolated_atom_response(
on_isolated_atoms: "Literal['ignore', 'warn', 'error']"
) → None
Set the graph converter’s response to isolated atom graph
Args:
on_isolated_atoms
(‘ignore’ | ‘warn’ | ‘error’): how to handle Structures with isolated atoms. Default = ‘error’.
Returns: None
module graph.crystalgraph
class CrystalGraph
A data class for crystal graph.
method __init__
__init__(
atomic_number: 'Tensor',
atom_frac_coord: 'Tensor',
atom_graph: 'Tensor',
atom_graph_cutoff: 'float',
neighbor_image: 'Tensor',
directed2undirected: 'Tensor',
undirected2directed: 'Tensor',
bond_graph: 'Tensor',
bond_graph_cutoff: 'float',
lattice: 'Tensor',
graph_id: 'str | None' = None,
mp_id: 'str | None' = None,
composition: 'str | None' = None
) → None
Initialize the crystal graph.
Attention! This data class is not intended to be created manually. CrystalGraph should be returned by a CrystalGraphConverter
Args:
atomic_number
(Tensor): the atomic numbers of atoms in the structure [n_atom]atom_frac_coord
(Tensor): the fractional coordinates of the atoms [n_atom, 3]atom_graph
(Tensor): a directed graph adjacency list, (center atom indices, neighbor atom indices, undirected bond index) for bonds in bond_fea [num_directed_bonds, 2]atom_graph_cutoff
(float): the cutoff radius to draw edges in atom_graphneighbor_image
(Tensor): the periodic image specifying the location of neighboring atomsee
: https://github.com/materialsproject/pymatgen/blob/ca2175c762e37ea7 c9f3950ef249bc540e683da1/pymatgen/core/structure.py#L1485-L1541 [num_directed_bonds, 3]directed2undirected
(Tensor): the mapping from directed edge index to undirected edge index for the atom graph [num_directed_bonds]undirected2directed
(Tensor): the mapping from undirected edge index to one of its directed edge index, this is essentially the inverse mapping of the directed2undirected this tensor is needed for computation efficiency. Note that num_directed_bonds = 2 * num_undirected_bonds [num_undirected_bonds]bond_graph
(Tensor): a directed graph adjacency list, (atom indices, 1st undirected bond idx, 1st directed bond idx, 2nd undirected bond idx, 2nd directed bond idx) for angles in angle_fea [n_angle, 5]bond_graph_cutoff
(float): the cutoff bond length to include bond as nodes in bond_graphlattice
(Tensor): lattices of the input structure [3, 3]graph_id
(str | None): an id to keep track of this crystal graph Default = Nonemp_id
(str | None): Materials Project id of this structure Default = Nonecomposition
: Chemical composition of the compound, used just for better tracking of the graph Default = None.
Raises:
ValueError
: if len(directed2undirected) != 2 * len(undirected2directed)
property num_isolated_atoms
Number of isolated atoms given the atom graph cutoff Isolated atoms are disconnected nodes in the atom graph that will not get updated in CHGNet. These atoms will always have calculated force equal to zero.
With the default CHGNet atom graph cutoff radius, only ~ 0.1% of MPtrj dataset structures has isolated atoms.
classmethod from_dict
from_dict(dic: 'dict[str, Any]') → Self
Load a CrystalGraph from a dictionary.
classmethod from_file
from_file(file_name: 'str') → Self
Load a crystal graph from a file.
Args:
file_name
(str): The path to the file.
Returns:
CrystalGraph
: The loaded graph.
method save
save(fname: 'str | None' = None, save_dir: 'str' = '.') → str
Save the graph to a file.
Args:
fname
(str, optional): File name. Defaults to None.save_dir
(str, optional): Directory to save the file. Defaults to ”.“.
Returns:
str
: The path to the saved file.
method to
to(device: 'str' = 'cpu') → CrystalGraph
Move the graph to a device. Default = ‘cpu’.
method to_dict
to_dict() → dict[str, Any]
Convert the graph to a dictionary.
module graph.graph
class Node
A node in a graph.
method __init__
__init__(index: 'int', info: 'dict | None' = None) → None
Initialize a Node.
Args:
index
(int): the index of this nodeinfo
(dict, optional): any additional information about this node.
method add_neighbor
add_neighbor(index, edge) → None
Draw an directed edge between self and the node specified by index.
Args:
index
(int): the index of neighboring nodeedge
(DirectedEdge): an DirectedEdge object pointing from self to the node.
class Edge
Abstract base class for edges in a graph.
method __init__
__init__(
nodes: 'list',
index: 'int | None' = None,
info: 'dict | None' = None
) → None
Initialize an Edge.
class UndirectedEdge
An undirected/bi-directed edge in a graph.
method __init__
__init__(
nodes: 'list',
index: 'int | None' = None,
info: 'dict | None' = None
) → None
Initialize an Edge.
class DirectedEdge
A directed edge in a graph.
method __init__
__init__(
nodes: 'list',
index: 'int | None' = None,
info: 'dict | None' = None
) → None
Initialize an Edge.
method make_undirected
make_undirected(index: 'int', info: 'dict | None' = None) → UndirectedEdge
Make a directed edge undirected.
class Graph
A graph for storing the neighbor information of atoms.
method __init__
__init__(nodes: 'list[Node]') → None
Initialize a Graph from a list of nodes.
method add_edge
add_edge(
center_index,
neighbor_index,
image,
distance,
dist_tol: 'float' = 1e-06
) → None
Add an directed edge to the graph.
Args:
center_index
(int): center node indexneighbor_index
(int): neighbor node indeximage
(np.array): the periodic cell image the neighbor is fromdistance
(float): distance between center and neighbor.dist_tol
(float): tolerance for distance comparison between edges. Default = 1e-6
method adjacency_list
adjacency_list() → tuple[list[list[int]], list[int]]
Get the adjacency list Return: graph: the adjacency list [[0, 1], [0, 2], … [5, 2] … ]] the fist column specifies center/source node, the second column specifies neighbor/destination node directed2undirected: [0, 1, …] a list of length = num_directed_edge that specifies the undirected edge index corresponding to the directed edges represented in each row in the graph adjacency list.
method as_dict
as_dict() → dict
Return dictionary serialization of a Graph.
method line_graph_adjacency_list
line_graph_adjacency_list(cutoff) → tuple[list[list[int]], list[int]]
Get the line graph adjacency list.
Args:
cutoff
(float): a float to indicate the maximum edge length to be included in constructing the line graph, this is used to decrease computation complexity
Return: line_graph: [[0, 1, 1, 2, 2], [0, 1, 1, 4, 23], [1, 4, 23, 5, 66], … … ] the fist column specifies node(atom) index at this angle, the second column specifies 1st undirected edge(left bond) index, the third column specifies 1st directed edge(left bond) index, the fourth column specifies 2nd undirected edge(right bond) index, the fifth column specifies 2nd directed edge(right bond) index,. undirected2directed: [32, 45, …] a list of length = num_undirected_edge that maps the undirected edge index to one of its directed edges indices
method to
to(filename='graph.json') → None
Save graph dictionary to file.
method undirected2directed
undirected2directed() → list[int]
The index map from undirected_edge index to one of its directed_edge index.
module model.basis
class Fourier
Fourier Expansion for angle features.
method __init__
__init__(order: 'int' = 5, learnable: 'bool' = False) → None
Initialize the Fourier expansion.
Args:
order
(int): the maximum order, refer to the N in eq 1 in CHGNet paper Default = 5learnable
(bool): whether to set the frequencies as learnable parameters Default = False
method forward
forward(x: 'Tensor') → Tensor
Apply Fourier expansion to a feature Tensor.
class RadialBessel
1D Bessel Basis from: https://github.com/TUM-DAML/gemnet_pytorch/.
method __init__
__init__(
num_radial: 'int' = 9,
cutoff: 'float' = 5,
learnable: 'bool' = False,
smooth_cutoff: 'int' = 5
) → None
Initialize the SmoothRBF function.
Args:
num_radial
(int): Controls maximum frequency Default = 9cutoff
(float): Cutoff distance in Angstrom. Default = 5learnable
(bool): whether to set the frequencies learnable Default = Falsesmooth_cutoff
(int): smooth cutoff strength Default = 5
method forward
forward(
dist: 'Tensor',
return_smooth_factor: 'bool' = False
) → Tensor | tuple[Tensor, Tensor]
Apply Bessel expansion to a feature Tensor.
Args:
dist
(Tensor): tensor of distances [n, 1]return_smooth_factor
(bool): whether to return the smooth factor Default = False
Returns:
out
(Tensor): tensor of Bessel distances [n, dim] where the expanded dimension will be num_radialsmooth_factor
(Tensor): tensor of smooth factors [n, 1]
class GaussianExpansion
Expands the distance by Gaussian basis. Unit: angstrom.
method __init__
__init__(
min: 'float' = 0,
max: 'float' = 5,
step: 'float' = 0.5,
var: 'float | None' = None
) → None
Gaussian Expansion expand a scalar feature to a soft-one-hot feature vector.
Args:
min
(float): minimum Gaussian center valuemax
(float): maximum Gaussian center valuestep
(float): Step size between the Gaussian centersvar
(float): variance in gaussian filter, default to step
method expand
expand(features: 'Tensor') → Tensor
Apply Gaussian filter to a feature Tensor.
Args:
features
(Tensor): tensor of features [n]
Returns:
expanded features (Tensor)
: tensor of Gaussian distances [n, dim] where the expanded dimension will be (dmax - dmin) / step + 1
class CutoffPolynomial
Polynomial soft-cutoff function for atom graph ref: https://github.com/TUM-DAML/gemnet_pytorch/blob/-/gemnet/model/layers/envelope.py.
method __init__
__init__(cutoff: 'float' = 5, cutoff_coeff: 'float' = 5) → None
Initialize the polynomial cutoff function.
Args:
cutoff
(float): cutoff radius (A) in atom graph construction Default = 5cutoff_coeff
(float): the strength of soft-Cutoff 0 will disable the cutoff, returning 1 at every r for positive numbers > 0, the smaller cutoff_coeff is, the faster this function decays. Default = 5.
method forward
forward(r: 'Tensor') → Tensor
Polynomial cutoff function.
Args:
r
(Tensor): radius distance tensor
Returns:
polynomial cutoff functions
: decaying from 1 at r=0 to 0 at r=cutoff
module model.composition_model
class CompositionModel
A simple FC model that takes in a chemical composition (no structure info) and outputs energy.
method __init__
__init__(
atom_fea_dim: 'int' = 64,
activation: 'str' = 'silu',
is_intensive: 'bool' = True,
max_num_elements: 'int' = 94
) → None
Initialize a CompositionModel.
method forward
forward(graphs: 'list[CrystalGraph]') → Tensor
Get the energy of a list of CrystalGraphs as Tensor.
class AtomRef
A linear regression for elemental energy. From: https://github.com/materialsvirtuallab/m3gnet/.
method __init__
__init__(is_intensive: 'bool' = True, max_num_elements: 'int' = 94) → None
Initialize an AtomRef model.
method fit
fit(
structures_or_graphs: 'Sequence[Structure | CrystalGraph]',
energies: 'Sequence[float]'
) → None
Fit the model to a list of crystals and energies.
Args:
structures_or_graphs
(list[Structure | CrystalGraph]): Any iterable of pymatgen structures and/or graphs.energies
(list[float]): Target energies.
method forward
forward(graphs: 'list[CrystalGraph]') → Tensor
Get the energy of a list of CrystalGraphs.
Args:
graphs
(List(CrystalGraph)): a list of Crystal Graph to compute
Returns: energy (tensor)
method get_site_energies
get_site_energies(graphs: 'list[CrystalGraph]') → list[Tensor]
Predict the site energies given a list of CrystalGraphs.
Args:
graphs
(List(CrystalGraph)): a list of Crystal Graph to compute
Returns: a list of tensors corresponding to site energies of each graph [batchsize].
method initialize_from
initialize_from(dataset: 'str') → None
Initialize pre-fitted weights from a dataset.
method initialize_from_MPF
initialize_from_MPF() → None
Initialize pre-fitted weights from MPF dataset.
method initialize_from_MPtrj
initialize_from_MPtrj() → None
Initialize pre-fitted weights from MPtrj dataset.
method initialize_from_numpy
initialize_from_numpy(file_name: 'str | Path') → None
Initialize pre-fitted weights from numpy file.
module model.dynamics
Global Variables
- TYPE_CHECKING
- all_changes
- all_properties
- OPTIMIZERS
class CHGNetCalculator
CHGNet Calculator for ASE applications.
method __init__
__init__(
model: 'CHGNet | None' = None,
use_device: 'str | None' = None,
check_cuda_mem: 'bool' = False,
stress_weight: 'float | None' = 0.006241509125883258,
on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'warn',
**kwargs
) → None
Provide a CHGNet instance to calculate various atomic properties using ASE.
Args:
model
(CHGNet): instance of a chgnet model. If set to None, the pretrained CHGNet is loaded. Default = Noneuse_device
(str, optional): The device to be used for predictions, either “cpu”, “cuda”, or “mps”. If not specified, the default device is automatically selected based on the available options. Default = Nonecheck_cuda_mem
(bool): Whether to use cuda with most available memory Default = Falsestress_weight
(float): the conversion factor to convert GPa to eV/A^3. Default = 1/160.21on_isolated_atoms
(‘ignore’ | ‘warn’ | ‘error’): how to handle Structures with isolated atoms. Default = ‘warn’**kwargs
: Passed to the Calculator parent class.
property directory
property label
property n_params
The number of parameters in CHGNet.
property name
property version
The version of CHGNet.
method calculate
calculate(
atoms: 'Atoms | None' = None,
properties: 'list | None' = None,
system_changes: 'list | None' = None
) → None
Calculate various properties of the atoms using CHGNet.
Args:
atoms
(Atoms | None): The atoms object to calculate properties for.properties
(list | None): The properties to calculate. Default is all properties.system_changes
(list | None): The changes made to the system. Default is all changes.
classmethod from_file
from_file(path: 'str', use_device: 'str | None' = None, **kwargs) → Self
Load a user’s CHGNet model and initialize the Calculator.
class StructOptimizer
Wrapper class for structural relaxation.
method __init__
__init__(
model: 'CHGNet | CHGNetCalculator | None' = None,
optimizer_class: 'Optimizer | str | None' = 'FIRE',
use_device: 'str | None' = None,
stress_weight: 'float' = 0.006241509125883258,
on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'warn'
) → None
Provide a trained CHGNet model and an optimizer to relax crystal structures.
Args:
model
(CHGNet): instance of a CHGNet model or CHGNetCalculator. If set to None, the pretrained CHGNet is loaded. Default = Noneoptimizer_class
(Optimizer,str): choose optimizer from ASE. Default = “FIRE”use_device
(str, optional): The device to be used for predictions, either “cpu”, “cuda”, or “mps”. If not specified, the default device is automatically selected based on the available options. Default = Nonestress_weight
(float): the conversion factor to convert GPa to eV/A^3. Default = 1/160.21on_isolated_atoms
(‘ignore’ | ‘warn’ | ‘error’): how to handle Structures with isolated atoms. Default = ‘warn’
property n_params
The number of parameters in CHGNet.
property version
The version of CHGNet.
method relax
relax(
atoms: 'Structure | Atoms',
fmax: 'float | None' = 0.1,
steps: 'int | None' = 500,
relax_cell: 'bool | None' = True,
ase_filter: 'str | None' = 'FrechetCellFilter',
save_path: 'str | None' = None,
loginterval: 'int | None' = 1,
crystal_feas_save_path: 'str | None' = None,
verbose: 'bool' = True,
assign_magmoms: 'bool' = True,
**kwargs
) → dict[str, Structure | TrajectoryObserver]
Relax the Structure/Atoms until maximum force is smaller than fmax.
Args:
atoms
(Structure | Atoms): A Structure or Atoms object to relax.fmax
(float | None): The maximum force tolerance for relaxation. Default = 0.1steps
(int | None): The maximum number of steps for relaxation. Default = 500relax_cell
(bool | None): Whether to relax the cell as well. Default = Truease_filter
(str | ase.filters.Filter): The filter to apply to the atoms object for relaxation. Default = FrechetCellFilter Default used to be ExpCellFilter which was removed due to bug reportedin https
: //gitlab.com/ase/ase/-/issues/1321 and fixed inhttps
: //gitlab.com/ase/ase/-/merge_requests/3024.save_path
(str | None): The path to save the trajectory. Default = Noneloginterval
(int | None): Interval for logging trajectory and crystal features. Default = 1crystal_feas_save_path
(str | None): Path to save crystal feature vectors which are logged at a loginterval rage Default = Noneverbose
(bool): Whether to print the output of the ASE optimizer. Default = Trueassign_magmoms
(bool): Whether to assign magnetic moments to the final structure. Default = True**kwargs
: Additional parameters for the optimizer.
Returns: dict[str, Structure | TrajectoryObserver]: A dictionary with ‘final_structure’ and ‘trajectory’.
class TrajectoryObserver
Trajectory observer is a hook in the relaxation process that saves the intermediate structures.
method __init__
__init__(atoms: 'Atoms') → None
Create a TrajectoryObserver from an Atoms object.
Args:
atoms
(Atoms): the structure to observe.
method compute_energy
compute_energy() → float
Calculate the potential energy.
Returns:
energy
(float): the potential energy.
method save
save(filename: 'str') → None
Save the trajectory to file.
Args:
filename
(str): filename to save the trajectory
class CrystalFeasObserver
CrystalFeasObserver is a hook in the relaxation and MD process that saves the intermediate crystal feature structures.
method __init__
__init__(atoms: 'Atoms') → None
Create a CrystalFeasObserver from an Atoms object.
method save
save(filename: 'str') → None
Save the crystal feature vectors to filename in pickle format.
class MolecularDynamics
Molecular dynamics class.
method __init__
__init__(
atoms: 'Atoms | Structure',
model: 'CHGNet | CHGNetCalculator | None' = None,
ensemble: 'str' = 'nvt',
thermostat: 'str' = 'Berendsen_inhomogeneous',
temperature: 'int' = 300,
starting_temperature: 'int | None' = None,
timestep: 'float' = 2.0,
pressure: 'float' = 0.000101325,
taut: 'float | None' = None,
taup: 'float | None' = None,
bulk_modulus: 'float | None' = None,
trajectory: 'str | Trajectory | None' = None,
logfile: 'str | None' = None,
loginterval: 'int' = 1,
crystal_feas_logfile: 'str | None' = None,
append_trajectory: 'bool' = False,
on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'warn',
use_device: 'str | None' = None
) → None
Initialize the MD class.
Args:
atoms
(Atoms): atoms to run the MDmodel
(CHGNet): instance of a CHGNet model or CHGNetCalculator. If set to None, the pretrained CHGNet is loaded. Default = Noneensemble
(str): choose from ‘nve’, ‘nvt’, ‘npt’ Default = “nvt”thermostat
(str): Thermostat to use choose from “Nose-Hoover”, “Berendsen”, “Berendsen_inhomogeneous” Default = “Berendsen_inhomogeneous”temperature
(float): temperature for MD simulation, in K Default = 300starting_temperature
(float): starting temperature of MD simulation, in K if set as None, the MD starts with the momentum carried by ase.Atoms if input is a pymatgen.core.Structure, the MD starts at 0K Default = Nonetimestep
(float): time step in fs Default = 2pressure
(float): pressure in GPa Can be 3x3 or 6 np.array if thermostat is “Nose-Hoover” Default = 1.01325e-4 GPa = 1 atmtaut
(float): time constant for temperature coupling in fs. The temperature will be raised to target temperature in approximate 10 taut time. Default = 100 timesteptaup
(float): time constant for pressure coupling in fs Default = 1000 * timestepbulk_modulus
(float): bulk modulus of the material in GPa. Used in NPT ensemble for the barostat pressure coupling. The DFT bulk modulus can be found for most materials athttps
: //next-gen.materialsproject.org/
In NPT ensemble, the effective damping time for pressure is multiplied by compressibility. In LAMMPS, Bulk modulus is defaulted to 10
see
: https://docs.lammps.org/fix_press_berendsen.htmland
: https://gitlab.com/ase/ase/-/blob/master/ase/md/nptberendsen.py
If bulk modulus is not provided here, it will be calculated by CHGNet through Birch Murnaghan equation of state (EOS). Note the EOS fitting can fail because of non-parabolic potential energy surface, which is common with soft system like liquid and gas. In such case, user should provide an input bulk modulus for better barostat coupling, otherwise a guessed bulk modulus = 2 GPa will be used (water’s bulk modulus)
Default = None
trajectory
(str or Trajectory): Attach trajectory object Default = Nonelogfile
(str): open this file for recording MD outputs Default = Noneloginterval
(int): write to log file every interval steps Default = 1crystal_feas_logfile
(str): open this file for recording crystal features during MD. Default = Noneappend_trajectory
(bool): Whether to append to prev trajectory. If false, previous trajectory gets overwritten Default = Falseon_isolated_atoms
(‘ignore’ | ‘warn’ | ‘error’): how to handle Structures with isolated atoms. Default = ‘warn’use_device
(str): the device for the MD run Default = None
method run
run(steps: 'int') → None
Thin wrapper of ase MD run.
Args:
steps
(int): number of MD steps
method set_atoms
set_atoms(atoms: 'Atoms') → None
Set new atoms to run MD.
Args:
atoms
(Atoms): new atoms for running MD
method upper_triangular_cell
upper_triangular_cell(verbose: 'bool | None' = False) → None
Transform to upper-triangular cell. ASE Nose-Hoover implementation only supports upper-triangular cell while ASE’s canonical description is lower-triangular cell.
Args:
verbose
(bool): Whether to notify user about upper-triangular cell transformation. Default = False
class EquationOfState
Class to calculate equation of state.
method __init__
__init__(
model: 'CHGNet | CHGNetCalculator | None' = None,
optimizer_class: 'Optimizer | str | None' = 'FIRE',
use_device: 'str | None' = None,
stress_weight: 'float' = 0.006241509125883258,
on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'error'
) → None
Initialize a structure optimizer object for calculation of bulk modulus.
Args:
model
(CHGNet): instance of a CHGNet model or CHGNetCalculator. If set to None, the pretrained CHGNet is loaded. Default = Noneoptimizer_class
(Optimizer,str): choose optimizer from ASE. Default = “FIRE”use_device
(str, optional): The device to be used for predictions, either “cpu”, “cuda”, or “mps”. If not specified, the default device is automatically selected based on the available options. Default = Nonestress_weight
(float): the conversion factor to convert GPa to eV/A^3. Default = 1/160.21on_isolated_atoms
(‘ignore’ | ‘warn’ | ‘error’): how to handle Structures with isolated atoms. Default = ‘error’
method fit
fit(
atoms: 'Structure | Atoms',
n_points: 'int' = 11,
fmax: 'float | None' = 0.1,
steps: 'int | None' = 500,
verbose: 'bool | None' = False,
**kwargs
) → None
Relax the Structure/Atoms and fit the Birch-Murnaghan equation of state.
Args:
atoms
(Structure | Atoms): A Structure or Atoms object to relax.n_points
(int): Number of structures used in fitting the equation of statesfmax
(float | None): The maximum force tolerance for relaxation. Default = 0.1steps
(int | None): The maximum number of steps for relaxation. Default = 500verbose
(bool): Whether to print the output of the ASE optimizer. Default = False**kwargs
: Additional parameters for the optimizer.
Returns: Bulk Modulus (float)
method get_bulk_modulus
get_bulk_modulus(unit: 'str' = 'eV/A^3') → float
Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.
Args:
unit
(str): The unit of bulk modulus. Can be “eV/A^3” or “GPa” Default = “eV/A^3”
Returns: Bulk Modulus (float)
method get_compressibility
get_compressibility(unit: 'str' = 'A^3/eV') → float
Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.
Args:
unit
(str): The unit of bulk modulus. Can be “A^3/eV”, “GPa^-1” “Pa^-1” or “m^2/N” Default = “A^3/eV”
Returns: Bulk Modulus (float)
module model.encoders
class AtomEmbedding
Encode an atom by its atomic number using a learnable embedding layer.
method __init__
__init__(atom_feature_dim: 'int', max_num_elements: 'int' = 94) → None
Initialize the Atom featurizer.
Args:
atom_feature_dim
(int): dimension of atomic embedding.max_num_elements
(int): maximum number of elements in the dataset. Default = 94
method forward
forward(atomic_numbers: 'Tensor') → Tensor
Convert the structure to a atom embedding tensor.
Args:
atomic_numbers
(Tensor): [n_atom, 1].
Returns:
atom_fea
(Tensor): atom embeddings [n_atom, atom_feature_dim].
class BondEncoder
Encode a chemical bond given the positions of two atoms using Gaussian distance.
method __init__
__init__(
atom_graph_cutoff: 'float' = 5,
bond_graph_cutoff: 'float' = 3,
num_radial: 'int' = 9,
cutoff_coeff: 'int' = 5,
learnable: 'bool' = False
) → None
Initialize the bond encoder.
Args:
atom_graph_cutoff
(float): The cutoff for constructing AtomGraph default = 5bond_graph_cutoff
(float): The cutoff for constructing BondGraph default = 3num_radial
(int): The number of radial component. Default = 9cutoff_coeff
(int): Strength for graph cutoff smoothness. Default = 5learnable
(bool): Whether the frequency in rbf expansion is learnable. Default = False
method forward
forward(
center: 'Tensor',
neighbor: 'Tensor',
undirected2directed: 'Tensor',
image: 'Tensor',
lattice: 'Tensor'
) → tuple[Tensor, Tensor, Tensor]
Compute the pairwise distance between 2 3d coordinates.
Args:
center
(Tensor): 3d cartesian coordinates of center atoms [n_bond, 3]neighbor
(Tensor): 3d cartesian coordinates of neighbor atoms [n_bond, 3]undirected2directed
(Tensor): mapping from undirected bond to one of its directed bond [n_bond]image
(Tensor): the periodic image specifying the location of neighboring atom [n_bond, 3]lattice
(Tensor): the lattice of this structure [3, 3]
Returns:
bond_basis_ag
(Tensor): the bond basis in AtomGraph [n_bond, num_radial]bond_basis_ag
(Tensor): the bond basis in BondGraph [n_bond, num_radial]bond_vectors
(Tensor): normalized bond vectors, for tracking the bond directions [n_bond, 3]
class AngleEncoder
Encode an angle given the two bond vectors using Fourier Expansion.
method __init__
__init__(num_angular: 'int' = 9, learnable: 'bool' = True) → None
Initialize the angle encoder.
Args:
num_angular
(int): number of angular basis to use. Must be an odd integer.learnable
(bool): whether to set the frequencies of the Fourier expansion as learnable parameters. Default = False
method forward
forward(bond_i: 'Tensor', bond_j: 'Tensor') → Tensor
Compute the angles between normalized vectors.
Args:
bond_i
(Tensor): normalized left bond vector [n_angle, 3]bond_j
(Tensor): normalized right bond vector [n_angle, 3]
Returns:
angle_fea
(Tensor): expanded cos_ij [n_angle, angle_feature_dim]
module model.functions
function aggregate
aggregate(
data: 'Tensor',
owners: 'Tensor',
average=True,
num_owner=None
) → Tensor
Aggregate rows in data by specifying the owners.
Args:
data
(Tensor): data tensor to aggregate [n_row, feature_dim]owners
(Tensor): specify the owner of each row [n_row, 1]average
(bool): if True, average the rows, if False, sum the rows. Default = Truenum_owner
(int, optional): the number of owners, this is needed if the max idx of owner is not presented in owners tensor Default = None
Returns:
output
(Tensor): [num_owner, feature_dim]
function find_activation
find_activation(name: 'str') → Module
Return an activation function using name.
function find_normalization
find_normalization(name: 'str', dim: 'int | None' = None) → Module | None
Return an normalization function using name.
class MLP
Multi-Layer Perceptron used for non-linear regression.
method __init__
__init__(
input_dim: 'int',
output_dim: 'int' = 1,
hidden_dim: 'int | Sequence[int] | None' = (64, 64),
dropout: 'float' = 0,
activation: 'str' = 'silu',
bias: 'bool' = True
) → None
Initialize the MLP.
Args:
input_dim
(int): the input dimensionoutput_dim
(int): the output dimensionhidden_dim
(list[int] | int]): a list of integers or a single integer representing the number of hidden units in each layer of the MLP. Default = [64, 64]dropout
(float): the dropout rate before each linear layer. Default: 0activation
(str, optional): The name of the activation function to use in the gated MLP. Must be one of “relu”, “silu”, “tanh”, or “gelu”. Default = “silu”bias
(bool): whether to use bias in each Linear layers. Default = True
method forward
forward(x: 'Tensor') → Tensor
Performs a forward pass through the MLP.
Args:
x
(Tensor): a tensor of shape (batch_size, input_dim)
Returns:
Tensor
: a tensor of shape (batch_size, output_dim)
class GatedMLP
Gated MLP similar model structure is used in CGCNN and M3GNet.
method __init__
__init__(
input_dim: 'int',
output_dim: 'int',
hidden_dim: 'int | list[int] | None' = None,
dropout: 'float' = 0,
activation: 'str' = 'silu',
norm: 'str' = 'batch',
bias: 'bool' = True
) → None
Initialize a gated MLP.
Args:
input_dim
(int): the input dimensionoutput_dim
(int): the output dimensionhidden_dim
(list[int] | int]): a list of integers or a single integer representing the number of hidden units in each layer of the MLP. Default = Nonedropout
(float): the dropout rate before each linear layer.Default
: 0activation
(str, optional): The name of the activation function to use in the gated MLP. Must be one of “relu”, “silu”, “tanh”, or “gelu”. Default = “silu”norm
(str, optional): The name of the normalization layer to use on the updated atom features. Must be one of “batch”, “layer”, or None. Default = “batch”bias
(bool): whether to use bias in each Linear layers. Default = True
method forward
forward(x: 'Tensor') → Tensor
Performs a forward pass through the MLP.
Args:
x
(Tensor): a tensor of shape (batch_size, input_dim)
Returns:
Tensor
: a tensor of shape (batch_size, output_dim)
class ScaledSiLU
Scaled Sigmoid Linear Unit.
method __init__
__init__() → None
Initialize a scaled SiLU.
method forward
forward(x: 'Tensor') → Tensor
Forward pass.
module model.layers
class AtomConv
A convolution Layer to update atom features.
method __init__
__init__(
atom_fea_dim: 'int',
bond_fea_dim: 'int',
hidden_dim: 'int' = 64,
dropout: 'float' = 0,
activation: 'str' = 'silu',
norm: 'str | None' = None,
use_mlp_out: 'bool' = True,
mlp_out_bias: 'bool' = False,
resnet: 'bool' = True,
gMLP_norm: 'str | None' = None
) → None
Initialize the AtomConv layer.
Args:
atom_fea_dim
(int): The dimensionality of the input atom features.bond_fea_dim
(int): The dimensionality of the input bond features.hidden_dim
(int, optional): The dimensionality of the hidden layers in the gated MLP. Default = 64dropout
(float, optional): The dropout rate to apply to the gated MLP. Default = 0.activation
(str, optional): The name of the activation function to use in the gated MLP. Must be one of “relu”, “silu”, “tanh”, or “gelu”. Default = “silu”norm
(str, optional): The name of the normalization layer to use on the updated atom features. Must be one of “batch”, “layer”, or None. Default = Noneuse_mlp_out
(bool, optional): Whether to apply an MLP output layer to the updated atom features. Default = Truemlp_out_bias
(bool): whether to use bias in the output MLP Linear layer. Default = Falseresnet
(bool, optional): Whether to apply a residual connection to the updated atom features. Default = TruegMLP_norm
(str, optional): The name of the normalization layer to use on the gated MLP. Must be one of “batch”, “layer”, or None. Default = None
method forward
forward(
atom_feas: 'Tensor',
bond_feas: 'Tensor',
bond_weights: 'Tensor',
atom_graph: 'Tensor',
directed2undirected: 'Tensor'
) → Tensor
Forward pass of AtomConv module that updates the atom features and optionally bond features.
Args:
atom_feas
(Tensor): Input tensor with shape [num_batch_atoms, atom_fea_dim]bond_feas
(Tensor): Input tensor with shape [num_undirected_bonds, bond_fea_dim]bond_weights
(Tensor): AtomGraph bond weights with shape [num_undirected_bonds, bond_fea_dim]atom_graph
(Tensor): Directed AtomGraph adjacency list with shape [num_directed_bonds, 2]directed2undirected
(Tensor): Index tensor that maps directed bonds to undirected bonds.with shape [num_undirected_bonds]
Returns:
Tensor
: the updated atom features tensor with shape [num_batch_atom, atom_fea_dim]
Notes:
- num_batch_atoms = sum(num_atoms) in batch
class BondConv
A convolution Layer to update bond features.
method __init__
__init__(
atom_fea_dim: 'int',
bond_fea_dim: 'int',
angle_fea_dim: 'int',
hidden_dim: 'int' = 64,
dropout: 'float' = 0,
activation: 'str' = 'silu',
norm: 'str | None' = None,
use_mlp_out: 'bool' = True,
mlp_out_bias: 'bool' = False,
resnet=True,
gMLP_norm: 'str | None' = None
) → None
Initialize the BondConv layer.
Args:
atom_fea_dim
(int): The dimensionality of the input atom features.bond_fea_dim
(int): The dimensionality of the input bond features.angle_fea_dim
(int): The dimensionality of the input angle features.hidden_dim
(int, optional): The dimensionality of the hidden layers in the gated MLP. Default = 64dropout
(float, optional): The dropout rate to apply to the gated MLP. Default = 0.activation
(str, optional): The name of the activation function to use in the gated MLP. Must be one of “relu”, “silu”, “tanh”, or “gelu”. Default = “silu”norm
(str, optional): The name of the normalization layer to use on the updated atom features. Must be one of “batch”, “layer”, or None. Default = Noneuse_mlp_out
(bool, optional): Whether to apply an MLP output layer to the updated atom features. Default = Truemlp_out_bias
(bool): whether to use bias in the output MLP Linear layer. Default = Falseresnet
(bool, optional): Whether to apply a residual connection to the updated atom features. Default = TruegMLP_norm
(str, optional): The name of the normalization layer to use on the gated MLP. Must be one of “batch”, “layer”, or None. Default = None
method forward
forward(
atom_feas: 'Tensor',
bond_feas: 'Tensor',
bond_weights: 'Tensor',
angle_feas: 'Tensor',
bond_graph: 'Tensor'
) → Tensor
Update the bond features.
Args:
atom_feas
(Tensor): atom features tensor with shape [num_batch_atoms, atom_fea_dim]bond_feas
(Tensor): bond features tensor with shape [num_undirected_bonds, bond_fea_dim]bond_weights
(Tensor): BondGraph bond weights with shape [num_undirected_bonds, bond_fea_dim]angle_feas
(Tensor): angle features tensor with shape [num_batch_angles, angle_fea_dim]bond_graph
(Tensor): Directed BondGraph tensor with shape [num_batched_angles, 3]
Returns:
new_bond_feas
(Tensor): bond feature tensor with shape [num_undirected_bonds, bond_fea_dim]
Notes:
- num_batch_atoms = sum(num_atoms) in batch
class AngleUpdate
Update angle features.
method __init__
__init__(
atom_fea_dim: 'int',
bond_fea_dim: 'int',
angle_fea_dim: 'int',
hidden_dim: 'int' = 0,
dropout: 'float' = 0,
activation: 'str' = 'silu',
norm: 'str | None' = None,
resnet: 'bool' = True,
gMLP_norm: 'str | None' = None
) → None
Initialize the AngleUpdate layer.
Args:
atom_fea_dim
(int): The dimensionality of the input atom features.bond_fea_dim
(int): The dimensionality of the input bond features.angle_fea_dim
(int): The dimensionality of the input angle features.hidden_dim
(int, optional): The dimensionality of the hidden layers in the gated MLP. Default = 0dropout
(float, optional): The dropout rate to apply to the gated MLP. Default = 0.activation
(str, optional): The name of the activation function to use in the gated MLP. Must be one of “relu”, “silu”, “tanh”, or “gelu”. Default = “silu”norm
(str, optional): The name of the normalization layer to use on the updated atom features. Must be one of “batch”, “layer”, or None. Default = Noneresnet
(bool, optional): Whether to apply a residual connection to the updated atom features. Default = TruegMLP_norm
(str, optional): The name of the normalization layer to use on the gated MLP. Must be one of “batch”, “layer”, or None. Default = None
method forward
forward(
atom_feas: 'Tensor',
bond_feas: 'Tensor',
angle_feas: 'Tensor',
bond_graph: 'Tensor'
) → Tensor
Update the angle features using bond graph.
Args:
atom_feas
(Tensor): atom features tensor with shape [num_batch_atoms, atom_fea_dim]bond_feas
(Tensor): bond features tensor with shape [num_undirected_bonds, bond_fea_dim]angle_feas
(Tensor): angle features tensor with shape [num_batch_angles, angle_fea_dim]bond_graph
(Tensor): Directed BondGraph tensor with shape [num_batched_angles, 3]
Returns:
new_angle_feas
(Tensor): angle features tensor with shape [num_batch_angles, angle_fea_dim]
Notes:
- num_batch_atoms = sum(num_atoms) in batch
class GraphPooling
Pooling the sub-graphs in the batched graph.
method __init__
__init__(average: 'bool' = False) → None
Args: average (bool): whether to average the features.
method forward
forward(atom_feas: 'Tensor', atom_owner: 'Tensor') → Tensor
Merge the atom features that belong to same graph in a batched graph.
Args:
atom_feas
(Tensor): batched atom features after convolution layers. [num_batch_atoms, atom_fea_dim or 1]atom_owner
(Tensor): graph indices for each atom. [num_batch_atoms]
Returns:
crystal_feas
(Tensor): crystal feature matrix. [n_crystals, atom_fea_dim or 1]
class GraphAttentionReadOut
Multi Head Attention Read Out Layer merge the information from atom_feas to crystal_fea.
method __init__
__init__(
atom_fea_dim: 'int',
num_head: 'int' = 3,
hidden_dim: 'int' = 32,
average=False
) → None
Initialize the layer.
Args:
atom_fea_dim
(int): atom feature dimensionnum_head
(int): number of attention heads usedhidden_dim
(int): dimension of hidden layeraverage
(bool): whether to average the features
method forward
forward(atom_feas: 'Tensor', atom_owner: 'Tensor') → Tensor
Merge the atom features that belong to same graph in a batched graph.
Args:
atom_feas
(Tensor): batched atom features after convolution layers. [num_batch_atoms, atom_fea_dim]atom_owner
(Tensor): graph indices for each atom. [num_batch_atoms]
Returns:
crystal_feas
(Tensor): crystal feature matrix. [n_crystals, atom_fea_dim]
module model.model
Global Variables
- TYPE_CHECKING
- module_dir
class CHGNet
Crystal Hamiltonian Graph neural Network A model that takes in a crystal graph and output energy, force, magmom, stress.
method __init__
__init__(
atom_fea_dim: 'int' = 64,
bond_fea_dim: 'int' = 64,
angle_fea_dim: 'int' = 64,
composition_model: 'str | Module' = 'MPtrj',
num_radial: 'int' = 31,
num_angular: 'int' = 31,
n_conv: 'int' = 4,
atom_conv_hidden_dim: 'Sequence[int] | int' = 64,
update_bond: 'bool' = True,
bond_conv_hidden_dim: 'Sequence[int] | int' = 64,
update_angle: 'bool' = True,
angle_layer_hidden_dim: 'Sequence[int] | int' = 0,
conv_dropout: 'float' = 0,
read_out: 'str' = 'ave',
mlp_hidden_dims: 'Sequence[int] | int' = (64, 64, 64),
mlp_dropout: 'float' = 0,
mlp_first: 'bool' = True,
is_intensive: 'bool' = True,
non_linearity: "Literal['silu', 'relu', 'tanh', 'gelu']" = 'silu',
atom_graph_cutoff: 'float' = 6,
bond_graph_cutoff: 'float' = 3,
graph_converter_algorithm: "Literal['legacy', 'fast']" = 'fast',
cutoff_coeff: 'int' = 8,
learnable_rbf: 'bool' = True,
gMLP_norm: 'str | None' = 'layer',
readout_norm: 'str | None' = 'layer',
version: 'str | None' = None,
**kwargs
) → None
Initialize CHGNet.
Args:
atom_fea_dim
(int): atom feature vector embedding dimension. Default = 64bond_fea_dim
(int): bond feature vector embedding dimension. Default = 64angle_fea_dim
(int): angle feature vector embedding dimension. Default = 64bond_fea_dim
(int): angle feature vector embedding dimension. Default = 64composition_model
(nn.Module, optional): attach a composition model to predict energy or initialize a pretrained linear regression (AtomRef). The default ‘MPtrj’ is the atom reference energy linear regression trained on all Materials Project relaxation trajectories Default = ‘MPtrj’num_radial
(int): number of radial basis used in bond basis expansion. Default = 9num_angular
(int): number of angular basis used in angle basis expansion. Default = 9n_conv
(int): number of interaction blocks. Default = 4Note
: last interaction block contain only an atom_conv layeratom_conv_hidden_dim
(List or int): hidden dimensions of atom convolution layers. Default = 64update_bond
(bool): whether to use bond_conv_layer in bond graph to update bond embeddings Default = True.bond_conv_hidden_dim
(List or int): hidden dimensions of bond convolution layers. Default = 64update_angle
(bool): whether to use angle_update_layer to update angle embeddings. Default = Trueangle_layer_hidden_dim
(List or int): hidden dimensions of angle layers. Default = 0conv_dropout
(float): dropout rate in all conv_layers. Default = 0read_out
(str): method for pooling layer, ‘ave’ for standard average pooling, ‘attn’ for multi-head attention. Default = “ave”mlp_hidden_dims
(int or list): readout multilayer perceptron hidden dimensions. Default = [64, 64]mlp_dropout
(float): dropout rate in readout MLP. Default = 0.is_intensive
(bool): whether the energy training label is intensive i.e. energy per atom. Default = Truenon_linearity
(‘silu’ | ‘relu’ | ‘tanh’ | ‘gelu’): The name of the activation function to use in the gated MLP. Default = “silu”.mlp_first
(bool): whether to apply mlp first then pooling. if set to True, then CHGNet is essentially calculating energy for each atom, them sum them up, this is used for the pretrained model Default = Trueatom_graph_cutoff
(float): cutoff radius (A) in creating atom_graph, this need to be consistent with the value in training dataloader Default = 5bond_graph_cutoff
(float): cutoff radius (A) in creating bond_graph, this need to be consistent with value in training dataloader Default = 3graph_converter_algorithm
(‘legacy’ | ‘fast’): algorithm to use for converting pymatgen.core.Structure to CrystalGraph.'legacy'
: python implementation of graph creation'fast'
: C implementation of graph creation, this is faster, but will need the cygraph.c file correctly compiled from pip install default = ‘fast’cutoff_coeff
(float): cutoff strength used in graph smooth cutoff function. the smaller this coeff is, the smoother the basis is Default = 5learnable_rbf
(bool): whether to set the frequencies in rbf and Fourier basis functions learnable. Default = TruegMLP_norm
(str): normalization layer to use in gate-MLP Default = ‘layer’readout_norm
(str): normalization layer to use before readout layer Default = ‘layer’version
(str): Pretrained checkpoint version.**kwargs
: Additional keyword arguments
property n_params
Return the number of parameters in the model.
property version
Return the version of the loaded checkpoint.
method as_dict
as_dict() → dict
Return the CHGNet weights and args in a dictionary.
method forward
forward(
graphs: 'Sequence[CrystalGraph]',
task: 'PredTask' = 'e',
return_site_energies: 'bool' = False,
return_atom_feas: 'bool' = False,
return_crystal_feas: 'bool' = False
) → dict[str, Tensor]
Get prediction associated with input graphs
Args:
graphs
(List): a list of CrystalGraphstask
(str): the prediction task. One of ‘e’, ‘em’, ‘ef’, ‘efs’, ‘efsm’. Default = ‘e’return_site_energies
(bool): whether to return per-site energies, only available if self.mlp_first == True Default = Falsereturn_atom_feas
(bool): whether to return the atom features before last conv layer. Default = Falsereturn_crystal_feas
(bool): whether to return crystal feature. Default = False
Returns: model output (dict).
classmethod from_dict
from_dict(dct: 'dict', **kwargs) → Self
Build a CHGNet from a saved dictionary.
classmethod from_file
from_file(path: 'str', **kwargs) → Self
Build a CHGNet from a saved file.
classmethod load
load(
model_name: 'str' = '0.3.0',
use_device: 'str | None' = None,
check_cuda_mem: 'bool' = False,
verbose: 'bool' = True
) → Self
Load pretrained CHGNet model.
Args: model_name (str, optional): Default = “0.3.0”.
use_device
(str, optional): The device to be used for predictions, either “cpu”, “cuda”, or “mps”. If not specified, the default device is automatically selected based on the available options. Default = Nonecheck_cuda_mem
(bool): Whether to use cuda with most available memory Default = Falseverbose
(bool): whether to print model device information Default = True
Raises:
ValueError
: On unknown model_name.
method predict_graph
predict_graph(
graph: 'CrystalGraph | Sequence[CrystalGraph]',
task: 'PredTask' = 'efsm',
return_site_energies: 'bool' = False,
return_atom_feas: 'bool' = False,
return_crystal_feas: 'bool' = False,
batch_size: 'int' = 16
) → dict[str, Tensor] | list[dict[str, Tensor]]
Predict from CrustalGraph.
Args:
graph
(CrystalGraph | Sequence[CrystalGraph]): CrystalGraph(s) to predict.task
(str): can be ‘e’ ‘ef’, ‘em’, ‘efs’, ‘efsm’ Default = “efsm”return_site_energies
(bool): whether to return per-site energies. Default = Falsereturn_atom_feas
(bool): whether to return atom features. Default = Falsereturn_crystal_feas
(bool): whether to return crystal features. Default = Falsebatch_size
(int): batch_size for predict structures. Default = 16
Returns:
prediction
(dict): dict or list of dict containing the fields:e (Tensor)
: energy of structures float in eV/atomf (Tensor)
: force on atoms [num_atoms, 3] in eV/As (Tensor)
: stress of structure [3, 3] in GPam (Tensor)
: magnetic moments of sites [num_atoms, 3] in Bohr magneton mu_B
method predict_structure
predict_structure(
structure: 'Structure | Sequence[Structure]',
task: 'PredTask' = 'efsm',
return_site_energies: 'bool' = False,
return_atom_feas: 'bool' = False,
return_crystal_feas: 'bool' = False,
batch_size: 'int' = 16
) → dict[str, Tensor] | list[dict[str, Tensor]]
Predict from pymatgen.core.Structure.
Args:
structure
(Structure | Sequence[Structure]): structure or a list of structures to predict.task
(str): can be ‘e’ ‘ef’, ‘em’, ‘efs’, ‘efsm’ Default = “efsm”return_site_energies
(bool): whether to return per-site energies. Default = Falsereturn_atom_feas
(bool): whether to return atom features. Default = Falsereturn_crystal_feas
(bool): whether to return crystal features. Default = Falsebatch_size
(int): batch_size for predict structures. Default = 16
Returns:
prediction
(dict): dict or list of dict containing the fields:e (Tensor)
: energy of structures float in eV/atomf (Tensor)
: force on atoms [num_atoms, 3] in eV/As (Tensor)
: stress of structure [3, 3] in GPam (Tensor)
: magnetic moments of sites [num_atoms, 3] in Bohr magneton mu_B
method todict
todict() → dict
Needed for ASE JSON serialization when saving CHGNet potential to trajectory file (https://github.com/CederGroupHub/chgnet/issues/48).
class BatchedGraph
Batched crystal graph for parallel computing.
Attributes:
atomic_numbers
(Tensor): atomic numbers vector [num_batch_atoms]bond_bases_ag
(Tensor): bond bases vector for atom_graph [num_batch_bonds_ag, num_radial]bond_bases_bg
(Tensor): bond bases vector for atom_graph [num_batch_bonds_bg, num_radial]angle_bases
(Tensor): angle bases vector [num_batch_angles, num_angular]batched_atom_graph (Tensor)
: batched atom graph adjacency list [num_batch_bonds, 2]batched_bond_graph (Tensor)
: bond graph adjacency list [num_batch_angles, 3]atom_owners
(Tensor): graph indices for each atom, used aggregate batched graph back to single graph [num_batch_atoms]directed2undirected
(Tensor): the utility tensor used to quickly map directed edges to undirected edges in graph [num_directed]atom_positions
(list[Tensor]): cartesian coordinates of the atoms from structures [[num_atoms_1, 3], [num_atoms_2, 3], …]strains
(list[Tensor]): a list of strains that’s initialized to be zeros [[3, 3], [3, 3], …]volumes
(Tensor): the volume of each structure in the batch [batch_size]
method __init__
__init__(
atomic_numbers: 'Tensor',
bond_bases_ag: 'Tensor',
bond_bases_bg: 'Tensor',
angle_bases: 'Tensor',
batched_atom_graph: 'Tensor',
batched_bond_graph: 'Tensor',
atom_owners: 'Tensor',
directed2undirected: 'Tensor',
atom_positions: 'Sequence[Tensor]',
strains: 'Sequence[Tensor]',
volumes: 'Sequence[Tensor] | Tensor'
) → None
classmethod from_graphs
from_graphs(
graphs: 'Sequence[CrystalGraph]',
bond_basis_expansion: 'Module',
angle_basis_expansion: 'Module',
compute_stress: 'bool' = False
) → Self
Featurize and assemble a list of graphs.
Args:
graphs
(list[Tensor]): a list of CrystalGraphsbond_basis_expansion
(nn.Module): bond basis expansion layer in CHGNetangle_basis_expansion
(nn.Module): angle basis expansion layer in CHGNetcompute_stress
(bool): whether to compute stress. Default = False
Returns:
BatchedGraph
: assembled graphs ready for batched CHGNet forward pass
module trainer.trainer
Global Variables
- TYPE_CHECKING
- wandb
- LogEachEpoch
- LogEachBatch
class Trainer
A trainer to train CHGNet using energy, force, stress and magmom.
method __init__
__init__(
model: 'CHGNet | None' = None,
targets: 'TrainTask' = 'ef',
energy_loss_ratio: 'float' = 1,
force_loss_ratio: 'float' = 1,
stress_loss_ratio: 'float' = 0.1,
mag_loss_ratio: 'float' = 0.1,
optimizer: 'str' = 'Adam',
scheduler: 'str' = 'CosLR',
criterion: 'str' = 'MSE',
epochs: 'int' = 50,
starting_epoch: 'int' = 0,
learning_rate: 'float' = 0.001,
print_freq: 'int' = 100,
torch_seed: 'int | None' = None,
data_seed: 'int | None' = None,
use_device: 'str | None' = None,
check_cuda_mem: 'bool' = False,
wandb_path: 'str | None' = None,
wandb_init_kwargs: 'dict | None' = None,
extra_run_config: 'dict | None' = None,
**kwargs
) → None
Initialize all hyper-parameters for trainer.
Args:
model
(nn.Module): a CHGNet modeltargets
(“ef” | “efs” | “efsm”): The training targets. Default = “ef”energy_loss_ratio
(float): energy loss ratio in loss function Default = 1force_loss_ratio
(float): force loss ratio in loss function Default = 1stress_loss_ratio
(float): stress loss ratio in loss function Default = 0.1mag_loss_ratio
(float): magmom loss ratio in loss function Default = 0.1optimizer
(str): optimizer to update model. Can be “Adam”, “SGD”, “AdamW”, “RAdam”. Default = ‘Adam’scheduler
(str): learning rate scheduler. Can be “CosLR”, “ExponentialLR”, “CosRestartLR”. Default = ‘CosLR’criterion
(str): loss function criterion. Can be “MSE”, “Huber”, “MAE” Default = ‘MSE’epochs
(int): number of epochs for training Default = 50starting_epoch
(int): The epoch number to start training at.learning_rate
(float): initial learning rate Default = 1e-3print_freq
(int): frequency to print training output Default = 100torch_seed
(int): random seed for torch Default = Nonedata_seed
(int): random seed for random Default = Noneuse_device
(str, optional): The device to be used for predictions, either “cpu”, “cuda”, or “mps”. If not specified, the default device is automatically selected based on the available options. Default = Nonecheck_cuda_mem
(bool): Whether to use cuda with most available memory Default = Falsewandb_path
(str | None): The project and run name separated by a slash: “project/run_name”. If None, wandb logging is not used. Default = Nonewandb_init_kwargs
(dict): Additional kwargs to pass to wandb.init. Default = Noneextra_run_config
(dict): Additional hyper-params to be recorded by wandb that are not included in the trainer_args. Default = None
**kwargs (dict)
: additional hyper-params for optimizer, scheduler, etc.
Raises:
NotImplementedError
: If the optimizer or scheduler is not implementedImportError
: If wandb_path is specified but wandb is not installedValueError
: If wandb_path is specified but not in the format ‘project/run_name’
method get_best_model
get_best_model() → CHGNet
Get best model recorded in the trainer.
Returns:
CHGNet
: the model with lowest validation set energy error
classmethod load
load(path: 'str') → Self
Load trainer state_dict.
Args:
path
(str): path to the saved model
Returns:
Trainer
: the loaded trainer
method move_to
move_to(obj: 'Tensor | list[Tensor]', device: 'device') → Tensor | list[Tensor]
Move object to device.
Args:
obj
(Tensor | list[Tensor]): object(s) to move to devicedevice
(torch.device): device to move object to
Raises:
TypeError
: if obj is not a tensor or list of tensors
Returns:
Tensor | list[Tensor]
: moved object(s)
method save
save(filename: 'str' = 'training_result.pth.tar') → None
Save the model, graph_converter, etc.
method save_checkpoint
save_checkpoint(epoch: 'int', mae_error: 'dict', save_dir: 'str') → None
Function to save CHGNet trained weights after each epoch.
Args:
epoch
(int): the epoch numbermae_error
(dict): dictionary that stores the MAEssave_dir
(str): the directory to save trained weights
method train
train(
train_loader: 'DataLoader',
val_loader: 'DataLoader',
test_loader: 'DataLoader | None' = None,
save_dir: 'str | None' = None,
save_test_result: 'bool' = False,
train_composition_model: 'bool' = False,
wandb_log_freq: 'LogFreq' = 'batch'
) → None
Train the model using torch data_loaders.
Args:
train_loader
(DataLoader): train loader to update CHGNet weightsval_loader
(DataLoader): val loader to test accuracy after each epochtest_loader
(DataLoader): test loader to test accuracy at end of training. Can be None. Default = Nonesave_dir
(str): the dir name to save the trained weights Default = Nonesave_test_result
(bool): Whether to save the test set prediction in a JSON file. Default = Falsetrain_composition_model
(bool): whether to train the composition model (AtomRef), this is suggested when the fine-tuning dataset has large elemental energy shift from the pretrained CHGNet, which typically comes from different DFT pseudo-potentials. Default = Falsewandb_log_freq
(“epoch” | “batch”): Frequency of logging to wandb. ‘epoch’ logs once per epoch, ‘batch’ logs after every batch. Default = “batch”
Raises:
ValueError
: If model is not initialized
class CombinedLoss
A combined loss function of energy, force, stress and magmom.
method __init__
__init__(
target_str: 'str' = 'ef',
criterion: 'str' = 'MSE',
is_intensive: 'bool' = True,
energy_loss_ratio: 'float' = 1,
force_loss_ratio: 'float' = 1,
stress_loss_ratio: 'float' = 0.1,
mag_loss_ratio: 'float' = 0.1,
delta: 'float' = 0.1
) → None
Initialize the combined loss.
Args:
target_str
: the training target label. Can be “e”, “ef”, “efs”, “efsm” etc. Default = “ef”criterion
: loss criterion to use Default = “MSE”is_intensive
(bool): whether the energy label is intensive Default = Trueenergy_loss_ratio
(float): energy loss ratio in loss function Default = 1force_loss_ratio
(float): force loss ratio in loss function Default = 1stress_loss_ratio
(float): stress loss ratio in loss function Default = 0.1mag_loss_ratio
(float): magmom loss ratio in loss function Default = 0.1delta
(float): delta for torch.nn.HuberLoss. Default = 0.1
method forward
forward(
targets: 'dict[str, Tensor]',
prediction: 'dict[str, Tensor]'
) → dict[str, Tensor]
Compute the combined loss using CHGNet prediction and labels this function can automatically mask out magmom loss contribution of data points without magmom labels.
Args:
targets
(dict): DFT labelsprediction
(dict): CHGNet prediction
Returns: dictionary of all the loss, MAE and MAE_size
module utils.common_utils
function determine_device
determine_device(
use_device: 'str | None' = None,
check_cuda_mem: 'bool' = False
) → str
Determine the device to use for torch model.
Args:
use_device
(str): User specify device namecheck_cuda_mem
(bool): Whether to return cuda with most available memory Default = False
Returns:
device
(str): device name to be passed to model.to(device)
function cuda_devices_sorted_by_free_mem
cuda_devices_sorted_by_free_mem() → list[int]
List available CUDA devices sorted by increasing available memory.
To get the device with the most free memory, use the last list item.
Returns:
list[int]
: CUDA device numbers sorted by increasing free memory.
function mae
mae(prediction: 'Tensor', target: 'Tensor') → Tensor
Computes the mean absolute error between prediction and target.
Args:
prediction
: Tensor (N, 1)target
: Tensor (N, 1).
Returns: tensor
function read_json
read_json(filepath: 'str') → dict
Read the JSON file.
Args:
filepath
(str): file name of JSON to read.
Returns:
dict
: data stored in filepath
function write_json
write_json(dct: 'dict', filepath: 'str') → dict
Write the JSON file.
Args:
dct
(dict): dictionary to writefilepath
(str): file name of JSON to write.
function mkdir
mkdir(path: 'str') → str
Make directory.
Args:
path
(str): directory name
Returns: path
class AverageMeter
Computes and stores the average and current value.
method __init__
__init__() → None
Initialize the meter.
method reset
reset() → None
Reset the meter value, average, sum and count to 0.
method update
update(val: 'float', n: 'int' = 1) → None
Update the meter value, average, sum and count.
Args:
val
(float): New value to be added to the running average.n
(int, optional): Number of times the value is added. Default = 1.
module utils.vasp_utils
function parse_vasp_dir
parse_vasp_dir(
base_dir: 'str',
check_electronic_convergence: 'bool' = True,
save_path: 'str | None' = None
) → dict[str, list]
Parse VASP output files into structures and labels By default, the magnetization is read from mag_x from VASP, plz modify the code if magnetization is for (y) and (z).
Args:
base_dir
(str): the directory of the VASP calculation outputscheck_electronic_convergence
(bool): if set to True, this function will raise Exception to VASP calculation that did not achieve electronic convergence. Default = Truesave_path
(str): path to save the parsed VASP labels
Raises:
NotADirectoryError
: if the base_dir is not a directory
Returns:
dict
: a dictionary of lists with keys for structure, uncorrected_total_energy, energy_per_atom, force, magmom, stress.
function solve_charge_by_mag
solve_charge_by_mag(
structure: 'Structure',
default_ox: 'dict[str, float] | None' = None,
ox_ranges: 'dict[str, dict[tuple[float, float], int]] | None' = None
) → Structure | None
Solve oxidation states by magmom.
Args:
structure
(Structure): pymatgen structure with magmoms in site_properties. Dict key must be either magmom or final_magmom.default_ox
(dict[str, float]): default oxidation state for elements. Default = dict(Li=1, O=-2)ox_ranges
(dict[str, dict[tuple[float, float], int]]): user-defined range to convert magmoms into formal valence. Example for Mn (Default):("Mn"
: ((0.5, 1.5)
: 2,(1.5, 2.5)
: 3,(2.5, 3.5)
: 4,(3.5, 4.2)
: 3,(4.2, 5)
: 2 ))
Returns:
Structure
: pymatgen Structure with oxidation states assigned based on magmoms.