Graph Neural Network (GNN)
Due to the high similarity between molecule and (heterogeneous) graph, GNN has been introduced to predict molecule property in recent years. As an introduciton of GNN, this article won’t go through the mathematical detail but only introduce the concept of GNN algorithm.
To make long story short, the purpose of GNN is to update the node features (atom) in the graph (molecule) to simulate the electron interaction effect in the molecule. Thus, GNN model is particularly effective in predicting quantum chemisrty and reaction prediction.
Threre are many types of GNN, but the main four steps in GNN are the same, namely
1. Initializing Node Feature
2. Node Feature Embedding and Updating (Main GNN algorithm)
3. Readout
4. Prediction
For convinience, I use DGL-LifeSci to perform all the functions.
To install dgl-lifesci, run1
conda install -c dglteam dgllife
Initializing Node Feature
To start the GNN training, you need to one-hot initialize all the atom and bond features by self-defined function, such as atomic number and bond type.
For me, I chose WeaveAtomFeaturizer
for atom and CanonicalBondFeaturizer
for bond.1
2
3
4
5
6
7
8
9
10
11
12import torch
import sklearn
import dgl
from dgllife.utils import smiles_to_bigraph, WeaveAtomFeaturizer, CanonicalBondFeaturizer
def GraphFromSmiles(smiles):
node_featurizer = WeaveAtomFeaturizer()
edge_featurizer = CanonicalBondFeaturizer()
return smiles_to_bigraph(smiles, node_featurizer = node_featurizer, edge_featurizer = edge_featurizer)
smiles = 'O=C1NC(=O)CCC1N3C(=O)c2cccc(c2C3)N' # Lenalidomide
graph = GraphFromSmiles(smiles)
If you print the variable graph, it should show1
2
3Graph(num_nodes=19, num_edges=42,
ndata_schemes={'h': Scheme(shape=(27,), dtype=torch.float32)}
edata_schemes={'e': Scheme(shape=(12,), dtype=torch.float32)})
Which means you have 19 atoms and 42 bonds in the molecule, where each atom has a feature vector with 27 features and each bond has a feature vector with 12 features.
Node Feature Embedding and Updating
Here is the core part of GNN, each atoms are embedded by linear layer and update by their environment (atoms and bonds) by GNN algorithm. There are many types of GNNs depending on differnt neural network used to update the atom:
If you use convolution algorithm, it is called Graph Convolutional Neural Network (GCN).
If you use attention mechanism, it is called Graph Attention Neural Network (GAT).
If you use convolution and recurrent algorithm, it is called Message Passing Neural Network (MPNN).
I will do GCN in this tutorial, however you can change to other network as you wish.
Because graph convoluion is already defined in dgl module, we only need to import and connect the GCN layers.1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
class GCN(nn.Module):
def __init__(self, in_feats, hidden_feats):
super(GCN, self).__init__()
self.gnn_layers = nn.ModuleList()
for i in range(len(hidden_feats)):
self.gnn_layers.append(GraphConv(in_feats, hidden_feats[i]))
in_feats = hidden_feats[i]
def forward(self, g, feats):
for gnn in self.gnn_layers:
feats = gnn(g, feats)
return feats
Readout
The readout funciton is used to summurize the feature of the molecule from the updated atoms. One can use recurrent, convolution network or Set2Set to do readout. But I found many people just use summation or average and still get fine results.1
2
3
4
5
6
7
8
9
10
11class Readout(nn.Module):
def __init__(self, mode):
super(Readout, self).__init__()
self.mode = mode
def forward(self, feats):
if self.mode == 'mean':
feats = torch.mean(feats, 0)
elif self.mode == 'sum':
feats = torch.sum(feats, 0)
return feats
Prediction
Lastly, you can simply make a linear layer to fit the readout feature to your data label.1
2
3
4
5
6
7
8class Predictor(nn.Module):
def __init__(self, in_feats):
super(Predictor, self).__init__()
self.predictor = nn.Linear(in_feats, 1)
def forward(self, feats):
feats = self.predictor(feats)
return nn.Sigmoid()(feats)
GNN Netowrk
Now we have defined all the needed components, put them together and make an end-to-end model!1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17class GNN(nn.Module):
def __init__(self, in_feats, hidden_feats):
super(GNN, self).__init__()
self.gcn = GCN(in_feats, hidden_feats)
self.readout = Readout('sum')
self.predictor = Predictor(hidden_feats[-1])
def forward(self, graph, feats):
feats = self.gcn(graph, feats)
feats = self.readout(feats)
output = self.predictor(feats)
return output
hidden_size = [256, 64, 32]
node_feats = graph.ndata.pop('h')
model = GNN(node_feats.shape[-1], hidden_size)
output = model(graph, node_feats)
If you print your output, you will get something like1
tensor([0.7552], grad_fn=<SigmoidBackward>)
Get your data ready and train your model with GNN!!
(See pytorch tutorial if you do not know how to train pytorch model!