Example for key regulators inference during cell state transition: chromatin accessibility¶
To comprehensively infer key regulators during cell state transitions, it is crucial to integrate analyses of chromatin accessibility and the transcriptome. In this tutorial, we will demonstrate how to use ChromBERT to infer key regulators involved in a specific transdifferentiation process (fibroblast to myoblast) through chromatin accessibility analysis.
[1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import chrombert
import pandas as pd
import numpy as np
from torchinfo import summary
import subprocess
import torch
import lightning.pytorch as pl
base_dir = os.path.expanduser("~/.cache/chrombert/data") # set the base directory for storing data default
Preprocess dataset¶
In this section, we prepare raw chromatin accessibility data, including peak and BigWig files for the cell types involved in transdifferentiation. This tutorial will guide you through formatting the data for use with ChromBERT.
To identify key regulators, we fine-tune ChromBERT to predict log2-transformed chromatin accessibility signal fold changes during transdifferentiation. This process involves preparing the transformed data. Additionally, to analyze the distances between regulator embeddings at upregulated and unchanged loci, we need to prepare and identify data for both types of loci. The preprocessing steps include:
Collecting peaks involved in transdifferentiation and TSS flank regions as background (extending 10 kb all TSSs).
Overlapping these regions with the 1 kb regions used in ChromBERT.
Extracting chromatin accessibility signals from BigWig files for each of these 1 kb regions.
Calculating log2-transformed chromatin accessibility changes.
[2]:
chromatin_accessibility_dir = f'{base_dir}/demo/transdifferentiation/chrom_accessibility'
[3]:
# Download the data needed for this tutorial
if not os.path.exists(f'{chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed'):
cmd = f'wget https://www.encodeproject.org/files/ENCFF184KAM/@@download/ENCFF184KAM.bed.gz -O {chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed'
subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig'):
cmd = f'wget https://www.encodeproject.org/files/ENCFF361BTT/@@download/ENCFF361BTT.bigWig -O {chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig'
subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed'):
cmd = f'wget https://www.encodeproject.org/files/ENCFF647RNC/@@download/ENCFF647RNC.bed.gz -O {chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed'
subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig'):
cmd = f'wget https://www.encodeproject.org/files/ENCFF149ERN/@@download/ENCFF149ERN.bigWig -O {chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig'
subprocess.run(cmd, shell=True)
Prepare merged peaks¶
[4]:
cmd = f"cat {chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed {chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed > {chromatin_accessibility_dir}/tmp_peak.bed"
subprocess.run(cmd, shell=True)
cmd = f"sort -k1,1 -k2,2n {chromatin_accessibility_dir}/tmp_peak.bed > {chromatin_accessibility_dir}/tmp_peak_sorted.bed"
subprocess.run(cmd, shell=True)
cmd = f"bedtools merge -i {chromatin_accessibility_dir}/tmp_peak_sorted.bed > {chromatin_accessibility_dir}/total_peak.bed"
subprocess.run(cmd, shell=True)
[4]:
CompletedProcess(args='bedtools merge -i /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/tmp_peak_sorted.bed > /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/total_peak.bed', returncode=0)
[5]:
from chrombert.scripts.chrombert_make_dataset import get_regions,process
chrom_regions = get_regions(base_dir,genome='hg38',high_resolution=False) # 1kb
total_peak_process = process(f'{chromatin_accessibility_dir}/total_peak.bed',chrom_regions,mode='region')[['chrom','start','end','build_region_index']]
len(total_peak_process),total_peak_process.head()
[5]:
(396195,
chrom start end build_region_index
0 chr1 180000 181000 38
1 chr1 181000 182000 39
2 chr1 182000 183000 40
3 chr1 191000 192000 46
4 chr1 268000 269000 54)
Generate background regions¶
In this study, we use TSS flank regions (within 10 kb of the transcription start site) as background samples to facilitate the fine-tuning process and identify key regulators.
[6]:
gep_df = pd.read_csv(f'{chromatin_accessibility_dir}/../transcriptome/fibroblast_to_myoblast_expression_changes.csv')
gep_df_tss_10kb = pd.DataFrame({'chrom':gep_df['chrom'],'start':gep_df['tss']-10000,'end':gep_df['tss']+10000})
gep_df_tss_10kb
gep_df_tss_10kb.to_csv(f'{chromatin_accessibility_dir}/gep_df_tss_10kb.bed',sep='\t',index=False,header=None)
gep_df_tss_10kb.head()
[6]:
| chrom | start | end | |
|---|---|---|---|
| 0 | chr19 | 58343492 | 58363492 |
| 1 | chr19 | 58337718 | 58357718 |
| 2 | chr10 | 50875675 | 50895675 |
| 3 | chr12 | 9106229 | 9126229 |
| 4 | chr12 | 9055163 | 9075163 |
[7]:
gep_df_tss_10kb_process = process(f'{chromatin_accessibility_dir}/gep_df_tss_10kb.bed',chrom_regions,mode='region').drop_duplicates(subset='build_region_index')[['chrom','start','end','build_region_index']]
len(gep_df_tss_10kb_process),gep_df_tss_10kb_process.head()
[7]:
(295682,
chrom start end build_region_index
0 chr1 815000 816000 126
1 chr1 816000 817000 127
2 chr1 817000 818000 128
3 chr1 818000 819000 129
4 chr1 819000 820000 130)
Collect total regions and further process¶
Here we concatenate the total peak and background regions to generate the total region. Then, we extract the chromatin accessibility signals for each region and perform log2 transformation.
[8]:
total_region_processed = pd.concat([total_peak_process,gep_df_tss_10kb_process],axis=0).drop_duplicates().reset_index(drop=True)
total_region_processed.to_csv(f'{chromatin_accessibility_dir}/total_region_processed.csv',index=False)
len(total_region_processed),total_region_processed.head()
[8]:
(614861,
chrom start end build_region_index
0 chr1 180000 181000 38
1 chr1 181000 182000 39
2 chr1 182000 183000 40
3 chr1 191000 192000 46
4 chr1 268000 269000 54)
[9]:
# Extract the chromatin accessibility signals
import bbi # pip install pybbi
def bw_getSignal_bins(
bw, regions:pd.DataFrame,name
):
regions = regions.copy()
with bbi.open(str(bw)) as bwf:
mtx = bwf.stackup(regions["chrom"],regions["start"],regions["end"], bins=1, missing=0)
mean= bwf.info["summary"]["mean"]
mtx = mtx/mean
df_signal = pd.DataFrame(data = mtx, columns = [f'{name}_signal'])
return df_signal
fibroblast_signal = bw_getSignal_bins(bw=f'{chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig',regions=total_region_processed,name='fibroblast')
myoblast_signal = bw_getSignal_bins(bw=f'{chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig',regions=total_region_processed,name='myoblast')
[12]:
# Prepare the log2-transformed chromatin accessibility signal fold changes data
total_region_signal_processed = pd.concat([total_region_processed,fibroblast_signal,myoblast_signal],axis=1)
total_region_signal_processed['fold_change'] = np.log2(1+total_region_signal_processed['myoblast_signal']) - np.log2(1+total_region_signal_processed['fibroblast_signal'])
total_region_signal_processed
chrom_accessibility_df = (
total_region_signal_processed[
['chrom','start','end','build_region_index','fold_change','fibroblast_signal','myoblast_signal']
].rename(columns={'fold_change':'label'})
)
chrom_accessibility_df.to_csv(f'{chromatin_accessibility_dir}/fibroblast_to_myoblast_chrom_accessibility_changes.csv',index=False)
chrom_accessibility_df.head()
[12]:
| chrom | start | end | build_region_index | label | fibroblast_signal | myoblast_signal | |
|---|---|---|---|---|---|---|---|
| 0 | chr1 | 180000 | 181000 | 38 | 0.091821 | 0.066471 | 0.136553 |
| 1 | chr1 | 181000 | 182000 | 39 | 0.001996 | 2.543848 | 2.548754 |
| 2 | chr1 | 182000 | 183000 | 40 | 0.142237 | 0.102401 | 0.216626 |
| 3 | chr1 | 191000 | 192000 | 46 | -0.678009 | 1.386900 | 0.491877 |
| 4 | chr1 | 268000 | 269000 | 54 | -0.329329 | 1.324022 | 0.849704 |
Prepare fine-tuning data¶
[13]:
train_data = chrom_accessibility_df.sample(frac=0.8,random_state=55)
test_data = chrom_accessibility_df.drop(train_data.index).sample(frac=0.5,random_state=55)
valid_data = chrom_accessibility_df.drop(train_data.index).drop(test_data.index)
train_data_sample = train_data.sample(n=80,random_state=55)
test_data_sample = test_data.sample(n=20,random_state=55)
valid_data_sample = valid_data.sample(n=20,random_state=55)
train_data_sample.to_csv(f'{chromatin_accessibility_dir}/train_sample.csv',index=False)
test_data_sample.to_csv(f'{chromatin_accessibility_dir}/test_sample.csv',index=False)
valid_data_sample.to_csv(f'{chromatin_accessibility_dir}/valid_sample.csv',index=False)
train_data_sample.head()
[13]:
| chrom | start | end | build_region_index | label | fibroblast_signal | myoblast_signal | |
|---|---|---|---|---|---|---|---|
| 58450 | chr11 | 16348000 | 16349000 | 300336 | -0.422661 | 1.251264 | 0.679549 |
| 81335 | chr12 | 49808000 | 49809000 | 422035 | 1.398294 | 0.950350 | 4.140920 |
| 346805 | chr7 | 148855000 | 148856000 | 1856383 | -0.905614 | 2.071367 | 0.639512 |
| 35488 | chr1 | 246417000 | 246418000 | 182254 | -2.684815 | 27.481988 | 3.429557 |
| 247095 | chr3 | 132783000 | 132784000 | 1302011 | 3.011465 | 3.356764 | 34.132198 |
Identification of upregulated and unchanged genomic regions¶
Genomic regions with a log2-transformed chromatin accessibility difference greater than 2 were classified as upregulated, indicating increased accessibility. Regions with insufficient coverage were excluded. Unchanged regions were identified by selecting the 40,000 loci with the smallest absolute fold changes in chromatin accessibility.
[14]:
chrom_accessibility_df = pd.read_csv(f'{chromatin_accessibility_dir}/fibroblast_to_myoblast_chrom_accessibility_changes.csv')
up_data = chrom_accessibility_df[chrom_accessibility_df['label']>2]
covered_region = chrom_accessibility_df[(chrom_accessibility_df['fibroblast_signal']>0) & (chrom_accessibility_df['myoblast_signal']>0)]
covered_region['label_abs'] = np.abs(covered_region['label'])
nochange_data = covered_region.sort_values('label_abs').reset_index(drop=True).iloc[0:40000]
up_data_sample = up_data.sample(n=100,random_state=55)
nochange_data_sample = nochange_data.sample(n=100,random_state=55)
up_data_sample.to_csv(f'{chromatin_accessibility_dir}/up_data_sample.csv',index=False)
nochange_data_sample.to_csv(f'{chromatin_accessibility_dir}/nochange_data_sample.csv',index=False)
/tmp/ipykernel_405465/1160250605.py:5: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
covered_region['label_abs'] = np.abs(covered_region['label'])
Fine-tune¶
Dataset configuration: Use the
generalpreset dataset configuration.Model instantiation: Use the
generalpreset model configuration.
[15]:
# Use the `general` preset dataset config
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = None, batch_size = 4, num_workers = 4)
dataset_config
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
[15]:
DatasetConfig({'hdf5_file': '/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5', 'supervised_file': None, 'kind': 'GeneralDataset', 'meta_file': '/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json', 'ignore': False, 'ignore_object': None, 'batch_size': 4, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'perturbation': False, 'perturbation_object': None, 'perturbation_value': 0, 'prompt_kind': None, 'prompt_regulator': None, 'prompt_regulator_cache_file': None, 'prompt_celltype': None, 'prompt_celltype_cache_file': None, 'prompt_regulator_cache_pin_memory': False, 'prompt_regulator_cache_limit': 3, 'fasta_file': None, 'flank_window': 0})
[16]:
# We use the `LitChromBERTFTDataModule` to load the data and facilitate the fine-tuning process.
data_module = chrombert.LitChromBERTFTDataModule(
config = dataset_config,
train_params = {'supervised_file': f'{chromatin_accessibility_dir}/train_sample.csv'},
val_params = {'supervised_file':f'{chromatin_accessibility_dir}/valid_sample.csv'},
test_params = {'supervised_file':f'{chromatin_accessibility_dir}/test_sample.csv'}
)
data_module.setup()
[17]:
# we use `general` preset model config
model_config = chrombert.get_preset_model_config("general")
model_config
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
[17]:
ChromBERTFTConfig:
{
"genome": "hg38",
"task": "general",
"dim_output": 1,
"mtx_mask": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv",
"dropout": 0.1,
"pretrain_ckpt": "/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt",
"finetune_ckpt": null,
"ignore": false,
"ignore_index": [
null,
null
],
"gep_flank_window": 4,
"gep_parallel_embedding": false,
"gep_gradient_checkpoint": false,
"gep_zero_inflation": false,
"prompt_kind": "cistrome",
"prompt_dim_external": 512,
"dnabert2_ckpt": null
}
[18]:
model = model_config.init_model()
model.freeze_pretrain(2) ### freeze chrombert 6 transformer blocks during fine-tuning
summary(model)
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
[18]:
================================================================================
Layer (type:depth-idx) Param #
================================================================================
ChromBERTGeneral --
├─ChromBERT: 1-1 --
│ └─BERTEmbedding: 2-1 --
│ │ └─TokenEmbedding: 3-1 (7,680)
│ │ └─PositionalEmbedding: 3-2 (4,909,056)
│ │ └─Dropout: 3-3 --
│ └─ModuleList: 2-2 --
│ │ └─EncoderTransformerBlock: 3-4 (6,497,280)
│ │ └─EncoderTransformerBlock: 3-5 (6,497,280)
│ │ └─EncoderTransformerBlock: 3-6 (6,497,280)
│ │ └─EncoderTransformerBlock: 3-7 (6,497,280)
│ │ └─EncoderTransformerBlock: 3-8 (6,497,280)
│ │ └─EncoderTransformerBlock: 3-9 (6,497,280)
│ │ └─EncoderTransformerBlock: 3-10 6,497,280
│ │ └─EncoderTransformerBlock: 3-11 6,497,280
├─GeneralHeader: 1-2 --
│ └─CistromeEmbeddingManager: 2-3 --
│ └─Conv2d: 2-4 769
│ └─ReLU: 2-5 --
│ └─ResidualBlock: 2-6 --
│ │ └─Linear: 3-12 1,099,776
│ │ └─Linear: 3-13 1,049,600
│ │ └─LayerNorm: 3-14 2,048
│ │ └─Linear: 3-15 1,099,776
│ │ └─Dropout: 3-16 --
│ └─ResidualBlock: 2-7 --
│ │ └─Linear: 3-17 787,200
│ │ └─Linear: 3-18 590,592
│ │ └─LayerNorm: 3-19 1,536
│ │ └─Linear: 3-20 787,200
│ │ └─Dropout: 3-21 --
│ └─ResidualBlock: 2-8 --
│ │ └─Linear: 3-22 196,864
│ │ └─Linear: 3-23 65,792
│ │ └─LayerNorm: 3-24 512
│ │ └─Linear: 3-25 196,864
│ │ └─Dropout: 3-26 --
│ └─Linear: 2-9 257
================================================================================
Total params: 62,773,762
Trainable params: 18,873,346
Non-trainable params: 43,900,416
================================================================================
The model is fine-tuned using PyTorch Lightning with a straightforward parameter configuration. To save time, fine-tuning is performed on a smaller dataset.
Note: Due to the stochastic nature of the tuning process, results may vary. For improved performance, consider increasing the number of epochs and expanding the dataset size.
[19]:
train_config = chrombert.finetune.TrainConfig(kind='regression',
loss='rmse',
max_epochs=2,
accumulate_grad_batches=2,
val_check_interval=2,
limit_val_batches=10,
tag='chrom_accessibility')
train_config
train_module = train_config.init_pl_module(model)
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs) # noqa: B028
[20]:
callback_ckpt = pl.callbacks.ModelCheckpoint(monitor = f"{train_config.tag}_validation/{train_config.loss}", mode = "min")
trainer = pl.Trainer(
max_epochs=train_config.max_epochs,
log_every_n_steps=1,
limit_val_batches = train_config.limit_val_batches,
val_check_interval = train_config.val_check_interval,
accelerator="gpu",
accumulate_grad_batches= train_config.accumulate_grad_batches,
fast_dev_run=False,
precision="bf16-mixed",
strategy="auto",
callbacks=[
pl.callbacks.LearningRateMonitor(),
callback_ckpt,
],
logger=pl.loggers.TensorBoardLogger("lightning_logs", name='chrom_accessibility'),
)
trainer.fit(train_module,data_module)
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: lightning_logs/chrom_accessibility
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
| Name | Type | Params | Mode
---------------------------------------------------
0 | model | ChromBERTGeneral | 62.8 M | train
---------------------------------------------------
18.9 M Trainable params
43.9 M Non-trainable params
62.8 M Total params
251.095 Total estimated model params size (MB)
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
`Trainer.fit` stopped: `max_epochs=2` reached.
[ ]:
Evaluate the fine-tuned model¶
ChromBERT has been successfully fine-tuned! You can evaluate the model directly using pl_module.model. However, note that due to flash-attention settings, the dropout probability cannot be adjusted by model.eval(), which may introduce some randomness to the output.
So, to ensure consistent results,it’s suggested to save the fine-tuned checkpoint and reload it into the original model configuration for accurate evaluation.
In addition, we use only downsampled test data for evaluation in this tutorial to save time.
We first load the fine-tuned model and evaluate it on the downsampled test data here.
[21]:
import glob
chrom_accessibility_ft_ckpt = os.path.abspath(glob.glob('./lightning_logs/chrom_accessibility/version_*/checkpoints/*.ckpt')[0])
chrom_accessibility_ft_ckpt
model_config = chrombert.get_preset_model_config("general")
ft_model = model_config.init_model(finetune_ckpt = chrom_accessibility_ft_ckpt,dropout=0)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/examples/tutorials/lightning_logs/chrom_accessibility/version_0/checkpoints/epoch=1-step=12.ckpt
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
new_state = torch.load(ckpt)
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
[22]:
from tqdm import tqdm
import torchmetrics as tm
dl = data_module.test_dataloader()
ft_model = ft_model.cuda().eval()
with torch.no_grad():
y_preds = []
y_labels = []
for batch in tqdm(dl,total=len(dl)):
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].cuda()
y_pred = ft_model(batch).cpu()
y_label = batch['label'].cpu()
y_preds.append(y_pred)
y_labels.append(y_label)
y_preds = torch.cat(y_preds)
y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
"pearsonr": score_pearsonr,
"spearmanr": score_spearmanr,
"mse": score_mse,
"mae": score_mae,
"r2": score_r2,
}
print(scores)
100%|██████████| 5/5 [00:01<00:00, 2.74it/s]
{'pearsonr': tensor(0.1227), 'spearmanr': tensor(0.2571), 'mse': tensor(0.7629), 'mae': tensor(0.5203), 'r2': tensor(-0.0613)}
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs) # noqa: B028
The model fine-tuned with limited data demonstrated suboptimal performance during evaluation. To address this, we provided a checkpoint fine-tuned on the entire dataset and will evaluate its performance here.
[28]:
chrom_accessibility_ft_ckpt = f'{chromatin_accessibility_dir}/chrom_accessibility_fibroblast_to_myoblast.ckpt'
model_config = chrombert.get_preset_model_config("general")
ft_model = model_config.init_model(finetune_ckpt = chrom_accessibility_ft_ckpt,dropout=0)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
new_state = torch.load(ckpt)
[29]:
from tqdm import tqdm
import torchmetrics as tm
dl = data_module.test_dataloader()
ft_model = ft_model.cuda().eval()
with torch.no_grad():
y_preds = []
y_labels = []
for batch in tqdm(dl,total=len(dl)):
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].cuda()
y_pred = ft_model(batch).cpu()
y_label = batch['label'].cpu()
y_preds.append(y_pred)
y_labels.append(y_label)
y_preds = torch.cat(y_preds)
y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
"pearsonr": score_pearsonr,
"spearmanr": score_spearmanr,
"mse": score_mse,
"mae": score_mae,
"r2": score_r2,
}
print(scores)
0%| | 0/5 [00:00<?, ?it/s]100%|██████████| 5/5 [00:01<00:00, 2.80it/s]
{'pearsonr': tensor(0.9576), 'spearmanr': tensor(0.8451), 'mse': tensor(0.0629), 'mae': tensor(0.2033), 'r2': tensor(0.9125)}
/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs) # noqa: B028
Infer key regulators¶
Using the fine-tuned model, we analyze embedding similarities between upregulated and unchanged genomic regions. Lower embedding similarity is hypothesized to indicate a greater functional shift, highlighting the potential role of key regulators in cell state transitions.
To save time, we perform the analysis on a downsampled set of upregulated and unchanged genomic regions.
Note: Only the selected factors are considered for this analysis, excluding histone modifications and chromatin accessibility.
[30]:
# Load the fine-tuned model
model_tuned = chrombert.get_preset_model_config(
"general",
dropout = 0,
finetune_ckpt = f'{chromatin_accessibility_dir}/chrom_accessibility_fibroblast_to_myoblast.ckpt').init_model() # use absolute path here, to avoid mixing of preset
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
use organisim hg38; max sequence length is 6391
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/base/model.py:56: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/chrombert/finetune/model/basic_model.py:49: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
new_state = torch.load(ckpt)
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters
[31]:
# Get the embedding manager
model_emb = model_tuned.get_embedding_manager().cuda()
[32]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chromatin_accessibility_dir}/up_data_sample.csv', batch_size = 32, num_workers = 4)
up_chrom_acc_embs=[]
dl = dataset_config.init_dataloader()
for batch in tqdm(dl):
with torch.no_grad():
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.cuda()
emb = model_emb(batch).cpu()
up_chrom_acc_embs.append(emb)
up_chrom_acc_embs = torch.cat(up_chrom_acc_embs,dim=0)
up_chrom_acc_embs.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
100%|██████████| 4/4 [00:05<00:00, 1.34s/it]
[32]:
torch.Size([100, 1073, 768])
[33]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chromatin_accessibility_dir}/nochange_data_sample.csv', batch_size = 32, num_workers = 4)
nochange_chrom_acc_embs=[]
dl = dataset_config.init_dataloader()
for batch in tqdm(dl):
with torch.no_grad():
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.cuda()
emb = model_emb(batch).cpu()
nochange_chrom_acc_embs.append(emb)
nochange_chrom_acc_embs = torch.cat(nochange_chrom_acc_embs,dim=0)
nochange_chrom_acc_embs.shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
100%|██████████| 4/4 [00:05<00:00, 1.39s/it]
[33]:
torch.Size([100, 1073, 768])
[34]:
with open(os.path.join(base_dir, "config","hg38_6k_factors_list.txt"),"r") as f:
factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]
indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
up_chrom_acc_embs = up_chrom_acc_embs.mean(axis=0)[indices]
nochange_chrom_acc_embs = nochange_chrom_acc_embs.mean(axis=0)[indices]
[35]:
up_chrom_acc_embs.shape, nochange_chrom_acc_embs.shape
[35]:
(torch.Size([991, 768]), torch.Size([991, 768]))
We calculate the cosine similarity between the embeddings of upregulated and unchanged genomic regions to identify key regulators.
[36]:
from sklearn.metrics.pairwise import cosine_similarity
chrom_acc_similarity = [cosine_similarity(up_chrom_acc_embs[i].reshape(1, -1), nochange_chrom_acc_embs[i].reshape(1, -1))[0, 0] for i in range(up_chrom_acc_embs.shape[0])]
chrom_acc_similarity_df = pd.DataFrame({'factors':names,'similarity':chrom_acc_similarity}).sort_values(by='similarity').reset_index(drop=True)
chrom_acc_similarity_df['rank']=chrom_acc_similarity_df.index + 1
chrom_acc_similarity_df.to_csv(f'{chromatin_accessibility_dir}/chromatin_accessibility_similarity_df.csv',index=False)
chrom_acc_similarity_df
[36]:
| factors | similarity | rank | |
|---|---|---|---|
| 0 | myod1 | -0.064511 | 1 |
| 1 | myf5 | 0.080093 | 2 |
| 2 | myog | 0.101807 | 3 |
| 3 | pax3-foxo1a | 0.107334 | 4 |
| 4 | tead1 | 0.214081 | 5 |
| ... | ... | ... | ... |
| 986 | znf250 | 0.984589 | 987 |
| 987 | hoxa1 | 0.984709 | 988 |
| 988 | zbtb10 | 0.985032 | 989 |
| 989 | znf706 | 0.985040 | 990 |
| 990 | hoxa10 | 0.986486 | 991 |
991 rows × 3 columns
[37]:
chrom_acc_similarity_df[chrom_acc_similarity_df['factors']=='myod1']
[37]:
| factors | similarity | rank | |
|---|---|---|---|
| 0 | myod1 | -0.064511 | 1 |
We identified 25 key regulators in the chromatin accessibility, including the notable regulator ‘MYOD1’.
Final top 25 key regulators derived from average ranks across both chromatin accessibility and transcriptome.
[38]:
gep_similarity_path = f'{chromatin_accessibility_dir}/../transcriptome/gep_similarity_df.csv'
if not os.path.exists(gep_similarity_path):
raise ValueError("Please follow the tutorial for key regulators inference during cell state transition: Transcriptome")
else:
gep_similarity_df = pd.read_csv(gep_similarity_path)
average_rank_df = pd.merge(gep_similarity_df, chrom_acc_similarity_df, on='factors', how='inner', suffixes=('_gep', '_chrom_acc'))
average_rank_df['averge_rank'] = ((average_rank_df['rank_gep']+average_rank_df['rank_chrom_acc'])/2).rank().astype(int)
average_rank_df=average_rank_df.sort_values(by='averge_rank')
average_rank_df
average_rank_df[average_rank_df['factors']=='myod1']
[38]:
| factors | similarity_gep | rank_gep | similarity_chrom_acc | rank_chrom_acc | averge_rank | |
|---|---|---|---|---|---|---|
| 17 | myod1 | 0.983192 | 18 | -0.064511 | 1 | 2 |
[39]:
final_indentified_factor = average_rank_df[average_rank_df['averge_rank']<=25]['factors'].tolist()
final_indentified_factor
[39]:
['myf5',
'myod1',
'neurog2',
'yap1',
'nr3c2',
'tead1',
'tead',
'dux4',
'pgbd3',
'chd4',
'myog',
'hira',
'pax3-foxo1a',
'tbx5',
'nr3c1',
'snai2',
'ss18',
'prmt5',
'ubn1',
'rb1',
'six2',
'klf11',
'ercc6',
'sumo1',
'esco2']
[ ]:
[ ]: