plant-dnabert / README.md
lgq12697's picture
Add Plant DNABERT model
10e9c3c
|
raw
history blame
No virus
2.56 kB
metadata
license: cc-by-nc-sa-4.0
widget:
  - text: ACCTGA<mask>TTCTGAGTC
tags:
  - DNA
  - biology
  - genomics
datasets:
  - zhangtaolab/plant_reference_genomes

Plant foundation DNA large language models

The plant DNA large language models (LLMs) contain a series of foundation models based on different model architectures, which are pre-trained on various plant reference genomes.
All the models have a comparable model size between 90 MB and 150 MB, BPE tokenizer is used for tokenization and 8000 tokens are included in the vocabulary.

Part of this collection is the nucleotide-transformer-v2-100m-multi-species, a 100m parameters transformer pre-trained on a collection of 850 genomes from a wide range of species, including model and non-model organisms.

Developed by: zhangtaolab

Model Sources

Architecture

The model is trained based on the Google BERT base model with modified tokenizer specific for DNA sequence.

How to use

Install the runtime library first:

pip install transformers

Here is a simple code for inference:

from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

model_name = 'plant-dnabert'
# load model and tokenizer
model = AutoModelForMaskedLM.from_pretrained(f'zhangtaolab/{model_name}', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(f'zhangtaolab/{model_name}', trust_remote_code=True)

# example sequence and tokenization
sequences = ['ATATACGGCCGNC','GGGTATCGCTTCCGAC']
tokens = tokenizer(sequences,padding="longest")['input_ids']
print(f"Tokenzied sequence: {tokenizer.batch_decode(tokens)}")

# inference
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
inputs = tokenizer(sequences, truncation=True, padding='max_length', max_length=512, 
                   return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
outs = model(
    **inputs,
    output_hidden_states=True
)

# get the final layer embeddings and prediction logits
embeddings = outs['hidden_states'][-1].detach().numpy()
logits = outs['logits'].detach().numpy()

Training data

We use MaskedLM method to pre-train the model, the tokenized sequence have a maximum length of 512.
Detailed training procedure can be found in our manuscript.

Hardware

Model was pre-trained on a NVIDIA RTX4090 GPU (24 GB).