Coder Social home page Coder Social logo

Comments (9)

gaurvigoyal avatar gaurvigoyal commented on May 27, 2024 1

Right. I don't remember closing this, must've been a mistake (I clicked in "close with comment", I guess). Thanks! I hadn't seen the PR! Clever things indeed :) I'll incorporate that.
And I'll try these and get back to you! Thanks!

from norse.

gaurvigoyal avatar gaurvigoyal commented on May 27, 2024

@Jegp

from norse.

gaurvigoyal avatar gaurvigoyal commented on May 27, 2024

Right, so the DataParallel assumed that the first dimension is batch. But with Norse, the first dimension is timestep. This is causing the sampled to be scattered in an incorrect way across the GPUs. I have now tried using DataParallel(dim=1), and that is a very different error:

torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!

from norse.

Jegp avatar Jegp commented on May 27, 2024

Thanks for reporting this @gaurvigoyal! Ok, so are you saying that the parallelization is not an issue? Because, I can actually imagine it breaking down if the tensor is sent to multiple devices after initialization.

Regarding the tracing, is it correct that this only happens during backprop? One potential problem could be that the autodiff graph isn't properly "cleared" between timesteps. Could you share the code you use to optimize the model? Happy to take it offline as well.

from norse.

gaurvigoyal avatar gaurvigoyal commented on May 27, 2024

Hey @jens, Thanks for responding so quickly. Have you tried DataParallel training or any other multi-GPU setup with Norse yet? With (dim=1), Pytorch scatters the data as per the batch dimension now, so that problem seems to be solved, for the data, but I guess it doesn't work for the model, leading to the issue with the trace.

I am working on my movenet.pytorch repository in the spiking-data-loader branch (here).

The resulting error log contains a graph diff on that trace that's too long, but here is the full error:

Traceback (most recent call last):
  File "train.py", line 50, in <module>
    main(cfg)
  File "train.py", line 40, in main
    run_task = Task(cfg, model)
  File "/projects/edpr_hpe/movenet.pytorch/lib/task/task.py", line 59, in __init__
    self.tb.add_graph(self.model, torch.randn(40, 1, 1, 192, 192).to(self.device))
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/tensorboard/writer.py", line 736, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, use_strict_trace))
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/tensorboard/_pytorch_graph.py", line 289, in graph
    trace = torch.jit.trace(model, args, strict=use_strict_trace)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 741, in trace
    return trace_module(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 983, in trace_module
    _check_trace(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 526, in _check_trace
    raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!

My first guess is maybe the data and model need to be parallel-ized separately? I'd hoped Pytorch and Norse would together handle this internally.

from norse.

Jegp avatar Jegp commented on May 27, 2024

Happy to help! Did you see I updated the receptive fields in a PR? I fixed some distributions and added some cleverness that should help you get better performance both in time and accuracy.

I can't seem to reproduce your error, I'm afraid. I tried a small toy example pasted below, and it worked for me. Could I ask you to try it out and see if it works for you?

import torch
import norse.torch as norse

model = norse.SequentialState(norse.LIBoxCell(), torch.nn.Linear(1, 10)).to("cuda:0")
par = torch.nn.DataParallel(model, dim=1)
par(torch.empty(100, 100, 100, 1).to("cuda:1"))

Another separate point is that I have much better experience in using Pytorch Lightning for the parallelization, because they're taking care of mapping both the data and model to various devices. I pasted a small (pretty dumb) example below. The "parallelization magic" comes from using the pl.Trainer(..., strategy=pl.strategies.DDPStrategy()) (for Distributed Data Parallel). Which is a single-device multi-processing strategy, which could maybe also work in your case. Would that be an option? Although, I realize that may require a bit to rewrite your code :P

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl

import norse.torch as norse

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = norse.SequentialState(norse.LIBoxCell(), nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        s = None
        for i in range(2):
            x_hat, s = self.decoder(z, s)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1, strategy=pl.strategies.DDPStrategy())
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

from norse.

Jegp avatar Jegp commented on May 27, 2024

Any news @gaurvigoyal? Anything I can do?

from norse.

gaurvigoyal avatar gaurvigoyal commented on May 27, 2024

Hey @Jegp, I rewrote the code to lightning. But distributing it over multiple GPUs was still running into errors. At some point, other projects get higher priority and this went on the back burner. At this point I don't have as much to dedicate as this needs. Did you already publish the paper on these integrators?

from norse.

Jegp avatar Jegp commented on May 27, 2024

Hey @Jegp, I rewrote the code to lightning. But distributing it over multiple GPUs was still running into errors. At some point, other projects get higher priority and this went on the back burner. At this point I don't have as much to dedicate as this needs. Did you already publish the paper on these integrators?

Hey @gaurvigoyal, we just put up a preprint on exactly this: https://arxiv.org/abs/2405.00318
I know it's like a couple of months late, but I'm happy to pick up the discussion :-)

from norse.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.