Skip to content

Commit

Permalink
Fix typing for DGL 2.0 (#79)
Browse files Browse the repository at this point in the history
* use dglteam dgl

* set node type

* fix astype

* type edges

* update changelog
  • Loading branch information
lilyminium authored Feb 9, 2024
1 parent fd2d78f commit 869993d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ The rules for this file:

### Authors
<!-- GitHub usernames of contributors to this release -->
- lilyminium

### Reviewers

### Added

### Fixed
<!-- Bug fixes -->
- Fixed node and edge typing, adding DGL 2.0 compatibility (Issue #78, PR #79)

### Changed
<!-- Changes in existing functionality -->
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: openff-nagl-test
channels:
- openeye
- dglteam
- conda-forge
- defaults
dependencies:
Expand Down
14 changes: 7 additions & 7 deletions openff/nagl/molecule/_dgl/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def from_openff(
for offmol in offmols
]
graph = dgl.batch(subgraphs)
graph.set_batch_num_nodes(graph.batch_num_nodes().sum().reshape((-1,)))
graph.set_batch_num_edges(
{
e_type: graph.batch_num_edges(e_type).sum().reshape((-1,))
for e_type in graph.canonical_etypes
}
)
n_nodes = graph.batch_num_nodes().sum().reshape((-1,))
graph.set_batch_num_nodes(n_nodes.type(torch.int32))
edges = {}
for e_type in graph.canonical_etypes:
n_edge = graph.batch_num_edges(e_type).sum().reshape((-1,))
edges[e_type] = n_edge.type(torch.int32)
graph.set_batch_num_edges(edges)

mapped_smiles = offmols[0].to_smiles(mapped=True)

Expand Down

0 comments on commit 869993d

Please sign in to comment.