Example for context-specific TRN: functional collaborations with EZH2 on funtional distinct loci¶
Inference of transcriptional regulatory networks (TRNs) at specific loci is a complex and dynamic process. In this tutorial, we will guide you through context-specific TRN analysis using ChromBERTs, with EZH2 serving as an example of functional collaborations at distinct loci.
Attention: You should go through thistutorialat first to get familiar with the basic usage of ChromBERT.
Preprocessing dataset¶
To identify classical and non-classical EZH2 sites in human embryonic stem cells (hESCs), we utilize the EZH2 peak dataset (GSM1003524) and the H3K27me3 peak dataset (GSM1498900). Classical EZH2 sites are defined as regions where EZH2 co-localizes with H3K27me3, while non-classical EZH2 sites are identified as regions where EZH2 is present without H3K27me3 co-localization.
[1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # to selected gpu used
import sys
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
import chrombert
from torchinfo import summary
import lightning.pytorch as pl
basedir = os.path.expanduser("~/.cache/chrombert/data")
/home/yangdongxu/.local/lib/python3.10/site-packages/pandas/core/arrays/masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).
from pandas.core import (
[2]:
peak_ezh2 = os.path.join(basedir, "demo", "ezh2", "hESC_GSM1003524_EZH2.bed")
!head -3 {peak_ezh2}
chr3 93470270 93470880 peak7668 2339 . 23.98034 243.39191 233.90086 232
chr2 91477646 91478694 peak6148 1127 . 10.28381 119.74835 112.79007 430
chr16 46390276 46390857 peak4186 1039 . 8.55891 110.78978 103.94416 350
[3]:
peak_h3k27me3 = os.path.join(basedir, "demo", "ezh2", "hESC_GSM1498900_H3K27me3.bed")
!head -3 {peak_h3k27me3}
chr10 100480580 100483730 peak1145 6.37280
chr10 100519337 100521146 peak1146 6.86372
chr10 100655029 100656371 peak1147 15.71588
[4]:
# Align genomic coordinates from the narrowPeak file to the Human-Cistrome-6k dataset regions
from chrombert.scripts.chrombert_make_dataset import get_overlap
ref_regions = os.path.join(basedir, "config", "hg38_6k_1kb_region.bed")
df1 = get_overlap(
supervised = peak_ezh2,
regions = ref_regions,
no_filter = False,
).assign(label = lambda df: df["label"] > 0 )
df2 = get_overlap(
supervised = peak_h3k27me3,
regions = ref_regions,
no_filter = True,
).assign(label = lambda df: df["label"] > 0 )
df1.head(), df2.head()
[4]:
( chrom start end build_region_index label
0 chr1 870000 871000 174 True
1 chr1 905000 906000 204 True
2 chr1 923000 924000 220 True
3 chr1 924000 925000 221 True
4 chr1 925000 926000 222 True,
chrom start end build_region_index label
0 chr1 10000 11000 0 False
1 chr1 16000 17000 1 False
2 chr1 17000 18000 2 False
3 chr1 29000 30000 3 False
4 chr1 30000 31000 4 False)
[5]:
df_supervised = df1.rename(columns = {"label": "EZH2"}).merge(df2).assign(label = lambda df: df["label"])
df_supervised
[5]:
| chrom | start | end | build_region_index | EZH2 | label | |
|---|---|---|---|---|---|---|
| 0 | chr1 | 870000 | 871000 | 174 | True | True |
| 1 | chr1 | 905000 | 906000 | 204 | True | True |
| 2 | chr1 | 923000 | 924000 | 220 | True | True |
| 3 | chr1 | 924000 | 925000 | 221 | True | False |
| 4 | chr1 | 925000 | 926000 | 222 | True | True |
| ... | ... | ... | ... | ... | ... | ... |
| 11003 | chrX | 154750000 | 154751000 | 2134828 | True | True |
| 11004 | chrY | 5001000 | 5002000 | 2135730 | True | False |
| 11005 | chrY | 10994000 | 10995000 | 2136403 | True | False |
| 11006 | chrY | 26670000 | 26671000 | 2137888 | True | False |
| 11007 | chrY | 26671000 | 26672000 | 2137889 | True | False |
11008 rows × 6 columns
[6]:
df_supervised.groupby("label").size() # that's a near balanced dataset
[6]:
label
False 5272
True 5736
dtype: int64
[7]:
# Then we split the dataset into training, validation and test sets
from sklearn.model_selection import train_test_split
df_train, df_temp = train_test_split(df_supervised, test_size=0.2, random_state=42, stratify = df_supervised['label'])
df_valid, df_test = train_test_split(df_temp, test_size=0.5, random_state=42, stratify = df_temp['label'])
os.makedirs("tmp_ezh2", exist_ok=True)
df_train.to_csv(os.path.join("tmp_ezh2", "train.csv"))
df_valid.to_csv(os.path.join("tmp_ezh2", "valid.csv"))
df_test.to_csv(os.path.join("tmp_ezh2", "test.csv"))
len(df_train), len(df_valid), len(df_test)
[7]:
(8806, 1101, 1101)
Fine-tune¶
In this section, we provide a tutorial on fine-tuning ChromBERTs for our specific task. The process closely follows the original ChromBERTs workflow, with a few important modifications:
Dataset Preparation: The
ignore_objectparameter is used to omit H3K27me3-related cistromes from the original ChromBERTs dataset, ensuring H3K27me3 does not interfere with the analysis.Model Instantiation: A special
ignore_indexparameter, derived from the dataset, is introduced to properly configure the model.
Let’s get started!
Instructions for dataset: omit specified regulators¶
[8]:
dc = chrombert.get_preset_dataset_config(
"general",
supervised_file = None,
ignore = False, ignore_object = "h3k27me3" # turn off omission
)
ds = dc.init_dataset(supervised_file = os.path.join("tmp_ezh2", "train.csv"))
ds[1]["input_ids"].shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
[8]:
torch.Size([6391])
[9]:
# We omit the h3k27me3 related cistrome, to avoid data leakage
dc = chrombert.get_preset_dataset_config(
"general",
supervised_file = None,
ignore = True, ignore_object = "h3k27me3"
)
ds = dc.init_dataset(supervised_file = os.path.join("tmp_ezh2", "train.csv"))
# Get ignore_index used to instantiate model.
# Currently, we only support same ignore object in one dataset,
# so it's ok to get ignore_index from any sample.
ignore_index = ds[0]["ignore_index"]
ds[1]["input_ids"].shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
[9]:
torch.Size([6188])
As shown above, the dataset functions as expected after omitting the specified cistromes. However, the input sequence length is reduced to 6185 from 6391, as 206 H3K27me3-related cistromes are omitted and do not participate in the training process.
A small note: the model is fine-tuned using PyTorch Lightning, and the dataset is wrapped in the lightning.pytorch.LightningDataModule class for seamless integration.
[10]:
data_module = chrombert.LitChromBERTFTDataModule(
config = dc,
train_params = dict(supervised_file = os.path.join("tmp_ezh2", "train.csv")),
val_params = dict(supervised_file = os.path.join("tmp_ezh2", "valid.csv")),
test_params = dict(supervised_file = os.path.join("tmp_ezh2", "test.csv")),
)
data_module
[10]:
<chrombert.finetune.dataset.data_module.LitChromBERTFTDataModule at 0x7f6abdb92f80>
Instantiate the Model¶
Next, we can instantiate the model using the ignore_index parameter.
[11]:
model = chrombert.get_preset_model_config(
"general",
ignore = True, ignore_index =ignore_index # ignore_index from above
).init_model()
model.freeze_pretrain(trainable=2) # we only fine-tune the last two layers
summary(model, depth=2)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
Ignoring 203 cistromes and 1 regulators
[11]:
================================================================================
Layer (type:depth-idx) Param #
================================================================================
ChromBERTGeneral --
├─ChromBERT: 1-1 --
│ └─BERTEmbedding: 2-1 (4,916,736)
│ └─ModuleList: 2-2 51,978,240
├─GeneralHeader: 1-2 --
│ └─CistromeEmbeddingManager: 2-3 --
│ └─Conv2d: 2-4 769
│ └─ReLU: 2-5 --
│ └─ResidualBlock: 2-6 3,249,152
│ └─ResidualBlock: 2-7 2,166,528
│ └─ResidualBlock: 2-8 460,032
│ └─Linear: 2-9 257
================================================================================
Total params: 62,771,714
Trainable params: 18,871,298
Non-trainable params: 43,900,416
================================================================================
Fine-tune¶
We fine-tune the model using PyTorch Lightning, employing a straightforward configuration to process parameters. The tuning is performed on a limited dataset to save time.
Note: The tuning process involves randomness, so results may vary. For improved performance, consider increasing the number of epochs and expanding the size of the dataset used.
[12]:
tc = chrombert.finetune.train.TrainConfig(
kind = "classification",
loss = "bce", # specify "bce" to use Binary Cross-Entropy (BCE) loss. Use "focal" to apply Focal Loss instead.
max_epochs = 1,
lr = 1e-4
)
pl_module = tc.init_pl_module(model) # wrap model with PyTorch Lightning module
type(pl_module)
[12]:
chrombert.finetune.train.pl_module.ClassificationPLModule
[13]:
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # for tensorboard compatibility
callback_ckpt = pl.callbacks.ModelCheckpoint( monitor = f"{tc.tag}_validation/{tc.loss}", mode = "min")
#
#
trainer = pl.Trainer(
max_epochs = tc.max_epochs,
accelerator = "gpu",
precision = "bf16-mixed",
fast_dev_run = False,
accumulate_grad_batches = 16,
logger = pl.loggers.TensorBoardLogger(os.path.join("tmp_ezh2","logs"), name = "ezh2"),
val_check_interval = 128,
limit_val_batches = 128,
log_every_n_steps = 1,
callbacks = [ callback_ckpt, pl.callbacks.LearningRateMonitor() ],
)
trainer.fit(pl_module, data_module)
pl_module.save_ckpt(os.path.join("tmp_ezh2", "ezh2.ckpt"))
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: tmp_ezh2/logs/ezh2
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
| Name | Type | Params
-------------------------------------------
0 | model | ChromBERTGeneral | 62.8 M
-------------------------------------------
18.9 M Trainable params
43.9 M Non-trainable params
62.8 M Total params
251.087 Total estimated model params size (MB)
/miniconda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:480: PossibleUserWarning: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=1` reached.
[ ]:
Use Fine-Tuned Model to Obtain Regulator Embeddings¶
ChromBERT has been successfully fine-tuned! You can directly access the tuned model using pl_module.model. However, due to specific settings in flash-attention, the dropout probability cannot be modified by model.eval(), which may introduce some randomness in the output.
To ensure consistent results, we recommend saving the checkpoint and loading it into the original model. This approach guarantees you are working with the fine-tuned version.
[14]:
model_tuned = chrombert.get_preset_model_config(
"general",
dropout = 0,
ignore = True, ignore_index =ignore_index, # ignore_index from above
finetune_ckpt = os.path.abspath(os.path.join("tmp_ezh2", "ezh2.ckpt")) # use absolute path here
).init_model()
# or use model_tuned = pl_module.model
summary(model_tuned, depth = 2)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/yangdongxu/work/source/repos/ChromBERT/examples/tutorials/tmp_ezh2/ezh2.ckpt
use organisim hg38; max sequence length is 6391
Ignoring 203 cistromes and 1 regulators
Loading checkpoint from /home/yangdongxu/work/source/repos/ChromBERT/examples/tutorials/tmp_ezh2/ezh2.ckpt
Loaded 110/110 parameters
[14]:
================================================================================
Layer (type:depth-idx) Param #
================================================================================
ChromBERTGeneral --
├─ChromBERT: 1-1 --
│ └─BERTEmbedding: 2-1 4,916,736
│ └─ModuleList: 2-2 51,978,240
├─GeneralHeader: 1-2 --
│ └─CistromeEmbeddingManager: 2-3 --
│ └─Conv2d: 2-4 769
│ └─ReLU: 2-5 --
│ └─ResidualBlock: 2-6 3,249,152
│ └─ResidualBlock: 2-7 2,166,528
│ └─ResidualBlock: 2-8 460,032
│ └─Linear: 2-9 257
================================================================================
Total params: 62,771,714
Trainable params: 62,771,714
Non-trainable params: 0
================================================================================
Then we can get the embedding manager following the instruction of tutorials about extracting embeddings.
[15]:
model_emb = model_tuned.get_embedding_manager().cuda()
summary(model_emb)
Ignoring 203 cistromes and 1 regulators
[15]:
================================================================================
Layer (type:depth-idx) Param #
================================================================================
ChromBERTEmbedding --
├─ChromBERT: 1-1 --
│ └─BERTEmbedding: 2-1 --
│ │ └─TokenEmbedding: 3-1 7,680
│ │ └─PositionalEmbedding: 3-2 4,909,056
│ │ └─Dropout: 3-3 --
│ └─ModuleList: 2-2 --
│ │ └─EncoderTransformerBlock: 3-4 6,497,280
│ │ └─EncoderTransformerBlock: 3-5 6,497,280
│ │ └─EncoderTransformerBlock: 3-6 6,497,280
│ │ └─EncoderTransformerBlock: 3-7 6,497,280
│ │ └─EncoderTransformerBlock: 3-8 6,497,280
│ │ └─EncoderTransformerBlock: 3-9 6,497,280
│ │ └─EncoderTransformerBlock: 3-10 6,497,280
│ │ └─EncoderTransformerBlock: 3-11 6,497,280
├─CistromeEmbeddingManager: 1-2 --
================================================================================
Total params: 56,894,976
Trainable params: 56,894,976
Non-trainable params: 0
================================================================================
[16]:
dc_test = data_module.test_config
ds_test = dc_test.init_dataset()
dl_test = dc_test.init_dataloader(batch_size = 1)
len(ds_test), list(ds_test[0].keys())
[16]:
(1101,
['input_ids',
'position_ids',
'region',
'build_region_index',
'ignore_index',
'label'])
[17]:
# Obtain embeddings for both classical and non-classical EZH2 sites
embs_classicial = []
embs_nonclassicial = []
for batch in tqdm(dl_test):
with torch.no_grad():
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.cuda()
emb = model_emb(batch)
if batch["label"].item() == 1:
embs_classicial.append(emb)
else:
embs_nonclassicial.append(emb)
print(len(embs_classicial), len(embs_nonclassicial))
embs_classicial = torch.cat(embs_classicial, dim = 0).cpu().numpy().mean(axis = 0)
embs_nonclassicial = torch.cat(embs_nonclassicial, dim = 0).cpu().numpy().mean(axis = 0)
embs_classicial.shape, embs_nonclassicial.shape
100%|██████████| 1101/1101 [01:28<00:00, 12.39it/s]
573 528
[17]:
((1072, 768), (1072, 768))
We focus exclusively on transcription factors, ignoring histone modifications and chromatin accessibility. This allows us to calculate the similarity between transcription factors, representing their potential interactions.
[18]:
with open(os.path.join(basedir, "config","hg38_6k_factors_list.txt"),"r") as f:
factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]
factors[:3], len(factors)
[18]:
(['adnp', 'aebp2', 'aff1'], 991)
[19]:
factors[-1]
[19]:
'zzz3'
[20]:
indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
embs_classicial = embs_classicial[indices]
embs_nonclassicial = embs_nonclassicial[indices]
[21]:
from sklearn.metrics.pairwise import cosine_similarity
cos_classicial_matrix = cosine_similarity(embs_classicial)
cos_nonclassicial_matrix = cosine_similarity(embs_nonclassicial)
df_cos_classicial = pd.DataFrame(cos_classicial_matrix, columns = names, index = names)
df_cos_nonclassicial = pd.DataFrame(cos_nonclassicial_matrix, columns = names, index = names)
[22]:
# we define threshold to select the most related regulators pairs
thre_class = np.percentile(cos_classicial_matrix.flatten(), 95)
thre_nonclass = np.percentile(cos_nonclassicial_matrix.flatten(), 95)
thre_class, thre_nonclass
[22]:
(0.5776264667510986, 0.5518490076065063)
Now, we identify TRNs associated with the non-classical functions of EZH2. As you can see, factors related to the classical functions of EZH2 are associated with the Polycomb complex, such as SUZ12. In contrast, factors linked to EZH2’s non-classical functions tend to be associated with transcriptional activation, including EP300 and STAT3.
[23]:
df_cos_ezh2 = pd.DataFrame(index =names, data = {"classical":df_cos_classicial.loc["ezh2",:],"nonclassical":df_cos_nonclassicial.loc["ezh2",:]})
df_cos_ezh2["diff"] = df_cos_ezh2["classical"] - df_cos_ezh2["nonclassical"]
df_cos_ezh2
[23]:
| classical | nonclassical | diff | |
|---|---|---|---|
| adnp | 0.193229 | 0.359759 | -0.166530 |
| aebp2 | 0.190192 | 0.309486 | -0.119294 |
| aff1 | 0.396427 | 0.479881 | -0.083454 |
| aff4 | 0.300634 | 0.452184 | -0.151550 |
| ago1 | 0.234017 | 0.291643 | -0.057626 |
| ... | ... | ... | ... |
| zscan5a | 0.218051 | 0.334104 | -0.116052 |
| zta | 0.181115 | 0.251131 | -0.070015 |
| zxdb | 0.213376 | 0.294275 | -0.080898 |
| zxdc | 0.149416 | 0.287630 | -0.138214 |
| zzz3 | 0.248419 | 0.346867 | -0.098448 |
991 rows × 3 columns
[24]:
df_cos_ezh2.query("classical > @thre_class ").sort_values("diff", ascending = False).head(10)
[24]:
| classical | nonclassical | diff | |
|---|---|---|---|
| ezh1 | 0.644102 | 0.578448 | 6.565380e-02 |
| pcgf1 | 0.627553 | 0.564524 | 6.302953e-02 |
| kdm2b | 0.600312 | 0.543989 | 5.632281e-02 |
| jarid2 | 0.660526 | 0.609975 | 5.055112e-02 |
| rybp | 0.594199 | 0.549152 | 4.504716e-02 |
| suz12 | 0.875248 | 0.833695 | 4.155296e-02 |
| bcor | 0.616664 | 0.587787 | 2.887654e-02 |
| eed | 0.597332 | 0.576189 | 2.114320e-02 |
| ezh2 | 1.000000 | 1.000000 | -3.576279e-07 |
| cbx8 | 0.584771 | 0.621309 | -3.653848e-02 |
[25]:
df_cos_ezh2.query("nonclassical > @thre_nonclass ").sort_values("diff", ascending = True).head(10)
[25]:
| classical | nonclassical | diff | |
|---|---|---|---|
| foxm1 | 0.395081 | 0.604784 | -0.209703 |
| med1 | 0.349747 | 0.556351 | -0.206604 |
| stat3 | 0.424211 | 0.614199 | -0.189988 |
| ep300 | 0.411506 | 0.599940 | -0.188434 |
| hinfp | 0.391241 | 0.577772 | -0.186532 |
| rela | 0.385995 | 0.568461 | -0.182466 |
| smarca4 | 0.383161 | 0.552061 | -0.168900 |
| brca1 | 0.415508 | 0.583485 | -0.167977 |
| stat1 | 0.431552 | 0.594642 | -0.163090 |
| e2f4 | 0.423399 | 0.568790 | -0.145391 |
The end¶
[ ]: