cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
cancel
Showing results for 
Search instead for 
Did you mean: 
yenlow
Databricks Employee
Databricks Employee

Did you know that AI was successfully used to discover a novel antibiotic, halicin, in 2020? This was noteworthy because halicin was structurally unique, highly differentiated from conventional antibiotics like penicilins, and unlocked a new direction in countering growing antibiotic resistance (Stokes et al. 2020). 

yenlow_0-1750857216323.png

Fig 1. Halicin, a structurally unique and novel antibiotic, discovered using Chemprop

Halicin was discovered using Chemprop (git repo). You can do the same to discover a new drug or a new disease indication for an existing drug (aka drug repurposing) using Chemprop and other open source AI tools on Databricks. This blog shows how highly specialized libraries like Chemprop can be easily integrated into Databricks for drug discovery. 

Why do AI Drug Design on Databricks?

Databricks facilitates production-grade research by providing a unified platform for data processing, model training, and deployment. 

Its Unity Catalog for managing data and model assets promote discovery and collaboration, making it easy to search and re-use large and complex data sets and models. 

MLflow, while opensource, is a first class citizen on Databricks, allowing both no-code and SDK options for experiment tracking, model registration, serving and monitoring. MLflow simplifies MLOps so research scientists can focus on developing models and results interpretation.

Why use Chemprop?

Chemprop is a suite of AI tools based on a directed message-passing neural network (MPNN) which treats molecules as graphs (atoms as nodes and bonds as edges). The model applies a series of message-passing steps where it aggregates information from neighboring atoms and bonds to build an understanding of local chemistry. This learned fingerprint representation is fed into a feed-forward neural network (FFN) that outputs a molecular property such as toxicity, or in halicin’s case, the ability to inhibit bacterial growth. 

yenlow_1-1750857216269.png

Fig 2. Chemprop treats molecules as graph structures and uses a message-passing neural network (MPNN) to learn feature representation. The MPNN is coupled with a feed-forward neural network (FFN) for property prediction. Source: Held et. al. 2023

In the high-stakes race to develop much needed drugs, such fast and accurate molecular property prediction by AI is key. Besides halicin, Chemprop has been successfully used by many pharma researchers to predict drug potency, IR spectra, combination drug synergy, etc. Below are a few example workflows to show how you can reuse existing models or train new ones to predict drug properties so you can find good drug candidates with desirable properties.

Example 1  : Load existing models for inferencing, e.g. solubility prediction

Chemprop relies on the pytorch framework so one can utilize existing models loaded from checkpoint files. 

Example 2 : Train a new model to specific chemical libraries or property, e.g. toxicity classifier

Sometimes, existing models are inadequate as they may be nonexistent for a particular molecular property or may not be applicable to the chemical space of interest. Thus, it may be necessary to train a new custom model.

Example 3: Load a newly trained model registered on Databricks mlflow for inferencing, e.g. toxicity prediction

After one trains a model (Example 2), MLflow logs the model metrics and saves the model artifacts. The model can be reloaded on a notebook for prediction. It can also be registered and served so others can search and re-use it via a REST API provided by Databricks Model Serving.

Example 4: Multi-task training, e.g. ADMET regression model

For a compound to be considered a good drug candidate, it must possess several desirable ADMET (absorption, distribution, metabolism, excretion, toxicity) properties such that it can be readily absorbed into the body, distributed to the target site, as well as being safely metabolized and excreted out of the body. As there may be as many as hundreds to thousands of ADMET properties to predict, it is common to do multi-task training and inferencing to find good candidates possessing many of them. Multi-task training is advantageous as the predicted properties are highly correlated and the joint data analysis allows knowledge gained from one task to improve another task.

 

Setup

Setup is straightforward as Chemprop is available as a Python package on PyPI or on GitHub. You will also need to install its dependency, rdkit-pypi

pip install chemprop rdkit-pypi

 

Example 1: Load existing solubility model for inference

If you already have models as pytorch checkpoint files (*.ckpt), you can load it directly with Chemprop and use it for inferencing. 

import torch
from lightning.pytorch import Trainer
from chemprop import data, featurizers, models

# Load model as mpnn
checkpoint_path = <ckpt_file_path>
mpnn = models.MPNN.load_from_checkpoint(checkpoint_path)
...

# Predict with the loaded mpnn
with torch.inference_mode():
   trainer = Trainer(
       logger=None,
       enable_progress_bar=True,
       accelerator="cpu",
       devices=1
   )
   test_preds = trainer.predict(mpnn, data_loader)

See example NB which uses a multi-component regressor (source: Chemprop repo) to predict if a compound would dissolve in a solvent. It expects 2 columns in Simplified Molecular Input Line Entry System (SMILES), a text representation of molecule structures: one for the compound to be dissolved and another representing the solvent dissolving the compound (see dataset for inferencing). The output is solubility as measured by UV-Vis spectroscopy.

Compound Solvent
CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S ClCCl
C(=C/c1cnccn1)\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1 ClCCl
CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1 O

 

Example 2: Train a single-task classifier

If there are no satisfactory existing models, one can train a model on specific chemical libraries or prediction properties. For example, we can train a classifier on the ClinTox database (Wu et al.) which consists of 1491 drugs labelled if they exhibited toxicity during clinical trials (source: Huggingface). See accompanying NB

Define model architecture

Generally, chemprop MPNN models consist of a message-passing module, an aggregation module, and a final feed forward network (FFN) module. These modules should be configured to fit the task at hand. For example, use a BinaryClassificationFFN as the output layer for binary classification and RegressionFFN for regression of continuous properties.

from chemprop import models, nn

mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()

# If classification
ffn = nn.BinaryClassificationFFN()
metric_list = [nn.metrics.BinaryAUROC(),
              nn.metrics.BinaryAUPRC(),
              nn.metrics.BinaryAccuracy(),
              nn.metrics.BinaryF1Score()]

# If regression
# ffn = nn.RegressionFFN()
# metric_list = [nn.metrics.R2Score(),
#              nn.metrics.RMSE()]

mpnn = models.MPNN(mp, agg, ffn, batch_norm=True, metrics=metric_list)

Data preparation

Chemprop expects molecules to be represented as SMILES (cite) text strings. It converts SMILES into its MoleculeDatapoint class that tracks the target property, the atoms and bonds as nodes and edges respectively and any additional molecular descriptors.

# Convert SMILES -> MoleculeData
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

Data splitting

As per best practices, split the data into train, validation, and test datasets. Chemprop has a data module with many splitting helper functions to make this easy. However, it expects molecules as RDkit Mol objects so we convert them accordingly as follows: 

from chemprop import data

# MoleculeDatapoint -> RDKit Mol
mols = [d.mol for d in all_data]

train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
   all_data, train_indices, val_indices, test_indices
)

Once split, generate graph descriptors using an appropriate featurizer and then finally convert to a pytorch dataloader.

from chemprop import data, featurizers
# Featurization: MoleculeDatapoint -> graph descriptors
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data[0], featurizer)
val_dset = data.MoleculeDataset(val_data[0], featurizer)
test_dset = data.MoleculeDataset(test_data[0], featurizer)

train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers)

Training

Once the above model architecture and datasets are defined, you can start training with the following few lines.

from lightning.pytorch import Trainer

trainer = Trainer(
   logger=False,
   enable_checkpointing=True,
   enable_progress_bar=True,
   accelerator="auto",
   max_epochs=20
   )
trainer.fit(mpnn, train_loader, val_loader)

Testing

Users can also test the model with the test holdout set.

# To get test statistics
test_stats = Trainer(logger=False).test(mpnn, test_loader)

# Inference to get prediction values
test_preds = Trainer(logger=False).predict(mpnn, test_loader)

Manage your trained models with MLflow

As Databricks has MLflow for model lifecycle management, it is highly recommended to log, register and serve your trained models using MLflow. Just add a couple of MLflow commands to the trainer.fit function call.

Log and register models (also autosaves models)

import mlflow.pytorch

mlflow.pytorch.autolog(registered_model_name=<some_model_name>)

with mlflow.start_run() as run:
   trainer.fit(mpnn, train_loader, val_loader)
mlflow.end_run()

Optionally save model artifacts to a volume on Unity Catalog

You can optionally save the model files to a desired volume path on Unity Catalog although the Model Registry has already saved the files upon registration.

mlflow.artifacts.download_artifacts(
run_id=run.info.run_id, 
artifact_path="some_artifact_path",                            
dst_path="some_vol_path")

See NB for the end-to-end execution of Example 2.

 

Example 3: Load model from MLflow for inferencing

If you have registered the model using MLflow, you can load it for inference with simply the following:

from lightning.pytorch import Trainer

model_uri = "models:/registered_model_name/model_version" 
model = mlflow.pytorch.load_model(model_uri)
test_preds_reloaded = Trainer(logger=False).predict(model, drugbank_loader)

To get the model_uri, check out these options

This example NB shows how the model trained in Example 2 was loaded from mlflow and used to predict the clinical toxicity of DrugBank, a database of over 2000 FDA-approved small molecule drugs (source: doi.org/10.5281/zenodo.10372418). 

 

Example 4: Train multi-task ADMET regression model

For a compound to be considered a good drug candidate, it must possess several desirable ADMET properties such as bioavailability, high potency and low toxicity. In this example, we trained a multi-task regression model on 10 continuous ADMET properties from the Therapeutics Data Commons simultaneously (doi.org/10.5281/zenodo.10372418). It is advantageous to do multi-task training as the predicted properties are highly correlated and the joint data analysis facilitates knowledge gained from one task to improve another task.

Table 1: sample training data with 10 continuous ADMET properties for multi-task regression

SMILES C=C[C@H]1CN2CC[C@H]1C[C@@H]2[C@@H](O)c1ccnc2ccc(OC)cc12 CC(=O)Nc1ccc(O)cc1 C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1
Caco2_Wang -4.6900001 -4.4400001 -4.4699998
Clearance_Hepatocyte_AZ 6.17 6.31 7.41
Clearance_Microsome_AZ null 3 18.62
Half_Life_Obach 6.6 2.5 null
HydrationFreeEnergy_FreeSolv null null null
LD50_Zhu null 1.799 null
Lipophilicity_AstraZeneca 2.21 0.25 3.35
PPBR_AZ 85.48 26.64 95.82
Solubility_AqSolDB -2.81214297 -1.033323213 null
VDss_Lombardo null 1 0.77

The code is similar to single-task training (Example 2) except that the final FFN has to accommodate the 10 targets that are co-trained. Thus, we define the ffn with 10 tasks as follows:

ffn = nn.RegressionFFN(n_tasks=10)

Once trained, the multi-task regressor can be used to predict the 10 ADMET properties of DrugBank loaded for inferencing in Example 3.

See these links for the end-to-end execution for multi-task training and multi-task inferencing.

 

Conclusions

This blog provides a quickstart guide for how to use Chemprop to load existing models to train new models and perform a variety of single-/multi-task molecular property predictions. To learn more, check out the official tutorials from Chemprop.