Compiled Scripts for fine-tuning of ChromBERT¶
Overview¶
For a hands-on tutorial, see the documentation on Tutorial for fine-tuning ChromBERT.
We provide three scripts for fine-tuning, designed for your convenience. All scripts can be downloaded and executed anywhere, provided that your installation is correct.
For detailed usage instructions, run the following command:
python <script.py> --help
Type |
Download |
Description |
|---|---|---|
Designed for scenarios where the model fine-tuning for cell-type-specific regulatory effects. |
||
Designed for scenarios that require incorporating additional information into the model. |
||
Intended for tasks that use multiple 1-kb bins as input, such as gene expression prediction. |
Details¶
Cell-type-specific regulatory effects¶
This script enables fine-tuning ChromBERT for analyzing cell-type-specific regulatory effects. Users can selectively perturb or omit specific genomic features, making it valuable for simulating regulatory changes and testing hypotheses about the role of individual regulatory elements in cell-type-specific gene regulation.
python ft_general.py [OPTIONS] --train TRAIN_PATH --valid VALID_PATH --test TEST_PATH
Options
- --lr¶
Learning rate. Default is 1e-4.
- --warmup-ratio¶
Warmup ratio. Default is 0.1.
- --grad-samples¶
Number of gradient samples. Automatically scaled according to the batch size and GPU number. Default is 512.
- --max-epochs¶
Number of epochs to train. Default is 10.
- --pretrain-trainable¶
Number of pretrained layers to be trainable. Default is 2.
- --tag¶
Tag of the trainer, used for grouping logged results. Default is default.
- --limit-val-batches¶
Number of batches to use for each validation. Default is 64.
- --val-check-interval¶
Validation check interval. Default is 64.
- --name¶
Name of the trainer. Default is chrombert-ft-general.
- --save-top-k¶
Save top k checkpoints. Default is 3.
- --checkpoint-metric¶
Checkpoint metric. Default is the same as the loss function if not specified.
- --checkpoint-mode¶
Checkpoint mode. Default is min.
- --log-every-n-steps¶
Log every n steps. Default is 50.
- --kind¶
Kind of the task. Choose from classification, regression, or zero_inflation. Default is classification.
- --loss¶
Loss function. Default is focal.
- --train¶
Path to the training data. This option is required.
- --valid¶
Path to the validation data. This option is required.
- --test¶
Path to the test data. This option is required.
- --batch-size¶
Batch size. Default is 8.
- --num-workers¶
Number of workers. Default is 4.
- --basedir¶
Path to the base directory. Default is set to the value of
os.path.expanduser("~/.cache/chrombert/data").
- -g, --genome¶
Genome version. For example, hg38 or mm10. Only hg38 is supported now. Default is hg38.
- -k, --ckpt¶
Path to the pretrain checkpoint. Optional if it could be inferred from other arguments.
- --mask¶
Path to the mtx mask file. Optional if it could be inferred from other arguments.
- -d, --hdf5-file¶
Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments.
- --dropout¶
Dropout rate. Default is 0.1.
- -hr, --high-resolution¶
Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet.
- --ignore¶
Ignore given targets.
- --ignore-object¶
Ignore object. Regulator, or dataset IDs separated by ;.
- --perturbation¶
Use perturbation model.
- --perturbation-object¶
Perturbation object. Regulator, or dataset IDs separated by ;.
- --perturbation-value¶
Perturbation target level. 0 means knock-out perturbation, and 4 means over-expression perturbation. Default is 0.
Prompt-enhanced¶
This script allows you to fine-tune ChromBERT by adding extra information as prompts. You can include things like cell-type features or DNA sequence patterns to help the model make better predictions. The model uses these prompts as additional clues when analyzing genomic data.
python ft_prompt_enhanced.py [OPTIONS] --prompt-kind KIND \
--train TRAIN_PATH \
--valid VALID_PATH \
--test TEST_PATH
# use cache file for acceleration
python ft_prompt_enhanced.py [OPTIONS] \
--prompt-kind KIND \
--prompt-regulator-cache-file CACHE_PATH1 \
--prompt-celltype-cache-file CACHE_PATH2 \
--train TRAIN_PATH \
--valid VALID_PATH \
--test TEST_PATH
Options
- --lr¶
Learning rate. Default is 1e-4.
- --warmup-ratio¶
Warmup ratio. Default is 0.1.
- --grad-samples¶
Number of gradient samples. Automatically scaled according to the batch size and GPU number. Default is 512.
- --pretrain-trainable¶
Number of pretrained layers to be trainable. Default is 0.
- --max-epochs¶
Number of epochs to train. Default is 10.
- --tag¶
Tag of the trainer, used for grouping logged results. Default is default.
- --limit-val-batches¶
Number of batches to use for each validation. Default is 64.
- --val-check-interval¶
Validation check interval. Default is 64.
- --name¶
Name of the trainer. Default is chrombert-ft-prompt-enhanced.
- --save-top-k¶
Save top k checkpoints. Default is 3.
- --checkpoint-metric¶
Checkpoint metric. Default is bce.
- --checkpoint-mode¶
Checkpoint mode. Default is min.
- --log-every-n-steps¶
Log every n steps. Default is 50.
- --kind¶
Kind of the task. Choose from classification, regression, or zero_inflation. Default is classification.
- --loss¶
Loss function. Default is focal.
- --train¶
Path to the training data. This option is required.
- --valid¶
Path to the validation data. This option is required.
- --test¶
Path to the test data. This option is required.
- --batch-size¶
Batch size. Default is 8. It’s suggested to set a larger number to accelerate training here.
- --num-workers¶
Number of workers. Default is 4.
- --basedir¶
Path to the base directory. Default is set to the value of
os.path.expanduser("~/.cache/chrombert/data").
- -g, --genome¶
Genome version. For example, hg38 or mm10. Only hg38 is supported now. Default is hg38.
- -k, --ckpt¶
Path to the checkpoints used to initialize the model. Optional. Defualt is the pretrain checkpoint provided in the base directory.
- -d, --hdf5-file¶
Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments.
- --dropout¶
Dropout rate. Default is 0.1.
- -hr, --high-resolution¶
Use 200-bp resolution instead of 1-kb resolution. Caution: 200-bp resolution is preparing for the future release of ChromBERT, which is not available yet.
- --prompt-kind¶
Prompt data class. Choose from cistrome or expression. Default is None. This option is required.
- --prompt-dim-external¶
Dimension of external data. Use 512 for scGPT, and 768 for ChromBERT’s embedding. Default is 512.
- --prompt-celltype-cache-file¶
Path to the cell-type-specific prompt cache file. Provided if you want to use cache file to accelerate the training process. Optional. Default is not use it.
- --prompt-regulator-cache-file¶
Path to the regulator prompt cache file. Provided if you want to use cache file to accelerate the training process. Optional. Default is not use it.
Gene expression prediction¶
Gene expression is influenced by multiple regulatory regions, often extending over significant genomic distances, particularly near the transcription start site (TSS). This task uses a flank window to consider multiple nearby regions, providing a holistic view of regulatory impacts on gene expression.
python ft_gep.py [OPTIONS] --flank-window FLANK_WINDOW_SIZE \
--train TRAIN_PATH \
--valid VALID_PATH \
--test TEST_PATH
Options
- --lr¶
Learning rate. Default is 1e-4.
- --warmup-ratio¶
Warmup ratio for the learning rate. Default is 0.1.
- --grad-samples¶
Number of gradient samples, scaled by batch size and GPU count. Default is 128.
- --pretrain-trainable¶
Number of pretrained layers to be trainable. Default is 2.
- --max-epochs¶
Maximum number of training epochs. Default is 10.
- --tag¶
Tag of the trainer, used for grouping logged results. Default is default.
- --limit-val-batches¶
Number of batches to use for each validation. Default is 64.
- --val-check-interval¶
Interval for validation checks. Default is 64.
- --name¶
Name of the training session. Default is chrombert-ft-gep.
- --save-top-k¶
Number of top-performing checkpoints to save. Default is 3.
- --checkpoint-metric¶
Metric for checkpointing. Default is pcc.
- --checkpoint-mode¶
Mode for checkpointing. Default is max.
- --log-every-n-steps¶
Logging frequency in terms of steps. Default is 50.
- --kind¶
Type of task, such as regression, zero_inflation. Default is regression.
- --loss¶
Loss function to be used. Default is rmse.
- --train¶
Path to the training data. This option is required.
- --valid¶
Path to the validation data. This option is required.
- --test¶
Path to the test data. This option is required.
- --batch-size¶
Batch size for training. Default is 2.
- --num-workers¶
Number of workers for data loading. Default is 4.
- --basedir¶
Path to the base directory for model and data files. Default is
os.path.expanduser("~/.cache/chrombert/data").
- -g, --genome¶
Genome version. Only hg38 is supported now. Default is hg38.
- -k, --ckpt¶
Path to the pretrained checkpoint. Optional if it could be inferred from other arguments.
- --mask¶
Path to the mtx mask file. Optional if it could be inferred from other arguments.
- -d, --hdf5-file¶
Path to the HDF5 file that contains the dataset. Optional if it could be inferred from other arguments.
- --dropout¶
Dropout rate for the model. Default is 0.1.
- -hr, --high-resolution¶
Use 200-bp resolution instead of 1-kb. Note: 200-bp resolution is not available yet, preparing for future release.
- --flank-window¶
Flank window size for genomic data embedding. Default is 4.
- --gep-zero-inflation¶
Specifies whether to include zero inflation in the GEP header. Default is False.
- --gep-parallel-embedding¶
Enable parallel embedding, which is faster but requires more GPU memory.
- --gep-gradient-checkpoint¶
Use gradient checkpointing to reduce GPU memory usage during training.