Coder Social home page Coder Social logo

Comments (6)

ivanstepanovftw avatar ivanstepanovftw commented on June 20, 2024

Email have been sent to paper authors regarding this concern. Still waiting for answers.

from partialconv.

ivanstepanovftw avatar ivanstepanovftw commented on June 20, 2024

CC: @liuguilin1225 @fitsumreda @bryancatanzaro

I have added to comparison basic Conv2d as many asked to reproduce your results, and I see that it is goes completely opposite with the paper.

Code:

Details

from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable

import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor


class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(self.out_channels, self.in_channels,
                                                   self.kernel_size[0], self.kernel_size[1]))
        else:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones_like(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
                else:
                    mask = mask_in

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)

        if self.return_mask:
            return output, self.update_mask
        else:
            return output


class MaskedConv2d(nn.Conv2d):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = 'zeros',
            eps=1e-8,
            multichannel: bool = False,
            partial_conv: bool = False,
            device=None,
            dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
        if multichannel:
            self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
        else:
            self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
        self.eps = eps
        self.multichannel = multichannel
        self.partial_conv = partial_conv

    def get_mask(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None
    ) -> (torch.Tensor, torch.Tensor):
        if mask is None:
            if self.multichannel:
                mask = torch.ones_like(input)
            else:
                mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
        else:
            if self.multichannel:
                mask = mask.expand_as(input)
            else:
                mask = mask.expand(1, 1, *input.shape[2:])
        return mask

    def forward(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None = None
    ) -> (torch.Tensor, torch.Tensor | None):
        if mask is not None:
            input *= mask

        mask = self.get_mask(input, mask)

        if self.partial_conv:
            output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
            mask_ratio = mask_kernel_numel / (mask + self.eps)
            mask.clamp_(0, 1)

            # Apply re-weighting and bias
            output *= mask_ratio
            if self.bias is not None:
                output += self.bias.view(-1, 1, 1)

            output *= mask
        else:
            output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
            mask = mask / max_vals

        return output, mask

    def extra_repr(self):
        return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"


class MaskedPixelUnshuffle(nn.PixelUnshuffle):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        return super().forward(input), super().forward(mask) if mask is not None else None


class MaskedSequential(nn.Sequential):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        for module in self:
            input, mask = module(input, mask)
        return input, mask


@contextmanager
def register_hooks(
        model: torch.nn.Module,
        hook: Callable,
        predicate: Callable[[str, torch.nn.Module], bool],
        **hook_kwargs
):
    handles = []
    try:
        for name, module in model.named_modules():
            if predicate(name, module):
                hook: Callable = partial(hook, name=name, **hook_kwargs)
                handle = module.register_forward_hook(hook)
                handles.append(handle)
        yield handles
    finally:
        for handle in handles:
            handle.remove()


def activations_recorder_hook(
        module: torch.nn.Module,
        input: torch.Tensor,
        output: torch.Tensor,
        name: str,
        *,
        storage: dict[str, Any]
):
    if name in storage:
        if isinstance(storage[name], list):
            storage[name].append(output)
        else:
            storage[name] = [storage[name], output]
    else:
        storage[name] = output


def forward_with_activations(
        model: torch.nn.Module,
        predicate: Callable[[str, torch.nn.Module], bool],
        *model_args,
        **model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
    storage = {}
    with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
        output = model(*model_args, **model_kwargs)
    return output, storage


def test_it():
    torch.manual_seed(37)

    in_channels = 3
    downscale_factor = 2
    scale = 1
    base = 2
    depth = 8
    visualize_depth = 6
    eps = 1e-8

    conv = []
    for i in range(depth):
        conv.append(nn.PixelUnshuffle(downscale_factor))
        conv.append(nn.Conv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False)
        )
    conv = nn.Sequential(*conv)

    pconv = []
    for i in range(depth):
        pconv.append(MaskedPixelUnshuffle(downscale_factor))
        pconv.append(PartialConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
        )
    pconv = MaskedSequential(*pconv)

    mpconv = []
    for i in range(depth):
        mpconv.append(MaskedPixelUnshuffle(downscale_factor))
        mpconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
        )
    mpconv = MaskedSequential(*mpconv)

    mconv = []
    for i in range(depth):
        mconv.append(MaskedPixelUnshuffle(downscale_factor))
        mconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
        )
    mconv = MaskedSequential(*mconv)

    with torch.no_grad():
        print(f"{conv=}")
        print(f"{pconv=}")
        print(f"{mpconv=}")
        print(f"{mconv=}")

        print(f"{list(conv.state_dict().keys())=}")
        print(f"{list(pconv.state_dict().keys())=}")
        print(f"{list(mpconv.state_dict().keys())=}")
        print(f"{list(mconv.state_dict().keys())=}")

        pconv.load_state_dict(conv.state_dict())
        mpconv.load_state_dict(conv.state_dict())
        mconv.load_state_dict(conv.state_dict())

        # x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
        x = torch.randn(1, in_channels, 512, 512)
        mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x)

        def is_conv_predicate(name: str, module: torch.nn.Module):
            return isinstance(module, torch.nn.Conv2d)

        y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x)
        (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv)
        (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv)
        (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv)

        assert not torch.allclose(y_conv, y_mpconv)
        assert torch.allclose(y_mpconv, y_pconv)
        assert not torch.allclose(y_mconv, y_mpconv)

        print(f"{activations_pconv.keys()=}")  # ['1', '3', '5', '7', '9', '11', '13', '15']

        # fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180)
        fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180)
        axs = axs.flatten()

        for impl_i, (name, y, mask, activations) in enumerate([
            ("conv", y_conv, None, activations_conv),
            ("pconv", y_pconv, mask_pconv, activations_pconv),
            ("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
            ("mconv", y_mconv, mask_mconv, activations_mconv)
        ]):
            batch_i = 0
            for depth_i in range(visualize_depth):
                # ax = axs[depth_i * 4 + impl_i]
                ax = axs[impl_i * visualize_depth + depth_i]

                layer_output = activations[f"{depth_i * 2 + 1}"]
                if isinstance(layer_output, torch.Tensor):
                    output = layer_output[batch_i]
                    mask_output = None
                else:
                    output = layer_output[0][batch_i]
                    mask_output = layer_output[1][batch_i]

                assert output.dim() == 3

                mean = output.mean()
                std = output.std(unbiased=False)
                skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
                kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
                print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")

                # ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
                ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std)
                ax.set_title(f"{name} {depth_i=}")
                ax.axis('off')

        # plt.suptitle(f"Depth {depth_i}")
        plt.show()


if __name__ == '__main__':
    test_it()

Output:

Details

conv=Sequential(
  (0): PixelUnshuffle(downscale_factor=2)
  (1): Conv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (2): PixelUnshuffle(downscale_factor=2)
  (3): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): PixelUnshuffle(downscale_factor=2)
  (5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (6): PixelUnshuffle(downscale_factor=2)
  (7): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (8): PixelUnshuffle(downscale_factor=2)
  (9): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (10): PixelUnshuffle(downscale_factor=2)
  (11): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (12): PixelUnshuffle(downscale_factor=2)
  (13): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (14): PixelUnshuffle(downscale_factor=2)
  (15): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
pconv=MaskedSequential(
  (0): MaskedPixelUnshuffle(downscale_factor=2)
  (1): PartialConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (2): MaskedPixelUnshuffle(downscale_factor=2)
  (3): PartialConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): MaskedPixelUnshuffle(downscale_factor=2)
  (5): PartialConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (6): MaskedPixelUnshuffle(downscale_factor=2)
  (7): PartialConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (8): MaskedPixelUnshuffle(downscale_factor=2)
  (9): PartialConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (10): MaskedPixelUnshuffle(downscale_factor=2)
  (11): PartialConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (12): MaskedPixelUnshuffle(downscale_factor=2)
  (13): PartialConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (14): MaskedPixelUnshuffle(downscale_factor=2)
  (15): PartialConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
mpconv=MaskedSequential(
  (0): MaskedPixelUnshuffle(downscale_factor=2)
  (1): MaskedConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (2): MaskedPixelUnshuffle(downscale_factor=2)
  (3): MaskedConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (4): MaskedPixelUnshuffle(downscale_factor=2)
  (5): MaskedConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (6): MaskedPixelUnshuffle(downscale_factor=2)
  (7): MaskedConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (8): MaskedPixelUnshuffle(downscale_factor=2)
  (9): MaskedConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (10): MaskedPixelUnshuffle(downscale_factor=2)
  (11): MaskedConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (12): MaskedPixelUnshuffle(downscale_factor=2)
  (13): MaskedConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (14): MaskedPixelUnshuffle(downscale_factor=2)
  (15): MaskedConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
)
mconv=MaskedSequential(
  (0): MaskedPixelUnshuffle(downscale_factor=2)
  (1): MaskedConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (2): MaskedPixelUnshuffle(downscale_factor=2)
  (3): MaskedConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (4): MaskedPixelUnshuffle(downscale_factor=2)
  (5): MaskedConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (6): MaskedPixelUnshuffle(downscale_factor=2)
  (7): MaskedConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (8): MaskedPixelUnshuffle(downscale_factor=2)
  (9): MaskedConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (10): MaskedPixelUnshuffle(downscale_factor=2)
  (11): MaskedConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (12): MaskedPixelUnshuffle(downscale_factor=2)
  (13): MaskedConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (14): MaskedPixelUnshuffle(downscale_factor=2)
  (15): MaskedConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
)
list(conv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
list(pconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
list(mpconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
list(mconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
activations_pconv.keys()=dict_keys(['1', '3', '5', '7', '9', '11', '13', '15'])
name='conv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5785), skewness=tensor(-2.6261e-05), kurtosis=tensor(3.0264)
name='conv', depth_i=1, mean=tensor(-0.0006), std=tensor(0.3238), skewness=tensor(0.0080), kurtosis=tensor(3.0212)
name='conv', depth_i=2, mean=tensor(6.5161e-06), std=tensor(0.1855), skewness=tensor(0.0049), kurtosis=tensor(3.0922)
name='conv', depth_i=3, mean=tensor(0.0001), std=tensor(0.1054), skewness=tensor(0.0081), kurtosis=tensor(3.0650)
name='conv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0589), skewness=tensor(-0.0125), kurtosis=tensor(3.1699)
name='conv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0316), skewness=tensor(-0.0147), kurtosis=tensor(3.2110)
name='pconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5821), skewness=tensor(-0.0017), kurtosis=tensor(3.0276)
name='pconv', depth_i=1, mean=tensor(-0.0007), std=tensor(0.3298), skewness=tensor(0.0055), kurtosis=tensor(3.0518)
name='pconv', depth_i=2, mean=tensor(8.8608e-05), std=tensor(0.1937), skewness=tensor(0.0104), kurtosis=tensor(3.1635)
name='pconv', depth_i=3, mean=tensor(0.0003), std=tensor(0.1153), skewness=tensor(0.0133), kurtosis=tensor(3.2829)
name='pconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0705), skewness=tensor(0.0024), kurtosis=tensor(3.3324)
name='pconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0456), skewness=tensor(-0.0024), kurtosis=tensor(3.4953)
name='mpconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5821), skewness=tensor(-0.0017), kurtosis=tensor(3.0276)
name='mpconv', depth_i=1, mean=tensor(-0.0007), std=tensor(0.3298), skewness=tensor(0.0055), kurtosis=tensor(3.0518)
name='mpconv', depth_i=2, mean=tensor(8.8608e-05), std=tensor(0.1937), skewness=tensor(0.0104), kurtosis=tensor(3.1635)
name='mpconv', depth_i=3, mean=tensor(0.0003), std=tensor(0.1153), skewness=tensor(0.0133), kurtosis=tensor(3.2829)
name='mpconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0705), skewness=tensor(0.0024), kurtosis=tensor(3.3324)
name='mpconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0456), skewness=tensor(-0.0024), kurtosis=tensor(3.4953)
name='mconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5785), skewness=tensor(-2.6261e-05), kurtosis=tensor(3.0264)
name='mconv', depth_i=1, mean=tensor(-0.0005), std=tensor(0.3232), skewness=tensor(0.0085), kurtosis=tensor(3.0263)
name='mconv', depth_i=2, mean=tensor(-9.7408e-06), std=tensor(0.1844), skewness=tensor(0.0053), kurtosis=tensor(3.1119)
name='mconv', depth_i=3, mean=tensor(0.0001), std=tensor(0.1039), skewness=tensor(0.0093), kurtosis=tensor(3.1074)
name='mconv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0571), skewness=tensor(-0.0164), kurtosis=tensor(3.2821)
name='mconv', depth_i=5, mean=tensor(0.0006), std=tensor(0.0296), skewness=tensor(-0.0277), kurtosis=tensor(3.3867)

image

from partialconv.

ivanstepanovftw avatar ivanstepanovftw commented on June 20, 2024

It even looks worse with the partially occluded mask.

Code:

Details

from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable

import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor


class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(self.out_channels, self.in_channels,
                                                   self.kernel_size[0], self.kernel_size[1]))
        else:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones_like(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
                else:
                    mask = mask_in

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)

        if self.return_mask:
            return output, self.update_mask
        else:
            return output


class MaskedConv2d(nn.Conv2d):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = 'zeros',
            eps=1e-8,
            multichannel: bool = False,
            partial_conv: bool = False,
            device=None,
            dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
        if multichannel:
            self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
        else:
            self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
        self.eps = eps
        self.multichannel = multichannel
        self.partial_conv = partial_conv

    def get_mask(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None
    ) -> (torch.Tensor, torch.Tensor):
        if mask is None:
            if self.multichannel:
                mask = torch.ones_like(input)
            else:
                mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
        else:
            if self.multichannel:
                mask = mask.expand_as(input)
            else:
                mask = mask.expand(1, 1, *input.shape[2:])
        return mask

    def forward(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None = None
    ) -> (torch.Tensor, torch.Tensor | None):
        if mask is not None:
            input *= mask

        mask = self.get_mask(input, mask)

        if self.partial_conv:
            output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
            mask_ratio = mask_kernel_numel / (mask + self.eps)
            mask.clamp_(0, 1)

            # Apply re-weighting and bias
            output *= mask_ratio
            if self.bias is not None:
                output += self.bias.view(-1, 1, 1)

            output *= mask
        else:
            output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
            mask = mask / max_vals

        return output, mask

    def extra_repr(self):
        return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"


class MaskedPixelUnshuffle(nn.PixelUnshuffle):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        return super().forward(input), super().forward(mask) if mask is not None else None


class MaskedSequential(nn.Sequential):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        for module in self:
            input, mask = module(input, mask)
        return input, mask


@contextmanager
def register_hooks(
        model: torch.nn.Module,
        hook: Callable,
        predicate: Callable[[str, torch.nn.Module], bool],
        **hook_kwargs
):
    handles = []
    try:
        for name, module in model.named_modules():
            if predicate(name, module):
                hook: Callable = partial(hook, name=name, **hook_kwargs)
                handle = module.register_forward_hook(hook)
                handles.append(handle)
        yield handles
    finally:
        for handle in handles:
            handle.remove()


def activations_recorder_hook(
        module: torch.nn.Module,
        input: torch.Tensor,
        output: torch.Tensor,
        name: str,
        *,
        storage: dict[str, Any]
):
    if name in storage:
        if isinstance(storage[name], list):
            storage[name].append(output)
        else:
            storage[name] = [storage[name], output]
    else:
        storage[name] = output


def forward_with_activations(
        model: torch.nn.Module,
        predicate: Callable[[str, torch.nn.Module], bool],
        *model_args,
        **model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
    storage = {}
    with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
        output = model(*model_args, **model_kwargs)
    return output, storage


def test_it():
    torch.manual_seed(37)

    in_channels = 3
    downscale_factor = 2
    scale = 1
    base = 2
    depth = 8
    visualize_depth = 6
    eps = 1e-8

    conv = []
    for i in range(depth):
        conv.append(nn.PixelUnshuffle(downscale_factor))
        conv.append(nn.Conv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False)
        )
    conv = nn.Sequential(*conv)

    pconv = []
    for i in range(depth):
        pconv.append(MaskedPixelUnshuffle(downscale_factor))
        pconv.append(PartialConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
        )
    pconv = MaskedSequential(*pconv)

    mpconv = []
    for i in range(depth):
        mpconv.append(MaskedPixelUnshuffle(downscale_factor))
        mpconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
        )
    mpconv = MaskedSequential(*mpconv)

    mconv = []
    for i in range(depth):
        mconv.append(MaskedPixelUnshuffle(downscale_factor))
        mconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
        )
    mconv = MaskedSequential(*mconv)

    with torch.no_grad():
        print(f"{conv=}")
        print(f"{pconv=}")
        print(f"{mpconv=}")
        print(f"{mconv=}")

        print(f"{list(conv.state_dict().keys())=}")
        print(f"{list(pconv.state_dict().keys())=}")
        print(f"{list(mpconv.state_dict().keys())=}")
        print(f"{list(mconv.state_dict().keys())=}")

        pconv.load_state_dict(conv.state_dict())
        mpconv.load_state_dict(conv.state_dict())
        mconv.load_state_dict(conv.state_dict())

        # x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
        x = torch.randn(1, in_channels, 512, 512)
        x_mask = torch.ones_like(x)
        x_mask[..., 128:256, 128:256] = 0

        def is_conv_predicate(name: str, module: torch.nn.Module):
            return isinstance(module, torch.nn.Conv2d)

        y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x * x_mask)
        (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, x_mask)
        (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, x_mask)
        (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, x_mask)

        assert not torch.allclose(y_conv, y_mpconv)
        assert torch.allclose(y_mpconv, y_pconv)
        assert not torch.allclose(y_mconv, y_mpconv)

        print(f"{activations_pconv.keys()=}")  # ['1', '3', '5', '7', '9', '11', '13', '15']

        # fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180)
        fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180)
        axs = axs.flatten()

        for impl_i, (name, y, mask, activations) in enumerate([
            ("conv", y_conv, None, activations_conv),
            ("pconv", y_pconv, mask_pconv, activations_pconv),
            ("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
            ("mconv", y_mconv, mask_mconv, activations_mconv)
        ]):
            batch_i = 0
            for depth_i in range(visualize_depth):
                # ax = axs[depth_i * 4 + impl_i]
                ax = axs[impl_i * visualize_depth + depth_i]

                layer_output = activations[f"{depth_i * 2 + 1}"]
                if isinstance(layer_output, torch.Tensor):
                    output = layer_output[batch_i]
                    mask_output = None
                else:
                    output = layer_output[0][batch_i]
                    mask_output = layer_output[1][batch_i]

                assert output.dim() == 3

                mean = output.mean()
                std = output.std(unbiased=False)
                skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
                kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
                print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")

                # ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
                ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std)
                ax.set_title(f"{name} {depth_i=}")
                ax.axis('off')

        # plt.suptitle(f"Depth {depth_i}")
        plt.show()


if __name__ == '__main__':
    test_it()

Output (notice large kurtosis, which means that there is more peaking outliers in the distribution):

name='conv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5601), skewness=tensor(-0.0016), kurtosis=tensor(3.2203)
name='conv', depth_i=1, mean=tensor(-0.0006), std=tensor(0.3134), skewness=tensor(0.0081), kurtosis=tensor(3.2148)
name='conv', depth_i=2, mean=tensor(0.0002), std=tensor(0.1794), skewness=tensor(0.0086), kurtosis=tensor(3.2706)
name='conv', depth_i=3, mean=tensor(6.1037e-06), std=tensor(0.1016), skewness=tensor(0.0055), kurtosis=tensor(3.2192)
name='conv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0566), skewness=tensor(-0.0155), kurtosis=tensor(3.2757)
name='conv', depth_i=5, mean=tensor(0.0004), std=tensor(0.0301), skewness=tensor(-0.0230), kurtosis=tensor(3.1709)
name='pconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5679), skewness=tensor(-0.0017), kurtosis=tensor(3.2674)
name='pconv', depth_i=1, mean=tensor(-0.0011), std=tensor(0.3480), skewness=tensor(-0.0731), kurtosis=tensor(9.9449)
name='pconv', depth_i=2, mean=tensor(2.2279e-05), std=tensor(0.2393), skewness=tensor(-0.1714), kurtosis=tensor(20.2840)
name='pconv', depth_i=3, mean=tensor(0.0017), std=tensor(0.1883), skewness=tensor(-0.1843), kurtosis=tensor(33.2860)
name='pconv', depth_i=4, mean=tensor(0.0009), std=tensor(0.1353), skewness=tensor(0.5092), kurtosis=tensor(22.7196)
name='pconv', depth_i=5, mean=tensor(0.0002), std=tensor(0.0836), skewness=tensor(-0.1813), kurtosis=tensor(6.7048)
name='mpconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5679), skewness=tensor(-0.0017), kurtosis=tensor(3.2674)
name='mpconv', depth_i=1, mean=tensor(-0.0011), std=tensor(0.3480), skewness=tensor(-0.0731), kurtosis=tensor(9.9449)
name='mpconv', depth_i=2, mean=tensor(2.2279e-05), std=tensor(0.2393), skewness=tensor(-0.1714), kurtosis=tensor(20.2840)
name='mpconv', depth_i=3, mean=tensor(0.0017), std=tensor(0.1883), skewness=tensor(-0.1843), kurtosis=tensor(33.2860)
name='mpconv', depth_i=4, mean=tensor(0.0009), std=tensor(0.1353), skewness=tensor(0.5092), kurtosis=tensor(22.7196)
name='mpconv', depth_i=5, mean=tensor(0.0002), std=tensor(0.0836), skewness=tensor(-0.1813), kurtosis=tensor(6.7048)
name='mconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5601), skewness=tensor(-0.0016), kurtosis=tensor(3.2203)
name='mconv', depth_i=1, mean=tensor(-0.0005), std=tensor(0.3124), skewness=tensor(0.0086), kurtosis=tensor(3.2303)
name='mconv', depth_i=2, mean=tensor(0.0001), std=tensor(0.1776), skewness=tensor(0.0087), kurtosis=tensor(3.3192)
name='mconv', depth_i=3, mean=tensor(-1.3955e-05), std=tensor(0.0991), skewness=tensor(0.0069), kurtosis=tensor(3.3181)
name='mconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0537), skewness=tensor(-0.0256), kurtosis=tensor(3.4699)
name='mconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0266), skewness=tensor(-0.0291), kurtosis=tensor(3.3908)

image

from partialconv.

ivanstepanovftw avatar ivanstepanovftw commented on June 20, 2024

partialconv even worse than regular convolution in object detection task (DETR-like model with Hungarian loss to minimize). Training performed of different image sizes batched, with their respective mask.

image 1:
image

image 2:
image

  • black: nn.Conv2d
  • red: PartialConv2d (this repo)
  • blue: MaskedConv2d (my implementation of masked convolution)

from partialconv.

ivanstepanovftw avatar ivanstepanovftw commented on June 20, 2024

I have received a response from the authors. I will provide further details via email.

from partialconv.

ivanstepanovftw avatar ivanstepanovftw commented on June 20, 2024

I have released Masked Convolution for Diverse Sample Sizes, so you can now use the fix from this issue under permissive license: https://github.com/ivanstepanovftw/masked_torch

from partialconv.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.