Now Available on Amazon SageMaker: The Deep Graph Library
In recent years, Deep learning has taken the world by storm thanks to its uncanny ability to extract elaborate patterns from complex data, such as free-form text, images, or videos. However, lots of datasets don’t fit these categories and are better expressed with graphs. Intuitively, we can feel that traditional neural network architectures like convolution neural networks or recurrent neural networks are not a good fit for such datasets, and a new approach is required.
A Primer On Graph Neural Networks
Graph neural networks (GNN) are one of the most exciting developments in machine learning today, and these reference papers will get you started.
GNNs are used to train predictive models on datasets such as:
- Social networks, where graphs show connections between related people,
- Recommender systems, where graphs show interactions between customers and items,
- Chemical analysis, where compounds are modeled as graphs of atoms and bonds,
- Cybersecurity, where graphs describe connections between source and destination IP addresses,
- And more!
Most of the time, these datasets are extremely large and only partially labeled. Consider a fraud detection scenario where we would try to predict the likelihood that an individual is a fraudulent actor by analyzing his connections to known fraudsters. This problem could be defined as a semi-supervised learning task, where only a fraction of graph nodes would be labeled (‘fraudster’ or ‘legitimate’). This should be a better solution than trying to build a large hand-labeled dataset, and “linearizing” it to apply traditional machine learning algorithms.
Working on these problems requires domain knowledge (retail, finance, chemistry, etc.), computer science knowledge (Python, deep learning, open source tools), and infrastructure knowledge (training, deploying, and scaling models). Very few people master all these skills, which is why tools like the Deep Graph Library and Amazon SageMaker are needed.
Introducing The Deep Graph Library
First released on Github in December 2018, the Deep Graph Library (DGL) is a Python open source library that helps researchers and scientists quickly build, train, and evaluate GNNs on their datasets.
DGL is built on top of popular deep learning frameworks like PyTorch and Apache MXNet. If you know either one or these, you’ll find yourself quite at home. No matter which framework you use, you can get started easily thanks to these beginner-friendly examples. I also found the slides and code for the GTC 2019 workshop very useful.
Once you’re done with toy examples, you can start exploring the collection of cutting edge models already implemented in DGL. For example, you can train a document classification model using a Graph Convolution Network (GCN) and the CORA dataset by simply running:
$ python3 train.py --dataset cora --gpu 0 --self-loop
The code for all models is available for inspection and tweaking. These implementations have been carefully validated by AWS teams, who verified performance claims and made sure results could be reproduced.
DGL also includes a collection of graph datasets, that you can easily download and experiment with.
Of course, you can install and run DGL locally, but to make your life simpler, we added it to the Deep Learning Containers for PyTorch and Apache MXNet. This makes it easy to use DGL on Amazon SageMaker, in order to train and deploy models at any scale, without having to manage a single server. Let me show you how.
The problem we’re trying to solve is figuring it the potential toxicity of new chemical compounds with respect to 12 different targets (receptors inside biological cells, etc.). As you can imagine, this type of analysis is crucial when designing new drugs, and being able to quickly predict results without having to run in vitro experiments helps researchers focus their efforts on the most promising drug candidates.
The dataset contains a little over 8,000 compounds: each one is modeled as a graph (atoms are vertices, atomic bonds are edges), and labeled 12 times (one label per target). Using a GNN, we’re going to build a multi-label binary classification model, allowing us to predict the potential toxicity of candidate molecules.
In the training script, we can easily download the dataset from the DGL collection.
from dgl.data.chem import Tox21
dataset = Tox21()
Similarly, we can easily build a GNN classifier using the DGL model zoo.
from dgl import model_zoo
model = model_zoo.chem.GCNClassifier(
gcn_hidden_feats=[args['n_hidden'] for _ in range(args['n_layers'])],
The rest of the code is mostly vanilla PyTorch, and you should be able to find your bearings if you’re familiar with this library.
When it comes to running this code on Amazon SageMaker, all we have to do is use a SageMaker
Estimator, passing the full name of our DGL container, and the name of the training script as a hyperparameter.
estimator = sagemaker.estimator.Estimator(container,
code_location = sess.upload_data(CODE_PATH,
epoch 23/100, batch 48/49, loss 0.4684
epoch 23/100, batch 49/49, loss 0.5389
epoch 23/100, training roc-auc 0.9451
EarlyStopping counter: 10 out of 10
epoch 23/100, validation roc-auc 0.8375, best validation roc-auc 0.8495
Best validation score 0.8495
Test score 0.8273
2019-11-21 14:11:03 Uploading - Uploading generated training model
2019-11-21 14:11:03 Completed - Training job completed
Training seconds: 209
Billable seconds: 209
Now, we could grab the trained model in S3, and use it to predict toxicity for large number of compounds, without having to run actual experiments. Fascinating stuff!