Comments (5)
I can fix it with torch.einsum("ij,jk->ik", (x.clone(), torch.randn(3, 3)))
from s2cnn.
The problem comes from s2_rft when we use torch.einsum. The problem can be reproduced by the following code:
x = torch.randn(3, 3, requires_grad=True)
z1 = torch.einsum("ij,jk->ik", (x, torch.randn(3, 3)))
z2 = torch.einsum("ij,jk->ik", (x, torch.randn(3, 3)))
z1.sum().backward()
from s2cnn.
Hi,
No I never observed this error, we always used a mono-model.
Did you try to simplify the model to see if the error still occur ? For instance using only s2conv or only so3conv ?
from s2cnn.
Yes, I got same error using only s2conv in following code.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from s2cnn import s2_near_identity_grid, S2Convolution
def S2conv2d(in_c, out_c, in_b, out_b):
grid = s2_near_identity_grid(n_alpha=2 * in_b)
return S2Convolution(in_c, out_c, in_b, out_b, grid)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = S2conv2d(1, 5, 14, 7)
def forward(self, x):
return self.conv1(x)
def main():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
WORKERS=1
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_loader = DataLoader(MNIST('./data', train=True, transform=img_transform, download=True),
batch_size=256, num_workers=WORKERS, pin_memory=True, shuffle=True)
model = Model().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,
momentum=0.9, weight_decay=5e-4)
def train():
model.train()
for batch_idx, (image,target) in enumerate(train_loader):
image = image.to(device)
optimizer.zero_grad()
# multi model
output1 = model(image)
output2 = model(image)
loss = (output1 + output2).mean()
loss.backward()
optimizer.step()
print("OK")
break
train()
from s2cnn.
Thank you so much!
from s2cnn.
Related Issues (20)
- shrec17 dataset HOT 15
- Cannot run the code in Mac, as there is no CUDA
- some question when I run gendata.py in /examples/mnist folder HOT 4
- query about feature maps HOT 4
- Equivariance error issue HOT 6
- About the signal transform
- SO3_fft_real and SO3_ifft_real do not seem to be inverses of each other? HOT 12
- Some questions about the rotation of kernels HOT 1
- How to choose different grid HOT 2
- Visualizations
- Questions about the computations HOT 2
- Correlation Between Spheres HOT 4
- Running MNIST Example Problems HOT 3
- Error with einsum in Equivariance plot HOT 3
- Error in so3_rotation (Jd matrix size) with custom data
- No module named 'lie_learn.representations.SO3.irrep_bases' HOT 4
- Error running example HOT 4
- Theoretical Problems about SO(3) Fourier Transformation HOT 2
- s2cnn
- How can I specify GPU to run s2cnn?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from s2cnn.