jonkhler avatar jonkhler commented on July 20, 2024
Correlation Between Spheres

Comments (4)

mariogeiger avatar mariogeiger commented on July 20, 2024 1

It sounds good. I can't say like that what could be wrong ... maybe a transpose of conjugate missing??

jessychen1016 avatar jessychen1016 commented on July 20, 2024

Hi Mario! This is quite a fast reply! Thank you :)
Here is my code:

from numpy.lib.function_base import _percentile_dispatcher
import torch
import math
import numpy as np
from s2cnn.s2_ft import s2_rft
from s2cnn.soft.so3_fft import SO3_ifft_real
from s2cnn.soft.s2_fft import S2_fft_real
from utils.utils import fftshift3d
from data.simulation_3d import get_simulation_3d
from log_sphere.log_sphere import sphere_transformer
import numpy as np

from s2cnn import s2_mm

def unravel_indices(indices,shape):
    """Converts flat indices into unraveled coordinates in a target shape.
        indices: A tensor of (flat) indices, (*, N).
        shape: The targeted shape, (D,).
        The unraveled coordinates, (*, N, D).
    coord = []

    for dim in reversed(shape):
        coord.append(indices % dim)
        indices = indices // dim

    coord = torch.stack(coord[::-1], dim=-1)

    return coord

template, source, rotation_gt, translation_gt, scale_gt =  get_simulation_3d(50,50,50,1)
template = torch.tensor(template[0]).float().permute(3,0,1,2)
source = torch.tensor(source[0]).float().permute(3,0,1,2)
device = torch.device("cpu")

# create two tensors on a sphere with the shape of [b,feature_in,beta,alpha]
bw_in = 25
bw_out = 25
print("grid", source.shape)
sphere1 = sphere_transformer(source.unsqueeze(-1), (50, 50, 50), device)[0].squeeze(-1)[...,20:40].sum(-1).float()
sphere2 = sphere_transformer(template.unsqueeze(-1), (50, 50, 50), device)[0].squeeze(-1)[...,20:40].sum(-1).float()

# sphere with the size of [b, theta, phi]
sphere_1_fft = S2_fft_real.apply(sphere1,bw_out)
sphere_2_fft = S2_fft_real.apply(sphere2,bw_out)

z = s2_mm(sphere_1_fft, sphere_2_fft).unsqueeze(-2)  # [l * m * n, batch, feature_out, complex]
z = SO3_ifft_real.apply(z)  # [batch, feature_out, beta, alpha, gamma]

z = z.squeeze(1)
# z = fftshift3d(z).unsqueeze(1)
cor_argmax = torch.argmax(z.view(z.size(0), z.size(1), -1), -1)

index = unravel_indices(cor_argmax, (z.size(2), z.size(3), z.size(4)))

and the modified s2_mm is something like this:

def s2_mm(x, y):
    :param x: [l * m,     batch, complex]
    :param y: [l * m,     batch, complex]
    :return:  [l * m * n, batch, complex]
    from s2cnn.utils.complex import complex_mm

    # assert y.size(3) == 2
    # assert x.size(3) == 2
    nbatch = x.size(1)
    # nfeature_in = x.size(2)
    # nfeature_out = y.size(2)
    # assert y.size(1) == nfeature_in
    nspec = x.size(0)
    # assert y.size(0) == nspec

    # if x.is_cuda:
    #     return _cuda_S2_mm.apply(x, y)

    nl = round(nspec**0.5)

    batch_list = []
    for b in range(nbatch):
        x_batch = x[:, b, ...].unsqueeze(1) # [l * m, 1, complex]
        y_batch = y[:, b, ...].unsqueeze(1) # [l * m, 1, complex]
        Fz_list = []
        begin = 0
        for l in range(nl):
            L = 2 * l + 1
            size = L

            Fx = x_batch[begin:begin+size]  # [m, 1, complex]
            Fy = y_batch[begin:begin+size]  # [m, 1, complex]

            Fy = Fy.transpose(0, 1) # [1, m, complex]
            Fy = Fy.contiguous()

            Fz = complex_mm(Fx, Fy, conj_y=True)  # [m, m, complex]
            Fz = Fz.view(L * L, 2)  # [m * m, complex]
            # print('fffff', Fz.shape)

            begin += size

        z_batch =, 0)  # [l * m * m, complex]


    z = torch.stack(batch_list, dim=1) # [l * m * m, batch, complex]
    print('shape', z.shape)
    return z

When I tried to make Sphere1 = Sphere2, the output of S2fft is consistent while the argmax coordinate keeps randomizing.

Hope this might gives you a hint on my mistakes?

jessychen1016 avatar jessychen1016 commented on July 20, 2024

Hi Mario!
So after playing with all these spherical things around for a while, I finally managed to calculate the zyz transformation between two spheres now. However, since the torch has upgraded to 1.9 and the fft APIs have changed, I tried to upgrade all these ffts in the s2_fft and the so3_fft. I believe that the last merge of these repos is somehow wrong and is not consistent with what the code
used to be, and here is my version. Could you please check the forward and backward of the S2_fft and SO3_fft below, if it is correct, I would love to request a pull to update the code.


def s2_fft(x, for_grad=False, b_out=None):
    :param x: [..., beta, alpha, complex]
    :return:  [l * m, ..., complex]
    assert x.size(-1) == 2
    b_in = x.size(-2) // 2
    assert x.size(-2) == 2 * b_in
    assert x.size(-3) == 2 * b_in
    if b_out is None:
        b_out = b_in
    assert b_out <= b_in
    batch_size = x.size()[:-3]

    x = x.view(-1, 2 * b_in, 2 * b_in, 2)  # [batch, beta, alpha, complex]

    :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2)
    :return: [l * m, batch, complex] (b_out**2, nbatch, 2)
    nspec = b_out ** 2
    nbatch = x.size(0)

    wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device)
    wigner = wigner.view(2 * b_in, -1)  # [beta, l * m] (2 * b_in, nspec)
    x = torch.view_as_real(torch.fft.fft(torch.view_as_complex(x)))  # [batch, beta, m, complex]
    # x = torch.fft.fft(x,1)  # [batch, beta, m, complex]

    output = x.new_empty((nspec, nbatch, 2))

    cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=x.device.index)
    stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
    cuda_kernel(block=(1024, 1, 1),
                grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1),
                args=[x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr()],
    # for l in range(b_out):
    #     s = slice(l ** 2, l ** 2 + 2 * l + 1)
    #     xx =[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1]
    #     output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx))
    output = output.view(-1, *batch_size, 2)  # [l * m, ..., complex] (nspec, ..., 2)
    return output

def s2_ifft(x, for_grad=False, b_out=None):
    :param x: [l * m, ..., complex]
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round(nspec ** 0.5)
    assert nspec == b_in ** 2
    if b_out is None:
        b_out = b_in
    assert b_out >= b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m, batch, complex] (nspec, nbatch, 2)

    :param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
    :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)
    wigner = wigner.view(2 * b_out, -1)  # [beta, l * m] (2 * b_out, nspec)

    cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index)
    stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
    output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
    cuda_kernel(block=(1024, 1, 1),
                grid=(cuda_utils.get_blocks(nbatch * (2 * b_out) ** 2, 1024), 1, 1),
                args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()],
        # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
    # output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2))
    # for l in range(b_in):
    #     s = slice(l ** 2, l ** 2 + 2 * l + 1)
    #     out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
    #     output[:, :, :l + 1] += out[:, :, -l - 1:]
    #     if l > 0:
    #         output[:, :, -l:] += out[:, :, :l]
    output = torch.view_as_real(torch.fft.ifft(torch.view_as_complex(output))) * output.size(-2)  # [batch, beta, alpha, complex]
    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2)
    return output

class S2_fft_real(torch.autograd.Function):
    def forward(ctx, x, b_out=None):  # pylint: disable=W
        from s2cnn.utils.complex import as_complex
        ctx.b_out = b_out
        ctx.b_in = x.size(-1) // 2
        return s2_fft(as_complex(x), b_out=ctx.b_out)

    def backward(ctx, grad_output):  # pylint: disable=W
        return s2_ifft(grad_output, for_grad=True, b_out=ctx.b_in)[..., 0], None


def so3_ifft(x, for_grad=False, b_out=None):
    :param x: [l * m * n, ..., complex]
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round((3 / 4 * nspec) ** (1 / 3))
    assert nspec == b_in * (4 * b_in ** 2 - 1) // 3
    if b_out is None:
        b_out = b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m * n, batch, complex] (nspec, nbatch, 2)

    :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
    :return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2)
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)  # [beta, l * m * n] (2 * b_out, nspec)

    output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
    # if x.is_cuda and x.dtype == torch.float32:
    cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=False, device=x.device.index)
    cuda_kernel(x, wigner, output)  # [batch, beta, m, n, complex]
    # else:
    # output.fill_(0)
    # for l in range(min(b_in, b_out)):
    #     s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
    #     out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
    #     l1 = min(l, b_out - 1)  # if b_out < b_in
    #     output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1]
    #     if l > 0:
    #         output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1]
    #         output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l]
    #         output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l]

    output = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(output), dim=[2,3])) * output.size(-2) ** 2  # [batch, beta, alpha, gamma, complex]
    # output = torch.view_as_real((torch.fft.ifft(output, 2) * output.size(-2) ** 2)[...,0])  # [batch, beta, alpha, gamma, complex]  

    return output

def so3_fft(x, for_grad=False, b_out=None):
    :param x: [..., beta, alpha, gamma, complex]
    :return: [l * m * n, ..., complex]
    assert x.size(-1) == 2, x.size()
    b_in = x.size(-2) // 2
    assert x.size(-2) == 2 * b_in
    assert x.size(-3) == 2 * b_in
    assert x.size(-4) == 2 * b_in
    if b_out is None:
        b_out = b_in
    batch_size = x.size()[:-4]
    # x = x.view(-1, 2 * b_in, 2 * b_in, 2 * b_in, 2)  # [batch, beta, alpha, gamma, complex]
    :param x: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_in, 2 b_in, 2 b_in, 2)
    :return: [l * m * n, batch, complex] (b_out (4 b_out**2 - 1) // 3, nbatch, 2)
    nspec = b_out * (4 * b_out ** 2 - 1) // 3
    nbatch = x.size(0)

    wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device)  # [beta, l * m * n]

    # x = torch.fft(x, 2)  # [batch, beta, m, n, complex]
    x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x),dim=[2,3]))

    output = x.new_empty((nspec, nbatch, 2))
    cuda_kernel = _setup_so3fft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_input=False, device=x.device.index)
    cuda_kernel(x, wigner, output)  # [l * m * n, batch, complex]
    output = output.view(-1, *batch_size, 2)  # [l * m * n, ..., complex]
    return output

class SO3_ifft_real(torch.autograd.Function):
    def forward(ctx, x, b_out=None):  # pylint: disable=W
        nspec = x.size(0)
        ctx.b_out = b_out
        ctx.b_in = round((3 / 4 * nspec) ** (1 / 3))
        return so3_ifft(x, b_out=ctx.b_out)

    def backward(ctx, grad_output):  # pylint: disable=W
        output = so3_fft(grad_output, for_grad=True, b_out=ctx.b_in).unsqueeze(-2)
        return output, None

mariogeiger avatar mariogeiger commented on July 20, 2024

Very nice! This will help some people. Yes please make a PR! I will revert the last merge such that I will be able to compare the original code with your in the PR

