Coder Social home page Coder Social logo

transhla's Introduction

TransHLA: A Hybrid Transformer Model for HLA-Presented Epitope Detection

This repository has open-sourced the source code of TransHLA, which is used for the comparative model, dataset, and the model training and inference process.

This document contains a tutorial on the model's training process, and it also includes the use of the TransHLA pre-trained model in transformers.

TransHLA is a tool designed to discern whether a peptide will be recognized by HLA as an epitope.TransHLA is the first tool capable of directly identifying peptides as epitopes without the need for inputting HLA alleles. Due the different length of epitopes, we trained two models. The first is TransHLA_I, which is used for the detection of the HLA-I epitope, the other is TransHLA_II, which is used for the detection of the HLA-II epitope. This document is the code and the datasets

Model description

TransHLA is a hybrid transformer model that utilizes a transformer encoder module and a deep CNN module. It is trained using pretrained sequence embeddings from ESM2 and contact map structural features as inputs. It can serve as a preliminary screening for the currently popular tools that are specific for HLA-epitope binding affinity.

Intended uses

Due to variations in peptide lengths, our TransHLA is divided into TransHLA_I and TransHLA_II, which are used to separately identify epitopes presented by HLA class I and class II molecules, respectively. Specifically, TransHLA_I is designed for shorter peptides ranging from 8 to 14 amino acids in length, while TransHLA_II targets longer peptides with lengths of 13 to 21 amino acids. The output consists of two parts. The first output indicates whether the peptide is an epitope, presented in a two-column format where each row contains two numbers that sum to 1, representing probabilities. If the number in the second column is greater than or equal to 0.5, the peptide is classified as an epitope; otherwise, it is considered a normal peptide. The second output is the sequence embedding generated by the model. For both models, we have written separate tutorials in this file to facilitate ease of use.

How to train your own model

First, download this repository.

git clone https://github.com/SkywalkerLuke/TransHLA.git
cd TransHLA

Then, install the requirements.txt

pip install -r  requirements.txt

Then, change the directory to the model_train_test, and use the train.py:

cd model_train_test
python train.py --train_path your_train.csv --validation_path your_validation.csv --model_path your_path_to_save_model --model_name your_model_name.pt

And you can use the inference.py to use your own model:

python inference.py --test_path your_test.csv --model_path your_model.pt --ouputs_path your_outputs_path.npy

How to use in transformers

First, users need to download the following packages: pytorch, fair-esm, and transformers. Additionally, the CUDA version must be 11.8 or higher; otherwise, the model will need to be run on CPU.

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install transformers
pip install fair-esm

Here is how to use TransHLA_I model to predict whether a peptide is an epitope:

from transformers import AutoTokenizer
from transformers import AutoModel
import torch



def pad_inner_lists_to_length(outer_list,target_length=16):
    for inner_list in outer_list:
        padding_length = target_length - len(inner_list)
        if padding_length > 0:
            inner_list.extend([1] * padding_length)
    return outer_list


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} device")
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
    model = AutoModel.from_pretrained("SkywalkerLu/TransHLA_I", trust_remote_code=True)
    model.to(device)
    peptide_examples = ['EDSAIVTPSR','SVWEPAKAKYVFR']
    peptide_encoding = tokenizer(peptide_examples)['input_ids']
    peptide_encoding = pad_inner_lists_to_length(peptide_encoding)
    print(peptide_encoding)
    peptide_encoding = torch.tensor(peptide_encoding)
    outputs,representations = model(peptide_encoding.to(device))
    print(outputs)
    print(representations)

And here is how to use TransHLA_II model to predict the peptide whether epitope:

from transformers import AutoTokenizer
from transformers import AutoModel
import torch




def pad_inner_lists_to_length(outer_list,target_length=23):
    for inner_list in outer_list:
        padding_length = target_length - len(inner_list)
        if padding_length > 0:
            inner_list.extend([1] * padding_length)
    return outer_list



if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} device")
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
    model = AutoModel.from_pretrained("SkywalkerLu/TransHLA_II", trust_remote_code=True)
    model.to(device)
    model.eval()
    peptide_examples = ['KMIYSYSSHAASSL','ARGDFFRATSRLTTDFG']
    peptide_encoding = tokenizer(peptide_examples)['input_ids']
    peptide_encoding = pad_inner_lists_to_length(peptide_encoding)
    peptide_encoding = torch.tensor(peptide_encoding)
    outputs,representations = model(peptide_encoding.to(device))
    print(outputs)
    print(representations)

transhla's People

Contributors

skywalkerluke avatar

Stargazers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.