Did you know that AI was successfully used to discover a novel antibiotic, halicin, in 2020? This was noteworthy because halicin was structurally unique, highly differentiated from conventional antibiotics like penicilins, and unlocked a new direction in countering growing antibiotic resistance (Stokes et al. 2020).
Fig 1. Halicin, a structurally unique and novel antibiotic, discovered using Chemprop
Halicin was discovered using Chemprop (git repo). You can do the same to discover a new drug or a new disease indication for an existing drug (aka drug repurposing) using Chemprop and other open source AI tools on Databricks. This blog shows how highly specialized libraries like Chemprop can be easily integrated into Databricks for drug discovery.
Databricks facilitates production-grade research by providing a unified platform for data processing, model training, and deployment.
Its Unity Catalog for managing data and model assets promote discovery and collaboration, making it easy to search and re-use large and complex data sets and models.
MLflow, while opensource, is a first class citizen on Databricks, allowing both no-code and SDK options for experiment tracking, model registration, serving and monitoring. MLflow simplifies MLOps so research scientists can focus on developing models and results interpretation.
Chemprop is a suite of AI tools based on a directed message-passing neural network (MPNN) which treats molecules as graphs (atoms as nodes and bonds as edges). The model applies a series of message-passing steps where it aggregates information from neighboring atoms and bonds to build an understanding of local chemistry. This learned fingerprint representation is fed into a feed-forward neural network (FFN) that outputs a molecular property such as toxicity, or in halicin’s case, the ability to inhibit bacterial growth.
Fig 2. Chemprop treats molecules as graph structures and uses a message-passing neural network (MPNN) to learn feature representation. The MPNN is coupled with a feed-forward neural network (FFN) for property prediction. Source: Held et. al. 2023
In the high-stakes race to develop much needed drugs, such fast and accurate molecular property prediction by AI is key. Besides halicin, Chemprop has been successfully used by many pharma researchers to predict drug potency, IR spectra, combination drug synergy, etc. Below are a few example workflows to show how you can reuse existing models or train new ones to predict drug properties so you can find good drug candidates with desirable properties.
Example 1 : Load existing models for inferencing, e.g. solubility prediction
Chemprop relies on the pytorch framework so one can utilize existing models loaded from checkpoint files.
Example 2 : Train a new model to specific chemical libraries or property, e.g. toxicity classifier
Sometimes, existing models are inadequate as they may be nonexistent for a particular molecular property or may not be applicable to the chemical space of interest. Thus, it may be necessary to train a new custom model.
Example 3: Load a newly trained model registered on Databricks mlflow for inferencing, e.g. toxicity prediction
After one trains a model (Example 2), MLflow logs the model metrics and saves the model artifacts. The model can be reloaded on a notebook for prediction. It can also be registered and served so others can search and re-use it via a REST API provided by Databricks Model Serving.
Example 4: Multi-task training, e.g. ADMET regression model
For a compound to be considered a good drug candidate, it must possess several desirable ADMET (absorption, distribution, metabolism, excretion, toxicity) properties such that it can be readily absorbed into the body, distributed to the target site, as well as being safely metabolized and excreted out of the body. As there may be as many as hundreds to thousands of ADMET properties to predict, it is common to do multi-task training and inferencing to find good candidates possessing many of them. Multi-task training is advantageous as the predicted properties are highly correlated and the joint data analysis allows knowledge gained from one task to improve another task.
Setup is straightforward as Chemprop is available as a Python package on PyPI or on GitHub. You will also need to install its dependency, rdkit-pypi
pip install chemprop rdkit-pypi
If you already have models as pytorch checkpoint files (*.ckpt), you can load it directly with Chemprop and use it for inferencing.
import torch
from lightning.pytorch import Trainer
from chemprop import data, featurizers, models
# Load model as mpnn
checkpoint_path = <ckpt_file_path>
mpnn = models.MPNN.load_from_checkpoint(checkpoint_path)
...
# Predict with the loaded mpnn
with torch.inference_mode():
trainer = Trainer(
logger=None,
enable_progress_bar=True,
accelerator="cpu",
devices=1
)
test_preds = trainer.predict(mpnn, data_loader)
See example NB which uses a multi-component regressor (source: Chemprop repo) to predict if a compound would dissolve in a solvent. It expects 2 columns in Simplified Molecular Input Line Entry System (SMILES), a text representation of molecule structures: one for the compound to be dissolved and another representing the solvent dissolving the compound (see dataset for inferencing). The output is solubility as measured by UV-Vis spectroscopy.
Compound | Solvent |
CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S | ClCCl |
C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1 | ClCCl |
CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1 | O |
If there are no satisfactory existing models, one can train a model on specific chemical libraries or prediction properties. For example, we can train a classifier on the ClinTox database (Wu et al.) which consists of 1491 drugs labelled if they exhibited toxicity during clinical trials (source: Huggingface). See accompanying NB.
Generally, chemprop MPNN models consist of a message-passing module, an aggregation module, and a final feed forward network (FFN) module. These modules should be configured to fit the task at hand. For example, use a BinaryClassificationFFN as the output layer for binary classification and RegressionFFN for regression of continuous properties.
from chemprop import models, nn
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()
# If classification
ffn = nn.BinaryClassificationFFN()
metric_list = [nn.metrics.BinaryAUROC(),
nn.metrics.BinaryAUPRC(),
nn.metrics.BinaryAccuracy(),
nn.metrics.BinaryF1Score()]
# If regression
# ffn = nn.RegressionFFN()
# metric_list = [nn.metrics.R2Score(),
# nn.metrics.RMSE()]
mpnn = models.MPNN(mp, agg, ffn, batch_norm=True, metrics=metric_list)
Chemprop expects molecules to be represented as SMILES (cite) text strings. It converts SMILES into its MoleculeDatapoint class that tracks the target property, the atoms and bonds as nodes and edges respectively and any additional molecular descriptors.
# Convert SMILES -> MoleculeData
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
As per best practices, split the data into train, validation, and test datasets. Chemprop has a data module with many splitting helper functions to make this easy. However, it expects molecules as RDkit Mol objects so we convert them accordingly as follows:
from chemprop import data
# MoleculeDatapoint -> RDKit Mol
mols = [d.mol for d in all_data]
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
all_data, train_indices, val_indices, test_indices
)
Once split, generate graph descriptors using an appropriate featurizer and then finally convert to a pytorch dataloader.
from chemprop import data, featurizers
# Featurization: MoleculeDatapoint -> graph descriptors
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_dset = data.MoleculeDataset(train_data[0], featurizer)
val_dset = data.MoleculeDataset(val_data[0], featurizer)
test_dset = data.MoleculeDataset(test_data[0], featurizer)
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers)
Once the above model architecture and datasets are defined, you can start training with the following few lines.
from lightning.pytorch import Trainer
trainer = Trainer(
logger=False,
enable_checkpointing=True,
enable_progress_bar=True,
accelerator="auto",
max_epochs=20
)
trainer.fit(mpnn, train_loader, val_loader)
Users can also test the model with the test holdout set.
# To get test statistics
test_stats = Trainer(logger=False).test(mpnn, test_loader)
# Inference to get prediction values
test_preds = Trainer(logger=False).predict(mpnn, test_loader)
As Databricks has MLflow for model lifecycle management, it is highly recommended to log, register and serve your trained models using MLflow. Just add a couple of MLflow commands to the trainer.fit function call.
Log and register models (also autosaves models)
import mlflow.pytorch
mlflow.pytorch.autolog(registered_model_name=<some_model_name>)
with mlflow.start_run() as run:
trainer.fit(mpnn, train_loader, val_loader)
mlflow.end_run()
Optionally save model artifacts to a volume on Unity Catalog
You can optionally save the model files to a desired volume path on Unity Catalog although the Model Registry has already saved the files upon registration.
mlflow.artifacts.download_artifacts(
run_id=run.info.run_id,
artifact_path="some_artifact_path",
dst_path="some_vol_path")
See NB for the end-to-end execution of Example 2.
If you have registered the model using MLflow, you can load it for inference with simply the following:
from lightning.pytorch import Trainer
model_uri = "models:/registered_model_name/model_version"
model = mlflow.pytorch.load_model(model_uri)
test_preds_reloaded = Trainer(logger=False).predict(model, drugbank_loader)
To get the model_uri, check out these options.
This example NB shows how the model trained in Example 2 was loaded from mlflow and used to predict the clinical toxicity of DrugBank, a database of over 2000 FDA-approved small molecule drugs (source: doi.org/10.5281/zenodo.10372418).
For a compound to be considered a good drug candidate, it must possess several desirable ADMET properties such as bioavailability, high potency and low toxicity. In this example, we trained a multi-task regression model on 10 continuous ADMET properties from the Therapeutics Data Commons simultaneously (doi.org/10.5281/zenodo.10372418). It is advantageous to do multi-task training as the predicted properties are highly correlated and the joint data analysis facilitates knowledge gained from one task to improve another task.
Table 1: sample training data with 10 continuous ADMET properties for multi-task regression
SMILES | C=C[C@H]1CN2CC[C@H]1C[C@@H]2[C@@H](O)c1ccnc2ccc(OC)cc12 | CC(=O)Nc1ccc(O)cc1 | C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1 |
Caco2_Wang | -4.6900001 | -4.4400001 | -4.4699998 |
Clearance_Hepatocyte_AZ | 6.17 | 6.31 | 7.41 |
Clearance_Microsome_AZ | null | 3 | 18.62 |
Half_Life_Obach | 6.6 | 2.5 | null |
HydrationFreeEnergy_FreeSolv | null | null | null |
LD50_Zhu | null | 1.799 | null |
Lipophilicity_AstraZeneca | 2.21 | 0.25 | 3.35 |
PPBR_AZ | 85.48 | 26.64 | 95.82 |
Solubility_AqSolDB | -2.81214297 | -1.033323213 | null |
VDss_Lombardo | null | 1 | 0.77 |
The code is similar to single-task training (Example 2) except that the final FFN has to accommodate the 10 targets that are co-trained. Thus, we define the ffn with 10 tasks as follows:
ffn = nn.RegressionFFN(n_tasks=10)
Once trained, the multi-task regressor can be used to predict the 10 ADMET properties of DrugBank loaded for inferencing in Example 3.
See these links for the end-to-end execution for multi-task training and multi-task inferencing.
This blog provides a quickstart guide for how to use Chemprop to load existing models to train new models and perform a variety of single-/multi-task molecular property predictions. To learn more, check out the official tutorials from Chemprop.
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.