Coder Social home page Coder Social logo

Usage documentation? about s2cnn HOT 14 CLOSED

jonkhler avatar jonkhler commented on August 21, 2024
Usage documentation?

from s2cnn.

Comments (14)

tscohen avatar tscohen commented on August 21, 2024 6

Regarding the different grids used for the kernel: you are free to use either one, or define your own grid. Spherical CNNs are very new, so we don't yet know what kind of grid / kernel support is appropriate for each kind of task. There is probably a lot that can be improved in terms of architecture details (for 2D CNNs, it took years to find good architectures, and they're still improving).

You can think of so3_integrate as being analogous to "global average pooling" in a standard CNN. This is sometimes done at the end to get approximate translation invariance (in the spherical CNN so3_integrate leads to rotation invariance). The reason you can't just sum up/average the values at all the points is that the grid points are not spread uniformly over the sphere, so we have to weigh them by the inverse of the density, which is given by the Haar measure. For instance, in the SOFT grid we have a lot of points near the north pole. Now imagine a signal that has high values near the north pole. After we rotate the signal, the large values could be near the equator, in which case we'd have fewer points with a high value. Simply summing the original and rotated pixels would not be rotation invariant, but so3_integrate() would be.

from s2cnn.

mariogeiger avatar mariogeiger commented on August 21, 2024 2
  • so3_equatorial_grid and so3_near_identity_grid are used to define the kernels support
  • so3_integrate integrate a signal on SO3 using the Haar measure

so3_soft_grid defines the SOFT grid that we use to represent our signals. You can use it to compute the Fourier transform of a signal, but we optimized the special case of the SOFT grid using FFT.
The following two codes does the same thing (but the second one is much faster)

Slow

grid = so3_soft_grid(b_in)
so3_rft(x, b_out, grid)

Fast

so3_rfft(x, b_out=b_out)

from s2cnn.

tscohen avatar tscohen commented on August 21, 2024 1

Indeed this is what so3_integrate counteracts. Note that the non-uniform grid is not the solution (high res exactly uniform grids on the sphere do not exist). The problem of having a non-uniform grid is solved by using quadrature weights / Haar measure to weigh each sample differently when averaging.

Yes, a Spherical ResNet is just a ResNet where you replace the 2D conv by a S2 or SO3 conv. ResNets compute y = f(x) + x, where f is usually a sequence of convolution / batch norm / relu. I currently don't have a code sample at hand, but it should be pretty straightforward.

from s2cnn.

tscohen avatar tscohen commented on August 21, 2024 1

If you want to do things properly, you first need to figure out what is the coordinate on the sphere associated with each of the points in your rectangular / panorama grid. Once you have that, you can create a new square grid, and use it to sample the rectangular one. The bandwidth of this grid can be anything, but I would think that setting bandwidth = 512 would preserve more detail than bandwidth = 256.

from s2cnn.

meder411 avatar meder411 commented on August 21, 2024

Thanks for the detailed explanations! This is a fascinating area of study.

As I’m sure you know, one unusual effect of processing equirectangular images is that the pixel variance along the top and bottom image borders is 0. At the poles the same pixel is just repeated along the length of the image. In addition to the greater warping required of the filter to capture information, this means that a typical perspective CNN will end up overweighting certain pixels. I presume the non-uniform gridding is a mechanism to circumvent this.

One last question (for now at least ;-)): you have a so3_shortcut function that’s commented as “useful for ResNets.” How do spherical convolutions as you’ve defined them translate to a residual network architecture? Is just straightforward substitution of a spherical convolution for a typical planar one? Could you please give a short toy code example of a skip-connection?

from s2cnn.

meder411 avatar meder411 commented on August 21, 2024

This makes sense. I will have to play around with the code a bit I think. If it helps anyone else, I visualized the S2 and SO3 sampling grids with PyPlot in Python. You can find the file attached here.
grid_viz.tar.gz

from s2cnn.

mariogeiger avatar mariogeiger commented on August 21, 2024

For the SO2 Visualization you exchanged sin and cos for the beta angle. We use beta=0 <=> north pole.

And also you should not set the bandwidth of the soft grid to pi.

from s2cnn.

meder411 avatar meder411 commented on August 21, 2024

Ah, I thought I may have done that. Yeah, what are good ranges for the soft grid? I used pi because it just showed a lot of points.

from s2cnn.

Jiankai-Sun avatar Jiankai-Sun commented on August 21, 2024

For SO3Shortcut, Why do you simply assign grid=((0, 0, 0), ) ? Is this parameter grid need to be changed when we use it?

Thank you!

from s2cnn.

mariogeiger avatar mariogeiger commented on August 21, 2024

The grid parameter can be compared to the parameter kernel_size in standard CNNs.
But you have more degrees of freedom with grid because you can choose each point of the support of the kernel (compared to standard CNN where you can only have square grids).

Here grid=((0, 0, 0), ) plays the same role as kernel_size=1 (like here)

from s2cnn.

Jiankai-Sun avatar Jiankai-Sun commented on August 21, 2024

Thank you for your quick reply!

I noticed that in the provided s2cnn MNIST Example,

grid_s2 = s2_near_identity_grid()
grid_so3 = so3_near_identity_grid()

self.conv1 = S2Convolution(
            nfeature_in=1,
            nfeature_out=f1,
            b_in=b_in,
            b_out=b_l1,
            grid=grid_s2)

self.conv2 = SO3Convolution(
            nfeature_in=f1,
            nfeature_out=f2,
            b_in=b_l1,
            b_out=b_l2,
            grid=grid_so3)

Here, grid_s2 is a tuple with length 24, and grid_so3 is a tuple with length 72.

>>> pprint(grid_s2)
((0.1308996938995747, 0.0),
 (0.1308996938995747, 0.7853981633974483),
 (0.1308996938995747, 1.5707963267948966),
 (0.1308996938995747, 2.356194490192345),
 (0.1308996938995747, 3.141592653589793),
 (0.1308996938995747, 3.9269908169872414),
 (0.1308996938995747, 4.71238898038469),
 (0.1308996938995747, 5.497787143782138),
 (0.2617993877991494, 0.0),
 (0.2617993877991494, 0.7853981633974483),
 (0.2617993877991494, 1.5707963267948966),
 (0.2617993877991494, 2.356194490192345),
 (0.2617993877991494, 3.141592653589793),
 (0.2617993877991494, 3.9269908169872414),
 (0.2617993877991494, 4.71238898038469),
 (0.2617993877991494, 5.497787143782138),
 (0.39269908169872414, 0.0),
 (0.39269908169872414, 0.7853981633974483),
 (0.39269908169872414, 1.5707963267948966),
 (0.39269908169872414, 2.356194490192345),
 (0.39269908169872414, 3.141592653589793),
 (0.39269908169872414, 3.9269908169872414),
 (0.39269908169872414, 4.71238898038469),
 (0.39269908169872414, 5.497787143782138))
>>> pprint(grid_so3)
((0.1308996938995747, 0.0, -0.39269908169872414),
 (0.1308996938995747, 0.0, 0.0),
 (0.1308996938995747, 0.0, 0.39269908169872414),
 (0.1308996938995747, 0.7853981633974483, -1.1780972450961724),
 (0.1308996938995747, 0.7853981633974483, -0.7853981633974483),
 (0.1308996938995747, 0.7853981633974483, -0.39269908169872414),
 (0.1308996938995747, 1.5707963267948966, -1.9634954084936207),
 (0.1308996938995747, 1.5707963267948966, -1.5707963267948966),
 (0.1308996938995747, 1.5707963267948966, -1.1780972450961724),
 (0.1308996938995747, 2.356194490192345, -2.748893571891069),
 (0.1308996938995747, 2.356194490192345, -2.356194490192345),
 (0.1308996938995747, 2.356194490192345, -1.9634954084936207),
 (0.1308996938995747, 3.141592653589793, -3.5342917352885173),
 (0.1308996938995747, 3.141592653589793, -3.141592653589793),
 (0.1308996938995747, 3.141592653589793, -2.748893571891069),
 (0.1308996938995747, 3.9269908169872414, -4.319689898685965),
 (0.1308996938995747, 3.9269908169872414, -3.9269908169872414),
 (0.1308996938995747, 3.9269908169872414, -3.5342917352885173),
 (0.1308996938995747, 4.71238898038469, -5.105088062083414),
 (0.1308996938995747, 4.71238898038469, -4.71238898038469),
 (0.1308996938995747, 4.71238898038469, -4.319689898685965),
 (0.1308996938995747, 5.497787143782138, -5.890486225480862),
 (0.1308996938995747, 5.497787143782138, -5.497787143782138),
 (0.1308996938995747, 5.497787143782138, -5.105088062083414),
 (0.2617993877991494, 0.0, -0.39269908169872414),
 (0.2617993877991494, 0.0, 0.0),
 (0.2617993877991494, 0.0, 0.39269908169872414),
 (0.2617993877991494, 0.7853981633974483, -1.1780972450961724),
 (0.2617993877991494, 0.7853981633974483, -0.7853981633974483),
 (0.2617993877991494, 0.7853981633974483, -0.39269908169872414),
 (0.2617993877991494, 1.5707963267948966, -1.9634954084936207),
 (0.2617993877991494, 1.5707963267948966, -1.5707963267948966),
 (0.2617993877991494, 1.5707963267948966, -1.1780972450961724),
 (0.2617993877991494, 2.356194490192345, -2.748893571891069),
 (0.2617993877991494, 2.356194490192345, -2.356194490192345),
 (0.2617993877991494, 2.356194490192345, -1.9634954084936207),
 (0.2617993877991494, 3.141592653589793, -3.5342917352885173),
 (0.2617993877991494, 3.141592653589793, -3.141592653589793),
 (0.2617993877991494, 3.141592653589793, -2.748893571891069),
 (0.2617993877991494, 3.9269908169872414, -4.319689898685965),
 (0.2617993877991494, 3.9269908169872414, -3.9269908169872414),
 (0.2617993877991494, 3.9269908169872414, -3.5342917352885173),
 (0.2617993877991494, 4.71238898038469, -5.105088062083414),
 (0.2617993877991494, 4.71238898038469, -4.71238898038469),
 (0.2617993877991494, 4.71238898038469, -4.319689898685965),
 (0.2617993877991494, 5.497787143782138, -5.890486225480862),
 (0.2617993877991494, 5.497787143782138, -5.497787143782138),
 (0.2617993877991494, 5.497787143782138, -5.105088062083414),
 (0.39269908169872414, 0.0, -0.39269908169872414),
 (0.39269908169872414, 0.0, 0.0),
 (0.39269908169872414, 0.0, 0.39269908169872414),
 (0.39269908169872414, 0.7853981633974483, -1.1780972450961724),
 (0.39269908169872414, 0.7853981633974483, -0.7853981633974483),
 (0.39269908169872414, 0.7853981633974483, -0.39269908169872414),
 (0.39269908169872414, 1.5707963267948966, -1.9634954084936207),
 (0.39269908169872414, 1.5707963267948966, -1.5707963267948966),
 (0.39269908169872414, 1.5707963267948966, -1.1780972450961724),
 (0.39269908169872414, 2.356194490192345, -2.748893571891069),
 (0.39269908169872414, 2.356194490192345, -2.356194490192345),
 (0.39269908169872414, 2.356194490192345, -1.9634954084936207),
 (0.39269908169872414, 3.141592653589793, -3.5342917352885173),
 (0.39269908169872414, 3.141592653589793, -3.141592653589793),
 (0.39269908169872414, 3.141592653589793, -2.748893571891069),
 (0.39269908169872414, 3.9269908169872414, -4.319689898685965),
 (0.39269908169872414, 3.9269908169872414, -3.9269908169872414),
 (0.39269908169872414, 3.9269908169872414, -3.5342917352885173),
 (0.39269908169872414, 4.71238898038469, -5.105088062083414),
 (0.39269908169872414, 4.71238898038469, -4.71238898038469),
 (0.39269908169872414, 4.71238898038469, -4.319689898685965),
 (0.39269908169872414, 5.497787143782138, -5.890486225480862),
 (0.39269908169872414, 5.497787143782138, -5.497787143782138),
 (0.39269908169872414, 5.497787143782138, -5.105088062083414))

Question 1: So does that mean the kernel size for MNIST Example are 27 and 72 respectively? Is there any possible explaination to adopt such a large kernel size for MNIST? Does it mean the learning effect will be better if we use such a large kernel size in the same way for VGG or ResNet?

Question 2: If it is recommended to just use the same kernel size as the original vgg or resnet implementation (for vgg, the kernel size is usually 2, 3, for ResNet, the kernel size is usually 3, 1, 7), is there any suggestions or rules about how to choose grid with the specified size? Should we manually specify 3 values (e.g. ((0, 0, 0), (1, 1, 1), (2, 2, 2))) or randomly select 3 elements from the existing grid grid_s2 and grid_so3? What if there are repeating elements in chosen points of kernel (e.g. ((0, 0, 0), (0, 0, 0), (1, 1, 1)))? Probably, small kernel size can reduce the needed memory :) If digit is project to the northern hemisphere, should we just choose the point of the support of the kernel whose point[2] > 1?

Question 3: I wonder how to set the parameter bandwidth_in for SO3Convolution() and S2Convolution() if the input rectangle whose size is (batch_size, channel, bandwidth, 2 * bandwidth) instead of (batch_size, channel, 2 * bandwidth, 2*bandwidth)? I am afraid that this is not an easy situation to handle rectangle input.

Question 4: How can we use different weight initialization method for s2cnn? E.g. Glorot initialization, kaiming initialization and so on.

(You can ignore the next 2 questions if it is not is not clearly stated :) )
Question 5: Noticed that here, in order to ensure that only south hemisphere (the paper said each digit is projected on the northern hemisphere) gets projected, you choose grid[2] <= 1, so the grid[2]==1 is the equator? grid[2] <= 1 is the south hemisphere and grid[2] > 1 is the north hemisphere? If we want to project the digit to the whole sphere, we just need to remove this line? Probably there is also something to do with these 2 lines

Question 6: Does s2cnn support project a rectangle shape image instead of square shape image to the sphere? Noticed that lie_learn.spaces.S2.meshgrid(b=b, grid_type=grid_type) can only return the theta, phi with shape [2 * bandwidth, 2 * bandwidth]. It seems that only square meshgrid are supported. What we need to do to project a 512 * 1024 * 3 image to a sphere and then s2_grid?

Thank you!

from s2cnn.

tscohen avatar tscohen commented on August 21, 2024

Q1: It is a kernel with 27 points. This is comparable to a 2D kernel of size 5, which has 5x5 = 25 points. In 3D you have another dimension, which again increases the number of points / parameters.

The grid is a flat list of coordinates. In a 3x3 kernel in a 2D CNN, the analog would be ((0,0), (0, 1), (0,2), (1,0),(1,1),(1,2),(2,0),(2,1),(2,2)).

Q2: It's not about choosing a set of points from the grid_s2 or grid_so3, it's about defining the right grid. The length of the grid defines the number of samples / points / parameters. The coordinates themselves (phi, theta), define where on the sphere this point lies. The near_identity_grid is one type of grid that we proposed, where the points are near the north pole (for s2) or near the identity transformation (for SO3). You can come up with any grid you like though. Whether it will work well depends on the characteristics of your data. Figuring this out is a research problem; we have only demonstrated that near_identity_grid etc. work reasonably well for the problems we looked at.

Q3: bandwidth_in of one layer should equal the output bandwidth of the previous layer. Not sure what you are asking wrt the square / rectangular grid. Can you rephrase?

Q4: S2Conv is just a pytorch Module, with parameters stored in .kernel and .bias. You can change them as you like.

Q5: whether that's the north or south pole is not an objective mathematical fact. It depends on how we describe the grid in prose / the name we give to a certain coordinate. I think we said that (theta, phi) = (0, 0) corresponds to the north pole.

Q6: we currently only support the SOFT grid, which is square. You could implement spherical convolution for other grids as wel, and this would have other benefits such as a more homogeneous sampling of the sphere. However, you can also resample a rectangular image onto a square grid. The question is how you want to project your image onto the sphere.

Is your image a planar image or a spherical image? If planar, you need to define a projection map from R2 to the S2, e.g. the stereographic projection. Then you can work out for each point in some square grid on the sphere, where it gets mapped to on the plane. Then you can sample that point in the image using bilinear interpolation.

If your image is a spherical image, it must be asssociated with some grid on the sphere. You need to figure out what it is, ie for each pixel what its (theta, phi) coordinates are. Then again you can figure out for each point in the SOFT grid, where in the rectangular image you need to sample.

from s2cnn.

Jiankai-Sun avatar Jiankai-Sun commented on August 21, 2024

For Question 3, what I mean is how to set the parameter bandwidth_in if the height and width of input image is different after projected to a grid? For example, if the input is a panorama (Probably panorama is a kind of spherical image that has been associated with a grid on the sphere) with height 512, width 1024 and channel 3 (the height and width of input image is different), Should we set the bandwidth_in for the first layer S2Convolution as 512/2=256 or 1024/2=512? Different from s2cnn MNIST example with size 60x60x1 (same height and width) so that we can easily set the bandwidth as 60/2=30.

Thank you!

from s2cnn.

Jiankai-Sun avatar Jiankai-Sun commented on August 21, 2024

Sure, thank you for your reply!

from s2cnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.