[StyleGAN-XL] GAN+CLIPでテキストから画像を生成する

2022年6月13日月曜日

Artificial Intelligence

本記事では、画像を生成するStyleGAN-XLと、画像とテキストの関連性スコアを算出するCLIPを用いて、テキストから画像を生成する方法をご紹介します。

アイキャッチ
出典: autonomousvision/stylegan_xl

StyleGAN-XL

概要

StyleGANは、画質と制御性に優れ画像生成タスクのスタンダードとも言えます。
しかしながら、ImageNetなどの大規模な非構造化データセットにおいて、トレーニング時のパフォーマンスが大幅に低下するという問題がありました。

StyleGAN-XLでは、ProjectedGANパラダイムに従い、neural network priorsとprogressive growing strategyを活用し、ImageNetを用いたStyleGAN3 generatorの正常なトレーニングを実現しました。
この結果ImageNetやFFHQ, CIFARなどのデータセットでSOTAを達成し、大規模データセットで1024x1024の解像度の画像を生成する最初のモデルとなっています。

overview
出典: StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、Text to Imageタスクを実行する方法をご紹介します。

デモ(Colaboratory)

それでは、実際に動かしながらテキストから画像を生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

また、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content

!git clone https://github.com/autonomousvision/stylegan_xl
!git clone https://github.com/openai/CLIP

次にライブラリをインストールします。

%cd /content

!pip install -e ./CLIP
!pip install einops ninja
!pip install timm

最後にライブラリをインポートします。

%cd /content

import sys
sys.path.append('./CLIP')
sys.path.append('./stylegan_xl')

import io
import os, time, glob
import pickle
import shutil
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import requests
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import clip
import unicodedata
import re
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from IPython.display import display
from einops import rearrange
from google.colab import files
import dnnlib
import legacy
from base64 import b64encode
from IPython.display import HTML

device = torch.device('cuda:0')
print('Using device:', device, file=sys.stderr)

以上で環境セットアップは完了です。

クラス・関数定義

次に、画像やテキストの関連度を算出するためエンコーダーなどを定義しておきます。

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def fetch_model(url_or_path):
    !wget -c '{url_or_path}'

def slugify(value, allow_unicode=False):

    value = str(value)
    if allow_unicode:
        value = unicodedata.normalize('NFKC', value)
    else:
        value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
    value = re.sub(r'[^\w\s-]', '', value.lower())
    return re.sub(r'[-\s]+', '-', value).strip('-_')

def norm1(prompt):
    "Normalize to the unit sphere."
    return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()

def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

def prompts_dist_loss(x, targets, loss):
    if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance 
      return loss(x, targets[0])
    distances = [loss(x, target) for target in targets]
    return torch.stack(distances, dim=-1).sum(dim=-1)

def embed_image(image):
  n = image.shape[0]
  cutouts = make_cutouts(image)
  embeds = clip_model.embed_cutout(cutouts)
  embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
  return embeds

def embed_url(url):
  image = Image.open(fetch(url)).convert('RGB')
  return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
  
class MakeCutouts(torch.nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)

class CLIP(object):
  def __init__(self):
    clip_model = "ViT-B/16"
    self.model, _ = clip.load(clip_model)
    self.model = self.model.requires_grad_(False)
    self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                          std=[0.26862954, 0.26130258, 0.27577711])

  @torch.no_grad()
  def embed_text(self, prompt):
      "Normalized clip text embedding."
      return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())

  def embed_cutout(self, image):
      "Normalized clip image embedding."
      return norm1(self.model.encode_image(self.normalize(image)))
  
make_cutouts = MakeCutouts(224, 32, 0.5)
clip_model = CLIP()

def fetch_model(url_or_path):
  !wget -c '{url_or_path}'

学習済みモデルのセットアップ

次に学習済みモデルをセットアップします。
まずは、Imagenetでトレーニングされたモデルをロードします。

Model = 'Imagenet' #@param ["Imagenet", "Pokemon", "FFHQ"]

network_url = {
    "Imagenet": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl",
    "Pokemon": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl",
    "FFHQ": "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl"
}

network_name = network_url[Model].split("/")[-1]
fetch_model(network_url[Model])

with dnnlib.util.open_url(network_name) as f:
  G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

zs = torch.randn([10000, G.mapping.z_dim], device=device)
cs = torch.zeros([10000, G.mapping.c_dim], device=device)

if G.mapping.c_dim != 0:
  for i in range(cs.shape[0]):
    cs[i,i//10]=1

w_stds = G.mapping(zs, cs)
w_stds = w_stds.reshape(10, 1000, G.num_ws, -1)
w_stds=w_stds.std(0).mean(0)[0]
w_all_classes_avg = G.mapping.w_avg.mean(0)

テキスト設定

入力となるテキストや、最適化ステップ数、Seedなどを定義します。
今回はピラミッド状の犬という現実ではありえないテキストを入力してみます。

texts = "Pyramid dog"#@param {type:"string"}
steps = 500#@param {type:"number"}
seed = 12#@param {type:"number"}

if seed == -1:
    seed = np.random.randint(0,9e9)
    print(f"Your random seed is: {seed}")

texts = [frase.strip() for frase in texts.split("|") if frase]

targets = [clip_model.embed_text(text) for text in texts]

tf = Compose([
  # Resize(224),
  lambda x: torch.clamp((x+1)/2,min=0,max=1),
])

initial_batch=4 #actually that will be multiplied by initial_image_steps
initial_image_steps=32

Text to Image

それでは、テキストに沿って画像を生成させてみます。

def run(out_dir):
  torch.manual_seed(seed)
  with torch.no_grad():
    qs = []
    losses = []
    for _ in range(initial_image_steps):
      a = torch.randn([initial_batch, 512], device=device)*0.4 + w_stds*0.4
      q = ((a-w_all_classes_avg)/w_stds)
      images = G.synthesis((q * w_stds + w_all_classes_avg).unsqueeze(1).repeat([1, G.num_ws, 1]))
      embeds = embed_image(images.add(1).div(2))
      loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0)
      i = torch.argmin(loss)
      qs.append(q[i])
      losses.append(loss[i])
    qs = torch.stack(qs)
    losses = torch.stack(losses)
    # print(losses)
    # print(losses.shape, qs.shape)
    i = torch.argmin(losses)
    q = qs[i].unsqueeze(0).repeat([G.num_ws, 1]).requires_grad_()


  # Sampling loop
  q_ema = q
  print(q.shape)
  opt = torch.optim.AdamW([q], lr=0.05, betas=(0., 0.999), weight_decay=0.025)
  loop = tqdm(range(steps))
  for i in loop:
    opt.zero_grad()
    w = q * w_stds
    image = G.synthesis((q * w_stds + w_all_classes_avg)[None], noise_mode='const')
    embed = embed_image(image.add(1).div(2))
    loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean()
    loss.backward()
    opt.step()
    loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())

    q_ema = q_ema * 0.98 + q * 0.02
    image = G.synthesis((q_ema * w_stds + w_all_classes_avg)[None], noise_mode='const')

    if i % 50 == 0:
      display(TF.to_pil_image(tf(image)[0]))
      print(f"Image {i}/{steps} | Current loss: {loss}")
    pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
    os.makedirs(f'{out_dir}', exist_ok=True)
    pil_image.save(f'{out_dir}/{i:04}.jpg')
    
run('./outputs')

!ffmpeg -r 60 -i 'outputs/%04d.jpg' -vcodec libx264 -crf 18 -pix_fmt yuv420p result.mp4

def show_mp4(filename, width):
  mp4 = open(filename + '.mp4', 'rb').read()
  data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
  display(HTML("""
  <video width="%d" controls autoplay loop>
    <source src="%s" type="video/mp4">
  </video>
  """ % (width, data_url)))
  
show_mp4("result", width=514)

出力結果は以下の通りです。

Pyramid dog

少しおどろおどろしい様相の画像になりましたが、テキストが反映された画像ではありそうです。
最後にstepごとの画像を動画にした結果は以下の通りです。

pyramid dog gif

最初にピラミッドが作られ、その後、犬が現れる様子が見て取れます。

まとめ

本記事では、StyleGAN-XLとCLIPを用いてテキストから画像を生成する方法をご紹介しました。
seedを変えるだけで初期画像が変わり、生成結果も変わってくるので試行錯誤しながら様々な画像生成を楽しむことができます。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1. 論文 - StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets

2. GitHub - autonomousvision/stylegan_xl

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology