Example for causal eQTL identification using prompt-enhanced ChromBERT¶
Attention: You should go through thistutorialat first to get familiar with the basic usage of ChromBERT.
[1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # to selected gpu used
import sys
import pathlib
import pickle
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
import chrombert
from torchinfo import summary
import lightning.pytorch as pl
import sklearn
from sklearn import metrics
basedir = os.path.expanduser("~/.cache/chrombert/data" )
Prepare datasets¶
We provide a demo eQTL dataset for the lung, which includes additional columns compared to those used in other tasks: base_ref, base_alt, variant_id, label, and pos.
pos: Specifies the variant position.base_ref: Indicates the reference allele.base_alt: Represents the alternative allele.
The variant_id serves as a unique identifier for each variant and can be any unique string for convenience. The label column classifies variants as causal (1) or non-causal (0). If the dataset is used only for prediction, the label column can be omitted.
[2]:
table_eqtls = os.path.join(basedir, "demo","eqtl", "lung_eqtl.tsv")
!head $table_eqtls
chrom start end build_region_index base_ref base_alt variant_id label pos
chr17 29595000 29596000 771566 A G chr17_29595257_A_G 1 29595257
chr11 78139000 78140000 342669 C T chr11_78139778_C_T 1 78139778
chr19 35309000 35310000 898784 G T chr19_35309759_G_T 1 35309759
chr4 6695000 6696000 1359650 T C chr4_6695946_T_C 1 6695946
chr6 139029000 139030000 1719458 G A chr6_139029045_G_A 1 139029045
chr15 41554000 41555000 640548 T G chr15_41554978_T_G 1 41554978
chr16 4343000 4344000 694493 A C chr16_4343375_A_C 1 4343375
chr9 97238000 97239000 2025247 A G chr9_97238449_A_G 1 97238449
chr5 96757000 96758000 1551457 G A chr5_96757518_G_A 1 96757518
[3]:
from sklearn.model_selection import train_test_split
odir = pathlib.Path("tmp_eqtl")
odir.mkdir(exist_ok=True, parents=True)
df_full = pd.read_csv(table_eqtls, sep="\t")
df_train, df_test = train_test_split(df_full, test_size=0.4, random_state=42)
df_train.to_csv(odir / "train.tsv", sep="\t", index=False)
df_test.to_csv(odir / "test.tsv", sep="\t", index=False)
len(df_train), len(df_test)
[3]:
(921, 614)
[4]:
# configure dataset
dc = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = table_eqtls)
print(dc)
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
{
"hdf5_file": "/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5",
"supervised_file": "/home/chenqianqian/.cache/chrombert/data/demo/eqtl/lung_eqtl.tsv",
"kind": "PromptDataset",
"meta_file": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json",
"ignore": false,
"ignore_object": null,
"batch_size": 8,
"num_workers": 20,
"shuffle": false,
"pin_memory": true,
"perturbation": false,
"perturbation_object": null,
"perturbation_value": 0,
"prompt_kind": "dna",
"prompt_regulator": null,
"prompt_regulator_cache_file": null,
"prompt_celltype": null,
"prompt_celltype_cache_file": null,
"prompt_regulator_cache_pin_memory": false,
"prompt_regulator_cache_limit": 3,
"fasta_file": "/home/chenqianqian/.cache/chrombert/data/other/hg38.fa",
"flank_window": 0
}
[5]:
# initialize dataset
ds = dc.init_dataset()
ds[1]
[5]:
{'input_ids': tensor([9, 9, 9, ..., 6, 6, 6], dtype=torch.int8),
'position_ids': tensor([ 1, 2, 3, ..., 6389, 6390, 6391]),
'region': tensor([ 11, 78139000, 78140000], dtype=torch.int32),
'build_region_index': 342669,
'label': 1,
'seq_raw': 'GATTTGTTCCAAATCAGACAGCGCCAGGTCTGAACCTAGCCAGCTGGGGCTAAGTCAAGTAACAACTGGCGAAACAGAAAGCTTAGCAAAGGCAGGATAGCGACAAACACGACCTAAAGTTTTCTCTTCATACCCAGGGATATCCACACCTTTCTCTCCCGCCCTGACCGACCGCGGGGCCTCCCCGCCCAGCCCCTGGCCGTGCGAGTCCCTTACTATGTGGGGATGAGAAGGCATTTGAGAAGAGTCACCCCGAGCGCCAAAGCCGAAAACCAATTGCCAGTACCCGTGGCAATTGTGAGCGCCGCCATTGCTGCGGCACCGCACGCTTCCCACCAACTTGATCCACATCCGGGATCCCGCGCATGCGGAGAAAGCCCTCTGAAGCCGTGCCCGCTAGCTGCGCGCATGCGGCGAGCGGCGCAGCCAGTCCGGGGACTGCAGTCAGCTATTTAAACCTCCCGCCCACCTTTTCTTTAGACCCGCGTCTCACCCCGGGCCGGAAGGGCTCCTGCGCAGGCGTTTGTAGCCACTTTTAAGTTTTATCAGCTAGTTCATGCTTGCGTTGAAAGAGTGGTCGTTTGCGCTGGGTCATCACTGTGTAGTATTGGGGATACTTAGGTGAGAAAAAAACTTAACGCTAGAGACGTTCACGCACTAGTGGAGAAGCCAGGATTGTTGCCCTAGAGTTACAGTAGATAAAAGTACCTCAGAGAACTGCGGGGGCTCCCAACCTGGACGCTTGCACCGGAGTATTAAATCCAGCTAGAGAATGGCATGTGCAAAGATACAGAGGTGAGAAACATTGTGTTTTTAGAACTCTGAGCGAGGCTCTTGGCTCACCTCCTGCTTGAGCGGAACCCATTCTGGAAGCAGGGTAGAGGCTAGTCCTAACGCTTAGTGTACAAATAGCCTACGGTTCATGTTAAAATAATTCGGATTCTGATTCAGTAGGCCCAACAAACTCGCAGATTGCGTAATGACTGAGGCACATGCAATT',
'seq_alt': 'GATTTGTTCCAAATCAGACAGCGCCAGGTCTGAACCTAGCCAGCTGGGGCTAAGTCAAGTAACAACTGGCGAAACAGAAAGCTTAGCAAAGGCAGGATAGCGACAAACACGACCTAAAGTTTTCTCTTCATACCCAGGGATATCCACACCTTTCTCTCCCGCCCTGACCGACCGCGGGGCCTCCCCGCCCAGCCCCTGGCCGTGCGAGTCCCTTACTATGTGGGGATGAGAAGGCATTTGAGAAGAGTCACCCCGAGCGCCAAAGCCGAAAACCAATTGCCAGTACCCGTGGCAATTGTGAGCGCCGCCATTGCTGCGGCACCGCACGCTTCCCACCAACTTGATCCACATCCGGGATCCCGCGCATGCGGAGAAAGCCCTCTGAAGCCGTGCCCGCTAGCTGCGCGCATGCGGCGAGCGGCGCAGCCAGTCCGGGGACTGCAGTCAGCTATTTAAACCTCCCGCCCACCTTTTCTTTAGACCCGCGTCTCACCCCGGGCTGGAAGGGCTCCTGCGCAGGCGTTTGTAGCCACTTTTAAGTTTTATCAGCTAGTTCATGCTTGCGTTGAAAGAGTGGTCGTTTGCGCTGGGTCATCACTGTGTAGTATTGGGGATACTTAGGTGAGAAAAAAACTTAACGCTAGAGACGTTCACGCACTAGTGGAGAAGCCAGGATTGTTGCCCTAGAGTTACAGTAGATAAAAGTACCTCAGAGAACTGCGGGGGCTCCCAACCTGGACGCTTGCACCGGAGTATTAAATCCAGCTAGAGAATGGCATGTGCAAAGATACAGAGGTGAGAAACATTGTGTTTTTAGAACTCTGAGCGAGGCTCTTGGCTCACCTCCTGCTTGAGCGGAACCCATTCTGGAAGCAGGGTAGAGGCTAGTCCTAACGCTTAGTGTACAAATAGCCTACGGTTCATGTTAAAATAATTCGGATTCTGATTCAGTAGGCCCAACAAACTCGCAGATTGCGTAATGACTGAGGCACATGCAATT'}
Prepare the model¶
The model can be loaded in the same way as for other tasks, except that the kind is set to prompt_dna.
[6]:
mc = chrombert.get_preset_model_config(
"prompt_dna",
dnabert2_ckpt="zhihan1996/DNABERT-2-117M" # use model from hugging-face, or provide path directly
)
print(mc)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
{
"genome": "hg38",
"task": "prompt",
"dim_output": 1,
"mtx_mask": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv",
"dropout": 0.1,
"pretrain_ckpt": "/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt",
"finetune_ckpt": null,
"ignore": false,
"ignore_index": [
null,
null
],
"gep_flank_window": 4,
"gep_parallel_embedding": false,
"gep_gradient_checkpoint": false,
"gep_zero_inflation": false,
"prompt_kind": "dna",
"prompt_dim_external": 768,
"dnabert2_ckpt": "zhihan1996/DNABERT-2-117M"
}
[7]:
model = mc.init_model()
summary(model)
use organisim hg38; max sequence length is 6391
Warning: zhihan1996/DNABERT-2-117M does not exist! Try to use huggingface cached...
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_clean/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
/home/chenqianqian/.conda/envs/chrombert_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
/home/chenqianqian/.conda/envs/chrombert_clean/lib/python3.9/site-packages/transformers/modeling_utils.py:442: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(checkpoint_file, map_location="cpu")
/home/chenqianqian/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:126: UserWarning: Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).
warnings.warn(
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[7]:
==========================================================================================
Layer (type:depth-idx) Param #
==========================================================================================
ChromBERTPromptDNA --
├─ChromBERT: 1-1 --
│ └─BERTEmbedding: 2-1 --
│ │ └─TokenEmbedding: 3-1 7,680
│ │ └─PositionalEmbedding: 3-2 4,909,056
│ │ └─Dropout: 3-3 --
│ └─ModuleList: 2-2 --
│ │ └─EncoderTransformerBlock: 3-4 6,497,280
│ │ └─EncoderTransformerBlock: 3-5 6,497,280
│ │ └─EncoderTransformerBlock: 3-6 6,497,280
│ │ └─EncoderTransformerBlock: 3-7 6,497,280
│ │ └─EncoderTransformerBlock: 3-8 6,497,280
│ │ └─EncoderTransformerBlock: 3-9 6,497,280
│ │ └─EncoderTransformerBlock: 3-10 6,497,280
│ │ └─EncoderTransformerBlock: 3-11 6,497,280
├─DNABERT2Interface: 1-2 --
│ └─BertModel: 2-3 --
│ │ └─BertEmbeddings: 3-12 3,148,800
│ │ └─BertEncoder: 3-13 113,329,152
│ │ └─BertPooler: 3-14 590,592
├─AdapterExternalEmb: 1-3 --
│ └─ResidualBlock: 2-4 --
│ │ └─Linear: 3-15 590,592
│ │ └─Linear: 3-16 590,592
│ │ └─LayerNorm: 3-17 1,536
│ │ └─Sequential: 3-18 --
│ │ └─Dropout: 3-19 --
│ └─ResidualBlock: 2-5 --
│ │ └─Linear: 3-20 590,592
│ │ └─Linear: 3-21 590,592
│ │ └─LayerNorm: 3-22 1,536
│ │ └─Sequential: 3-23 --
│ │ └─Dropout: 3-24 --
├─GeneralHeader: 1-4 --
│ └─CistromeEmbeddingManager: 2-6 --
│ └─Conv2d: 2-7 769
│ └─ReLU: 2-8 --
│ └─ResidualBlock: 2-9 --
│ │ └─Linear: 3-25 1,099,776
│ │ └─Linear: 3-26 1,049,600
│ │ └─LayerNorm: 3-27 2,048
│ │ └─Linear: 3-28 1,099,776
│ │ └─Dropout: 3-29 --
│ └─ResidualBlock: 2-10 --
│ │ └─Linear: 3-30 787,200
│ │ └─Linear: 3-31 590,592
│ │ └─LayerNorm: 3-32 1,536
│ │ └─Linear: 3-33 787,200
│ │ └─Dropout: 3-34 --
│ └─ResidualBlock: 2-11 --
│ │ └─Linear: 3-35 196,864
│ │ └─Linear: 3-36 65,792
│ │ └─LayerNorm: 3-37 512
│ │ └─Linear: 3-38 196,864
│ │ └─Dropout: 3-39 --
│ └─Linear: 2-12 257
├─PromptHeader: 1-5 --
│ └─Sequential: 2-13 --
│ │ └─ResidualBlock: 3-40 4,724,736
│ │ └─ResidualBlock: 3-41 2,952,960
│ │ └─ResidualBlock: 3-42 1,182,720
│ │ └─ResidualBlock: 3-43 102,720
│ │ └─Linear: 3-44 65
==========================================================================================
Total params: 191,170,947
Trainable params: 191,170,947
Non-trainable params: 0
==========================================================================================
[8]:
# we freeze pre-trained part of ChromBERT, as well as all parameters in DNABERT-2
model.freeze_pretrain(trainable=2)
model.dnabert2.freeze()
model.display_trainable_parameters()
{'total_params': 191170947, 'trainable_params': 30201987}
pretrain_model.embedding.token.weight : frozen
pretrain_model.embedding.position.pe.pe.weight : frozen
pretrain_model.transformer_blocks.0.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.0.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.0.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.0.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.0.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.0.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.0.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.0.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.0.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.0.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.1.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.1.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.1.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.1.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.1.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.1.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.1.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.1.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.1.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.1.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.2.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.2.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.2.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.2.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.2.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.2.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.2.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.2.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.2.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.2.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.3.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.3.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.3.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.3.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.3.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.3.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.3.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.3.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.3.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.3.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.4.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.4.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.4.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.4.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.4.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.4.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.4.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.4.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.4.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.4.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.5.attention.Wqkv.weight : frozen
pretrain_model.transformer_blocks.5.attention.Wqkv.bias : frozen
pretrain_model.transformer_blocks.5.feed_forward.w_1.weight : frozen
pretrain_model.transformer_blocks.5.feed_forward.w_1.bias : frozen
pretrain_model.transformer_blocks.5.feed_forward.w_2.weight : frozen
pretrain_model.transformer_blocks.5.feed_forward.w_2.bias : frozen
pretrain_model.transformer_blocks.5.input_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.5.input_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.5.output_sublayer.norm.a_2 : frozen
pretrain_model.transformer_blocks.5.output_sublayer.norm.b_2 : frozen
pretrain_model.transformer_blocks.6.attention.Wqkv.weight : trainable
pretrain_model.transformer_blocks.6.attention.Wqkv.bias : trainable
pretrain_model.transformer_blocks.6.feed_forward.w_1.weight : trainable
pretrain_model.transformer_blocks.6.feed_forward.w_1.bias : trainable
pretrain_model.transformer_blocks.6.feed_forward.w_2.weight : trainable
pretrain_model.transformer_blocks.6.feed_forward.w_2.bias : trainable
pretrain_model.transformer_blocks.6.input_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.6.input_sublayer.norm.b_2 : trainable
pretrain_model.transformer_blocks.6.output_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.6.output_sublayer.norm.b_2 : trainable
pretrain_model.transformer_blocks.7.attention.Wqkv.weight : trainable
pretrain_model.transformer_blocks.7.attention.Wqkv.bias : trainable
pretrain_model.transformer_blocks.7.feed_forward.w_1.weight : trainable
pretrain_model.transformer_blocks.7.feed_forward.w_1.bias : trainable
pretrain_model.transformer_blocks.7.feed_forward.w_2.weight : trainable
pretrain_model.transformer_blocks.7.feed_forward.w_2.bias : trainable
pretrain_model.transformer_blocks.7.input_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.7.input_sublayer.norm.b_2 : trainable
pretrain_model.transformer_blocks.7.output_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.7.output_sublayer.norm.b_2 : trainable
dnabert2.model.embeddings.word_embeddings.weight : frozen
dnabert2.model.embeddings.token_type_embeddings.weight : frozen
dnabert2.model.embeddings.LayerNorm.weight : frozen
dnabert2.model.embeddings.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.0.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.0.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.0.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.0.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.0.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.0.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.0.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.0.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.0.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.0.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.0.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.1.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.1.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.1.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.1.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.1.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.1.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.1.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.1.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.1.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.1.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.1.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.2.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.2.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.2.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.2.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.2.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.2.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.2.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.2.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.2.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.2.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.2.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.3.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.3.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.3.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.3.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.3.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.3.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.3.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.3.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.3.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.3.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.3.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.4.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.4.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.4.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.4.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.4.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.4.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.4.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.4.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.4.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.4.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.4.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.5.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.5.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.5.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.5.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.5.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.5.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.5.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.5.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.5.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.5.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.5.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.6.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.6.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.6.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.6.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.6.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.6.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.6.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.6.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.6.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.6.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.6.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.7.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.7.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.7.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.7.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.7.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.7.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.7.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.7.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.7.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.7.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.7.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.8.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.8.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.8.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.8.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.8.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.8.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.8.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.8.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.8.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.8.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.8.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.9.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.9.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.9.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.9.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.9.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.9.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.9.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.9.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.9.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.9.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.9.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.10.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.10.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.10.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.10.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.10.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.10.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.10.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.10.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.10.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.10.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.10.mlp.layernorm.bias : frozen
dnabert2.model.encoder.layer.11.attention.self.Wqkv.weight : frozen
dnabert2.model.encoder.layer.11.attention.self.Wqkv.bias : frozen
dnabert2.model.encoder.layer.11.attention.output.dense.weight : frozen
dnabert2.model.encoder.layer.11.attention.output.dense.bias : frozen
dnabert2.model.encoder.layer.11.attention.output.LayerNorm.weight : frozen
dnabert2.model.encoder.layer.11.attention.output.LayerNorm.bias : frozen
dnabert2.model.encoder.layer.11.mlp.gated_layers.weight : frozen
dnabert2.model.encoder.layer.11.mlp.wo.weight : frozen
dnabert2.model.encoder.layer.11.mlp.wo.bias : frozen
dnabert2.model.encoder.layer.11.mlp.layernorm.weight : frozen
dnabert2.model.encoder.layer.11.mlp.layernorm.bias : frozen
dnabert2.model.pooler.dense.weight : frozen
dnabert2.model.pooler.dense.bias : frozen
adapter_dna_emb.fc1.fc1.weight : trainable
adapter_dna_emb.fc1.fc1.bias : trainable
adapter_dna_emb.fc1.fc2.weight : trainable
adapter_dna_emb.fc1.fc2.bias : trainable
adapter_dna_emb.fc1.norm.a_2 : trainable
adapter_dna_emb.fc1.norm.b_2 : trainable
adapter_dna_emb.fc2.fc1.weight : trainable
adapter_dna_emb.fc2.fc1.bias : trainable
adapter_dna_emb.fc2.fc2.weight : trainable
adapter_dna_emb.fc2.fc2.bias : trainable
adapter_dna_emb.fc2.norm.a_2 : trainable
adapter_dna_emb.fc2.norm.b_2 : trainable
adapter_chrombert.conv.weight : trainable
adapter_chrombert.conv.bias : trainable
adapter_chrombert.res1.fc1.weight : trainable
adapter_chrombert.res1.fc1.bias : trainable
adapter_chrombert.res1.fc2.weight : trainable
adapter_chrombert.res1.fc2.bias : trainable
adapter_chrombert.res1.norm.a_2 : trainable
adapter_chrombert.res1.norm.b_2 : trainable
adapter_chrombert.res1.shortcut.weight : trainable
adapter_chrombert.res1.shortcut.bias : trainable
adapter_chrombert.res2.fc1.weight : trainable
adapter_chrombert.res2.fc1.bias : trainable
adapter_chrombert.res2.fc2.weight : trainable
adapter_chrombert.res2.fc2.bias : trainable
adapter_chrombert.res2.norm.a_2 : trainable
adapter_chrombert.res2.norm.b_2 : trainable
adapter_chrombert.res2.shortcut.weight : trainable
adapter_chrombert.res2.shortcut.bias : trainable
adapter_chrombert.res3.fc1.weight : trainable
adapter_chrombert.res3.fc1.bias : trainable
adapter_chrombert.res3.fc2.weight : trainable
adapter_chrombert.res3.fc2.bias : trainable
adapter_chrombert.res3.norm.a_2 : trainable
adapter_chrombert.res3.norm.b_2 : trainable
adapter_chrombert.res3.shortcut.weight : trainable
adapter_chrombert.res3.shortcut.bias : trainable
adapter_chrombert.fc.weight : trainable
adapter_chrombert.fc.bias : trainable
head_output.fcs.0.fc1.weight : trainable
head_output.fcs.0.fc1.bias : trainable
head_output.fcs.0.fc2.weight : trainable
head_output.fcs.0.fc2.bias : trainable
head_output.fcs.0.norm.a_2 : trainable
head_output.fcs.0.norm.b_2 : trainable
head_output.fcs.1.fc1.weight : trainable
head_output.fcs.1.fc1.bias : trainable
head_output.fcs.1.fc2.weight : trainable
head_output.fcs.1.fc2.bias : trainable
head_output.fcs.1.norm.a_2 : trainable
head_output.fcs.1.norm.b_2 : trainable
head_output.fcs.1.shortcut.weight : trainable
head_output.fcs.1.shortcut.bias : trainable
head_output.fcs.2.fc1.weight : trainable
head_output.fcs.2.fc1.bias : trainable
head_output.fcs.2.fc2.weight : trainable
head_output.fcs.2.fc2.bias : trainable
head_output.fcs.2.norm.a_2 : trainable
head_output.fcs.2.norm.b_2 : trainable
head_output.fcs.3.fc1.weight : trainable
head_output.fcs.3.fc1.bias : trainable
head_output.fcs.3.fc2.weight : trainable
head_output.fcs.3.fc2.bias : trainable
head_output.fcs.3.norm.a_2 : trainable
head_output.fcs.3.norm.b_2 : trainable
head_output.fcs.3.shortcut.weight : trainable
head_output.fcs.3.shortcut.bias : trainable
head_output.fcs.4.weight : trainable
head_output.fcs.4.bias : trainable
[8]:
{'total_params': 191170947, 'trainable_params': 30201987}
Fine-tune¶
This task has fewer training samples compared to others, so we use a reduced number of training steps. To simplify the process, the model is trained directly without using PyTorch Lightning.
[9]:
dc_train = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = df_train)
dc_test = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = df_test)
ds_train = dc_train.init_dataset()
ds_test = dc_test.init_dataset()
dl_train = dc_train.init_dataloader(batch_size=2, shuffle=True, num_workers=4)
dl_test = dc_test.init_dataloader(batch_size=2, shuffle=False, num_workers=4)
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
[10]:
from transformers import get_linear_schedule_with_warmup
from torch import nn
def train(m, dl, lr = 5e-5, grad_accumulation_steps=4, min_epochs = 2, max_epochs = 5, max_steps = 200):
num_training_steps = len(dl) * max_epochs // grad_accumulation_steps
optimizer = torch.optim.AdamW(m.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
loss_f = nn.BCEWithLogitsLoss()
keys_to_cuda = ["input_ids", "position_ids", "label"]
total_steps = 0
for e in range(max_epochs):
m.train()
for i, batch in enumerate(tqdm(dl)):
# batch = {k: v.cuda() for k, v in batch.items()}
for k in keys_to_cuda:
batch[k] = batch[k].cuda()
logits = m(batch).view(-1)
loss = loss_f(logits, batch["label"].to(torch.float))
loss = loss.mean() / grad_accumulation_steps
loss.backward()
if (i+1) % grad_accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_steps += 1
if total_steps > max_steps:
if e >= min_epochs:
return m
print(f"epoch {e} loss {loss.item()}")
return m
[11]:
model_tuned = train(model.cuda(), dl_train)
0%| | 0/461 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 461/461 [01:30<00:00, 5.10it/s]
epoch 0 loss 0.06694883108139038
100%|██████████| 461/461 [01:26<00:00, 5.32it/s]
epoch 1 loss 0.2063915878534317
1%| | 3/461 [00:01<02:42, 2.82it/s]
Evaluation¶
Typically, we save the checkpoint and evaluate it later. However, direct evaluation is an option if minimal randomness is acceptable.
[12]:
model_tuned.save_ckpt(odir / "model.ckpt")
[15]:
model = mc.init_model(finetune_ckpt = odir / "model.ckpt", dropout=0).cuda()
model.display_trainable_parameters(verbose=False)
use organisim hg38; max sequence length is 6391
Warning: zhihan1996/DNABERT-2-117M does not exist! Try to use huggingface cached...
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_clean/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
/home/chenqianqian/.conda/envs/chrombert_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
/home/chenqianqian/.conda/envs/chrombert_clean/lib/python3.9/site-packages/transformers/modeling_utils.py:442: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(checkpoint_file, map_location="cpu")
/home/chenqianqian/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:126: UserWarning: Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).
warnings.warn(
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Loading checkpoint from tmp_eqtl/model.ckpt
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_clean/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
new_state = torch.load(ckpt)
Loaded 290/290 parameters
{'total_params': 191170947, 'trainable_params': 191170947}
[15]:
{'total_params': 191170947, 'trainable_params': 191170947}
[16]:
model = model.cuda()
probs = []
labels = []
model.eval()
with torch.no_grad():
for i, batch in enumerate(tqdm(dl_test)):
keys_to_cuda = ["input_ids", "position_ids", "label"]
for k in keys_to_cuda:
batch[k] = batch[k].cuda()
logits = model(batch).view(-1)
probs.append(logits.sigmoid().float().cpu().numpy())
labels.append(batch["label"].cpu().numpy())
probs = np.concatenate(probs)
labels = np.concatenate(labels)
auc = metrics.roc_auc_score(labels, probs)
aupr = metrics.average_precision_score(labels, probs)
print(f"test auc {auc}, aupr {aupr}")
0%| | 0/307 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 307/307 [00:38<00:00, 8.01it/s]
test auc 0.8395420949791861, aupr 0.8312050511192879
Identification of differential regulators¶
We can infer differential regulators for causal and non-causal variants by comparing distances in regulator embeddings. Notably, DNase-seq data often shows a significant difference in these distances.
[17]:
model_emb = model.get_embedding_manager()
model_emb
[17]:
ChromBERTEmbedding(
(pretrain_model): ChromBERT(
(embedding): BERTEmbedding(
(token): TokenEmbedding(10, 768, padding_idx=0)
(position): PositionalEmbedding(
(pe): PositionalEmbeddingTrainable(
(pe): Embedding(6392, 768)
)
)
(dropout): Dropout(p=0, inplace=False)
)
(transformer_blocks): ModuleList(
(0-7): 8 x EncoderTransformerBlock(
(attention): SelfAttentionFlashMHA(
(Wqkv): Linear(in_features=768, out_features=2304, bias=True)
)
(feed_forward): PositionwiseFeedForward(
(w_1): Linear(in_features=768, out_features=3072, bias=True)
(w_2): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0, inplace=False)
(activation): GELU()
)
(input_sublayer): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0, inplace=False)
)
(output_sublayer): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0, inplace=False)
)
(dropout): Dropout(p=0, inplace=False)
)
)
)
(CistromeEmbeddingManager): CistromeEmbeddingManager()
)
[18]:
dc = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = table_eqtls)
dl = dc.init_dataloader(batch_size = 2, shuffle=False, num_workers=4)
labels = []
embeddings = []
for batch in tqdm(dl):
keys_to_cuda = ["input_ids", "position_ids", "label"]
for k in keys_to_cuda:
batch[k] = batch[k].cuda()
labels.append(batch["label"].cpu().numpy())
embeddings.append(model_emb(batch).float().cpu().numpy())
labels = np.concatenate(labels)
embeddings = np.concatenate(embeddings)
emb_causal = embeddings[labels == 1].mean(axis=0)
emb_noncausal = embeddings[labels == 0].mean(axis=0)
emb_causal.shape, emb_noncausal.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
100%|██████████| 768/768 [01:27<00:00, 8.82it/s]
[18]:
((1073, 768), (1073, 768))
[19]:
df_shift = pd.DataFrame(
{
"regulator":model_emb.list_regulator,
"shift": 1- sklearn.metrics.pairwise.cosine_similarity(emb_causal, emb_noncausal).diagonal(),
}
).sort_values("shift", ascending=False, ignore_index=True)
# The regulator embeddings of "input" were set to a zero vector, so its cosine similarity calculation is not meaningful
df_shift.head(20)
[19]:
| regulator | shift | |
|---|---|---|
| 0 | input | 1.000000 |
| 1 | rbl1 | 0.178612 |
| 2 | dnase | 0.161799 |
| 3 | hcfc1r1 | 0.146511 |
| 4 | sox4 | 0.141581 |
| 5 | h3k4me3 | 0.140027 |
| 6 | kdm2b | 0.131659 |
| 7 | lmnb1 | 0.130848 |
| 8 | h3k4me2 | 0.123314 |
| 9 | hexim1 | 0.121096 |
| 10 | cnot3 | 0.116232 |
| 11 | maz | 0.116230 |
| 12 | h1.4 | 0.115697 |
| 13 | kdm4c | 0.112475 |
| 14 | zfx | 0.111389 |
| 15 | cfp1 | 0.111306 |
| 16 | ints11 | 0.110976 |
| 17 | kdm5b | 0.107938 |
| 18 | cdca2 | 0.106124 |
| 19 | fgfr1 | 0.105282 |
[ ]: