module data.dataset

function collate_graphs

    batch_data: 'list'
)tuple[list[CrystalGraph], dict[str, Tensor]]

Collate of list of (graph, target) into batch data.



function get_train_val_test_loader

    dataset: 'Dataset',
    batch_size: 'int' = 64,
    train_ratio: 'float' = 0.8,
    val_ratio: 'float' = 0.1,
    return_test: 'bool' = True,
    num_workers: 'int' = 0,
    pin_memory: 'bool' = True
)tuple[DataLoader, DataLoader, DataLoader]

Randomly partition a dataset into train, val, test loaders.


Returns: train_loader, val_loader and optionally test_loader

function get_loader

    batch_size: 'int' = 64,
    num_workers: 'int' = 0,
    pin_memory: 'bool' = True
) → DataLoader

Get a dataloader from a dataset.


Returns: data_loader

class StructureData

A simple torch Dataset of structures.

method __init__

    structures: 'list[Structure]',
    energies: 'list[float]',
    forces: 'list[Sequence[Sequence[float]]]',
    stresses: 'list[Sequence[Sequence[float]]] | None' = None,
    magmoms: 'list[Sequence[Sequence[float]]] | None' = None,
    structure_ids: 'list | None' = None,
    graph_converter: 'CrystalGraphConverter | None' = None,
    shuffle: 'bool' = True

Initialize the dataset.



classmethod from_vasp

    file_root: 'str',
    check_electronic_convergence: 'bool' = True,
    save_path: 'str | None' = None,
    graph_converter: 'CrystalGraphConverter | None' = None,
    shuffle: 'bool' = True
) → Self

Parse VASP output files into structures and labels and feed into the dataset.


class CIFData

A dataset from CIFs.

method __init__

    cif_path: 'str',
    labels: 'str | dict' = 'labels.json',
    targets: 'TrainTask' = 'efsm',
    graph_converter: 'CrystalGraphConverter | None' = None,
    energy_key: 'str' = 'energy_per_atom',
    force_key: 'str' = 'force',
    stress_key: 'str' = 'stress',
    magmom_key: 'str' = 'magmom',
    shuffle: 'bool' = True

Initialize the dataset from a directory containing CIFs.


class GraphData

A dataset of graphs. This is compatible with the graph.pt documents made by make_graphs.py. We recommend you to use the dataset to avoid graph conversion steps.

method __init__

    graph_path: 'str',
    labels: 'str | dict' = 'labels.json',
    targets: 'TrainTask' = 'efsm',
    exclude: 'str | list | None' = None,
    energy_key: 'str' = 'energy_per_atom',
    force_key: 'str' = 'force',
    stress_key: 'str' = 'stress',
    magmom_key: 'str' = 'magmom',
    shuffle: 'bool' = True

Initialize the dataset from a directory containing saved crystal graphs.


method get_train_val_test_loader

    train_ratio: 'float' = 0.8,
    val_ratio: 'float' = 0.1,
    train_key: 'list[str] | None' = None,
    val_key: 'list[str] | None' = None,
    test_key: 'list[str] | None' = None,
)tuple[DataLoader, DataLoader, DataLoader]

Partition the GraphData using materials id, randomly select the train_keys, val_keys, test_keys by train val test ratio, or use pre-defined train_keys, val_keys, and test_keys to create train, val, test loaders.


Returns: train_loader, val_loader, test_loader

class StructureJsonData

Read structure and targets from a JSON file. This class is used to load the MPtrj dataset.

method __init__

    data: 'str | dict',
    graph_converter: 'CrystalGraphConverter',
    targets: 'TrainTask' = 'efsm',
    energy_key: 'str' = 'energy_per_atom',
    force_key: 'str' = 'force',
    stress_key: 'str' = 'stress',
    magmom_key: 'str' = 'magmom',
    shuffle: 'bool' = True

Initialize the dataset by reading JSON files.


method get_train_val_test_loader

    train_ratio: 'float' = 0.8,
    val_ratio: 'float' = 0.1,
    train_key: 'list[str] | None' = None,
    val_key: 'list[str] | None' = None,
    test_key: 'list[str] | None' = None,
)tuple[DataLoader, DataLoader, DataLoader]

Partition the Dataset using materials id, randomly select the train_keys, val_keys, test_keys by train val test ratio, or use pre-defined train_keys, val_keys, and test_keys to create train, val, test loaders.


Returns: train_loader, val_loader, test_loader

module graph.converter

class CrystalGraphConverter

Convert a pymatgen.core.Structure to a CrystalGraph The CrystalGraph dataclass stores essential field to make sure that gradients like force and stress can be calculated through back-propagation later.

method __init__

    atom_graph_cutoff: 'float' = 6,
    bond_graph_cutoff: 'float' = 3,
    algorithm: "Literal['legacy', 'fast']" = 'fast',
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'error',
    verbose: 'bool' = False

Initialize the Crystal Graph Converter.


method as_dict

as_dict()dict[str, str | float]

Save the args of the graph converter.

method forward

forward(structure: 'Structure', graph_id=None, mp_id=None) → CrystalGraph

Convert a structure, return a CrystalGraph.


Return: CrystalGraph that is ready to use by CHGNet

classmethod from_dict

from_dict(dct: 'dict') → Self

Create converter from dictionary.

method set_isolated_atom_response

    on_isolated_atoms: "Literal['ignore', 'warn', 'error']"

Set the graph converter’s response to isolated atom graph


Returns: None

module graph.crystalgraph

class CrystalGraph

A data class for crystal graph.

method __init__

    atomic_number: 'Tensor',
    atom_frac_coord: 'Tensor',
    atom_graph: 'Tensor',
    atom_graph_cutoff: 'float',
    neighbor_image: 'Tensor',
    directed2undirected: 'Tensor',
    undirected2directed: 'Tensor',
    bond_graph: 'Tensor',
    bond_graph_cutoff: 'float',
    lattice: 'Tensor',
    graph_id: 'str | None' = None,
    mp_id: 'str | None' = None,
    composition: 'str | None' = None

Initialize the crystal graph.

Attention! This data class is not intended to be created manually. CrystalGraph should be returned by a CrystalGraphConverter



property num_isolated_atoms

Number of isolated atoms given the atom graph cutoff Isolated atoms are disconnected nodes in the atom graph that will not get updated in CHGNet. These atoms will always have calculated force equal to zero.

With the default CHGNet atom graph cutoff radius, only ~ 0.1% of MPtrj dataset structures has isolated atoms.

classmethod from_dict

from_dict(dic: 'dict[str, Any]') → Self

Load a CrystalGraph from a dictionary.

classmethod from_file

from_file(file_name: 'str') → Self

Load a crystal graph from a file.



method save

save(fname: 'str | None' = None, save_dir: 'str' = '.')str

Save the graph to a file.



method to

to(device: 'str' = 'cpu') → CrystalGraph

Move the graph to a device. Default = ‘cpu’.

method to_dict

to_dict()dict[str, Any]

Convert the graph to a dictionary.

module graph.graph

class Node

A node in a graph.

method __init__

__init__(index: 'int', info: 'dict | None' = None)None

Initialize a Node.


method add_neighbor

add_neighbor(index, edge)None

Draw an directed edge between self and the node specified by index.


class Edge

Abstract base class for edges in a graph.

method __init__

    nodes: 'list',
    index: 'int | None' = None,
    info: 'dict | None' = None

Initialize an Edge.

class UndirectedEdge

An undirected/bi-directed edge in a graph.

source link

    nodes: 'list',
    index: 'int | None' = None,
    info: 'dict | None' = None

Initialize an Edge.

class DirectedEdge

A directed edge in a graph.

source link

    nodes: 'list',
    index: 'int | None' = None,
    info: 'dict | None' = None

Initialize an Edge.

method make_undirected

make_undirected(index: 'int', info: 'dict | None' = None) → UndirectedEdge

Make a directed edge undirected.

class Graph

A graph for storing the neighbor information of atoms.

method __init__

__init__(nodes: 'list[Node]')None

Initialize a Graph from a list of nodes.

method add_edge

    dist_tol: 'float' = 1e-06

Add an directed edge to the graph.


method adjacency_list

adjacency_list()tuple[list[list[int]], list[int]]

Get the adjacency list Return: graph: the adjacency list [[0, 1], [0, 2], … [5, 2] … ]] the fist column specifies center/source node, the second column specifies neighbor/destination node directed2undirected: [0, 1, …] a list of length = num_directed_edge that specifies the undirected edge index corresponding to the directed edges represented in each row in the graph adjacency list.

method as_dict


Return dictionary serialization of a Graph.

method line_graph_adjacency_list

line_graph_adjacency_list(cutoff)tuple[list[list[int]], list[int]]

Get the line graph adjacency list.


Return: line_graph: [[0, 1, 1, 2, 2], [0, 1, 1, 4, 23], [1, 4, 23, 5, 66], … … ] the fist column specifies node(atom) index at this angle, the second column specifies 1st undirected edge(left bond) index, the third column specifies 1st directed edge(left bond) index, the fourth column specifies 2nd undirected edge(right bond) index, the fifth column specifies 2nd directed edge(right bond) index,. undirected2directed: [32, 45, …] a list of length = num_undirected_edge that maps the undirected edge index to one of its directed edges indices

method to


Save graph dictionary to file.

source link

The index map from undirected_edge index to one of its directed_edge index.

source link

source link

Fourier Expansion for angle features.

source link

__init__(order: 'int' = 5, learnable: 'bool' = False)None

Initialize the Fourier expansion.


method forward

forward(x: 'Tensor') → Tensor

Apply Fourier expansion to a feature Tensor.

class RadialBessel

1D Bessel Basis from: https://github.com/TUM-DAML/gemnet_pytorch/.

method __init__

    num_radial: 'int' = 9,
    cutoff: 'float' = 5,
    learnable: 'bool' = False,
    smooth_cutoff: 'int' = 5

Initialize the SmoothRBF function.


method forward

    dist: 'Tensor',
    return_smooth_factor: 'bool' = False
) → Tensor | tuple[Tensor, Tensor]

Apply Bessel expansion to a feature Tensor.



class GaussianExpansion

Expands the distance by Gaussian basis. Unit: angstrom.

method __init__

    min: 'float' = 0,
    max: 'float' = 5,
    step: 'float' = 0.5,
    var: 'float | None' = None

Gaussian Expansion expand a scalar feature to a soft-one-hot feature vector.


method expand

expand(features: 'Tensor') → Tensor

Apply Gaussian filter to a feature Tensor.



class CutoffPolynomial

Polynomial soft-cutoff function for atom graph ref: https://github.com/TUM-DAML/gemnet_pytorch/blob/-/gemnet/model/layers/envelope.py.

method __init__

__init__(cutoff: 'float' = 5, cutoff_coeff: 'float' = 5)None

Initialize the polynomial cutoff function.


method forward

forward(r: 'Tensor') → Tensor

Polynomial cutoff function.



module model.composition_model

class CompositionModel

A simple FC model that takes in a chemical composition (no structure info) and outputs energy.

method __init__

    atom_fea_dim: 'int' = 64,
    activation: 'str' = 'silu',
    is_intensive: 'bool' = True,
    max_num_elements: 'int' = 94

Initialize a CompositionModel.

method forward

forward(graphs: 'list[CrystalGraph]') → Tensor

Get the energy of a list of CrystalGraphs as Tensor.

class AtomRef

A linear regression for elemental energy. From: https://github.com/materialsvirtuallab/m3gnet/.

method __init__

__init__(is_intensive: 'bool' = True, max_num_elements: 'int' = 94)None

Initialize an AtomRef model.

method fit

    structures_or_graphs: 'Sequence[Structure | CrystalGraph]',
    energies: 'Sequence[float]'

Fit the model to a list of crystals and energies.


method forward

forward(graphs: 'list[CrystalGraph]') → Tensor

Get the energy of a list of CrystalGraphs.


Returns: energy (tensor)

method get_site_energies

get_site_energies(graphs: 'list[CrystalGraph]')list[Tensor]

Predict the site energies given a list of CrystalGraphs.


Returns: a list of tensors corresponding to site energies of each graph [batchsize].

method initialize_from

initialize_from(dataset: 'str')None

Initialize pre-fitted weights from a dataset.

method initialize_from_MPF


Initialize pre-fitted weights from MPF dataset.

method initialize_from_MPtrj


Initialize pre-fitted weights from MPtrj dataset.

method initialize_from_numpy

initialize_from_numpy(file_name: 'str | Path')None

Initialize pre-fitted weights from numpy file.

module model.dynamics

Global Variables

class CHGNetCalculator

CHGNet Calculator for ASE applications.

method __init__

    model: 'CHGNet | None' = None,
    use_device: 'str | None' = None,
    check_cuda_mem: 'bool' = False,
    stress_weight: 'float | None' = 0.006241509125883258,
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'warn',

Provide a CHGNet instance to calculate various atomic properties using ASE.


property directory

property label

property n_params

The number of parameters in CHGNet.

property name

property version

The version of CHGNet.

method calculate

    atoms: 'Atoms | None' = None,
    properties: 'list | None' = None,
    system_changes: 'list | None' = None

Calculate various properties of the atoms using CHGNet.


classmethod from_file

from_file(path: 'str', use_device: 'str | None' = None, **kwargs) → Self

Load a user’s CHGNet model and initialize the Calculator.

source link

class StructOptimizer

Wrapper class for structural relaxation.

method __init__

    model: 'CHGNet | CHGNetCalculator | None' = None,
    optimizer_class: 'Optimizer | str | None' = 'FIRE',
    use_device: 'str | None' = None,
    stress_weight: 'float' = 0.006241509125883258,
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'warn'

Provide a trained CHGNet model and an optimizer to relax crystal structures.


property n_params

The number of parameters in CHGNet.

property version

The version of CHGNet.

method relax

    atoms: 'Structure | Atoms',
    fmax: 'float | None' = 0.1,
    steps: 'int | None' = 500,
    relax_cell: 'bool | None' = True,
    ase_filter: 'str | None' = 'FrechetCellFilter',
    save_path: 'str | None' = None,
    loginterval: 'int | None' = 1,
    crystal_feas_save_path: 'str | None' = None,
    verbose: 'bool' = True,
    assign_magmoms: 'bool' = True,
)dict[str, Structure | TrajectoryObserver]

Relax the Structure/Atoms until maximum force is smaller than fmax.


Returns: dict[str, Structure | TrajectoryObserver]: A dictionary with ‘final_structure’ and ‘trajectory’.

class TrajectoryObserver

Trajectory observer is a hook in the relaxation process that saves the intermediate structures.

method __init__

__init__(atoms: 'Atoms')None

Create a TrajectoryObserver from an Atoms object.


method compute_energy


Calculate the potential energy.


method save

save(filename: 'str')None

Save the trajectory to file.


class CrystalFeasObserver

CrystalFeasObserver is a hook in the relaxation and MD process that saves the intermediate crystal feature structures.

method __init__

__init__(atoms: 'Atoms')None

Create a CrystalFeasObserver from an Atoms object.

method save

save(filename: 'str')None

Save the crystal feature vectors to filename in pickle format.

class MolecularDynamics

Molecular dynamics class.

method __init__

    atoms: 'Atoms | Structure',
    model: 'CHGNet | CHGNetCalculator | None' = None,
    ensemble: 'str' = 'nvt',
    thermostat: 'str' = 'Berendsen_inhomogeneous',
    temperature: 'int' = 300,
    starting_temperature: 'int | None' = None,
    timestep: 'float' = 2.0,
    pressure: 'float' = 0.000101325,
    taut: 'float | None' = None,
    taup: 'float | None' = None,
    bulk_modulus: 'float | None' = None,
    trajectory: 'str | Trajectory | None' = None,
    logfile: 'str | None' = None,
    loginterval: 'int' = 1,
    crystal_feas_logfile: 'str | None' = None,
    append_trajectory: 'bool' = False,
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'warn',
    use_device: 'str | None' = None

Initialize the MD class.


In NPT ensemble, the effective damping time for pressure is multiplied by compressibility. In LAMMPS, Bulk modulus is defaulted to 10

If bulk modulus is not provided here, it will be calculated by CHGNet through Birch Murnaghan equation of state (EOS). Note the EOS fitting can fail because of non-parabolic potential energy surface, which is common with soft system like liquid and gas. In such case, user should provide an input bulk modulus for better barostat coupling, otherwise a guessed bulk modulus = 2 GPa will be used (water’s bulk modulus)

Default = None

method run

run(steps: 'int')None

Thin wrapper of ase MD run.


method set_atoms

set_atoms(atoms: 'Atoms')None

Set new atoms to run MD.


method upper_triangular_cell

upper_triangular_cell(verbose: 'bool | None' = False)None

Transform to upper-triangular cell. ASE Nose-Hoover implementation only supports upper-triangular cell while ASE’s canonical description is lower-triangular cell.


class EquationOfState

Class to calculate equation of state.

method __init__

    model: 'CHGNet | CHGNetCalculator | None' = None,
    optimizer_class: 'Optimizer | str | None' = 'FIRE',
    use_device: 'str | None' = None,
    stress_weight: 'float' = 0.006241509125883258,
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'error'

Initialize a structure optimizer object for calculation of bulk modulus.


method fit

    atoms: 'Structure | Atoms',
    n_points: 'int' = 11,
    fmax: 'float | None' = 0.1,
    steps: 'int | None' = 500,
    verbose: 'bool | None' = False,

Relax the Structure/Atoms and fit the Birch-Murnaghan equation of state.


Returns: Bulk Modulus (float)

method get_bulk_modulus

get_bulk_modulus(unit: 'str' = 'eV/A^3')float

Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.


Returns: Bulk Modulus (float)

method get_compressibility

get_compressibility(unit: 'str' = 'A^3/eV')float

Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.


Returns: Bulk Modulus (float)

module model.encoders

class AtomEmbedding

Encode an atom by its atomic number using a learnable embedding layer.

method __init__

__init__(atom_feature_dim: 'int', max_num_elements: 'int' = 94)None

Initialize the Atom featurizer.


method forward

forward(atomic_numbers: 'Tensor') → Tensor

Convert the structure to a atom embedding tensor.



class BondEncoder

Encode a chemical bond given the positions of two atoms using Gaussian distance.

method __init__

    atom_graph_cutoff: 'float' = 5,
    bond_graph_cutoff: 'float' = 3,
    num_radial: 'int' = 9,
    cutoff_coeff: 'int' = 5,
    learnable: 'bool' = False

source link

method forward

    center: 'Tensor',
    neighbor: 'Tensor',
    undirected2directed: 'Tensor',
    image: 'Tensor',
    lattice: 'Tensor'
)tuple[Tensor, Tensor, Tensor]

source link

class AngleEncoder

Encode an angle given the two bond vectors using Fourier Expansion.

method __init__

__init__(num_angular: 'int' = 9, learnable: 'bool' = True)None

source link

method forward

forward(bond_i: 'Tensor', bond_j: 'Tensor') → Tensor

source link

module model.functions

function aggregate

    data: 'Tensor',
    owners: 'Tensor',
) → Tensor

source link

function find_activation

find_activation(name: 'str') → Module

source link

function find_normalization

find_normalization(name: 'str', dim: 'int | None' = None) → Module | None

source link

class MLP

Multi-Layer Perceptron used for non-linear regression.

method __init__

    input_dim: 'int',
    output_dim: 'int' = 1,
    hidden_dim: 'int | Sequence[int] | None' = (64, 64),
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    bias: 'bool' = True

source link

method forward

forward(x: 'Tensor') → Tensor

source link

class GatedMLP

Gated MLP similar model structure is used in CGCNN and M3GNet.

method __init__

    input_dim: 'int',
    output_dim: 'int',
    hidden_dim: 'int | list[int] | None' = None,
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    norm: 'str' = 'batch',
    bias: 'bool' = True

source link

method forward

forward(x: 'Tensor') → Tensor

source link

class ScaledSiLU

Scaled Sigmoid Linear Unit.

method __init__


Initialize a scaled SiLU.

method forward

forward(x: 'Tensor') → Tensor

Forward pass.

module model.layers

class AtomConv

A convolution Layer to update atom features.

method __init__

    atom_fea_dim: 'int',
    bond_fea_dim: 'int',
    hidden_dim: 'int' = 64,
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    norm: 'str | None' = None,
    use_mlp_out: 'bool' = True,
    mlp_out_bias: 'bool' = False,
    resnet: 'bool' = True,
    gMLP_norm: 'str | None' = None

source link

method forward

    atom_feas: 'Tensor',
    bond_feas: 'Tensor',
    bond_weights: 'Tensor',
    atom_graph: 'Tensor',
    directed2undirected: 'Tensor'
) → Tensor

Forward pass of AtomConv module that updates the atom features and optionally bond features.




  • num_batch_atoms = sum(num_atoms) in batch

class BondConv

A convolution Layer to update bond features.

method __init__

    atom_fea_dim: 'int',
    bond_fea_dim: 'int',
    angle_fea_dim: 'int',
    hidden_dim: 'int' = 64,
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    norm: 'str | None' = None,
    use_mlp_out: 'bool' = True,
    mlp_out_bias: 'bool' = False,
    gMLP_norm: 'str | None' = None

source link

method forward

    atom_feas: 'Tensor',
    bond_feas: 'Tensor',
    bond_weights: 'Tensor',
    angle_feas: 'Tensor',
    bond_graph: 'Tensor'
) → Tensor

Update the bond features.




  • num_batch_atoms = sum(num_atoms) in batch

class AngleUpdate

Update angle features.

method __init__

    atom_fea_dim: 'int',
    bond_fea_dim: 'int',
    angle_fea_dim: 'int',
    hidden_dim: 'int' = 0,
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    norm: 'str | None' = None,
    resnet: 'bool' = True,
    gMLP_norm: 'str | None' = None

source link

method forward

    atom_feas: 'Tensor',
    bond_feas: 'Tensor',
    angle_feas: 'Tensor',
    bond_graph: 'Tensor'
) → Tensor

Update the angle features using bond graph.




  • num_batch_atoms = sum(num_atoms) in batch

class GraphPooling

Pooling the sub-graphs in the batched graph.

method __init__

__init__(average: 'bool' = False)None

Args: average (bool): whether to average the features.

method forward

forward(atom_feas: 'Tensor', atom_owner: 'Tensor') → Tensor

Merge the atom features that belong to same graph in a batched graph.



class GraphAttentionReadOut

Multi Head Attention Read Out Layer merge the information from atom_feas to crystal_fea.

method __init__

    atom_fea_dim: 'int',
    num_head: 'int' = 3,
    hidden_dim: 'int' = 32,

source link

method forward

forward(atom_feas: 'Tensor', atom_owner: 'Tensor') → Tensor

source link

module model.model

source link

class CHGNet

Crystal Hamiltonian Graph neural Network A model that takes in a crystal graph and output energy, force, magmom, stress.

method __init__

    atom_fea_dim: 'int' = 64,
    bond_fea_dim: 'int' = 64,
    angle_fea_dim: 'int' = 64,
    composition_model: 'str | Module' = 'MPtrj',
    num_radial: 'int' = 31,
    num_angular: 'int' = 31,
    n_conv: 'int' = 4,
    atom_conv_hidden_dim: 'Sequence[int] | int' = 64,
    update_bond: 'bool' = True,
    bond_conv_hidden_dim: 'Sequence[int] | int' = 64,
    update_angle: 'bool' = True,
    angle_layer_hidden_dim: 'Sequence[int] | int' = 0,
    conv_dropout: 'float' = 0,
    read_out: 'str' = 'ave',
    mlp_hidden_dims: 'Sequence[int] | int' = (64, 64, 64),
    mlp_dropout: 'float' = 0,
    mlp_first: 'bool' = True,
    is_intensive: 'bool' = True,
    non_linearity: "Literal['silu', 'relu', 'tanh', 'gelu']" = 'silu',
    atom_graph_cutoff: 'float' = 6,
    bond_graph_cutoff: 'float' = 3,
    graph_converter_algorithm: "Literal['legacy', 'fast']" = 'fast',
    cutoff_coeff: 'int' = 8,
    learnable_rbf: 'bool' = True,
    gMLP_norm: 'str | None' = 'layer',
    readout_norm: 'str | None' = 'layer',
    version: 'str | None' = None,

Initialize CHGNet.


property n_params

Return the number of parameters in the model.

property version

Return the version of the loaded checkpoint.

method as_dict


Return the CHGNet weights and args in a dictionary.

method forward

    graphs: 'Sequence[CrystalGraph]',
    task: 'PredTask' = 'e',
    return_site_energies: 'bool' = False,
    return_atom_feas: 'bool' = False,
    return_crystal_feas: 'bool' = False
)dict[str, Tensor]

Get prediction associated with input graphs


Returns: model output (dict).

classmethod from_dict

from_dict(dct: 'dict', **kwargs) → Self

Build a CHGNet from a saved dictionary.

classmethod from_file

from_file(path: 'str', **kwargs) → Self

Build a CHGNet from a saved file.

classmethod load

    model_name: 'str' = '0.3.0',
    use_device: 'str | None' = None,
    check_cuda_mem: 'bool' = False,
    verbose: 'bool' = True
) → Self

Load pretrained CHGNet model.

Args: model_name (str, optional): Default = “0.3.0”.


method predict_graph

    graph: 'CrystalGraph | Sequence[CrystalGraph]',
    task: 'PredTask' = 'efsm',
    return_site_energies: 'bool' = False,
    return_atom_feas: 'bool' = False,
    return_crystal_feas: 'bool' = False,
    batch_size: 'int' = 16
)dict[str, Tensor] | list[dict[str, Tensor]]

Predict from CrustalGraph.



method predict_structure

    structure: 'Structure | Sequence[Structure]',
    task: 'PredTask' = 'efsm',
    return_site_energies: 'bool' = False,
    return_atom_feas: 'bool' = False,
    return_crystal_feas: 'bool' = False,
    batch_size: 'int' = 16
)dict[str, Tensor] | list[dict[str, Tensor]]

Predict from pymatgen.core.Structure.



method todict


Needed for ASE JSON serialization when saving CHGNet potential to trajectory file (https://github.com/CederGroupHub/chgnet/issues/48).

source link

class BatchedGraph

Batched crystal graph for parallel computing.


method __init__

    atomic_numbers: 'Tensor',
    bond_bases_ag: 'Tensor',
    bond_bases_bg: 'Tensor',
    angle_bases: 'Tensor',
    batched_atom_graph: 'Tensor',
    batched_bond_graph: 'Tensor',
    atom_owners: 'Tensor',
    directed2undirected: 'Tensor',
    atom_positions: 'Sequence[Tensor]',
    strains: 'Sequence[Tensor]',
    volumes: 'Sequence[Tensor] | Tensor'

classmethod from_graphs

    graphs: 'Sequence[CrystalGraph]',
    bond_basis_expansion: 'Module',
    angle_basis_expansion: 'Module',
    compute_stress: 'bool' = False
) → Self

source link

module trainer.trainer

source link

class Trainer

A trainer to train CHGNet using energy, force, stress and magmom.

method __init__

    model: 'CHGNet | None' = None,
    targets: 'TrainTask' = 'ef',
    energy_loss_ratio: 'float' = 1,
    force_loss_ratio: 'float' = 1,
    stress_loss_ratio: 'float' = 0.1,
    mag_loss_ratio: 'float' = 0.1,
    optimizer: 'str' = 'Adam',
    scheduler: 'str' = 'CosLR',
    criterion: 'str' = 'MSE',
    epochs: 'int' = 50,
    starting_epoch: 'int' = 0,
    learning_rate: 'float' = 0.001,
    print_freq: 'int' = 100,
    torch_seed: 'int | None' = None,
    data_seed: 'int | None' = None,
    use_device: 'str | None' = None,
    check_cuda_mem: 'bool' = False,
    wandb_path: 'str | None' = None,
    wandb_init_kwargs: 'dict | None' = None,
    extra_run_config: 'dict | None' = None,

Initialize all hyper-parameters for trainer.


method get_best_model

get_best_model() → CHGNet

Get best model recorded in the trainer.

classmethod load

load(path: 'str') → Self

Load trainer state_dict.

method move_to

move_to(obj, device) → Tensor | list[Tensor]

Move object to device.

method save

save(filename: 'str' = 'training_result.pth.tar')None

Save the model, graph_converter, etc.

method save_checkpoint

save_checkpoint(epoch: 'int', mae_error: 'dict', save_dir: 'str')None

Function to save CHGNet trained weights after each epoch.


method train

    train_loader: 'DataLoader',
    val_loader: 'DataLoader',
    test_loader: 'DataLoader | None' = None,
    save_dir: 'str | None' = None,
    save_test_result: 'bool' = False,
    train_composition_model: 'bool' = False,
    wandb_log_freq: 'LogFreq' = 'batch'

Train the model using torch data_loaders.


class CombinedLoss

A combined loss function of energy, force, stress and magmom.

method __init__

    target_str: 'str' = 'ef',
    criterion: 'str' = 'MSE',
    is_intensive: 'bool' = True,
    energy_loss_ratio: 'float' = 1,
    force_loss_ratio: 'float' = 1,
    stress_loss_ratio: 'float' = 0.1,
    mag_loss_ratio: 'float' = 0.1,
    delta: 'float' = 0.1

source link

method forward

    targets: 'dict[str, Tensor]',
    prediction: 'dict[str, Tensor]'
)dict[str, Tensor]

Compute the combined loss using CHGNet prediction and labels this function can automatically mask out magmom loss contribution of data points without magmom labels.


Returns: dictionary of all the loss, MAE and MAE_size

module utils.common_utils

function determine_device

    use_device: 'str | None' = None,
    check_cuda_mem: 'bool' = False

source link

function cuda_devices_sorted_by_free_mem


List available CUDA devices sorted by increasing available memory.

To get the device with the most free memory, use the last list item.

function mae

mae(prediction: 'Tensor', target: 'Tensor') → Tensor

Computes the mean absolute error between prediction and target.


Returns: tensor

function read_json

read_json(filepath: 'str')dict

source link

function write_json

write_json(dct: 'dict', filepath: 'str')dict

Write the json file.


Returns: written dictionary

function mkdir

mkdir(path: 'str')str

Make directory.


Returns: path

class AverageMeter

Computes and stores the average and current value.

method __init__


Initialize the meter.

method reset


Reset the meter value, average, sum and count to 0.

method update

update(val: 'float', n: 'int' = 1)None

Update the meter value, average, sum and count.


module utils.vasp_utils

function parse_vasp_dir

    base_dir: 'str',
    check_electronic_convergence: 'bool' = True,
    save_path: 'str | None' = None
)dict[str, list]

Parse VASP output files into structures and labels By default, the magnetization is read from mag_x from VASP, plz modify the code if magnetization is for (y) and (z).


function solve_charge_by_mag

    structure: 'Structure',
    default_ox: 'dict[str, float] | None' = None,
    ox_ranges: 'dict[str, dict[tuple[float, float], int]] | None' = None
) → Structure | None

Solve oxidation states by magmom.
