Coder Social home page Coder Social logo

Comments (2)

CYXYZ avatar CYXYZ commented on June 8, 2024 1

Hello Vivek, I've made the changes to the content in registration.py, and now the file "register.py" is working fine. The specific changes are as shown in the code below. However, the previous issue regarding parameterizations still persists. Could you please help me with that when you have time? Thank you for your assistance.


from diffdrr.drr import DRR
from diffdrr.detector import Detector
from diffdrr.renderers import Siddon

class SparseRegistration(torch.nn.Module):
    def __init__(
        self,
        drr: DRR,
        pose: RigidTransform,
        parameterization: str,
        convention: str = None,
        features=None,  # Used to compute biased estimate of mNCC
        n_patches: int = None,  # If n_patches is None, render the whole image
        patch_size: int = 13,
    ):
        super().__init__()
        self.drr = drr

        rotation, translation = pose.convert(parameterization,convention)
        self.parameterization = parameterization
        self.convention = convention
        self.rotation = torch.nn.Parameter(rotation)
        self.translation = torch.nn.Parameter(translation)

        # Crop pixels off the edge such that pixels don't fall outside the image
        self.n_patches = n_patches
        self.patch_size = patch_size
        self.patch_radius = self.patch_size // 2 + 1
        self.height = self.drr.detector.height
        self.width = self.drr.detector.width
        self.f_height = self.height - 2 * self.patch_radius
        self.f_width = self.width - 2 * self.patch_radius

        # Define the distribution over patch centers
        if features is None:
            features = torch.ones(
                self.height, self.width, device=self.rotation.device
            ) / (self.height * self.width)
        self.patch_centers = torch.distributions.categorical.Categorical(
            probs=features.squeeze()[
                self.patch_radius : -self.patch_radius,
                self.patch_radius : -self.patch_radius,
            ].flatten()
        )

    def forward(self, n_patches=None, patch_size=None):
        # Parse initial density
        if not hasattr(self.drr, "density"):
            self.drr.set_bone_attenuation_multiplier(
                self.drr.bone_attenuation_multiplier
            )

        if n_patches is not None or patch_size is not None:
            self.n_patches = n_patches
            self.patch_size = patch_size

        # Make the mask for sparse rendering
        if self.n_patches is None:
            mask = torch.ones(
                1,
                self.height,
                self.width,
                dtype=torch.bool,
                device=self.rotation.device,
            )
        else:
            mask = torch.zeros(
                self.n_patches,
                self.height,
                self.width,
                dtype=torch.bool,
                device=self.rotation.device,
            )
            radius = self.patch_size // 2
            idxs = self.patch_centers.sample(sample_shape=torch.Size([self.n_patches]))
            idxs, jdxs = (
                idxs // self.f_height + self.patch_radius,
                idxs % self.f_width + self.patch_radius,
            )

            idx = torch.arange(-radius, radius + 1, device=self.rotation.device)
            patches = torch.cartesian_prod(idx, idx).expand(self.n_patches, -1, -1)
            patches = patches + torch.stack([idxs, jdxs], dim=-1).unsqueeze(1)
            patches = torch.concat(
                [
                    torch.arange(self.n_patches, device=self.rotation.device)
                    .unsqueeze(-1)
                    .expand(-1, self.patch_size**2)
                    .unsqueeze(-1),
                    patches,
                ],
                dim=-1,
            )
            mask[
                patches[..., 0],
                patches[..., 1],
                patches[..., 2],
            ] = True

        pose = convert(self.rotation, 
                       self.translation,
                       parameterization=self.parameterization,
                       convention=self.convention,)
        source, target = self.drr.detector.forward(pose)

        # Render the sparse image
        target = target[mask.any(dim=0).view(1, -1)]

        siddon = Siddon(eps=1e-8)
        img = siddon(self.drr.density, self.drr.spacing, source, target)
        
        if self.n_patches is None:
            img = self.drr.reshape_transform(img, batch_size=len(self.rotation))
        return img, mask

    def get_current_pose(self):
        return convert(self.rotation, 
                       self.translation,
                       parameterization=self.parameterization,
                       convention=self.convention,)

from diffpose.

CYXYZ avatar CYXYZ commented on June 8, 2024

It's in refactor-se3 branch.

from diffpose.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.