Coder Social home page Coder Social logo

partialconv's Introduction

Partial Convolution Layer for Padding and Image Inpainting

This is the PyTorch implementation of partial convolution layer. It can serve as a new padding scheme; it can also be used for image inpainting.

Partial Convolution based Padding
Guilin Liu, Kevin J. Shih, Ting-Chun Wang, Fitsum A. Reda, Karan Sapra, Zhiding Yu, Andrew Tao, Bryan Catanzaro
NVIDIA Corporation
Technical Report (Technical Report) 2018

Image Inpainting for Irregular Holes Using Partial Convolutions
Guilin Liu, Fitsum A. Reda, Kevin J. Shih, Ting-Chun Wang, Andrew Tao, Bryan Catanzaro
NVIDIA Corporation
In The European Conference on Computer Vision (ECCV) 2018

Comparison with Zero Padding

Installation

Installation can be found: https://github.com/pytorch/examples/tree/master/imagenet

Usage:

  • using partial conv for padding
#typical convolution layer with zero padding
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)

#partial convolution based padding
PartialConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
  • using partial conv for image inpainting, set both multi_channel and return_mask to be True
#partial convolution for inpainting (using multiple channels and updating mask)
PartialConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, multi_channel=True, return_mask=True)

Mixed Precision Training with AMP for image inpainting

  • Installation: to train with mixed precision support, please first install apex from: https://github.com/NVIDIA/apex
  • Required change #1 (Typical changes): typical changes needed for AMP
  from apex import amp
  
  #initializing model and optimizer
  self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=args.amp_opt_level)
  
  #initializing vgg loss function/extractor
  self.vgg_feat_loss = amp.initialize(self.vgg_feat_loss, opt_level=args.amp_opt_level)
  
  #scale loss
  with amp.scale_loss(total_loss, self.g_optimizer) as scaled_loss:
      scaled_loss.backward()

  • Required change #2 (Gram Matrix Loss): in Gram matrix loss computation, change one-step division to two-step smaller divisions
    input = torch.zeros(b, ch, ch).type(features.type())
    gram = torch.baddbmm(input, features, features_t, beta=0, alpha=1./(ch * h * w), out=None)
  • Required change #3 (Small Constant Number): make the small constant number a bit larger (e.g. 1e-8 to 1e-6)
    • change from 1e-8: self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
    • to 1e-6: self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-6)

Usage of partial conv based padding to train ImageNet

  • ResNet50 using zero padding (default padding)
python main.py -a resnet50 --data_train /path/ILSVRC/Data/CLS-LOC/train --data_val /path/ILSVRC/Data/CLS-LOC/perfolder_val --batch-size 192 --workers 32 --prefix multigpu_b192 --ckptdirprefix experiment_1/
  • ResNet50 using partial conv based padding
python main.py -a pdresnet50 --data_train /path/ILSVRC/Data/CLS-LOC/train --data_val /path/ILSVRC/Data/CLS-LOC/perfolder_val --batch-size 192 --workers 32 --prefix multigpu_b192 --ckptdirprefix experiment_1/
  • vgg16_bn using zero padding (default padding)
python main.py -a vgg16_bn --data_train /path/ILSVRC/Data/CLS-LOC/train --data_val /path/ILSVRC/Data/CLS-LOC/perfolder_val --batch-size 192 --workers 32 --prefix multigpu_b192 --ckptdirprefix experiment_1/
  • vgg16_bn using partial conv based padding
python main.py -a pdvgg16_bn --data_train /path/ILSVRC/Data/CLS-LOC/train --data_val /path/ILSVRC/Data/CLS-LOC/perfolder_val --batch-size 192 --workers 32 --prefix multigpu_b192 --ckptdirprefix experiment_1/

Pretrained checkpoints (weights) for VGG and ResNet networks with partial convolution based padding:

https://www.dropbox.com/sh/t6flbuoipyzqid8/AACJ8rtrF6V5b9348aG5PIhia?dl=0

Comparison with Zero Padding, Reflection Padding and Replication Padding for 5 runs

The best top-1 accuracies for each run with 1-crop testing. *_zero, *_pd, *_ref and *_rep indicate the corresponding model with zero padding, partial convolution based padding, reflection padding and replication padding respectively. *_best means the best validation score for each run of the training. Average represents the average accuracy of the 5 runs. Column diff represents the difference with corresponding network using zero padding. Column stdev represents the standard deviation of the accuracies from 5 runs. PT_official represents the corresponding official accuracies published on PyTorch website: https://pytorch.org/docs/stable/torchvision/models.html

Citation

@inproceedings{liu2018partialpadding,
   author    = {Guilin Liu and Kevin J. Shih and Ting-Chun Wang and Fitsum A. Reda and Karan Sapra and Zhiding Yu and Andrew Tao and Bryan Catanzaro},
   title     = {Partial Convolution based Padding},
   booktitle = {arXiv preprint arXiv:1811.11718},   
   year      = {2018},
}
@inproceedings{liu2018partialinpainting,
   author    = {Guilin Liu and Fitsum A. Reda and Kevin J. Shih and Ting-Chun Wang and Andrew Tao and Bryan Catanzaro},
   title     = {Image Inpainting for Irregular Holes Using Partial Convolutions},
   booktitle = {The European Conference on Computer Vision (ECCV)},   
   year      = {2018},
}

Contact: Guilin Liu ([email protected])

Acknowledgments

We thank Jinwei Gu, Matthieu Le, Andrzej Sulecki, Marek Kolodziej and Hongfu Liu for helpful discussions.

partialconv's People

Contributors

bryancatanzaro avatar fitsumreda avatar liuguilin1225 avatar ruthcfong avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

partialconv's Issues

Pretrained Checkpoints

First, congratulations to the PartialConv team, it is a fascinating contribution in many ways.

Also, thank you for making pdresnet checkpoints available, and for the insightful comparison with the other padding modes!
Would it also be possible to have pre-trained weights for REFLECT and REPLICATE?
Thank you!

is there any implementation of 3D partial convolution in keras or tf

now i'm trying to make some codes to build a model for video inpainting, not image inpainting.
on researching inpainting models, i found partial convolution works well
so i would like to build a model with partial convolution for 3D but there is no standard codes or codes working well that i can follow

so do your team have a plan to do or doing some works or research for it?
so is there any reference to read some codes, not papers?

what i mentioned)
https://github.com/MathiasGruber/PConv-Keras
https://arxiv.org/pdf/1804.07723.pdf

About mask training dataset

Thank you for your great project!
I am sorry,Because I don't found the website of the mask training dataset from the paper,Do you provide the website? If you do I will appreciate it!
Looking forward to your reply.

Blurry results and non-recoverable facial features in CelebA-HQ dataset

Hi. Has anyone tried to reproduce the painting results on CelebA-HQ dataset? This model is trained on a randomly select subset of CelebA-HQ of size 27k for 3 days on a GPU and fine-tuned for half of a day using batch size 6. And here are some outputs I get, where the inputs come from my screenshot of the paper:
my_result

For comparison, here are the results shown in the paper (Fig 8):
sample

For my code, the inpainted regions are blurry when the masked areas are big and the masked facial features are not recovered. I wonder if anyone has encountered similar issues or has any guess on what might be the reason. Thanks!

Hello, Chief Scientist. would you like publicize the image inpainting model?

I am a postgraduate student in China, I find a new partial-conv based (mask based ) convolution method for some special inpainting application. I have been re-implementing your inpainting model for a long time. But it doesn't get a good result as yours. Could u help me to publicize the official model code or send to me? I am hurrying to publish my first paper by modifying your method. Thanks very much

image inpainting error

I did your usage guide for image inpainting.

PartialConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, multi_channel=True, return_mask=True)

and started to run main.py
but after partialconv, the code error is occured with this message

AttributeError: 'tuple' object has no attribute 'dim'

How can I solve this error?

Pytorch export trace/script

Hello. Has anyone tried to save the model with partial convolutions as pytorch script? I got this error when calling pytorch.jit.script:

Traceback (most recent call last):
  File "export.py", line 47, in <module>
    net_script = torch.jit.script(model)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_script.py", line 943, in script
    obj, torch.jit._recursive.infer_methods_to_compile
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_recursive.py", line 391, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_recursive.py", line 448, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_script.py", line 391, in _construct
    init_fn(script_module)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_recursive.py", line 428, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_recursive.py", line 448, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_script.py", line 391, in _construct
    init_fn(script_module)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_recursive.py", line 428, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_recursive.py", line 452, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/_recursive.py", line 335, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: 
cannot statically infer the expected size of a list in this context:
  File "/home/user/Projects/gitlab/2d22d-poc/pconv.py", line 56
    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4, "Input shape for partial convolution must have 4 dimensions"
        if mask_in is not None or self.last_size != tuple(input.shape):
                                                    ~~~~~~~~~~~~~~~~~ <--- HERE
            self.last_size = tuple(input.shape)

Doesn't take 2 channel mask as input

Hi,

In your inpainting paper, you've mentioned that in the decoder part of Unet, you add skip connections for both features and mask by concatenating them. I assume you concatenate them along feature channels. But if I give this concatenated mask as input to the Pconv layer, it throws the following error. So, it looks like I can't pass the concatenated mask to this layer directly. What should I concatenate the masks?

RuntimeError: Given groups=1, weight of size [1, 1, 3, 3], expected input[8, 2, 32, 32] to have 1 channels, but got 2 channels instead

Inpainting demo not working

The inpainted result image is not shown, there's a gateway error taking place. Problem occurs in severeal different browsers.

image

The updating of mask

Thank you for your great contributions! May i ask you for the updating of the mask in partial convolution? The mask and weight_maskUpdater are initialized by torch.ones and the results of update_mask are always identity matrix. Thank you for your help!

2d and 3d implementation differences

While comparing line-by-line 2d and 3d implementation, I noticed that 2d have mask, while 3d implementation have mask_in.
Here is the line-by-line comparison:

# 2d
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
# 3d
raw_out = super(PartialConv3d, self).forward(torch.mul(input, mask_in) if mask_in is not None else input)

Is there a bug? Which one is correct?

I can't import models in main.py

When i import modules, i take error that
import models as models_partial # partial conv based padding

ModuleNotFoundError: No module named 'models'

partial con

some questions about "class PartialConv2d(nn.Conv2d): "
firstly, what's the purpose of " self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)"? mabe the is the update_mask which is use for the next layer,so,why, torch.mul(self.mask_ratio, self.update_mask)
secendly, what's the purpose of "output = torch.mul(output, self.update_mask)" the next feature layer is the feature with mask?
lastly, there is no computition about filter weight matrix and X (WX).so when type "class PartialConv2d(nn.Conv2d)",the code will first jump to nn.Conv2d ?

Map at edges is peaking (PartialConv2d implementation + fix)

I have implemented partialconv, and stumbled with the problem that layer activations are peaking at edges, though "Partial Convolution based Padding" paper at Figure 5 (paper) explicitly saying that "Red rectangles show the strong activation regions from VGG19 network with zero paddding":
image


I started to double check my implementation, and it turns out to be similar as this repo. After that I started to think about it, why this is happening. After trial and fail I came up with simple solution - just convolute mask on mask_weight, then normalize mask by dividing it with max value in the mask.

Here is code for your reference to double check your implementation, my implementation, and fix by yourself:

Code

from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable

import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor


class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(self.out_channels, self.in_channels,
                                                   self.kernel_size[0], self.kernel_size[1]))
        else:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones_like(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
                else:
                    mask = mask_in

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)

        if self.return_mask:
            return output, self.update_mask
        else:
            return output


class MaskedConv2d(nn.Conv2d):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = 'zeros',
            eps=1e-8,
            multichannel: bool = False,
            partial_conv: bool = False,
            device=None,
            dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
        if multichannel:
            self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
        else:
            self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
        self.eps = eps
        self.multichannel = multichannel
        self.partial_conv = partial_conv

    def get_mask(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None
    ) -> (torch.Tensor, torch.Tensor):
        if mask is None:
            if self.multichannel:
                mask = torch.ones_like(input)
            else:
                mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
        else:
            if self.multichannel:
                mask = mask.expand_as(input)
            else:
                mask = mask.expand(1, 1, *input.shape[2:])
        return mask

    def forward(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None = None
    ) -> (torch.Tensor, torch.Tensor | None):
        if mask is not None:
            input *= mask

        mask = self.get_mask(input, mask)

        if self.partial_conv:
            output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
            mask_ratio = mask_kernel_numel / (mask + self.eps)
            mask.clamp_(0, 1)

            # Apply re-weighting and bias
            output *= mask_ratio
            if self.bias is not None:
                output += self.bias.view(-1, 1, 1)

            output *= mask
        else:
            output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
            mask = mask / max_vals

        return output, mask

    def extra_repr(self):
        return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"


class MaskedPixelUnshuffle(nn.PixelUnshuffle):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        return super().forward(input), super().forward(mask) if mask is not None else None


class MaskedSequential(nn.Sequential):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        for module in self:
            input, mask = module(input, mask)
        return input, mask


@contextmanager
def register_hooks(
        model: torch.nn.Module,
        hook: Callable,
        predicate: Callable[[str, torch.nn.Module], bool],
        **hook_kwargs
):
    handles = []
    try:
        for name, module in model.named_modules():
            if predicate(name, module):
                hook: Callable = partial(hook, name=name, **hook_kwargs)
                handle = module.register_forward_hook(hook)
                handles.append(handle)
        yield handles
    finally:
        for handle in handles:
            handle.remove()


def activations_recorder_hook(
        module: torch.nn.Module,
        input: torch.Tensor,
        output: torch.Tensor,
        name: str,
        *,
        storage: dict[str, Any]
):
    if name in storage:
        if isinstance(storage[name], list):
            storage[name].append(output)
        else:
            storage[name] = [storage[name], output]
    else:
        storage[name] = output


def forward_with_activations(
        model: torch.nn.Module,
        predicate: Callable[[str, torch.nn.Module], bool],
        *model_args,
        **model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
    storage = {}
    with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
        output = model(*model_args, **model_kwargs)
    return output, storage


def test_it():
    torch.manual_seed(37)

    in_channels = 3
    downscale_factor = 2
    scale = 1
    base = 2
    depth = 8
    visualize_depth = 4
    eps = 1e-8

    pconv = []
    for i in range(depth):
        pconv.append(MaskedPixelUnshuffle(downscale_factor))
        pconv.append(PartialConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
        )
    pconv = MaskedSequential(*pconv)

    mpconv = []
    for i in range(depth):
        mpconv.append(MaskedPixelUnshuffle(downscale_factor))
        mpconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
        )
    mpconv = MaskedSequential(*mpconv)

    mconv = []
    for i in range(depth):
        mconv.append(MaskedPixelUnshuffle(downscale_factor))
        mconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
        )
    mconv = MaskedSequential(*mconv)

    with torch.no_grad():
        print(f"{pconv=}")
        print(f"{mpconv=}")
        print(f"{mconv=}")

        print(f"{list(pconv.state_dict().keys())=}")
        print(f"{list(mpconv.state_dict().keys())=}")
        print(f"{list(mconv.state_dict().keys())=}")

        mpconv.load_state_dict(pconv.state_dict())
        mconv.load_state_dict(pconv.state_dict())

        x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
        mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x)

        def is_conv_predicate(name: str, module: torch.nn.Module):
            return isinstance(module, torch.nn.Conv2d)

        (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv)
        (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv)
        (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv)

        assert torch.allclose(y_mpconv, y_pconv)
        assert not torch.allclose(y_mconv, y_mpconv)

        print(f"{activations_pconv.keys()=}")  # ['1', '3', '5', '7', '9', '11', '13', '15']

        # fig, axs = plt.subplots(nrows=visualize_depth, ncols=3, figsize=(12, 8), dpi=180)
        fig, axs = plt.subplots(nrows=3, ncols=visualize_depth, figsize=(12, 8), dpi=180)
        axs = axs.flatten()

        for impl_i, (name, y, mask, activations) in enumerate([
            ("pconv", y_pconv, mask_pconv, activations_pconv),
            ("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
            ("mconv", y_mconv, mask_mconv, activations_mconv)
        ]):
            batch_i = 0
            for depth_i in range(visualize_depth):
                # ax = axs[depth_i * 3 + impl_i]
                ax = axs[impl_i * visualize_depth + depth_i]

                output = activations[f"{depth_i * 2 + 1}"][0][batch_i]
                mask_output = activations[f"{depth_i * 2 + 1}"][1][batch_i]

                mean = output.mean()
                std = output.std(unbiased=False)
                skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
                kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
                print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")

                ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
                ax.set_title(f"{name} {depth_i=}")
                ax.axis('off')

        # plt.suptitle(f"Depth {depth_i}")
        plt.show()


if __name__ == '__main__':
    test_it()

Output:

name='pconv', depth_i=0, mean=tensor(-0.0040), std=tensor(0.5844), skewness=tensor(0.0056), kurtosis=tensor(3.0593)
name='pconv', depth_i=1, mean=tensor(-0.0014), std=tensor(0.3347), skewness=tensor(-0.0053), kurtosis=tensor(3.1046)
name='pconv', depth_i=2, mean=tensor(-0.0001), std=tensor(0.1993), skewness=tensor(0.0125), kurtosis=tensor(3.2002)
name='pconv', depth_i=3, mean=tensor(-0.0013), std=tensor(0.1211), skewness=tensor(-0.0061), kurtosis=tensor(3.5512)
name='mpconv', depth_i=0, mean=tensor(-0.0040), std=tensor(0.5844), skewness=tensor(0.0056), kurtosis=tensor(3.0593)
name='mpconv', depth_i=1, mean=tensor(-0.0014), std=tensor(0.3347), skewness=tensor(-0.0053), kurtosis=tensor(3.1046)
name='mpconv', depth_i=2, mean=tensor(-0.0001), std=tensor(0.1993), skewness=tensor(0.0125), kurtosis=tensor(3.2002)
name='mpconv', depth_i=3, mean=tensor(-0.0013), std=tensor(0.1211), skewness=tensor(-0.0061), kurtosis=tensor(3.5512)
name='mconv', depth_i=0, mean=tensor(-0.0039), std=tensor(0.5769), skewness=tensor(0.0052), kurtosis=tensor(3.0468)
name='mconv', depth_i=1, mean=tensor(-0.0016), std=tensor(0.3209), skewness=tensor(-0.0099), kurtosis=tensor(3.0444)
name='mconv', depth_i=2, mean=tensor(-0.0003), std=tensor(0.1796), skewness=tensor(-0.0102), kurtosis=tensor(3.1047)
name='mconv', depth_i=3, mean=tensor(-0.0011), std=tensor(0.0973), skewness=tensor(-0.0421), kurtosis=tensor(3.3349)

image

  • pconv is an original implementation of partial conv (this repo)
  • mpconv is my implementation of partial conv
  • mconv is my approach of masked convolution

Here is also activations on real images:


  • PartialConv2d:
    image

  • MaskedConv2d:
    image

Questions about loss function modification

Hi, i want to use and modify your code for image inpainting. I have a dataset with thousands of pics. Here are some questions,hope someone can help me: @liuguilin1225
1.if I want to fill a hole in image A with more content from image A and less imaginary content from other images. What should i do? Is it right to modify total loss function? Is so, which part of loss should be modified?
2. If I modify loss function, does it work if I train my own dataset both with the pretrained model this repo provided and the modified loss function?

In place change Conv2D

Hello! First off thanks for your paper and code, it's great insight.

I wanted to try using PartialConv2D to replace nn.Conv2D for an upscaling task, to test if that would help reduce the artifacts that appear at the image edges (which I guess is caused by the padding), but when I tried to do it, the first loss of the network became extremely large and on the second iteration the loss tensor was just NaN.

Do I have to take anything else into account when testing replacing Conv2D for partial convolution based padding?

Cheers!

About the comparison experiment with GL

Hi,
Thanks a lot for your brilliant work on inpainting with partial conv. It helped me a lot. :)

I'm a bit puzzled at your comparison experiment with Globally and Locally Consistent Image Completion. According to the paper, you are using the pretrained model of GL, which was trained on Place2 dataset, and you compared the results on both Place2 and ImageNet (figure 5 & 6).
I just wonder why you decided to test it on ImageNet? Do you mean that the images in Place2 and ImageNet are similiar so we can just use the latter to validate models trained on the former?

Looking forward to your reply. Thanks again!

Is ordinary convolution(with bias=0 and no use BN) and partial convolution the same?

Thank you for your genius work. However, there's a little problem that puzzles me.
For image inpainting task, when the pixel values the missing area of the input image are set to 0,
I think that ordinary convolution(bias=False and no use BN) and partial convolution are the same? Under this circumstance, the result of the ordinary convolution(bias=False and no use BN) is 0 naturally, at the missing area (if all piexls in this slide windows are 0) . And the gradients at these missing areas is 0, it is the same as parital convolution. I'm confused about this small question. Can you tell me the essential difference between them? Thanks!
@liuguilin1225

About train details

Hello, i can't see about train details in this project. To train inpainting model, input : batch image and mask is one or batch? Whether this difference will mislead the final result? Thanks!

Some comments about code of PartialConv2d

Hi,
I really like your work - thank you for sharing the code.

I have some remarks on the code of PartialConv2d:

1. Using buffers instead of simple "member tensors":

The class attribute self.weight_maskUpdater is defined as a "plain" tensor:

if self.multi_channel:
self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
else:
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])

As a result, when the model is transferred to GPU, or data type is changing you need to explicitly check for it and change it:
if self.weight_maskUpdater.type() != input.type():
self.weight_maskUpdater = self.weight_maskUpdater.to(input)

A more elegant way is to use non-persistent buffers:

        if self.multi_channel:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(self.out_channels, self.in_channels,
                                                   self.kernel_size[0], self.kernel_size[1]))
        else:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))

This way self.weight_maskUpdater will be affected by any .to(...) / .cuda() / .cpu() invoked on the model hosting this layer, making the condition on line 49 redundent.

2. Use of torch.ones_like

Instead of torch.ones(...).to(...)

if self.multi_channel:
mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)

You can use torch.ones_like which is simpler and easier to read:

                    if self.multi_channel:
                        mask = torch.ones_like(input)
                    else:
                        mask = torch.ones_like(input[:1, :1, ...])

3. No need to import Variable

You import Variable, but never use it.

from torch.autograd import Variable

BTW All these comments are applicable to the code of PartialConv3d.

Generalization for resneXt

Hello!
Greaty results! How to generalization this layer for resneXt network with groups != 1 in conv layers inside Bottleneck block?

Very sparse pixels

Hello,

I was trying the web base bersion of your toolbox, but without sucess. (https://www.nvidia.com/research/inpainting/selection).

I have a very sparse pixel image, which I want to inpaint (see attached image).
Basically I want to inpaint in the convex hull, where I draw a mask and use the scipy inpaint, with good results. When I try this toolbox, it returns the same image as the input.
I understand that the missing pixels are much-much more than the actual information, but I wonder if I can tweak some parameters to make it work.

Can you share a hint where shall I start searching?
Thanks

out

About model input

hello, I have a question, Is the model input = image * mask or input = image*mask +1-mask? Is the hole white or black in image?

papaer arch partial conv num question

Hi,
I have a question about arch design, I find in the paper in Section 3.2 which is:

replacing all convolutional layers with partial convolutional layers and using nearest neighbor up-sampling in the decoding stage

The input have inpainting mask, so the partial conv is necessary.
But on the middle feature maps, do it be ordinary conv is also okay?

About the ImageNet pretrained checkpoints.

Really simple and interesting work.

I am wondering when it will be convenient for you to share the imagenet pretrained models.

I hope to try your pd_resnet on the scene parsing tasks.

About args: multi-channel for image inpainting

Hello! In partialconv2d.py :
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False
Why should we set the args: multi-channel to be "True" if we want use Pconv for image inpainting? I think it has no effect when we use it for image inpainting if multi-channel is set to False. It seems like a waste, but it doesn't make any difference to the result.

Please correct me if I make some wrongs.

Online Demo down?

Is the online demo down for good or will it be coming back online?

TypeError: mul() received an invalid combination of arguments - got (Tensor, NoneType), but expected one of:

I was impressed with your efficient padding method. So I used your method in my code.
But I have a problem.

error :
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask)
TypeError: mul() received an invalid combination of arguments - got (Tensor, NoneType), but expected one of:

  • (Tensor input, Tensor other, Tensor out)
  • (Tensor input, float other, Tensor out)

The first iteration does not produce an error, but at the second iteration this error occurs.
I think the error is due to the mask value being none.
That's why I
I used self.last_size = (None, None) on line 44 or I used the mask by replacing it with the self variable.

If you have any other way to fix this error, can you recommend it?

Thank you for opening the code.

Pretraining

You mentioned in the previous issue we can load pretrained and convert conv2d to partialconv. How would you change it as the model structure is fixed in pretrained models? My model is

`class convNet(torch.nn.Module):      # pretrained only taking till last 3 children of resnet18
    #constructor
    def __init__(self):
        super(convNet, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.mycut=torch.nn.Sequential(*list(resnet.children())[:-3])
        self.max_pool=torch.nn.MaxPool2d(2)
        self.conv1 = PartialConv2d(256,  256, kernel_size=3, stride=1, padding=1)
        self.conv2 = PartialConv2d(256, 256, kernel_size=3, stride=1, padding=1)

        self.Flattern=Lambda(flat)
        
        self.fc1=torch.nn.Linear(25600,120)
        self.fc2=torch.nn.Linear(120,50)
        self.fc3=torch.nn.Linear(50,8)
        
    def forward(self, x):
        x=self.mycut(x)
        x=F.relu(self.conv1(x))
        x=F.relu(self.conv2(x))
        
        x=self.Flattern(x)

        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        return x` 

Using this makes your model train much slower than normal conv2d layers with fine tunning both off and on. How can i change all layers to PartialConv2d in the model above?

Also is there any plans to release a PartialConv3d? I would be interested in medical segmentation for that?

Thanks,
Rahul

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.