[UniDiffuser] 一つのモデルで複数のタスクに対応するdiffusionモデル

2023年3月19日日曜日

Artificial Intelligence

本記事では、UniDiffuserと呼ばれる機械学習手法を用いて画像生成や画像編集などを行う方法をご紹介します。

アイキャッチ
出典: One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale

UniDiffuser

概要

UniDiffuserは、一つのモデルで、複数のタスクに適応するマルチモーダルな拡散モデルです。

UniDiffuserは、拡散モデルに最低限の変更を加えることで周辺分布、条件付き分布、結合分布(marginal, conditional, and joint)のすべてを同時に学習します。

大規模な画像とテキストのペアデータセットでトレーニングされたUniDIffuserは、追加のオーバヘッドなしで適切なタイムステップを設定することにより、
画像生成、テキスト生成、テキストから画像の生成、画像とテキストのペアの生成を実行でき、かつ、すべてのタスクでFIDやCLIPスコアなどの定量的な結果において、既存の汎用モデルより優れていることが論文では示されています。

Architecture
出典: One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、Image Variation(Image to Text to Image)を動かしてみます。

デモ(Colaboratory)

それでは、実際に動かしながらi2t2iを動かしていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content

!git clone https://github.com/thu-ml/unidiffuser.git
!git clone https://github.com/openai/CLIP.git

%cd /content/unidiffuser
# Commits on Mar 15, 2023
!git checkout 4450333d2d95774ad1e6951a5d32016c0be7332d

%cd /content/CLIP
# Commits on Feb 20, 2023
!git checkout a9b1bf5920416aaeaec965c25dd9e8f98c864f16

次にライブラリをインストールします。

%cd /content/CLIP
!pip install -e .

%cd /content/unidiffuser
!pip install accelerate==0.12.0 absl-py ml_collections einops ftfy==6.1.1 transformers==4.23.1
!pip install -U xformers
!pip install -U --pre triton

最後にライブラリをインポートします。

%cd /content/unidiffuser

import os
import sys
sys.path.append(".")
sys.path.append('/content/CLIP')

import ml_collections
import torch
import random
import utils
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
from absl import logging
import einops
import libs.autoencoder
import libs.clip
from torchvision.utils import save_image, make_grid
import torchvision.transforms as standard_transforms
import numpy as np
import clip
from PIL import Image

from libs.uvit_multi_post_ln_v1 import UViT
from libs.caption_decoder import CaptionDecoder


device = 'cuda' if torch.cuda.is_available() else 'cpu'

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

ここでは、論文発表元が公開する学習済みモデルをダウンロードしていきます。

%cd /content/unidiffuser

!mkdir models

!wget -c https://huggingface.co/thu-ml/unidiffuser-v1/resolve/main/autoencoder_kl.pth \
      -O models/autoencoder_kl.pth
!wget -c https://huggingface.co/thu-ml/unidiffuser-v1/resolve/main/caption_decoder.pth \
      -O models/caption_decoder.pth
!wget -c https://huggingface.co/thu-ml/unidiffuser-v1/resolve/main/uvit_v1.pth \
      -O models/uvit_v1.pth

ダウンロードしたモデルをメモリ上にロードします。

# load UniDiffuser-v1.
nnet = UViT(
    img_size=64,
    in_chans=4,
    patch_size=2,
    embed_dim=1536,
    depth=30,
    num_heads=24,
    text_dim=64,
    num_text_tokens=77,
    clip_img_dim=512,
    use_checkpoint=True
)
nnet.to(device)
nnet.load_state_dict(torch.load('models/uvit_v1.pth', map_location='cpu'))
nnet.eval()

# load caption decoder
caption_decoder = CaptionDecoder(device=device, pretrained_path="models/caption_decoder.pth", hidden_dim=64)

clip_text_model = libs.clip.FrozenCLIPEmbedder(device=device)
clip_text_model.eval()
clip_text_model.to(device)

# load autoencoder
autoencoder = libs.autoencoder.get_model(pretrained_path='models/autoencoder_kl.pth')
autoencoder.to(device)

# load clip
clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", device=device, jit=False)

@torch.cuda.amp.autocast()
def encode(_batch):
  return autoencoder.encode(_batch)

@torch.cuda.amp.autocast()
def decode(_batch):
  return autoencoder.decode(_batch)

Utility関数の定義

ここでは、各種Utility関数を定義しておきます。

def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
  _betas = (
      torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
  )
  return _betas.numpy()
_betas = stable_diffusion_beta_schedule()
N = len(_betas)

def split(x):
  C, H, W = 4, 64, 64
  z_dim = C * H * W
  z, clip_img = x.split([z_dim, 512], dim=1)
  z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
  clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=512)
  return z, clip_img

def combine(z, clip_img):
  z = einops.rearrange(z, 'B C H W -> B (C H W)')
  clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
  return torch.concat([z, clip_img], dim=-1)

def combine_joint(z, clip_img, text):
  z = einops.rearrange(z, 'B C H W -> B (C H W)')
  clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
  text = einops.rearrange(text, 'B L D -> B (L D)')
  return torch.concat([z, clip_img, text], dim=-1)

def split_joint(x):
  C, H, W = 4, 64, 64
  z_dim = C * H * W
  z, clip_img, text = x.split([z_dim, 512, 77 * 64], dim=1)
  z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
  clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=512)
  text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=64)
  return z, clip_img, text

def unpreprocess(v):  # to B C H W and [0, 1]
  v = 0.5 * (v + 1.)
  v.clamp_(0., 1.)
  return v


def set_seed(seed: int):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

def watermarking(save_path):
  img_pre = Image.open(save_path)
  img_pos = utils.add_water(img_pre)
  img_pos.save(save_path)

ハイパーパラメータの設定

ここでは、タスクや、テキストプロンプト、入力画像パスなどを設定します。

以下では、i2t2iタスク(画像のバリエーションを生成するタスク)を選択しています。

# 入力画像のダウンロード
!wget -c https://www.musey.net/wp-content/uploads/2014/08/Vincent-van-Gogh-Starry-Night_1920x1200-1024x640.jpg \
      -O ./assets/test_01.jpg

mode = "i2t2i" #@param ["t2i", "i2t", "joint", "i", "t", "t2i2t", "i2t2i"]
"""
t2i: text-to-image generation
i2t: image-to-text generation
joint: joint generation
i: image generation
t: text generation
t2i2t: text variation
i2t2i: image variation
"""
prompt = "" #@param {type:"string"}
img = './assets/test_01.jpg' #@param {type:"string"}
seed = 1234 #@param {type:"number"}
steps = 50 #@param {type:"slider", min:0, max:100, step:1}
cfg_scale = 8 #@param {type:"slider", min:0, max:10, step:0.1}
n_samples = 4 #@param {type:"number"}
nrow = 1 #@param {type:"number"}
data_type = 1
output_path = 'out' #@param {type:"string"}

タスクに応じて、テキストプロンプトや画像をロードします。

if mode == 't2i' or mode == 't2i2t':
  prompts = [ prompt ] * n_samples
  contexts = clip_text_model.encode(prompts)
  contexts_low_dim = caption_decoder.encode_prefix(contexts)
elif mode == 'i2t' or mode == 'i2t2i':
  from PIL import Image
  img_contexts = []
  clip_imgs = []

  def get_img_feature(image):
    image = np.array(image).astype(np.uint8)
    image = utils.center_crop(512, 512, image)
    clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))
    
    image = (image / 127.5 - 1.0).astype(np.float32)
    image = einops.rearrange(image, 'h w c -> 1 c h w')
    image = torch.tensor(image, device=device)
    moments = autoencoder.encode_moments(image)
    
    return clip_img_feature, moments

  image = Image.open(img).convert('RGB')
  clip_img, img_context = get_img_feature(image)

  img_contexts.append(img_context)
  clip_imgs.append(clip_img)
  img_contexts = img_contexts * n_samples
  clip_imgs = clip_imgs * n_samples

  img_contexts = torch.concat(img_contexts, dim=0)
  z_img = autoencoder.sample(img_contexts)
  clip_imgs = torch.stack(clip_imgs, dim=0)

タスクごとのネットワーク定義

ここでは、それぞれのタスクに対応したネットワークを定義します。

def t2i_nnet(x, timesteps, text):  # text is the low dimension version of the text clip embedding
  """
  1. calculate the conditional model output
  2. calculate unconditional model output
      config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
      config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
  3. return linear combination of conditional output and unconditional output
  """
  z, clip_img = split(x)

  t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)

  z_out, clip_img_out, text_out = nnet(
      z, clip_img, 
      text=text, t_img=timesteps, t_text=t_text,
      data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + data_type)
  x_out = combine(z_out, clip_img_out)

  if cfg_scale == 0.:
    return x_out

  text_N = torch.randn_like(text)  # 3 other possible choices
  z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(
      z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,
      data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + data_type)
  x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)

  return x_out + cfg_scale * (x_out - x_out_uncond)

def i_nnet(x, timesteps):
  z, clip_img = split(x)
  text = torch.randn(x.size(0), 77, 64, device=device)
  t_text = torch.ones_like(timesteps) * N
  z_out, clip_img_out, text_out = nnet(
      z, clip_img, text=text, t_img=timesteps, t_text=t_text,
      data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + data_type)
  x_out = combine(z_out, clip_img_out)
  return x_out

def t_nnet(x, timesteps):
  z = torch.randn(x.size(0), *[4, 64, 64], device=device)
  clip_img = torch.randn(x.size(0), 1, 512, device=device)
  z_out, clip_img_out, text_out = nnet(
      z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
      data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)
  return text_out

def i2t_nnet(x, timesteps, z, clip_img):
  """
  1. calculate the conditional model output
  2. calculate unconditional model output
  3. return linear combination of conditional output and unconditional output
  """
  t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)

  z_out, clip_img_out, text_out = nnet(
      z, clip_img, text=x, t_img=t_img, t_text=timesteps,
      data_type=torch.zeros_like(t_img, device=device, dtype=torch.int) + data_type)

  if cfg_scale == 0.:
    return text_out

  z_N = torch.randn_like(z)  # 3 other possible choices
  clip_img_N = torch.randn_like(clip_img)
  z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(
      z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
      data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)

  return text_out + cfg_scale * (text_out - text_out_uncond)

def joint_nnet(x, timesteps):
  z, clip_img, text = split_joint(x)
  z_out, clip_img_out, text_out = nnet(
      z, clip_img, text=text, t_img=timesteps, t_text=timesteps,
      data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)
  x_out = combine_joint(z_out, clip_img_out, text_out)

  if cfg_scale == 0.:
    return x_out

  z_noise = torch.randn(x.size(0), *(4, 64, 64), device=device)
  clip_img_noise = torch.randn(x.size(0), 1, 512, device=device)
  text_noise = torch.randn(x.size(0), 77, 64, device=device)

  _, _, text_out_uncond = nnet(
      z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
      data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)
  z_out_uncond, clip_img_out_uncond, _ = nnet(
      z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,
      data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)

  x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond)

  return x_out + cfg_scale * (x_out - x_out_uncond)


def sample_fn(mode, **kwargs):
  _z_init = torch.randn(n_samples, *(4, 64, 64), device=device)
  _clip_img_init = torch.randn(n_samples, 1, 512, device=device)
  _text_init = torch.randn(n_samples, 77, 64, device=device)
  if mode == 'joint':
    _x_init = combine_joint(_z_init, _clip_img_init, _text_init)
  elif mode in ['t2i', 'i']:
    _x_init = combine(_z_init, _clip_img_init)
  elif mode in ['i2t', 't']:
    _x_init = _text_init
  noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())

  def model_fn(x, t_continuous):
    t = t_continuous * N
    if mode == 'joint':
      return joint_nnet(x, t)
    elif mode == 't2i':
      return t2i_nnet(x, t, **kwargs)
    elif mode == 'i2t':
      return i2t_nnet(x, t, **kwargs)
    elif mode == 'i':
      return i_nnet(x, t)
    elif mode == 't':
      return t_nnet(x, t)

  dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
  with torch.no_grad():
    with torch.autocast(device_type=device):
      x = dpm_solver.sample(_x_init, steps=steps, eps=1. / N, T=1.)

  os.makedirs(output_path, exist_ok=True)
  if mode == 'joint':
    _z, _clip_img, _text = split_joint(x)
    return _z, _clip_img, _text
  elif mode in ['t2i', 'i']:
    _z, _clip_img = split(x)
    return _z, _clip_img
  elif mode in ['i2t', 't']:
    return x

Sampling

それでは、設定したタスクのサンプリングを実行し画像を生成してみます。

set_seed(seed)

def show(path):
  samples = Image.open(path)
  display(samples)
  


if mode in ['joint']:
  _z, _clip_img, _text = sample_fn(mode)
  samples = unpreprocess(decode(_z))
  prompts = caption_decoder.generate_captions(_text)
  os.makedirs(os.path.join(output_path, mode), exist_ok=True)
  print(prompts)
  with open(os.path.join(output_path, mode, 'prompts.txt'), 'w') as f:
    print('\n'.join(prompts), file=f)
  for idx, sample in enumerate(samples):
    save_path = os.path.join(output_path, mode, f'{idx}.png')
    save_image(sample, save_path)
    watermarking(save_path)
  # save a grid of generated images
  samples_pos = []
  for idx, sample in enumerate(samples):
    sample_pil = standard_transforms.ToPILImage()(sample)
    sample_pil = utils.add_water(sample_pil)
    sample = standard_transforms.ToTensor()(sample_pil)
    samples_pos.append(sample)
  samples = make_grid(samples_pos, nrow)
  save_path = os.path.join(output_path, mode, f'grid.png')
  save_image(samples, save_path)
  show(save_path)

elif mode in ['t2i', 'i', 'i2t2i']:
  if mode == 't2i':
    _z, _clip_img = sample_fn(mode, text=contexts_low_dim)  # conditioned on the text embedding
  elif mode == 'i':
    _z, _clip_img = sample_fn(mode)
  elif mode == 'i2t2i':
    _text = sample_fn('i2t', z=z_img, clip_img=clip_imgs)  # conditioned on the image embedding
    _z, _clip_img = sample_fn('t2i', text=_text)
  samples = unpreprocess(decode(_z))
  os.makedirs(os.path.join(output_path, mode), exist_ok=True)
  for idx, sample in enumerate(samples):
    save_path = os.path.join(output_path, mode, f'{idx}.png')
    save_image(sample, save_path)
    watermarking(save_path)
  # save a grid of generated images
  samples_pos = []
  for idx, sample in enumerate(samples):
    sample_pil = standard_transforms.ToPILImage()(sample)
    sample_pil = utils.add_water(sample_pil)
    sample = standard_transforms.ToTensor()(sample_pil)
    samples_pos.append(sample)
  samples = make_grid(samples_pos, nrow)
  save_path = os.path.join(output_path, mode, f'grid.png')
  save_image(samples, save_path)
  show(save_path)

elif mode in ['i2t', 't', 't2i2t']:
  if mode == 'i2t':
    _text = sample_fn(mode, z=z_img, clip_img=clip_imgs)  # conditioned on the image embedding
  elif mode == 't':
    _text = sample_fn(mode)
  elif mode == 't2i2t':
    _z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
    _text = sample_fn('i2t', z=_z, clip_img=_clip_img)
  samples = caption_decoder.generate_captions(_text)
  print(samples)
  logging.info(samples)
  os.makedirs(os.path.join(output_path, mode), exist_ok=True)
  with open(os.path.join(output_path, mode, f'{mode}.txt'), 'w') as f:
    print('\n'.join(samples), file=f)

出力結果は以下の通りです。
入力画像から、テキストを生成し、そのテキストから画像が生成されています。

variation image

まとめ

本記事では、UniDIffuserを用いて、様々なタスクを実行する方法をご紹介しました。
他のタスクを実行したい場合mode = "i2t2i"を切り替え適切なプロンプトを設定するだけで実行することができます。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale

2. GitHub - thu-ml/unidiffuser

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology