Coder Social home page Coder Social logo

aehmc's Introduction

Aehmc

Pypi Gitter Discord Twitter

AeHMC provides implementations for the HMC and NUTS samplers in Aesara.

FeaturesGet StartedInstallGet helpContribute

Get started

import aesara
from aesara import tensor as at
from aesara.tensor.random.utils import RandomStream

from aeppl import joint_logprob

from aehmc import nuts

# A simple normal distribution
Y_rv = at.random.normal(0, 1)


def logprob_fn(y):
    return joint_logprob(realized={Y_rv: y})[0]


# Build the transition kernel
srng = RandomStream(seed=0)
kernel = nuts.new_kernel(srng, logprob_fn)

# Compile a function that updates the chain
y_vv = Y_rv.clone()
initial_state = nuts.new_state(y_vv, logprob_fn)

step_size = at.as_tensor(1e-2)
inverse_mass_matrix=at.as_tensor(1.0)
chain_info, updates = kernel(initial_state, step_size, inverse_mass_matrix)

next_step_fn = aesara.function([y_vv], chain_info.state.position, updates=updates)

print(next_step_fn(0))
# 1.1034719409361107

Install

The latest release of AeHMC can be installed from PyPI using pip:

pip install aehmc

Or via conda-forge:

conda install -c conda-forge aehmc

The current development branch of AeHMC can be installed from GitHub using pip:

pip install git+https://github.com/aesara-devs/aehmc

Get help

Report bugs by opening an issue. If you have a question regarding the usage of AeHMC, start a discussion. For real-time feedback or more general chat about AeHMC use our Discord server or Gitter room.

Contribute

AeHMC welcomes contributions. A good place to start contributing is by looking at the issues.

If you want to implement a new feature, open a discussion or come chat with us on Discord or Gitter.

aehmc's People

Contributors

brandonwillard avatar rlouf avatar twiecki avatar zoj613 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

aehmc's Issues

Add `KernelType` type annotation

[W]e could really use some type annotations for arguments like this. I honestly don't know the exact form of a valid kernel type (other than that it's a callable of some sort), so this is one instance where the type info really helps. In order to figure this out, I would have to trace where it's used in this codebase, and/or guess and check.

A custom type (alias) that explicitly ties together the functions in this library that construct valid kernel arguments is really all that's necessary to answer basic questions, so there's no need to delve much deeper than KernelType = Callable[[...], ...]—although something like KernelType = NewType("KernelType", Callable[[...], ...]) might be better for maintaining the "this function takes/outputs acceptable kernels" information when the local variable name KernelType isn't present in context (e.g. some Mypy reporting).

Originally posted by @brandonwillard in #66 (comment)

Make HMC and NUTS return extra information

Information such as whether a divergence occurred, the number of integration steps (for NUTS), the value of the energy, etc. can be useful to diagnose sampling issues. Here are some examples:

  • momentum
  • acceptance_probability
  • is_accepted
  • is_divergent
  • energy
  • proposal
  • num_integration_steps

Add Dual Averaging

Related to #21. Dual averaging is an optimization algorithm that is commonly used to adapt the step size of HMC kernels by targeting an acceptance rate.

pre-commit hooks setup is failing.

Description of your problem or feature request

Setting up pre-commit hooks fails due to a bug in isort explained here: PyCQA/isort#2083
Please provide a minimal, self-contained, and reproducible example.

pre-commit install-hooks

Please provide the full traceback of any errors.

[INFO] Installing environment for https://github.com/pycqa/isort.
[INFO] Once installed this environment will be reused.
[INFO] This may take a few minutes...
An unexpected error has occurred: CalledProcessError: command: ('/home/zoj/.cache/pre-commit/repopytli6no/py_env-python3/bin/python', '-mpip', 'install', '.')
return code: 1
stdout:
    Processing /home/zoj/.cache/pre-commit/repopytli6no
      Installing build dependencies: started
      Installing build dependencies: finished with status 'done'
      Getting requirements to build wheel: started
      Getting requirements to build wheel: finished with status 'done'
      Preparing metadata (pyproject.toml): started
      Preparing metadata (pyproject.toml): finished with status 'error'
stderr:
      error: subprocess-exited-with-error

      × Preparing metadata (pyproject.toml) did not run successfully.
      │ exit code: 1
      ╰─> [17 lines of output]
          Traceback (most recent call last):
            File "/home/zoj/.cache/pre-commit/repopytli6no/py_env-python3/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
              main()
            File "/home/zoj/.cache/pre-commit/repopytli6no/py_env-python3/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
              json_out['return_val'] = hook(**hook_input['kwargs'])
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            File "/home/zoj/.cache/pre-commit/repopytli6no/py_env-python3/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 149, in prepare_metadata_for_build_wheel
              return hook(metadata_directory, config_settings)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            File "/tmp/pip-build-env-218ucduw/overlay/lib/python3.11/site-packages/poetry/core/masonry/api.py", line 40, in prepare_metadata_for_build_wheel
              poetry = Factory().create_poetry(Path(".").resolve(), with_groups=False)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            File "/tmp/pip-build-env-218ucduw/overlay/lib/python3.11/site-packages/poetry/core/factory.py", line 57, in create_poetry
              raise RuntimeError("The Poetry configuration is invalid:\n" + message)
          RuntimeError: The Poetry configuration is invalid:
            - [extras.pipfile_deprecated_finder.2] 'pip-shims<=0.3.4' does not match '^[a-zA-Z-_.0-9]+$'

          [end of output]

      note: This error originates from a subprocess, and is likely not a problem with pip.
    error: metadata-generation-failed

    × Encountered error while generating package metadata.
    ╰─> See above for output.

    note: This is an issue with the package mentioned above, not pip.
    hint: See above for details.
Check the log at /home/zoj/.cache/pre-commit/pre-commit.log

Please provide any additional information below.
We need to update the isort version in the precommit config to 5.12.0

Versions and main components

  • Aesara version: 2.8.12
  • Aesara config (python -c "import aesara; print(aesara.config)")
  • Python version: 3.11
  • Operating system: Linux
  • How did you install Aesara: conda

Slow compilation time in warmup tests.

Description of your problem or feature request

The warmup test, particularly

def test_warmup_vector():
is very slow and this is due to the compilation time of warmup_fn.

Please provide a minimal, self-contained, and reproducible example.

python -m pytest tests/test_hmc.py::test_warmup_vector -v

Please provide the full traceback of any errors.
Running the test with profiling turned on produces the following:

Function profiling
==================
  Message:
  Time in 1 calls to Function.__call__: 1.529328e+01s
  Time in Function.vm.__call__: 15.29313341199304s (99.999%)
  Time in thunks: 15.29306435585022s (99.999%)
  Total compilation time: 1.663600e+02s
    Number of Apply nodes: 48
    Aesara rewrite time: 5.106596e+01s
       Aesara validate time: 3.606322e-02s
    Aesara Linker time (includes C, CUDA code generation/compiling): 115.27304236099008s
       Import time 2.332967e-01s
       Node make_thunk time 1.152699e+02s
           Node forall_inplace,cpu,window_adaptation}(TensorConstant{1000}, TensorConstant{[False Fal..lse  True]}, TensorConstant{[  0   1  ..7 998 999]}, Inc
Subtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{Set;:int64:}.0, IncSubtensor{Set;:int64:}
.0, IncSubtensor{Set;:int64:}.0, IncSubtensor{Set;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:
int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, RandomGeneratorSharedVariable(<Generator
(PCG64) at 0x7F36B5424F20>), RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F36B54265E0>), RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F36B542
42E0>), RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F36B5427760>), IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubtens
or{Set;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0) time 1.151257e+02s
           Node AllocEmpty{dtype='float64'}(TensorConstant{11}) time 1.173822e-02s
           Node InplaceDimShuffle{x}(Sum{acc_dtype=float64}.0) time 9.378754e-03s
           Node InplaceDimShuffle{x,0}(<TensorType(float64, (2,))>) time 9.352558e-03s
           Node Elemwise{Composite{(i0 - (i1 * i2))}}[(0, 2)](TensorConstant{(1,) of -2..2469692907}, TensorConstant{(1,) of 0.5}, InplaceDimShuffle{x}.0) time
 5.734075e-03s

Time in all call to aesara.grad() 1.157906e-01s
Time since aesara import 207.771s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  100.0%   100.0%      15.293s       1.53e+01s     Py       1       1   aesara.scan.op.Scan
   0.0%   100.0%       0.000s       1.33e-04s     Py       2       2   aesara.tensor.slinalg.SolveTriangular
   0.0%   100.0%       0.000s       2.24e-06s     C       17      17   aesara.tensor.subtensor.IncSubtensor
   0.0%   100.0%       0.000s       9.06e-06s     C        3       3   aesara.tensor.elemwise.DimShuffle
   0.0%   100.0%       0.000s       1.33e-06s     C       12      12   aesara.tensor.basic.AllocEmpty
   0.0%   100.0%       0.000s       6.20e-07s     C        5       5   aesara.tensor.subtensor.Subtensor
   0.0%   100.0%       0.000s       7.75e-07s     C        4       4   aesara.tensor.elemwise.Elemwise
   0.0%   100.0%       0.000s       1.91e-06s     C        1       1   aesara.tensor.math.Sum
   0.0%   100.0%       0.000s       3.18e-07s     C        3       3   aesara.tensor.shape.Unbroadcast
   ... (remaining 0 Classes account for   0.00%(0.00s) of the runtime)

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  100.0%   100.0%      15.293s       1.53e+01s     Py       1        1   forall_inplace,cpu,window_adaptation}
   0.0%   100.0%       0.000s       2.09e-04s     Py       1        1   SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True}
   0.0%   100.0%       0.000s       5.72e-05s     Py       1        1   SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True}
   0.0%   100.0%       0.000s       2.09e-06s     C       12       12   IncSubtensor{InplaceSet;:int64:}
   0.0%   100.0%       0.000s       1.00e-05s     C        2        2   InplaceDimShuffle{x,0}
   0.0%   100.0%       0.000s       1.38e-06s     C       10       10   AllocEmpty{dtype='float64'}
   0.0%   100.0%       0.000s       2.62e-06s     C        5        5   IncSubtensor{Set;:int64:}
   0.0%   100.0%       0.000s       7.15e-06s     C        1        1   InplaceDimShuffle{x}
   0.0%   100.0%       0.000s       6.20e-07s     C        5        5   Subtensor{uint8}
   0.0%   100.0%       0.000s       1.07e-06s     C        2        2   AllocEmpty{dtype='int64'}
   0.0%   100.0%       0.000s       2.15e-06s     C        1        1   Elemwise{sub,no_inplace}
   0.0%   100.0%       0.000s       1.91e-06s     C        1        1   Sum{acc_dtype=float64}
   0.0%   100.0%       0.000s       3.18e-07s     C        3        3   Unbroadcast{0}
   0.0%   100.0%       0.000s       9.54e-07s     C        1        1   Elemwise{Composite{(i0 - (i1 * i2))}}[(0, 2)]
   0.0%   100.0%       0.000s       0.00e+00s     C        1        1   Elemwise{Sqr}[(0, 0)]
   0.0%   100.0%       0.000s       0.00e+00s     C        1        1   Elemwise{Neg}[(0, 0)]
   ... (remaining 0 Ops account for   0.00%(0.00s) of the runtime)

Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
  100.0%   100.0%      15.293s       1.53e+01s      1    42   forall_inplace,cpu,window_adaptation}(TensorConstant{1000}, TensorConstant{[False Fal..lse  True]
}, TensorConstant{[  0   1  ..7 998 999]}, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubte
nsor{Set;:int64:}.0, IncSubtensor{Set;:int64:}.0, IncSubtensor{Set;:int64:}.0, IncSubtensor{Set;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{In
placeSet;:int64:}.0, IncSubtensor{Inpl
   0.0%   100.0%       0.000s       2.09e-04s      1    25   SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True}(TensorConstant{[[1. 0
.]
 [0. 2.]]}, Elemwise{sub,no_inplace}.0)
   0.0%   100.0%       0.000s       5.72e-05s      1    30   SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True}(TensorConstant{[[1. 
0.]
 [0. 2.]]}, SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True}.0)
   0.0%   100.0%       0.000s       1.50e-05s      1    10   InplaceDimShuffle{x,0}(<TensorType(float64, (2,))>)
   0.0%   100.0%       0.000s       7.15e-06s      1    36   InplaceDimShuffle{x}(Sum{acc_dtype=float64}.0)
   0.0%   100.0%       0.000s       7.15e-06s      1    31   IncSubtensor{InplaceSet;:int64:}(AllocEmpty{dtype='float64'}.0, Unbroadcast{0}.0, ScalarConstant{1
})
   0.0%   100.0%       0.000s       5.01e-06s      1    33   InplaceDimShuffle{x,0}(SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=Tru
e}.0)
   0.0%   100.0%       0.000s       4.05e-06s      1    19   IncSubtensor{Set;:int64:}(AllocEmpty{dtype='int64'}.0, TensorConstant{(1,) of 1}, ScalarConstant{1
})
   0.0%   100.0%       0.000s       3.10e-06s      1    37   IncSubtensor{InplaceSet;:int64:}(AllocEmpty{dtype='float64'}.0, Unbroadcast{0}.0, ScalarConstant{1
})
   0.0%   100.0%       0.000s       3.10e-06s      1    22   IncSubtensor{Set;:int64:}(AllocEmpty{dtype='float64'}.0, TensorConstant{(1,) of 0.0}, ScalarConsta
nt{1})
   0.0%   100.0%       0.000s       3.10e-06s      1    13   AllocEmpty{dtype='float64'}(TensorConstant{2}, TensorConstant{2})
   0.0%   100.0%       0.000s       2.86e-06s      1    41   IncSubtensor{InplaceSet;:int64:}(AllocEmpty{dtype='float64'}.0, Elemwise{Neg}[(0, 0)].0, ScalarCon
stant{1})
   0.0%   100.0%       0.000s       2.15e-06s      1    23   IncSubtensor{InplaceSet;:int64:}(AllocEmpty{dtype='float64'}.0, TensorConstant{(1, 2) of 0.0}, Sca
larConstant{1})
   0.0%   100.0%       0.000s       2.15e-06s      1    14   IncSubtensor{Set;:int64:}(AllocEmpty{dtype='float64'}.0, TensorConstant{(1,) of -inf}, ScalarConst
ant{1})
   0.0%   100.0%       0.000s       2.15e-06s      1    11   AllocEmpty{dtype='float64'}(TensorConstant{2}, TensorConstant{2})
   0.0%   100.0%       0.000s       2.15e-06s      1     9   Elemwise{sub,no_inplace}(<TensorType(float64, (2,))>, TensorConstant{[0. 3.]})
   0.0%   100.0%       0.000s       1.91e-06s      1    47   Subtensor{uint8}(forall_inplace,cpu,window_adaptation}.0, ScalarConstant{1})
   0.0%   100.0%       0.000s       1.91e-06s      1    34   Sum{acc_dtype=float64}(Elemwise{Sqr}[(0, 0)].0)
   0.0%   100.0%       0.000s       1.91e-06s      1    24   IncSubtensor{InplaceSet;:int64:}(AllocEmpty{dtype='float64'}.0, TensorConstant{(1, 2) of 0.0}, Sca
larConstant{1})
   0.0%   100.0%       0.000s       1.91e-06s      1    21   IncSubtensor{Set;:int64:}(AllocEmpty{dtype='float64'}.0, TensorConstant{(1,) of 0.0}, ScalarConsta
nt{1})
   ... (remaining 28 Apply instances account for 0.00%(0.00s) of the runtime)


Scan overhead:
<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(% scan op time)> <sub scan op time(% scan op time)> <node>
  One scan node do not have its inner profile enabled. If you enable Aesara profiler with 'aesara.function(..., profile=True)', you must manually enable the pr
ofiling for each scan too: 'aesara.scan(...,profile=True)'. Or use Aesara flag 'profile=True'.
  No scan have its inner profile enabled.
Here are tips to potentially make your code run faster
                 (if you think of new ones, suggest them on the mailing list).
                 Test them first, as they are not guaranteed to always provide a speedup.
  - Try the Aesara flag floatX=float32

real    3m30.446s
user    3m12.514s
sys     0m13.513s

Please provide any additional information below.

Versions and main components

  • Aesara version: 2.8.12
  • Aesara config (python -c "import aesara; print(aesara.config)"):
Details

floatX ({'float32', 'float64', 'float16'}) 
    Doc:  Default floating-point precision for python casts.

Note: float16 support is experimental, use at your own risk.
    Value:  float64

warn_float64 ({'pdb', 'ignore', 'warn', 'raise'}) 
    Doc:  Do an action when a tensor variable with float64 dtype is created.
    Value:  ignore

pickle_test_value (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdca2cc2490>>) 
    Doc:  Dump test values while pickling model. If True, test values will be dumped with model.
    Value:  True

cast_policy ({'custom', 'numpy+floatX'}) 
    Doc:  Rules for implicit type casting
    Value:  custom

deterministic ({'default', 'more'}) 
    Doc:  If `more`, sometimes we will select some implementation that are more deterministic, but slower.  Also see the dnn.conv.algo* flags to cover more cas
es.
    Value:  default

device (cpu)
    Doc:  Default device for computations. only cpu is supported for now
    Value:  cpu

force_device (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93c9f150>>) 
    Doc:  Raise an error if we can't use the specified device
    Value:  False

conv__assert_shape (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc98f5f190>>) 
    Doc:  If True, AbstractConv* ops will verify that user-provided shapes match the runtime shapes (debugging option, may slow down compilation)
    Value:  False

print_global_stats (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93d150d0>>) 
    Doc:  Print some global statistics (time spent) at the end
    Value:  False

assert_no_cpu_op ({'pdb', 'ignore', 'warn', 'raise'}) 
    Doc:  Raise an error/warning if there is a CPU op in the computational graph.
    Value:  ignore

unpickle_function (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93c9f450>>) 
    Doc:  Replace unpickled Aesara functions with None. This is useful to unpickle old graphs that pickled them when it shouldn't
    Value:  True

<aesara.configparser.ConfigParam object at 0x7fdc93c9f510>
    Doc:  Default compilation mode
    Value:  Mode

cxx (<class 'str'>) 
    Doc:  The C++ compiler to use. Currently only g++ is supported, but supporting additional compilers should not be too difficult. If it is empty, no C++ cod
e is compiled.
    Value:  /home/zoj/micromamba/envs/aehmc-dev/bin/g++

linker ({'vm', 'cvm_nogc', 'c|py_nogc', 'c', 'py', 'c|py', 'cvm', 'vm_nogc'}) 
    Doc:  Default linker used if the aesara flags mode is Mode
    Value:  cvm

allow_gc (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93c9f690>>) 
    Doc:  Do we default to delete intermediate results during Aesara function calls? Doing so lowers the memory requirement, but asks that we reallocate memory
 at the next function call. This is implemented for the default linker, but may not work for all linkers.
    Value:  True

optimizer ({'None', 'o1', 'fast_compile', 'o2', 'o3', 'unsafe', 'merge', 'fast_run', 'o4'}) 
    Doc:  Default optimizer. If not None, will use this optimizer with the Mode
    Value:  o4

optimizer_verbose (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdca174c050>>) 
    Doc:  If True, we print all optimization being applied
    Value:  False

on_opt_error ({'pdb', 'ignore', 'warn', 'raise'}) 
    Doc:  What to do when an optimization crashes: warn and skip it, raise the exception, or fall into the pdb debugger.
    Value:  warn

nocleanup (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93c9f950>>) 
    Doc:  Suppress the deletion of code files that did not compile cleanly
    Value:  False

on_unused_input ({'ignore', 'warn', 'raise'}) 
    Doc:  What to do if a variable in the 'inputs' list of  aesara.function() is not used in the graph.
    Value:  raise

gcc__cxxflags (<class 'str'>) 
    Doc:  Extra compiler flags for gcc
    Value:  

cmodule__warn_no_version (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93c9fad0>>) 
    Doc:  If True, will print a warning when compiling one or more Op with C code that can't be cached because there is no c_code_cache_version() function asso
ciated to at least one of those Ops.
    Value:  False

cmodule__remove_gxx_opt (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc98f5f250>>) 
    Doc:  If True, will remove the -O* parameter passed to g++.This is useful to debug in gdb modules compiled by Aesara.The parameter -g is passed by default 
to g++
    Value:  False

cmodule__compilation_warning (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93c9fcd0>>) 
    Doc:  If True, will print compilation warnings.
    Value:  False

cmodule__preload_cache (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93c9fdd0>>) 
    Doc:  If set to True, will preload the C module cache at import time
    Value:  False

cmodule__age_thresh_use (<class 'int'>) 
    Doc:  In seconds. The time after which Aesara won't reuse a compile c module.
    Value:  2073600

cmodule__debug (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93c9fd50>>) 
    Doc:  If True, define a DEBUG macro (if not exists) for any compiled C code.
    Value:  False

compile__wait (<class 'int'>) 
    Doc:  Time to wait before retrying to acquire the compile lock.
    Value:  5

compile__timeout (<class 'int'>) 
    Doc:  In seconds, time that a process will wait before deciding to
    override an existing lock. An override only happens when the existing
    lock is held by the same owner *and* has not been 'refreshed' by this
    owner for more than this period. Refreshes are done every half timeout
    period for running processes.
    Value:  120

ctc__root (<class 'str'>) 
    Doc:  Directory which contains the root of Baidu CTC library. It is assumed         that the compiled library is either inside the build, lib or lib64     
    subdirectory, and the header inside the include directory.
    Value:  

tensor__cmp_sloppy (<class 'int'>) 
    Doc:  Relax aesara.tensor.math._allclose (0) not at all, (1) a bit, (2) more
    Value:  0

tensor__local_elemwise_fusion (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca8250>>) 
    Doc:  Enable or not in fast_run mode(fast_run optimization) the elemwise fusion optimization
    Value:  True

lib__amblibm (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca8350>>) 
    Doc:  Use amd's amdlibm numerical library
    Value:  False

tensor__insert_inplace_optimizer_validate_nb (<class 'int'>) 
    Doc:  -1: auto, if graph have less then 500 nodes 1, else 10
    Value:  -1

traceback__limit (<class 'int'>) 
    Doc:  The number of stack to trace. -1 mean all.
    Value:  8

traceback__compile_limit (<class 'int'>) 
    Doc:  The number of stack to trace to keep during compilation. -1 mean all. If greater then 0, will also make us save Aesara internal stack trace.
    Value:  0

experimental__local_alloc_elemwise (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca8750>>) 
    Doc:  DEPRECATED: If True, enable the experimental optimization local_alloc_elemwise. Generates error if not True. Use optimizer_excluding=local_alloc_elem
wise to disable.
    Value:  True

experimental__local_alloc_elemwise_assert (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdca2cc0ed0>>) 
    Doc:  When the local_alloc_elemwise is applied, add an assert to highlight shape errors.
    Value:  True

warn__ignore_bug_before ({'0.3', '0.5', '1.0', '0.9', '0.7', '0.8', '1.0.5', '1.0.4', '1.0.2', '1.0.3', '0.8.2', '0.4.1', 'None', '0.8.1', '0.4', 'all', '0.10'
, '1.0.1', '0.6'}) 
    Doc:  If 'None', we warn about all Aesara bugs found by default. If 'all', we don't warn about Aesara bugs found by default. If a version, we print only th
e warnings relative to Aesara bugs found after that version. Warning for specific bugs can be configured with specific [warn] flags.
    Value:  0.9

exception_verbosity ({'low', 'high'}) 
    Doc:  If 'low', the text of exceptions will generally refer to apply nodes with short names such as Elemwise{add_no_inplace}. If 'high', some exceptions wi
ll also refer to apply nodes with long descriptions  like:
        A. Elemwise{add_no_inplace}
                B. log_likelihood_v_given_h
                C. log_likelihood_h
    Value:  low

print_test_value (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdca2cc0dd0>>) 
    Doc:  If 'True', the __eval__ of an Aesara variable will return its test_value when this is available. This has the practical conseguence that, e.g., in de
bugging `my_var` will print the same as `my_var.tag.test_value` when a test value is defined.
    Value:  False

compute_test_value ({'ignore', 'warn', 'raise', 'off', 'pdb'}) 
    Doc:  If 'True', Aesara will run each op at graph build time, using Constants, SharedVariables and the tag 'test_value' as inputs to the function. This hel
ps the user track down problems in the graph before it gets optimized.
    Value:  off

compute_test_value_opt ({'ignore', 'warn', 'raise', 'off', 'pdb'}) 
    Doc:  For debugging Aesara optimization only. Same as compute_test_value, but is used during Aesara optimization
    Value:  off

check_input (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdca0e97310>>) 
    Doc:  Specify if types should check their input in their C code. It can be used to speed up compilation, reduce overhead (particularly for scalars) and red
uce the number of generated C files.
    Value:  True

NanGuardMode__nan_is_error (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdca2cc0e90>>) 
    Doc:  Default value for nan_is_error
    Value:  True

NanGuardMode__inf_is_error (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca8dd0>>) 
    Doc:  Default value for inf_is_error
    Value:  True

NanGuardMode__big_is_error (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca8fd0>>) 
    Doc:  Default value for big_is_error
    Value:  True

NanGuardMode__action ({'pdb', 'warn', 'raise'}) 
    Doc:  What NanGuardMode does when it finds a problem
    Value:  raise

DebugMode__patience (<class 'int'>) 
    Doc:  Optimize graph this many times to detect inconsistency
    Value:  10

DebugMode__check_c (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca9150>>) 
    Doc:  Run C implementations where possible
    Value:  True

DebugMode__check_py (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca9290>>) 
    Doc:  Run Python implementations where possible
    Value:  True

DebugMode__check_finite (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca9390>>) 
    Doc:  True -> complain about NaN/Inf results
    Value:  True

DebugMode__check_strides (<class 'int'>) 
    Doc:  Check that Python- and C-produced ndarrays have same strides. On difference: (0) - ignore, (1) warn, or (2) raise error
    Value:  0

DebugMode__warn_input_not_reused (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca93d0>>) 
    Doc:  Generate a warning when destroy_map or view_map says that an op works inplace, but the op did not reuse the input for its output.
    Value:  True

DebugMode__check_preallocated_output (<class 'str'>) 
    Doc:  Test thunks with pre-allocated memory as output storage. This is a list of strings separated by ":". Valid values are: "initial" (initial storage in 
storage map, happens with Scan),"previous" (previously-returned memory), "c_contiguous", "f_contiguous", "strided" (positive and negative strides), "wrong_size
" (larger and smaller dimensions), and "ALL" (all of the above).
    Value:  

DebugMode__check_preallocated_output_ndim (<class 'int'>) 
    Doc:  When testing with "strided" preallocated output memory, test all combinations of strides over that number of (inner-most) dimensions. You may want to
 reduce that number to reduce memory or time usage, but it is advised to keep a minimum of 2.
    Value:  4

profiling__time_thunks (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca96d0>>) 
    Doc:  Time individual thunks when profiling
    Value:  True

profiling__n_apply (<class 'int'>) 
    Doc:  Number of Apply instances to print by default
    Value:  20

profiling__n_ops (<class 'int'>) 
    Doc:  Number of Ops to print by default
    Value:  20

profiling__output_line_width (<class 'int'>) 
    Doc:  Max line width for the profiling output
    Value:  512

profiling__min_memory_size (<class 'int'>) 
    Doc:  For the memory profile, do not print Apply nodes if the size
                 of their outputs (in bytes) is lower than this threshold
    Value:  1024

profiling__min_peak_memory (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca9990>>) 
    Doc:  The min peak memory usage of the order
    Value:  False

profiling__destination (<class 'str'>) 
    Doc:  File destination of the profiling output
    Value:  stderr

profiling__debugprint (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca9c90>>) 
    Doc:  Do a debugprint of the profiled functions
    Value:  False

profiling__ignore_first_call (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca9bd0>>) 
    Doc:  Do we ignore the first call of an Aesara function.
    Value:  False

on_shape_error ({'warn', 'raise'}) 
    Doc:  warn: print a warning and use the default value. raise: raise an error
    Value:  warn

openmp (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93ca9a50>>) 
    Doc:  Allow (or not) parallel computation on the CPU with OpenMP. This is the default value used when creating an Op that supports OpenMP parallelization. 
It is preferable to define it via the Aesara configuration file ~/.aesararc or with the environment variable AESARA_FLAGS. Parallelization is only done for som
e operations that implement it, and even for operations that implement parallelism, each operation is free to respect this flag or not. You can control the num
ber of threads used with the environment variable OMP_NUM_THREADS. If it is set to 1, we disable openmp in Aesara by default.
    Value:  False

openmp_elemwise_minsize (<class 'int'>) 
    Doc:  If OpenMP is enabled, this is the minimum size of vectors for which the openmp parallelization is enabled in element wise ops.
    Value:  200000

optimizer_excluding (<class 'str'>) 
    Doc:  When using the default mode, we will remove optimizer with these tags. Separate tags with ':'.
    Value:  

optimizer_including (<class 'str'>) 
    Doc:  When using the default mode, we will add optimizer with these tags. Separate tags with ':'.
    Value:  

optimizer_requiring (<class 'str'>) 
    Doc:  When using the default mode, we will require optimizer with these tags. Separate tags with ':'.
    Value:  

optdb__position_cutoff (<class 'float'>) 
    Doc:  Where to stop eariler during optimization. It represent the position of the optimizer where to stop.
    Value:  inf

optdb__max_use_ratio (<class 'float'>) 
    Doc:  A ratio that prevent infinite loop in EquilibriumGraphRewriter.
    Value:  8.0

cycle_detection ({'fast', 'regular'}) 
    Doc:  If cycle_detection is set to regular, most inplaces are allowed,but it is slower. If cycle_detection is set to faster, less inplacesare allowed, but 
it makes the compilation faster.The interaction of which one give the lower peak memory usage iscomplicated and not predictable, so if you are close to the pea
kmemory usage, triyng both could give you a small gain.
    Value:  regular

check_stack_trace ({'off', 'warn', 'raise', 'log'}) 
    Doc:  A flag for checking the stack trace during the optimization process. default (off): does not check the stack trace of any optimization log: inserts a
 dummy stack trace that identifies the optimizationthat inserted the variable that had an empty stack trace.warn: prints a warning if a stack trace is missing 
and also a dummystack trace is inserted that indicates which optimization insertedthe variable that had an empty stack trace.raise: raises an exception if a st
ack trace is missing
    Value:  off

metaopt__verbose (<class 'int'>) 
    Doc:  0 for silent, 1 for only warnings, 2 for full output withtimings and selected implementation
    Value:  0

metaopt__optimizer_excluding (<class 'str'>) 
    Doc:  exclude optimizers with these tags. Separate tags with ':'.
    Value:  

metaopt__optimizer_including (<class 'str'>) 
    Doc:  include optimizers with these tags. Separate tags with ':'.
    Value:  

profile (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93caa4d0>>) 
    Doc:  If VM should collect profile information
    Value:  False

profile_optimizer (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93caa550>>) 
    Doc:  If VM should collect optimizer profile information
    Value:  False

profile_memory (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93caa5d0>>) 
    Doc:  If VM should collect memory profile information and print it
    Value:  False

<aesara.configparser.ConfigParam object at 0x7fdc93caa650>
    Doc:  Useful only for the VM Linkers. When lazy is None, auto detect if lazy evaluation is needed and use the appropriate version. If the C loop isn't bein
g used and lazy is True, use the Stack VM; otherwise, use the Loop VM.
    Value:  None

unittests__rseed (<class 'str'>) 
    Doc:  Seed to use for randomized unit tests. Special value 'random' means using a seed of None.
    Value:  666

warn__round (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93caa7d0>>) 
    Doc:  Warn when using `tensor.round` with the default mode. Round changed its default from `half_away_from_zero` to `half_to_even` to have the same default
 as NumPy.
    Value:  False

numba__vectorize_target ({'cuda', 'parallel', 'cpu'}) 
    Doc:  Default target for numba.vectorize.
    Value:  cpu

numba__fastmath (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93caa950>>) 
    Doc:  If True, use Numba's fastmath mode.
    Value:  True

numba__cache (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93caa9d0>>) 
    Doc:  If True, use Numba's file based caching.
    Value:  True

compiledir_format (<class 'str'>) 
    Doc:  Format string for platform-dependent compiled module subdirectory
(relative to base_compiledir). Available keys: aesara_version, device,
gxx_version, hostname, numpy_version, platform, processor,
python_bitwidth, python_int_bitwidth, python_version, short_platform.
Defaults to compiledir_%(short_platform)s-%(processor)s-
%(python_version)s-%(python_bitwidth)s.
    Value:  compiledir_%(short_platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s

<aesara.configparser.ConfigParam object at 0x7fdca25add90>
    Doc:  platform-independent root directory for compiled modules
    Value:  /home/zoj/.aesara

<aesara.configparser.ConfigParam object at 0x7fdc93cab250>
    Doc:  platform-dependent cache directory for compiled modules
    Value:  /home/zoj/.aesara/compiledir_Linux-6.1--lts-x86_64-with-glibc2.37--3.11.3-64

blas__ldflags (<class 'str'>) 
    Doc:  lib[s] to include for [Fortran] level-3 blas implementation
    Value:  -L/home/zoj/micromamba/envs/aehmc-dev/lib -lcblas -lblas -lcblas -lblas

blas__check_openmp (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc93922b90>>) 
    Doc:  Check for openmp library conflict.
WARNING: Setting this to False leaves you open to wrong results in blas-related operations.
    Value:  True

scan__allow_gc (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc8ffbae50>>) 
    Doc:  Allow/disallow gc inside of Scan (default: False)
    Value:  False

scan__allow_output_prealloc (<bound method BoolParam._apply of <aesara.configparser.BoolParam object at 0x7fdc907f4c50>>) 
    Doc:  Allow/disallow memory preallocation for outputs inside of scan (default: True)
    Value:  True

  • Python version: 3.11.3
  • Operating system: Linux
  • How did you install Aesara: conda

Fix the example notebook

It is currently broken, probably due to a poorly handled merge conflict. We should consider ditching notebooks for markdown files (for future integration in documentation) like I did with Blackjax. I do not regret that decision.

`test_hmc_mcse` unit test suddenly failing to pass due to `aesara.gradient.DisconnectedInputError` error.

Description of your problem or feature request

CI tests are failing due to the tests for test_hmc_mcse raising a aesara.gradient.DisconnectedInputError exception.

Please provide a minimal, self-contained, and reproducible example.

python -m pytest tests/test_hmc.py::test_hmc_mcse -v

Please provide the full traceback of any errors.

tests/test_hmc.py:203:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../micromamba/envs/aehmc-dev/lib/python3.11/site-packages/aesara/scan/basic.py:864: in scan
    raw_inner_outputs = fn(*args)
aehmc/hmc.py:132: in step
    ), updates = proposal_generator(
aehmc/hmc.py:222: in propose
    ), updates = integrate(q, p, potential_energy, potential_energy_grad, step_size)
aehmc/trajectory.py:92: in integrate
    [q, p, energy, energy_grad], updates = aesara.scan(
../../micromamba/envs/aehmc-dev/lib/python3.11/site-packages/aesara/scan/basic.py:864: in scan
    raw_inner_outputs = fn(*args)
aehmc/trajectory.py:87: in one_step
    new_state = integrator(
aehmc/integrators.py:68: in one_step
    potential_energy_grad = aesara.grad(potential_energy, position)
../../micromamba/envs/aehmc-dev/lib/python3.11/site-packages/aesara/gradient.py:608: in grad
    handle_disconnected(elem)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

var = Elemwise{add,no_inplace}.0

    def handle_disconnected(var):
        message = (
            "grad method was asked to compute the gradient "
            "with respect to a variable that is not part of "
            "the computational graph of the cost, or is used "
            f"only by a non-differentiable operator: {var}"
        )
        if disconnected_inputs == "ignore":
            pass
        elif disconnected_inputs == "warn":
            warnings.warn(message, stacklevel=2)
        elif disconnected_inputs == "raise":
            message = utils.get_variable_trace_string(var)
>           raise DisconnectedInputError(message)
E           aesara.gradient.DisconnectedInputError:
E           Backtrace when that variable is created:
E
E             File "/home/zoj/dev/aehmc/tests/test_hmc.py", line 203, in test_hmc_mcse
E               trajectory, updates = aesara.scan(
E             File "/home/zoj/micromamba/envs/aehmc-dev/lib/python3.11/site-packages/aesara/scan/basic.py", line 864, in scan
E               raw_inner_outputs = fn(*args)
E             File "/home/zoj/dev/aehmc/aehmc/hmc.py", line 132, in step
E               ), updates = proposal_generator(
E             File "/home/zoj/dev/aehmc/aehmc/hmc.py", line 222, in propose
E               ), updates = integrate(q, p, potential_energy, potential_energy_grad, step_size)
E             File "/home/zoj/dev/aehmc/aehmc/trajectory.py", line 92, in integrate
E               [q, p, energy, energy_grad], updates = aesara.scan(
E             File "/home/zoj/micromamba/envs/aehmc-dev/lib/python3.11/site-packages/aesara/scan/basic.py", line 864, in scan
E               raw_inner_outputs = fn(*args)
E             File "/home/zoj/dev/aehmc/aehmc/trajectory.py", line 87, in one_step
E               new_state = integrator(
E             File "/home/zoj/dev/aehmc/aehmc/integrators.py", line 65, in one_step
E               position = position + a2 * step_size * kinetic_grad

../../micromamba/envs/aehmc-dev/lib/python3.11/site-packages/aesara/gradient.py:594: DisconnectedInputError

Versions and main components

  • Aesara version: 2.8.12
  • Aesara config (python -c "import aesara; print(aesara.config)")
  • Python version: 3.11
  • Operating system: Linux
  • How did you install Aesara: conda

Initialize `RaveledParamsMap` with dictionaries

Currently one has to pass an iterable (that is then converted to a tuple) to initialize RaveledParamsMap:

import aesara as at
from aehmc.utils import RaveledParamsMap

tau_vv = at.vector("tau")
lambda_vv = at.vector("lambda")

rp_map = RaveledParamsMap((tau_vv, lambda_vv))

q = rp_map.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map.unravel_params(q)[tau_vv]
lambda_part = rp_map.unravel_params(q)[lambda_vv]

In some circumstances we need the map to be indexed with other variables. For instance when we work with transformed variables and need the map to link the original value variables to the transformed variables (which may have different shapes/dtypes). In this case we need to overwrite the ref_params property:

from aeppl.transforms import LogTransform

lambda_vv_trans = LogTransform().forward(lambda_vv)

rp_map_trans = RaveledParamsMap((tau_vv, lambda_vv_trans))
rp_map_trans.ref_params = (tau_vv, lambda_vv)

q = rp_map_trans.ravel_params((tau_vv, lambda_vv))
tau_part = rp_map_trans.unravel_params(q)[tau_vv]
lambda_trans_part = rp_map_trans.unravel_params(q)[lambda_vv]

I suggest to simplify this by allowing the RaveledParamsMap to be initialized with a dictionary:

rp_map_trans = RaveledParamsMap({tau_vv: tau_vv, lambda_vv: lambda_vv_trans})

Shapes and dtypes are infered from the dictionaries' values, the map is indexed by the dictionaries' keys.

README example not working as expected.

The readme example throws an exception on the latest version. It requires a missing positional argument step_size in the kernel function call.

Please provide a minimal, self-contained, and reproducible example.

Please provide the full traceback of any errors.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-4acbdcf584da> in <module>
     16 # Build the transition kernel
     17 srng = RandomStream(seed=0)
---> 18 kernel = nuts.kernel(
     19     srng,
     20     logprob_fn,

TypeError: kernel() missing 1 required positional argument: 'step_size'

Please provide any additional information below.

Versions and main components

  • Aesara version: 2.3.3
  • Python version: 3.10.0
  • Operating system: Arch
  • How did you install Aesara: pip

Initial Aesara HMC conversion

We want to start this Aesara HMC implementation by converting a simple, yet representative function from blackjax. I believe we decided to try the velocity_verlet function.

From what I can tell, the following will be need to be done:

  • Clarify the inputs and return value of velocity_verlet.
    • Aesara isn't based on the compilation/tracing of functions, so, instead of potential and kinetic energy functions, those arguments will need to be potential and kinetic energy graphs. Such graphs are essentially the body of the original functions, so there's not much of a change.
  • Create an Aesara Type for IntegratorState and, eventually, a corresponding Numba conversion.
    • This is probably a worthwhile convenience, but it's not absolutely necessary. I believe that ParamsType is essentially what we want, so there might not be much to do anyway.
  • Replace jax.tree_util.tree_multimap with Aesara's Scan Op.

Update step size dtypes

In #86, I updated the step size dtypes to int64, but I imagine it would be better to perform that conversion somewhere within AeHMC; however, I don't know the best place for that at the moment.

@rlouf, if you get a minute, tell me where you think an astype could be performed so that users can specify int32 and avoid those Scan casting issues.

Originally posted by @brandonwillard in #86 (comment)

Handle `dtype`s in a systematic way

Type errors in #35 are the consequence of me not putting enough thoughts into types. I should go through the whole code base and pay attention to places where type errors might creep in.

Change `at.inv` to `at.reciprocal`

Is there a reason why at.inv is used instead of 1 / inverse_mass_matrix when inverse_mass_matrix is 1-dimensional?

mass_matrix_sqrt = aet.sqrt(aet.inv(inverse_mass_matrix))

Does it have an advantage over the latter? Apologies if this is obvious to some but I was confused by the line when I read it.

Warning when running the `HMC.ipynb` notebook

Although the code runs and the results seem unaffected, cell 12 returns the following warning:

WARNING (aesara.tensor.basic_opt): Failed to infer_shape from Op normal_rv{0, (0, 0), floatX, False}.
Input shapes: [None, (TensorConstant{1},), (), (), ()]
Exception encountered during infer_shape: <class 'ValueError'>
Exception message: Length of <TensorType(int64, vector)> cannot be determined

This was discovered in #29 but also happens on main with aeppl v0.0.12. This warning did not appear with previous versions in aeppl so I believe this is linked to a recent change in the library. Here is the corresponding subgraph:

> |   |InplaceDimShuffle{} [id DJ] ''   
 > |   | |Gemv{no_inplace} [id DK] ''   
 > |   |   |AllocEmpty{dtype='float64'} [id DL] ''   
 > |   |   | |TensorConstant{1} [id DM]
 > |   |   |TensorConstant{1.0} [id DN]
 > |   |   |InplaceDimShuffle{x,0} [id DO] ''   
 > |   |   | |Elemwise{Mul}[(0, 1)] [id DP] ''   
 > |   |   |   |<TensorType(float64, vector)> [id DQ] -> [id DA]
 > |   |   |   |normal_rv{0, (0, 0), floatX, False}.1 [id DR] ''   
 > |   |   |     |<RandomGeneratorType> [id DS] -> [id CU]
 > |   |   |     |<TensorType(int64, vector)> [id DT] -> [id CY]
 > |   |   |     |TensorConstant{11} [id DU]
 > |   |   |     |TensorConstant{0} [id DV]
 > |   |   |     |TensorConstant{1} [id DM]
 > |   |   |Elemwise{mul,no_inplace} [id DW] ''   
 > |   |   | |inverse_mass_matrix_copy [id DX] -> [id CV]
 > |   |   | |<TensorType(float64, vector)> [id DQ] -> [id DA]
 > |   |   | |normal_rv{0, (0, 0), floatX, False}.1 [id DR] ''   
 > |   |   |TensorConstant{0.0} [id DY]

Interestingly the previous cells that apply the kernel for one step outside of the scan loop do not return this warning. I am working on a minimal example and will move the issue to aeppl if necessary.

Rename `hmc.kernel` and `nuts.kernel`

The function name kernel can be misleading as this function creates a new kernel. We should rename it to something along the lines of build_kernel.

Kernels should have `kernel(state, *parameters)` signature

We currently specialize the HMC and NUTS kernels in the factory using closures. However this is unpractical, we are moving away from this design in blackjax, see the related discussion.

The HMC kernel factory has the following signature

new_kernel(srrng, logprob_fn, step_size, inverse_mass_matrix, num_integration_steps, divergence_threshold)

and the HMC kernel:

kernel(q, log_prob, log_prob_grad)

And we suggest to instead have:

new_kernel(srng, logprob_fn, inverse_mass_matrix, num_integration_steps, divergence_threshold)
kernel(q, log_prob, log_prob_grad, step_size)

I bumped into this design issue while implementing algorithms for step size adaptation where we have to "create" as many kernels as we change the values of the parameters.

I think this issue should be addressed before we move forward with the adaptation.

Broadcasting issues with scalar random variables

I am not sure this is the best repository to raise this issue, we can always transfer it to aeppl or aesara should we choose to address it there.

I was working on a very simple example for the README when I stumbled upon this issue. Consider the following:

import aesara
import aesara.tensor as at
from aesara.tensor.random.utils import RandomStream
from aeppl import joint_logprob
from aehmc import hmc as hmc


Y_rv = at.random.normal(0, 1)


def logprob_fn(y):
    logprob = joint_logprob(Y_rv, {Y_rv: y})
    return logprob


# Build the transition kernel
srng = RandomStream(seed=0)
kernel = hmc.kernel(
    srng,
    logprob_fn,
    step_size=1e-3,
    inverse_mass_matrix=at.ones(1),
    num_integration_steps=10,
)

# Compile a function that updates the chain
q = Y_rv.clone()
potential_energy = -logprob_fn(q)
potential_energy_grad = aesara.grad(potential_energy, wrt=q)

next_step = kernel(q, potential_energy, potential_energy_grad)

which raises the following exception:

TypeError: Cannot convert Type TensorType(float64, (True,)) (of Variable Elemwise{add,no_inplace}.0) into Type TensorType(float64, scalar). You can try to manually convert Elemwise{add,no_inplace}.0 into a TensorType(float64, scalar).

The exception is raised when computing log_prob in the integrator. Indeed, the momentum generator returns a vector of size 1 so the updated value of q is no longer a scalar but a vector of size 1.

I have tried adding ifelse statements so the momentum is drawn from normal(0,1) instead of normal(0, 1, size=size) but then both branches don't have the same type. How would you go about this edge case?

Incorrect momentum samples generated when inverse mass matrix is 2D.

Description of your problem or feature request

The aim is to compute the lower triangular factor of the inverse mass matrix (lets call this $L$) such that the inverse mass matrix can be factored as $M^{-1} = LL^T$. Next we want to use $L$ in order to generate a sample from $\mathcal{N}(0, M)$ using the formula $(L^T)^{-1} z$ where $z$ is a standard normal sample. This is because the actual distribution's covariance ($M$) is factored as $(L^T)^{-1}L^{-1}$. From looking at https://github.com/aesara-devs/aehmc/blob/d54e2d05512d8d3d4aea92b8732854e6794296e8/aehmc/metrics.py#L54-L57, it seems like the call to cholesky returns $L$ (since by default the lower triangular matrix is returned). Then we try to invert this lower triangular to obtain the actual factor to use for generating a sample from this distribution (using my notation above, this is $(L^T)^{-1}$. But the call to solve_triangular appears to be solving $LX = I$ which returns $L^{-1}$, instead of $(L^T)^{-1}$.

Please provide any additional information below.
To fix this we need to either make sure cholesky returns $L^T$ or make sure solve_triangular call solves the right equation by specifying trans=True.

Versions and main components

  • Aesara version: 2.8.12
  • Python version: 3.11.3
  • Operating system: Linux
  • How did you install Aesara: conda

Add function to create a new HMC state

It is cumbersome to have to systematically write:

q = aet.vector('q')
potential_energy = -logprob_fn(q)
potential_energy_grad = aesara.grad(potentiel_energy, wrt=q)

when initializing the chain's state (HMC or NUTS) when we could instead just write:

init_state = hmc.new_state(q, logprob_fn)

We should implement this new_state helper function which will work for any algorithm in the HMC family.

The warmup should be exposed as a kernel

Following the outline in blackjax-devs/blackjax#171. The reason we want to do that is to be able to interleave sampling and warmup steps when not all variables are sampled with the HMC/NUTS sampler; the current API to the warmup just runs the warmup for a set number of steps.

Be more specific

This also means we could separate step functions and adaptation steps (that take chain state as an input)?

Add high-level API

Currently there's a lot of code to set up the kernel etc, would be nice to wrap this into a high-level function similar to pm.sample().

Typo in variable name.

I think kinetic_ernergy is supposed to be kinetic_energy. See:

aehmc/aehmc/hmc.py

Lines 22 to 27 in 87b0c4d

momentum_generator, kinetic_ernergy_fn, _ = metrics.gaussian_metric(
inverse_mass_matrix
)
symplectic_integrator = integrators.velocity_verlet(
potential_fn, kinetic_ernergy_fn

The latter spelling is used in other places like

kinetic_grad = aesara.grad(kinetic_energy_fn(new_momentum), new_momentum)

NUTS kernel fails with large log-probability values and step size

import aesara
import aesara.tensor as at
from aesara.tensor.random import RandomStream
from aeppl import joint_logprob
from aehmc import nuts

srng = at.random.RandomStream(seed=0)
Y_rv = srng.normal(1, 2)

def logprob_fn(y):
    logprob = 1e20 * joint_logprob({Y_rv: y})
    return logprob

y_vv = Y_rv.clone()
kernel = nuts.new_kernel(srng, logprob_fn)
initial_state = nuts.new_state(y_vv, logprob_fn)

params = (at.scalar(), at.scalar())
new_state, updates = kernel(*initial_state, *params)
nuts_step_fn = aesara.function(
    (y_vv, *params), new_state, updates=updates
)

step_size = 1.
inverse_mass_matrix = 1.
print(nuts_step_fn(100., step_size, inverse_mass_matrix))
# [array(100.), array(1.22673709e+23), array(2.475e+21), array(0.), array(1), array(True), array(True)]

step_size = 1e40
inverse_mass_matrix = 1.
print(nuts_step_fn(100., step_size, inverse_mass_matrix))
# Exception due to `p_accept` evaluating to `NaN`

p_accept should never evaluate to NaN in this case. Instead, we expect the kernel to return the initial state and is_diverging=True:

step_size = 1e40
inverse_mass_matrix = 1.
print(nuts_step_fn(100., step_size, inverse_mass_matrix))
# [array(100.), array(1.22673709e+23), array(2.475e+21), array(0.), array(1), array(True), array(True)]

I currently suspect (but this needs to be confirmed) that the trajectory builder does not return immediately when the first step diverges, leading to a at.exp(np.inf - np.inf) operation which returns NaN.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.