Comments (6)
Email have been sent to paper authors regarding this concern. Still waiting for answers.
from partialconv.
CC: @liuguilin1225 @fitsumreda @bryancatanzaro
I have added to comparison basic Conv2d as many asked to reproduce your results, and I see that it is goes completely opposite with the paper.
Code:
Details
from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor
class PartialConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
# whether the mask is multi-channel or not
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False
if 'return_mask' in kwargs:
self.return_mask = kwargs['return_mask']
kwargs.pop('return_mask')
else:
self.return_mask = False
super(PartialConv2d, self).__init__(*args, **kwargs)
if self.multi_channel:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(self.out_channels, self.in_channels,
self.kernel_size[0], self.kernel_size[1]))
else:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]
self.last_size = (None, None, None, None)
self.update_mask = None
self.mask_ratio = None
def forward(self, input, mask_in=None):
assert len(input.shape) == 4
if mask_in is not None or self.last_size != tuple(input.shape):
self.last_size = tuple(input.shape)
with torch.no_grad():
if mask_in is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
else:
mask = mask_in
self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)
# for mixed precision training, change 1e-8 to 1e-6
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
# self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)
if self.return_mask:
return output, self.update_mask
else:
return output
class MaskedConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
eps=1e-8,
multichannel: bool = False,
partial_conv: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
if multichannel:
self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
else:
self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
self.eps = eps
self.multichannel = multichannel
self.partial_conv = partial_conv
def get_mask(
self,
input: torch.Tensor,
mask: torch.Tensor | None
) -> (torch.Tensor, torch.Tensor):
if mask is None:
if self.multichannel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
else:
if self.multichannel:
mask = mask.expand_as(input)
else:
mask = mask.expand(1, 1, *input.shape[2:])
return mask
def forward(
self,
input: torch.Tensor,
mask: torch.Tensor | None = None
) -> (torch.Tensor, torch.Tensor | None):
if mask is not None:
input *= mask
mask = self.get_mask(input, mask)
if self.partial_conv:
output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
mask_ratio = mask_kernel_numel / (mask + self.eps)
mask.clamp_(0, 1)
# Apply re-weighting and bias
output *= mask_ratio
if self.bias is not None:
output += self.bias.view(-1, 1, 1)
output *= mask
else:
output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
mask = mask / max_vals
return output, mask
def extra_repr(self):
return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"
class MaskedPixelUnshuffle(nn.PixelUnshuffle):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
return super().forward(input), super().forward(mask) if mask is not None else None
class MaskedSequential(nn.Sequential):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
for module in self:
input, mask = module(input, mask)
return input, mask
@contextmanager
def register_hooks(
model: torch.nn.Module,
hook: Callable,
predicate: Callable[[str, torch.nn.Module], bool],
**hook_kwargs
):
handles = []
try:
for name, module in model.named_modules():
if predicate(name, module):
hook: Callable = partial(hook, name=name, **hook_kwargs)
handle = module.register_forward_hook(hook)
handles.append(handle)
yield handles
finally:
for handle in handles:
handle.remove()
def activations_recorder_hook(
module: torch.nn.Module,
input: torch.Tensor,
output: torch.Tensor,
name: str,
*,
storage: dict[str, Any]
):
if name in storage:
if isinstance(storage[name], list):
storage[name].append(output)
else:
storage[name] = [storage[name], output]
else:
storage[name] = output
def forward_with_activations(
model: torch.nn.Module,
predicate: Callable[[str, torch.nn.Module], bool],
*model_args,
**model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
storage = {}
with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
output = model(*model_args, **model_kwargs)
return output, storage
def test_it():
torch.manual_seed(37)
in_channels = 3
downscale_factor = 2
scale = 1
base = 2
depth = 8
visualize_depth = 6
eps = 1e-8
conv = []
for i in range(depth):
conv.append(nn.PixelUnshuffle(downscale_factor))
conv.append(nn.Conv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False)
)
conv = nn.Sequential(*conv)
pconv = []
for i in range(depth):
pconv.append(MaskedPixelUnshuffle(downscale_factor))
pconv.append(PartialConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
)
pconv = MaskedSequential(*pconv)
mpconv = []
for i in range(depth):
mpconv.append(MaskedPixelUnshuffle(downscale_factor))
mpconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
)
mpconv = MaskedSequential(*mpconv)
mconv = []
for i in range(depth):
mconv.append(MaskedPixelUnshuffle(downscale_factor))
mconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
)
mconv = MaskedSequential(*mconv)
with torch.no_grad():
print(f"{conv=}")
print(f"{pconv=}")
print(f"{mpconv=}")
print(f"{mconv=}")
print(f"{list(conv.state_dict().keys())=}")
print(f"{list(pconv.state_dict().keys())=}")
print(f"{list(mpconv.state_dict().keys())=}")
print(f"{list(mconv.state_dict().keys())=}")
pconv.load_state_dict(conv.state_dict())
mpconv.load_state_dict(conv.state_dict())
mconv.load_state_dict(conv.state_dict())
# x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
x = torch.randn(1, in_channels, 512, 512)
mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x)
def is_conv_predicate(name: str, module: torch.nn.Module):
return isinstance(module, torch.nn.Conv2d)
y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x)
(y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv)
(y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv)
(y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv)
assert not torch.allclose(y_conv, y_mpconv)
assert torch.allclose(y_mpconv, y_pconv)
assert not torch.allclose(y_mconv, y_mpconv)
print(f"{activations_pconv.keys()=}") # ['1', '3', '5', '7', '9', '11', '13', '15']
# fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180)
fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180)
axs = axs.flatten()
for impl_i, (name, y, mask, activations) in enumerate([
("conv", y_conv, None, activations_conv),
("pconv", y_pconv, mask_pconv, activations_pconv),
("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
("mconv", y_mconv, mask_mconv, activations_mconv)
]):
batch_i = 0
for depth_i in range(visualize_depth):
# ax = axs[depth_i * 4 + impl_i]
ax = axs[impl_i * visualize_depth + depth_i]
layer_output = activations[f"{depth_i * 2 + 1}"]
if isinstance(layer_output, torch.Tensor):
output = layer_output[batch_i]
mask_output = None
else:
output = layer_output[0][batch_i]
mask_output = layer_output[1][batch_i]
assert output.dim() == 3
mean = output.mean()
std = output.std(unbiased=False)
skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")
# ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std)
ax.set_title(f"{name} {depth_i=}")
ax.axis('off')
# plt.suptitle(f"Depth {depth_i}")
plt.show()
if __name__ == '__main__':
test_it()
Output:
Details
conv=Sequential(
(0): PixelUnshuffle(downscale_factor=2)
(1): Conv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(2): PixelUnshuffle(downscale_factor=2)
(3): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): PixelUnshuffle(downscale_factor=2)
(5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(6): PixelUnshuffle(downscale_factor=2)
(7): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(8): PixelUnshuffle(downscale_factor=2)
(9): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(10): PixelUnshuffle(downscale_factor=2)
(11): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(12): PixelUnshuffle(downscale_factor=2)
(13): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(14): PixelUnshuffle(downscale_factor=2)
(15): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
pconv=MaskedSequential(
(0): MaskedPixelUnshuffle(downscale_factor=2)
(1): PartialConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(2): MaskedPixelUnshuffle(downscale_factor=2)
(3): PartialConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): MaskedPixelUnshuffle(downscale_factor=2)
(5): PartialConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(6): MaskedPixelUnshuffle(downscale_factor=2)
(7): PartialConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(8): MaskedPixelUnshuffle(downscale_factor=2)
(9): PartialConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(10): MaskedPixelUnshuffle(downscale_factor=2)
(11): PartialConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(12): MaskedPixelUnshuffle(downscale_factor=2)
(13): PartialConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(14): MaskedPixelUnshuffle(downscale_factor=2)
(15): PartialConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
mpconv=MaskedSequential(
(0): MaskedPixelUnshuffle(downscale_factor=2)
(1): MaskedConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
(2): MaskedPixelUnshuffle(downscale_factor=2)
(3): MaskedConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
(4): MaskedPixelUnshuffle(downscale_factor=2)
(5): MaskedConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
(6): MaskedPixelUnshuffle(downscale_factor=2)
(7): MaskedConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
(8): MaskedPixelUnshuffle(downscale_factor=2)
(9): MaskedConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
(10): MaskedPixelUnshuffle(downscale_factor=2)
(11): MaskedConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
(12): MaskedPixelUnshuffle(downscale_factor=2)
(13): MaskedConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
(14): MaskedPixelUnshuffle(downscale_factor=2)
(15): MaskedConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
)
mconv=MaskedSequential(
(0): MaskedPixelUnshuffle(downscale_factor=2)
(1): MaskedConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
(2): MaskedPixelUnshuffle(downscale_factor=2)
(3): MaskedConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
(4): MaskedPixelUnshuffle(downscale_factor=2)
(5): MaskedConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
(6): MaskedPixelUnshuffle(downscale_factor=2)
(7): MaskedConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
(8): MaskedPixelUnshuffle(downscale_factor=2)
(9): MaskedConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
(10): MaskedPixelUnshuffle(downscale_factor=2)
(11): MaskedConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
(12): MaskedPixelUnshuffle(downscale_factor=2)
(13): MaskedConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
(14): MaskedPixelUnshuffle(downscale_factor=2)
(15): MaskedConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
)
list(conv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
list(pconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
list(mpconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
list(mconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
activations_pconv.keys()=dict_keys(['1', '3', '5', '7', '9', '11', '13', '15'])
name='conv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5785), skewness=tensor(-2.6261e-05), kurtosis=tensor(3.0264)
name='conv', depth_i=1, mean=tensor(-0.0006), std=tensor(0.3238), skewness=tensor(0.0080), kurtosis=tensor(3.0212)
name='conv', depth_i=2, mean=tensor(6.5161e-06), std=tensor(0.1855), skewness=tensor(0.0049), kurtosis=tensor(3.0922)
name='conv', depth_i=3, mean=tensor(0.0001), std=tensor(0.1054), skewness=tensor(0.0081), kurtosis=tensor(3.0650)
name='conv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0589), skewness=tensor(-0.0125), kurtosis=tensor(3.1699)
name='conv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0316), skewness=tensor(-0.0147), kurtosis=tensor(3.2110)
name='pconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5821), skewness=tensor(-0.0017), kurtosis=tensor(3.0276)
name='pconv', depth_i=1, mean=tensor(-0.0007), std=tensor(0.3298), skewness=tensor(0.0055), kurtosis=tensor(3.0518)
name='pconv', depth_i=2, mean=tensor(8.8608e-05), std=tensor(0.1937), skewness=tensor(0.0104), kurtosis=tensor(3.1635)
name='pconv', depth_i=3, mean=tensor(0.0003), std=tensor(0.1153), skewness=tensor(0.0133), kurtosis=tensor(3.2829)
name='pconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0705), skewness=tensor(0.0024), kurtosis=tensor(3.3324)
name='pconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0456), skewness=tensor(-0.0024), kurtosis=tensor(3.4953)
name='mpconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5821), skewness=tensor(-0.0017), kurtosis=tensor(3.0276)
name='mpconv', depth_i=1, mean=tensor(-0.0007), std=tensor(0.3298), skewness=tensor(0.0055), kurtosis=tensor(3.0518)
name='mpconv', depth_i=2, mean=tensor(8.8608e-05), std=tensor(0.1937), skewness=tensor(0.0104), kurtosis=tensor(3.1635)
name='mpconv', depth_i=3, mean=tensor(0.0003), std=tensor(0.1153), skewness=tensor(0.0133), kurtosis=tensor(3.2829)
name='mpconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0705), skewness=tensor(0.0024), kurtosis=tensor(3.3324)
name='mpconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0456), skewness=tensor(-0.0024), kurtosis=tensor(3.4953)
name='mconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5785), skewness=tensor(-2.6261e-05), kurtosis=tensor(3.0264)
name='mconv', depth_i=1, mean=tensor(-0.0005), std=tensor(0.3232), skewness=tensor(0.0085), kurtosis=tensor(3.0263)
name='mconv', depth_i=2, mean=tensor(-9.7408e-06), std=tensor(0.1844), skewness=tensor(0.0053), kurtosis=tensor(3.1119)
name='mconv', depth_i=3, mean=tensor(0.0001), std=tensor(0.1039), skewness=tensor(0.0093), kurtosis=tensor(3.1074)
name='mconv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0571), skewness=tensor(-0.0164), kurtosis=tensor(3.2821)
name='mconv', depth_i=5, mean=tensor(0.0006), std=tensor(0.0296), skewness=tensor(-0.0277), kurtosis=tensor(3.3867)
from partialconv.
It even looks worse with the partially occluded mask.
Code:
Details
from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor
class PartialConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
# whether the mask is multi-channel or not
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False
if 'return_mask' in kwargs:
self.return_mask = kwargs['return_mask']
kwargs.pop('return_mask')
else:
self.return_mask = False
super(PartialConv2d, self).__init__(*args, **kwargs)
if self.multi_channel:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(self.out_channels, self.in_channels,
self.kernel_size[0], self.kernel_size[1]))
else:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]
self.last_size = (None, None, None, None)
self.update_mask = None
self.mask_ratio = None
def forward(self, input, mask_in=None):
assert len(input.shape) == 4
if mask_in is not None or self.last_size != tuple(input.shape):
self.last_size = tuple(input.shape)
with torch.no_grad():
if mask_in is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
else:
mask = mask_in
self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)
# for mixed precision training, change 1e-8 to 1e-6
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
# self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)
if self.return_mask:
return output, self.update_mask
else:
return output
class MaskedConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
eps=1e-8,
multichannel: bool = False,
partial_conv: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
if multichannel:
self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
else:
self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
self.eps = eps
self.multichannel = multichannel
self.partial_conv = partial_conv
def get_mask(
self,
input: torch.Tensor,
mask: torch.Tensor | None
) -> (torch.Tensor, torch.Tensor):
if mask is None:
if self.multichannel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
else:
if self.multichannel:
mask = mask.expand_as(input)
else:
mask = mask.expand(1, 1, *input.shape[2:])
return mask
def forward(
self,
input: torch.Tensor,
mask: torch.Tensor | None = None
) -> (torch.Tensor, torch.Tensor | None):
if mask is not None:
input *= mask
mask = self.get_mask(input, mask)
if self.partial_conv:
output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
mask_ratio = mask_kernel_numel / (mask + self.eps)
mask.clamp_(0, 1)
# Apply re-weighting and bias
output *= mask_ratio
if self.bias is not None:
output += self.bias.view(-1, 1, 1)
output *= mask
else:
output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
mask = mask / max_vals
return output, mask
def extra_repr(self):
return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"
class MaskedPixelUnshuffle(nn.PixelUnshuffle):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
return super().forward(input), super().forward(mask) if mask is not None else None
class MaskedSequential(nn.Sequential):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
for module in self:
input, mask = module(input, mask)
return input, mask
@contextmanager
def register_hooks(
model: torch.nn.Module,
hook: Callable,
predicate: Callable[[str, torch.nn.Module], bool],
**hook_kwargs
):
handles = []
try:
for name, module in model.named_modules():
if predicate(name, module):
hook: Callable = partial(hook, name=name, **hook_kwargs)
handle = module.register_forward_hook(hook)
handles.append(handle)
yield handles
finally:
for handle in handles:
handle.remove()
def activations_recorder_hook(
module: torch.nn.Module,
input: torch.Tensor,
output: torch.Tensor,
name: str,
*,
storage: dict[str, Any]
):
if name in storage:
if isinstance(storage[name], list):
storage[name].append(output)
else:
storage[name] = [storage[name], output]
else:
storage[name] = output
def forward_with_activations(
model: torch.nn.Module,
predicate: Callable[[str, torch.nn.Module], bool],
*model_args,
**model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
storage = {}
with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
output = model(*model_args, **model_kwargs)
return output, storage
def test_it():
torch.manual_seed(37)
in_channels = 3
downscale_factor = 2
scale = 1
base = 2
depth = 8
visualize_depth = 6
eps = 1e-8
conv = []
for i in range(depth):
conv.append(nn.PixelUnshuffle(downscale_factor))
conv.append(nn.Conv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False)
)
conv = nn.Sequential(*conv)
pconv = []
for i in range(depth):
pconv.append(MaskedPixelUnshuffle(downscale_factor))
pconv.append(PartialConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
)
pconv = MaskedSequential(*pconv)
mpconv = []
for i in range(depth):
mpconv.append(MaskedPixelUnshuffle(downscale_factor))
mpconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
)
mpconv = MaskedSequential(*mpconv)
mconv = []
for i in range(depth):
mconv.append(MaskedPixelUnshuffle(downscale_factor))
mconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
)
mconv = MaskedSequential(*mconv)
with torch.no_grad():
print(f"{conv=}")
print(f"{pconv=}")
print(f"{mpconv=}")
print(f"{mconv=}")
print(f"{list(conv.state_dict().keys())=}")
print(f"{list(pconv.state_dict().keys())=}")
print(f"{list(mpconv.state_dict().keys())=}")
print(f"{list(mconv.state_dict().keys())=}")
pconv.load_state_dict(conv.state_dict())
mpconv.load_state_dict(conv.state_dict())
mconv.load_state_dict(conv.state_dict())
# x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
x = torch.randn(1, in_channels, 512, 512)
x_mask = torch.ones_like(x)
x_mask[..., 128:256, 128:256] = 0
def is_conv_predicate(name: str, module: torch.nn.Module):
return isinstance(module, torch.nn.Conv2d)
y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x * x_mask)
(y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, x_mask)
(y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, x_mask)
(y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, x_mask)
assert not torch.allclose(y_conv, y_mpconv)
assert torch.allclose(y_mpconv, y_pconv)
assert not torch.allclose(y_mconv, y_mpconv)
print(f"{activations_pconv.keys()=}") # ['1', '3', '5', '7', '9', '11', '13', '15']
# fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180)
fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180)
axs = axs.flatten()
for impl_i, (name, y, mask, activations) in enumerate([
("conv", y_conv, None, activations_conv),
("pconv", y_pconv, mask_pconv, activations_pconv),
("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
("mconv", y_mconv, mask_mconv, activations_mconv)
]):
batch_i = 0
for depth_i in range(visualize_depth):
# ax = axs[depth_i * 4 + impl_i]
ax = axs[impl_i * visualize_depth + depth_i]
layer_output = activations[f"{depth_i * 2 + 1}"]
if isinstance(layer_output, torch.Tensor):
output = layer_output[batch_i]
mask_output = None
else:
output = layer_output[0][batch_i]
mask_output = layer_output[1][batch_i]
assert output.dim() == 3
mean = output.mean()
std = output.std(unbiased=False)
skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")
# ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std)
ax.set_title(f"{name} {depth_i=}")
ax.axis('off')
# plt.suptitle(f"Depth {depth_i}")
plt.show()
if __name__ == '__main__':
test_it()
Output (notice large kurtosis, which means that there is more peaking outliers in the distribution):
name='conv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5601), skewness=tensor(-0.0016), kurtosis=tensor(3.2203)
name='conv', depth_i=1, mean=tensor(-0.0006), std=tensor(0.3134), skewness=tensor(0.0081), kurtosis=tensor(3.2148)
name='conv', depth_i=2, mean=tensor(0.0002), std=tensor(0.1794), skewness=tensor(0.0086), kurtosis=tensor(3.2706)
name='conv', depth_i=3, mean=tensor(6.1037e-06), std=tensor(0.1016), skewness=tensor(0.0055), kurtosis=tensor(3.2192)
name='conv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0566), skewness=tensor(-0.0155), kurtosis=tensor(3.2757)
name='conv', depth_i=5, mean=tensor(0.0004), std=tensor(0.0301), skewness=tensor(-0.0230), kurtosis=tensor(3.1709)
name='pconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5679), skewness=tensor(-0.0017), kurtosis=tensor(3.2674)
name='pconv', depth_i=1, mean=tensor(-0.0011), std=tensor(0.3480), skewness=tensor(-0.0731), kurtosis=tensor(9.9449)
name='pconv', depth_i=2, mean=tensor(2.2279e-05), std=tensor(0.2393), skewness=tensor(-0.1714), kurtosis=tensor(20.2840)
name='pconv', depth_i=3, mean=tensor(0.0017), std=tensor(0.1883), skewness=tensor(-0.1843), kurtosis=tensor(33.2860)
name='pconv', depth_i=4, mean=tensor(0.0009), std=tensor(0.1353), skewness=tensor(0.5092), kurtosis=tensor(22.7196)
name='pconv', depth_i=5, mean=tensor(0.0002), std=tensor(0.0836), skewness=tensor(-0.1813), kurtosis=tensor(6.7048)
name='mpconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5679), skewness=tensor(-0.0017), kurtosis=tensor(3.2674)
name='mpconv', depth_i=1, mean=tensor(-0.0011), std=tensor(0.3480), skewness=tensor(-0.0731), kurtosis=tensor(9.9449)
name='mpconv', depth_i=2, mean=tensor(2.2279e-05), std=tensor(0.2393), skewness=tensor(-0.1714), kurtosis=tensor(20.2840)
name='mpconv', depth_i=3, mean=tensor(0.0017), std=tensor(0.1883), skewness=tensor(-0.1843), kurtosis=tensor(33.2860)
name='mpconv', depth_i=4, mean=tensor(0.0009), std=tensor(0.1353), skewness=tensor(0.5092), kurtosis=tensor(22.7196)
name='mpconv', depth_i=5, mean=tensor(0.0002), std=tensor(0.0836), skewness=tensor(-0.1813), kurtosis=tensor(6.7048)
name='mconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5601), skewness=tensor(-0.0016), kurtosis=tensor(3.2203)
name='mconv', depth_i=1, mean=tensor(-0.0005), std=tensor(0.3124), skewness=tensor(0.0086), kurtosis=tensor(3.2303)
name='mconv', depth_i=2, mean=tensor(0.0001), std=tensor(0.1776), skewness=tensor(0.0087), kurtosis=tensor(3.3192)
name='mconv', depth_i=3, mean=tensor(-1.3955e-05), std=tensor(0.0991), skewness=tensor(0.0069), kurtosis=tensor(3.3181)
name='mconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0537), skewness=tensor(-0.0256), kurtosis=tensor(3.4699)
name='mconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0266), skewness=tensor(-0.0291), kurtosis=tensor(3.3908)
from partialconv.
partialconv even worse than regular convolution in object detection task (DETR-like model with Hungarian loss to minimize). Training performed of different image sizes batched, with their respective mask.
- black: nn.Conv2d
- red: PartialConv2d (this repo)
- blue: MaskedConv2d (my implementation of masked convolution)
from partialconv.
I have received a response from the authors. I will provide further details via email.
from partialconv.
I have released Masked Convolution for Diverse Sample Sizes, so you can now use the fix from this issue under permissive license: https://github.com/ivanstepanovftw/masked_torch
from partialconv.
Related Issues (20)
- Pretrained Checkpoints
- Demo not working HOT 12
- About train details
- papaer arch partial conv num question HOT 1
- Problem with Pretrained checkpoints
- Some comments about code of PartialConv2d HOT 4
- How to test the code with the different ratios mask? HOT 1
- About mask training dataset HOT 5
- Doesn't take 2 channel mask as input HOT 2
- Online Demo down? HOT 7
- Pytorch export trace/script
- Blurry results and non-recoverable facial features in CelebA-HQ dataset HOT 3
- image inpainting error
- I can't import models in main.py
- About args: multi-channel for image inpainting
- partial con
- Inpainting demo not working HOT 2
- The updating of mask HOT 1
- 2d and 3d implementation differences
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from partialconv.