Coder Social home page Coder Social logo

splicebert's Introduction

SpliceBERT: RNA langauge model pre-trained on vertebrate primary RNA sequences

SpliceBERT (manuscript, preprint) is a primary RNA sequence language model pre-trained on over 2 million vertebrate RNA sequences. It can be used to study RNA splicing and other biological problems related to RNA sequence.

For additional benchmarks and applications of SpliceBERT (e.g., on SpliceAI's and DeepSTARR's datasets), see SpliceBERT-analysis.

SpliceBERT overview

Data availability

The model weights and data for analysis are available at zenodo:7995778.

How to use SpliceBERT?

SpliceBERT is implemented with Huggingface transformers and FlashAttention in PyTorch. Users should install pytorch, transformers and FlashAttention (optional) to load the SpliceBERT model.

SpliceBERT can be easily used for a series of downstream tasks through the official API. See official guide for more details.

Download SpliceBERT

The weights of SpliceBERT can be downloaded from zenodo: https://zenodo.org/record/7995778/files/models.tar.gz?download=1

System requirements

We recommend running SpliceBERT on a Linux system with a NVIDIA GPU of at least 4GB memory. (Running our model with only CPU is possible, but it will be very slow.)

Examples
We provide a demo script to show how to use SpliceBERT though the official API of Huggingface transformers in the first part of the following code block.
Users can also use SpliceBERT with FlashAttention by replacing the official API with the custom API, as shown in the second part of the following code block. Note that flash-attention requires automatic mixed precision (amp) mode to be enabled and currently it does not support attention_mask

Use SpliceBERT though the official API of Huggingface transformers:

SPLICEBERT_PATH = "/path/to/SpliceBERT/models/model_folder"  # set the path to the folder of pre-trained SpliceBERT
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForTokenClassification

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(SPLICEBERT_PATH)

# prepare input sequence
seq = "ACGUACGuacguaCGu"  ## WARNING: this is just a demo. SpliceBERT may not work on sequences shorter than 64nt as it was trained on sequences of 64-1024nt in length
seq = ' '.join(list(seq.upper().replace("U", "T"))) # U -> T and add whitespace
input_ids = tokenizer.encode(seq) # N -> 5, A -> 6, C -> 7, G -> 8, T(U) -> 9. NOTE: a [CLS] and a [SEP] token will be added to the start and the end of seq
input_ids = torch.as_tensor(input_ids) # convert python list to Tensor
input_ids = input_ids.unsqueeze(0) # add batch dimension, shape: (batch_size, sequence_length)

# use huggerface's official API to use SpliceBERT
# get nucleotide embeddings (hidden states)
model = AutoModel.from_pretrained(SPLICEBERT_PATH) # load model
last_hidden_state = model(input_ids).last_hidden_state # get hidden states from last layer
hiddens_states = model(input_ids, output_hidden_states=True).hidden_states # hidden states from the embedding layer (nn.Embedding) and the 6 transformer encoder layers

# get nucleotide type logits in masked language modeling
model = AutoModelForMaskedLM.from_pretrained(SPLICEBERT_PATH) # load model
logits = model(input_ids).logits # shape: (batch_size, sequence_length, vocab_size)

# finetuning SpliceBERT for token classification tasks
model = AutoModelForTokenClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)

# finetuning SpliceBERT for sequence classification tasks
model = AutoModelForSequenceClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)

Or use SpliceBERT with FlashAttention by replacing the official API with the custom API (Currently flash-attention does not support attention_mask. As a result, the length of sequences in each batch should be the same)

SPLICEBERT_PATH = "/path/to/SpliceBERT/models/model_folder"  # set the path to the folder of pre-trained SpliceBERT
import torch
import sys
sys.path.append(os.path.dirname(os.path.abspath(SPICEBERT_PATH)))
from transformers import AutoTokenizer
from splicebert_model import BertModel, BertForMaskedLM, BertForTokenClassification

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(SPLICEBERT_PATH)

# prepare input sequence
seq = "ACGUACGuacguaCGu"  ## WARNING: this is just a demo. SpliceBERT may not work on sequences shorter than 64nt as it was trained on sequences of 64-1024nt in length
seq = ' '.join(list(seq.upper().replace("U", "T"))) # U -> T and add whitespace
input_ids = tokenizer.encode(seq) # N -> 5, A -> 6, C -> 7, G -> 8, T(U) -> 9. NOTE: a [CLS] and a [SEP] token will be added to the start and the end of seq
input_ids = torch.as_tensor(input_ids) # convert python list to Tensor
input_ids = input_ids.unsqueeze(0) # add batch dimension, shape: (batch_size, sequence_length)

# Or use custom BertModel with FlashAttention
# get nucleotide embeddings (hidden states)
model = BertModel.from_pretrained(SPLICEBERT_PATH) # load model
with autocast():
    last_hidden_state = model(input_ids).last_hidden_state # get hidden states from last layer
    hiddens_states = model(input_ids, output_hidden_states=True).hidden_states # hidden states from the embedding layer (nn.Embedding) and the 6 transformer encoder layers

# get nucleotide type logits in masked language modeling
model = BertForMaskedLM.from_pretrained(SPLICEBERT_PATH) # load model
with autocast():
    logits = model(input_ids).logits # shape: (batch_size, sequence_length, vocab_size)

# finetuning SpliceBERT for token classification tasks
with autocast():
    model = BertForTokenClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)

# finetuning SpliceBERT for sequence classification tasks
with autocast():
    model = BertForSequenceClassification.from_pretrained(SPLICEBERT_PATH, num_labels=3) # assume the class number is 3, shape: (batch_size, sequence_length, num_labels)

Reproduce the analysis

  1. Configure the environment.

    We run the scripts in a conda environment with python 3.9.7 on a Linux system (Ubuntu 20.04.3 LTS). The required packages are:

    • Python packages:
      • Python (3.9.7)
      • transformers (4.24.0)
      • pytorch (1.12.1)
      • h5py (3.2.1)
      • numpy (1.23.3)
      • scipy (1.8.0)
      • scikit-learn (1.1.1)
      • scanpy (1.8.2)
      • matplotlib (3.5.1)
      • seaborn (0.11.2)
      • tqdm (4.64.0)
      • pyBigWig (0.3.18)
      • cython (0.29.28)
    • Command line tools (optional):
      • bedtools (2.30.0)
      • MaxEntScan (2004)
      • gtfToGenePred (v377)

    Note: the version number is only used to illustrate the version of softwares used in our study. In most cases, users do not need to ensure that the versions are strictly the same to ours to run the codes

  2. Clone this repository, download data and setup scripts.

    git clone [email protected]:biomed-AI/SpliceBERT.git
    cd SpliceBERT
    bash download.sh # download model weights and data, or manually download them from [zenodo](https://doi.org/10.5281/zenodo.7995778)
    cd examples
    bash setup.sh # compile selene utils, cython is required
  3. (Optional) Download pre-computed results for section 1-4 from Google Drive and decompress them in the examples folder.

    # users should manually download `pre-computed_results.tar.gz` and put it in the `./examples` folder and run the following command to decompress it
    tar -zxvf pre-computed_results.tar.gz

    If pre-computed results have been downloaded and decompressed correctly, users can skip running pipeline.sh in the jupyter notebooks of section 1-4.

  4. Run jupyter notebooks (section 1-4) or bash scripts pipeline.sh (section 5-6):

Contact

For issues related to the scripts, create an issue at https://github.com/biomed-AI/SpliceBERT/issues.

For any other questions, feel free to contact chenkenbio {at} gmail.com.

Citation

@article{Chen2023.01.31.526427,
	author = {Chen, Ken and Zhou, Yue and Ding, Maolin and Wang, Yu and Ren, Zhixiang and Yang, Yuedong},
	title = {Self-supervised learning on millions of primary RNA sequences from 72 vertebrates improves sequence-based RNA splicing prediction},
	year = {2024},
	doi = {10.1093/bib/bbae163},
	publisher = {Oxford University Press},
	URL = {https://doi.org/10.1093/bib/bbae163},
	journal = {Briefings in bioinformatics}
}

splicebert's People

Contributors

chenkenbio avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

splicebert's Issues

In silico mutagenesis

Hi there,

Thank you so much for this awesome paper!
I am wondering how do you implement in silico mutagenesis analysis?
Is there any materials that I can look into?

Thank you so much!

No module named 'biock'

Hi there,
When I try reproduce example 03, I met with the following error: ModuleNotFoundError: No module named 'biock' when I try to run:

from biock.plot._plot import boxplot_with_scatter
fig = plt.figure(figsize=get_figure_size(0.5, 0.35))
ax = plt.subplot()
# _ = ax.boxplot(scores.values(), positions=xticks, sym='', medianprops=dict(color='black'))
xticks = np.arange(len(scores))
print('\n'.join(["{}\t{:.3e}".format(k, np.mean(v)) for k, v in scores.items()]))
boxplot_with_scatter(
    list(scores.values()), 
    # list(scores_norm.values()), 
    positions=xticks, 
    ax=ax, 
    scatter_kwargs=dict(marker='.', alpha=0.8), max_sample=1000, size=0.5, medianprops=dict(color='black'))
xticklabels = list(scores.keys())
xticklabels = ["D-A(intron)", "D-A(exon)", "D-A(unpair)", "D-D", "A-A", "random pair\n(control)"]
# _ = ax.boxplot(scores.values(), positions=xticks, sym='', medianprops=dict(color='black'))
_ = ax.set_xticks(xticks, xticklabels, rotation=15)
ax.set_yscale("log")
set_spines(ax)
ax.set_ylabel("attention weights")
plt.tight_layout()
# plt.savefig("./Figure_4A.jpg", dpi=600)
# plt.savefig("../figures/attention_pair_by_groups.jpg", dpi=600)
# plt.savefig("../figures/attention_pair_by_groups.svg")

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.