Coder Social home page Coder Social logo

juliatrustworthyai / laplaceredux.jl Goto Github PK

View Code? Open in Web Editor NEW
37.0 37.0 3.0 101.35 MB

Effortless Bayesian Deep Learning through Laplace Approximation for Flux.jl neural networks.

Home Page: https://juliatrustworthyai.github.io/LaplaceRedux.jl/

License: MIT License

Julia 66.07% TeX 33.47% Lua 0.43% CSS 0.03%
bayesian-deep-learning julia laplace-approximation machine-learning

laplaceredux.jl's People

Contributors

adelinacazacu avatar andrei32ionescu avatar eeeerrw avatar markardman avatar navimakarov avatar pat-alt avatar pitmonticone avatar rockdeldiablo avatar severinbratus avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

laplaceredux.jl's Issues

๐Ÿš€ Lift off

This is an issue reserved for the TU Delft Student Software Project '23 - LaplaceRedux

This part forms the core of the project and is all about scaling things up to lift the package off the ground. The following two issues should probably be tackled in this order:

Interfacing this package to MLJ (or perhaps directly to MLJFlux) would be really great, but is less of a priority than the other two issues:

Extended paper

Rough game plan for turning extended abstract into full paper (@severinbratus to see). Paper is in paper/.

  • Update docs to close issues #61 #60 #37
  • Make sure LaplaceRedux can be used with CounterfactualExplanations #7 (add section to paper)
  • Rework MLJ interface (#39) and ensure LaplaceRedux can be used with ConformalPrediction (add section to paper)
  • #67
  • #68
  • #69
  • #70

Interface MLJ

It would be nice to interface this library to MLJ.jl using FluxMLJ.jl.

EDIT: Following #33, still need support for regressor.

  • add NN regressor turns out the wrapper can be used for both regression and classification. Seems to work for now, but may refactor in the future to distinguish the two cases more clearly.

Laplace input format

I have realized that Laplace has no problem to deal with input in the format

Base.Iterators.Zip{Tuple{Vector{Vector{Float32}}, Vector{Float32}}}

but it doesn't work if the data are in the format

Base.Iterators.Zip{Tuple{Matrix{Float64}, Vector{Float64}}}

to avoid errors i had to use zip(eachrow(X),y) and not use the dataloader, which may be convenient when the dataset is big.

Refactor gradient and jacobians as multi-dimensional arrays

This is an issue reserved for the TU Delft Student Software Project '23 - LaplaceRedux

I think this is a fairly challenging problem, so do expect to spend some time on this. It may involve some breaking changes to the core architecture of the package, but on the bright side, it should help with some outstanding issues.

Currently gradients and Jacobians are mapped to two-dimensional arrays, since so far I had not considered training LA in batches (see related #19). This currently leads to problems for the implementation of GGN for the multi-class cases (see related #17). Refactoring things properly should help. In this context it may be worth using Tullio.jl for multi-dimensional array computations.

Update docs

Update docs to incorporate changes/improvements made by students #33

TagBot trigger issue

This issue is used to trigger TagBot; feel free to unsubscribe.

If you haven't already, you should update your TagBot.yml to include issue comment triggers.
Please see this post on Discourse for instructions and more details.

If you'd like for me to do this for you, comment TagBot fix on this issue.
I'll open a PR within a few hours, please be patient!

Add support for mini-batch training

Currently the Hessian is approximated per data point,

# Training:
for d in data
    loss_batch, H_batch = hessian_approximation(la, d)
    loss += loss_batch
    H += H_batch
    n_data += 1
end

where d is a tuple of inputs and outputs (x,y) at instance level. For various reasons it would be better to add support for training in batches.

`predict` method not differentiable using Zygote.jl

A while ago I ran into autodiff issues when trying to use Laplace models with CounterfactualExplanations.jl, but unfortunately was silly enough to not report the exact error here.

Would be good to check if this is really any issue. To do that, try to use the LaplaceReduxModel wrapper in CounterfactualExplanations.jl, which is currently neither exported nor properly documented.

If no issue pops up here, then try to directly differentiate through a predict call.

If there's still no issue, happy days!

Error when Plotting: Cannot convert LaplaceRedux.Laplace to series data for plotting

Hi

I am running into an issue when getting to the plotting section of the MLP regression tutorial.

I had originally run into the error on my own example, but I still face it when following the tutorial. The fitting and optimisation step runs without error!

Below is my environment (Pluto notebook; Julia 1.9.3):

Status `/tmp/jl_HCANJ7/Project.toml`
  [587475ba] Flux v0.14.6
  [c52c1a26] LaplaceRedux v0.1.3
  [91a5bcdd] Plots v1.39.0
  [44cfe95a] Pkg v1.9.2
  [10745b16] Statistics v1.9.0

For clarification, below is the code:

begin
	subset_w = :all
	la = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w)
	fit!(la, data)
	plot(la, X, y; zoom=-5, size=(400,400))
end

And below is the full error:

Cannot convert LaplaceRedux.Laplace to series data for plotting

    error(::String)@error.jl:35
    _prepare_series_data(::LaplaceRedux.Laplace)@series.jl:8
    _series_data_vector(::LaplaceRedux.Laplace, ::Dict{Symbol, Any})@series.jl:36
    macro expansion@series.jl:128[inlined]
    apply_recipe(::AbstractDict{Symbol, Any}, ::Type{RecipesPipeline.SliceIt}, ::Any, ::Any, ::Any)@RecipesBase.jl:300
    _process_userrecipes!(::Any, ::Any, ::Any)@user_recipe.jl:38
    recipe_pipeline!(::Any, ::Any, ::Any)@RecipesPipeline.jl:72
    _plot!(::Plots.Plot, ::Any, ::Any)@plot.jl:223
    #plot#[email protected]:102[inlined]
    plot@plot.jl:93[inlined]
    top-level scope@[Local: 5](http://localhost:1234/edit?id=0e0472c4-5887-11ee-1f75-2fc9822473d3#)[inlined]

Thank you for any assistance.

Kind regards
Patrick

Add support for Last-Layer and Subnet Laplace

Instead of applying LA to the full network, it's also possible to only work with the last layer or a subnetwork of the network. See if you can add support for this and consider if we really need separate classes here or if we can instead work with the existing classes for the full network.

๐ŸŒฏ Wrapping Up

This is an issue reserved for the TU Delft Student Software Project '23 - LaplaceRedux

To wrap up the project, it would be nice if you could summarise what you've done in an accessible manner (e.g. a blog post) and share it with the Julia community (and beyond if you like). You are entirely free to decide on the extent and form of this post. In other words, there is no expectation for this to be particularly extensive. But it is expected that you deliver something that wraps up the project.

Revise extended abstract

  • the citation block in the intro seems unnecessary
  • in discussion and outlook, the package being in its infancy is not really up to date. You can just mention the perspective and potential future work to incorporate
  • there could be more details on the methods using the place gained, and more details on the API. The only code example is three lines, we don't know howย la can be used after optimizing the prior

Compatibility with Flux 0.14

Hey really cool work!

Is there a reason why 0.14 is not supported? If what needs to be fixed and maybe we can take a stab at it?

Add support for block-diagonal Hessian approximations

This is an issue reserved for the TU Delft Student Software Project '23 - LaplaceRedux

Again a larger piece of work ...

Take a look at the Python package and note that they have implemented classes for different ways to approximate the Hessian (e.g. Laplace approximation with Kronecker factored log likelihood Hessian approximation). Our package currently only supports Empirical Fisher and you may have already tackled #17 to implement GGN. Both of these approaches approximate the full Hessian, which is costly.

The goal here is to implement more approaches.

๐Ÿƒ๐Ÿฝ Getting started

This is an issue reserved for the TU Delft Student Software Project '23 - LaplaceRedux

Here are two issues that should help you get started with this package ...

Firstly, what I believe is a small bug, but failed to document properly a while ago (see issue for details):

Secondly, I'd like you to revisit this previously closed issue for adding multi-class support. In particular, I would like you to test if this is actually working (I previously found that for multi-class problem we seem to end up with overly conservative posteriors (high predictive uncertainty everywhere).

If you find the same, then I think this may be related to #20, which is the next item on the agenda and the first task of #25.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.