Coder Social home page Coder Social logo

Comments (2)

jonasmac16 avatar jonasmac16 commented on May 25, 2024

As suggested by Kai on Slack tempering would need to be added to the 'step' function

Secondly, you need to add the tempering steps to (

function step(
lf::Leapfrog{F},
h::Hamiltonian,
θ::AbstractVector{T},
r::AbstractVector{T},
n_steps::Int=1
) where {F<:AbstractFloat,T<:Real}
fwd = n_steps > 0 # simulate hamiltonian backward when n_steps < 0
ϵ = fwd ? lf.ϵ : - lf.ϵ
n_valid = 0
r_new, _is_valid_1 = lf_momentum/2, h, θ, r)
for i = 1:abs(n_steps)
θ_new = lf_position(ϵ, h, θ, r_new)
r_new, _is_valid_2 = lf_momentum(i == n_steps ? ϵ / 2 : ϵ, h, θ_new, r_new)
if (_is_valid_1 && _is_valid_2)
θ, r = θ_new, r_new
n_valid = n_valid + 1
else
# Reverse half leapfrog step for r when breaking
# the loop immaturely.
if i > 1 && i < _n_steps
r, _ = lf_momentum(-lf.ϵ / 2, h, θ, r)
end
break
end
end
return θ, r, n_valid > 0
end
)

Making sure to that we "heat up" and "cool down" symmetrically as suggested by Neal, 2012. As pointed out by Kai due to nature of the leapfrog implementation taking full step for the momentum within the trajectory we need to take extra care.

An if elseif scheme for tempering with the trajectory like this should probably work:

n_steps =  5

println("Step:0")
println("half*sqrt(α)")
for i in 1:n_steps
  if i*2 == n_steps && i != n_steps
      #θ, r = θ_new, r_new/α
      println("half*sqrt(α) + half/sqrt(α) = full")
      println("Step:",i,".5")
  elseif 2*i < n_steps
      #θ, r = θ_new, r_new*α
      println("full*α")
      println("Step:",i,".5")
  elseif i*2 > n_steps  && i != n_steps
      #θ, r = θ_new, r_new/α
      println("full/α")
      println("Step:",i,".5")
  else
      #θ, r = θ_new, r_new/α_sqrt
      println("full/sqrt(α)")
      println("Step:",i)
  end
end

Lastly, @xukai92 raised the issue of how to exactly implement the tempering code within the existing code structure as well as how much effect it will have on AHMC sampling:

So it comes to a design issue, should I implement this conditions inside step and check if lf.α == 1, or we’d better introduce another Integrator type which does this job and implement lf_momentum and lf_position separate for each type and rely on multiple dispatch to deal with this conditions. Not sure what’s the better way. But for your purpose add this conditions should work.
However, I’m not sure how does this actually interact with NUTS (for plain HMC it seems all good), as NUTS only call steps with n_steps=1 and the trajectory building relies on its on doubling tree algorithm.

from advancedhmc.jl.

jonasmac16 avatar jonasmac16 commented on May 25, 2024

An update from above code(fixing an error), the tempering could be conceptually done like this:

n_steps =  5

α = 1.02
α_sqrt = sqrt(α)

println("Step:0")
#θ, r = θ_new, r_new*α_sqrt
println("half*sqrt(α)")
for i in 1:n_steps
  if i*2 == n_steps && i != n_steps
      #θ, r = θ_new, r_new
      println("half*sqrt(α) + half/sqrt(α) = full")
      println("Step:",i,".5")
  elseif 2*i < n_steps
      #θ, r = θ_new, r_new*α
      println("full*α")
      println("Step:",i,".5")
  elseif i*2 > n_steps  && i != n_steps
      #θ, r = θ_new, r_new/α
      println("full/α")
      println("Step:",i,".5")
  else
      #θ, r = θ_new, r_new/α_sqrt
      println("full/sqrt(α)")
      println("Step:",i)
  end
end

from advancedhmc.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.