本記事では、機械学習手法Dalle2の再現実装を用いてテキストから画像を生成する方法をご紹介します。
DALL-E 2
概要
DALL-E
2はOpenAIが発表した画像合成モデルで、DALL-Eの後継となります。主にテキストから画像を生成するText
to Imageタスクに使用されます。
自然言語のテキストを入力に、画像埋め込み(Image
embedding)を出力するpriorモデルと、
画像埋め込みから、画像を生成するdecoderモデルの2stageで構成されています。
priorモデルではCLIPを、decoderモデルでは拡散モデルがベースとして用いられています。
この2段階の構成により、生成画像の写実さと入力テキストとの意味的関連性の損失を抑えつつ、画像の多様性向上を実現しています。
詳細はこちらの論文をご参照ください。
本記事では上記手法を用いて、任意のテキストから画像を生成(Text to
Image)していきます。
デモ(Colaboratory)
それでは、実際に動かしながらテキストから画像を生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo
また、下記から直接Google Colaboratoryで開くこともできます。
なお、本デモはこちらのHuggingfaceに公開された学習済みモデルを使用しますが、論文に掲載されたモデルと比較して1%以下のトレーニングしか実施されていません。これにより精度が低いため、忠実な画像生成ではなく、DALL-E
2の推論のさせ方のご参考としてご覧いただければ幸いです。
また、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
環境セットアップ
それではセットアップしていきます。
Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
初めにGithubからソースコードを取得します。
取得するSwinIRはDALL-E
2が生成した64x64の画像を256x256に拡大するために使用します。
# for SR
!git clone https://github.com/JingyunLiang/SwinIR.git
次にライブラリをインストールします。
!pip install -q dalle2_pytorch==0.15.4
# for swinIR
!pip install -q timm
!pip install -q opencv-python
最後にライブラリをインポートします。
import os
import json
from IPython.display import Image as IPythonImage
from PIL import Image
import numpy as np
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP, OpenAIClipAdapter
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
以上で環境セットアップは完了です。
学習済みモデルのセットアップ
続いて、学習済みモデルをGoogle Colaboratoryにダウンロードします。
prior, decoderのそれぞれをダウンロードします。
!mkdir decoder prior
# Laion2B
!wget -c https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B/latest.pth \
-O ./decoder/latest.pth
!wget -c https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B_laion2B/decoder_config.json \
-O ./decoder/decoder_config.json
# prior
!wget -c https://huggingface.co/zenglishuci/conditioned-prior/resolve/main/vit-l-14/prior_aes_finetune.pth \
-O ./prior/prior_aes_finetune.pth
学習済みモデルのロード
ダウンロードしたモデルをそれぞれロードします。
まず、CLIPとpriorをロードします。
decoder_model_path = './decoder/latest.pth'
decoder_config_path = './decoder/decoder_config.json'
prior_model_path = './prior/prior_aes_finetune.pth'
with open(decoder_config_path, "r") as f:
decoder_config = json.load(f)
clip = OpenAIClipAdapter(decoder_config['decoder']['clip']['model']).to(device)
prior_network = DiffusionPriorNetwork(
dim=768,
depth=24,
dim_head=64,
heads=32,
normformer=True,
attn_dropout=5e-2,
ff_dropout=5e-2,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
num_timesteps=1000,
ff_mult=4
).to(device)
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=clip,
image_embed_dim=768,
timesteps=1000,
cond_drop_prob=0.1,
loss_type="l2",
condition_on_text_encodings=True,
).to(device)
prior_state_dict = torch.load(prior_model_path, map_location='cpu')
diffusion_prior.load_state_dict(prior_state_dict['ema_model'], strict=True)
del prior_state_dict
次に、decoderをロードします。
unet = Unet(
dim = decoder_config['decoder']['unets'][0]['dim'],
cond_dim = decoder_config['decoder']['unets'][0]['cond_dim'],
image_embed_dim = decoder_config['decoder']['unets'][0]['image_embed_dim'],
text_embed_dim = decoder_config['decoder']['unets'][0]['text_embed_dim'],
cond_on_text_encodings = decoder_config['decoder']['unets'][0]['cond_on_text_encodings'],
channels = decoder_config['decoder']['unets'][0]['channels'],
dim_mults = decoder_config['decoder']['unets'][0]['dim_mults'],
num_resnet_blocks= decoder_config['decoder']['unets'][0]['num_resnet_blocks'],
attn_heads= decoder_config['decoder']['unets'][0]['attn_heads'],
attn_dim_head= decoder_config['decoder']['unets'][0]['attn_dim_head'],
sparse_attn=decoder_config['decoder']['unets'][0]['sparse_attn'],
memory_efficient=decoder_config['decoder']['unets'][0]['memory_efficient'],
self_attn =decoder_config['decoder']['unets'][0]['self_attn'],
)
decoder = Decoder(
unet = (unet),
clip = clip,
image_sizes = decoder_config['decoder']['image_sizes'],
channels = decoder_config['decoder']['channels'],
timesteps = decoder_config['decoder']['timesteps'],
loss_type =decoder_config['decoder']['loss_type'],
beta_schedule = decoder_config['decoder']['beta_schedule'],
learned_variance =decoder_config['decoder']['learned_variance'],
).cuda()
decoder_state_dict = torch.load(decoder_model_path, map_location='cpu')
decoder.load_state_dict(decoder_state_dict, strict=False)
decoder.eval()
del decoder_state_dict
Text to Image
ロードしたモデルを使用して、Text to Imageを実行します。
モデルさえロードできてしまえば、あとはテキストをモデルに入力するのみです。
今回は'strawberry cake', 'tiger', 'desk'
の3つを入力してみます。
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
images = dalle2(
['strawberry cake', 'tiger', 'desk'],
cond_scale = 3. # classifier free guidance strength (> 1 would strengthen the condition)
)
def show_images(np_images):
for i, np_img in enumerate(np_images):
image = Image.fromarray(np.uint8(np_img * 255))
display(image)
def save_images(output_dir, np_images):
os.makedirs(output_dir, exist_ok=True)
for i, np_img in enumerate(np_images):
image = Image.fromarray(np.uint8(np_img * 255))
output_path = os.path.join(output_dir, f'{i}.png')
image.save(output_path)
images_np = images.cpu().permute(0, 2, 3, 1).numpy()
save_images('./outputs', images_np)
show_images(images_np)
出力結果は以下の通りです。
かろうじて一枚目の画像がcakeの形をしていますが、残りのtiger、deskは関連性のない画像が生成されています。
Super Resolution
最後に、出力画像を拡大します。
!python SwinIR/main_test_swinir.py \
--task real_sr \
--model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth \
--folder_lq './outputs' \
--scale 4 \
--large_model
IPythonImage('/content/results/swinir_real_sr_x4_large/0_SwinIR.png')
出力結果は以下の通りです。
解像度を4倍に超解像した画像が生成されています。
まとめ
本記事では、DALL-E 2の再現実装を用いてテキストから画像を生成しました。
最新の手法ではありますが、十分にトレーニングされた学習済みモデルが存在しないため、その実力をコードベースでは見ることができない点が歯がゆいですね。
膨大なデータセットであるLAIONをトレーニングできるだけのハードウェアリソースをお持ちの方は論文に近い精度のモデルの生成が可能ですが、それ以外の場合学習済みモデルが公開されることを待つしかないかもしれません。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
リンク
リンク
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
参考文献
1.
論文 - Hierarchical Text-Conditional Image Generation with CLIP Latents
2. GitHub - lucidrains/DALLE2-pytorch
0 件のコメント :
コメントを投稿