Causal inference code for GNN-TARNET models and baseline models for comparison.
Code for the paper "Graph Neural Networks for Individual Treatment Effect Estimation".
This repository provides a compact framework for training and evaluating graph neural network models for individual treatment effect estimation, including the GNN-TARNET architecture used in the associated paper. The repository supports datasets with graph-structured data (e.g., social networks or molecular graphs) for use cases such as personalized medicine or recommendation systems.
This project provides code to train, validate, and compare GNN-based and baseline models for individualized treatment effect estimation on graph-structured data.
- Implementations of GNN-TARNET and baseline models
- Training, evaluation, and hyperparameter sweep utilities
- Dockerized environment for reproducible runs
- Example configuration and dataset loaders
- Docker (recommended for reproducibility)
- Python 3.8+ (if running without Docker)
- CUDA and a recent NVIDIA driver for GPU training (optional)
- Ensure Docker is installed and running on your system.
- Build the Docker image and run a container:
This will build the image and start a container with all dependencies installed.
./docker_commands.sh
If you prefer to run without Docker, ensure you have Python 3.8+ installed, then install the required dependencies:
pip install tensorflow tensorflow-probability==0.19.0 seaborn codecarbon==2.3.1 causal-learn keras-tuner==1.1.3 tf2onnx==1.14.0 onnxruntime==1.15.1 pandas numpy scikit-learn matplotlib-
Clone the repository:
git clone https://github.com/causal-lab-miism/gnntarnet.git cd gnntarnet -
Build and launch the Docker container (or install dependencies as shown above):
./docker_commands.sh
-
Inside the container, run the main experiment driver:
python main_hyper.py --model-name GNNTARnet --dataset-name ihdp_a
Training and evaluation are controlled by the main_hyper.py script. The following command-line options are available:
GNNTARnet: Graph Neural Network TARNET (default)TARnet: Treatment-Agnostic Representation NetworkCFRNet: Counterfactual Regression NetworkSLearner: S-Learner baselineTLearner: T-Learner baselineGANITE: Generative Adversarial Nets for Inference of Individualized Treatment EffectsTEDVAE: Treatment Effect Disentangled Variational Autoencoder
ihdp_a: IHDP dataset variant Aihdp_b: IHDP dataset variant Bjobs: Jobs dataset
Train the GNN-TARNET model on IHDP dataset variant A:
python main_hyper.py --model-name GNNTARnet --dataset-name ihdp_aTrain a baseline T-Learner model on the JOBS dataset:
python main_hyper.py --model-name TLearner --dataset-name jobsRun with custom hyperparameters:
python main_hyper.py --model-name GNNTARnet --dataset-name ihdp_a --num 100 --num_layers 5--model-name: Model architecture to use (default:GNNTARnet)--dataset-name: Dataset to train on (default:ihdp_a)--tuner-name: Hyperparameter tuner type (default:random)--num: Number of experimental runs (default: 1)--num_layers: Number of GNN layers (default: 5)--defaults: Use default hyperparameters (default:True)
For a complete list of options, run:
python main_hyper.py --helpIf you use this code in your research, please cite the related publication describing the GNN-TARNET approach:
A. Sirazitdinov, M. Buchwald, V. Heuveline and J. Hesser, "Graph Neural Networks for Individual Treatment Effect Estimation," in IEEE Access, vol. 12, pp. 106884-106894, 2024, doi: 10.1109/ACCESS.2024.3437665.
See the LICENSE file for license details.
For questions about usage or to report problems, open an issue in this repository.