Coder Social home page Coder Social logo

sound-guided-semantic-image-manipulation's Introduction

๐Ÿ”‰ Sound-guided Semantic Image Manipulation (CVPR2022)

Official Pytorch Implementation

Teaser image

Sound-guided Semantic Image Manipulation
IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) 2022

Paper : CVPR 2022 Open Access
Project Page: https://kuai-lab.github.io/cvpr2022sound/
Seung Hyun Lee, Wonseok Roh, Wonmin Byeon, Sang Ho Yoon, Chanyoung Kim, Jinkyu Kim*, and Sangpil Kim*

Abstract: The recent success of the generative model shows that leveraging the multi-modal embedding space can manipulate an image using text information. However, manipulating an image with other sources rather than text, such as sound, is not easy due to the dynamic characteristics of the sources. Especially, sound can convey vivid emotions and dynamic expressions of the real world. Here, we propose a framework that directly encodes sound into the multi-modal~(image-text) embedding space and manipulates an image from the space. Our audio encoder is trained to produce a latent representation from an audio input, which is forced to be aligned with image and text representations in the multi-modal embedding space. We use a direct latent optimization method based on aligned embeddings for sound-guided image manipulation. We also show that our method can mix different modalities, i.e., text and audio, which enrich the variety of the image modification. The experiments on zero-shot audio classification and semantic-level image classification show that our proposed model outperforms other text and sound-guided state-of-the-art methods.

๐Ÿ’พ Installation

For all the methods described in the paper, is it required to have:

Specific requirements for each method are described in its section. To install CLIP please run the following commands:

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=<CUDA_VERSION>
pip install ftfy regex tqdm gdown
pip install git+https://github.com/openai/CLIP.git

๐Ÿ”จ Method

Method image

1. CLIP-based Contrastive Latent Representation Learning.

Dataset Curation.

We create an audio-text pair dataset with the vggsound dataset. We also used the audioset dataset as the script below.

  1. Please download vggsound.csv from the link.
  2. Execute download.py to download the audio file of the vggsound dataset.
  3. Execute curate.py to preprocess the audio file (wav to mel-spectrogram).
cd soundclip
python3 download.py
python3 curate.py

Training.

python3 train.py

2. Sound-Guided Image Manipulation.

Direct Latent Code Optimization.

The code relies on the StyleCLIP pytorch implementation.

python3 optimization/run_optimization.py --lambda_similarity 0.002 --lambda_identity 0.0 --truncation 0.7 --lr 0.1 --audio_path "./audiosample/explosion.wav" --ckpt ./pretrained_models/landscape.pt --stylegan_size 256

โ›ณ Results

Zero-shot Audio Classification Accuracy.

Model Supervised Setting Zero-Shot ESC-50 UrbanSound 8K
ResNet50 โœ… - 66.8% 71.3%
Ours (Without Self-Supervised) - - 58.7% 63.3%
โœจ Ours (Logistic Regression) - - 72.2% 66.8%
Wav2clip - โœ… 41.4% 40.4%
AudioCLIP - โœ… 69.4% 68.8%
Ours (Without Self-Supervised) - โœ… 49.4% 45.6%
โœจ Ours - โœ… 57.8% 45.7%

Manipulation Results.

LSUN. LSUN image

FFHQ. FFHQ image

To see more diverse examples, please visit our project page!

Citation

@InProceedings{Lee_2022_CVPR,
    author    = {Lee, Seung Hyun and Roh, Wonseok and Byeon, Wonmin and Yoon, Sang Ho and Kim, Chanyoung and Kim, Jinkyu and Kim, Sangpil},
    title     = {Sound-Guided Semantic Image Manipulation},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022},
    pages     = {3377-3386}
}

sound-guided-semantic-image-manipulation's People

Contributors

kochanha avatar lsh3163 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

sound-guided-semantic-image-manipulation's Issues

Pretrained model for ArcFace

Hi,
thanks for the great work!

Could you also share the pretrained model used for ArcFace? (model_ir_se50.pth)

Thank you!

Zero-shot Code

Thanks for the great work! Can you share the code of zero-shot transfer on the ESC-50 and Urban sound 8k? Thanks.

add Gradio Demo for cvpr 2022 call for demos

Hi, would you be interested in adding sound-guided-semantic-image-manipulation to Hugging Face as a Gradio Web Demo for CVPR 2022 call for Demos? The Hub offers free hosting, and it would make your work more accessible and visible to the rest of the ML community. Models/datasets/spaces(web demos) can be added to a user account or organization similar to github.

more info on CVPR call for demos: https://huggingface.co/CVPR

and here are guides for adding web demo to the org

How to add a Space: https://huggingface.co/blog/gradio-spaces

Please let us know if you would be interested and if you have any questions, we can also help with the technical implementation.

Does part 1 of training (CLIP-based training) includes image modality?

Hi,
Thanks for the great work!!

The paper states that during part-1 training (i.e. CLIP-based Contrastive Latent Representation Learning step) you consider image, text and audio modalities. But the code only uses audio and text modality for this training part.

Is this an old code? Or I misinterpreted the training part in paper?
Thanks

Computation requirement for training

Hi,

First of all, great contribution towards the field of image manipulation. Could you please provide information on how many GPUs and how much duration it took to train the model?

Thanks,
Himangi

About StyleGAN3

The main code is borrowed from the link below
Link : https://github.com/ouhenio/StyleGAN3-CLIP-notebooks

StyleGAN3 + Our CLIP-based sound representation

import sys

import io
import os, time, glob
import pickle
import shutil
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import requests
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import clip
import unicodedata
import re
from tqdm import tqdm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from einops import rearrange
from collections import OrderedDict

import timm
import librosa
import cv2

def make_transform(translate, angle):
    m = np.eye(3)
    s = np.sin(angle/360.0*np.pi*2)
    c = np.cos(angle/360.0*np.pi*2)
    m[0][0] = c
    m[0][1] = s
    m[0][2] = translate[0]
    m[1][0] = -s
    m[1][1] = c
    m[1][2] = translate[1]
    return m
    
class AudioEncoder(torch.nn.Module):
    def __init__(self):
        super(AudioEncoder, self).__init__()
        self.conv = torch.nn.Conv2d(1, 3, (3, 3))
        self.feature_extractor = timm.create_model("resnet18", num_classes=512, pretrained=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.feature_extractor(x)
        return x

def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith("module"):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ".".join(k.split(".")[start_idx:])
        new_state_dict[name] = v
    return new_state_dict

class CLIP(object):
  def __init__(self):
    clip_model = "ViT-B/32"
    self.model, _ = clip.load(clip_model)
    self.model = self.model.requires_grad_(False)
    self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                          std=[0.26862954, 0.26130258, 0.27577711])

  @torch.no_grad()
  def embed_text(self, prompt):
      "Normalized clip text embedding."
      return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())

  def embed_cutout(self, image):
      "Normalized clip image embedding."
      # return norm1(self.model.encode_image(self.normalize(image)))
      return norm1(self.model.encode_image(image))

tf = Compose([
  Resize(224),
  lambda x: torch.clamp((x+1)/2,min=0,max=1),
  ])

def norm1(prompt):
    "Normalize to the unit sphere."
    return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()

def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

def prompts_dist_loss(x, targets, loss):
    if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance 
      return loss(x, targets[0])
    distances = [loss(x, target) for target in targets]
    return torch.stack(distances, dim=-1).sum(dim=-1)  

class MakeCutouts(torch.nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)

make_cutouts = MakeCutouts(224, 32, 0.5)

def embed_image(image):
  n = image.shape[0]
  cutouts = make_cutouts(image)
  embeds = clip_model.embed_cutout(cutouts)
  embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
  return embeds

def run(timestring):
  torch.manual_seed(seed)

  # Init
  # Sample 32 inits and choose the one closest to prompt

  with torch.no_grad():
    qs = []
    losses = []
    for _ in range(8):
      q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
      images = G.synthesis(q * w_stds + G.mapping.w_avg)
      embeds = embed_image(images.add(1).div(2))
      loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0)
      i = torch.argmin(loss)
      qs.append(q[i])
      losses.append(loss[i])
    qs = torch.stack(qs)
    losses = torch.stack(losses)
    i = torch.argmin(losses)
    q = qs[i].unsqueeze(0).requires_grad_()

  w_init = (q * w_stds + G.mapping.w_avg).detach().clone()
  # Sampling loop
  q_ema = q
  opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
  loop = tqdm(range(steps))
  for i in loop:
    opt.zero_grad()
    w = q * w_stds + G.mapping.w_avg
    image = G.synthesis(w , noise_mode='const')
    embed = embed_image(image.add(1).div(2))
    loss = 0.1 *  prompts_dist_loss(embed, targets, spherical_dist_loss).mean() + ((w - w_init) ** 2).mean()
    # loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean()
    loss.backward()
    opt.step()
    loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())

    q_ema = q_ema * 0.9 + q * 0.1

    final_code = q_ema * w_stds + G.mapping.w_avg
    final_code[:,6:,:] = w_init[:,6:,:]
    image = G.synthesis(final_code, noise_mode='const')

    if i % 10 == 9 or i % 10 == 0:
      # display(TF.to_pil_image(tf(image)[0]))
      print(f"Image {i}/{steps} | Current loss: {loss}")
      pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1).cpu())
      os.makedirs(f'samples/{timestring}', exist_ok=True)
      pil_image.save(f'samples/{timestring}/{i:04}.jpg')


device = torch.device('cuda:0')
print('Using device:', device, file=sys.stderr)

model_url = "./pretrained_models/stylegan3-r-afhqv2-512x512.pkl"

with open(model_url, 'rb') as fp:
  G = pickle.load(fp)['G_ema'].to(device)

zs = torch.randn([100000, G.mapping.z_dim], device=device)
w_stds = G.mapping(zs, None).std(0)

m = make_transform([0,0], 0)
m = np.linalg.inv(m)
G.synthesis.input.transform.copy_(torch.from_numpy(m))
# audio_paths = "./audio/sweet-kitty-meow.wav"
#audio_paths = "./audio/dog-sad.wav"
audio_paths = "./audio/cartoon-voice-laugh.wav"
steps = 200
seed = 14 + 22
#seed = 22

audio_paths = [frase.strip() for frase in audio_paths.split("|") if frase]

clip_model = CLIP()
audio_encoder = AudioEncoder()
audio_encoder.load_state_dict(copyStateDict(torch.load("./pretrained_models/resnet18.pth", map_location=device)))
audio_encoder = audio_encoder.to(device)
audio_encoder.eval()

targets = []
n_mels = 128
time_length = 864
resize_resolution = 512

for audio_path in audio_paths:
    y, sr = librosa.load(audio_path, sr=44100)
    audio_inputs = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
    audio_inputs = librosa.power_to_db(audio_inputs, ref=np.max) / 80.0 + 1

    zero = np.zeros((n_mels, time_length))
    h, w = audio_inputs.shape
    if w >= time_length:
        j = (w - time_length) // 2
        audio_inputs = audio_inputs[:,j:j+time_length]
    else:
        j = (time_length - w) // 2
        zero[:,:w] = audio_inputs[:,:w]
        audio_inputs = zero
    
    audio_inputs = cv2.resize(audio_inputs, (n_mels, resize_resolution))
    audio_inputs = np.array([audio_inputs])
    audio_inputs = torch.from_numpy(audio_inputs.reshape((1, 1, n_mels, resize_resolution))).float().to(device)
    with torch.no_grad():
        audio_embedding = audio_encoder(audio_inputs)
        audio_embedding = audio_embedding / audio_embedding.norm(dim=-1, keepdim=True)
    targets.append(audio_embedding)

timestring = time.strftime('%Y%m%d%H%M%S')
run(timestring)

Is there something wrong with my zero-shot code?

Thanks for the great work!

I tried to reproduce your work on my environment. However, the accuracy of the your pre-trained model (ResNet18) was 3% on my zero-shot code.
Therefore, please check what's different compared to your code.
Thank you so much!

This is my zero-shot classification code.

# model init from ckpt
audioencoder = AudioEncoder().cuda()
ckpt = torch.load('./audio-diffusion/resnet18_57.pth')
audioencoder.load_state_dict(ckpt)
textencoder, _ = clip.load("ViT-B/32")
textencoder = textencoder.cuda()

correct = 0

for i in tqdm(range(len(dataset))):
    audio = dataset[i][0][None].cuda() # (B, 1, 128, 512)

    # embeddings
    with torch.no_grad():
        text_token = torch.cat([clip.tokenize(txt) for txt in text])
        text_features = textencoder.encode_text(text_token.cuda())
        audio_features = audioencoder(audio)
        audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
    audio_features = audio_features.float()
    text_features = text_features.float()

    # logit
    proj_per_audio = (audio_features @ text_features.T) * math.exp(0.07)

    # classification
    # confidence = proj_per_audio.softmax(dim=1)
    # conf_values, ids = confidence.topk(1)
    label_idx = torch.argmax(proj_per_audio, axis=1)

    pred = LABELS[label_idx]
    label = dataset[i][-1]

    if label == pred:
        correct += 1
    
    if i%100 == 0:
        print(correct)
        
print(f'Classification accuracy: {correct/len(dataset)*100} %')

pretrained model

i am very interested in your work.
iโ€™d like to kown how to get the pretrained model?
thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.