Source code for espaloma.graphs.utils.regenerate_impropers

import dgl
import numpy as np
import torch

from .offmol_indices import improper_torsion_indices
from ..graph import Graph


[docs] def regenerate_impropers(g: Graph, improper_def="smirnoff"): """ Method to regenerate the improper nodes according to the specified method of permuting the impropers. Modifies the esp.Graph's heterograph in place and returns the new heterograph. NOTE: This will clear the data on all n4_improper nodes, including previously generated improper from JanossyPoolingImproper. """ ## First get rid of the old nodes/edges hg = g.heterograph hg = dgl.remove_nodes(hg, hg.nodes("n4_improper"), "n4_improper") ## Generate new improper torsion permutations idxs = improper_torsion_indices(g.mol, improper_def) if len(idxs) == 0: return g ## Add new nodes of type n4_improper (one for each permut) hg = dgl.add_nodes(hg, idxs.shape[0], ntype="n4_improper") ## New edges b/n improper permuts and n1 nodes permut_ids = np.arange(idxs.shape[0]) for i in range(4): n1_ids = idxs[:, i] # edge from improper node to n1 node outgoing_etype = ("n4_improper", f"n4_improper_has_{i}_n1", "n1") hg = dgl.add_edges(hg, permut_ids, n1_ids, etype=outgoing_etype) # edge from n1 to improper incoming_etype = ("n1", f"n1_as_{i}_in_n4_improper", "n4_improper") hg = dgl.add_edges(hg, n1_ids, permut_ids, etype=incoming_etype) ## New edges b/n improper permuts and the graph (for global pooling) # edge from improper node to graph outgoing_etype = ("n4_improper", f"n4_improper_in_g", "g") hg = dgl.add_edges( hg, permut_ids, np.zeros_like(permut_ids), etype=outgoing_etype ) # edge from graph to improper nodes incoming_etype = ("g", "g_has_n4_improper", "n4_improper") hg = dgl.add_edges( hg, np.zeros_like(permut_ids), permut_ids, etype=incoming_etype ) hg.nodes["n4_improper"].data["idxs"] = torch.tensor(idxs) g.heterograph = hg return g # hg