Compiled Scripts for fine-tuning of ChromBERT

Overview

For a hands-on tutorial, see the documentation on Tutorial for fine-tuning ChromBERT.

We provide three scripts for fine-tuning, designed for your convenience. All scripts can be downloaded and executed anywhere, provided that your installation is correct.

For detailed usage instructions, run the following command:

python <script.py> --help
Fine-Tune Scripts

Type

Download

Description

Cell-type-specific regulatory effects

download

Designed for scenarios where the model fine-tuning for cell-type-specific regulatory effects.

Prompt-enhanced

download

Designed for scenarios that require incorporating additional information into the model.

Gene expression prediction

download

Intended for tasks that use multiple 1-kb bins as input, such as gene expression prediction.

Details

Cell-type-specific regulatory effects

This script enables fine-tuning ChromBERT for analyzing cell-type-specific regulatory effects. Users can selectively perturb or omit specific genomic features, making it valuable for simulating regulatory changes and testing hypotheses about the role of individual regulatory elements in cell-type-specific gene regulation.

python ft_general.py [OPTIONS] --train TRAIN_PATH --valid VALID_PATH --test TEST_PATH

Options

--lr

Learning rate. Default is 1e-4.

--warmup-ratio

Warmup ratio. Default is 0.1.

--grad-samples

Number of gradient samples. Automatically scaled according to the batch size and GPU number. Default is 512.

--max-epochs

Number of epochs to train. Default is 10.

--pretrain-trainable

Number of pretrained layers to be trainable. Default is 2.

--tag

Tag of the trainer, used for grouping logged results. Default is default.

--limit-val-batches

Number of batches to use for each validation. Default is 64.

--val-check-interval

Validation check interval. Default is 64.

--name

Name of the trainer. Default is chrombert-ft-general.

--save-top-k

Save top k checkpoints. Default is 3.

--checkpoint-metric

Checkpoint metric. Default is the same as the loss function if not specified.

--checkpoint-mode

Checkpoint mode. Default is min.

--log-every-n-steps

Log every n steps. Default is 50.

--kind

Kind of the task. Choose from classification, regression, or zero_inflation. Default is classification.

--loss

Loss function. Default is focal.

--train

Path to the training data. This option is required.

--valid

Path to the validation data. This option is required.

--test

Path to the test data. This option is required.

--batch-size

Batch size. Default is 8.

--num-workers

Number of workers. Default is 4.

--basedir

Path to the base directory. Default is set to the value of os.path.expanduser("~/.cache/chrombert/data").

-g, --genome

Genome version. For example, hg38 or mm10. Only hg38 is supported now. Default is hg38.

-k, --ckpt

Path to the pretrain checkpoint. Optional if it could be inferred from other arguments.

--mask

Path to the mtx mask file. Optional if it could be inferred from other arguments.

-d, --hdf5-file

Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments.

--dropout

Dropout rate. Default is 0.1.

-hr, --high-resolution

Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet.

--ignore

Ignore given targets.

--ignore-object

Ignore object. Regulator, or dataset IDs separated by ;.

--perturbation

Use perturbation model.

--perturbation-object

Perturbation object. Regulator, or dataset IDs separated by ;.

--perturbation-value

Perturbation target level. 0 means knock-out perturbation, and 4 means over-expression perturbation. Default is 0.


Prompt-enhanced

This script allows you to fine-tune ChromBERT by adding extra information as prompts. You can include things like cell-type features or DNA sequence patterns to help the model make better predictions. The model uses these prompts as additional clues when analyzing genomic data.

python ft_prompt_enhanced.py [OPTIONS] --prompt-kind KIND \
    --train TRAIN_PATH \
    --valid VALID_PATH \
    --test TEST_PATH

# use cache file for acceleration
python ft_prompt_enhanced.py [OPTIONS] \
    --prompt-kind KIND  \
    --prompt-regulator-cache-file CACHE_PATH1 \
    --prompt-celltype-cache-file CACHE_PATH2 \
    --train TRAIN_PATH \
    --valid VALID_PATH \
    --test TEST_PATH

Options

--lr

Learning rate. Default is 1e-4.

--warmup-ratio

Warmup ratio. Default is 0.1.

--grad-samples

Number of gradient samples. Automatically scaled according to the batch size and GPU number. Default is 512.

--pretrain-trainable

Number of pretrained layers to be trainable. Default is 0.

--max-epochs

Number of epochs to train. Default is 10.

--tag

Tag of the trainer, used for grouping logged results. Default is default.

--limit-val-batches

Number of batches to use for each validation. Default is 64.

--val-check-interval

Validation check interval. Default is 64.

--name

Name of the trainer. Default is chrombert-ft-prompt-enhanced.

--save-top-k

Save top k checkpoints. Default is 3.

--checkpoint-metric

Checkpoint metric. Default is bce.

--checkpoint-mode

Checkpoint mode. Default is min.

--log-every-n-steps

Log every n steps. Default is 50.

--kind

Kind of the task. Choose from classification, regression, or zero_inflation. Default is classification.

--loss

Loss function. Default is focal.

--train

Path to the training data. This option is required.

--valid

Path to the validation data. This option is required.

--test

Path to the test data. This option is required.

--batch-size

Batch size. Default is 8. It’s suggested to set a larger number to accelerate training here.

--num-workers

Number of workers. Default is 4.

--basedir

Path to the base directory. Default is set to the value of os.path.expanduser("~/.cache/chrombert/data").

-g, --genome

Genome version. For example, hg38 or mm10. Only hg38 is supported now. Default is hg38.

-k, --ckpt

Path to the checkpoints used to initialize the model. Optional. Defualt is the pretrain checkpoint provided in the base directory.

--mask
Path to the mtx mask file. Optional if it could infered from other arguments.
-d, --hdf5-file

Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments.

--dropout

Dropout rate. Default is 0.1.

-hr, --high-resolution

Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet.

--prompt-kind

Prompt data class. Choose from cistrome or expression. Default is None. This option is required.

--prompt-dim-external

Dimension of external data. Use 512 for scGPT, and 768 for ChromBERT’s embedding. Default is 512.

--prompt-celltype-cache-file

Path to the cell-type-specific prompt cache file. Provided if you want to use cache file to accelerate the training process. Optional. Default is not use it.

--prompt-regulator-cache-file

Path to the regulator prompt cache file. Provided if you want to use cache file to accelerate the training process. Optional. Default is not use it.


Gene expression prediction

Gene expression is influenced by multiple regulatory regions, often extending over significant genomic distances, particularly near the transcription start site (TSS). This task uses a flank window to consider multiple nearby regions, providing a holistic view of regulatory impacts on gene expression.

python ft_gep.py [OPTIONS] --flank-window FLANK_WINDOW_SIZE \
--train TRAIN_PATH \
--valid VALID_PATH \
--test TEST_PATH

Options

--lr

Learning rate. Default is 1e-4.

--warmup-ratio

Warmup ratio for the learning rate. Default is 0.1.

--grad-samples

Number of gradient samples, scaled by batch size and GPU count. Default is 128.

--pretrain-trainable

Number of pretrained layers to be trainable. Default is 2.

--max-epochs

Maximum number of training epochs. Default is 10.

--tag

Tag of the trainer, used for grouping logged results. Default is default.

--limit-val-batches

Number of batches to use for each validation. Default is 64.

--val-check-interval

Interval for validation checks. Default is 64.

--name

Name of the training session. Default is chrombert-ft-gep.

--save-top-k

Number of top-performing checkpoints to save. Default is 3.

--checkpoint-metric

Metric for checkpointing. Default is pcc.

--checkpoint-mode

Mode for checkpointing. Default is max.

--log-every-n-steps

Logging frequency in terms of steps. Default is 50.

--kind

Type of task, such as regression, zero_inflation. Default is regression.

--loss

Loss function to be used. Default is rmse.

--train

Path to the training data. This option is required.

--valid

Path to the validation data. This option is required.

--test

Path to the test data. This option is required.

--batch-size

Batch size for training. Default is 2.

--num-workers

Number of workers for data loading. Default is 4.

--basedir

Path to the base directory for model and data files. Default is os.path.expanduser("~/.cache/chrombert/data").

-g, --genome

Genome version. Only hg38 is supported now. Default is hg38.

-k, --ckpt

Path to the pretrained checkpoint. Optional if it could be inferred from other arguments.

--mask

Path to the mtx mask file. Optional if it could be inferred from other arguments.

-d, --hdf5-file

Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments.

--dropout

Dropout rate for the model. Default is 0.1.

-hr, --high-resolution

Use 200-bp resolution instead of 1-kb. Note: 200-bp resolution is not available yet, preparing for future release.

--flank-window

Flank window size for genomic data embedding. Default is 4.

--gep-zero-inflation

Specifies whether to include zero inflation in the GEP header. Default is False.

--gep-parallel-embedding

Enable parallel embedding, which is faster but requires more GPU memory.

--gep-gradient-checkpoint

Use gradient checkpointing to reduce GPU memory usage during training.