- Python 3.7
- torch 1.12.1
- accelerate 0.17.1
conda create -n efill python=3.7
conda activate efill
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=10.2 -c pytorch
pip install -r requirements.txt
python -m venv efill
efill\Scripts\activate
pip install torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
Download the pretrained models above, and put them under the folder of checkpoints
cd demo
python demo.py \
--port 8000 \
--model_path ../checkpoints/place_best.pth
Then, click on the link that pops up below. For example
Download the dataset
Please refer to this link lama for download the dataset of CelebA-HQ and Places.
Download the pretrained models
-
download model for calculating the perceptual loss
-
download the models AlexNet and Inception for metric calculation.
mkdir -p ./hub/checkpoints cd ./hub/checkpoints wget https://download.pytorch.org/models/alexnet-owt-7be5be79.pth wget https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth
-
prepare images and masks for validation
sh prepare.sh
-
download the pretrained teacher models (Recommend)
Note: this is an optional choice. You can also train the teacher model from scratch.
Configure the accelerator
We use the framework accelerate to speed up the training. Before starting training, you should specify a config file for it. Run the following command in terminal.
accelerate config --config_file acc_config.yaml
Training the teacher
Modify the example_train.yaml
on the following items:
mode: 2
Generator: Teacher_concat_WithAtt
...
Then run
CUDA_VISIBLE_DEVICES=0 accelerate launch --config_file ./acc_config.yaml ./run.py --configs ./config/example_train.yaml
Training EFill
Modify the example_train.yaml
on the following items:
mode:1
Generator: DistInpaintModel_SPADE_IN_LFFC_Base_concat_WithAtt
st_TeacherPath:./checkpoints/celeba-hq_latest.pth
...
Prepare the images and masks
python prepare_masks.py \
--dataset_name "Celeba" \
--mask_type "thick_256" \
--target_size 256 \
--aspect_ratio_kept \
--fixed_size \
--total_num 10000 \
--img_dir "/home/codeoops/CV/data/celeba/test" \
--save_dir "./dataset/validation"
Evaluate the performance
python performance.py \
--dataset_name celeba \
--config_path ./config/celeba_train.yaml \
--model_path ./checkpoints/celeba_best.pth \
--mask_type thick_256 \
--target_size 256 \
--total_num 10000 \
--img_dir ./dataset/validation/Celeba/thick_256/imgs \
--mask_dir ./dataset/validation/Celeba/thick_256/masks \
--save_dir ./results
Our code is built upon the following repositories: