Coder Social home page Coder Social logo

jim's Introduction

Jim jim - A JAX-based gravitational-wave inference toolkit

doc

Jim comprises a set of tools for estimating parameters of gravitational-wave sources thorugh Bayesian inference. At its core, Jim relies on the JAX-based sampler flowMC, which leverages normalizing flows to enhance the convergence of a gradient-based MCMC sampler.

Since its based on JAX, Jim can also leverage hardware acceleration to achieve significant speedups on GPUs. Jim also takes advantage of likelihood-heterodyining, (Cornish 2010, Cornish 2021) to compute the gravitational-wave likelihood more efficiently.

See the accompanying paper, Wong, Isi, Edwards (2023) for details.

Warning

Jim is under heavy development, so API is constantly changing. Use at your own risk! One way to mitigate this inconvience is to make your own fork over a version for now. We expect to hit a stable version this year. Stay tuned.

[Documentatation and examples are a work in progress]

Installation

You may install the latest released version of Jim through pip by doing

pip install jimGW

You may install the bleeding edge version by cloning this repo, or doing

pip install git+https://github.com/kazewong/jim

If you would like to take advantage of CUDA, you will additionally need to install a specific version of JAX by doing

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

NOTE: Jim is only currently compatible with Python 3.10.

Performance

The performance of Jim will vary depending on the hardware available. Under optimal conditions, the CUDA installation can achieve parameter estimation in ~1 min on an Nvidia A100 GPU for a binary neutron star (see paper for details). If a GPU is not available, JAX will fall back on CPUs, and you will see a message like this on execution:

No GPU/TPU found, falling back to CPU.

Directory

Parameter estimation examples are in example/ParameterEstimation.

Attribution

Please cite the accompanying paper, Wong, Isi, Edwards (2023).

jim's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

jim's Issues

Faster optimization for heterodyne likelihood using Jax

Currently, the optimization is done through scipy differential evolution by calling the jax kernel probably sequentially, which is pretty slow.

Ideally we want something completely Jaxxed, such as a vmap lbfgs to solve for the optimum

Weird behaviour in terms of training the normalizing flow.

There is a report on the following breaking mode:

When analyzing GW170817 as shown in the example, if the prior of cos(iota) is changed to [-1,-0.9], then choose a number of chain = 800 will give non-sensible result, such as nan in the log_prob.

Further investigation shows number of chains = 601, 801 will also give the same behavior. But 100, 500, 501 performs normally.

This is very weird, which I think it has to do with the training of the normalizing flow and thinning.

Problems in GW170817 example

In a recent PR, there were two updated examples about GW170817, GW170817_PhenomD_NRTv2.py and GW170817_TaylorF2.py. Currently they have some paths that are pointing to a local driver of one of the contributers @ThibeauWouters.

@ThibeauWouters Can you update that two scripts?

Using PyDantic to validate the code structure

I discovered PyDantic when I was playing with langchain, which I think is beneficial for building application in python in general.

That said, this is one more dependency in our list, which might be an overkill for an astro software anyway.

Maybe worth thinking whether we want this in the future.

Extrinsic parameters marginalization

Marginalizing over extrinsic parameters should alleviate our computational load, but more important it should smooth out the likelihood surface so HMC will have a better time traversing the likelihood surface.

Three marginalization that can be done readily is time of coalescence, phase of coalescence (which is only valid when 22 mode is the dominant mode), and distance.

References:
https://arxiv.org/pdf/1809.02293.pdf
Marginalization over time: https://dcc.ligo.org/public/0114/T1400460/002/margtime.pdf
Marginalization over phase: https://dcc.ligo.org/LIGO-T1300326/public
Marginalization over distance: https://journals.aps.org/prd/abstract/10.1103/PhysRevD.93.024013

Restyling the code to work with pre-commit

To facilitate best practice in maintaining the code base, we should starting rolling in support to adhere to pre-commit checks, which check for styles and typing before commiting

Command line tools for runs.

Currently the main way to run an inference run is to go through the examples, write a script using the Jim library, and run something along the line of python my_script.py.

This is good for exploration, but in production, it would be nice to have some more infrastructure that support programmatically generating config files, then run the inference.

There seem to be some missing files in the project

Hello, I am very interested in your project. I would like to ask where the files injection_WithParser.py, RealDataAnalysis.py, and utils_plotting.py are located in the project. Could you please provide them. Thank you.

Installation is unstable now

The examples don't necessarily run after simply installing everything from pip.

We need to pin the jax version and put the necessary dependencies into setup.cfg

Add population analysis functionality

The main focus of jim has been on individual event parameter estimation. One obvious extension that can leverage the infrastructure in jim and flowMC is the addition of population analysis, at least in the way that is commonly done with hierarchical Bayesian analysis.

This would require the following steps:

  • 1. Data ingestion pipeline for posterior samples
  • 2. Selection function data
  • 3. Hierarchical Bayesian model likelihood.

Adding generic transformation function

It would be good to have a more generic transformation function/interface than the currently available 1-to-1 mapping function in jim.

That would be particularly useful for nuclear equation-of-state study, in which one would need to sample over the masses (2 parameters) and the EOS parameters (1 to N), then transform them into the tidal deformabilities (2 parameters) by solving the TOV equations.

Heterodyned likelihood not working

The HeterodynedLikelihood function is not working. Changing from TransientLikelihood to HeterodynedLikelihood in the GW150914.py leads to an error.

The Jacobian determinant for transformation should take absolute value

Currently, JIM computes the density change in the parameter space due to transformation by calculating the Jacobian determinant. The absolute value of the Jacobian determinant should be used; otherwise, feeding it into the calculation of the log Jacobian determinant may result in nan values.

Batch job parser

Currently the user need to go into the python code to set up their own run. For future production, there should be a wrapper interfacing with the core and a configuration files, which the user can just specify the configuration they want for the run (Probably in a somewhat restrictive way),

On top of this, it would be good to have some sort of batch job options that the code will consume an array of configuration files then run on them, such that we don't need to pay the compilation overhead repeatedly.

Add more prior classes and add composite prior example

Currently there is only uniform prior with a guard outside the domain available.

This create issue of broken waveform returning nan when the gradient is large in HMC, which shoot the test point outside the domain and raising nan.

Here is a list of priors that could be useful to have:

  • Unconstrained uniform. Sample from an unconstrained distribution that corresponds to a uniform distribution over [a,b]
  • Uniform on sphere. Sample from a sphere and transform them into Cartesian coordinate. Useful for spins.
  • Add example of composite prior. This should take a list of priors and return a prior

Working with scaling factors

It might be beneficial to work with whiten data such that the value and the dynamic range of data won't be present in the analysis. This could be useful to leverage lower precision format like float16/float32

Adam optimizer doesn't work for some cases

Script:

from jimgw.single_event.runManager import SingleEventPERunManager, SingleEventRun
import jax.numpy as jnp
import jax

import os
outdir = os.path.dirname(__file__)
label = os.path.splitext(os.path.basename(__file__))[0]

jax.config.update("jax_enable_x64", True)

mass_matrix = jnp.eye(15)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[9, 9].set(1e-3)
mass_matrix = mass_matrix * 3e-3
local_sampler_arg = {"step_size": mass_matrix}
bounds = jnp.array(
    [
        [10.0, 40.0],
        [0.125, 1.0],
        [0, jnp.pi],
        [0, 2*jnp.pi],
        [0.0, 1.0],
        [0, jnp.pi],
        [0, 2*jnp.pi],
        [0.0, 1.0],
        [0.0, 2000.0],
        [-0.05, 0.05],
        [0.0, 2 * jnp.pi],
        [-1.0, 1.0],
        [0.0, jnp.pi],
        [0.0, 2 * jnp.pi],
        [-1.0, 1.0],
    ]
)


run = SingleEventRun(
    seed=0,
    path='',
    detectors=["H1", "L1"],
    priors={
        "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
        "q": {"name": "MassRatio"},
        "s1": {"name": "Sphere"},
        "s2": {"name": "Sphere"},
        "d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0},
        "t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05},
        "phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
        "cos_iota": {"name": "CosIota"},
        "psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi},
        "ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
        "sin_dec": {"name": "SinDec"},
    },
    waveform_parameters={"name": "RippleIMRPhenomPv2", "f_ref": 20.0},
    jim_parameters={
        "n_loop_training": 10,
        "n_loop_production": 10,
        "n_local_steps": 150,
        "n_global_steps": 150,
        "n_chains": 500,
        "n_epochs": 50,
        "learning_rate": 0.001,
        "max_samples": 45000,
        "momentum": 0.9,
        "batch_size": 50000,
        "use_global": True,
        "keep_quantile": 0.0,
        "train_thinning": 1,
        "output_thinning": 10,
        "local_sampler_arg": local_sampler_arg,
    },
    likelihood_parameters={"name": "HeterodynedTransientLikelihoodFD", "bounds": bounds},
    injection=True,
    injection_parameters={
        "M_c": 28.6,
        "eta": 0.24,
        "s1_x": 0.05,
        "s1_y": -0.05,
        "s1_z": 0.05,
        "s2_x": -0.05,
        "s2_y": 0.05,
        "s2_z": 0.05,
        "d_L": 440.0,
        "t_c": 0.0,
        "phase_c": 0.0,
        "iota": 0.5,
        "psi": 0.7,
        "ra": 1.2,
        "dec": 0.3,
    },
    data_parameters={
        "trigger_time": 1126259462.4,
        "duration": 4,
        "post_trigger_duration": 2,
        "f_min": 20.0,
        "f_max": 1024.0,
        "tukey_alpha": 0.2,
        "f_sampling": 4096.0,
    },
)

run_manager = SingleEventPERunManager(run=run)
run_manager.jim.sample(jax.random.PRNGKey(42))
samples = run_manager.jim.get_samples()
run_manager.save(outdir+'/'+label)
jnp.save(outdir+'/'+label+"_samples.npy", samples)
run_manager.jim.print_summary()

Output:

Run instance provided. Loading from instance.
Initializing detectors.
Injection mode. Need to wait until waveform model is loaded.
Injection mode. Need to wait until waveform model is loaded.
Initializing waveform.
Grabbing GWTC-2 PSD for H1
For detector H1:
The injected optimal SNR is 29.713682403670045
The injected match filter SNR is (30.11265963611454-0.563276239330949j)
Grabbing GWTC-2 PSD for L1
For detector L1:
The injected optimal SNR is 51.04914603748342
The injected match filter SNR is (50.302015972725656-0.13345233070825305j)
Initializing heterodyned likelihood..
No reference parameters are provided, finding it...
Starting the optimizer
Using Adam optimization
Warning: Optimization accessed infinite or NaN log-probabilities.
The reference parameters are {'M_c': nan, 'eta': nan, 's1_x': nan, 's1_y': nan, 's1_z': nan, 's2_x': nan, 's2_y': nan, 's2_z': nan, 'd_L': nan, 't_c': nan, 'phase_c': nan, 'iota': nan, 'psi': nan, 'ra': nan, 'dec': nan}
Constructing reference waveforms..
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/ckng/project/jim_testing/sky_location_frame/ra_dec.py", line 103, in <module>
    run_manager = SingleEventPERunManager(run=run)
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jimgw/single_event/runManager.py", line 127, in __init__
    local_likelihood = self.initialize_likelihood(local_prior)
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jimgw/single_event/runManager.py", line 185, in initialize_likelihood
    return likelihood_presets[name](
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jimgw/single_event/likelihood.py", line 298, in __init__
    f_max = jnp.max(f_valid)
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 268, in max
    return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 260, in _reduce_max
    return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
  File "/home/user/ckng/.conda/envs/jim/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 115, in _reduction
    raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
ValueError: zero-size array to reduction operation max which has no identity

Allow freedom in choosing different parametrization

It would be nice to allow users input injection parameters or prior parameters in any parametrization, while keeping options for what parameters to be sampled in. Similarly, allowing users to access the posteriors in any parameters would be handy.

Documentation and packaging

  1. Write a Readme that describe the structure of the project
  2. Write documentation strings in the code before it is too late.
  3. Separate common functions from example scripts and put them into the module.

Posterior/Prior class

In order to rehaul the likelihood construction to be more flexible and user-friendly, we plan to have a likelihood class to construct kernels that will eventually be sampled over.

Here is a non-exhaustive list

  1. Keep track of the number of parameters. I.e. n_dim
  2. Do coordinate transformation upon request. I.e. theta_wf = transf(theta_samp)
  3. Interface with waveform generator.

Tagging @maxisi @tedwards2412

Potential phasing error in IMRPhenomC

Match_filter_SNR with IMRPhenomC drop extremely fast with slight change in mass, which makes learning the inverse mass matrix very difficult, hence wrong in inference.

  • HMC part seems working with TaylorF2
  • It doesn't work well with IMRPhenomC

Update the run manager

Here is some things we might want to add to the run manager:

  • Compute the SNR and save it somewhere.
  • Maybe add support for an SNR threshold, to skip injections below this threshold. See this example script
  • Add support to randomly sample injection parameters from the given prior ranges if no injection parameters are given.
  • Plotting functionalities (see #107 for work in progress)
  • Add functionality to save everything in the runmanager: hyperparameters, save the output chains separately.
  • Add flexible functionality for loading from other RunManagers, e.g., load the injection parameters of a previous run in case the user wants to play around with hyperparameters on a specific injection
  • Add command line functionalities (see #18 as well)
  • Check how to handle real GW events, load data, preprocess it et cetera.
  • Add dedicated examples of the runmanager
  • Add functionality to for doing multiple event runs in a batch (see #145 )

Moving naming tracking into Jim class from Prior class

Currently tracking the naming of parameters is done within the prior class, which limits the flexibility and legibility of parameters within Jim.

It is probably more ergonomic to define parameters and track their naming within Jim, and specify transform only at the interface with the likelihood object and prior object.

This should solve the problem specified in #89 .

Redundant transform in prior class

Since we have added boundToUnbound transforms in sample transform, I think it is no longer necessary to build UniformPrior on top of a base distribution of LogisticDistribution. The only things from the prior class that are currently accessed in JIM are sample() and log_prob(), so as long as these two functions provide the correct output, the implementation details of the prior class under the hood are no longer important. We could simplify the implementation.

Refactoring Jim for an easier adoption for production in LVK

Currently Jim is still in its adolescent stage, which it is usable and energetic but it is not very social and user friendly. To make Jim to be usable and more importantly hired by the community, we need to refactor Jim to have a cleaner API with some features useful for productions runs in the future.

Note that on the top level, while we want to be as user-friendly and modular as possible, this should not come with a cost to performance.

Here is a preliminary list of features needed for the refactoring and next release:

This should be the development main branch where the aforementioned features can be branched off from.

Working example for GW170817 PE with IMRPhenomD

Dear,

Peter Pang and I are currently exploring the use of jim to perform parameter estimation on BNS events, for which we wish to use the TaylorF2 waveform model that we have currently implemented in ripple and will push to the main ripple repository once we have cleaned up the code and finalized all checks against the lalsuite implementation. In order to be able to make full use of jim, we are kindly requesting to have a fully working and updated example that shows how to analyze GW170817 with jim with the IMRPhenomD waveform, in order to reproduce the PE reported in the jim paper. We have been trying to get a good run set up but have been struggling with adapting the hyperparameters starting from the provided GW150914 run to get satisfactory results for GW170817 that agree with the results from Bilby or those mentioned in the paper.

Thanks in advance for any help provided!

Nan in optimization

Currently, the optimization routine uses a population adam optimizer to try to compute the reference parameters, there are some cases this optimizer might return nan, usually due to some boundary issue in the prior.

Without modifying the code, using the unconstrained_uniform prior should alleviate if not solving the problem.

As a first patch it would be good to able to let the user decide what kind of optimizer they want to use. But more important, I think we should think about setting hard boundary in the prior. Because the sampler relies on gradient heavily, perhaps we should discourage such behavior in general

3G detectors in Jim

Currently, Jim only supports 2G detectors (LIGO and VIRGO). It would be beneficial to add 3G detectors such as Einstein Telescope and Cosmic Explorer for a lot of future applications with Jim, and test the robustness of relative binning for these scenarios.

Unit test with Jim

Currently jim has minimal amount of unit test. Since there are more and more devs joining the development, it is better to add unit tests now than later.

Here is a non-exhaustive list of unit tests that are needed:

  • Test the construction of the jim object.
  • Test waveform interface with ripple
  • Test prior construction interface.
  • Test data fetching API

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.