GRNFormer is an advanced variational graph transformer autoencoder model designed to accurately infer regulatory relationships between transcription factors (TFs) and target genes from single-cell RNA-seq transcriptomics data, while supporting generalization across species and cell types.
GRNFormer consists of three main novel designs:
-
TFWalker: A de-novo Transcription Factor (TF) centered subgraph sampling method to extract local or neighborhood co-expression of a transcription factor (TF) to facilitate GRN inference.
-
End-to-End Learning:
- GeneTranscoder: A transformer encoder representation module for encoding single-cell RNA-seq (scRNA-seq) gene expression data across different species and cell types.
- A graph transformer model with a GRNFormer Encoder and a variational GRNFormer decoder coupled with GRN inference module for the reconstruction of GRNs.
-
Novel Inference Strategy: Incorporates both node features and edge features to infer GRNs for given gene expression data of any given length.
Given a scRNA-seq dataset, a gene co-expression network is first constructed, from which a set of subgraphs are sampled by TF-Walker. The subgraphs are processed by GeneTranscoder to generate node and edge embeddings, which are fed to the variational graph transformer autoencoder to learn a GRN representation. The representation is used to infer a gene regulatory sub-network for each subgraph. The subnetworks are aggregated to construct a full GRN.
- Python 3.11+
- CUDA-capable GPU (recommended for training)
- Conda or Miniconda
- Clone the repository:
git clone https://github.com/BioinfoMachineLearning/GRNformer.git
cd GRNformer- Set up conda environment and install necessary packages using the setup script:
./setup.shAlternatively, you can manually create the environment:
conda env create -f environment.yml
conda activate grnformer_envRun GRNFormer inference on a sample gene expression file:
python infer_grn.py \
--exp_file /path/to/expression-file.csv \
--tf_file /path/to/listoftfs.csv \
--output_file /path/to/predicted-edges.csv \
--coexpression_threshold 0.1 \
--max_subgraph_size 100Input File Formats:
expression-file.csv: Gene expression matrix with genes as rows and cells as columns (or vice versa - the script handles both orientations)listoftfs.csv: List of transcription factor gene names (one per line or comma-separated)output_file: Path where the predicted GRN edges will be saved (CSV format: source, target, weight/score)
Optional Parameters:
--coexpression_threshold(default: 0.1): Threshold for constructing the co-expression network. Lower values result in denser networks, while higher values create sparser networks.--max_subgraph_size(default: 100): Maximum number of nodes in each TF-centered subgraph sampled by TFWalker. Adjust based on your dataset size and computational resources.
Run GRNFormer to evaluate performance when a ground truth network is available:
python eval_grn.py \
--exp_file /path/to/expression-file.csv \
--tf_file /path/to/listoftfs.csv \
--net_file /path/to/ground-truth-network.csv \
--output_file /path/to/predicted-edges.csvAdditional Input:
ground-truth-network.csv: Ground truth network edges (CSV format: source, target)
For evaluation with custom coexpression threshold and subgraph size:
python eval_grn_custom.py \
--exp_file /path/to/expression-file.csv \
--tf_file /path/to/listoftfs.csv \
--net_file /path/to/ground-truth-network.csv \
--output_file /path/to/predicted-edges.csv \
--ckpt_path /path/to/checkpoint.ckpt \
--coexpression_threshold 0.1 \
--max_subgraph_size 100Additional Parameters:
--ckpt_path: Path to the trained model checkpoint file--coexpression_threshold(default: 0.1): Threshold for co-expression network construction--max_subgraph_size(default: 100): Maximum subgraph size for TFWalker sampling
Evaluate model robustness under various perturbation conditions (noise and dropout):
Single test with specific perturbation:
python eval_grn_perturb.py \
--single_test \
--exp_file /path/to/expression-file.csv \
--tf_file /path/to/listoftfs.csv \
--net_file /path/to/ground-truth-network.csv \
--output_file /path/to/predicted-edges.csv \
--ckpt_path /path/to/checkpoint.ckpt \
--noise_std 0.1 \
--dropout_fraction 0.05 \
--coexpression_threshold 0.1 \
--max_subgraph_size 100Full perturbation sweep (tests multiple noise and dropout levels):
python eval_grn_perturb.py \
--exp_file /path/to/expression-file.csv \
--tf_file /path/to/listoftfs.csv \
--net_file /path/to/ground-truth-network.csv \
--output_file /path/to/predicted-edges.csv \
--ckpt_path /path/to/checkpoint.ckpt \
--noise_levels 0.0 0.05 0.1 0.15 0.2 \
--dropout_levels 0.0 0.05 0.1 0.15 \
--output_dir ./outputs/perturbation_results \
--coexpression_threshold 0.1 \
--max_subgraph_size 100Perturbation Parameters:
--noise_std: Standard deviation of Gaussian noise to add to expression data (for single test)--dropout_fraction: Fraction of genes to randomly drop (for single test)--noise_levels: Space-separated list of noise levels for sweep (e.g., "0.0 0.05 0.1 0.15 0.2")--dropout_levels: Space-separated list of dropout fractions for sweep (e.g., "0.0 0.05 0.1 0.15")--absolute_noise: Use absolute noise values instead of scaled (default: noise is scaled relative to data std)--output_dir: Directory to save perturbation sweep results--coexpression_threshold(default: 0.1): Threshold for co-expression network construction--max_subgraph_size(default: 100): Maximum subgraph size for TFWalker sampling
Download BEELINE sc-RNAseq datasets:
python collect_data.py --data_dir ./Data/scRNA-seq/The downloaded datasets can be found in:
Data/scRNA-seq/- Expression dataData/scRNA-seq-Networks/- Network data
Run the evaluation pipeline on test datasets with all subset creations:
python evaluation_pipeline.py \
--dataset_file Data/mESC.csv \
--output_dir ./outputs/evaluationDownload BEELINE sc-RNAseq datasets:
python collect_data.py --data_dir ./Data/scRNA-seq/Note: Before beginning training, copy all the Regulatory Networks (Non-specific-Chip-seq-network.csv, STRING-network.csv, [cell-type]-Chip-seq-network.csv) and TFs.csv file to the corresponding cell-type datasets in ./Data/scRNA-seq/[cell-type]/.
For generalization training, GRNformer combines all the networks for every training dataset:
python dataset_combiner.py \
--cell-type-network ./Data/scRNA-seq/hESC/hESC-Chip-seq-network.csv \
--non-specific-network ./Data/scRNA-seq/hESC/Non-specific-Chip-seq-network.csv \
--string-network ./Data/scRNA-seq/hESC/STRING-network.csv \
--output-file ./Data/scRNA-seq/hESC/hESC-combined.csvCreate dataset and splits for training, validation, and testing:
python create_dataset.py \
--dataset_dir ./Data/sc-RNAseq \
--dataset_name ./Data/train_list.csvTrain the model from scratch using the configuration file:
python main.py fit --config config/grnformer.yamlYou can customize training parameters by editing config/grnformer.yaml or by passing command-line arguments.
- BEELINE: https://zenodo.org/records/3701939
- DREAM5: https://www.synapse.org/Synapse:syn2787209/wiki/70351
- PBMC3k: https://support.10xgenomics.com/single-cell-gene-expression/datasets/1.1.0/pbmc3k
- Preprocessed PBMC: Can be accessed from the
scanpyPython package
GRNformer/
├── src/
│ ├── models/
│ │ └── grnformer/
│ │ ├── model.py # Main GRNFormer model
│ │ └── network.py # Network architecture
│ └── datamodules/
│ ├── grn_datamodule.py # Training data module
│ ├── grn_dataset_inference.py # Inference dataset
│ └── grn_dataset_test.py # Test dataset
├── config/
│ └── grnformer.yaml # Training configuration
├── main.py # Training entry point
├── infer_grn.py # Inference script
├── eval_grn.py # Standard evaluation script
├── eval_grn_custom.py # Custom evaluation with configurable parameters
├── eval_grn_perturb.py # Perturbation evaluation script
├── evaluation_pipeline.py # Full evaluation pipeline
├── create_dataset.py # Dataset creation
├── dataset_combiner.py # Network combination
├── collect_data.py # Data download
└── environment.yml # Conda environment
If you use GRNFormer in your research, please cite:
@article {Hegde2025.01.26.634966,
author = {Hegde, Akshata and Cheng, Jianlin},
title = {GRNFormer: Accurate Gene Regulatory Network Inference Using Graph Transformer},
elocation-id = {2025.01.26.634966},
year = {2025},
doi = {10.1101/2025.01.26.634966},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2025/01/27/2025.01.26.634966},
eprint = {https://www.biorxiv.org/content/early/2025/01/27/2025.01.26.634966.full.pdf},
journal = {bioRxiv}
}See LICENSE file for details.
For questions or issues, please open an issue on the GitHub repository.
