Coder Social home page Coder Social logo

vahidzee / ocdaf Goto Github PK

View Code? Open in Web Editor NEW
3.0 3.0 1.0 71.29 MB

Ordered Causal Discovery (two-staged causal structure discovery with Deep Learning)

License: MIT License

Python 28.16% R 0.28% Jupyter Notebook 71.56%
causal-discovery causal-ordering causality pytorch-lightning

ocdaf's Introduction

OSLow: Order-based Structure Learning with Normalizing Flows

main_fig

Python Version arXiv

InstallationExperimentsLicense

This repository provides the codebase for conducting experiments in our paper, built on top of the lightning-toolbox and dypy. These tools facilitate dynamic training processes and model designs for our experiments. For large-scale benchmarking and hyperparameter tuning, we utilize the dysweep package

Installation

Begin by cloning the repository and navigating to the root directory. In a python (>=3.9) environment, run the following commands:

git clone https://github.com/vahidzee/ocdaf.git # clone the repository
cd ocdaf
pip install -r requirements.txt # install dependencies

You will need a working wandb account and wandb enabled workspace to run the experiments. Start here! In addition, this codebase is heavily dependent on wandb sweep configurations. Take a look at the stock documentation here to learn more about how to navigate through the sweep workspaces.

Experiments

The details for all the experiments mentioned in the paper can be found experiments. Please read through the following for a big picture guide line on how to navigate the experimental details, and reproduce the results:

Running Single Experiments

Single experiments can be conducted by defining a configuration file. For instance, the following command creates a visualization of our training process based on the 2D PCA projection of sampled permutations:

python trainer.py fit --config experiments/examples/birkhoff-gumbel-sinkhorn.yaml --seed_everything=555

We have provided a sample configuration file with extensive documentation here to familiarize you with the components of our configurations. Furthermore, the trainer.py file is a standard LightningCLI runnable file that runs the causal discovery on a specific configuration defined.

Conducting Sweeps

Our experiments leverage dysweep to employ the Weights and Biases API for executing comprehensive experiments across different configurations. Each sweep configuration is a .yaml file located in the experiments/sweep directory.

To initiate a specific sweep, use the following command:

python sweep.py --config path/to/sweep-config.yaml

This generates a sweep object in the designated project with a unique ID. Subsequently, execute the following command across multiple machines to simultaneously run each configuration:

python sweep.py --entity=<Wandb-Entity> --project=<Wandb-Project> --sweep_id=<Wandb-Sweep-id> --count=<#-of-configurations-to-run>

Alternatively, simply add the sweep_id option to the initial command. Since we use jsonargparse this will simply rewrite the sweep_id on the previous configuration and seamlessly starts the sweep.

python sweep.py --config path/to/sweep-config.yaml --sweep_id=<Wandb-Sweep-id>

To completely reproduce our paper's experimental results, refer to the following sections:

Sachs dataset

We have created a sweep configuration here that runs the Sachs dataset on different seeds using the seed_everything option that Lightning provides. This will automatically create a sweep that runs the Sachs dataset on different hyper-parameter configurations, and for each configuration, it will run it for five different seeds.

Finally, the run will produce a set of model results as json files in the experiments/saves/sachs directory. These json files will contain full detail of the final ordering that the model has converged to and it can then later on be used for pruning.

Syntren dataset

Similar to Sachs, the sweep configuration for this run is available here. This is a simple sweep that will run all of the Syntren datas (with identifiers ranging from 0 to 1) and produce the same set of result json files in experiments/saves/syntren.

Synthetic datasets

We provide several sweep configurations for synthetic datasets, each covering a specific set of conditions and scenarios. The results are conveniently summarized using the Weights and Biases UI.

Small parametric datasets

The configuration for these experiments can be found here. It covers graphs with 3, 4, 5, and 6 covariates generated by different algorithms (tournaments, paths, and Erdos-Renyi graphs). The functional forms included are sinusoidal, polynomial, and linear, all accompanied with Gaussian noise. For a comparative study between affine and additive, both options are also included. Each configuration is run five times with different seeds.

We test each dataset using three algorithms: Gumbel top-k, Gumbel Sinkhorn, and Soft. In total, this sweep contains 1480 different configurations.

Small non-parametric datasets

You can find the sweep configuration for these datasets here. Similar to the parametric configuration, it covers graphs with 3, 4, 5, and 6 covariates. However, these datasets are generated using Gaussian processes to sample the scale and shift functions. Both Affine and Additive options are included for comparison, and each set of configuration is also seeded 5 times, totalling to 240 different configurations.

Small Linear Laplace Datasets

The configuration for the linear Laplace runs can be found here. This experiment demonstrates that our model can handle broader classes of Latent Structural Nonlinear Models (LSNMs), providing insights into possible updates of our theoretical conditions. For these configurations, we use small graphs with different generation schemes, but we employ a linear function for the scale and shift and choose a standard Laplace noise. The number of configurations generated by this sweep on different seeds totals to 480 runs.

Large Datasets

For large synthetic datasets, the sweep configuration can be found here. This set includes three different functional forms: sinusoidal, polynomial, and a Non-parametric scheme. The number of covariates is set to either 10 or 20, and each configuration is run on five different seeds. The final 30 synthetic configurations are passed on to the Gumbel-top-k method for evaluating model scalability.

You may refer to the dysweep documentation to learn how to generate your own sweep configurations.

Pruning

Our code also allows for pruning the final model ordering, which is facilitated by the prune.py file. Execute the pruning process with the following command:

python prune.py --method=<cam/pc> --data_type=<syntren/sachs> --data_num=<data_id (Optional)> --order=<dash-separated-ordering> --saved_permutations_dir=<directory-to-saved-permutations> 

In order to reproduce our results for the Sachs and Syntren datasets, you need to execute a series of steps after obtaining the experiments/saves directory:

  1. un the sweep for the dataset you're interested in. For instance, if you're working on the Syntren dataset, execute the Syntren sweep.
  2. After the sweep, a set of saved files will be available in the experiments/saves/syntren directory.
  3. These files will follow the data-i format, where i represents the identifier of the Syntren dataset.
  4. You can then use these saved files to run CAM pruning on all of the Syntren datasets. Run the command below, which iterates over all dataset IDs and performs pruning for each:
for i in {0..9}
do
    python prune.py --method=cam --data_type=syntren --data_num=$i --saved_permutations_dir=experiments/saves/syntren/data-$i
done

This process streamlines the replication of our results for the Sachs and Syntren datasets, using the CAM pruning method on all datasets generated by the sweep. Please check out the results to check a table of the different pruning techniques on the Sachs and Syntren datasets.

Intervention

To reproduce the results of our interventional experiments, look at this notebook for further instructions. The resulting checkpoints of the trained models are also available here

License

This project is licensed under the MIT License - see the LICENSE file for details

ocdaf's People

Contributors

hamidrezakmk avatar vahidzee avatar vdblm avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

harel-coffee

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.