Shuan Chen

PhD Student in KAIST CBE

0%

Graph Neural Network (GNN) for chemistry

Graph Neural Network (GNN)

Due to the high similarity between molecule and (heterogeneous) graph, GNN has been introduced to predict molecule property in recent years. As an introduciton of GNN, this article won’t go through the mathematical detail but only introduce the concept of GNN algorithm.

To make long story short, the purpose of GNN is to update the node features (atom) in the graph (molecule) to simulate the electron interaction effect in the molecule. Thus, GNN model is particularly effective in predicting quantum chemisrty and reaction prediction.

Threre are many types of GNN, but the main four steps in GNN are the same, namely
1. Initializing Node Feature
2. Node Feature Embedding and Updating (Main GNN algorithm)
3. Readout
4. Prediction

GNN

For convinience, I use DGL-LifeSci to perform all the functions.
To install dgl-lifesci, run

1
conda install -c dglteam dgllife

Initializing Node Feature

To start the GNN training, you need to one-hot initialize all the atom and bond features by self-defined function, such as atomic number and bond type.
For me, I chose WeaveAtomFeaturizer for atom and CanonicalBondFeaturizer for bond.

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import sklearn
import dgl
from dgllife.utils import smiles_to_bigraph, WeaveAtomFeaturizer, CanonicalBondFeaturizer

def GraphFromSmiles(smiles):
node_featurizer = WeaveAtomFeaturizer()
edge_featurizer = CanonicalBondFeaturizer()
return smiles_to_bigraph(smiles, node_featurizer = node_featurizer, edge_featurizer = edge_featurizer)

smiles = 'O=C1NC(=O)CCC1N3C(=O)c2cccc(c2C3)N' # Lenalidomide
graph = GraphFromSmiles(smiles)

If you print the variable graph, it should show
1
2
3
Graph(num_nodes=19, num_edges=42,
ndata_schemes={'h': Scheme(shape=(27,), dtype=torch.float32)}
edata_schemes={'e': Scheme(shape=(12,), dtype=torch.float32)})

Which means you have 19 atoms and 42 bonds in the molecule, where each atom has a feature vector with 27 features and each bond has a feature vector with 12 features.

Node Feature Embedding and Updating

Here is the core part of GNN, each atoms are embedded by linear layer and update by their environment (atoms and bonds) by GNN algorithm. There are many types of GNNs depending on differnt neural network used to update the atom:
If you use convolution algorithm, it is called Graph Convolutional Neural Network (GCN).
If you use attention mechanism, it is called Graph Attention Neural Network (GAT).
If you use convolution and recurrent algorithm, it is called Message Passing Neural Network (MPNN).

I will do GCN in this tutorial, however you can change to other network as you wish.
Because graph convoluion is already defined in dgl module, we only need to import and connect the GCN layers.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import GraphConv

class GCN(nn.Module):
def __init__(self, in_feats, hidden_feats):
super(GCN, self).__init__()

self.gnn_layers = nn.ModuleList()
for i in range(len(hidden_feats)):
self.gnn_layers.append(GraphConv(in_feats, hidden_feats[i]))
in_feats = hidden_feats[i]

def forward(self, g, feats):
for gnn in self.gnn_layers:
feats = gnn(g, feats)
return feats

Readout

The readout funciton is used to summurize the feature of the molecule from the updated atoms. One can use recurrent, convolution network or Set2Set to do readout. But I found many people just use summation or average and still get fine results.

1
2
3
4
5
6
7
8
9
10
11
class Readout(nn.Module):
def __init__(self, mode):
super(Readout, self).__init__()
self.mode = mode

def forward(self, feats):
if self.mode == 'mean':
feats = torch.mean(feats, 0)
elif self.mode == 'sum':
feats = torch.sum(feats, 0)
return feats

Prediction

Lastly, you can simply make a linear layer to fit the readout feature to your data label.

1
2
3
4
5
6
7
8
class Predictor(nn.Module):
def __init__(self, in_feats):
super(Predictor, self).__init__()
self.predictor = nn.Linear(in_feats, 1)

def forward(self, feats):
feats = self.predictor(feats)
return nn.Sigmoid()(feats)

GNN Netowrk

Now we have defined all the needed components, put them together and make an end-to-end model!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class GNN(nn.Module):
def __init__(self, in_feats, hidden_feats):
super(GNN, self).__init__()
self.gcn = GCN(in_feats, hidden_feats)
self.readout = Readout('sum')
self.predictor = Predictor(hidden_feats[-1])

def forward(self, graph, feats):
feats = self.gcn(graph, feats)
feats = self.readout(feats)
output = self.predictor(feats)
return output

hidden_size = [256, 64, 32]
node_feats = graph.ndata.pop('h')
model = GNN(node_feats.shape[-1], hidden_size)
output = model(graph, node_feats)

If you print your output, you will get something like
1
tensor([0.7552], grad_fn=<SigmoidBackward>)

Get your data ready and train your model with GNN!!
(See pytorch tutorial if you do not know how to train pytorch model!