import sys
import io
import os, time, glob
import pickle
import shutil
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import requests
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import clip
import unicodedata
import re
from tqdm import tqdm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from einops import rearrange
from collections import OrderedDict
import timm
import librosa
import cv2
def make_transform(translate, angle):
m = np.eye(3)
s = np.sin(angle/360.0*np.pi*2)
c = np.cos(angle/360.0*np.pi*2)
m[0][0] = c
m[0][1] = s
m[0][2] = translate[0]
m[1][0] = -s
m[1][1] = c
m[1][2] = translate[1]
return m
class AudioEncoder(torch.nn.Module):
def __init__(self):
super(AudioEncoder, self).__init__()
self.conv = torch.nn.Conv2d(1, 3, (3, 3))
self.feature_extractor = timm.create_model("resnet18", num_classes=512, pretrained=True)
def forward(self, x):
x = self.conv(x)
x = self.feature_extractor(x)
return x
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ".".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict
class CLIP(object):
def __init__(self):
clip_model = "ViT-B/32"
self.model, _ = clip.load(clip_model)
self.model = self.model.requires_grad_(False)
self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
@torch.no_grad()
def embed_text(self, prompt):
"Normalized clip text embedding."
return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
def embed_cutout(self, image):
"Normalized clip image embedding."
# return norm1(self.model.encode_image(self.normalize(image)))
return norm1(self.model.encode_image(image))
tf = Compose([
Resize(224),
lambda x: torch.clamp((x+1)/2,min=0,max=1),
])
def norm1(prompt):
"Normalize to the unit sphere."
return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
def prompts_dist_loss(x, targets, loss):
if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance
return loss(x, targets[0])
distances = [loss(x, target) for target in targets]
return torch.stack(distances, dim=-1).sum(dim=-1)
class MakeCutouts(torch.nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
return torch.cat(cutouts)
make_cutouts = MakeCutouts(224, 32, 0.5)
def embed_image(image):
n = image.shape[0]
cutouts = make_cutouts(image)
embeds = clip_model.embed_cutout(cutouts)
embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
return embeds
def run(timestring):
torch.manual_seed(seed)
# Init
# Sample 32 inits and choose the one closest to prompt
with torch.no_grad():
qs = []
losses = []
for _ in range(8):
q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
images = G.synthesis(q * w_stds + G.mapping.w_avg)
embeds = embed_image(images.add(1).div(2))
loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0)
i = torch.argmin(loss)
qs.append(q[i])
losses.append(loss[i])
qs = torch.stack(qs)
losses = torch.stack(losses)
i = torch.argmin(losses)
q = qs[i].unsqueeze(0).requires_grad_()
w_init = (q * w_stds + G.mapping.w_avg).detach().clone()
# Sampling loop
q_ema = q
opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
loop = tqdm(range(steps))
for i in loop:
opt.zero_grad()
w = q * w_stds + G.mapping.w_avg
image = G.synthesis(w , noise_mode='const')
embed = embed_image(image.add(1).div(2))
loss = 0.1 * prompts_dist_loss(embed, targets, spherical_dist_loss).mean() + ((w - w_init) ** 2).mean()
# loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean()
loss.backward()
opt.step()
loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
q_ema = q_ema * 0.9 + q * 0.1
final_code = q_ema * w_stds + G.mapping.w_avg
final_code[:,6:,:] = w_init[:,6:,:]
image = G.synthesis(final_code, noise_mode='const')
if i % 10 == 9 or i % 10 == 0:
# display(TF.to_pil_image(tf(image)[0]))
print(f"Image {i}/{steps} | Current loss: {loss}")
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1).cpu())
os.makedirs(f'samples/{timestring}', exist_ok=True)
pil_image.save(f'samples/{timestring}/{i:04}.jpg')
device = torch.device('cuda:0')
print('Using device:', device, file=sys.stderr)
model_url = "./pretrained_models/stylegan3-r-afhqv2-512x512.pkl"
with open(model_url, 'rb') as fp:
G = pickle.load(fp)['G_ema'].to(device)
zs = torch.randn([100000, G.mapping.z_dim], device=device)
w_stds = G.mapping(zs, None).std(0)
m = make_transform([0,0], 0)
m = np.linalg.inv(m)
G.synthesis.input.transform.copy_(torch.from_numpy(m))
# audio_paths = "./audio/sweet-kitty-meow.wav"
#audio_paths = "./audio/dog-sad.wav"
audio_paths = "./audio/cartoon-voice-laugh.wav"
steps = 200
seed = 14 + 22
#seed = 22
audio_paths = [frase.strip() for frase in audio_paths.split("|") if frase]
clip_model = CLIP()
audio_encoder = AudioEncoder()
audio_encoder.load_state_dict(copyStateDict(torch.load("./pretrained_models/resnet18.pth", map_location=device)))
audio_encoder = audio_encoder.to(device)
audio_encoder.eval()
targets = []
n_mels = 128
time_length = 864
resize_resolution = 512
for audio_path in audio_paths:
y, sr = librosa.load(audio_path, sr=44100)
audio_inputs = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
audio_inputs = librosa.power_to_db(audio_inputs, ref=np.max) / 80.0 + 1
zero = np.zeros((n_mels, time_length))
h, w = audio_inputs.shape
if w >= time_length:
j = (w - time_length) // 2
audio_inputs = audio_inputs[:,j:j+time_length]
else:
j = (time_length - w) // 2
zero[:,:w] = audio_inputs[:,:w]
audio_inputs = zero
audio_inputs = cv2.resize(audio_inputs, (n_mels, resize_resolution))
audio_inputs = np.array([audio_inputs])
audio_inputs = torch.from_numpy(audio_inputs.reshape((1, 1, n_mels, resize_resolution))).float().to(device)
with torch.no_grad():
audio_embedding = audio_encoder(audio_inputs)
audio_embedding = audio_embedding / audio_embedding.norm(dim=-1, keepdim=True)
targets.append(audio_embedding)
timestring = time.strftime('%Y%m%d%H%M%S')
run(timestring)