本記事では、画像を生成するStyleGAN-XLと、画像とテキストの関連性スコアを算出するCLIPを用いて、テキストから画像を生成する方法をご紹介します。
StyleGAN-XL
概要
StyleGANは、画質と制御性に優れ画像生成タスクのスタンダードとも言えます。
しかしながら、ImageNetなどの大規模な非構造化データセットにおいて、トレーニング時のパフォーマンスが大幅に低下するという問題がありました。
StyleGAN-XLでは、ProjectedGANパラダイムに従い、neural network
priorsとprogressive growing strategyを活用し、ImageNetを用いたStyleGAN3
generatorの正常なトレーニングを実現しました。
この結果ImageNetやFFHQ,
CIFARなどのデータセットでSOTAを達成し、大規模データセットで1024x1024の解像度の画像を生成する最初のモデルとなっています。
詳細はこちらの論文をご参照ください。
本記事では上記手法を用いて、Text to Imageタスクを実行する方法をご紹介します。
デモ(Colaboratory)
それでは、実際に動かしながらテキストから画像を生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo
また、下記から直接Google Colaboratoryで開くこともできます。
また、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
環境セットアップ
それではセットアップしていきます。
Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
初めにGithubからソースコードを取得します。
%cd /content
!git clone https://github.com/autonomousvision/stylegan_xl
!git clone https://github.com/openai/CLIP
次にライブラリをインストールします。
%cd /content
!pip install -e ./CLIP
!pip install einops ninja
!pip install timm
最後にライブラリをインポートします。
%cd /content
import sys
sys.path.append('./CLIP')
sys.path.append('./stylegan_xl')
import io
import os, time, glob
import pickle
import shutil
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import requests
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import clip
import unicodedata
import re
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from IPython.display import display
from einops import rearrange
from google.colab import files
import dnnlib
import legacy
from base64 import b64encode
from IPython.display import HTML
device = torch.device('cuda:0')
print('Using device:', device, file=sys.stderr)
以上で環境セットアップは完了です。
クラス・関数定義
次に、画像やテキストの関連度を算出するためエンコーダーなどを定義しておきます。
def fetch(url_or_path):
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
r = requests.get(url_or_path)
r.raise_for_status()
fd = io.BytesIO()
fd.write(r.content)
fd.seek(0)
return fd
return open(url_or_path, 'rb')
def fetch_model(url_or_path):
!wget -c '{url_or_path}'
def slugify(value, allow_unicode=False):
value = str(value)
if allow_unicode:
value = unicodedata.normalize('NFKC', value)
else:
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
value = re.sub(r'[^\w\s-]', '', value.lower())
return re.sub(r'[-\s]+', '-', value).strip('-_')
def norm1(prompt):
"Normalize to the unit sphere."
return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
def prompts_dist_loss(x, targets, loss):
if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance
return loss(x, targets[0])
distances = [loss(x, target) for target in targets]
return torch.stack(distances, dim=-1).sum(dim=-1)
def embed_image(image):
n = image.shape[0]
cutouts = make_cutouts(image)
embeds = clip_model.embed_cutout(cutouts)
embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
return embeds
def embed_url(url):
image = Image.open(fetch(url)).convert('RGB')
return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
class MakeCutouts(torch.nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
return torch.cat(cutouts)
class CLIP(object):
def __init__(self):
clip_model = "ViT-B/16"
self.model, _ = clip.load(clip_model)
self.model = self.model.requires_grad_(False)
self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
@torch.no_grad()
def embed_text(self, prompt):
"Normalized clip text embedding."
return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
def embed_cutout(self, image):
"Normalized clip image embedding."
return norm1(self.model.encode_image(self.normalize(image)))
make_cutouts = MakeCutouts(224, 32, 0.5)
clip_model = CLIP()
def fetch_model(url_or_path):
!wget -c '{url_or_path}'
学習済みモデルのセットアップ
次に学習済みモデルをセットアップします。
まずは、Imagenetでトレーニングされたモデルをロードします。
Model = 'Imagenet' #@param ["Imagenet", "Pokemon", "FFHQ"]
network_url = {
"Imagenet": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl",
"Pokemon": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl",
"FFHQ": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl"
}
network_name = network_url[Model].split("/")[-1]
fetch_model(network_url[Model])
with dnnlib.util.open_url(network_name) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
zs = torch.randn([10000, G.mapping.z_dim], device=device)
cs = torch.zeros([10000, G.mapping.c_dim], device=device)
if G.mapping.c_dim != 0:
for i in range(cs.shape[0]):
cs[i,i//10]=1
w_stds = G.mapping(zs, cs)
w_stds = w_stds.reshape(10, 1000, G.num_ws, -1)
w_stds=w_stds.std(0).mean(0)[0]
w_all_classes_avg = G.mapping.w_avg.mean(0)
テキスト設定
入力となるテキストや、最適化ステップ数、Seedなどを定義します。
今回はピラミッド状の犬という現実ではありえないテキストを入力してみます。
texts = "Pyramid dog"#@param {type:"string"}
steps = 500#@param {type:"number"}
seed = 12#@param {type:"number"}
if seed == -1:
seed = np.random.randint(0,9e9)
print(f"Your random seed is: {seed}")
texts = [frase.strip() for frase in texts.split("|") if frase]
targets = [clip_model.embed_text(text) for text in texts]
tf = Compose([
# Resize(224),
lambda x: torch.clamp((x+1)/2,min=0,max=1),
])
initial_batch=4 #actually that will be multiplied by initial_image_steps
initial_image_steps=32
Text to Image
それでは、テキストに沿って画像を生成させてみます。
def run(out_dir):
torch.manual_seed(seed)
with torch.no_grad():
qs = []
losses = []
for _ in range(initial_image_steps):
a = torch.randn([initial_batch, 512], device=device)*0.4 + w_stds*0.4
q = ((a-w_all_classes_avg)/w_stds)
images = G.synthesis((q * w_stds + w_all_classes_avg).unsqueeze(1).repeat([1, G.num_ws, 1]))
embeds = embed_image(images.add(1).div(2))
loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0)
i = torch.argmin(loss)
qs.append(q[i])
losses.append(loss[i])
qs = torch.stack(qs)
losses = torch.stack(losses)
# print(losses)
# print(losses.shape, qs.shape)
i = torch.argmin(losses)
q = qs[i].unsqueeze(0).repeat([G.num_ws, 1]).requires_grad_()
# Sampling loop
q_ema = q
print(q.shape)
opt = torch.optim.AdamW([q], lr=0.05, betas=(0., 0.999), weight_decay=0.025)
loop = tqdm(range(steps))
for i in loop:
opt.zero_grad()
w = q * w_stds
image = G.synthesis((q * w_stds + w_all_classes_avg)[None], noise_mode='const')
embed = embed_image(image.add(1).div(2))
loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean()
loss.backward()
opt.step()
loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
q_ema = q_ema * 0.98 + q * 0.02
image = G.synthesis((q_ema * w_stds + w_all_classes_avg)[None], noise_mode='const')
if i % 50 == 0:
display(TF.to_pil_image(tf(image)[0]))
print(f"Image {i}/{steps} | Current loss: {loss}")
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
os.makedirs(f'{out_dir}', exist_ok=True)
pil_image.save(f'{out_dir}/{i:04}.jpg')
run('./outputs')
!ffmpeg -r 60 -i 'outputs/%04d.jpg' -vcodec libx264 -crf 18 -pix_fmt yuv420p result.mp4
def show_mp4(filename, width):
mp4 = open(filename + '.mp4', 'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
display(HTML("""
<video width="%d" controls autoplay loop>
<source src="%s" type="video/mp4">
</video>
""" % (width, data_url)))
show_mp4("result", width=514)
出力結果は以下の通りです。
少しおどろおどろしい様相の画像になりましたが、テキストが反映された画像ではありそうです。
最後にstepごとの画像を動画にした結果は以下の通りです。
最初にピラミッドが作られ、その後、犬が現れる様子が見て取れます。
まとめ
本記事では、StyleGAN-XLとCLIPを用いてテキストから画像を生成する方法をご紹介しました。
seedを変えるだけで初期画像が変わり、生成結果も変わってくるので試行錯誤しながら様々な画像生成を楽しむことができます。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
リンク
リンク
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
参考文献
1. 論文 - StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets
2. GitHub - autonomousvision/stylegan_xl
0 件のコメント :
コメントを投稿