Example for key regulators inference during cell state transition: transcriptome

To comprehensively infer key regulators during cell state transitions, it is crucial to integrate analyses of chromatin accessibility and the transcriptome. In this tutorial, we will demonstrate how to use ChromBERT to infer key regulators involved in a specific transdifferentiation process (fibroblast to myoblast) through transcriptome analysis.

[1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import chrombert
import pandas as pd
import numpy as np
from torchinfo import summary
import subprocess
import torch
import lightning.pytorch as pl
import glob
from tqdm import tqdm
import torchmetrics as tm
base_dir =  os.path.expanduser("~/.cache/chrombert/data") ### to_path_chrombert/data

Preprocess dataset

This section walks you through preparing raw transcriptome data, including TSS and TPM values for each gene, and transforming it into the format required by ChromBERT.

To identify key regulators, we will fine-tune ChromBERT to predict log1p-transformed gene expression fold changes during transdifferentiation. This requires careful preparation of the transformed data. Additionally, to analyze shifts in regulator embeddings, we need to identify and prepare datasets for both upregulated and unchanged genes.

[2]:
# We provide tables of gene expression data for fibroblast and myoblast.
gep_dir = f'{base_dir}/demo/transdifferentiation/transcriptome'
fibroblast_exp = pd.read_csv(f'{gep_dir}/fibroblast_expression.csv')
myoblast_exp = pd.read_csv(f'{gep_dir}/myoblast_expression.csv')
myoblast_exp.head(), fibroblast_exp.head()
[2]:
(   chrom       tss          gene_id        tpm
 0  chr19  58353492  ENSG00000121410  22.236894
 1  chr19  58347718  ENSG00000268895   9.317134
 2  chr10  50885675  ENSG00000148584   0.000000
 3  chr12   9116229  ENSG00000175899   0.993828
 4  chr12   9065163  ENSG00000245105   0.124228,
    chrom       tss          gene_id        tpm
 0  chr19  58353492  ENSG00000121410  12.774133
 1  chr19  58347718  ENSG00000268895   2.939181
 2  chr10  50885675  ENSG00000148584   0.000000
 3  chr12   9116229  ENSG00000175899   0.226091
 4  chr12   9065163  ENSG00000245105   0.226091)

Prepare the log1p-transformed gene expression fold changes data

[3]:
from chrombert.scripts.chrombert_make_dataset import get_regions
# We merge the two datasets, and calculate the log1p-transformed gene expression fold changes.
merge_exp = pd.merge(fibroblast_exp,myoblast_exp,left_on=['chrom','tss','gene_id'],right_on=['chrom','tss','gene_id'],suffixes=['_fibroblast','_myoblast'])
merge_exp['fold_change']= np.log1p(merge_exp['tpm_myoblast']) - np.log1p(merge_exp['tpm_fibroblast'])
merge_exp['start'] = merge_exp['tss']//1000 * 1000
merge_exp['end'] = (merge_exp['tss']//1000 + 1) * 1000
foldchange_exp = merge_exp [['chrom','start','end','tss','gene_id','fold_change']]

# align genomic coordinates to the predefined 1-kb bins
chrom_regions = get_regions(base_dir,genome='hg38',high_resolution=False) # 1kb
chrom_regions
chrom_regions_df = pd.read_csv(chrom_regions,sep='\t',names=['chrom','start','end','build_region_index'])
chrom_regions_df
merge_region = pd.merge(foldchange_exp,chrom_regions_df,left_on=['chrom','start','end'],right_on=['chrom','start','end'],how='inner')[['chrom','start','end','build_region_index','fold_change','tss','gene_id']]
gep_df = merge_region.rename(columns={'fold_change':'label'})
gep_df.to_csv(f'{gep_dir}/fibroblast_to_myoblast_expression_changes.csv',index=False)
gep_df.head() ### This label represents log1p-transformed gene expression fold change data."
[3]:
chrom start end build_region_index label tss gene_id
0 chr19 58353000 58354000 917950 0.522949 58353492 ENSG00000121410
1 chr19 58347000 58348000 917944 0.962833 58347718 ENSG00000268895
2 chr10 50885000 50886000 221904 0.000000 50885675 ENSG00000148584
3 chr12 9116000 9117000 393001 0.486225 9116229 ENSG00000175899
4 chr12 9065000 9066000 392961 -0.086734 9065163 ENSG00000245105

Prepare the fine-tuning data

We split the data into training, testing, and validation sets with an 8:1:1 ratio and downsample the data to test the fine-tuning process.

[4]:
train_data = gep_df.sample(frac=0.8,random_state=55)
test_data = gep_df.drop(train_data.index).sample(frac=0.5,random_state=55)
valid_data = gep_df.drop(train_data.index).drop(test_data.index)
train_data_sample = train_data.sample(n=80,random_state=55)
test_data_sample = test_data.sample(n=50,random_state=55)
valid_data_sample = valid_data.sample(n=20,random_state=55)
train_data_sample.to_csv(f'{gep_dir}/train_sample.csv',index=False)
test_data_sample.to_csv(f'{gep_dir}/test_sample.csv',index=False)
valid_data_sample.to_csv(f'{gep_dir}/valid_sample.csv',index=False)
train_data_sample.head()
[4]:
chrom start end build_region_index label tss gene_id
4496 chr6 138795000 138796000 1719247 0.000000 138795911 ENSG00000203734
13237 chr8 96261000 96262000 1933979 -0.351699 96261902 ENSG00000156471
17079 chr22 29555000 29556000 1186796 -0.565257 29555216 ENSG00000100296
4118 chr8 1622000 1623000 1866390 0.000000 1622417 ENSG00000253267
13482 chr7 66682000 66683000 1792879 0.402664 66682164 ENSG00000154710

Prepare the upregulated genes and unchanged genes

Here we identify the upregulated genes and unchanged genes, based on the log1p-transformed gene expression fold changes. In addition, we downsample the list of upregulated and unchanged genes to save time.

[5]:
up_data = gep_df[gep_df['label']>1]
nochange_data = gep_df[(gep_df['label']>-0.5) & (gep_df['label']<0.5)]

up_data_sample = up_data.sample(n=100,random_state=55)
nochange_data_sample = nochange_data.sample(n=100,random_state=55)


up_data_sample.to_csv(f'{gep_dir}/up_data_sample.csv',index=False)
nochange_data_sample.to_csv(f'{gep_dir}/nochange_data_sample.csv',index=False)

Fine-tune

This section provides a tutorial for fine-tuning ChromBERTs to predict genome-wide changes in the transcriptome. The parameters for transcriptome changes are adjusted to account for the inclusion of nearby flank regions for each TSS. The process involves the following key modifications:

  • Dataset Configuration: Use the multi_flank_window preset for dataset configuration and set the flank_window parameter to four.

  • Model Instantiation: Use the gep preset for model configuration and set the gep_flank_window parameter to four.

Configure dataset and data module

[6]:
# We use the `multi_flank_window` preset dataset config and set the `flank_window` parameter to 4.
# The flank_window parameter is used to set the flank window for the dataset configuration.
# "4" represents +/- 4 nearest genomic regions to the TSS were used.
dataset_config = chrombert.get_preset_dataset_config(
    "multi_flank_window",
    supervised_file = None,
    batch_size = 2,
    num_workers = 4,
    flank_window=4
    )
dataset_config
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
[6]:
DatasetConfig({'hdf5_file': '/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5', 'supervised_file': None, 'kind': 'MultiFlankwindowDataset', 'meta_file': '/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json', 'ignore': False, 'ignore_object': None, 'batch_size': 2, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'perturbation': False, 'perturbation_object': None, 'perturbation_value': 0, 'prompt_kind': None, 'prompt_regulator': None, 'prompt_regulator_cache_file': None, 'prompt_celltype': None, 'prompt_celltype_cache_file': None, 'prompt_regulator_cache_pin_memory': False, 'prompt_regulator_cache_limit': 3, 'fasta_file': None, 'flank_window': 4})
[7]:
gep_dir = f'{base_dir}/demo/transdifferentiation/transcriptome'
[8]:
# We use the `LitChromBERTFTDataModule` to create a data module for fine-tuning.
data_module = chrombert.LitChromBERTFTDataModule(
    config = dataset_config,
    train_params = {'supervised_file': f'{gep_dir}/train_sample.csv'},
    val_params = {'supervised_file':f'{gep_dir}/valid_sample.csv'},
    test_params = {'supervised_file':f'{gep_dir}/test_sample.csv'}
)
data_module.setup()

Configure model and instantiation

[9]:
# We use the `gep` preset model config and also set the `gep_flank_window` parameter to four.

model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4)
model_config
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
[9]:
ChromBERTFTConfig:
{
    "genome": "hg38",
    "task": "gep",
    "dim_output": 1,
    "mtx_mask": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv",
    "dropout": 0.1,
    "pretrain_ckpt": "/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt",
    "finetune_ckpt": null,
    "ignore": false,
    "ignore_index": [
        null,
        null
    ],
    "gep_flank_window": 4,
    "gep_parallel_embedding": false,
    "gep_gradient_checkpoint": false,
    "gep_zero_inflation": false,
    "prompt_kind": "cistrome",
    "prompt_dim_external": 512,
    "dnabert2_ckpt": null
}
[10]:
model = model_config.init_model()
model.freeze_pretrain(2) ### freeze chrombert 6 transformer blocks during fine-tuning
summary(model)
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
[10]:
=====================================================================================
Layer (type:depth-idx)                                       Param #
=====================================================================================
ChromBERTGEP                                                 --
├─PoolFlankWindow: 1-1                                       --
│    └─ChromBERT: 2-1                                        --
│    │    └─BERTEmbedding: 3-1                               (4,916,736)
│    │    └─ModuleList: 3-2                                  51,978,240
├─GeneralHeader: 1-2                                         --
│    └─CistromeEmbeddingManager: 2-2                         --
│    └─Conv2d: 2-3                                           769
│    └─ReLU: 2-4                                             --
│    └─ResidualBlock: 2-5                                    --
│    │    └─Linear: 3-3                                      1,099,776
│    │    └─Linear: 3-4                                      1,049,600
│    │    └─LayerNorm: 3-5                                   2,048
│    │    └─Linear: 3-6                                      1,099,776
│    │    └─Dropout: 3-7                                     --
│    └─ResidualBlock: 2-6                                    --
│    │    └─Linear: 3-8                                      787,200
│    │    └─Linear: 3-9                                      590,592
│    │    └─LayerNorm: 3-10                                  1,536
│    │    └─Linear: 3-11                                     787,200
│    │    └─Dropout: 3-12                                    --
│    └─ResidualBlock: 2-7                                    --
│    │    └─Linear: 3-13                                     196,864
│    │    └─Linear: 3-14                                     65,792
│    │    └─LayerNorm: 3-15                                  512
│    │    └─Linear: 3-16                                     196,864
│    │    └─Dropout: 3-17                                    --
│    └─Linear: 2-8                                           257
=====================================================================================
Total params: 62,773,762
Trainable params: 18,873,346
Non-trainable params: 43,900,416
=====================================================================================

Configure training parameters and fine-tune

The model is fine-tuned using PyTorch Lightning with a straightforward parameter configuration. To save time, fine-tuning is performed on a smaller dataset.

Note: Due to the stochastic nature of the tuning process, results may vary. For improved performance, consider increasing the number of epochs and expanding the dataset size.

[11]:
train_config = chrombert.finetune.TrainConfig(
    kind='regression',
    loss='rmse',
    max_epochs=2,
    accumulate_grad_batches=2,
    val_check_interval=2,
    limit_val_batches=10,
    tag='gep',
    checkpoint_mode='max',
    checkpoint_metric='pcc'
    )
pl_module = train_config.init_pl_module(model) # wrap model with PyTorch Lightning module
type(pl_module)
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)  # noqa: B028
[11]:
chrombert.finetune.train.pl_module.RegressionPLModule
Then we start fine-tuning!
The trainer will save logs in a format compatible with TensorBoard, and the checkpoint with the lowest validation loss will be saved during the process.
[12]:
callback_ckpt = pl.callbacks.ModelCheckpoint(monitor = f"{train_config.tag}_validation/{train_config.loss}", mode = "min")
trainer = pl.Trainer(
    max_epochs=train_config.max_epochs,
    log_every_n_steps=1,
    limit_val_batches = train_config.limit_val_batches,
    val_check_interval = train_config.val_check_interval,
    accelerator="gpu",
    accumulate_grad_batches= train_config.accumulate_grad_batches,
    fast_dev_run=False,
    precision="bf16-mixed",
    strategy="auto",
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        callback_ckpt,
    ],
    logger=pl.loggers.TensorBoardLogger("lightning_logs", name='gep'))
trainer.fit(pl_module,data_module)
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: lightning_logs/gep
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type         | Params | Mode
-----------------------------------------------
0 | model | ChromBERTGEP | 62.8 M | train
-----------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.095   Total estimated model params size (MB)
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
`Trainer.fit` stopped: `max_epochs=2` reached.

Evaluate the fine-tuned model

ChromBERT has been successfully fine-tuned! You can evaluate the model directly using pl_module.model. However, note that due to flash-attention settings, the dropout probability cannot be adjusted by model.eval(), which may introduce some randomness to the output.

So, to ensure consistent results,it’s suggested to save the fine-tuned checkpoint and reload it into the original model configuration for accurate evaluation.

In addition, we use only downsampled test data for evaluation in this tutorial to save time.

Evaluate the model fine-tuned with limited data

We first load the fine-tuned model and evaluate it on the downsampled test data here.

[13]:
gep_ft_ckpt = os.path.abspath(glob.glob('./lightning_logs/gep/version*/checkpoints/*.ckpt')[0])
model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4, dropout=0)
ft_model = model_config.init_model(finetune_ckpt = gep_ft_ckpt)

dl = data_module.test_dataloader()
ft_model.cuda()

with torch.no_grad():
    y_preds = []
    y_labels = []
    for idx, batch in enumerate(tqdm(dl, total=len(dl))):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/examples/tutorials/lightning_logs/gep/version_0/checkpoints/epoch=0-step=2.ckpt
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  new_state = torch.load(ckpt)
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
100%|██████████| 25/25 [00:23<00:00,  1.06it/s]
{'pearsonr': tensor(-0.0272), 'spearmanr': tensor(0.0601), 'mse': tensor(0.5545), 'mae': tensor(0.5203), 'r2': tensor(-0.1555)}

/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)  # noqa: B028

Performance of the Fine-Tuned Model

The model fine-tuned with limited data demonstrated suboptimal performance during evaluation. To address this, we provided a checkpoint fine-tuned on the entire dataset and will evaluate its performance here.

[14]:
gep_ft_ckpt = f'{gep_dir}/gep_fibroblast_to_myoblast.ckpt'
model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4, dropout=0)
ft_model = model_config.init_model(finetune_ckpt = gep_ft_ckpt)

dl = data_module.test_dataloader()
ft_model.cuda()

with torch.no_grad():
    y_preds = []
    y_labels = []
    for batch in tqdm(dl,total=len(dl)):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  new_state = torch.load(ckpt)
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
100%|██████████| 25/25 [00:23<00:00,  1.08it/s]
{'pearsonr': tensor(0.7469), 'spearmanr': tensor(0.5639), 'mse': tensor(0.2352), 'mae': tensor(0.3242), 'r2': tensor(0.5099)}

/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)  # noqa: B028

Infer key regulators

Using the fine-tuned model, we analyze embedding similarities between upregulated and unchanged genomic regions. Lower embedding similarity is hypothesized to indicate a greater functional shift, highlighting the potential role of key regulators in cell state transitions.

To save time, we perform the analysis on a downsampled set of upregulated and unchanged genomic regions.

Note: Only the selected factors are considered for this analysis, excluding histone modifications and chromatin accessibility.

[15]:
# Load the fine-tuned model
model_tuned = chrombert.get_preset_model_config(
    "gep",
    gep_flank_window = 4,
    dropout = 0,
    finetune_ckpt = f'{gep_dir}/gep_fibroblast_to_myoblast.ckpt').init_model() # use absolute path here, to avoid mixing of preset
# Get the embedding manager
model_emb = model_tuned.get_embedding_manager().cuda()
summary(model_emb)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  new_state = torch.load(ckpt)
[15]:
=====================================================================================
Layer (type:depth-idx)                                       Param #
=====================================================================================
ChromBERTEmbedding                                           --
├─PoolFlankWindow: 1-1                                       --
│    └─ChromBERT: 2-1                                        --
│    │    └─BERTEmbedding: 3-1                               4,916,736
│    │    └─ModuleList: 3-2                                  51,978,240
├─CistromeEmbeddingManager: 1-2                              --
=====================================================================================
Total params: 56,894,976
Trainable params: 56,894,976
Non-trainable params: 0
=====================================================================================

Gather regulator embeddings in upregulated gene

[16]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/up_data_sample.csv', batch_size = 32, num_workers = 4)
dl = dataset_config.init_dataloader()
up_gep_embs = []
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        up_gep_embs.append(emb)
up_gep_embs = torch.cat(up_gep_embs)
up_gep_embs.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
100%|██████████| 4/4 [00:43<00:00, 10.95s/it]
[16]:
torch.Size([100, 1073, 768])

Gather regulator embeddings in unchanged genes

[17]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/nochange_data_sample.csv', batch_size = 32, num_workers = 4)
dl = dataset_config.init_dataloader()
nochange_gep_embs = []
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        nochange_gep_embs.append(emb)
nochange_gep_embs = torch.cat(nochange_gep_embs)
nochange_gep_embs.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
100%|██████████| 4/4 [00:43<00:00, 10.98s/it]
[17]:
torch.Size([100, 1073, 768])

We consider only factors below, remove histone modifications and chromatin accessibility.

[18]:
with open(os.path.join(base_dir, "config","hg38_6k_factors_list.txt"),"r") as f:
    factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]


indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
up_gep_embs = up_gep_embs.mean(axis=0)[indices]
nochange_gep_embs = nochange_gep_embs.mean(axis=0)[indices]
up_gep_embs.shape, nochange_gep_embs.shape
[18]:
(torch.Size([991, 768]), torch.Size([991, 768]))

Analyze the embedding similarity between upregulated and unchanged genes

We calculate the cosine similarity between the embeddings of upregulated and unchanged genomic regions to identify key regulators.

[19]:
from sklearn.metrics.pairwise import cosine_similarity
gep_similarity = [cosine_similarity(up_gep_embs[i].reshape(1, -1), nochange_gep_embs[i].reshape(1, -1))[0, 0] for i in range(up_gep_embs.shape[0])]
gep_similarity_df = pd.DataFrame({'factors':names,'similarity':gep_similarity}).sort_values(by='similarity').reset_index(drop=True)
gep_similarity_df['rank']=gep_similarity_df.index + 1
gep_similarity_df.to_csv(f'{gep_dir}/gep_similarity_df.csv',index=False)
gep_similarity_df
[19]:
factors similarity rank
0 chd4 0.924975 1
1 esco2 0.935212 2
2 cbx7 0.946856 3
3 cbx6 0.946874 4
4 cbx8 0.949367 5
... ... ... ...
986 zbtb10 0.998598 987
987 zbed4 0.998614 988
988 nkx2-5 0.998627 989
989 snapc4 0.998678 990
990 dbp 0.998807 991

991 rows × 3 columns

[20]:
indentified_factor = gep_similarity_df[gep_similarity_df['rank']<=25]['factors'].tolist()
indentified_factor
[20]:
['chd4',
 'esco2',
 'cbx7',
 'cbx6',
 'cbx8',
 'brd7',
 'hira',
 'myf5',
 'ring1',
 'neurog2',
 'kdm6b',
 'ubn1',
 'nr3c2',
 'rpa2',
 'sumo1',
 'yap1',
 'ptpn11',
 'myod1',
 'klf11',
 'phf2',
 'prkdc',
 'brdu',
 'ssrp1',
 'tead',
 'tead1']
[21]:
gep_similarity_df[gep_similarity_df['factors']=='myod1']
[21]:
factors similarity rank
17 myod1 0.983192 18

We identified 25 key regulators in the transcriptome, including the notable regulator ‘MYOD1’.

Combine the regulator rankings from both chromatin accessibility and transcriptome

Final top 25 key regulators derived from average ranks across both chromatin accessibility and transcriptome

[22]:
chrom_accessibility_path = f'{gep_dir}/../chrom_accessibility/chromatin_accessibility_similarity_df.csv'
if not os.path.exists(chrom_accessibility_path):
    raise ValueError("Please follow the tutorial for key regulators inference during cell state transition: Chromatin accessibility")
else:
    chrom_acc_similarity_df = pd.read_csv(chrom_accessibility_path)
    average_rank_df = pd.merge(gep_similarity_df, chrom_acc_similarity_df, on='factors', how='inner', suffixes=('_gep', '_chrom_acc'))
    average_rank_df['averge_rank'] = ((average_rank_df['rank_gep']+average_rank_df['rank_chrom_acc'])/2).rank().astype(int)
    average_rank_df=average_rank_df.sort_values(by='averge_rank')
    average_rank_df

average_rank_df[average_rank_df['factors']=='myod1']
[22]:
factors similarity_gep rank_gep similarity_chrom_acc rank_chrom_acc averge_rank
17 myod1 0.983192 18 -0.064511 1 2
[23]:
final_indentified_factor = average_rank_df[average_rank_df['averge_rank']<=25]['factors'].tolist()
final_indentified_factor
[23]:
['myf5',
 'myod1',
 'neurog2',
 'yap1',
 'nr3c2',
 'tead1',
 'tead',
 'dux4',
 'pgbd3',
 'chd4',
 'myog',
 'hira',
 'pax3-foxo1a',
 'tbx5',
 'nr3c1',
 'snai2',
 'ss18',
 'prmt5',
 'ubn1',
 'rb1',
 'six2',
 'klf11',
 'ercc6',
 'sumo1',
 'esco2']
[ ]:

[ ]: