Coder Social home page Coder Social logo

julianst / delta-prox Goto Github PK

View Code? Open in Web Editor NEW

This project forked from princeton-computational-imaging/delta-prox

0.0 0.0 0.0 99.95 MB

Official code repository for ∇-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization (SIGGRAPH TOG 2023)

Home Page: https://light.princeton.edu/publication/delta_prox/

Python 100.00%

delta-prox's Introduction

Delta Prox

Differentiable Proximal Algorithm Modeling for Large-Scale Optimization

Paper | Tutorials | Examples | Documentation | Citation

PyPI - Downloads arXiv huggingface

∇-Prox is a domain-specific language (DSL) and compiler that transforms optimization problems into differentiable proximal solvers. Departing from handwriting these solvers and differentiating via autograd, ∇-Prox requires only a few lines of code to define a solver that can be specialized based on user requirements w.r.t memory constraints or training budget by optimized algorithm unrolling, deep equilibrium learning, and deep reinforcement learning. ∇-Prox makes it easier to prototype different learning-based bi-level optimization problems for a diverse range of applications. We compare our framework against existing methods with naive implementations. ∇-Prox is significantly more compact in terms of lines of code and compares favorably in memory consumption in applications across domains.

News

  • August 2023 : $\nabla$-Prox is presented at SIGGRAPH 2023 and its code base is now public.
  • May 2023 : $\nabla$-Prox is accepted as a journal paper at SIGGRAPH 2023.

Installation

We recommend installing $\nabla$-Prox in a virtual environment from PyPI.

pip install dprox

Please refer to the Installation guide for other options.

Quickstart

pipeline2

Consider a simple image deconvolution problem, where we want to find a clean image $x$ given the blurred observation $y$ that minimizes the following objective function:

$$ \arg \min_x { \frac{1}{2} |Dx - y|^2_2 + g(x) }, $$

where $g(x)$ denotes an implicit plug-and-play denoiser prior. We can solve this problem in ∇-Prox with the following code:

from dprox import *
from dprox.utils import *
from dprox.contrib import *

img = sample()
psf = point_spread_function(15, 5)
b = blurring(img, psf)

x = Variable()
data_term = sum_squares(conv(x, psf) - b)
reg_term = deep_prior(x, denoiser='ffdnet_color')
prob = Problem(data_term + reg_term)

prob.solve(method='admm', x0=b)

We can also specialize the solver via bi-level optimization. For example, we can specialize the solver into a reinforcement learning (RL) solver for automatic parameter tuning.

solver = compile(data_term + reg_term, method='admm')
rl_solver = specialize(solver, method='rl')
rl_solver = train(rl_solver, **training_kwargs)

Alternatively, we can specialize the solver into an unrolled solver for end-to-end optics optimization.

x = Variable()
y = Placeholder()
PSF = Placeholder()
data_term = sum_squares(conv_doe(x, PSF, circular=True) - y)
reg_term = deep_prior(x, denoiser='ffdnet_color')
solver = compile(data_term + reg_term, method='admm')
unrolled_solver = specialize(solver, method='unroll', max_iter=10)

# training doe model and hyperparameters
doe_model = build_doe_model()
doe_model.rhos = nn.parameter.Parameter(rhos)
doe_model.lams = nn.parameter.Parameter(lams)

def step_fn(gt):
    psf = doe_model.get_psf()
    inp = img_psf_conv(gt, psf, circular=True)
    inp = inp + torch.randn(*inp.shape) * sigma
    y.value = inp
    PSF.value = psf

    out = solver.solve(x0=inp, rhos=doe_model.rhos, lams={reg_term: doe_model.lams})
    return gt, inp, out
    
train(doe_model, step_fn, dataset)

Want to learn more? Check out the step-by-step tutorials for the framework and its applications.

Citation

@article{deltaprox2023,
  title = {$\nabla$-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization},
  author = {Lai, Zeqiang and Wei, Kaixuan and Fu, Ying and H\"{a}rtel, Philipp and Heide, Felix},
  journal={ACM Transactions on Graphics (TOG)},
  volume = {42},
  number = {4},
  articleno = {105},
  pages = {1--19},
  year={2023},
  publisher = {Association for Computing Machinery},
  address = {New York, NY, USA},
  doi = {10.1145/3592144},
}

Acknowledgement

ProxImaLODLDPIRDPHSIRDGUNet

delta-prox's People

Contributors

fheide avatar philipphaertel avatar vandermode avatar zeqiang-lai avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.