[DALL-E 2] AIでテキストから画像を生成する

2022年8月11日木曜日

Artificial Intelligence

本記事では、機械学習手法Dalle2の再現実装を用いてテキストから画像を生成する方法をご紹介します。

eye catch
出典: DALL-E 2

DALL-E 2

概要

DALL-E 2はOpenAIが発表した画像合成モデルで、DALL-Eの後継となります。主にテキストから画像を生成するText to Imageタスクに使用されます。

自然言語のテキストを入力に、画像埋め込み(Image embedding)を出力するpriorモデルと、
画像埋め込みから、画像を生成するdecoderモデル2stageで構成されています。

priorモデルではCLIPを、decoderモデルでは拡散モデルがベースとして用いられています。
この2段階の構成により、生成画像の写実さと入力テキストとの意味的関連性の損失を抑えつつ、画像の多様性向上を実現しています。

overview
出典: Hierarchical Text-Conditional Image Generation with CLIP Latents

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、任意のテキストから画像を生成(Text to Image)していきます。

デモ(Colaboratory)

それでは、実際に動かしながらテキストから画像を生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、本デモはこちらのHuggingfaceに公開された学習済みモデルを使用しますが、論文に掲載されたモデルと比較して1%以下のトレーニングしか実施されていません。これにより精度が低いため、忠実な画像生成ではなく、DALL-E 2の推論のさせ方のご参考としてご覧いただければ幸いです。

また、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。
取得するSwinIRはDALL-E 2が生成した64x64の画像を256x256に拡大するために使用します。

# for SR
!git clone https://github.com/JingyunLiang/SwinIR.git

次にライブラリをインストールします。

!pip install -q dalle2_pytorch==0.15.4

# for swinIR
!pip install -q timm
!pip install -q opencv-python

最後にライブラリをインポートします。

import os
import json
from IPython.display import Image as IPythonImage

from PIL import Image
import numpy as np

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP, OpenAIClipAdapter

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

続いて、学習済みモデルをGoogle Colaboratoryにダウンロードします。
prior, decoderのそれぞれをダウンロードします。

!mkdir decoder prior

# Laion2B
!wget -c https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B/latest.pth \
      -O ./decoder/latest.pth
!wget -c https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B_laion2B/decoder_config.json \
      -O ./decoder/decoder_config.json

# prior
!wget -c https://huggingface.co/zenglishuci/conditioned-prior/resolve/main/vit-l-14/prior_aes_finetune.pth \
      -O ./prior/prior_aes_finetune.pth

学習済みモデルのロード

ダウンロードしたモデルをそれぞれロードします。
まず、CLIPとpriorをロードします。

decoder_model_path = './decoder/latest.pth'
decoder_config_path = './decoder/decoder_config.json'
prior_model_path = './prior/prior_aes_finetune.pth'
with open(decoder_config_path, "r") as f:
  decoder_config = json.load(f)
  
clip = OpenAIClipAdapter(decoder_config['decoder']['clip']['model']).to(device)

prior_network = DiffusionPriorNetwork(
    dim=768,
    depth=24,
    dim_head=64,
    heads=32,
    normformer=True,
    attn_dropout=5e-2,
    ff_dropout=5e-2,
    num_time_embeds=1,
    num_image_embeds=1,
    num_text_embeds=1,
    num_timesteps=1000,
    ff_mult=4
).to(device)

diffusion_prior = DiffusionPrior(
    net=prior_network,
    clip=clip,
    image_embed_dim=768,
    timesteps=1000,
    cond_drop_prob=0.1,
    loss_type="l2",
    condition_on_text_encodings=True,
).to(device)

prior_state_dict = torch.load(prior_model_path, map_location='cpu')
diffusion_prior.load_state_dict(prior_state_dict['ema_model'], strict=True)

del prior_state_dict

次に、decoderをロードします。

unet = Unet(
  dim = decoder_config['decoder']['unets'][0]['dim'],
  cond_dim = decoder_config['decoder']['unets'][0]['cond_dim'],
  image_embed_dim = decoder_config['decoder']['unets'][0]['image_embed_dim'],
  text_embed_dim = decoder_config['decoder']['unets'][0]['text_embed_dim'],
  cond_on_text_encodings = decoder_config['decoder']['unets'][0]['cond_on_text_encodings'],
  channels = decoder_config['decoder']['unets'][0]['channels'],
  dim_mults = decoder_config['decoder']['unets'][0]['dim_mults'],
  num_resnet_blocks= decoder_config['decoder']['unets'][0]['num_resnet_blocks'],
  attn_heads= decoder_config['decoder']['unets'][0]['attn_heads'],
  attn_dim_head= decoder_config['decoder']['unets'][0]['attn_dim_head'],
  sparse_attn=decoder_config['decoder']['unets'][0]['sparse_attn'],
  memory_efficient=decoder_config['decoder']['unets'][0]['memory_efficient'],
  self_attn =decoder_config['decoder']['unets'][0]['self_attn'],
)

decoder = Decoder(
    unet = (unet),
    clip = clip,
    image_sizes = decoder_config['decoder']['image_sizes'],
    channels = decoder_config['decoder']['channels'],
    timesteps = decoder_config['decoder']['timesteps'],
    loss_type =decoder_config['decoder']['loss_type'],
    beta_schedule = decoder_config['decoder']['beta_schedule'],
    learned_variance =decoder_config['decoder']['learned_variance'],
).cuda()

decoder_state_dict = torch.load(decoder_model_path, map_location='cpu')
decoder.load_state_dict(decoder_state_dict, strict=False)
decoder.eval()

del decoder_state_dict

Text to Image

ロードしたモデルを使用して、Text to Imageを実行します。
モデルさえロードできてしまえば、あとはテキストをモデルに入力するのみです。

今回は'strawberry cake', 'tiger', 'desk'の3つを入力してみます。

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['strawberry cake', 'tiger', 'desk'],
    cond_scale = 3. # classifier free guidance strength (> 1 would strengthen the condition)
)

def show_images(np_images):
  for i, np_img in enumerate(np_images):
    image = Image.fromarray(np.uint8(np_img * 255))
    display(image)

def save_images(output_dir, np_images):
  os.makedirs(output_dir, exist_ok=True)
  for i, np_img in enumerate(np_images):
    image = Image.fromarray(np.uint8(np_img * 255))
    output_path = os.path.join(output_dir, f'{i}.png')
    image.save(output_path)

images_np = images.cpu().permute(0, 2, 3, 1).numpy()
save_images('./outputs', images_np)
show_images(images_np)

出力結果は以下の通りです。
かろうじて一枚目の画像がcakeの形をしていますが、残りのtiger、deskは関連性のない画像が生成されています。

results

Super Resolution

最後に、出力画像を拡大します。

!python SwinIR/main_test_swinir.py \
  --task real_sr \
  --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth \
  --folder_lq './outputs' \
  --scale 4 \
  --large_model
  
IPythonImage('/content/results/swinir_real_sr_x4_large/0_SwinIR.png')

出力結果は以下の通りです。
解像度を4倍に超解像した画像が生成されています。

result 256

まとめ

本記事では、DALL-E 2の再現実装を用いてテキストから画像を生成しました。
最新の手法ではありますが、十分にトレーニングされた学習済みモデルが存在しないため、その実力をコードベースでは見ることができない点が歯がゆいですね。

膨大なデータセットであるLAIONをトレーニングできるだけのハードウェアリソースをお持ちの方は論文に近い精度のモデルの生成が可能ですが、それ以外の場合学習済みモデルが公開されることを待つしかないかもしれません。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - Hierarchical Text-Conditional Image Generation with CLIP Latents

2. GitHub - lucidrains/DALLE2-pytorch

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology