Example for key regulators inference during cell state transition: transcriptome¶
To comprehensively infer key regulators during cell state transitions, it is crucial to integrate analyses of chromatin accessibility and the transcriptome. In this tutorial, we will demonstrate how to use ChromBERT to infer key regulators involved in a specific transdifferentiation process (fibroblast to myoblast) through transcriptome analysis.
[1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import chrombert
import pandas as pd
import numpy as np
from torchinfo import summary
import subprocess
import torch
import lightning.pytorch as pl
import glob
from tqdm import tqdm
import torchmetrics as tm
base_dir = os.path.expanduser("~/.cache/chrombert/data") ### to_path_chrombert/data
Preprocess dataset¶
This section walks you through preparing raw transcriptome data, including TSS and TPM values for each gene, and transforming it into the format required by ChromBERT.
To identify key regulators, we will fine-tune ChromBERT to predict log1p-transformed gene expression fold changes during transdifferentiation. This requires careful preparation of the transformed data. Additionally, to analyze shifts in regulator embeddings, we need to identify and prepare datasets for both upregulated and unchanged genes.
[2]:
# We provide tables of gene expression data for fibroblast and myoblast.
gep_dir = f'{base_dir}/demo/transdifferentiation/transcriptome'
fibroblast_exp = pd.read_csv(f'{gep_dir}/fibroblast_expression.csv')
myoblast_exp = pd.read_csv(f'{gep_dir}/myoblast_expression.csv')
myoblast_exp.head(), fibroblast_exp.head()
[2]:
( chrom tss gene_id tpm
0 chr19 58353492 ENSG00000121410 22.236894
1 chr19 58347718 ENSG00000268895 9.317134
2 chr10 50885675 ENSG00000148584 0.000000
3 chr12 9116229 ENSG00000175899 0.993828
4 chr12 9065163 ENSG00000245105 0.124228,
chrom tss gene_id tpm
0 chr19 58353492 ENSG00000121410 12.774133
1 chr19 58347718 ENSG00000268895 2.939181
2 chr10 50885675 ENSG00000148584 0.000000
3 chr12 9116229 ENSG00000175899 0.226091
4 chr12 9065163 ENSG00000245105 0.226091)
Prepare the log1p-transformed gene expression fold changes data¶
[3]:
from chrombert.scripts.chrombert_make_dataset import get_regions
# We merge the two datasets, and calculate the log1p-transformed gene expression fold changes.
merge_exp = pd.merge(fibroblast_exp,myoblast_exp,left_on=['chrom','tss','gene_id'],right_on=['chrom','tss','gene_id'],suffixes=['_fibroblast','_myoblast'])
merge_exp['fold_change']= np.log1p(merge_exp['tpm_myoblast']) - np.log1p(merge_exp['tpm_fibroblast'])
merge_exp['start'] = merge_exp['tss']//1000 * 1000
merge_exp['end'] = (merge_exp['tss']//1000 + 1) * 1000
foldchange_exp = merge_exp [['chrom','start','end','tss','gene_id','fold_change']]
# align genomic coordinates to the predefined 1-kb bins
chrom_regions = get_regions(base_dir,genome='hg38',high_resolution=False) # 1kb
chrom_regions
chrom_regions_df = pd.read_csv(chrom_regions,sep='\t',names=['chrom','start','end','build_region_index'])
chrom_regions_df
merge_region = pd.merge(foldchange_exp,chrom_regions_df,left_on=['chrom','start','end'],right_on=['chrom','start','end'],how='inner')[['chrom','start','end','build_region_index','fold_change','tss','gene_id']]
gep_df = merge_region.rename(columns={'fold_change':'label'})
gep_df.to_csv(f'{gep_dir}/fibroblast_to_myoblast_expression_changes.csv',index=False)
gep_df.head() ### This label represents log1p-transformed gene expression fold change data."
[3]:
| chrom | start | end | build_region_index | label | tss | gene_id | |
|---|---|---|---|---|---|---|---|
| 0 | chr19 | 58353000 | 58354000 | 917950 | 0.522949 | 58353492 | ENSG00000121410 |
| 1 | chr19 | 58347000 | 58348000 | 917944 | 0.962833 | 58347718 | ENSG00000268895 |
| 2 | chr10 | 50885000 | 50886000 | 221904 | 0.000000 | 50885675 | ENSG00000148584 |
| 3 | chr12 | 9116000 | 9117000 | 393001 | 0.486225 | 9116229 | ENSG00000175899 |
| 4 | chr12 | 9065000 | 9066000 | 392961 | -0.086734 | 9065163 | ENSG00000245105 |
Prepare the fine-tuning data¶
We split the data into training, testing, and validation sets with an 8:1:1 ratio and downsample the data to test the fine-tuning process.
[4]:
train_data = gep_df.sample(frac=0.8,random_state=55)
test_data = gep_df.drop(train_data.index).sample(frac=0.5,random_state=55)
valid_data = gep_df.drop(train_data.index).drop(test_data.index)
train_data_sample = train_data.sample(n=80,random_state=55)
test_data_sample = test_data.sample(n=50,random_state=55)
valid_data_sample = valid_data.sample(n=20,random_state=55)
train_data_sample.to_csv(f'{gep_dir}/train_sample.csv',index=False)
test_data_sample.to_csv(f'{gep_dir}/test_sample.csv',index=False)
valid_data_sample.to_csv(f'{gep_dir}/valid_sample.csv',index=False)
train_data_sample.head()
[4]:
| chrom | start | end | build_region_index | label | tss | gene_id | |
|---|---|---|---|---|---|---|---|
| 4496 | chr6 | 138795000 | 138796000 | 1719247 | 0.000000 | 138795911 | ENSG00000203734 |
| 13237 | chr8 | 96261000 | 96262000 | 1933979 | -0.351699 | 96261902 | ENSG00000156471 |
| 17079 | chr22 | 29555000 | 29556000 | 1186796 | -0.565257 | 29555216 | ENSG00000100296 |
| 4118 | chr8 | 1622000 | 1623000 | 1866390 | 0.000000 | 1622417 | ENSG00000253267 |
| 13482 | chr7 | 66682000 | 66683000 | 1792879 | 0.402664 | 66682164 | ENSG00000154710 |
Prepare the upregulated genes and unchanged genes¶
Here we identify the upregulated genes and unchanged genes, based on the log1p-transformed gene expression fold changes. In addition, we downsample the list of upregulated and unchanged genes to save time.
[5]:
up_data = gep_df[gep_df['label']>1]
nochange_data = gep_df[(gep_df['label']>-0.5) & (gep_df['label']<0.5)]
up_data_sample = up_data.sample(n=100,random_state=55)
nochange_data_sample = nochange_data.sample(n=100,random_state=55)
up_data_sample.to_csv(f'{gep_dir}/up_data_sample.csv',index=False)
nochange_data_sample.to_csv(f'{gep_dir}/nochange_data_sample.csv',index=False)
Fine-tune¶
This section provides a tutorial for fine-tuning ChromBERTs to predict genome-wide changes in the transcriptome. The parameters for transcriptome changes are adjusted to account for the inclusion of nearby flank regions for each TSS. The process involves the following key modifications:
Dataset Configuration: Use the
multi_flank_windowpreset for dataset configuration and set theflank_windowparameter to four.Model Instantiation: Use the
geppreset for model configuration and set thegep_flank_windowparameter to four.
Configure dataset and data module¶
[6]:
# We use the `multi_flank_window` preset dataset config and set the `flank_window` parameter to 4.
# The flank_window parameter is used to set the flank window for the dataset configuration.
# "4" represents +/- 4 nearest genomic regions to the TSS were used.
dataset_config = chrombert.get_preset_dataset_config(
"multi_flank_window",
supervised_file = None,
batch_size = 2,
num_workers = 4,
flank_window=4
)
dataset_config
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
[6]:
DatasetConfig({'hdf5_file': '/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5', 'supervised_file': None, 'kind': 'MultiFlankwindowDataset', 'meta_file': '/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json', 'ignore': False, 'ignore_object': None, 'batch_size': 2, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'perturbation': False, 'perturbation_object': None, 'perturbation_value': 0, 'prompt_kind': None, 'prompt_regulator': None, 'prompt_regulator_cache_file': None, 'prompt_celltype': None, 'prompt_celltype_cache_file': None, 'prompt_regulator_cache_pin_memory': False, 'prompt_regulator_cache_limit': 3, 'fasta_file': None, 'flank_window': 4})
[7]:
gep_dir = f'{base_dir}/demo/transdifferentiation/transcriptome'
[8]:
# We use the `LitChromBERTFTDataModule` to create a data module for fine-tuning.
data_module = chrombert.LitChromBERTFTDataModule(
config = dataset_config,
train_params = {'supervised_file': f'{gep_dir}/train_sample.csv'},
val_params = {'supervised_file':f'{gep_dir}/valid_sample.csv'},
test_params = {'supervised_file':f'{gep_dir}/test_sample.csv'}
)
data_module.setup()
Configure model and instantiation¶
[9]:
# We use the `gep` preset model config and also set the `gep_flank_window` parameter to four.
model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4)
model_config
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
[9]:
ChromBERTFTConfig:
{
"genome": "hg38",
"task": "gep",
"dim_output": 1,
"mtx_mask": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv",
"dropout": 0.1,
"pretrain_ckpt": "/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt",
"finetune_ckpt": null,
"ignore": false,
"ignore_index": [
null,
null
],
"gep_flank_window": 4,
"gep_parallel_embedding": false,
"gep_gradient_checkpoint": false,
"gep_zero_inflation": false,
"prompt_kind": "cistrome",
"prompt_dim_external": 512,
"dnabert2_ckpt": null
}
[10]:
model = model_config.init_model()
model.freeze_pretrain(2) ### freeze chrombert 6 transformer blocks during fine-tuning
summary(model)
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
[10]:
=====================================================================================
Layer (type:depth-idx) Param #
=====================================================================================
ChromBERTGEP --
├─PoolFlankWindow: 1-1 --
│ └─ChromBERT: 2-1 --
│ │ └─BERTEmbedding: 3-1 (4,916,736)
│ │ └─ModuleList: 3-2 51,978,240
├─GeneralHeader: 1-2 --
│ └─CistromeEmbeddingManager: 2-2 --
│ └─Conv2d: 2-3 769
│ └─ReLU: 2-4 --
│ └─ResidualBlock: 2-5 --
│ │ └─Linear: 3-3 1,099,776
│ │ └─Linear: 3-4 1,049,600
│ │ └─LayerNorm: 3-5 2,048
│ │ └─Linear: 3-6 1,099,776
│ │ └─Dropout: 3-7 --
│ └─ResidualBlock: 2-6 --
│ │ └─Linear: 3-8 787,200
│ │ └─Linear: 3-9 590,592
│ │ └─LayerNorm: 3-10 1,536
│ │ └─Linear: 3-11 787,200
│ │ └─Dropout: 3-12 --
│ └─ResidualBlock: 2-7 --
│ │ └─Linear: 3-13 196,864
│ │ └─Linear: 3-14 65,792
│ │ └─LayerNorm: 3-15 512
│ │ └─Linear: 3-16 196,864
│ │ └─Dropout: 3-17 --
│ └─Linear: 2-8 257
=====================================================================================
Total params: 62,773,762
Trainable params: 18,873,346
Non-trainable params: 43,900,416
=====================================================================================
Configure training parameters and fine-tune¶
The model is fine-tuned using PyTorch Lightning with a straightforward parameter configuration. To save time, fine-tuning is performed on a smaller dataset.
Note: Due to the stochastic nature of the tuning process, results may vary. For improved performance, consider increasing the number of epochs and expanding the dataset size.
[11]:
train_config = chrombert.finetune.TrainConfig(
kind='regression',
loss='rmse',
max_epochs=2,
accumulate_grad_batches=2,
val_check_interval=2,
limit_val_batches=10,
tag='gep',
checkpoint_mode='max',
checkpoint_metric='pcc'
)
pl_module = train_config.init_pl_module(model) # wrap model with PyTorch Lightning module
type(pl_module)
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs) # noqa: B028
[11]:
chrombert.finetune.train.pl_module.RegressionPLModule
[12]:
callback_ckpt = pl.callbacks.ModelCheckpoint(monitor = f"{train_config.tag}_validation/{train_config.loss}", mode = "min")
trainer = pl.Trainer(
max_epochs=train_config.max_epochs,
log_every_n_steps=1,
limit_val_batches = train_config.limit_val_batches,
val_check_interval = train_config.val_check_interval,
accelerator="gpu",
accumulate_grad_batches= train_config.accumulate_grad_batches,
fast_dev_run=False,
precision="bf16-mixed",
strategy="auto",
callbacks=[
pl.callbacks.LearningRateMonitor(),
callback_ckpt,
],
logger=pl.loggers.TensorBoardLogger("lightning_logs", name='gep'))
trainer.fit(pl_module,data_module)
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: lightning_logs/gep
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
| Name | Type | Params | Mode
-----------------------------------------------
0 | model | ChromBERTGEP | 62.8 M | train
-----------------------------------------------
18.9 M Trainable params
43.9 M Non-trainable params
62.8 M Total params
251.095 Total estimated model params size (MB)
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
`Trainer.fit` stopped: `max_epochs=2` reached.
Evaluate the fine-tuned model¶
ChromBERT has been successfully fine-tuned! You can evaluate the model directly using pl_module.model. However, note that due to flash-attention settings, the dropout probability cannot be adjusted by model.eval(), which may introduce some randomness to the output.
So, to ensure consistent results,it’s suggested to save the fine-tuned checkpoint and reload it into the original model configuration for accurate evaluation.
In addition, we use only downsampled test data for evaluation in this tutorial to save time.
Evaluate the model fine-tuned with limited data¶
We first load the fine-tuned model and evaluate it on the downsampled test data here.
[13]:
gep_ft_ckpt = os.path.abspath(glob.glob('./lightning_logs/gep/version*/checkpoints/*.ckpt')[0])
model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4, dropout=0)
ft_model = model_config.init_model(finetune_ckpt = gep_ft_ckpt)
dl = data_module.test_dataloader()
ft_model.cuda()
with torch.no_grad():
y_preds = []
y_labels = []
for idx, batch in enumerate(tqdm(dl, total=len(dl))):
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].cuda()
y_pred = ft_model(batch).cpu()
y_label = batch['label'].cpu()
y_preds.append(y_pred)
y_labels.append(y_label)
y_preds = torch.cat(y_preds)
y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
"pearsonr": score_pearsonr,
"spearmanr": score_spearmanr,
"mse": score_mse,
"mae": score_mae,
"r2": score_r2,
}
print(scores)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/examples/tutorials/lightning_logs/gep/version_0/checkpoints/epoch=0-step=2.ckpt
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
new_state = torch.load(ckpt)
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
100%|██████████| 25/25 [00:23<00:00, 1.06it/s]
{'pearsonr': tensor(-0.0272), 'spearmanr': tensor(0.0601), 'mse': tensor(0.5545), 'mae': tensor(0.5203), 'r2': tensor(-0.1555)}
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs) # noqa: B028
Performance of the Fine-Tuned Model¶
The model fine-tuned with limited data demonstrated suboptimal performance during evaluation. To address this, we provided a checkpoint fine-tuned on the entire dataset and will evaluate its performance here.
[14]:
gep_ft_ckpt = f'{gep_dir}/gep_fibroblast_to_myoblast.ckpt'
model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4, dropout=0)
ft_model = model_config.init_model(finetune_ckpt = gep_ft_ckpt)
dl = data_module.test_dataloader()
ft_model.cuda()
with torch.no_grad():
y_preds = []
y_labels = []
for batch in tqdm(dl,total=len(dl)):
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].cuda()
y_pred = ft_model(batch).cpu()
y_label = batch['label'].cpu()
y_preds.append(y_pred)
y_labels.append(y_label)
y_preds = torch.cat(y_preds)
y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
"pearsonr": score_pearsonr,
"spearmanr": score_spearmanr,
"mse": score_mse,
"mae": score_mae,
"r2": score_r2,
}
print(scores)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
new_state = torch.load(ckpt)
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
100%|██████████| 25/25 [00:23<00:00, 1.08it/s]
{'pearsonr': tensor(0.7469), 'spearmanr': tensor(0.5639), 'mse': tensor(0.2352), 'mae': tensor(0.3242), 'r2': tensor(0.5099)}
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs) # noqa: B028
Infer key regulators¶
Using the fine-tuned model, we analyze embedding similarities between upregulated and unchanged genomic regions. Lower embedding similarity is hypothesized to indicate a greater functional shift, highlighting the potential role of key regulators in cell state transitions.
To save time, we perform the analysis on a downsampled set of upregulated and unchanged genomic regions.
Note: Only the selected factors are considered for this analysis, excluding histone modifications and chromatin accessibility.
[15]:
# Load the fine-tuned model
model_tuned = chrombert.get_preset_model_config(
"gep",
gep_flank_window = 4,
dropout = 0,
finetune_ckpt = f'{gep_dir}/gep_fibroblast_to_myoblast.ckpt').init_model() # use absolute path here, to avoid mixing of preset
# Get the embedding manager
model_emb = model_tuned.get_embedding_manager().cuda()
summary(model_emb)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
new_state = torch.load(ckpt)
[15]:
=====================================================================================
Layer (type:depth-idx) Param #
=====================================================================================
ChromBERTEmbedding --
├─PoolFlankWindow: 1-1 --
│ └─ChromBERT: 2-1 --
│ │ └─BERTEmbedding: 3-1 4,916,736
│ │ └─ModuleList: 3-2 51,978,240
├─CistromeEmbeddingManager: 1-2 --
=====================================================================================
Total params: 56,894,976
Trainable params: 56,894,976
Non-trainable params: 0
=====================================================================================
Gather regulator embeddings in upregulated gene¶
[16]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/up_data_sample.csv', batch_size = 32, num_workers = 4)
dl = dataset_config.init_dataloader()
up_gep_embs = []
for batch in tqdm(dl):
with torch.no_grad():
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.cuda()
emb = model_emb(batch).cpu()
up_gep_embs.append(emb)
up_gep_embs = torch.cat(up_gep_embs)
up_gep_embs.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
100%|██████████| 4/4 [00:43<00:00, 10.95s/it]
[16]:
torch.Size([100, 1073, 768])
Gather regulator embeddings in unchanged genes¶
[17]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/nochange_data_sample.csv', batch_size = 32, num_workers = 4)
dl = dataset_config.init_dataloader()
nochange_gep_embs = []
for batch in tqdm(dl):
with torch.no_grad():
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.cuda()
emb = model_emb(batch).cpu()
nochange_gep_embs.append(emb)
nochange_gep_embs = torch.cat(nochange_gep_embs)
nochange_gep_embs.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
100%|██████████| 4/4 [00:43<00:00, 10.98s/it]
[17]:
torch.Size([100, 1073, 768])
We consider only factors below, remove histone modifications and chromatin accessibility.
[18]:
with open(os.path.join(base_dir, "config","hg38_6k_factors_list.txt"),"r") as f:
factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]
indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
up_gep_embs = up_gep_embs.mean(axis=0)[indices]
nochange_gep_embs = nochange_gep_embs.mean(axis=0)[indices]
up_gep_embs.shape, nochange_gep_embs.shape
[18]:
(torch.Size([991, 768]), torch.Size([991, 768]))
Analyze the embedding similarity between upregulated and unchanged genes¶
We calculate the cosine similarity between the embeddings of upregulated and unchanged genomic regions to identify key regulators.
[19]:
from sklearn.metrics.pairwise import cosine_similarity
gep_similarity = [cosine_similarity(up_gep_embs[i].reshape(1, -1), nochange_gep_embs[i].reshape(1, -1))[0, 0] for i in range(up_gep_embs.shape[0])]
gep_similarity_df = pd.DataFrame({'factors':names,'similarity':gep_similarity}).sort_values(by='similarity').reset_index(drop=True)
gep_similarity_df['rank']=gep_similarity_df.index + 1
gep_similarity_df.to_csv(f'{gep_dir}/gep_similarity_df.csv',index=False)
gep_similarity_df
[19]:
| factors | similarity | rank | |
|---|---|---|---|
| 0 | chd4 | 0.924975 | 1 |
| 1 | esco2 | 0.935212 | 2 |
| 2 | cbx7 | 0.946856 | 3 |
| 3 | cbx6 | 0.946874 | 4 |
| 4 | cbx8 | 0.949367 | 5 |
| ... | ... | ... | ... |
| 986 | zbtb10 | 0.998598 | 987 |
| 987 | zbed4 | 0.998614 | 988 |
| 988 | nkx2-5 | 0.998627 | 989 |
| 989 | snapc4 | 0.998678 | 990 |
| 990 | dbp | 0.998807 | 991 |
991 rows × 3 columns
[20]:
indentified_factor = gep_similarity_df[gep_similarity_df['rank']<=25]['factors'].tolist()
indentified_factor
[20]:
['chd4',
'esco2',
'cbx7',
'cbx6',
'cbx8',
'brd7',
'hira',
'myf5',
'ring1',
'neurog2',
'kdm6b',
'ubn1',
'nr3c2',
'rpa2',
'sumo1',
'yap1',
'ptpn11',
'myod1',
'klf11',
'phf2',
'prkdc',
'brdu',
'ssrp1',
'tead',
'tead1']
[21]:
gep_similarity_df[gep_similarity_df['factors']=='myod1']
[21]:
| factors | similarity | rank | |
|---|---|---|---|
| 17 | myod1 | 0.983192 | 18 |
We identified 25 key regulators in the transcriptome, including the notable regulator ‘MYOD1’.
Combine the regulator rankings from both chromatin accessibility and transcriptome¶
Final top 25 key regulators derived from average ranks across both chromatin accessibility and transcriptome
[22]:
chrom_accessibility_path = f'{gep_dir}/../chrom_accessibility/chromatin_accessibility_similarity_df.csv'
if not os.path.exists(chrom_accessibility_path):
raise ValueError("Please follow the tutorial for key regulators inference during cell state transition: Chromatin accessibility")
else:
chrom_acc_similarity_df = pd.read_csv(chrom_accessibility_path)
average_rank_df = pd.merge(gep_similarity_df, chrom_acc_similarity_df, on='factors', how='inner', suffixes=('_gep', '_chrom_acc'))
average_rank_df['averge_rank'] = ((average_rank_df['rank_gep']+average_rank_df['rank_chrom_acc'])/2).rank().astype(int)
average_rank_df=average_rank_df.sort_values(by='averge_rank')
average_rank_df
average_rank_df[average_rank_df['factors']=='myod1']
[22]:
| factors | similarity_gep | rank_gep | similarity_chrom_acc | rank_chrom_acc | averge_rank | |
|---|---|---|---|---|---|---|
| 17 | myod1 | 0.983192 | 18 | -0.064511 | 1 | 2 |
[23]:
final_indentified_factor = average_rank_df[average_rank_df['averge_rank']<=25]['factors'].tolist()
final_indentified_factor
[23]:
['myf5',
'myod1',
'neurog2',
'yap1',
'nr3c2',
'tead1',
'tead',
'dux4',
'pgbd3',
'chd4',
'myog',
'hira',
'pax3-foxo1a',
'tbx5',
'nr3c1',
'snai2',
'ss18',
'prmt5',
'ubn1',
'rb1',
'six2',
'klf11',
'ercc6',
'sumo1',
'esco2']
[ ]:
[ ]: