Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AtomRef Updates #158

Merged
merged 17 commits into from
Sep 1, 2023
Merged
46 changes: 24 additions & 22 deletions matgl/layers/_atom_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,53 @@
import torch
from torch import nn

import matgl


class AtomRef(nn.Module):
"""Get total property offset for a system."""

def __init__(
self,
property_offset: np.array, # type: ignore
) -> None:
def __init__(self, property_offset: torch.Tensor | None = None, max_z: int = 89) -> None:
"""
Args:
property_offset (np.array): a array of elemental property offset.
property_offset (Tensor): a tensor containing the property offset for each element
if given max_z is ignored, and the size of the tensor is used instead
max_z (int): maximum atomic number.
"""
super().__init__()
self.property_offset = torch.tensor(property_offset)
self.max_z = self.property_offset.size(dim=0)
if property_offset is None:
property_offset = torch.zeros(max_z, dtype=matgl.float_th)
elif isinstance(property_offset, (np.ndarray, list)): # for backward compatibility of saved models
property_offset = torch.tensor(property_offset, dtype=matgl.float_th)

self.max_z = property_offset.shape[-1]
self.register_buffer("property_offset", property_offset)
self.register_buffer("onehot", torch.eye(self.max_z))

def get_feature_matrix(self, graphs: list) -> np.typing.NDArray:
def get_feature_matrix(self, graphs: list[dgl.DGLGraph]) -> torch.Tensor:
"""Get the number of atoms for different elements in the structure.

Args:
graphs (list): a list of dgl graph

Returns:
features (np.array): a matrix (num_structures, num_elements)
features (torch.Tensor): a matrix (num_structures, num_elements)
"""
n = len(graphs)
features = np.zeros(shape=(n, self.max_z))
for i, s in enumerate(graphs):
atomic_numbers = s.ndata["node_type"].numpy().tolist()
features[i] = np.bincount(atomic_numbers, minlength=self.max_z)
features = torch.zeros(len(graphs), self.max_z, dtype=matgl.float_th)
for i, graph in enumerate(graphs):
atomic_numbers = graph.ndata["node_type"]
features[i] = torch.bincount(atomic_numbers, minlength=self.max_z)
return features

def fit(self, graphs: list, properties: np.typing.NDArray) -> None:
def fit(self, graphs: list[dgl.DGLGraph], properties: torch.Tensor) -> None:
"""Fit the elemental reference values for the properties.

Args:
graphs: dgl graphs
properties (np.ndarray): array of extensive properties
properties (torch.Tensor): tensor of extensive properties
"""
features = self.get_feature_matrix(graphs)
self.property_offset = np.linalg.pinv(features.T.dot(features)).dot(features.T.dot(properties))
self.property_offset = torch.tensor(self.property_offset)
self.property_offset = torch.linalg.lstsq(features, properties).solution

def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None):
"""Get the total property offset for a system.
Expand All @@ -58,10 +63,7 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None):
Returns:
offset_per_graph
"""
num_elements = (
self.property_offset.size(dim=1) if self.property_offset.ndim > 1 else self.property_offset.size(dim=0)
)
one_hot = torch.eye(num_elements)[g.ndata["node_type"]]
one_hot = self.onehot[g.ndata["node_type"]]
if self.property_offset.ndim > 1:
offset_batched_with_state = []
for i in range(self.property_offset.size(dim=0)):
Expand Down
6 changes: 3 additions & 3 deletions matgl/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def load(cls, path: str | Path | dict, **kwargs):
d[k] = cls_(**v["init_args"])
d = {k: v for k, v in d.items() if not k.startswith("@")}
model = cls(**d)
model.load_state_dict(state) # type: ignore
model.load_state_dict(state, strict=False) # type: ignore

return model

Expand Down Expand Up @@ -209,11 +209,11 @@ def load_model(path: Path, **kwargs):
mod = __import__(modname, globals(), locals(), [classname], 0)
cls_ = getattr(mod, classname)
return cls_.load(fpaths, **kwargs)
except BaseException:
except BaseException as err:
raise ValueError(
"Bad serialized model or bad model name. It is possible that you have an older model cached. Please "
'clear your cache by running `python -c "import matgl; matgl.clear_cache()"`'
) from None
) from err


def _get_file_paths(path: Path, **kwargs):
Expand Down
7 changes: 4 additions & 3 deletions tests/layers/test_atom_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,24 @@
class TestAtomRef:
def test_atom_ref(self, graph_MoSH):
_, g1, _ = graph_MoSH
element_ref = AtomRef(np.array([0.5, 1.0, 2.0]))
element_ref = AtomRef(torch.tensor([0.5, 1.0, 2.0]))

atom_ref = element_ref(g1)
assert atom_ref == 3.5

def test_atom_ref_fit(self, graph_MoSH):
_, g1, _ = graph_MoSH
element_ref = AtomRef(np.array([0.5, 1.0, 2.0]))
element_ref = AtomRef(torch.tensor([0.5, 1.0, 2.0]))
properties = torch.tensor([2.0, 2.0])
bg = dgl.batch([g1, g1])
element_ref.fit([g1, g1], properties)

atom_ref = element_ref(bg)
assert list(np.round(atom_ref.numpy())) == [2.0, 2.0]

def test_atom_ref_with_states(self, graph_MoSH):
_, g1, _ = graph_MoSH
element_ref = AtomRef(np.array([[0.5, 1.0, 2.0], [2.0, 3.0, 5.0]]))
element_ref = AtomRef(torch.tensor([[0.5, 1.0, 2.0], [2.0, 3.0, 5.0]]))
state_label = torch.tensor([1])
atom_ref = element_ref(g1, state_label)
assert atom_ref == 10