Comments (4)
It sounds good. I can't say like that what could be wrong ... maybe a transpose of conjugate missing??
from s2cnn.
Hi Mario! This is quite a fast reply! Thank you :)
Here is my code:
from numpy.lib.function_base import _percentile_dispatcher
import torch
import math
import numpy as np
from s2cnn.s2_ft import s2_rft
from s2cnn.soft.so3_fft import SO3_ifft_real
from s2cnn.soft.s2_fft import S2_fft_real
from utils.utils import fftshift3d
from data.simulation_3d import get_simulation_3d
from log_sphere.log_sphere import sphere_transformer
import numpy as np
from s2cnn import s2_mm
def unravel_indices(indices,shape):
"""Converts flat indices into unraveled coordinates in a target shape.
Args:
indices: A tensor of (flat) indices, (*, N).
shape: The targeted shape, (D,).
Returns:
The unraveled coordinates, (*, N, D).
"""
coord = []
for dim in reversed(shape):
coord.append(indices % dim)
indices = indices // dim
coord = torch.stack(coord[::-1], dim=-1)
return coord
template, source, rotation_gt, translation_gt, scale_gt = get_simulation_3d(50,50,50,1)
template = torch.tensor(template[0]).float().permute(3,0,1,2)
source = torch.tensor(source[0]).float().permute(3,0,1,2)
device = torch.device("cpu")
# create two tensors on a sphere with the shape of [b,feature_in,beta,alpha]
bw_in = 25
bw_out = 25
print("grid", source.shape)
sphere1 = sphere_transformer(source.unsqueeze(-1), (50, 50, 50), device)[0].squeeze(-1)[...,20:40].sum(-1).float()
sphere2 = sphere_transformer(template.unsqueeze(-1), (50, 50, 50), device)[0].squeeze(-1)[...,20:40].sum(-1).float()
# sphere with the size of [b, theta, phi]
sphere_1_fft = S2_fft_real.apply(sphere1,bw_out)
sphere_2_fft = S2_fft_real.apply(sphere2,bw_out)
z = s2_mm(sphere_1_fft, sphere_2_fft).unsqueeze(-2) # [l * m * n, batch, feature_out, complex]
z = SO3_ifft_real.apply(z) # [batch, feature_out, beta, alpha, gamma]
z = z.squeeze(1)
# z = fftshift3d(z).unsqueeze(1)
cor_argmax = torch.argmax(z.view(z.size(0), z.size(1), -1), -1)
index = unravel_indices(cor_argmax, (z.size(2), z.size(3), z.size(4)))
print(index)
and the modified s2_mm is something like this:
def s2_mm(x, y):
'''
:param x: [l * m, batch, complex]
:param y: [l * m, batch, complex]
:return: [l * m * n, batch, complex]
'''
from s2cnn.utils.complex import complex_mm
# assert y.size(3) == 2
# assert x.size(3) == 2
nbatch = x.size(1)
# nfeature_in = x.size(2)
# nfeature_out = y.size(2)
# assert y.size(1) == nfeature_in
nspec = x.size(0)
# assert y.size(0) == nspec
# if x.is_cuda:
# return _cuda_S2_mm.apply(x, y)
nl = round(nspec**0.5)
batch_list = []
for b in range(nbatch):
x_batch = x[:, b, ...].unsqueeze(1) # [l * m, 1, complex]
y_batch = y[:, b, ...].unsqueeze(1) # [l * m, 1, complex]
Fz_list = []
begin = 0
for l in range(nl):
L = 2 * l + 1
size = L
Fx = x_batch[begin:begin+size] # [m, 1, complex]
Fy = y_batch[begin:begin+size] # [m, 1, complex]
Fy = Fy.transpose(0, 1) # [1, m, complex]
Fy = Fy.contiguous()
Fz = complex_mm(Fx, Fy, conj_y=True) # [m, m, complex]
Fz = Fz.view(L * L, 2) # [m * m, complex]
# print('fffff', Fz.shape)
Fz_list.append(Fz)
begin += size
z_batch = torch.cat(Fz_list, 0) # [l * m * m, complex]
batch_list.append(z_batch)
z = torch.stack(batch_list, dim=1) # [l * m * m, batch, complex]
print('shape', z.shape)
return z
When I tried to make Sphere1 = Sphere2, the output of S2fft is consistent while the argmax coordinate keeps randomizing.
Hope this might gives you a hint on my mistakes?
from s2cnn.
Hi Mario!
So after playing with all these spherical things around for a while, I finally managed to calculate the zyz transformation between two spheres now. However, since the torch has upgraded to 1.9 and the fft APIs have changed, I tried to upgrade all these ffts in the s2_fft and the so3_fft. I believe that the last merge of these repos is somehow wrong and is not consistent with what the code
used to be, and here is my version. Could you please check the forward and backward of the S2_fft and SO3_fft below, if it is correct, I would love to request a pull to update the code.
For S2FFT
def s2_fft(x, for_grad=False, b_out=None):
'''
:param x: [..., beta, alpha, complex]
:return: [l * m, ..., complex]
'''
assert x.size(-1) == 2
b_in = x.size(-2) // 2
assert x.size(-2) == 2 * b_in
assert x.size(-3) == 2 * b_in
if b_out is None:
b_out = b_in
assert b_out <= b_in
batch_size = x.size()[:-3]
x = x.view(-1, 2 * b_in, 2 * b_in, 2) # [batch, beta, alpha, complex]
'''
:param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2)
:return: [l * m, batch, complex] (b_out**2, nbatch, 2)
'''
nspec = b_out ** 2
nbatch = x.size(0)
wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device)
wigner = wigner.view(2 * b_in, -1) # [beta, l * m] (2 * b_in, nspec)
x = torch.view_as_real(torch.fft.fft(torch.view_as_complex(x))) # [batch, beta, m, complex]
# x = torch.fft.fft(x,1) # [batch, beta, m, complex]
output = x.new_empty((nspec, nbatch, 2))
cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=x.device.index)
stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
cuda_kernel(block=(1024, 1, 1),
grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1),
args=[x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr()],
stream=stream)
# for l in range(b_out):
# s = slice(l ** 2, l ** 2 + 2 * l + 1)
# xx = torch.cat((x[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1]
# output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx))
output = output.view(-1, *batch_size, 2) # [l * m, ..., complex] (nspec, ..., 2)
return output
def s2_ifft(x, for_grad=False, b_out=None):
'''
:param x: [l * m, ..., complex]
'''
assert x.size(-1) == 2
nspec = x.size(0)
b_in = round(nspec ** 0.5)
assert nspec == b_in ** 2
if b_out is None:
b_out = b_in
assert b_out >= b_in
batch_size = x.size()[1:-1]
x = x.view(nspec, -1, 2) # [l * m, batch, complex] (nspec, nbatch, 2)
'''
:param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
:return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
'''
nbatch = x.size(1)
wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)
wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec)
cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index)
stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
cuda_kernel(block=(1024, 1, 1),
grid=(cuda_utils.get_blocks(nbatch * (2 * b_out) ** 2, 1024), 1, 1),
args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()],
stream=stream)
# [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
# output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2))
# for l in range(b_in):
# s = slice(l ** 2, l ** 2 + 2 * l + 1)
# out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
# output[:, :, :l + 1] += out[:, :, -l - 1:]
# if l > 0:
# output[:, :, -l:] += out[:, :, :l]
output = torch.view_as_real(torch.fft.ifft(torch.view_as_complex(output))) * output.size(-2) # [batch, beta, alpha, complex]
output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2)
return output
class S2_fft_real(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b_out=None): # pylint: disable=W
from s2cnn.utils.complex import as_complex
ctx.b_out = b_out
ctx.b_in = x.size(-1) // 2
return s2_fft(as_complex(x), b_out=ctx.b_out)
@staticmethod
def backward(ctx, grad_output): # pylint: disable=W
return s2_ifft(grad_output, for_grad=True, b_out=ctx.b_in)[..., 0], None
For SO3_FFT
def so3_ifft(x, for_grad=False, b_out=None):
'''
:param x: [l * m * n, ..., complex]
'''
assert x.size(-1) == 2
nspec = x.size(0)
b_in = round((3 / 4 * nspec) ** (1 / 3))
assert nspec == b_in * (4 * b_in ** 2 - 1) // 3
if b_out is None:
b_out = b_in
batch_size = x.size()[1:-1]
x = x.view(nspec, -1, 2) # [l * m * n, batch, complex] (nspec, nbatch, 2)
'''
:param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
:return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2)
'''
nbatch = x.size(1)
wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) # [beta, l * m * n] (2 * b_out, nspec)
output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
# if x.is_cuda and x.dtype == torch.float32:
cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=False, device=x.device.index)
cuda_kernel(x, wigner, output) # [batch, beta, m, n, complex]
# else:
# output.fill_(0)
# for l in range(min(b_in, b_out)):
# s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
# out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
# l1 = min(l, b_out - 1) # if b_out < b_in
# output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1]
# if l > 0:
# output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1]
# output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l]
# output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l]
output = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(output), dim=[2,3])) * output.size(-2) ** 2 # [batch, beta, alpha, gamma, complex]
# output = torch.view_as_real((torch.fft.ifft(output, 2) * output.size(-2) ** 2)[...,0]) # [batch, beta, alpha, gamma, complex]
return output
def so3_fft(x, for_grad=False, b_out=None):
'''
:param x: [..., beta, alpha, gamma, complex]
:return: [l * m * n, ..., complex]
'''
assert x.size(-1) == 2, x.size()
b_in = x.size(-2) // 2
assert x.size(-2) == 2 * b_in
assert x.size(-3) == 2 * b_in
assert x.size(-4) == 2 * b_in
if b_out is None:
b_out = b_in
batch_size = x.size()[:-4]
# x = x.view(-1, 2 * b_in, 2 * b_in, 2 * b_in, 2) # [batch, beta, alpha, gamma, complex]
'''
:param x: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_in, 2 b_in, 2 b_in, 2)
:return: [l * m * n, batch, complex] (b_out (4 b_out**2 - 1) // 3, nbatch, 2)
'''
nspec = b_out * (4 * b_out ** 2 - 1) // 3
nbatch = x.size(0)
wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device) # [beta, l * m * n]
# x = torch.fft(x, 2) # [batch, beta, m, n, complex]
x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x),dim=[2,3]))
output = x.new_empty((nspec, nbatch, 2))
cuda_kernel = _setup_so3fft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_input=False, device=x.device.index)
cuda_kernel(x, wigner, output) # [l * m * n, batch, complex]
output = output.view(-1, *batch_size, 2) # [l * m * n, ..., complex]
return output
class SO3_ifft_real(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b_out=None): # pylint: disable=W
nspec = x.size(0)
ctx.b_out = b_out
ctx.b_in = round((3 / 4 * nspec) ** (1 / 3))
return so3_ifft(x, b_out=ctx.b_out)
@staticmethod
def backward(ctx, grad_output): # pylint: disable=W
output = so3_fft(grad_output, for_grad=True, b_out=ctx.b_in).unsqueeze(-2)
return output, None
from s2cnn.
Very nice! This will help some people. Yes please make a PR! I will revert the last merge such that I will be able to compare the original code with your in the PR
from s2cnn.
Related Issues (20)
- shrec17 dataset HOT 15
- Cannot run the code in Mac, as there is no CUDA
- some question when I run gendata.py in /examples/mnist folder HOT 4
- query about feature maps HOT 4
- Equivariance error issue HOT 6
- About the signal transform
- SO3_fft_real and SO3_ifft_real do not seem to be inverses of each other? HOT 12
- Some questions about the rotation of kernels HOT 1
- How to choose different grid HOT 2
- Visualizations
- Questions about the computations HOT 2
- Running MNIST Example Problems HOT 3
- Error with einsum in Equivariance plot HOT 3
- Error in so3_rotation (Jd matrix size) with custom data
- No module named 'lie_learn.representations.SO3.irrep_bases' HOT 4
- Error running example HOT 4
- Theoretical Problems about SO(3) Fourier Transformation HOT 2
- s2cnn
- How can I specify GPU to run s2cnn?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from s2cnn.