Official implementation of SynFlowNet, a GFlowNet model with a synthesis action space. The paper is available on arxiv.
Primer
SynFlowNet is a GFlowNet model that generates molecules from chemical reactions and available building blocks. This repo contains instructions for how to train SynFlowNet and sample synthesisable molecules with probability proportional to a reward specified by the user. The code builds upon the recursionpharma/gflownet codebase, available under the MIT license. For a primer and repo overview visit recursionpharma/gflownet.
This package is installable as a PIP package, but since it depends on some torch-geometric package wheels, the --find-links
arguments must be specified as well:
conda create --name sfn python=3.10.14
conda activate sfn
pip install -e . --find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html
Or for CPU use:
pip install -e . --find-links https://data.pyg.org/whl/torch-2.1.2+cpu.html
The training relies on two data sources: reaction templates and building blocks. Filenames are specified in the ReactionTaskConfig
. The model uses pre-computed masks to ensure compatibility between the building blocks and the reaction templates. Instructions for preprocessing building blocks and for computing masks can be found in synflownet/data/building_blocks.
SynFlowNet uses a reward to guide the molecule generation process. We have implemented a few reward functions in the ReactionTask
class. These include the SeH binding proxy, QED, oracles from PyTDC and Vina docking (see below). Other reward functions can be imported in the synflownet/tasks/reactions_task.py
file.
The model can be trained by running synflownet/tasks/reactions_task.py
using different reward functions. You may want to change the default configuration in main()
.
For easy adoption to other targets, a GPU-accelerated version of Vina docking can be used to calculate rewards as binding affinities to targets of interest. Follow the instructions at this repo to compile an excuteable for QuickVina2-GPU-2-1
. One done, place the excuteable in bin/
.
Below is a breakdown of the different components of SynFlowNet.
We separate experiment concerns in four categories:
ReactionTemplateEnv
is the definition of the reaction MDP and it implements stepping forward and backward in the environment.ReactionTemplateEnvContext
provides an interface between the agent and the environment, it- maps graphs to other molecule representations and to torch_geometric
Data
instances - maps GraphActions to action indices
- creates masks for actions
- communicates with the model what inputs it should expect
- maps graphs to other molecule representations and to torch_geometric
- The
ReactionTask
class is responsible for computing rewards, and for sampling conditional information - The
ReactionTrainer
class is responsible for instanciating everything, and running the training loop
The GraphTransformerSynGFN
class is used to parameterize the policies and outputs a specific categorical distribution type for the actions defined in ReactionTemplateEnvContext
. If config.model.graph_transformer.continuous_action_embs
is set to True
, then the probability of sampling building blocks is computed from the normalized dot product of the molecule representation and the embedding vector of the state. The ActionCategorical
class contains the logic to sample from the hierarchical distribution of actions.
The data used for training the GFlowNet can come from multiple sources:
- Generating new trajectories on-policy from s_0
- Generating new trajectories on-policy backwards from samples stored in a replay buffer (for training the backwards policy with REINFORCE)
- Sampling trajectories from a fixed, offline dataset
If you use this code in your research, please cite the following paper:
@article{cretu2024synflownet,
title={SynFlowNet: Design of Diverse and Novel Molecules with Synthesis Constraints},
author={Miruna Cretu, Charles Harris, Ilia Igashov, Arne Schneuing, Marwin Segler, Bruno Correia, Julien Roy, Emmanuel Bengio and Pietro Liò},
journal={arXiv preprint arXiv},
year={2024}
}