« home

API

source link

module data.dataset


source link

function collate_graphs

collate_graphs(
    batch_data: 'list'
)tuple[list[CrystalGraph], dict[str, Tensor]]

Collate of list of (graph, target) into batch data.

Args:

Returns:


source link

function get_train_val_test_loader

get_train_val_test_loader(
    dataset: 'Dataset',
    batch_size: 'int' = 64,
    train_ratio: 'float' = 0.8,
    val_ratio: 'float' = 0.1,
    return_test: 'bool' = True,
    num_workers: 'int' = 0,
    pin_memory: 'bool' = True
)tuple[DataLoader, DataLoader, DataLoader]

Randomly partition a dataset into train, val, test loaders.

Args:

Returns: train_loader, val_loader and optionally test_loader


source link

function get_loader

get_loader(
    dataset,
    batch_size: 'int' = 64,
    num_workers: 'int' = 0,
    pin_memory: 'bool' = True
) → DataLoader

Get a dataloader from a dataset.

Args:

Returns: data_loader


source link

class StructureData

A simple torch Dataset of structures.

source link

method __init__

__init__(
    structures: 'list[Structure]',
    energies: 'list[float]',
    forces: 'list[Sequence[Sequence[float]]]',
    stresses: 'list[Sequence[Sequence[float]]] | None' = None,
    magmoms: 'list[Sequence[Sequence[float]]] | None' = None,
    structure_ids: 'list | None' = None,
    graph_converter: 'CrystalGraphConverter | None' = None,
    shuffle: 'bool' = True
)None

Initialize the dataset.

Args:

Raises:


source link

classmethod from_vasp

from_vasp(
    file_root: 'str',
    check_electronic_convergence: 'bool' = True,
    save_path: 'str | None' = None,
    graph_converter: 'CrystalGraphConverter | None' = None,
    shuffle: 'bool' = True
) → Self

Parse VASP output files into structures and labels and feed into the dataset.

Args:


source link

class CIFData

A dataset from CIFs.

source link

method __init__

__init__(
    cif_path: 'str',
    labels: 'str | dict' = 'labels.json',
    targets: 'TrainTask' = 'efsm',
    graph_converter: 'CrystalGraphConverter | None' = None,
    energy_key: 'str' = 'energy_per_atom',
    force_key: 'str' = 'force',
    stress_key: 'str' = 'stress',
    magmom_key: 'str' = 'magmom',
    shuffle: 'bool' = True
)None

Initialize the dataset from a directory containing CIFs.

Args:


source link

class GraphData

A dataset of graphs. This is compatible with the graph.pt documents made by make_graphs.py. We recommend you to use the dataset to avoid graph conversion steps.

source link

method __init__

__init__(
    graph_path: 'str',
    labels: 'str | dict' = 'labels.json',
    targets: 'TrainTask' = 'efsm',
    exclude: 'str | list | None' = None,
    energy_key: 'str' = 'energy_per_atom',
    force_key: 'str' = 'force',
    stress_key: 'str' = 'stress',
    magmom_key: 'str' = 'magmom',
    shuffle: 'bool' = True
)None

Initialize the dataset from a directory containing saved crystal graphs.

Args:


source link

method get_train_val_test_loader

get_train_val_test_loader(
    train_ratio: 'float' = 0.8,
    val_ratio: 'float' = 0.1,
    train_key: 'list[str] | None' = None,
    val_key: 'list[str] | None' = None,
    test_key: 'list[str] | None' = None,
    batch_size=32,
    num_workers=0,
    pin_memory=True
)tuple[DataLoader, DataLoader, DataLoader]

Partition the GraphData using materials id, randomly select the train_keys, val_keys, test_keys by train val test ratio, or use pre-defined train_keys, val_keys, and test_keys to create train, val, test loaders.

Args:

Returns: train_loader, val_loader, test_loader


source link

class StructureJsonData

Read structure and targets from a JSON file. This class is used to load the MPtrj dataset.

source link

method __init__

__init__(
    data: 'str | dict',
    graph_converter: 'CrystalGraphConverter',
    targets: 'TrainTask' = 'efsm',
    energy_key: 'str' = 'energy_per_atom',
    force_key: 'str' = 'force',
    stress_key: 'str' = 'stress',
    magmom_key: 'str' = 'magmom',
    shuffle: 'bool' = True
)None

Initialize the dataset by reading JSON files.

Args:


source link

method get_train_val_test_loader

get_train_val_test_loader(
    train_ratio: 'float' = 0.8,
    val_ratio: 'float' = 0.1,
    train_key: 'list[str] | None' = None,
    val_key: 'list[str] | None' = None,
    test_key: 'list[str] | None' = None,
    batch_size=32,
    num_workers=0,
    pin_memory=True
)tuple[DataLoader, DataLoader, DataLoader]

Partition the Dataset using materials id, randomly select the train_keys, val_keys, test_keys by train val test ratio, or use pre-defined train_keys, val_keys, and test_keys to create train, val, test loaders.

Args:

Returns: train_loader, val_loader, test_loader

source link

module graph.converter


source link

class CrystalGraphConverter

Convert a pymatgen.core.Structure to a CrystalGraph The CrystalGraph dataclass stores essential field to make sure that gradients like force and stress can be calculated through back-propagation later.

source link

method __init__

__init__(
    atom_graph_cutoff: 'float' = 6,
    bond_graph_cutoff: 'float' = 3,
    algorithm: "Literal['legacy', 'fast']" = 'fast',
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'error',
    verbose: 'bool' = False
)None

Initialize the Crystal Graph Converter.

Args:


source link

method as_dict

as_dict()dict[str, str | float]

Save the args of the graph converter.


source link

method forward

forward(structure: 'Structure', graph_id=None, mp_id=None) → CrystalGraph

Convert a structure, return a CrystalGraph.

Args:

Return: CrystalGraph that is ready to use by CHGNet


source link

classmethod from_dict

from_dict(dct: 'dict') → Self

Create converter from dictionary.


source link

method set_isolated_atom_response

set_isolated_atom_response(
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']"
)None

Set the graph converter’s response to isolated atom graph

Args:

Returns: None

source link

module graph.crystalgraph


source link

class CrystalGraph

A data class for crystal graph.

source link

method __init__

__init__(
    atomic_number: 'Tensor',
    atom_frac_coord: 'Tensor',
    atom_graph: 'Tensor',
    atom_graph_cutoff: 'float',
    neighbor_image: 'Tensor',
    directed2undirected: 'Tensor',
    undirected2directed: 'Tensor',
    bond_graph: 'Tensor',
    bond_graph_cutoff: 'float',
    lattice: 'Tensor',
    graph_id: 'str | None' = None,
    mp_id: 'str | None' = None,
    composition: 'str | None' = None
)None

Initialize the crystal graph.

Attention! This data class is not intended to be created manually. CrystalGraph should be returned by a CrystalGraphConverter

Args:

Raises:


property num_isolated_atoms

Number of isolated atoms given the atom graph cutoff Isolated atoms are disconnected nodes in the atom graph that will not get updated in CHGNet. These atoms will always have calculated force equal to zero.

With the default CHGNet atom graph cutoff radius, only ~ 0.1% of MPtrj dataset structures has isolated atoms.


source link

classmethod from_dict

from_dict(dic: 'dict[str, Any]') → Self

Load a CrystalGraph from a dictionary.


source link

classmethod from_file

from_file(file_name: 'str') → Self

Load a crystal graph from a file.

Args:

Returns:


source link

method save

save(fname: 'str | None' = None, save_dir: 'str' = '.')str

Save the graph to a file.

Args:

Returns:


source link

method to

to(device: 'str' = 'cpu') → CrystalGraph

Move the graph to a device. Default = ‘cpu’.


source link

method to_dict

to_dict()dict[str, Any]

Convert the graph to a dictionary.

source link

module graph.graph


source link

class Node

A node in a graph.

source link

method __init__

__init__(index: 'int', info: 'dict | None' = None)None

Initialize a Node.

Args:


source link

method add_neighbor

add_neighbor(index, edge)None

Draw an directed edge between self and the node specified by index.

Args:


source link

class Edge

Abstract base class for edges in a graph.

source link

method __init__

__init__(
    nodes: 'list',
    index: 'int | None' = None,
    info: 'dict | None' = None
)None

Initialize an Edge.


source link

class UndirectedEdge

An undirected/bi-directed edge in a graph.

source link

method __init__

__init__(
    nodes: 'list',
    index: 'int | None' = None,
    info: 'dict | None' = None
)None

Initialize an Edge.


source link

class DirectedEdge

A directed edge in a graph.

source link

method __init__

__init__(
    nodes: 'list',
    index: 'int | None' = None,
    info: 'dict | None' = None
)None

Initialize an Edge.


source link

method make_undirected

make_undirected(index: 'int', info: 'dict | None' = None) → UndirectedEdge

Make a directed edge undirected.


source link

class Graph

A graph for storing the neighbor information of atoms.

source link

method __init__

__init__(nodes: 'list[Node]')None

Initialize a Graph from a list of nodes.


source link

method add_edge

add_edge(
    center_index,
    neighbor_index,
    image,
    distance,
    dist_tol: 'float' = 1e-06
)None

Add an directed edge to the graph.

Args:


source link

method adjacency_list

adjacency_list()tuple[list[list[int]], list[int]]

Get the adjacency list Return: graph: the adjacency list [[0, 1], [0, 2], … [5, 2] … ]] the fist column specifies center/source node, the second column specifies neighbor/destination node directed2undirected: [0, 1, …] a list of length = num_directed_edge that specifies the undirected edge index corresponding to the directed edges represented in each row in the graph adjacency list.


source link

method as_dict

as_dict()dict

Return dictionary serialization of a Graph.


source link

method line_graph_adjacency_list

line_graph_adjacency_list(cutoff)tuple[list[list[int]], list[int]]

Get the line graph adjacency list.

Args:

Return: line_graph: [[0, 1, 1, 2, 2], [0, 1, 1, 4, 23], [1, 4, 23, 5, 66], … … ] the fist column specifies node(atom) index at this angle, the second column specifies 1st undirected edge(left bond) index, the third column specifies 1st directed edge(left bond) index, the fourth column specifies 2nd undirected edge(right bond) index, the fifth column specifies 2nd directed edge(right bond) index,. undirected2directed: [32, 45, …] a list of length = num_undirected_edge that maps the undirected edge index to one of its directed edges indices


source link

method to

to(filename='graph.json')None

Save graph dictionary to file.


source link

method undirected2directed

undirected2directed()list[int]

The index map from undirected_edge index to one of its directed_edge index.

source link

module model.basis


source link

class Fourier

Fourier Expansion for angle features.

source link

method __init__

__init__(order: 'int' = 5, learnable: 'bool' = False)None

Initialize the Fourier expansion.

Args:


source link

method forward

forward(x: 'Tensor') → Tensor

Apply Fourier expansion to a feature Tensor.


source link

class RadialBessel

1D Bessel Basis from: https://github.com/TUM-DAML/gemnet_pytorch/.

source link

method __init__

__init__(
    num_radial: 'int' = 9,
    cutoff: 'float' = 5,
    learnable: 'bool' = False,
    smooth_cutoff: 'int' = 5
)None

Initialize the SmoothRBF function.

Args:


source link

method forward

forward(
    dist: 'Tensor',
    return_smooth_factor: 'bool' = False
) → Tensor | tuple[Tensor, Tensor]

Apply Bessel expansion to a feature Tensor.

Args:

Returns:


source link

class GaussianExpansion

Expands the distance by Gaussian basis. Unit: angstrom.

source link

method __init__

__init__(
    min: 'float' = 0,
    max: 'float' = 5,
    step: 'float' = 0.5,
    var: 'float | None' = None
)None

Gaussian Expansion expand a scalar feature to a soft-one-hot feature vector.

Args:


source link

method expand

expand(features: 'Tensor') → Tensor

Apply Gaussian filter to a feature Tensor.

Args:

Returns:


source link

class CutoffPolynomial

Polynomial soft-cutoff function for atom graph ref: https://github.com/TUM-DAML/gemnet_pytorch/blob/-/gemnet/model/layers/envelope.py.

source link

method __init__

__init__(cutoff: 'float' = 5, cutoff_coeff: 'float' = 5)None

Initialize the polynomial cutoff function.

Args:


source link

method forward

forward(r: 'Tensor') → Tensor

Polynomial cutoff function.

Args:

Returns:

source link

module model.composition_model


source link

class CompositionModel

A simple FC model that takes in a chemical composition (no structure info) and outputs energy.

source link

method __init__

__init__(
    atom_fea_dim: 'int' = 64,
    activation: 'str' = 'silu',
    is_intensive: 'bool' = True,
    max_num_elements: 'int' = 94
)None

Initialize a CompositionModel.


source link

method forward

forward(graphs: 'list[CrystalGraph]') → Tensor

Get the energy of a list of CrystalGraphs as Tensor.


source link

class AtomRef

A linear regression for elemental energy. From: https://github.com/materialsvirtuallab/m3gnet/.

source link

method __init__

__init__(is_intensive: 'bool' = True, max_num_elements: 'int' = 94)None

Initialize an AtomRef model.


source link

method fit

fit(
    structures_or_graphs: 'Sequence[Structure | CrystalGraph]',
    energies: 'Sequence[float]'
)None

Fit the model to a list of crystals and energies.

Args:


source link

method forward

forward(graphs: 'list[CrystalGraph]') → Tensor

Get the energy of a list of CrystalGraphs.

Args:

Returns: energy (tensor)


source link

method get_site_energies

get_site_energies(graphs: 'list[CrystalGraph]')list[Tensor]

Predict the site energies given a list of CrystalGraphs.

Args:

Returns: a list of tensors corresponding to site energies of each graph [batchsize].


source link

method initialize_from

initialize_from(dataset: 'str')None

Initialize pre-fitted weights from a dataset.


source link

method initialize_from_MPF

initialize_from_MPF()None

Initialize pre-fitted weights from MPF dataset.


source link

method initialize_from_MPtrj

initialize_from_MPtrj()None

Initialize pre-fitted weights from MPtrj dataset.


source link

method initialize_from_numpy

initialize_from_numpy(file_name: 'str | Path')None

Initialize pre-fitted weights from numpy file.

source link

module model.dynamics

Global Variables


source link

class CHGNetCalculator

CHGNet Calculator for ASE applications.

source link

method __init__

__init__(
    model: 'CHGNet | None' = None,
    use_device: 'str | None' = None,
    check_cuda_mem: 'bool' = False,
    stress_weight: 'float | None' = 0.006241509125883258,
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'warn',
    **kwargs
)None

Provide a CHGNet instance to calculate various atomic properties using ASE.

Args:


property directory


property label


property n_params

The number of parameters in CHGNet.


property name


property version

The version of CHGNet.


source link

method calculate

calculate(
    atoms: 'Atoms | None' = None,
    properties: 'list | None' = None,
    system_changes: 'list | None' = None
)None

Calculate various properties of the atoms using CHGNet.

Args:


source link

classmethod from_file

from_file(path: 'str', use_device: 'str | None' = None, **kwargs) → Self

Load a user’s CHGNet model and initialize the Calculator.


source link

class StructOptimizer

Wrapper class for structural relaxation.

source link

method __init__

__init__(
    model: 'CHGNet | CHGNetCalculator | None' = None,
    optimizer_class: 'Optimizer | str | None' = 'FIRE',
    use_device: 'str | None' = None,
    stress_weight: 'float' = 0.006241509125883258,
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'warn'
)None

Provide a trained CHGNet model and an optimizer to relax crystal structures.

Args:


property n_params

The number of parameters in CHGNet.


property version

The version of CHGNet.


source link

method relax

relax(
    atoms: 'Structure | Atoms',
    fmax: 'float | None' = 0.1,
    steps: 'int | None' = 500,
    relax_cell: 'bool | None' = True,
    ase_filter: 'str | None' = 'FrechetCellFilter',
    save_path: 'str | None' = None,
    loginterval: 'int | None' = 1,
    crystal_feas_save_path: 'str | None' = None,
    verbose: 'bool' = True,
    assign_magmoms: 'bool' = True,
    **kwargs
)dict[str, Structure | TrajectoryObserver]

Relax the Structure/Atoms until maximum force is smaller than fmax.

Args:

Returns: dict[str, Structure | TrajectoryObserver]: A dictionary with ‘final_structure’ and ‘trajectory’.


source link

class TrajectoryObserver

Trajectory observer is a hook in the relaxation process that saves the intermediate structures.

source link

method __init__

__init__(atoms: 'Atoms')None

Create a TrajectoryObserver from an Atoms object.

Args:


source link

method compute_energy

compute_energy()float

Calculate the potential energy.

Returns:


source link

method save

save(filename: 'str')None

Save the trajectory to file.

Args:


source link

class CrystalFeasObserver

CrystalFeasObserver is a hook in the relaxation and MD process that saves the intermediate crystal feature structures.

source link

method __init__

__init__(atoms: 'Atoms')None

Create a CrystalFeasObserver from an Atoms object.


source link

method save

save(filename: 'str')None

Save the crystal feature vectors to filename in pickle format.


source link

class MolecularDynamics

Molecular dynamics class.

source link

method __init__

__init__(
    atoms: 'Atoms | Structure',
    model: 'CHGNet | CHGNetCalculator | None' = None,
    ensemble: 'str' = 'nvt',
    thermostat: 'str' = 'Berendsen_inhomogeneous',
    temperature: 'int' = 300,
    starting_temperature: 'int | None' = None,
    timestep: 'float' = 2.0,
    pressure: 'float' = 0.000101325,
    taut: 'float | None' = None,
    taup: 'float | None' = None,
    bulk_modulus: 'float | None' = None,
    trajectory: 'str | Trajectory | None' = None,
    logfile: 'str | None' = None,
    loginterval: 'int' = 1,
    crystal_feas_logfile: 'str | None' = None,
    append_trajectory: 'bool' = False,
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'warn',
    use_device: 'str | None' = None
)None

Initialize the MD class.

Args:

In NPT ensemble, the effective damping time for pressure is multiplied by compressibility. In LAMMPS, Bulk modulus is defaulted to 10

If bulk modulus is not provided here, it will be calculated by CHGNet through Birch Murnaghan equation of state (EOS). Note the EOS fitting can fail because of non-parabolic potential energy surface, which is common with soft system like liquid and gas. In such case, user should provide an input bulk modulus for better barostat coupling, otherwise a guessed bulk modulus = 2 GPa will be used (water’s bulk modulus)

Default = None


source link

method run

run(steps: 'int')None

Thin wrapper of ase MD run.

Args:


source link

method set_atoms

set_atoms(atoms: 'Atoms')None

Set new atoms to run MD.

Args:


source link

method upper_triangular_cell

upper_triangular_cell(verbose: 'bool | None' = False)None

Transform to upper-triangular cell. ASE Nose-Hoover implementation only supports upper-triangular cell while ASE’s canonical description is lower-triangular cell.

Args:


source link

class EquationOfState

Class to calculate equation of state.

source link

method __init__

__init__(
    model: 'CHGNet | CHGNetCalculator | None' = None,
    optimizer_class: 'Optimizer | str | None' = 'FIRE',
    use_device: 'str | None' = None,
    stress_weight: 'float' = 0.006241509125883258,
    on_isolated_atoms: "Literal['ignore', 'warn', 'error']" = 'error'
)None

Initialize a structure optimizer object for calculation of bulk modulus.

Args:


source link

method fit

fit(
    atoms: 'Structure | Atoms',
    n_points: 'int' = 11,
    fmax: 'float | None' = 0.1,
    steps: 'int | None' = 500,
    verbose: 'bool | None' = False,
    **kwargs
)None

Relax the Structure/Atoms and fit the Birch-Murnaghan equation of state.

Args:

Returns: Bulk Modulus (float)


source link

method get_bulk_modulus

get_bulk_modulus(unit: 'str' = 'eV/A^3')float

Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.

Args:

Returns: Bulk Modulus (float)


source link

method get_compressibility

get_compressibility(unit: 'str' = 'A^3/eV')float

Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.

Args:

Returns: Bulk Modulus (float)

source link

module model.encoders


source link

class AtomEmbedding

Encode an atom by its atomic number using a learnable embedding layer.

source link

method __init__

__init__(atom_feature_dim: 'int', max_num_elements: 'int' = 94)None

Initialize the Atom featurizer.

Args:


source link

method forward

forward(atomic_numbers: 'Tensor') → Tensor

Convert the structure to a atom embedding tensor.

Args:

Returns:


source link

class BondEncoder

Encode a chemical bond given the positions of two atoms using Gaussian distance.

source link

method __init__

__init__(
    atom_graph_cutoff: 'float' = 5,
    bond_graph_cutoff: 'float' = 3,
    num_radial: 'int' = 9,
    cutoff_coeff: 'int' = 5,
    learnable: 'bool' = False
)None

Initialize the bond encoder.

Args:


source link

method forward

forward(
    center: 'Tensor',
    neighbor: 'Tensor',
    undirected2directed: 'Tensor',
    image: 'Tensor',
    lattice: 'Tensor'
)tuple[Tensor, Tensor, Tensor]

Compute the pairwise distance between 2 3d coordinates.

Args:

Returns:


source link

class AngleEncoder

Encode an angle given the two bond vectors using Fourier Expansion.

source link

method __init__

__init__(num_angular: 'int' = 9, learnable: 'bool' = True)None

Initialize the angle encoder.

Args:


source link

method forward

forward(bond_i: 'Tensor', bond_j: 'Tensor') → Tensor

Compute the angles between normalized vectors.

Args:

Returns:

source link

module model.functions


source link

function aggregate

aggregate(
    data: 'Tensor',
    owners: 'Tensor',
    average=True,
    num_owner=None
) → Tensor

Aggregate rows in data by specifying the owners.

Args:

Returns:


source link

function find_activation

find_activation(name: 'str') → Module

Return an activation function using name.


source link

function find_normalization

find_normalization(name: 'str', dim: 'int | None' = None) → Module | None

Return an normalization function using name.


source link

class MLP

Multi-Layer Perceptron used for non-linear regression.

source link

method __init__

__init__(
    input_dim: 'int',
    output_dim: 'int' = 1,
    hidden_dim: 'int | Sequence[int] | None' = (64, 64),
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    bias: 'bool' = True
)None

Initialize the MLP.

Args:


source link

method forward

forward(x: 'Tensor') → Tensor

Performs a forward pass through the MLP.

Args:

Returns:


source link

class GatedMLP

Gated MLP similar model structure is used in CGCNN and M3GNet.

source link

method __init__

__init__(
    input_dim: 'int',
    output_dim: 'int',
    hidden_dim: 'int | list[int] | None' = None,
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    norm: 'str' = 'batch',
    bias: 'bool' = True
)None

Initialize a gated MLP.

Args:


source link

method forward

forward(x: 'Tensor') → Tensor

Performs a forward pass through the MLP.

Args:

Returns:


source link

class ScaledSiLU

Scaled Sigmoid Linear Unit.

source link

method __init__

__init__()None

Initialize a scaled SiLU.


source link

method forward

forward(x: 'Tensor') → Tensor

Forward pass.

source link

module model.layers


source link

class AtomConv

A convolution Layer to update atom features.

source link

method __init__

__init__(
    atom_fea_dim: 'int',
    bond_fea_dim: 'int',
    hidden_dim: 'int' = 64,
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    norm: 'str | None' = None,
    use_mlp_out: 'bool' = True,
    mlp_out_bias: 'bool' = False,
    resnet: 'bool' = True,
    gMLP_norm: 'str | None' = None
)None

Initialize the AtomConv layer.

Args:


source link

method forward

forward(
    atom_feas: 'Tensor',
    bond_feas: 'Tensor',
    bond_weights: 'Tensor',
    atom_graph: 'Tensor',
    directed2undirected: 'Tensor'
) → Tensor

Forward pass of AtomConv module that updates the atom features and optionally bond features.

Args:

Returns:

Notes:

  • num_batch_atoms = sum(num_atoms) in batch

source link

class BondConv

A convolution Layer to update bond features.

source link

method __init__

__init__(
    atom_fea_dim: 'int',
    bond_fea_dim: 'int',
    angle_fea_dim: 'int',
    hidden_dim: 'int' = 64,
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    norm: 'str | None' = None,
    use_mlp_out: 'bool' = True,
    mlp_out_bias: 'bool' = False,
    resnet=True,
    gMLP_norm: 'str | None' = None
)None

Initialize the BondConv layer.

Args:


source link

method forward

forward(
    atom_feas: 'Tensor',
    bond_feas: 'Tensor',
    bond_weights: 'Tensor',
    angle_feas: 'Tensor',
    bond_graph: 'Tensor'
) → Tensor

Update the bond features.

Args:

Returns:

Notes:

  • num_batch_atoms = sum(num_atoms) in batch

source link

class AngleUpdate

Update angle features.

source link

method __init__

__init__(
    atom_fea_dim: 'int',
    bond_fea_dim: 'int',
    angle_fea_dim: 'int',
    hidden_dim: 'int' = 0,
    dropout: 'float' = 0,
    activation: 'str' = 'silu',
    norm: 'str | None' = None,
    resnet: 'bool' = True,
    gMLP_norm: 'str | None' = None
)None

Initialize the AngleUpdate layer.

Args:


source link

method forward

forward(
    atom_feas: 'Tensor',
    bond_feas: 'Tensor',
    angle_feas: 'Tensor',
    bond_graph: 'Tensor'
) → Tensor

Update the angle features using bond graph.

Args:

Returns:

Notes:

  • num_batch_atoms = sum(num_atoms) in batch

source link

class GraphPooling

Pooling the sub-graphs in the batched graph.

source link

method __init__

__init__(average: 'bool' = False)None

Args: average (bool): whether to average the features.


source link

method forward

forward(atom_feas: 'Tensor', atom_owner: 'Tensor') → Tensor

Merge the atom features that belong to same graph in a batched graph.

Args:

Returns:


source link

class GraphAttentionReadOut

Multi Head Attention Read Out Layer merge the information from atom_feas to crystal_fea.

source link

method __init__

__init__(
    atom_fea_dim: 'int',
    num_head: 'int' = 3,
    hidden_dim: 'int' = 32,
    average=False
)None

Initialize the layer.

Args:


source link

method forward

forward(atom_feas: 'Tensor', atom_owner: 'Tensor') → Tensor

Merge the atom features that belong to same graph in a batched graph.

Args:

Returns:

source link

module model.model

Global Variables


source link

class CHGNet

Crystal Hamiltonian Graph neural Network A model that takes in a crystal graph and output energy, force, magmom, stress.

source link

method __init__

__init__(
    atom_fea_dim: 'int' = 64,
    bond_fea_dim: 'int' = 64,
    angle_fea_dim: 'int' = 64,
    composition_model: 'str | Module' = 'MPtrj',
    num_radial: 'int' = 31,
    num_angular: 'int' = 31,
    n_conv: 'int' = 4,
    atom_conv_hidden_dim: 'Sequence[int] | int' = 64,
    update_bond: 'bool' = True,
    bond_conv_hidden_dim: 'Sequence[int] | int' = 64,
    update_angle: 'bool' = True,
    angle_layer_hidden_dim: 'Sequence[int] | int' = 0,
    conv_dropout: 'float' = 0,
    read_out: 'str' = 'ave',
    mlp_hidden_dims: 'Sequence[int] | int' = (64, 64, 64),
    mlp_dropout: 'float' = 0,
    mlp_first: 'bool' = True,
    is_intensive: 'bool' = True,
    non_linearity: "Literal['silu', 'relu', 'tanh', 'gelu']" = 'silu',
    atom_graph_cutoff: 'float' = 6,
    bond_graph_cutoff: 'float' = 3,
    graph_converter_algorithm: "Literal['legacy', 'fast']" = 'fast',
    cutoff_coeff: 'int' = 8,
    learnable_rbf: 'bool' = True,
    gMLP_norm: 'str | None' = 'layer',
    readout_norm: 'str | None' = 'layer',
    version: 'str | None' = None,
    **kwargs
)None

Initialize CHGNet.

Args:


property n_params

Return the number of parameters in the model.


property version

Return the version of the loaded checkpoint.


source link

method as_dict

as_dict()dict

Return the CHGNet weights and args in a dictionary.


source link

method forward

forward(
    graphs: 'Sequence[CrystalGraph]',
    task: 'PredTask' = 'e',
    return_site_energies: 'bool' = False,
    return_atom_feas: 'bool' = False,
    return_crystal_feas: 'bool' = False
)dict[str, Tensor]

Get prediction associated with input graphs

Args:

Returns: model output (dict).


source link

classmethod from_dict

from_dict(dct: 'dict', **kwargs) → Self

Build a CHGNet from a saved dictionary.


source link

classmethod from_file

from_file(path: 'str', **kwargs) → Self

Build a CHGNet from a saved file.


source link

classmethod load

load(
    model_name: 'str' = '0.3.0',
    use_device: 'str | None' = None,
    check_cuda_mem: 'bool' = False,
    verbose: 'bool' = True
) → Self

Load pretrained CHGNet model.

Args: model_name (str, optional): Default = “0.3.0”.

Raises:


source link

method predict_graph

predict_graph(
    graph: 'CrystalGraph | Sequence[CrystalGraph]',
    task: 'PredTask' = 'efsm',
    return_site_energies: 'bool' = False,
    return_atom_feas: 'bool' = False,
    return_crystal_feas: 'bool' = False,
    batch_size: 'int' = 16
)dict[str, Tensor] | list[dict[str, Tensor]]

Predict from CrustalGraph.

Args:

Returns:


source link

method predict_structure

predict_structure(
    structure: 'Structure | Sequence[Structure]',
    task: 'PredTask' = 'efsm',
    return_site_energies: 'bool' = False,
    return_atom_feas: 'bool' = False,
    return_crystal_feas: 'bool' = False,
    batch_size: 'int' = 16
)dict[str, Tensor] | list[dict[str, Tensor]]

Predict from pymatgen.core.Structure.

Args:

Returns:


source link

method todict

todict()dict

Needed for ASE JSON serialization when saving CHGNet potential to trajectory file (https://github.com/CederGroupHub/chgnet/issues/48).


source link

class BatchedGraph

Batched crystal graph for parallel computing.

Attributes:

source link

method __init__

__init__(
    atomic_numbers: 'Tensor',
    bond_bases_ag: 'Tensor',
    bond_bases_bg: 'Tensor',
    angle_bases: 'Tensor',
    batched_atom_graph: 'Tensor',
    batched_bond_graph: 'Tensor',
    atom_owners: 'Tensor',
    directed2undirected: 'Tensor',
    atom_positions: 'Sequence[Tensor]',
    strains: 'Sequence[Tensor]',
    volumes: 'Sequence[Tensor] | Tensor'
)None

source link

classmethod from_graphs

from_graphs(
    graphs: 'Sequence[CrystalGraph]',
    bond_basis_expansion: 'Module',
    angle_basis_expansion: 'Module',
    compute_stress: 'bool' = False
) → Self

Featurize and assemble a list of graphs.

Args:

Returns:

source link

module trainer.trainer

Global Variables


source link

class Trainer

A trainer to train CHGNet using energy, force, stress and magmom.

source link

method __init__

__init__(
    model: 'CHGNet | None' = None,
    targets: 'TrainTask' = 'ef',
    energy_loss_ratio: 'float' = 1,
    force_loss_ratio: 'float' = 1,
    stress_loss_ratio: 'float' = 0.1,
    mag_loss_ratio: 'float' = 0.1,
    optimizer: 'str' = 'Adam',
    scheduler: 'str' = 'CosLR',
    criterion: 'str' = 'MSE',
    epochs: 'int' = 50,
    starting_epoch: 'int' = 0,
    learning_rate: 'float' = 0.001,
    print_freq: 'int' = 100,
    torch_seed: 'int | None' = None,
    data_seed: 'int | None' = None,
    use_device: 'str | None' = None,
    check_cuda_mem: 'bool' = False,
    wandb_path: 'str | None' = None,
    wandb_init_kwargs: 'dict | None' = None,
    extra_run_config: 'dict | None' = None,
    **kwargs
)None

Initialize all hyper-parameters for trainer.

Args:


source link

method get_best_model

get_best_model() → CHGNet

Get best model recorded in the trainer.


source link

classmethod load

load(path: 'str') → Self

Load trainer state_dict.


source link

method move_to

move_to(obj, device) → Tensor | list[Tensor]

Move object to device.


source link

method save

save(filename: 'str' = 'training_result.pth.tar')None

Save the model, graph_converter, etc.


source link

method save_checkpoint

save_checkpoint(epoch: 'int', mae_error: 'dict', save_dir: 'str')None

Function to save CHGNet trained weights after each epoch.

Args:


source link

method train

train(
    train_loader: 'DataLoader',
    val_loader: 'DataLoader',
    test_loader: 'DataLoader | None' = None,
    save_dir: 'str | None' = None,
    save_test_result: 'bool' = False,
    train_composition_model: 'bool' = False,
    wandb_log_freq: 'LogFreq' = 'batch'
)None

Train the model using torch data_loaders.

Args:


source link

class CombinedLoss

A combined loss function of energy, force, stress and magmom.

source link

method __init__

__init__(
    target_str: 'str' = 'ef',
    criterion: 'str' = 'MSE',
    is_intensive: 'bool' = True,
    energy_loss_ratio: 'float' = 1,
    force_loss_ratio: 'float' = 1,
    stress_loss_ratio: 'float' = 0.1,
    mag_loss_ratio: 'float' = 0.1,
    delta: 'float' = 0.1
)None

Initialize the combined loss.

Args:


source link

method forward

forward(
    targets: 'dict[str, Tensor]',
    prediction: 'dict[str, Tensor]'
)dict[str, Tensor]

Compute the combined loss using CHGNet prediction and labels this function can automatically mask out magmom loss contribution of data points without magmom labels.

Args:

Returns: dictionary of all the loss, MAE and MAE_size

source link

module utils.common_utils


source link

function determine_device

determine_device(
    use_device: 'str | None' = None,
    check_cuda_mem: 'bool' = False
)str

Determine the device to use for torch model.

Args:

Returns:


source link

function cuda_devices_sorted_by_free_mem

cuda_devices_sorted_by_free_mem()list[int]

List available CUDA devices sorted by increasing available memory.

To get the device with the most free memory, use the last list item.


source link

function mae

mae(prediction: 'Tensor', target: 'Tensor') → Tensor

Computes the mean absolute error between prediction and target.

Args:

Returns: tensor


source link

function read_json

read_json(filepath: 'str')dict

Read the json file.

Args:

Returns:


source link

function write_json

write_json(dct: 'dict', filepath: 'str')dict

Write the json file.

Args:

Returns: written dictionary


source link

function mkdir

mkdir(path: 'str')str

Make directory.

Args:

Returns: path


source link

class AverageMeter

Computes and stores the average and current value.

source link

method __init__

__init__()None

Initialize the meter.


source link

method reset

reset()None

Reset the meter value, average, sum and count to 0.


source link

method update

update(val: 'float', n: 'int' = 1)None

Update the meter value, average, sum and count.

Args:

source link

module utils.vasp_utils


source link

function parse_vasp_dir

parse_vasp_dir(
    base_dir: 'str',
    check_electronic_convergence: 'bool' = True,
    save_path: 'str | None' = None
)dict[str, list]

Parse VASP output files into structures and labels By default, the magnetization is read from mag_x from VASP, plz modify the code if magnetization is for (y) and (z).

Args:


source link

function solve_charge_by_mag

solve_charge_by_mag(
    structure: 'Structure',
    default_ox: 'dict[str, float] | None' = None,
    ox_ranges: 'dict[str, dict[tuple[float, float], int]] | None' = None
) → Structure | None

Solve oxidation states by magmom.

Args: