Example for context-specific TRN: functional collaborations with EZH2 on funtional distinct loci

Inference of transcriptional regulatory networks (TRNs) at specific loci is a complex and dynamic process. In this tutorial, we will guide you through context-specific TRN analysis using ChromBERTs, with EZH2 serving as an example of functional collaborations at distinct loci.

Attention: You should go through thistutorialat first to get familiar with the basic usage of ChromBERT.

Preprocessing dataset

To identify classical and non-classical EZH2 sites in human embryonic stem cells (hESCs), we utilize the EZH2 peak dataset (GSM1003524) and the H3K27me3 peak dataset (GSM1498900). Classical EZH2 sites are defined as regions where EZH2 co-localizes with H3K27me3, while non-classical EZH2 sites are identified as regions where EZH2 is present without H3K27me3 co-localization.

[1]:
import  os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # to selected gpu used

import sys
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
import chrombert
from torchinfo import summary

import lightning.pytorch as pl
basedir =  os.path.expanduser("~/.cache/chrombert/data")


/home/yangdongxu/.local/lib/python3.10/site-packages/pandas/core/arrays/masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).
  from pandas.core import (
[2]:
peak_ezh2 = os.path.join(basedir, "demo", "ezh2", "hESC_GSM1003524_EZH2.bed")
!head -3 {peak_ezh2}
chr3    93470270        93470880        peak7668        2339    .       23.98034        243.39191       233.90086       232
chr2    91477646        91478694        peak6148        1127    .       10.28381        119.74835       112.79007       430
chr16   46390276        46390857        peak4186        1039    .       8.55891 110.78978       103.94416       350
[3]:
peak_h3k27me3 = os.path.join(basedir, "demo", "ezh2", "hESC_GSM1498900_H3K27me3.bed")
!head -3 {peak_h3k27me3}
chr10   100480580       100483730       peak1145        6.37280
chr10   100519337       100521146       peak1146        6.86372
chr10   100655029       100656371       peak1147        15.71588
[4]:
# Align genomic coordinates from the narrowPeak file to the Human-Cistrome-6k dataset regions
from chrombert.scripts.chrombert_make_dataset import get_overlap
ref_regions = os.path.join(basedir, "config", "hg38_6k_1kb_region.bed")
df1 = get_overlap(
    supervised = peak_ezh2,
    regions = ref_regions,
    no_filter = False,
).assign(label = lambda df: df["label"] > 0 )
df2 = get_overlap(
    supervised = peak_h3k27me3,
    regions = ref_regions,
    no_filter = True,
).assign(label = lambda df: df["label"] > 0 )
df1.head(), df2.head()
[4]:
(  chrom   start     end  build_region_index  label
 0  chr1  870000  871000                 174   True
 1  chr1  905000  906000                 204   True
 2  chr1  923000  924000                 220   True
 3  chr1  924000  925000                 221   True
 4  chr1  925000  926000                 222   True,
   chrom  start    end  build_region_index  label
 0  chr1  10000  11000                   0  False
 1  chr1  16000  17000                   1  False
 2  chr1  17000  18000                   2  False
 3  chr1  29000  30000                   3  False
 4  chr1  30000  31000                   4  False)
[5]:
df_supervised = df1.rename(columns = {"label": "EZH2"}).merge(df2).assign(label = lambda df: df["label"])
df_supervised
[5]:
chrom start end build_region_index EZH2 label
0 chr1 870000 871000 174 True True
1 chr1 905000 906000 204 True True
2 chr1 923000 924000 220 True True
3 chr1 924000 925000 221 True False
4 chr1 925000 926000 222 True True
... ... ... ... ... ... ...
11003 chrX 154750000 154751000 2134828 True True
11004 chrY 5001000 5002000 2135730 True False
11005 chrY 10994000 10995000 2136403 True False
11006 chrY 26670000 26671000 2137888 True False
11007 chrY 26671000 26672000 2137889 True False

11008 rows × 6 columns

[6]:
df_supervised.groupby("label").size() # that's a near balanced dataset
[6]:
label
False    5272
True     5736
dtype: int64
[7]:
# Then we split the dataset into training, validation and test sets

from sklearn.model_selection import train_test_split
df_train, df_temp = train_test_split(df_supervised, test_size=0.2, random_state=42, stratify = df_supervised['label'])
df_valid, df_test = train_test_split(df_temp, test_size=0.5, random_state=42, stratify = df_temp['label'])

os.makedirs("tmp_ezh2", exist_ok=True)
df_train.to_csv(os.path.join("tmp_ezh2", "train.csv"))
df_valid.to_csv(os.path.join("tmp_ezh2", "valid.csv"))
df_test.to_csv(os.path.join("tmp_ezh2", "test.csv"))

len(df_train), len(df_valid), len(df_test)
[7]:
(8806, 1101, 1101)

Fine-tune

In this section, we provide a tutorial on fine-tuning ChromBERTs for our specific task. The process closely follows the original ChromBERTs workflow, with a few important modifications:

  • Dataset Preparation: The ignore_object parameter is used to omit H3K27me3-related cistromes from the original ChromBERTs dataset, ensuring H3K27me3 does not interfere with the analysis.

  • Model Instantiation: A special ignore_index parameter, derived from the dataset, is introduced to properly configure the model.

Let’s get started!

Instructions for dataset: omit specified regulators

[8]:
dc = chrombert.get_preset_dataset_config(
    "general",
    supervised_file = None,
    ignore = False, ignore_object = "h3k27me3" # turn off omission
    )
ds = dc.init_dataset(supervised_file = os.path.join("tmp_ezh2", "train.csv"))
ds[1]["input_ids"].shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
[8]:
torch.Size([6391])
[9]:
# We omit the h3k27me3 related cistrome, to avoid data leakage
dc = chrombert.get_preset_dataset_config(
    "general",
    supervised_file = None,
    ignore = True, ignore_object = "h3k27me3"
    )
ds = dc.init_dataset(supervised_file = os.path.join("tmp_ezh2", "train.csv"))

# Get ignore_index used to instantiate model.
# Currently, we only support same ignore object in one dataset,
# so it's ok to get ignore_index from any sample.
ignore_index = ds[0]["ignore_index"]
ds[1]["input_ids"].shape
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
[9]:
torch.Size([6188])

As shown above, the dataset functions as expected after omitting the specified cistromes. However, the input sequence length is reduced to 6185 from 6391, as 206 H3K27me3-related cistromes are omitted and do not participate in the training process.

A small note: the model is fine-tuned using PyTorch Lightning, and the dataset is wrapped in the lightning.pytorch.LightningDataModule class for seamless integration.

[10]:
data_module = chrombert.LitChromBERTFTDataModule(
    config = dc,
    train_params = dict(supervised_file = os.path.join("tmp_ezh2", "train.csv")),
    val_params = dict(supervised_file = os.path.join("tmp_ezh2", "valid.csv")),
    test_params = dict(supervised_file = os.path.join("tmp_ezh2", "test.csv")),
)
data_module
[10]:
<chrombert.finetune.dataset.data_module.LitChromBERTFTDataModule at 0x7f6abdb92f80>

Instantiate the Model

Next, we can instantiate the model using the ignore_index parameter.

[11]:
model = chrombert.get_preset_model_config(
    "general",
    ignore = True, ignore_index =ignore_index  # ignore_index from above
).init_model()
model.freeze_pretrain(trainable=2) # we only fine-tune the last two layers
summary(model, depth=2)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
Ignoring 203 cistromes and 1 regulators
[11]:
================================================================================
Layer (type:depth-idx)                                  Param #
================================================================================
ChromBERTGeneral                                        --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               (4,916,736)
│    └─ModuleList: 2-2                                  51,978,240
├─GeneralHeader: 1-2                                    --
│    └─CistromeEmbeddingManager: 2-3                    --
│    └─Conv2d: 2-4                                      769
│    └─ReLU: 2-5                                        --
│    └─ResidualBlock: 2-6                               3,249,152
│    └─ResidualBlock: 2-7                               2,166,528
│    └─ResidualBlock: 2-8                               460,032
│    └─Linear: 2-9                                      257
================================================================================
Total params: 62,771,714
Trainable params: 18,871,298
Non-trainable params: 43,900,416
================================================================================

Fine-tune

We fine-tune the model using PyTorch Lightning, employing a straightforward configuration to process parameters. The tuning is performed on a limited dataset to save time.

Note: The tuning process involves randomness, so results may vary. For improved performance, consider increasing the number of epochs and expanding the size of the dataset used.

[12]:
tc = chrombert.finetune.train.TrainConfig(
    kind = "classification",
    loss = "bce", # specify "bce" to use Binary Cross-Entropy (BCE) loss. Use "focal" to apply Focal Loss instead.
    max_epochs = 1,
    lr = 1e-4
)
pl_module = tc.init_pl_module(model) # wrap model with PyTorch Lightning module
type(pl_module)
[12]:
chrombert.finetune.train.pl_module.ClassificationPLModule
Next, we begin the tuning process!
The trainer will save logs in a format compatible with TensorBoard, and multiple checkpoints may be generated during the process.
For this tutorial, however, we will use the latest model parameters instead of the checkpoints, as the tuning is insufficient.
[13]:
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # for tensorboard compatibility
callback_ckpt = pl.callbacks.ModelCheckpoint( monitor = f"{tc.tag}_validation/{tc.loss}", mode = "min")
#
#
trainer = pl.Trainer(
    max_epochs = tc.max_epochs,
    accelerator = "gpu",
    precision = "bf16-mixed",
    fast_dev_run = False,
    accumulate_grad_batches = 16,
    logger = pl.loggers.TensorBoardLogger(os.path.join("tmp_ezh2","logs"), name = "ezh2"),
    val_check_interval = 128,
    limit_val_batches = 128,
    log_every_n_steps = 1,
    callbacks = [ callback_ckpt, pl.callbacks.LearningRateMonitor() ],
)
trainer.fit(pl_module, data_module)
pl_module.save_ckpt(os.path.join("tmp_ezh2", "ezh2.ckpt"))
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: tmp_ezh2/logs/ezh2
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type             | Params
-------------------------------------------
0 | model | ChromBERTGeneral | 62.8 M
-------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.087   Total estimated model params size (MB)
/miniconda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:480: PossibleUserWarning: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=1` reached.
[ ]:

Use Fine-Tuned Model to Obtain Regulator Embeddings

ChromBERT has been successfully fine-tuned! You can directly access the tuned model using pl_module.model. However, due to specific settings in flash-attention, the dropout probability cannot be modified by model.eval(), which may introduce some randomness in the output.

To ensure consistent results, we recommend saving the checkpoint and loading it into the original model. This approach guarantees you are working with the fine-tuned version.

[14]:
model_tuned = chrombert.get_preset_model_config(
    "general",
    dropout = 0,
    ignore = True, ignore_index =ignore_index,   # ignore_index from above
    finetune_ckpt = os.path.abspath(os.path.join("tmp_ezh2", "ezh2.ckpt")) # use absolute path here
).init_model()
# or use model_tuned = pl_module.model
summary(model_tuned, depth = 2)
update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/yangdongxu/work/source/repos/ChromBERT/examples/tutorials/tmp_ezh2/ezh2.ckpt
use organisim hg38; max sequence length is 6391
Ignoring 203 cistromes and 1 regulators
Loading checkpoint from /home/yangdongxu/work/source/repos/ChromBERT/examples/tutorials/tmp_ezh2/ezh2.ckpt
Loaded 110/110 parameters
[14]:
================================================================================
Layer (type:depth-idx)                                  Param #
================================================================================
ChromBERTGeneral                                        --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               4,916,736
│    └─ModuleList: 2-2                                  51,978,240
├─GeneralHeader: 1-2                                    --
│    └─CistromeEmbeddingManager: 2-3                    --
│    └─Conv2d: 2-4                                      769
│    └─ReLU: 2-5                                        --
│    └─ResidualBlock: 2-6                               3,249,152
│    └─ResidualBlock: 2-7                               2,166,528
│    └─ResidualBlock: 2-8                               460,032
│    └─Linear: 2-9                                      257
================================================================================
Total params: 62,771,714
Trainable params: 62,771,714
Non-trainable params: 0
================================================================================

Then we can get the embedding manager following the instruction of tutorials about extracting embeddings.

[15]:
model_emb = model_tuned.get_embedding_manager().cuda()
summary(model_emb)
Ignoring 203 cistromes and 1 regulators
[15]:
================================================================================
Layer (type:depth-idx)                                  Param #
================================================================================
ChromBERTEmbedding                                      --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               --
│    │    └─TokenEmbedding: 3-1                         7,680
│    │    └─PositionalEmbedding: 3-2                    4,909,056
│    │    └─Dropout: 3-3                                --
│    └─ModuleList: 2-2                                  --
│    │    └─EncoderTransformerBlock: 3-4                6,497,280
│    │    └─EncoderTransformerBlock: 3-5                6,497,280
│    │    └─EncoderTransformerBlock: 3-6                6,497,280
│    │    └─EncoderTransformerBlock: 3-7                6,497,280
│    │    └─EncoderTransformerBlock: 3-8                6,497,280
│    │    └─EncoderTransformerBlock: 3-9                6,497,280
│    │    └─EncoderTransformerBlock: 3-10               6,497,280
│    │    └─EncoderTransformerBlock: 3-11               6,497,280
├─CistromeEmbeddingManager: 1-2                         --
================================================================================
Total params: 56,894,976
Trainable params: 56,894,976
Non-trainable params: 0
================================================================================
[16]:
dc_test = data_module.test_config
ds_test = dc_test.init_dataset()
dl_test = dc_test.init_dataloader(batch_size = 1)
len(ds_test), list(ds_test[0].keys())
[16]:
(1101,
 ['input_ids',
  'position_ids',
  'region',
  'build_region_index',
  'ignore_index',
  'label'])
[17]:
# Obtain embeddings for both classical and non-classical EZH2 sites
embs_classicial = []
embs_nonclassicial = []
for batch in tqdm(dl_test):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch)
    if batch["label"].item() == 1:
        embs_classicial.append(emb)
    else:
        embs_nonclassicial.append(emb)

print(len(embs_classicial), len(embs_nonclassicial))
embs_classicial = torch.cat(embs_classicial, dim = 0).cpu().numpy().mean(axis = 0)
embs_nonclassicial = torch.cat(embs_nonclassicial, dim = 0).cpu().numpy().mean(axis = 0)
embs_classicial.shape, embs_nonclassicial.shape
100%|██████████| 1101/1101 [01:28<00:00, 12.39it/s]
573 528
[17]:
((1072, 768), (1072, 768))

We focus exclusively on transcription factors, ignoring histone modifications and chromatin accessibility. This allows us to calculate the similarity between transcription factors, representing their potential interactions.

[18]:
with open(os.path.join(basedir, "config","hg38_6k_factors_list.txt"),"r") as f:
    factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]
factors[:3], len(factors)
[18]:
(['adnp', 'aebp2', 'aff1'], 991)
[19]:
factors[-1]
[19]:
'zzz3'
[20]:
indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
embs_classicial = embs_classicial[indices]
embs_nonclassicial = embs_nonclassicial[indices]

[21]:
from sklearn.metrics.pairwise import cosine_similarity
cos_classicial_matrix = cosine_similarity(embs_classicial)
cos_nonclassicial_matrix = cosine_similarity(embs_nonclassicial)
df_cos_classicial = pd.DataFrame(cos_classicial_matrix, columns = names, index = names)
df_cos_nonclassicial = pd.DataFrame(cos_nonclassicial_matrix, columns = names, index = names)
[22]:
# we define threshold to select the most related regulators pairs
thre_class = np.percentile(cos_classicial_matrix.flatten(), 95)
thre_nonclass = np.percentile(cos_nonclassicial_matrix.flatten(), 95)
thre_class, thre_nonclass
[22]:
(0.5776264667510986, 0.5518490076065063)

Now, we identify TRNs associated with the non-classical functions of EZH2. As you can see, factors related to the classical functions of EZH2 are associated with the Polycomb complex, such as SUZ12. In contrast, factors linked to EZH2’s non-classical functions tend to be associated with transcriptional activation, including EP300 and STAT3.

[23]:
df_cos_ezh2 = pd.DataFrame(index =names, data = {"classical":df_cos_classicial.loc["ezh2",:],"nonclassical":df_cos_nonclassicial.loc["ezh2",:]})
df_cos_ezh2["diff"] = df_cos_ezh2["classical"] - df_cos_ezh2["nonclassical"]
df_cos_ezh2
[23]:
classical nonclassical diff
adnp 0.193229 0.359759 -0.166530
aebp2 0.190192 0.309486 -0.119294
aff1 0.396427 0.479881 -0.083454
aff4 0.300634 0.452184 -0.151550
ago1 0.234017 0.291643 -0.057626
... ... ... ...
zscan5a 0.218051 0.334104 -0.116052
zta 0.181115 0.251131 -0.070015
zxdb 0.213376 0.294275 -0.080898
zxdc 0.149416 0.287630 -0.138214
zzz3 0.248419 0.346867 -0.098448

991 rows × 3 columns

[24]:
df_cos_ezh2.query("classical > @thre_class ").sort_values("diff", ascending = False).head(10)
[24]:
classical nonclassical diff
ezh1 0.644102 0.578448 6.565380e-02
pcgf1 0.627553 0.564524 6.302953e-02
kdm2b 0.600312 0.543989 5.632281e-02
jarid2 0.660526 0.609975 5.055112e-02
rybp 0.594199 0.549152 4.504716e-02
suz12 0.875248 0.833695 4.155296e-02
bcor 0.616664 0.587787 2.887654e-02
eed 0.597332 0.576189 2.114320e-02
ezh2 1.000000 1.000000 -3.576279e-07
cbx8 0.584771 0.621309 -3.653848e-02
[25]:
df_cos_ezh2.query("nonclassical > @thre_nonclass ").sort_values("diff", ascending = True).head(10)
[25]:
classical nonclassical diff
foxm1 0.395081 0.604784 -0.209703
med1 0.349747 0.556351 -0.206604
stat3 0.424211 0.614199 -0.189988
ep300 0.411506 0.599940 -0.188434
hinfp 0.391241 0.577772 -0.186532
rela 0.385995 0.568461 -0.182466
smarca4 0.383161 0.552061 -0.168900
brca1 0.415508 0.583485 -0.167977
stat1 0.431552 0.594642 -0.163090
e2f4 0.423399 0.568790 -0.145391

The end

This tutorial offers a comprehensive guide to context-specific TRN inference, using EZH2’s functional collaborations as an example.
We hope you find it both helpful and informative.
If you have any questions or require further assistance, please don’t hesitate to reach out. Thank you for following along!
[ ]: