Hi, thanks for sharing such great repo for solving image degradation. I am currently trying to update the code from styleGAN2-ADA to styleGAN3. I simply modified some imports since styleGAN3 uses different name and path, and what I understand is that this image restoration method is mainly doing image generation (please correct me if I misunderstood). I pasted below:
'''
File: robust_unsupervised/prelude.py
'''
from typing import *
import copy
import os
import pickle
import functools
import sys
import torch.optim as optim
import tqdm
import dataclasses
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import math
# >>>>>>>> Modified
import dnnlib
import legacy
# <<<<<<<<
import shutil
from functools import partial
import itertools
import warnings
from warnings import warn
import datetime
import torchvision.transforms.functional as TF
from torchvision.utils import save_image, make_grid
# >>>>>>>> Modified
# import training.networks as networks
import training.networks_stylegan3 as networks
# <<<<<<<<
from abc import ABC, abstractmethod, abstractstaticmethod, abstractclassmethod
from dataclasses import dataclass, field
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
warnings.filterwarnings("ignore", r"Named tensors and all their associated APIs.*")
warnings.filterwarnings("ignore", r"Arguments other than a weight enum.*")
warnings.filterwarnings("ignore", r"The parameter 'pretrained' is deprecated.*")
'''
File: robust_unsupervised/io_utils.py
'''
from robust_unsupervised.prelude import *
from robust_unsupervised.variables import *
import shutil
import torch_utils as torch_utils
import torch_utils.misc as misc
import contextlib
import PIL.Image as Image
def open_generator(pkl_path: str, refresh=True, float=True, ema=True) -> networks.Generator:
print(f"Loading generator from {pkl_path}...")
# >>>>>>>> Modified
# with dnnlib.util.open_url(pkl_path) as fp:
# G = legacy.load_network_pkl(fp)["G_ema" if ema else "G"].cuda().eval()
# if float:
# G = G.float()
with open(pkl_path, 'rb') as f:
G = pickle.load(f)['G_ema'].cuda()
if float:
G = G.float()
# <<<<<<<<
if refresh:
with torch.no_grad():
old_G = G
G = networks.Generator(*old_G.init_args, **old_G.init_kwargs).cuda()
misc.copy_params_and_buffers(old_G, G, require_all=True)
for param in G.parameters():
param.requires_grad = False
return G
import tyro
from dataclasses import dataclass
from typing import *
import sys
# >>>>>>>> Modified
sys.path.append("stylegan3")
# <<<<<<<<
@dataclass
class Config:
name: str = f"restored_samples"
"A name used to group log files."
pkl_path: str = "pretrained_networks/stylegan3-r-ffhq-1024x1024.pkl"
"The location of the pretrained StyleGAN."
dataset_path: str = "datasets/samples"
"The location of the images to process."
resolution: int = 1024
"The resolution of your images. Images which are smaller or larger will be resized."
global_lr_scale: float = 1.0
"A global factor which scales up and down all learning rates. This may need adjustment for datasets other than faces."
tasks: Literal["all", "single", "composed", "custom"] = "all"
"Selects which tasks to run."
def parse_config() -> Config:
return tyro.cli(Config)
So, I downloaded a new pre-trained model from styleGAN3 named stylegan3-r-ffhq-1024x1024.pkl
. I keep using the same datasets that this repo provided (sample_1.png and sample_2.png), and all other codes remain same.
However, the program can run for W
and Wp
successfully, but it will be terminated by an error when entering Wpp
process. The error message is copied below:
>$ python run.py --dataset_path datasets/samples
restored_samples
Loading generator from pretrained_networks/stylegan3-r-ffhq-1024x1024.pkl...
out/restored_samples/2024-05-04T211319/single_tasks/inpainting/XL/
/project/robust-unsupervised/out/restored_samples/2024-05-04T211319/single_tasks/inpainting/XL/datasets/samples
- 0000
W: 0%| | 0/150 [00:00<?, ?it/s]
Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.
W: 100%|███████| 150/150 [00:27<00:00, 5.48it/s]
W+: 100%|██████| 150/150 [00:25<00:00, 5.79it/s]
W++: 0%| | 0/150 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/project/robust-unsupervised/run.py", line 154, in <module>
run_phase("W++", Wpp_variable, config.global_lr_scale * 0.005)
File "/project/robust-unsupervised/run.py", line 24, in run_phase
x = variable.to_image()
File "/project/robust-unsupervised/robust_unsupervised/variables.py", line 29, in to_image
return self.render_image(self.to_input_tensor())
File "/project/robust-unsupervised/robust_unsupervised/variables.py", line 35, in render_image
return (self.G.synthesis(ws, noise_mode="const", force_fp32=True) + 1.0) / 2.0
File "/project/miniconda3/envs/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/project/miniconda3/envs/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/project/robust-unsupervised/stylegan3/training/networks_stylegan3.py", line 465, in forward
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
File "/project/robust-unsupervised/stylegan3/torch_utils/misc.py", line 95, in assert_shape
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
AssertionError: Wrong size for dimension 1: got 8192, expected 16
Could you share any suggestions on solving this issue? Thanks.
My venv:
TLDR
This environment can run both styleGAN2-ADA and styleGAN3.
python: 3.10
cuda: 12.3
torch: 2.3.0
torchvision: 0.18.0
OS: Linux
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- binutils=2.40=h4852527_0
- binutils_impl_linux-64=2.40=ha885e6a_0
- binutils_linux-64=2.40=hdade7a5_3
- blas=1.0=mkl
- brotli-python=1.1.0=py310hc6cd4ac_1
- bzip2=1.0.8=hd590300_5
- c-compiler=1.7.0=hd590300_0
- ca-certificates=2024.3.11=h06a4308_0
- certifi=2024.2.2=pyhd8ed1ab_0
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- cuda=12.1.0=0
- cuda-cccl=12.4.127=0
- cuda-command-line-tools=12.1.1=0
- cuda-compiler=12.4.1=0
- cuda-cudart=12.1.105=0
- cuda-cudart-dev=12.1.105=0
- cuda-cudart-static=12.1.105=0
- cuda-cuobjdump=12.4.127=0
- cuda-cupti=12.1.105=0
- cuda-cupti-static=12.1.105=0
- cuda-cuxxfilt=12.4.127=0
- cuda-demo-suite=12.4.127=0
- cuda-documentation=12.4.127=0
- cuda-driver-dev=12.4.127=0
- cuda-gdb=12.4.127=0
- cuda-libraries=12.1.0=0
- cuda-libraries-dev=12.1.0=0
- cuda-libraries-static=12.1.0=0
- cuda-nsight=12.4.127=0
- cuda-nsight-compute=12.4.1=0
- cuda-nvcc=12.4.131=0
- cuda-nvdisasm=12.4.127=0
- cuda-nvml-dev=12.4.127=0
- cuda-nvprof=12.4.127=0
- cuda-nvprune=12.4.127=0
- cuda-nvrtc=12.1.105=0
- cuda-nvrtc-dev=12.1.105=0
- cuda-nvrtc-static=12.1.105=0
- cuda-nvtx=12.1.105=0
- cuda-nvvp=12.4.127=0
- cuda-opencl=12.4.127=0
- cuda-opencl-dev=12.4.127=0
- cuda-profiler-api=12.4.127=0
- cuda-runtime=12.1.0=0
- cuda-sanitizer-api=12.4.127=0
- cuda-toolkit=12.1.0=0
- cuda-tools=12.1.0=0
- cuda-visual-tools=12.1.0=0
- cudatoolkit=11.7.0=hd8887f6_10
- cxx-compiler=1.7.0=h00ab1b0_0
- ffmpeg=4.3=hf484d3e_0
- filelock=3.13.4=pyhd8ed1ab_0
- freetype=2.12.1=h267a509_2
- gcc=12.3.0=h915e2ae_6
- gcc_impl_linux-64=12.3.0=h1562d66_6
- gcc_linux-64=12.3.0=h6477408_3
- gds-tools=1.9.1.3=0
- gmp=6.3.0=h59595ed_1
- gmpy2=2.1.5=py310hc3586ac_0
- gnutls=3.6.13=h85f3911_1
- gxx=12.3.0=h915e2ae_6
- gxx_impl_linux-64=12.3.0=h1562d66_6
- gxx_linux-64=12.3.0=h4a1b8e8_3
- icu=73.2=h59595ed_0
- idna=3.7=pyhd8ed1ab_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- jinja2=3.1.3=pyhd8ed1ab_0
- jpeg=9e=h166bdaf_2
- kernel-headers_linux-64=2.6.32=he073ed8_17
- lame=3.100=h166bdaf_1003
- lcms2=2.15=hfd0df8a_0
- ld_impl_linux-64=2.40=h55db66e_0
- lerc=4.0.0=h27087fc_0
- libblas=3.9.0=1_h86c2bf4_netlib
- libcblas=3.9.0=5_h92ddd45_netlib
- libcublas=12.1.0.26=0
- libcublas-dev=12.1.0.26=0
- libcublas-static=12.1.0.26=0
- libcufft=11.0.2.4=0
- libcufft-dev=11.0.2.4=0
- libcufft-static=11.0.2.4=0
- libcufile=1.9.1.3=0
- libcufile-dev=1.9.1.3=0
- libcufile-static=1.9.1.3=0
- libcurand=10.3.5.147=0
- libcurand-dev=10.3.5.147=0
- libcurand-static=10.3.5.147=0
- libcusolver=11.4.4.55=0
- libcusolver-dev=11.4.4.55=0
- libcusolver-static=11.4.4.55=0
- libcusparse=12.0.2.55=0
- libcusparse-dev=12.0.2.55=0
- libcusparse-static=12.0.2.55=0
- libdeflate=1.17=h0b41bf4_0
- libffi=3.4.2=h7f98852_5
- libgcc-devel_linux-64=12.3.0=h2af2641_106
- libgcc-ng=13.2.0=hc881cc4_6
- libgfortran-ng=13.2.0=h69a702a_6
- libgfortran5=13.2.0=h43f5ff8_6
- libgomp=13.2.0=hc881cc4_6
- libhwloc=2.10.0=default_h2fb2949_1000
- libiconv=1.17=hd590300_2
- libjpeg-turbo=2.0.0=h9bf148f_0
- liblapack=3.9.0=5_h92ddd45_netlib
- libnpp=12.0.2.50=0
- libnpp-dev=12.0.2.50=0
- libnpp-static=12.0.2.50=0
- libnsl=2.0.1=hd590300_0
- libnvjitlink=12.1.105=0
- libnvjitlink-dev=12.1.105=0
- libnvjpeg=12.1.1.14=0
- libnvjpeg-dev=12.1.1.14=0
- libnvjpeg-static=12.1.1.14=0
- libnvvm-samples=12.1.105=0
- libpng=1.6.43=h2797004_0
- libsanitizer=12.3.0=h2af2641_6
- libsqlite=3.45.3=h2797004_0
- libstdcxx-devel_linux-64=12.3.0=h2af2641_106
- libstdcxx-ng=13.2.0=h95c4c6d_6
- libtiff=4.5.0=h6adf6a1_2
- libuuid=2.38.1=h0b41bf4_0
- libwebp-base=1.4.0=hd590300_0
- libxcb=1.13=h7f98852_1004
- libxcrypt=4.4.36=hd590300_1
- libxml2=2.12.6=h232c23b_2
- libzlib=1.2.13=hd590300_5
- llvm-openmp=15.0.7=h0cdce71_0
- markupsafe=2.1.5=py310h2372a71_0
- mkl=2023.1.0=h213fc3f_46344
- mpc=1.3.1=hfe3b2da_0
- mpfr=4.2.1=h9458935_1
- mpmath=1.3.0=pyhd8ed1ab_0
- ncurses=6.4.20240210=h59595ed_0
- nettle=3.6=he412f7d_0
- networkx=3.3=pyhd8ed1ab_1
- ninja=1.12.0=h00ab1b0_0
- nsight-compute=2024.1.1.4=0
- numpy=1.26.4=py310hb13e2d6_0
- openh264=2.1.1=h780b84a_0
- openjpeg=2.5.0=hfec8fc6_2
- openssl=3.2.1=hd590300_1
- pillow=9.4.0=py310h023d228_1
- pip=24.0=pyhd8ed1ab_0
- pthread-stubs=0.4=h36c2ea0_1001
- pysocks=1.7.1=pyha2e5f31_6
- python=3.10.14=hd12c33a_0_cpython
- python_abi=3.10=4_cp310
- pytorch=2.3.0=py3.10_cuda12.1_cudnn8.9.2_0
- pytorch-cuda=12.1=ha16c6d3_5
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.1=py310h2372a71_1
- readline=8.2=h8228510_1
- requests=2.31.0=pyhd8ed1ab_0
- setuptools=69.5.1=pyhd8ed1ab_0
- sympy=1.12=pypyh9d50eac_103
- sysroot_linux-64=2.12=he073ed8_17
- tbb=2021.12.0=h00ab1b0_0
- tk=8.6.13=noxft_h4845f30_101
- torchaudio=2.3.0=py310_cu121
- torchtriton=2.3.0=py310
- torchvision=0.18.0=py310_cu121
- typing_extensions=4.11.0=pyha770c72_0
- tzdata=2024a=h0c530f3_0
- urllib3=2.2.1=pyhd8ed1ab_0
- wheel=0.43.0=pyhd8ed1ab_1
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxdmcp=1.1.3=h7f98852_0
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- zlib=1.2.13=hd590300_5
- zstd=1.5.5=hfc55251_0
- pip:
- absl-py==2.1.0
- beautifulsoup4==4.12.3
- cachetools==5.3.3
- click==8.1.7
- contourpy==1.2.1
- cycler==0.12.1
- fonttools==4.51.0
- gdown==5.1.0
- grpcio==1.62.2
- imageio-ffmpeg==0.4.9
- kiwisolver==1.4.5
- markdown==3.6
- matplotlib==3.8.4
- nvidia-ml-py==12.535.161
- nvitop==1.3.2
- packaging==24.0
- protobuf==5.26.1
- psutil==5.9.8
- pyparsing==3.1.2
- pyspng==0.1.1
- python-dateutil==2.9.0.post0
- scipy==1.13.0
- six==1.16.0
- soupsieve==2.5
- tensorboard==2.16.2
- tensorboard-data-server==0.7.2
- termcolor==2.4.0
- tqdm==4.66.2
- werkzeug==3.0.2