Coder Social home page Coder Social logo

molecularai / molbart Goto Github PK

View Code? Open in Web Editor NEW
43.0 12.0 9.0 183.89 MB

Pretrained SMILES transformation model for finetuning for diverse molecular tasks.

License: Apache License 2.0

Shell 1.19% Python 92.87% Makefile 0.06% C++ 5.42% TeX 0.09% Cuda 0.36%

molbart's Introduction

MolBART

The MolBART project aims to pre-train a BART transformer language model [2] on molecular SMILES strings [4] by optimising a de-noising objective[2] as well as a chemical format transformation specific to the SMILES language (heteroencoding)[3]. Pre-training lead to improved generalisation, performance, training speed and validity on downstream fine-tuned tasks. The project has also been called Chemformer, and the approach has been tested on downstream tasks such as reaction prediction, retrosynthetic prediction, molecular optimisation and molecular property prediction[1].

Installation

Firstly, Apex and pysmilesutils must be downloaded, then the project dependencies can be installed as follows:

  • conda create --name molbart rdkit -c rdkit
  • conda install pytorch==1.8.0 torchvision cudatoolkit==11.1 -c pytorch
  • conda install gcc_linux-64 gxx_linux-64 mpi4py
  • pip install requirements.txt
  • cd ../pysmilesutils && python setup.py install

Code

The codebase is broadly split into the following parts:

  • Models
  • Data helpers
  • Tokenisation
  • Decoding
  • Scripts

Models

The models.py file contains a Pytorch Lightning implementation of the BART language model, as well as Pytorch Lightning implementations of models for downstream tasks.

Data Helpers

The dataset.py file contains a number of classes used to load, batch and process the data before it is passed to the model. Classes which inherit from _AbsDataset are subclasses of Pytorch's nn.utils.Dataset and are simply used to store and split data (molecules, reactions, etc) into its relevant subset (train, val, test).

Our _AbsDataModule class inherits from Pytorch Lightning's LightningDataModule class, and its subclasses are used to augment, tokenise and tensorise the data before it passed to the model.

Finally, we include a TokenSampler class which categorises sequences into buckets based on their length, and is able to sample a different batch size of sequences from each bucket. This helps to ensure that the model sees approximately the same number of tokens on each batch, as well as dramatically improving training speed.

Tokenisation

Our tokenise.py file includes the MolEncTokeniser class which is capable of random 'BERT-style' masking of tokens, as well as padding each batch of sequences to be the same length. The tokeniser makes use of the SMILESTokenizer from the pysmilesutils library for tokenising SMILES into their constituent atoms.

Decoding

We include implementations of greedy and beam search decoding in the decode.py file. Both implementations make use of batch decoding for improved evaluation speeds. They do not, however, cache results from previous decodes, rather, they simply pass the entire sequence of tokens produced so far through the transformer decoder.

Scripts

The repository includes the following scripts:

  • train.py runs the pre-training
  • fine_tune.py runs fine-tuning on a specified task
  • evaluate.py evaluates the performance of a fine-tuned model
  • build_tokeniser.py creates a tokeniser from a dataset and stores it in a pickle file

Each script can be run using python -m molbart.<scipt_name> <args>.

See the ArgumentParser args in each file for more details on each argument.

To run on multiple GPUs use the --gpus <num> argument for the train or fine tune scripts. This will run the script with Pytorch Lightning's distributed data parallel (DDP) processing. Validation will be disabled when using DDP to ensure the GPUs stay synchronised and stop possible deadlocks from occurring.

References

[1] Ross Irwin, Spyridon Dimitriadis, Jiazhen He, and Esben Jannik Bjerrum, "Chemformer: A Pre-Trained Transformer for Computational Chemistry", ChemRXiv (2021)

[2] Lewis, Mike, et al., "Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension.", arXiv preprint arXiv:1910.13461 (2019).

[3] Esben J Bjerrum and Boris Sattarov, “Improving Chemical Autoencoder Latent Space and Molecular De Novo Generation Diversity with Heteroencoders”, Biomolecules, (2018) http://doi.org/10.3390/biom8040131

[4] Weininger, David., "SMILES, a chemical language and information system. 1. Introduction to methodology and encoding rules.", Journal of chemical information and computer sciences 28.1 (1988): 31-36.

molbart's People

Contributors

mdemouth avatar rahulmohan avatar rssrwn avatar sirelkhatim avatar vellamike avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

molbart's Issues

Error with pysmilesutils and MolRandomizer

Hi, I am trying to run the Molbart train script and I am having problems with pysmilesutils.

molbart/data/datamodules.py is using MolRandomizer, but in the pysmilesutils repo I can only see MolAugmenter. This is throwing an error that prevents from running the training.

Are you using an older version of library? If so, which one? I am trying to work around it by switching MolRandomizer to MolAugmenter, but no idea whether this will work. Any hints?

See below:

diff --git a/MolBART/molbart/data/datamodules.py b/MolBART/molbart/data/datamodules.py
index 3c2ad67..bc909f1 100644
--- a/MolBART/molbart/data/datamodules.py
+++ b/MolBART/molbart/data/datamodules.py
@@ -5,7 +5,7 @@ from rdkit import Chem
 from functools import partial
 from typing import List, Optional
 from torch.utils.data import DataLoader
-from pysmilesutils.augment import MolRandomizer
+from pysmilesutils.augment import MolAugmenter

 from molbart.tokeniser import MolEncTokeniser
 from molbart.data.util import TokenSampler
@@ -180,7 +180,7 @@ class MoleculeDataModule(_AbsDataModule):

         if augment:
             print("Using molecule data module with augmentations.")
-            self.aug = MolRandomizer()
+            self.aug = MolAugmenter()
         else:
             print("No molecular augmentation.")
             self.aug = None
@@ -328,7 +328,7 @@ class FineTuneReactionDataModule(_AbsDataModule):
             print("Training on backward prediction task.")

         self.augment = augment
-        self.aug = MolRandomizer() if augment is not None else None
+        self.aug = MolAugmenter() if augment is not None else None
         self.forward_pred = forward_pred

     def _collate(self, batch, train=True):

ImportError: cannot import name 'MolRandomizer' from 'pysmilesutils.augment'

Dear authors,
I have installed latest pysmileutils from your repo, while I cannot run the MolBART due to following : ImportError: cannot import name 'MolRandomizer' from 'pysmilesutils.augment' (/xxxx/xxxx/xxxx/anaconda3/envs/xxxx/lib/python3.9/site-packages/pysmilesutils/augment.py).

Is it versioning problem? How can I solve it?

Best Regards,

Eric

Available Models and Datasets

Hello! Thank you for your outstanding work.
May I ask if the trained models and datasets of Chemformer or MolBART could be available now?
It would be so grateful if you can share them.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.