Example for key regulators inference during cell state transition: chromatin accessibility

To comprehensively infer key regulators during cell state transitions, it is crucial to integrate analyses of chromatin accessibility and the transcriptome. In this tutorial, we will demonstrate how to use ChromBERT to infer key regulators involved in a specific transdifferentiation process (fibroblast to myoblast) through chromatin accessibility analysis.

[1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import chrombert
import pandas as pd
import numpy as np
from torchinfo import summary
import subprocess
import torch
import lightning.pytorch as pl
base_dir =  os.path.expanduser("~/.cache/chrombert/data") # set the base directory for storing data default

Preprocess dataset

In this section, we prepare raw chromatin accessibility data, including peak and BigWig files for the cell types involved in transdifferentiation. This tutorial will guide you through formatting the data for use with ChromBERT.

To identify key regulators, we fine-tune ChromBERT to predict log2-transformed chromatin accessibility signal fold changes during transdifferentiation. This process involves preparing the transformed data. Additionally, to analyze the distances between regulator embeddings at upregulated and unchanged loci, we need to prepare and identify data for both types of loci. The preprocessing steps include:

  • Collecting peaks involved in transdifferentiation and TSS flank regions as background (extending 10 kb all TSSs).

  • Overlapping these regions with the 1 kb regions used in ChromBERT.

  • Extracting chromatin accessibility signals from BigWig files for each of these 1 kb regions.

  • Calculating log2-transformed chromatin accessibility changes.

[2]:

chromatin_accessibility_dir = f'{base_dir}/demo/transdifferentiation/chrom_accessibility'
[3]:
# Download the data needed for this tutorial
if not os.path.exists(f'{chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF184KAM/@@download/ENCFF184KAM.bed.gz -O {chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed'
     subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF361BTT/@@download/ENCFF361BTT.bigWig -O {chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig'
     subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF647RNC/@@download/ENCFF647RNC.bed.gz -O {chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed'
     subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF149ERN/@@download/ENCFF149ERN.bigWig -O {chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig'
     subprocess.run(cmd, shell=True)

Prepare merged peaks

Here we merge the peaks for the cell types involved in transdifferentiation to generate a comprehensive region list.
Then, we align the genomic coordinates of these regions with ChromBERT’s 1 kb bins.
[4]:
cmd = f"cat {chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed {chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed > {chromatin_accessibility_dir}/tmp_peak.bed"
subprocess.run(cmd, shell=True)
cmd = f"sort -k1,1 -k2,2n {chromatin_accessibility_dir}/tmp_peak.bed > {chromatin_accessibility_dir}/tmp_peak_sorted.bed"
subprocess.run(cmd, shell=True)
cmd = f"bedtools merge -i {chromatin_accessibility_dir}/tmp_peak_sorted.bed > {chromatin_accessibility_dir}/total_peak.bed"
subprocess.run(cmd, shell=True)
[4]:
CompletedProcess(args='bedtools merge -i /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/tmp_peak_sorted.bed > /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/total_peak.bed', returncode=0)
[5]:
from chrombert.scripts.chrombert_make_dataset import get_regions,process
chrom_regions = get_regions(base_dir,genome='hg38',high_resolution=False) # 1kb
total_peak_process = process(f'{chromatin_accessibility_dir}/total_peak.bed',chrom_regions,mode='region')[['chrom','start','end','build_region_index']]
len(total_peak_process),total_peak_process.head()
[5]:
(396195,
   chrom   start     end  build_region_index
 0  chr1  180000  181000                  38
 1  chr1  181000  182000                  39
 2  chr1  182000  183000                  40
 3  chr1  191000  192000                  46
 4  chr1  268000  269000                  54)

Generate background regions

In this study, we use TSS flank regions (within 10 kb of the transcription start site) as background samples to facilitate the fine-tuning process and identify key regulators.

[6]:
gep_df = pd.read_csv(f'{chromatin_accessibility_dir}/../transcriptome/fibroblast_to_myoblast_expression_changes.csv')
gep_df_tss_10kb = pd.DataFrame({'chrom':gep_df['chrom'],'start':gep_df['tss']-10000,'end':gep_df['tss']+10000})
gep_df_tss_10kb
gep_df_tss_10kb.to_csv(f'{chromatin_accessibility_dir}/gep_df_tss_10kb.bed',sep='\t',index=False,header=None)
gep_df_tss_10kb.head()
[6]:
chrom start end
0 chr19 58343492 58363492
1 chr19 58337718 58357718
2 chr10 50875675 50895675
3 chr12 9106229 9126229
4 chr12 9055163 9075163
[7]:
gep_df_tss_10kb_process = process(f'{chromatin_accessibility_dir}/gep_df_tss_10kb.bed',chrom_regions,mode='region').drop_duplicates(subset='build_region_index')[['chrom','start','end','build_region_index']]
len(gep_df_tss_10kb_process),gep_df_tss_10kb_process.head()
[7]:
(295682,
   chrom   start     end  build_region_index
 0  chr1  815000  816000                 126
 1  chr1  816000  817000                 127
 2  chr1  817000  818000                 128
 3  chr1  818000  819000                 129
 4  chr1  819000  820000                 130)

Collect total regions and further process

Here we concatenate the total peak and background regions to generate the total region. Then, we extract the chromatin accessibility signals for each region and perform log2 transformation.

[8]:
total_region_processed = pd.concat([total_peak_process,gep_df_tss_10kb_process],axis=0).drop_duplicates().reset_index(drop=True)
total_region_processed.to_csv(f'{chromatin_accessibility_dir}/total_region_processed.csv',index=False)
len(total_region_processed),total_region_processed.head()
[8]:
(614861,
   chrom   start     end  build_region_index
 0  chr1  180000  181000                  38
 1  chr1  181000  182000                  39
 2  chr1  182000  183000                  40
 3  chr1  191000  192000                  46
 4  chr1  268000  269000                  54)
[9]:
# Extract the chromatin accessibility signals

import bbi  # pip install pybbi
def bw_getSignal_bins(
    bw, regions:pd.DataFrame,name
    ):
    regions = regions.copy()
    with bbi.open(str(bw)) as bwf:
        mtx = bwf.stackup(regions["chrom"],regions["start"],regions["end"], bins=1, missing=0)
        mean= bwf.info["summary"]["mean"]
        mtx = mtx/mean
    df_signal = pd.DataFrame(data = mtx, columns = [f'{name}_signal'])
    return df_signal

fibroblast_signal = bw_getSignal_bins(bw=f'{chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig',regions=total_region_processed,name='fibroblast')
myoblast_signal = bw_getSignal_bins(bw=f'{chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig',regions=total_region_processed,name='myoblast')


[12]:
# Prepare the log2-transformed chromatin accessibility signal fold changes data

total_region_signal_processed = pd.concat([total_region_processed,fibroblast_signal,myoblast_signal],axis=1)
total_region_signal_processed['fold_change'] = np.log2(1+total_region_signal_processed['myoblast_signal']) - np.log2(1+total_region_signal_processed['fibroblast_signal'])
total_region_signal_processed
chrom_accessibility_df = (
    total_region_signal_processed[
        ['chrom','start','end','build_region_index','fold_change','fibroblast_signal','myoblast_signal']
    ].rename(columns={'fold_change':'label'})
)
chrom_accessibility_df.to_csv(f'{chromatin_accessibility_dir}/fibroblast_to_myoblast_chrom_accessibility_changes.csv',index=False)
chrom_accessibility_df.head()
[12]:
chrom start end build_region_index label fibroblast_signal myoblast_signal
0 chr1 180000 181000 38 0.091821 0.066471 0.136553
1 chr1 181000 182000 39 0.001996 2.543848 2.548754
2 chr1 182000 183000 40 0.142237 0.102401 0.216626
3 chr1 191000 192000 46 -0.678009 1.386900 0.491877
4 chr1 268000 269000 54 -0.329329 1.324022 0.849704

Prepare fine-tuning data

The dataset is split into training, testing, and validation sets in an 8:1:1 ratio.
To make the fine-tuning process simpler and faster, the data is then reduced to match these proportions.
[13]:
train_data = chrom_accessibility_df.sample(frac=0.8,random_state=55)
test_data = chrom_accessibility_df.drop(train_data.index).sample(frac=0.5,random_state=55)
valid_data = chrom_accessibility_df.drop(train_data.index).drop(test_data.index)

train_data_sample = train_data.sample(n=80,random_state=55)
test_data_sample = test_data.sample(n=20,random_state=55)
valid_data_sample = valid_data.sample(n=20,random_state=55)


train_data_sample.to_csv(f'{chromatin_accessibility_dir}/train_sample.csv',index=False)
test_data_sample.to_csv(f'{chromatin_accessibility_dir}/test_sample.csv',index=False)
valid_data_sample.to_csv(f'{chromatin_accessibility_dir}/valid_sample.csv',index=False)
train_data_sample.head()
[13]:
chrom start end build_region_index label fibroblast_signal myoblast_signal
58450 chr11 16348000 16349000 300336 -0.422661 1.251264 0.679549
81335 chr12 49808000 49809000 422035 1.398294 0.950350 4.140920
346805 chr7 148855000 148856000 1856383 -0.905614 2.071367 0.639512
35488 chr1 246417000 246418000 182254 -2.684815 27.481988 3.429557
247095 chr3 132783000 132784000 1302011 3.011465 3.356764 34.132198

Identification of upregulated and unchanged genomic regions

Genomic regions with a log2-transformed chromatin accessibility difference greater than 2 were classified as upregulated, indicating increased accessibility. Regions with insufficient coverage were excluded. Unchanged regions were identified by selecting the 40,000 loci with the smallest absolute fold changes in chromatin accessibility.

[14]:
chrom_accessibility_df = pd.read_csv(f'{chromatin_accessibility_dir}/fibroblast_to_myoblast_chrom_accessibility_changes.csv')
up_data = chrom_accessibility_df[chrom_accessibility_df['label']>2]

covered_region = chrom_accessibility_df[(chrom_accessibility_df['fibroblast_signal']>0) & (chrom_accessibility_df['myoblast_signal']>0)]
covered_region['label_abs'] = np.abs(covered_region['label'])
nochange_data = covered_region.sort_values('label_abs').reset_index(drop=True).iloc[0:40000]


up_data_sample = up_data.sample(n=100,random_state=55)
nochange_data_sample = nochange_data.sample(n=100,random_state=55)


up_data_sample.to_csv(f'{chromatin_accessibility_dir}/up_data_sample.csv',index=False)
nochange_data_sample.to_csv(f'{chromatin_accessibility_dir}/nochange_data_sample.csv',index=False)
/tmp/ipykernel_405465/1160250605.py:5: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  covered_region['label_abs'] = np.abs(covered_region['label'])

Fine-tune

This section provides a tutorial for fine-tuning ChromBERTs to predict genome-wide changes in chromatin accessibility.
The fine-tuning process for chromatin accessibility is similar to the general process:
  • Dataset configuration: Use the general preset dataset configuration.

  • Model instantiation: Use the general preset model configuration.

[15]:
# Use the `general` preset dataset config
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = None, batch_size = 4, num_workers = 4)
dataset_config
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
[15]:
DatasetConfig({'hdf5_file': '/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5', 'supervised_file': None, 'kind': 'GeneralDataset', 'meta_file': '/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json', 'ignore': False, 'ignore_object': None, 'batch_size': 4, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'perturbation': False, 'perturbation_object': None, 'perturbation_value': 0, 'prompt_kind': None, 'prompt_regulator': None, 'prompt_regulator_cache_file': None, 'prompt_celltype': None, 'prompt_celltype_cache_file': None, 'prompt_regulator_cache_pin_memory': False, 'prompt_regulator_cache_limit': 3, 'fasta_file': None, 'flank_window': 0})
[16]:
# We use the `LitChromBERTFTDataModule` to load the data and facilitate the fine-tuning process.
data_module = chrombert.LitChromBERTFTDataModule(
    config = dataset_config,
    train_params = {'supervised_file': f'{chromatin_accessibility_dir}/train_sample.csv'},
    val_params = {'supervised_file':f'{chromatin_accessibility_dir}/valid_sample.csv'},
    test_params = {'supervised_file':f'{chromatin_accessibility_dir}/test_sample.csv'}
)
data_module.setup()
[17]:
# we use `general` preset model config
model_config = chrombert.get_preset_model_config("general")
model_config
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
[17]:
ChromBERTFTConfig:
{
    "genome": "hg38",
    "task": "general",
    "dim_output": 1,
    "mtx_mask": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv",
    "dropout": 0.1,
    "pretrain_ckpt": "/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt",
    "finetune_ckpt": null,
    "ignore": false,
    "ignore_index": [
        null,
        null
    ],
    "gep_flank_window": 4,
    "gep_parallel_embedding": false,
    "gep_gradient_checkpoint": false,
    "gep_zero_inflation": false,
    "prompt_kind": "cistrome",
    "prompt_dim_external": 512,
    "dnabert2_ckpt": null
}
[18]:
model = model_config.init_model()
model.freeze_pretrain(2) ### freeze chrombert 6 transformer blocks during fine-tuning
summary(model)
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
[18]:
================================================================================
Layer (type:depth-idx)                                  Param #
================================================================================
ChromBERTGeneral                                        --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               --
│    │    └─TokenEmbedding: 3-1                         (7,680)
│    │    └─PositionalEmbedding: 3-2                    (4,909,056)
│    │    └─Dropout: 3-3                                --
│    └─ModuleList: 2-2                                  --
│    │    └─EncoderTransformerBlock: 3-4                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-5                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-6                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-7                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-8                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-9                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-10               6,497,280
│    │    └─EncoderTransformerBlock: 3-11               6,497,280
├─GeneralHeader: 1-2                                    --
│    └─CistromeEmbeddingManager: 2-3                    --
│    └─Conv2d: 2-4                                      769
│    └─ReLU: 2-5                                        --
│    └─ResidualBlock: 2-6                               --
│    │    └─Linear: 3-12                                1,099,776
│    │    └─Linear: 3-13                                1,049,600
│    │    └─LayerNorm: 3-14                             2,048
│    │    └─Linear: 3-15                                1,099,776
│    │    └─Dropout: 3-16                               --
│    └─ResidualBlock: 2-7                               --
│    │    └─Linear: 3-17                                787,200
│    │    └─Linear: 3-18                                590,592
│    │    └─LayerNorm: 3-19                             1,536
│    │    └─Linear: 3-20                                787,200
│    │    └─Dropout: 3-21                               --
│    └─ResidualBlock: 2-8                               --
│    │    └─Linear: 3-22                                196,864
│    │    └─Linear: 3-23                                65,792
│    │    └─LayerNorm: 3-24                             512
│    │    └─Linear: 3-25                                196,864
│    │    └─Dropout: 3-26                               --
│    └─Linear: 2-9                                      257
================================================================================
Total params: 62,773,762
Trainable params: 18,873,346
Non-trainable params: 43,900,416
================================================================================

The model is fine-tuned using PyTorch Lightning with a straightforward parameter configuration. To save time, fine-tuning is performed on a smaller dataset.

Note: Due to the stochastic nature of the tuning process, results may vary. For improved performance, consider increasing the number of epochs and expanding the dataset size.

[19]:
train_config = chrombert.finetune.TrainConfig(kind='regression',
                                              loss='rmse',
                                              max_epochs=2,
                                              accumulate_grad_batches=2,
                                              val_check_interval=2,
                                              limit_val_batches=10,
                                              tag='chrom_accessibility')
train_config
train_module = train_config.init_pl_module(model)
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)  # noqa: B028
[20]:
callback_ckpt = pl.callbacks.ModelCheckpoint(monitor = f"{train_config.tag}_validation/{train_config.loss}", mode = "min")
trainer = pl.Trainer(
    max_epochs=train_config.max_epochs,
    log_every_n_steps=1,
    limit_val_batches = train_config.limit_val_batches,
    val_check_interval = train_config.val_check_interval,
    accelerator="gpu",
    accumulate_grad_batches= train_config.accumulate_grad_batches,
    fast_dev_run=False,
    precision="bf16-mixed",
    strategy="auto",
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        callback_ckpt,
    ],
    logger=pl.loggers.TensorBoardLogger("lightning_logs", name='chrom_accessibility'),
    )
trainer.fit(train_module,data_module)
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: lightning_logs/chrom_accessibility
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type             | Params | Mode
---------------------------------------------------
0 | model | ChromBERTGeneral | 62.8 M | train
---------------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.095   Total estimated model params size (MB)
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
`Trainer.fit` stopped: `max_epochs=2` reached.
[ ]:

Evaluate the fine-tuned model

ChromBERT has been successfully fine-tuned! You can evaluate the model directly using pl_module.model. However, note that due to flash-attention settings, the dropout probability cannot be adjusted by model.eval(), which may introduce some randomness to the output.

So, to ensure consistent results,it’s suggested to save the fine-tuned checkpoint and reload it into the original model configuration for accurate evaluation.

In addition, we use only downsampled test data for evaluation in this tutorial to save time.

We first load the fine-tuned model and evaluate it on the downsampled test data here.

[21]:
import glob
chrom_accessibility_ft_ckpt = os.path.abspath(glob.glob('./lightning_logs/chrom_accessibility/version_*/checkpoints/*.ckpt')[0])
chrom_accessibility_ft_ckpt
model_config = chrombert.get_preset_model_config("general")
ft_model = model_config.init_model(finetune_ckpt = chrom_accessibility_ft_ckpt,dropout=0)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/examples/tutorials/lightning_logs/chrom_accessibility/version_0/checkpoints/epoch=1-step=12.ckpt
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  new_state = torch.load(ckpt)
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
[22]:
from tqdm import tqdm
import torchmetrics as tm
dl = data_module.test_dataloader()
ft_model = ft_model.cuda().eval()
with torch.no_grad():
    y_preds = []
    y_labels = []
    for batch in tqdm(dl,total=len(dl)):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)
100%|██████████| 5/5 [00:01<00:00,  2.74it/s]
{'pearsonr': tensor(0.1227), 'spearmanr': tensor(0.2571), 'mse': tensor(0.7629), 'mae': tensor(0.5203), 'r2': tensor(-0.0613)}

/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)  # noqa: B028

The model fine-tuned with limited data demonstrated suboptimal performance during evaluation. To address this, we provided a checkpoint fine-tuned on the entire dataset and will evaluate its performance here.

[28]:
chrom_accessibility_ft_ckpt = f'{chromatin_accessibility_dir}/chrom_accessibility_fibroblast_to_myoblast.ckpt'
model_config = chrombert.get_preset_model_config("general")
ft_model = model_config.init_model(finetune_ckpt = chrom_accessibility_ft_ckpt,dropout=0)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  new_state = torch.load(ckpt)
[29]:
from tqdm import tqdm
import torchmetrics as tm
dl = data_module.test_dataloader()
ft_model = ft_model.cuda().eval()
with torch.no_grad():
    y_preds = []
    y_labels = []
    for batch in tqdm(dl,total=len(dl)):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)
  0%|          | 0/5 [00:00<?, ?it/s]100%|██████████| 5/5 [00:01<00:00,  2.80it/s]
{'pearsonr': tensor(0.9576), 'spearmanr': tensor(0.8451), 'mse': tensor(0.0629), 'mae': tensor(0.2033), 'r2': tensor(0.9125)}

/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)  # noqa: B028

Infer key regulators

Using the fine-tuned model, we analyze embedding similarities between upregulated and unchanged genomic regions. Lower embedding similarity is hypothesized to indicate a greater functional shift, highlighting the potential role of key regulators in cell state transitions.

To save time, we perform the analysis on a downsampled set of upregulated and unchanged genomic regions.

Note: Only the selected factors are considered for this analysis, excluding histone modifications and chromatin accessibility.

[30]:
# Load the fine-tuned model
model_tuned = chrombert.get_preset_model_config(
    "general",
    dropout = 0,
    finetune_ckpt = f'{chromatin_accessibility_dir}/chrom_accessibility_fibroblast_to_myoblast.ckpt').init_model() # use absolute path here, to avoid mixing of preset
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  new_state = torch.load(ckpt)
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
[31]:
# Get the embedding manager
model_emb = model_tuned.get_embedding_manager().cuda()

[32]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chromatin_accessibility_dir}/up_data_sample.csv', batch_size = 32, num_workers = 4)
up_chrom_acc_embs=[]
dl = dataset_config.init_dataloader()
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        up_chrom_acc_embs.append(emb)
up_chrom_acc_embs = torch.cat(up_chrom_acc_embs,dim=0)
up_chrom_acc_embs.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
100%|██████████| 4/4 [00:05<00:00,  1.34s/it]
[32]:
torch.Size([100, 1073, 768])
[33]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chromatin_accessibility_dir}/nochange_data_sample.csv', batch_size = 32, num_workers = 4)
nochange_chrom_acc_embs=[]
dl = dataset_config.init_dataloader()
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        nochange_chrom_acc_embs.append(emb)
nochange_chrom_acc_embs = torch.cat(nochange_chrom_acc_embs,dim=0)
nochange_chrom_acc_embs.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
100%|██████████| 4/4 [00:05<00:00,  1.39s/it]
[33]:
torch.Size([100, 1073, 768])
[34]:
with open(os.path.join(base_dir, "config","hg38_6k_factors_list.txt"),"r") as f:
    factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]


indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
up_chrom_acc_embs = up_chrom_acc_embs.mean(axis=0)[indices]
nochange_chrom_acc_embs = nochange_chrom_acc_embs.mean(axis=0)[indices]
[35]:
up_chrom_acc_embs.shape, nochange_chrom_acc_embs.shape
[35]:
(torch.Size([991, 768]), torch.Size([991, 768]))

We calculate the cosine similarity between the embeddings of upregulated and unchanged genomic regions to identify key regulators.

[36]:
from sklearn.metrics.pairwise import cosine_similarity
chrom_acc_similarity = [cosine_similarity(up_chrom_acc_embs[i].reshape(1, -1), nochange_chrom_acc_embs[i].reshape(1, -1))[0, 0] for i in range(up_chrom_acc_embs.shape[0])]
chrom_acc_similarity_df = pd.DataFrame({'factors':names,'similarity':chrom_acc_similarity}).sort_values(by='similarity').reset_index(drop=True)
chrom_acc_similarity_df['rank']=chrom_acc_similarity_df.index + 1
chrom_acc_similarity_df.to_csv(f'{chromatin_accessibility_dir}/chromatin_accessibility_similarity_df.csv',index=False)
chrom_acc_similarity_df
[36]:
factors similarity rank
0 myod1 -0.064511 1
1 myf5 0.080093 2
2 myog 0.101807 3
3 pax3-foxo1a 0.107334 4
4 tead1 0.214081 5
... ... ... ...
986 znf250 0.984589 987
987 hoxa1 0.984709 988
988 zbtb10 0.985032 989
989 znf706 0.985040 990
990 hoxa10 0.986486 991

991 rows × 3 columns

[37]:
chrom_acc_similarity_df[chrom_acc_similarity_df['factors']=='myod1']
[37]:
factors similarity rank
0 myod1 -0.064511 1

We identified 25 key regulators in the chromatin accessibility, including the notable regulator ‘MYOD1’.

Final top 25 key regulators derived from average ranks across both chromatin accessibility and transcriptome.

[38]:
gep_similarity_path = f'{chromatin_accessibility_dir}/../transcriptome/gep_similarity_df.csv'
if not os.path.exists(gep_similarity_path):
    raise ValueError("Please follow the tutorial for key regulators inference during cell state transition: Transcriptome")
else:
    gep_similarity_df = pd.read_csv(gep_similarity_path)
    average_rank_df = pd.merge(gep_similarity_df, chrom_acc_similarity_df, on='factors', how='inner', suffixes=('_gep', '_chrom_acc'))
    average_rank_df['averge_rank'] = ((average_rank_df['rank_gep']+average_rank_df['rank_chrom_acc'])/2).rank().astype(int)
    average_rank_df=average_rank_df.sort_values(by='averge_rank')
    average_rank_df

average_rank_df[average_rank_df['factors']=='myod1']
[38]:
factors similarity_gep rank_gep similarity_chrom_acc rank_chrom_acc averge_rank
17 myod1 0.983192 18 -0.064511 1 2
[39]:
final_indentified_factor = average_rank_df[average_rank_df['averge_rank']<=25]['factors'].tolist()
final_indentified_factor
[39]:
['myf5',
 'myod1',
 'neurog2',
 'yap1',
 'nr3c2',
 'tead1',
 'tead',
 'dux4',
 'pgbd3',
 'chd4',
 'myog',
 'hira',
 'pax3-foxo1a',
 'tbx5',
 'nr3c1',
 'snai2',
 'ss18',
 'prmt5',
 'ubn1',
 'rb1',
 'six2',
 'klf11',
 'ercc6',
 'sumo1',
 'esco2']
[ ]:

[ ]: