Example for causal eQTL identification using prompt-enhanced ChromBERT

ChromBERT’s integration of a DNA sequence prompt from DNABERT-2 enhances its versatility for genomic applications, such as fine-mapping eQTLs.
By incorporating DNA sequence variation, ChromBERT can identify causal variants influencing gene expression.
Using the latest eQTL Catalogue, we fine-tuned ChromBERT to classify causal and non-causal variants as an example of its capability.

Attention: You should go through thistutorialat first to get familiar with the basic usage of ChromBERT.

[1]:
import  os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # to selected gpu used

import sys
import pathlib
import pickle

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
import chrombert
from torchinfo import summary

import lightning.pytorch as pl

import sklearn
from sklearn import metrics

basedir = os.path.expanduser("~/.cache/chrombert/data"  )

Prepare datasets

We provide a demo eQTL dataset for the lung, which includes additional columns compared to those used in other tasks: base_ref, base_alt, variant_id, label, and pos.

  • pos: Specifies the variant position.

  • base_ref: Indicates the reference allele.

  • base_alt: Represents the alternative allele.

The variant_id serves as a unique identifier for each variant and can be any unique string for convenience. The label column classifies variants as causal (1) or non-causal (0). If the dataset is used only for prediction, the label column can be omitted.

[2]:
table_eqtls = os.path.join(basedir, "demo","eqtl", "lung_eqtl.tsv")
!head $table_eqtls
chrom   start   end     build_region_index      base_ref        base_alt        variant_id      label   pos
chr17   29595000        29596000        771566  A       G       chr17_29595257_A_G      1       29595257
chr11   78139000        78140000        342669  C       T       chr11_78139778_C_T      1       78139778
chr19   35309000        35310000        898784  G       T       chr19_35309759_G_T      1       35309759
chr4    6695000 6696000 1359650 T       C       chr4_6695946_T_C        1       6695946
chr6    139029000       139030000       1719458 G       A       chr6_139029045_G_A      1       139029045
chr15   41554000        41555000        640548  T       G       chr15_41554978_T_G      1       41554978
chr16   4343000 4344000 694493  A       C       chr16_4343375_A_C       1       4343375
chr9    97238000        97239000        2025247 A       G       chr9_97238449_A_G       1       97238449
chr5    96757000        96758000        1551457 G       A       chr5_96757518_G_A       1       96757518
[3]:
from sklearn.model_selection import train_test_split

odir = pathlib.Path("tmp_eqtl")
odir.mkdir(exist_ok=True, parents=True)

df_full = pd.read_csv(table_eqtls, sep="\t")
df_train, df_test = train_test_split(df_full, test_size=0.4, random_state=42)
df_train.to_csv(odir / "train.tsv", sep="\t", index=False)
df_test.to_csv(odir / "test.tsv", sep="\t", index=False)
len(df_train), len(df_test)
[3]:
(921, 614)
[4]:
# configure dataset
dc = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = table_eqtls)
print(dc)
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
{
    "hdf5_file": "/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5",
    "supervised_file": "/home/chenqianqian/.cache/chrombert/data/demo/eqtl/lung_eqtl.tsv",
    "kind": "PromptDataset",
    "meta_file": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json",
    "ignore": false,
    "ignore_object": null,
    "batch_size": 8,
    "num_workers": 20,
    "shuffle": false,
    "pin_memory": true,
    "perturbation": false,
    "perturbation_object": null,
    "perturbation_value": 0,
    "prompt_kind": "dna",
    "prompt_regulator": null,
    "prompt_regulator_cache_file": null,
    "prompt_celltype": null,
    "prompt_celltype_cache_file": null,
    "prompt_regulator_cache_pin_memory": false,
    "prompt_regulator_cache_limit": 3,
    "fasta_file": "/home/chenqianqian/.cache/chrombert/data/other/hg38.fa",
    "flank_window": 0
}
[5]:
# initialize dataset
ds = dc.init_dataset()
ds[1]
[5]:
{'input_ids': tensor([9, 9, 9,  ..., 6, 6, 6], dtype=torch.int8),
 'position_ids': tensor([   1,    2,    3,  ..., 6389, 6390, 6391]),
 'region': tensor([      11, 78139000, 78140000], dtype=torch.int32),
 'build_region_index': 342669,
 'label': 1,
 'seq_raw': 'GATTTGTTCCAAATCAGACAGCGCCAGGTCTGAACCTAGCCAGCTGGGGCTAAGTCAAGTAACAACTGGCGAAACAGAAAGCTTAGCAAAGGCAGGATAGCGACAAACACGACCTAAAGTTTTCTCTTCATACCCAGGGATATCCACACCTTTCTCTCCCGCCCTGACCGACCGCGGGGCCTCCCCGCCCAGCCCCTGGCCGTGCGAGTCCCTTACTATGTGGGGATGAGAAGGCATTTGAGAAGAGTCACCCCGAGCGCCAAAGCCGAAAACCAATTGCCAGTACCCGTGGCAATTGTGAGCGCCGCCATTGCTGCGGCACCGCACGCTTCCCACCAACTTGATCCACATCCGGGATCCCGCGCATGCGGAGAAAGCCCTCTGAAGCCGTGCCCGCTAGCTGCGCGCATGCGGCGAGCGGCGCAGCCAGTCCGGGGACTGCAGTCAGCTATTTAAACCTCCCGCCCACCTTTTCTTTAGACCCGCGTCTCACCCCGGGCCGGAAGGGCTCCTGCGCAGGCGTTTGTAGCCACTTTTAAGTTTTATCAGCTAGTTCATGCTTGCGTTGAAAGAGTGGTCGTTTGCGCTGGGTCATCACTGTGTAGTATTGGGGATACTTAGGTGAGAAAAAAACTTAACGCTAGAGACGTTCACGCACTAGTGGAGAAGCCAGGATTGTTGCCCTAGAGTTACAGTAGATAAAAGTACCTCAGAGAACTGCGGGGGCTCCCAACCTGGACGCTTGCACCGGAGTATTAAATCCAGCTAGAGAATGGCATGTGCAAAGATACAGAGGTGAGAAACATTGTGTTTTTAGAACTCTGAGCGAGGCTCTTGGCTCACCTCCTGCTTGAGCGGAACCCATTCTGGAAGCAGGGTAGAGGCTAGTCCTAACGCTTAGTGTACAAATAGCCTACGGTTCATGTTAAAATAATTCGGATTCTGATTCAGTAGGCCCAACAAACTCGCAGATTGCGTAATGACTGAGGCACATGCAATT',
 'seq_alt': 'GATTTGTTCCAAATCAGACAGCGCCAGGTCTGAACCTAGCCAGCTGGGGCTAAGTCAAGTAACAACTGGCGAAACAGAAAGCTTAGCAAAGGCAGGATAGCGACAAACACGACCTAAAGTTTTCTCTTCATACCCAGGGATATCCACACCTTTCTCTCCCGCCCTGACCGACCGCGGGGCCTCCCCGCCCAGCCCCTGGCCGTGCGAGTCCCTTACTATGTGGGGATGAGAAGGCATTTGAGAAGAGTCACCCCGAGCGCCAAAGCCGAAAACCAATTGCCAGTACCCGTGGCAATTGTGAGCGCCGCCATTGCTGCGGCACCGCACGCTTCCCACCAACTTGATCCACATCCGGGATCCCGCGCATGCGGAGAAAGCCCTCTGAAGCCGTGCCCGCTAGCTGCGCGCATGCGGCGAGCGGCGCAGCCAGTCCGGGGACTGCAGTCAGCTATTTAAACCTCCCGCCCACCTTTTCTTTAGACCCGCGTCTCACCCCGGGCTGGAAGGGCTCCTGCGCAGGCGTTTGTAGCCACTTTTAAGTTTTATCAGCTAGTTCATGCTTGCGTTGAAAGAGTGGTCGTTTGCGCTGGGTCATCACTGTGTAGTATTGGGGATACTTAGGTGAGAAAAAAACTTAACGCTAGAGACGTTCACGCACTAGTGGAGAAGCCAGGATTGTTGCCCTAGAGTTACAGTAGATAAAAGTACCTCAGAGAACTGCGGGGGCTCCCAACCTGGACGCTTGCACCGGAGTATTAAATCCAGCTAGAGAATGGCATGTGCAAAGATACAGAGGTGAGAAACATTGTGTTTTTAGAACTCTGAGCGAGGCTCTTGGCTCACCTCCTGCTTGAGCGGAACCCATTCTGGAAGCAGGGTAGAGGCTAGTCCTAACGCTTAGTGTACAAATAGCCTACGGTTCATGTTAAAATAATTCGGATTCTGATTCAGTAGGCCCAACAAACTCGCAGATTGCGTAATGACTGAGGCACATGCAATT'}

Prepare the model

The model can be loaded in the same way as for other tasks, except that the kind is set to prompt_dna.

[6]:
mc = chrombert.get_preset_model_config(
    "prompt_dna",
    dnabert2_ckpt="zhihan1996/DNABERT-2-117M" # use model from hugging-face, or provide path directly
    )
print(mc)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
{
    "genome": "hg38",
    "task": "prompt",
    "dim_output": 1,
    "mtx_mask": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv",
    "dropout": 0.1,
    "pretrain_ckpt": "/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt",
    "finetune_ckpt": null,
    "ignore": false,
    "ignore_index": [
        null,
        null
    ],
    "gep_flank_window": 4,
    "gep_parallel_embedding": false,
    "gep_gradient_checkpoint": false,
    "gep_zero_inflation": false,
    "prompt_kind": "dna",
    "prompt_dim_external": 768,
    "dnabert2_ckpt": "zhihan1996/DNABERT-2-117M"
}
[7]:
model = mc.init_model()
summary(model)
use organisim hg38; max sequence length is 6391
Warning: zhihan1996/DNABERT-2-117M does not exist! Try to use huggingface cached...
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_clean/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
/home/chenqianqian/.conda/envs/chrombert_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
/home/chenqianqian/.conda/envs/chrombert_clean/lib/python3.9/site-packages/transformers/modeling_utils.py:442: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  return torch.load(checkpoint_file, map_location="cpu")
/home/chenqianqian/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:126: UserWarning: Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).
  warnings.warn(
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[7]:
==========================================================================================
Layer (type:depth-idx)                                            Param #
==========================================================================================
ChromBERTPromptDNA                                                --
├─ChromBERT: 1-1                                                  --
│    └─BERTEmbedding: 2-1                                         --
│    │    └─TokenEmbedding: 3-1                                   7,680
│    │    └─PositionalEmbedding: 3-2                              4,909,056
│    │    └─Dropout: 3-3                                          --
│    └─ModuleList: 2-2                                            --
│    │    └─EncoderTransformerBlock: 3-4                          6,497,280
│    │    └─EncoderTransformerBlock: 3-5                          6,497,280
│    │    └─EncoderTransformerBlock: 3-6                          6,497,280
│    │    └─EncoderTransformerBlock: 3-7                          6,497,280
│    │    └─EncoderTransformerBlock: 3-8                          6,497,280
│    │    └─EncoderTransformerBlock: 3-9                          6,497,280
│    │    └─EncoderTransformerBlock: 3-10                         6,497,280
│    │    └─EncoderTransformerBlock: 3-11                         6,497,280
├─DNABERT2Interface: 1-2                                          --
│    └─BertModel: 2-3                                             --
│    │    └─BertEmbeddings: 3-12                                  3,148,800
│    │    └─BertEncoder: 3-13                                     113,329,152
│    │    └─BertPooler: 3-14                                      590,592
├─AdapterExternalEmb: 1-3                                         --
│    └─ResidualBlock: 2-4                                         --
│    │    └─Linear: 3-15                                          590,592
│    │    └─Linear: 3-16                                          590,592
│    │    └─LayerNorm: 3-17                                       1,536
│    │    └─Sequential: 3-18                                      --
│    │    └─Dropout: 3-19                                         --
│    └─ResidualBlock: 2-5                                         --
│    │    └─Linear: 3-20                                          590,592
│    │    └─Linear: 3-21                                          590,592
│    │    └─LayerNorm: 3-22                                       1,536
│    │    └─Sequential: 3-23                                      --
│    │    └─Dropout: 3-24                                         --
├─GeneralHeader: 1-4                                              --
│    └─CistromeEmbeddingManager: 2-6                              --
│    └─Conv2d: 2-7                                                769
│    └─ReLU: 2-8                                                  --
│    └─ResidualBlock: 2-9                                         --
│    │    └─Linear: 3-25                                          1,099,776
│    │    └─Linear: 3-26                                          1,049,600
│    │    └─LayerNorm: 3-27                                       2,048
│    │    └─Linear: 3-28                                          1,099,776
│    │    └─Dropout: 3-29                                         --
│    └─ResidualBlock: 2-10                                        --
│    │    └─Linear: 3-30                                          787,200
│    │    └─Linear: 3-31                                          590,592
│    │    └─LayerNorm: 3-32                                       1,536
│    │    └─Linear: 3-33                                          787,200
│    │    └─Dropout: 3-34                                         --
│    └─ResidualBlock: 2-11                                        --
│    │    └─Linear: 3-35                                          196,864
│    │    └─Linear: 3-36                                          65,792
│    │    └─LayerNorm: 3-37                                       512
│    │    └─Linear: 3-38                                          196,864
│    │    └─Dropout: 3-39                                         --
│    └─Linear: 2-12                                               257
├─PromptHeader: 1-5                                               --
│    └─Sequential: 2-13                                           --
│    │    └─ResidualBlock: 3-40                                   4,724,736
│    │    └─ResidualBlock: 3-41                                   2,952,960
│    │    └─ResidualBlock: 3-42                                   1,182,720
│    │    └─ResidualBlock: 3-43                                   102,720
│    │    └─Linear: 3-44                                          65
==========================================================================================
Total params: 191,170,947
Trainable params: 191,170,947
Non-trainable params: 0
==========================================================================================
[8]:
# we freeze pre-trained part of ChromBERT, as well as all parameters in DNABERT-2
model.freeze_pretrain(trainable=2)
model.dnabert2.freeze()
model.display_trainable_parameters()
{'total_params': 191170947, 'trainable_params': 30201987}
pretrain_model.embedding.token.weight : frozen
pretrain_model.embedding.position.pe.pe.weight : frozen
pretrain_model.transformer_blocks.0.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.0.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.0.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.0.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.0.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.0.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.0.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.0.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.0.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.0.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.1.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.1.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.1.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.1.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.1.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.1.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.1.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.1.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.1.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.1.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.2.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.2.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.2.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.2.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.2.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.2.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.2.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.2.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.2.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.2.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.3.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.3.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.3.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.3.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.3.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.3.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.3.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.3.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.3.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.3.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.4.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.4.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.4.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.4.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.4.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.4.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.4.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.4.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.4.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.4.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.5.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.5.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.5.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.5.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.5.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.5.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.5.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.5.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.5.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.5.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.6.attention.Wqkv.weight : trainable
pretrain_model.transformer_blocks.6.attention.Wqkv.bias : trainable
pretrain_model.transformer_blocks.6.feed_forward.w_1.weight : trainable
pretrain_model.transformer_blocks.6.feed_forward.w_1.bias : trainable
pretrain_model.transformer_blocks.6.feed_forward.w_2.weight : trainable
pretrain_model.transformer_blocks.6.feed_forward.w_2.bias : trainable
pretrain_model.transformer_blocks.6.input_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.6.input_sublayer.norm.b_2 : trainable
pretrain_model.transformer_blocks.6.output_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.6.output_sublayer.norm.b_2 : trainable
pretrain_model.transformer_blocks.7.attention.Wqkv.weight : trainable
pretrain_model.transformer_blocks.7.attention.Wqkv.bias : trainable
pretrain_model.transformer_blocks.7.feed_forward.w_1.weight : trainable
pretrain_model.transformer_blocks.7.feed_forward.w_1.bias : trainable
pretrain_model.transformer_blocks.7.feed_forward.w_2.weight : trainable
pretrain_model.transformer_blocks.7.feed_forward.w_2.bias : trainable
pretrain_model.transformer_blocks.7.input_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.7.input_sublayer.norm.b_2 : trainable
pretrain_model.transformer_blocks.7.output_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.7.output_sublayer.norm.b_2 : trainable
dnabert2.model.embeddings.word_embeddings.weight : frozen
dnabert2.model.embeddings.token_type_embeddings.weight : frozen
dnabert2.model.embeddings.LayerNorm.weight : frozen
dnabert2.model.embeddings.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.0.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.0.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.0.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.0.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.0.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.0.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.0.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.0.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.0.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.0.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.0.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.1.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.1.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.1.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.1.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.1.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.1.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.1.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.1.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.1.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.1.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.1.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.2.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.2.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.2.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.2.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.2.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.2.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.2.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.2.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.2.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.2.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.2.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.3.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.3.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.3.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.3.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.3.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.3.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.3.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.3.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.3.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.3.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.3.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.4.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.4.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.4.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.4.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.4.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.4.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.4.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.4.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.4.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.4.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.4.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.5.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.5.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.5.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.5.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.5.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.5.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.5.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.5.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.5.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.5.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.5.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.6.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.6.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.6.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.6.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.6.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.6.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.6.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.6.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.6.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.6.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.6.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.7.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.7.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.7.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.7.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.7.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.7.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.7.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.7.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.7.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.7.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.7.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.8.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.8.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.8.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.8.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.8.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.8.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.8.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.8.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.8.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.8.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.8.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.9.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.9.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.9.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.9.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.9.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.9.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.9.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.9.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.9.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.9.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.9.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.10.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.10.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.10.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.10.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.10.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.10.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.10.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.10.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.10.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.10.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.10.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.11.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.11.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.11.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.11.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.11.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.11.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.11.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.11.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.11.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.11.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.11.mlp.layernorm.bias : frozen
dnabert2.model.pooler.dense.weight : frozen
dnabert2.model.pooler.dense.bias : frozen
adapter_dna_emb.fc1.fc1.weight : trainable
adapter_dna_emb.fc1.fc1.bias : trainable
adapter_dna_emb.fc1.fc2.weight : trainable
adapter_dna_emb.fc1.fc2.bias : trainable
adapter_dna_emb.fc1.norm.a_2 : trainable
adapter_dna_emb.fc1.norm.b_2 : trainable
adapter_dna_emb.fc2.fc1.weight : trainable
adapter_dna_emb.fc2.fc1.bias : trainable
adapter_dna_emb.fc2.fc2.weight : trainable
adapter_dna_emb.fc2.fc2.bias : trainable
adapter_dna_emb.fc2.norm.a_2 : trainable
adapter_dna_emb.fc2.norm.b_2 : trainable
adapter_chrombert.conv.weight : trainable
adapter_chrombert.conv.bias : trainable
adapter_chrombert.res1.fc1.weight : trainable
adapter_chrombert.res1.fc1.bias : trainable
adapter_chrombert.res1.fc2.weight : trainable
adapter_chrombert.res1.fc2.bias : trainable
adapter_chrombert.res1.norm.a_2 : trainable
adapter_chrombert.res1.norm.b_2 : trainable
adapter_chrombert.res1.shortcut.weight : trainable
adapter_chrombert.res1.shortcut.bias : trainable
adapter_chrombert.res2.fc1.weight : trainable
adapter_chrombert.res2.fc1.bias : trainable
adapter_chrombert.res2.fc2.weight : trainable
adapter_chrombert.res2.fc2.bias : trainable
adapter_chrombert.res2.norm.a_2 : trainable
adapter_chrombert.res2.norm.b_2 : trainable
adapter_chrombert.res2.shortcut.weight : trainable
adapter_chrombert.res2.shortcut.bias : trainable
adapter_chrombert.res3.fc1.weight : trainable
adapter_chrombert.res3.fc1.bias : trainable
adapter_chrombert.res3.fc2.weight : trainable
adapter_chrombert.res3.fc2.bias : trainable
adapter_chrombert.res3.norm.a_2 : trainable
adapter_chrombert.res3.norm.b_2 : trainable
adapter_chrombert.res3.shortcut.weight : trainable
adapter_chrombert.res3.shortcut.bias : trainable
adapter_chrombert.fc.weight : trainable
adapter_chrombert.fc.bias : trainable
head_output.fcs.0.fc1.weight : trainable
head_output.fcs.0.fc1.bias : trainable
head_output.fcs.0.fc2.weight : trainable
head_output.fcs.0.fc2.bias : trainable
head_output.fcs.0.norm.a_2 : trainable
head_output.fcs.0.norm.b_2 : trainable
head_output.fcs.1.fc1.weight : trainable
head_output.fcs.1.fc1.bias : trainable
head_output.fcs.1.fc2.weight : trainable
head_output.fcs.1.fc2.bias : trainable
head_output.fcs.1.norm.a_2 : trainable
head_output.fcs.1.norm.b_2 : trainable
head_output.fcs.1.shortcut.weight : trainable
head_output.fcs.1.shortcut.bias : trainable
head_output.fcs.2.fc1.weight : trainable
head_output.fcs.2.fc1.bias : trainable
head_output.fcs.2.fc2.weight : trainable
head_output.fcs.2.fc2.bias : trainable
head_output.fcs.2.norm.a_2 : trainable
head_output.fcs.2.norm.b_2 : trainable
head_output.fcs.3.fc1.weight : trainable
head_output.fcs.3.fc1.bias : trainable
head_output.fcs.3.fc2.weight : trainable
head_output.fcs.3.fc2.bias : trainable
head_output.fcs.3.norm.a_2 : trainable
head_output.fcs.3.norm.b_2 : trainable
head_output.fcs.3.shortcut.weight : trainable
head_output.fcs.3.shortcut.bias : trainable
head_output.fcs.4.weight : trainable
head_output.fcs.4.bias : trainable
[8]:
{'total_params': 191170947, 'trainable_params': 30201987}

Fine-tune

This task has fewer training samples compared to others, so we use a reduced number of training steps. To simplify the process, the model is trained directly without using PyTorch Lightning.

[9]:
dc_train = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = df_train)
dc_test = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = df_test)

ds_train = dc_train.init_dataset()
ds_test = dc_test.init_dataset()
dl_train = dc_train.init_dataloader(batch_size=2, shuffle=True, num_workers=4)
dl_test = dc_test.init_dataloader(batch_size=2, shuffle=False, num_workers=4)

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
[10]:
from transformers import get_linear_schedule_with_warmup
from torch import nn

def train(m, dl, lr = 5e-5, grad_accumulation_steps=4, min_epochs = 2, max_epochs = 5, max_steps = 200):
    num_training_steps = len(dl) * max_epochs // grad_accumulation_steps
    optimizer = torch.optim.AdamW(m.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
    loss_f = nn.BCEWithLogitsLoss()

    keys_to_cuda = ["input_ids", "position_ids", "label"]
    total_steps = 0
    for e in range(max_epochs):
        m.train()
        for i, batch in enumerate(tqdm(dl)):
            # batch = {k: v.cuda() for k, v in batch.items()}
            for k in keys_to_cuda:
                batch[k] = batch[k].cuda()

            logits = m(batch).view(-1)
            loss = loss_f(logits, batch["label"].to(torch.float))
            loss = loss.mean() / grad_accumulation_steps
            loss.backward()
            if (i+1) % grad_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                total_steps += 1
                if total_steps > max_steps:
                    if e >= min_epochs:
                        return m

        print(f"epoch {e} loss {loss.item()}")
    return m
[11]:
model_tuned = train(model.cuda(), dl_train)
  0%|          | 0/461 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 461/461 [01:30<00:00,  5.10it/s]
epoch 0 loss 0.06694883108139038
100%|██████████| 461/461 [01:26<00:00,  5.32it/s]
epoch 1 loss 0.2063915878534317
  1%|          | 3/461 [00:01<02:42,  2.82it/s]

Evaluation

Typically, we save the checkpoint and evaluate it later. However, direct evaluation is an option if minimal randomness is acceptable.

[12]:
model_tuned.save_ckpt(odir / "model.ckpt")

[15]:
model = mc.init_model(finetune_ckpt = odir / "model.ckpt", dropout=0).cuda()
model.display_trainable_parameters(verbose=False)
use organisim hg38; max sequence length is 6391
Warning: zhihan1996/DNABERT-2-117M does not exist! Try to use huggingface cached...
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_clean/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
/home/chenqianqian/.conda/envs/chrombert_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
/home/chenqianqian/.conda/envs/chrombert_clean/lib/python3.9/site-packages/transformers/modeling_utils.py:442: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  return torch.load(checkpoint_file, map_location="cpu")
/home/chenqianqian/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:126: UserWarning: Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).
  warnings.warn(
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Loading checkpoint from tmp_eqtl/model.ckpt
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_clean/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  new_state = torch.load(ckpt)
Loaded 290/290 parameters
{'total_params': 191170947, 'trainable_params': 191170947}
[15]:
{'total_params': 191170947, 'trainable_params': 191170947}
[16]:
model = model.cuda()
probs = []
labels = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(tqdm(dl_test)):
        keys_to_cuda = ["input_ids", "position_ids", "label"]
        for k in keys_to_cuda:
            batch[k] = batch[k].cuda()
        logits = model(batch).view(-1)
        probs.append(logits.sigmoid().float().cpu().numpy())
        labels.append(batch["label"].cpu().numpy())

probs = np.concatenate(probs)
labels = np.concatenate(labels)
auc = metrics.roc_auc_score(labels, probs)
aupr = metrics.average_precision_score(labels, probs)
print(f"test auc {auc}, aupr {aupr}")
  0%|          | 0/307 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 307/307 [00:38<00:00,  8.01it/s]
test auc 0.8395420949791861, aupr 0.8312050511192879

Identification of differential regulators

We can infer differential regulators for causal and non-causal variants by comparing distances in regulator embeddings. Notably, DNase-seq data often shows a significant difference in these distances.

[17]:
model_emb = model.get_embedding_manager()
model_emb
[17]:
ChromBERTEmbedding(
  (pretrain_model): ChromBERT(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(10, 768, padding_idx=0)
      (position): PositionalEmbedding(
        (pe): PositionalEmbeddingTrainable(
          (pe): Embedding(6392, 768)
        )
      )
      (dropout): Dropout(p=0, inplace=False)
    )
    (transformer_blocks): ModuleList(
      (0-7): 8 x EncoderTransformerBlock(
        (attention): SelfAttentionFlashMHA(
          (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=768, out_features=3072, bias=True)
          (w_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0, inplace=False)
          (activation): GELU()
        )
        (input_sublayer): SublayerConnection(
          (norm): LayerNorm()
          (dropout): Dropout(p=0, inplace=False)
        )
        (output_sublayer): SublayerConnection(
          (norm): LayerNorm()
          (dropout): Dropout(p=0, inplace=False)
        )
        (dropout): Dropout(p=0, inplace=False)
      )
    )
  )
  (CistromeEmbeddingManager): CistromeEmbeddingManager()
)
[18]:
dc = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = table_eqtls)
dl = dc.init_dataloader(batch_size = 2, shuffle=False, num_workers=4)
labels = []
embeddings = []
for batch in tqdm(dl):
    keys_to_cuda = ["input_ids", "position_ids", "label"]
    for k in keys_to_cuda:
        batch[k] = batch[k].cuda()
    labels.append(batch["label"].cpu().numpy())
    embeddings.append(model_emb(batch).float().cpu().numpy())
labels = np.concatenate(labels)
embeddings = np.concatenate(embeddings)
emb_causal = embeddings[labels == 1].mean(axis=0)
emb_noncausal = embeddings[labels == 0].mean(axis=0)
emb_causal.shape, emb_noncausal.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
100%|██████████| 768/768 [01:27<00:00,  8.82it/s]
[18]:
((1073, 768), (1073, 768))
[19]:
df_shift = pd.DataFrame(
    {
        "regulator":model_emb.list_regulator,
        "shift": 1- sklearn.metrics.pairwise.cosine_similarity(emb_causal, emb_noncausal).diagonal(),
    }
).sort_values("shift", ascending=False, ignore_index=True)
# The regulator embeddings of "input" were set to a zero vector, so its cosine similarity calculation is not meaningful
df_shift.head(20)
[19]:
regulator shift
0 input 1.000000
1 rbl1 0.178612
2 dnase 0.161799
3 hcfc1r1 0.146511
4 sox4 0.141581
5 h3k4me3 0.140027
6 kdm2b 0.131659
7 lmnb1 0.130848
8 h3k4me2 0.123314
9 hexim1 0.121096
10 cnot3 0.116232
11 maz 0.116230
12 h1.4 0.115697
13 kdm4c 0.112475
14 zfx 0.111389
15 cfp1 0.111306
16 ints11 0.110976
17 kdm5b 0.107938
18 cdca2 0.106124
19 fgfr1 0.105282
[ ]: