Coder Social home page Coder Social logo

robust-unsupervised's Introduction

Robust Unsupervised StyleGAN Image Restoration

Code for the paper Robust Unsupervised StyleGAN Image Restoration presented at CVPR 2023.

Installation

  1. First install the same environment as https://github.com/NVlabs/stylegan2-ada-pytorch.git. It is not essential for the custom cuda kernels to compile correctly, they just make things run ~30% faster.

  2. Run pip install tyro. For running the evaluation you will also need to pip install torchmetrics git+https://github.com/jwblangley/pytorch-fid.git.

  3. Download the pretrained StyleGAN model:

wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl -O pretrained/ffhq.pkl

Restoring images

To run the tasks presented in the paper, use:

python run.py --dataset_path datasets/samples

Some sample images have already been provided in datasets/samples.

Other datasets

First, download a pretrained StyleGAN2 generator for your dataset (.pkl), and pass it's path to the --pkl_path option. If the resolution of your data is different from 1024 you also need to set it using the --resolution option. This resolution does not need to match the pretrained generator's resolution; for best results pick a high resolution generator even if your images are smaller.

Finally, on datasets other than faces you may need to scale all learning rates up or down by a constant amount to compensate for the different scale of the latent space. For this you can use the CLI option --global_lr_scale.

Restoring your own degradations

Use the option --tasks custom, then find the following code in run.py and update it with your degradation function:

class YourDegradation:
    def degrade_ground_truth(self, x):
        """
        The true degradation you are attempting to invert.
        This assumes you are testing against clean ground truth images.
        """
        raise NotImplementedError
    
    def degrade_prediction(self, x):
        """
        Differentiable approximation to the degradation in question. 
        Can be identical to the true degradation if it is invertible.
        """
        raise NotImplementedError

If you do not have access to ground truth images, you can open degraded images directly and make degrade_ground_truth an indentity function.

Evaluation

Coming soon.

robust-unsupervised's People

Contributors

mlomnitz avatar yohan-pg avatar

Stargazers

Wei Cao avatar Flying Boy avatar  avatar  avatar  avatar  avatar David Serrano Lozano avatar Firat Ozdemir avatar

Watchers

 avatar

robust-unsupervised's Issues

Fail to run in styleGAN3

Hi, thanks for sharing such great repo for solving image degradation. I am currently trying to update the code from styleGAN2-ADA to styleGAN3. I simply modified some imports since styleGAN3 uses different name and path, and what I understand is that this image restoration method is mainly doing image generation (please correct me if I misunderstood). I pasted below:

'''
File: robust_unsupervised/prelude.py
'''

from typing import *

import copy
import os

import pickle

import functools
import sys
import torch.optim as optim
import tqdm
import dataclasses
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import math

# >>>>>>>> Modified
import dnnlib
import legacy
# <<<<<<<<

import shutil
from functools import partial
import itertools
import warnings
from warnings import warn
import datetime
import torchvision.transforms.functional as TF
from torchvision.utils import save_image, make_grid

# >>>>>>>> Modified
# import training.networks as networks
import training.networks_stylegan3 as networks
# <<<<<<<<

from abc import ABC, abstractmethod, abstractstaticmethod, abstractclassmethod
from dataclasses import dataclass, field

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

warnings.filterwarnings("ignore", r"Named tensors and all their associated APIs.*")
warnings.filterwarnings("ignore", r"Arguments other than a weight enum.*")
warnings.filterwarnings("ignore", r"The parameter 'pretrained' is deprecated.*")
'''
File: robust_unsupervised/io_utils.py
'''
from robust_unsupervised.prelude import *
from robust_unsupervised.variables import *

import shutil
import torch_utils as torch_utils
import torch_utils.misc as misc
import contextlib

import PIL.Image as Image

def open_generator(pkl_path: str, refresh=True, float=True, ema=True) -> networks.Generator:
    print(f"Loading generator from {pkl_path}...")

    # >>>>>>>> Modified
    # with dnnlib.util.open_url(pkl_path) as fp:
    #     G = legacy.load_network_pkl(fp)["G_ema" if ema else "G"].cuda().eval()
    #     if float:
    #         G = G.float()

    with open(pkl_path, 'rb') as f:
        G = pickle.load(f)['G_ema'].cuda()
        if float:
            G = G.float()
    # <<<<<<<<

    if refresh:
        with torch.no_grad():
            old_G = G
            G = networks.Generator(*old_G.init_args, **old_G.init_kwargs).cuda()
            misc.copy_params_and_buffers(old_G, G, require_all=True)
            for param in G.parameters():
                param.requires_grad = False

    return G
import tyro
from dataclasses import dataclass
from typing import *

import sys
# >>>>>>>> Modified
sys.path.append("stylegan3")
# <<<<<<<<

@dataclass
class Config:
    name: str = f"restored_samples"
    "A name used to group log files."

    pkl_path: str = "pretrained_networks/stylegan3-r-ffhq-1024x1024.pkl"
    "The location of the pretrained StyleGAN."

    dataset_path: str = "datasets/samples"
    "The location of the images to process."
    
    resolution: int = 1024
    "The resolution of your images. Images which are smaller or larger will be resized."

    global_lr_scale: float = 1.0
    "A global factor which scales up and down all learning rates. This may need adjustment for datasets other than faces."

    tasks: Literal["all", "single", "composed", "custom"] = "all"
    "Selects which tasks to run."

def parse_config() -> Config:
    return tyro.cli(Config)

So, I downloaded a new pre-trained model from styleGAN3 named stylegan3-r-ffhq-1024x1024.pkl. I keep using the same datasets that this repo provided (sample_1.png and sample_2.png), and all other codes remain same.
However, the program can run for W and Wp successfully, but it will be terminated by an error when entering Wpp process. The error message is copied below:

>$ python run.py --dataset_path datasets/samples
restored_samples
Loading generator from pretrained_networks/stylegan3-r-ffhq-1024x1024.pkl...
out/restored_samples/2024-05-04T211319/single_tasks/inpainting/XL/
/project/robust-unsupervised/out/restored_samples/2024-05-04T211319/single_tasks/inpainting/XL/datasets/samples
- 0000
W:   0%|     | 0/150 [00:00<?, ?it/s]
Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.
W: 100%|███████| 150/150 [00:27<00:00,  5.48it/s]
W+: 100%|██████| 150/150 [00:25<00:00,  5.79it/s]
W++:   0%|     | 0/150 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/project/robust-unsupervised/run.py", line 154, in <module>
    run_phase("W++", Wpp_variable, config.global_lr_scale * 0.005)
  File "/project/robust-unsupervised/run.py", line 24, in run_phase
    x = variable.to_image()
  File "/project/robust-unsupervised/robust_unsupervised/variables.py", line 29, in to_image
    return self.render_image(self.to_input_tensor())
  File "/project/robust-unsupervised/robust_unsupervised/variables.py", line 35, in render_image
    return (self.G.synthesis(ws, noise_mode="const", force_fp32=True) + 1.0) / 2.0
  File "/project/miniconda3/envs/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/project/miniconda3/envs/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/project/robust-unsupervised/stylegan3/training/networks_stylegan3.py", line 465, in forward
    misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
  File "/project/robust-unsupervised/stylegan3/torch_utils/misc.py", line 95, in assert_shape
    raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
AssertionError: Wrong size for dimension 1: got 8192, expected 16

Could you share any suggestions on solving this issue? Thanks.


My venv:
TLDR
This environment can run both styleGAN2-ADA and styleGAN3.

python: 3.10
cuda: 12.3
torch: 2.3.0
torchvision: 0.18.0
OS: Linux
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - binutils=2.40=h4852527_0
  - binutils_impl_linux-64=2.40=ha885e6a_0
  - binutils_linux-64=2.40=hdade7a5_3
  - blas=1.0=mkl
  - brotli-python=1.1.0=py310hc6cd4ac_1
  - bzip2=1.0.8=hd590300_5
  - c-compiler=1.7.0=hd590300_0
  - ca-certificates=2024.3.11=h06a4308_0
  - certifi=2024.2.2=pyhd8ed1ab_0
  - charset-normalizer=3.3.2=pyhd8ed1ab_0
  - cuda=12.1.0=0
  - cuda-cccl=12.4.127=0
  - cuda-command-line-tools=12.1.1=0
  - cuda-compiler=12.4.1=0
  - cuda-cudart=12.1.105=0
  - cuda-cudart-dev=12.1.105=0
  - cuda-cudart-static=12.1.105=0
  - cuda-cuobjdump=12.4.127=0
  - cuda-cupti=12.1.105=0
  - cuda-cupti-static=12.1.105=0
  - cuda-cuxxfilt=12.4.127=0
  - cuda-demo-suite=12.4.127=0
  - cuda-documentation=12.4.127=0
  - cuda-driver-dev=12.4.127=0
  - cuda-gdb=12.4.127=0
  - cuda-libraries=12.1.0=0
  - cuda-libraries-dev=12.1.0=0
  - cuda-libraries-static=12.1.0=0
  - cuda-nsight=12.4.127=0
  - cuda-nsight-compute=12.4.1=0
  - cuda-nvcc=12.4.131=0
  - cuda-nvdisasm=12.4.127=0
  - cuda-nvml-dev=12.4.127=0
  - cuda-nvprof=12.4.127=0
  - cuda-nvprune=12.4.127=0
  - cuda-nvrtc=12.1.105=0
  - cuda-nvrtc-dev=12.1.105=0
  - cuda-nvrtc-static=12.1.105=0
  - cuda-nvtx=12.1.105=0
  - cuda-nvvp=12.4.127=0
  - cuda-opencl=12.4.127=0
  - cuda-opencl-dev=12.4.127=0
  - cuda-profiler-api=12.4.127=0
  - cuda-runtime=12.1.0=0
  - cuda-sanitizer-api=12.4.127=0
  - cuda-toolkit=12.1.0=0
  - cuda-tools=12.1.0=0
  - cuda-visual-tools=12.1.0=0
  - cudatoolkit=11.7.0=hd8887f6_10
  - cxx-compiler=1.7.0=h00ab1b0_0
  - ffmpeg=4.3=hf484d3e_0
  - filelock=3.13.4=pyhd8ed1ab_0
  - freetype=2.12.1=h267a509_2
  - gcc=12.3.0=h915e2ae_6
  - gcc_impl_linux-64=12.3.0=h1562d66_6
  - gcc_linux-64=12.3.0=h6477408_3
  - gds-tools=1.9.1.3=0
  - gmp=6.3.0=h59595ed_1
  - gmpy2=2.1.5=py310hc3586ac_0
  - gnutls=3.6.13=h85f3911_1
  - gxx=12.3.0=h915e2ae_6
  - gxx_impl_linux-64=12.3.0=h1562d66_6
  - gxx_linux-64=12.3.0=h4a1b8e8_3
  - icu=73.2=h59595ed_0
  - idna=3.7=pyhd8ed1ab_0
  - intel-openmp=2023.1.0=hdb19cb5_46306
  - jinja2=3.1.3=pyhd8ed1ab_0
  - jpeg=9e=h166bdaf_2
  - kernel-headers_linux-64=2.6.32=he073ed8_17
  - lame=3.100=h166bdaf_1003
  - lcms2=2.15=hfd0df8a_0
  - ld_impl_linux-64=2.40=h55db66e_0
  - lerc=4.0.0=h27087fc_0
  - libblas=3.9.0=1_h86c2bf4_netlib
  - libcblas=3.9.0=5_h92ddd45_netlib
  - libcublas=12.1.0.26=0
  - libcublas-dev=12.1.0.26=0
  - libcublas-static=12.1.0.26=0
  - libcufft=11.0.2.4=0
  - libcufft-dev=11.0.2.4=0
  - libcufft-static=11.0.2.4=0
  - libcufile=1.9.1.3=0
  - libcufile-dev=1.9.1.3=0
  - libcufile-static=1.9.1.3=0
  - libcurand=10.3.5.147=0
  - libcurand-dev=10.3.5.147=0
  - libcurand-static=10.3.5.147=0
  - libcusolver=11.4.4.55=0
  - libcusolver-dev=11.4.4.55=0
  - libcusolver-static=11.4.4.55=0
  - libcusparse=12.0.2.55=0
  - libcusparse-dev=12.0.2.55=0
  - libcusparse-static=12.0.2.55=0
  - libdeflate=1.17=h0b41bf4_0
  - libffi=3.4.2=h7f98852_5
  - libgcc-devel_linux-64=12.3.0=h2af2641_106
  - libgcc-ng=13.2.0=hc881cc4_6
  - libgfortran-ng=13.2.0=h69a702a_6
  - libgfortran5=13.2.0=h43f5ff8_6
  - libgomp=13.2.0=hc881cc4_6
  - libhwloc=2.10.0=default_h2fb2949_1000
  - libiconv=1.17=hd590300_2
  - libjpeg-turbo=2.0.0=h9bf148f_0
  - liblapack=3.9.0=5_h92ddd45_netlib
  - libnpp=12.0.2.50=0
  - libnpp-dev=12.0.2.50=0
  - libnpp-static=12.0.2.50=0
  - libnsl=2.0.1=hd590300_0
  - libnvjitlink=12.1.105=0
  - libnvjitlink-dev=12.1.105=0
  - libnvjpeg=12.1.1.14=0
  - libnvjpeg-dev=12.1.1.14=0
  - libnvjpeg-static=12.1.1.14=0
  - libnvvm-samples=12.1.105=0
  - libpng=1.6.43=h2797004_0
  - libsanitizer=12.3.0=h2af2641_6
  - libsqlite=3.45.3=h2797004_0
  - libstdcxx-devel_linux-64=12.3.0=h2af2641_106
  - libstdcxx-ng=13.2.0=h95c4c6d_6
  - libtiff=4.5.0=h6adf6a1_2
  - libuuid=2.38.1=h0b41bf4_0
  - libwebp-base=1.4.0=hd590300_0
  - libxcb=1.13=h7f98852_1004
  - libxcrypt=4.4.36=hd590300_1
  - libxml2=2.12.6=h232c23b_2
  - libzlib=1.2.13=hd590300_5
  - llvm-openmp=15.0.7=h0cdce71_0
  - markupsafe=2.1.5=py310h2372a71_0
  - mkl=2023.1.0=h213fc3f_46344
  - mpc=1.3.1=hfe3b2da_0
  - mpfr=4.2.1=h9458935_1
  - mpmath=1.3.0=pyhd8ed1ab_0
  - ncurses=6.4.20240210=h59595ed_0
  - nettle=3.6=he412f7d_0
  - networkx=3.3=pyhd8ed1ab_1
  - ninja=1.12.0=h00ab1b0_0
  - nsight-compute=2024.1.1.4=0
  - numpy=1.26.4=py310hb13e2d6_0
  - openh264=2.1.1=h780b84a_0
  - openjpeg=2.5.0=hfec8fc6_2
  - openssl=3.2.1=hd590300_1
  - pillow=9.4.0=py310h023d228_1
  - pip=24.0=pyhd8ed1ab_0
  - pthread-stubs=0.4=h36c2ea0_1001
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.10.14=hd12c33a_0_cpython
  - python_abi=3.10=4_cp310
  - pytorch=2.3.0=py3.10_cuda12.1_cudnn8.9.2_0
  - pytorch-cuda=12.1=ha16c6d3_5
  - pytorch-mutex=1.0=cuda
  - pyyaml=6.0.1=py310h2372a71_1
  - readline=8.2=h8228510_1
  - requests=2.31.0=pyhd8ed1ab_0
  - setuptools=69.5.1=pyhd8ed1ab_0
  - sympy=1.12=pypyh9d50eac_103
  - sysroot_linux-64=2.12=he073ed8_17
  - tbb=2021.12.0=h00ab1b0_0
  - tk=8.6.13=noxft_h4845f30_101
  - torchaudio=2.3.0=py310_cu121
  - torchtriton=2.3.0=py310
  - torchvision=0.18.0=py310_cu121
  - typing_extensions=4.11.0=pyha770c72_0
  - tzdata=2024a=h0c530f3_0
  - urllib3=2.2.1=pyhd8ed1ab_0
  - wheel=0.43.0=pyhd8ed1ab_1
  - xorg-libxau=1.0.11=hd590300_0
  - xorg-libxdmcp=1.1.3=h7f98852_0
  - xz=5.2.6=h166bdaf_0
  - yaml=0.2.5=h7f98852_2
  - zlib=1.2.13=hd590300_5
  - zstd=1.5.5=hfc55251_0
  - pip:
      - absl-py==2.1.0
      - beautifulsoup4==4.12.3
      - cachetools==5.3.3
      - click==8.1.7
      - contourpy==1.2.1
      - cycler==0.12.1
      - fonttools==4.51.0
      - gdown==5.1.0
      - grpcio==1.62.2
      - imageio-ffmpeg==0.4.9
      - kiwisolver==1.4.5
      - markdown==3.6
      - matplotlib==3.8.4
      - nvidia-ml-py==12.535.161
      - nvitop==1.3.2
      - packaging==24.0
      - protobuf==5.26.1
      - psutil==5.9.8
      - pyparsing==3.1.2
      - pyspng==0.1.1
      - python-dateutil==2.9.0.post0
      - scipy==1.13.0
      - six==1.16.0
      - soupsieve==2.5
      - tensorboard==2.16.2
      - tensorboard-data-server==0.7.2
      - termcolor==2.4.0
      - tqdm==4.66.2
      - werkzeug==3.0.2

about the train

thanks for your wonderful work!
in your paper, I find that you use the target y that is a degradation image as the GT, i am not understand why use this image as GT is enough for training the model to do restoration. actually, i think the model's output is similar to target y, not the original clean one. Hope for your answer!

cli

NameError: name 'Literal' is not defined

pre-trained model

I have not seen the pre-trained model and corresponding test code you provided in your project, may I ask whether these will be released later? What is the approximate time?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.