Coder Social home page Coder Social logo

Comments (9)

AtlantixJJ avatar AtlantixJJ commented on September 6, 2024 1

Thanks. I see.

I think this is a brilliant software on the midway. In comparison to the old and messy examples in torchdiffeq, your software is quite clean.

There is an advice for your software. GAN guys like me would prefer having a real MNIST/Cifar/CelebA image generation demo in the home page. In fact I am also working on a MNIST CNF.

from torchdyn.

Zymrael avatar Zymrael commented on September 6, 2024

Hi @AtlantixJJ, thanks for the question. Using those alternative layers with torchdyn is as simple as this:

class ConcatSquashLinear(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self._layer = nn.Linear(dim_in, dim_out)
        self._hyper_bias = nn.Linear(1, dim_out, bias=False)
        self._hyper_gate = nn.Linear(1, dim_out)

    def forward(self, x):
        x, t = x[:, :-1], x[:, -1:]
        return self._layer(x) * torch.sigmoid(self._hyper_gate(t) + self._hyper_bias(t))
    

f = nn.Sequential(DepthCat(1),
                  ConcatSquashLinear(2, 2))

nde = NeuralDE(f)

x = torch.randn(100, 2)
out = nde(x)

In general you always have access to t / s after DepthCat(idx) as the last element of the input x along axis idx. You can add other similar layers from https://github.com/rtqichen/ffjord/blob/master/lib/layers/diffeq_layers/basic.py in the same exact fashion.

from torchdyn.

AtlantixJJ avatar AtlantixJJ commented on September 6, 2024

Thanks for your response. I noticed that to concatenate along depth dimension, the input x need to be 2D. However, for some flow models like a MNIST image generator, the input is 4D, introducing difficulty in using DepthCat. Do you consider accepting network input as (x, t)?

In fact, I tried to modify your code and I found it may be pretty easy to do this. Just change the forward in DEFunc and CNF works:

#defunc.py
        # regular forward
        else:
            if self.order > 1: x = self.horder_forward(s, x)
            else: x = self.m(x, s) # [modified]

            self.dxds = x
            return x
#normflow.py
    def forward(self, x, t):
        with torch.set_grad_enabled(True):
            x_in = torch.autograd.Variable(x[:,1:], requires_grad=True).to(x) # first dimension reserved to divergence propagation

            # the neural network will handle the data-dynamics here
            if self.order > 1: self.higher_order(x_in)
            else: x_out = self.net(x_in, t) # [modified] to take t

            trJ = self.trace_estimator(x_out, x_in, noise=self.noise)
        return torch.cat([-trJ[:, None], x_out], 1) + 0 * x # `+ 0*x` has the only purpose of connecting x[:, 0] to autograd graph

from torchdyn.

AtlantixJJ avatar AtlantixJJ commented on September 6, 2024

Besides, the CNF implementation also assumes 2D input/output. Now I have to flatten input/output mandatorily. Any plan to get rid of this constraint?

from torchdyn.

Zymrael avatar Zymrael commented on September 6, 2024

Thanks for your response. I noticed that to concatenate along depth dimension, the input x need to be 2D. However, for some flow models like a MNIST image generator, the input is 4D, introducing difficulty in using DepthCat. Do you consider accepting network input as (x, t)?

DepthCat can be used with any number of dimensions, see e.g

f = nn.Sequential(DepthCat(1),
                  nn.Conv2d(6, 5, kernel_size=3, padding=1))

nde = NeuralDE(f)

x = torch.randn(100, 5, 20, 20)
out = nde(x)

The same can be done for CNFs :). The reason for this implicit usage of t is to avoid having to define specific nn.Modules every time a NeuralDE is used, allowing for example direct usage of torch.nn layers.

from torchdyn.

AtlantixJJ avatar AtlantixJJ commented on September 6, 2024

What do you mean by the same can be done for CNFs? Without flattening the output to 2D, it once gave me an error of dimension mismatch: trJ is 2D and x_in is 4D.

  File "./torchdyn/models/normflows.py", line 62, in forward
    return torch.cat([-trJ[:, None], x_out], 1) + 0*x # `+ 0*x` has the only purpose of connecting x[:, 0] to autograd graph
RuntimeError: Tensors must have same number of dimensions: got 4 and 2

from torchdyn.

Zymrael avatar Zymrael commented on September 6, 2024

I went back to the CNF source, here is a way to do it by changing the trace_estimator (intended) and the CNF class (not intended - expect a change in the library to make this easier soon!). Let me know if that works for you; if it does I'll work on a few torchdyn changes to make this slightly easier than below.

class CNF(nn.Module):
    def __init__(self, net, trace_estimator=None, noise_dist=None):
        super().__init__()
        self.net = net
        self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace;
        self.noise_dist, self.noise = noise_dist, None
        if self.trace_estimator in REQUIRES_NOISE:
            assert self.noise_dist is not None, 'This type of trace estimator requires specification of a noise distribution'
            
    def forward(self, x):   
        with torch.set_grad_enabled(True):
            x_in = torch.autograd.Variable(x[:,1:], requires_grad=True).to(x) # first dimension reserved to divergence propagation          
            # the neural network will handle the data-dynamics here
            x_out = self.net(x_in)
                
            trJ = self.trace_estimator(x_out, x_in, noise=self.noise)
            
            # changed for 4D. `B -> B x 1 x H x W`
            trJ = trJ[:, None, None, None].expand((x_out.shape[0], 1, *x_out.shape[2:]))
        return torch.cat([-trJ, x_out], 1) + 0*x # `+ 0*x` has the only purpose of connecting x[:, 0] to autograd graph


# changed for 4D
def hutch_trace_4D(x_out, x_in, noise=None, **kwargs):
    """Hutchinson's trace Jacobian estimator, O(1) call to autograd"""
    noise = noise.view(x_out.shape)
    jvp = torch.autograd.grad(x_out, x_in, noise, create_graph=True)[0]
    trJ = torch.einsum('bijk,bijk->b', jvp, noise)
    return trJ

f = nn.Sequential(
        DepthCat(1),
        nn.Conv2d(6, 5, kernel_size=3, padding=1)
    )

from torch.distributions import Independent, MultivariateNormal
noise_dist = MultivariateNormal(torch.zeros(5 * 20 * 20), torch.eye(5 * 20 * 20))

cnf = CNF(f, trace_estimator=hutch_trace_img, noise_dist=noise_dist)
nde = NeuralDE(cnf, solver='dopri5', s_span=torch.linspace(0, 1, 2), sensitivity='adjoint', atol=1e-4, rtol=1e-4)
model = nn.Sequential(Augmenter(augment_idx=1, augment_dims=1, order='first'),
                      nde)

x = torch.randn(100, 5, 20, 20)
out = model(x)

from torchdyn.

Zymrael avatar Zymrael commented on September 6, 2024

Thank you for the kind words; we're planning on adding more image examples for CNFs in the next release :)

from torchdyn.

Zymrael avatar Zymrael commented on September 6, 2024

Closing as the issue has been solved.

from torchdyn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.