本記事では、UniDiffuserと呼ばれる機械学習手法を用いて画像生成や画像編集などを行う方法をご紹介します。
UniDiffuser
概要
UniDiffuserは、一つのモデルで、複数のタスクに適応するマルチモーダルな拡散モデルです。
UniDiffuserは、拡散モデルに最低限の変更を加えることで周辺分布、条件付き分布、結合分布(marginal, conditional, and joint)のすべてを同時に学習します。
大規模な画像とテキストのペアデータセットでトレーニングされたUniDIffuserは、追加のオーバヘッドなしで適切なタイムステップを設定することにより、
画像生成、テキスト生成、テキストから画像の生成、画像とテキストのペアの生成を実行でき、かつ、すべてのタスクでFIDやCLIPスコアなどの定量的な結果において、既存の汎用モデルより優れていることが論文では示されています。
詳細はこちらの論文をご参照ください。
本記事では上記手法を用いて、Image Variation(Image to Text to Image)を動かしてみます。
デモ(Colaboratory)
それでは、実際に動かしながらi2t2iを動かしていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo
また、下記から直接Google Colaboratoryで開くこともできます。
なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
環境セットアップ
それではセットアップしていきます。
Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
初めにGithubからソースコードを取得します。
%cd /content
!git clone https://github.com/thu-ml/unidiffuser.git
!git clone https://github.com/openai/CLIP.git
%cd /content/unidiffuser
# Commits on Mar 15, 2023
!git checkout 4450333d2d95774ad1e6951a5d32016c0be7332d
%cd /content/CLIP
# Commits on Feb 20, 2023
!git checkout a9b1bf5920416aaeaec965c25dd9e8f98c864f16
次にライブラリをインストールします。
%cd /content/CLIP
!pip install -e .
%cd /content/unidiffuser
!pip install accelerate==0.12.0 absl-py ml_collections einops ftfy==6.1.1 transformers==4.23.1
!pip install -U xformers
!pip install -U --pre triton
最後にライブラリをインポートします。
%cd /content/unidiffuser
import os
import sys
sys.path.append(".")
sys.path.append('/content/CLIP')
import ml_collections
import torch
import random
import utils
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
from absl import logging
import einops
import libs.autoencoder
import libs.clip
from torchvision.utils import save_image, make_grid
import torchvision.transforms as standard_transforms
import numpy as np
import clip
from PIL import Image
from libs.uvit_multi_post_ln_v1 import UViT
from libs.caption_decoder import CaptionDecoder
device = 'cuda' if torch.cuda.is_available() else 'cpu'
以上で環境セットアップは完了です。
学習済みモデルのセットアップ
ここでは、論文発表元が公開する学習済みモデルをダウンロードしていきます。
%cd /content/unidiffuser
!mkdir models
!wget -c https://huggingface.co/thu-ml/unidiffuser-v1/resolve/main/autoencoder_kl.pth \
-O models/autoencoder_kl.pth
!wget -c https://huggingface.co/thu-ml/unidiffuser-v1/resolve/main/caption_decoder.pth \
-O models/caption_decoder.pth
!wget -c https://huggingface.co/thu-ml/unidiffuser-v1/resolve/main/uvit_v1.pth \
-O models/uvit_v1.pth
ダウンロードしたモデルをメモリ上にロードします。
# load UniDiffuser-v1.
nnet = UViT(
img_size=64,
in_chans=4,
patch_size=2,
embed_dim=1536,
depth=30,
num_heads=24,
text_dim=64,
num_text_tokens=77,
clip_img_dim=512,
use_checkpoint=True
)
nnet.to(device)
nnet.load_state_dict(torch.load('models/uvit_v1.pth', map_location='cpu'))
nnet.eval()
# load caption decoder
caption_decoder = CaptionDecoder(device=device, pretrained_path="models/caption_decoder.pth", hidden_dim=64)
clip_text_model = libs.clip.FrozenCLIPEmbedder(device=device)
clip_text_model.eval()
clip_text_model.to(device)
# load autoencoder
autoencoder = libs.autoencoder.get_model(pretrained_path='models/autoencoder_kl.pth')
autoencoder.to(device)
# load clip
clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", device=device, jit=False)
@torch.cuda.amp.autocast()
def encode(_batch):
return autoencoder.encode(_batch)
@torch.cuda.amp.autocast()
def decode(_batch):
return autoencoder.decode(_batch)
Utility関数の定義
ここでは、各種Utility関数を定義しておきます。
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
_betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
return _betas.numpy()
_betas = stable_diffusion_beta_schedule()
N = len(_betas)
def split(x):
C, H, W = 4, 64, 64
z_dim = C * H * W
z, clip_img = x.split([z_dim, 512], dim=1)
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=512)
return z, clip_img
def combine(z, clip_img):
z = einops.rearrange(z, 'B C H W -> B (C H W)')
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
return torch.concat([z, clip_img], dim=-1)
def combine_joint(z, clip_img, text):
z = einops.rearrange(z, 'B C H W -> B (C H W)')
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
text = einops.rearrange(text, 'B L D -> B (L D)')
return torch.concat([z, clip_img, text], dim=-1)
def split_joint(x):
C, H, W = 4, 64, 64
z_dim = C * H * W
z, clip_img, text = x.split([z_dim, 512, 77 * 64], dim=1)
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=512)
text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=64)
return z, clip_img, text
def unpreprocess(v): # to B C H W and [0, 1]
v = 0.5 * (v + 1.)
v.clamp_(0., 1.)
return v
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def watermarking(save_path):
img_pre = Image.open(save_path)
img_pos = utils.add_water(img_pre)
img_pos.save(save_path)
ハイパーパラメータの設定
ここでは、タスクや、テキストプロンプト、入力画像パスなどを設定します。
以下では、i2t2iタスク(画像のバリエーションを生成するタスク)を選択しています。
# 入力画像のダウンロード
!wget -c https://www.musey.net/wp-content/uploads/2014/08/Vincent-van-Gogh-Starry-Night_1920x1200-1024x640.jpg \
-O ./assets/test_01.jpg
mode = "i2t2i" #@param ["t2i", "i2t", "joint", "i", "t", "t2i2t", "i2t2i"]
"""
t2i: text-to-image generation
i2t: image-to-text generation
joint: joint generation
i: image generation
t: text generation
t2i2t: text variation
i2t2i: image variation
"""
prompt = "" #@param {type:"string"}
img = './assets/test_01.jpg' #@param {type:"string"}
seed = 1234 #@param {type:"number"}
steps = 50 #@param {type:"slider", min:0, max:100, step:1}
cfg_scale = 8 #@param {type:"slider", min:0, max:10, step:0.1}
n_samples = 4 #@param {type:"number"}
nrow = 1 #@param {type:"number"}
data_type = 1
output_path = 'out' #@param {type:"string"}
タスクに応じて、テキストプロンプトや画像をロードします。
if mode == 't2i' or mode == 't2i2t':
prompts = [ prompt ] * n_samples
contexts = clip_text_model.encode(prompts)
contexts_low_dim = caption_decoder.encode_prefix(contexts)
elif mode == 'i2t' or mode == 'i2t2i':
from PIL import Image
img_contexts = []
clip_imgs = []
def get_img_feature(image):
image = np.array(image).astype(np.uint8)
image = utils.center_crop(512, 512, image)
clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))
image = (image / 127.5 - 1.0).astype(np.float32)
image = einops.rearrange(image, 'h w c -> 1 c h w')
image = torch.tensor(image, device=device)
moments = autoencoder.encode_moments(image)
return clip_img_feature, moments
image = Image.open(img).convert('RGB')
clip_img, img_context = get_img_feature(image)
img_contexts.append(img_context)
clip_imgs.append(clip_img)
img_contexts = img_contexts * n_samples
clip_imgs = clip_imgs * n_samples
img_contexts = torch.concat(img_contexts, dim=0)
z_img = autoencoder.sample(img_contexts)
clip_imgs = torch.stack(clip_imgs, dim=0)
タスクごとのネットワーク定義
ここでは、それぞれのタスクに対応したネットワークを定義します。
def t2i_nnet(x, timesteps, text): # text is the low dimension version of the text clip embedding
"""
1. calculate the conditional model output
2. calculate unconditional model output
config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
3. return linear combination of conditional output and unconditional output
"""
z, clip_img = split(x)
t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
z_out, clip_img_out, text_out = nnet(
z, clip_img,
text=text, t_img=timesteps, t_text=t_text,
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + data_type)
x_out = combine(z_out, clip_img_out)
if cfg_scale == 0.:
return x_out
text_N = torch.randn_like(text) # 3 other possible choices
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(
z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + data_type)
x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
return x_out + cfg_scale * (x_out - x_out_uncond)
def i_nnet(x, timesteps):
z, clip_img = split(x)
text = torch.randn(x.size(0), 77, 64, device=device)
t_text = torch.ones_like(timesteps) * N
z_out, clip_img_out, text_out = nnet(
z, clip_img, text=text, t_img=timesteps, t_text=t_text,
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + data_type)
x_out = combine(z_out, clip_img_out)
return x_out
def t_nnet(x, timesteps):
z = torch.randn(x.size(0), *[4, 64, 64], device=device)
clip_img = torch.randn(x.size(0), 1, 512, device=device)
z_out, clip_img_out, text_out = nnet(
z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)
return text_out
def i2t_nnet(x, timesteps, z, clip_img):
"""
1. calculate the conditional model output
2. calculate unconditional model output
3. return linear combination of conditional output and unconditional output
"""
t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
z_out, clip_img_out, text_out = nnet(
z, clip_img, text=x, t_img=t_img, t_text=timesteps,
data_type=torch.zeros_like(t_img, device=device, dtype=torch.int) + data_type)
if cfg_scale == 0.:
return text_out
z_N = torch.randn_like(z) # 3 other possible choices
clip_img_N = torch.randn_like(clip_img)
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(
z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)
return text_out + cfg_scale * (text_out - text_out_uncond)
def joint_nnet(x, timesteps):
z, clip_img, text = split_joint(x)
z_out, clip_img_out, text_out = nnet(
z, clip_img, text=text, t_img=timesteps, t_text=timesteps,
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)
x_out = combine_joint(z_out, clip_img_out, text_out)
if cfg_scale == 0.:
return x_out
z_noise = torch.randn(x.size(0), *(4, 64, 64), device=device)
clip_img_noise = torch.randn(x.size(0), 1, 512, device=device)
text_noise = torch.randn(x.size(0), 77, 64, device=device)
_, _, text_out_uncond = nnet(
z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)
z_out_uncond, clip_img_out_uncond, _ = nnet(
z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)
x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond)
return x_out + cfg_scale * (x_out - x_out_uncond)
def sample_fn(mode, **kwargs):
_z_init = torch.randn(n_samples, *(4, 64, 64), device=device)
_clip_img_init = torch.randn(n_samples, 1, 512, device=device)
_text_init = torch.randn(n_samples, 77, 64, device=device)
if mode == 'joint':
_x_init = combine_joint(_z_init, _clip_img_init, _text_init)
elif mode in ['t2i', 'i']:
_x_init = combine(_z_init, _clip_img_init)
elif mode in ['i2t', 't']:
_x_init = _text_init
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
def model_fn(x, t_continuous):
t = t_continuous * N
if mode == 'joint':
return joint_nnet(x, t)
elif mode == 't2i':
return t2i_nnet(x, t, **kwargs)
elif mode == 'i2t':
return i2t_nnet(x, t, **kwargs)
elif mode == 'i':
return i_nnet(x, t)
elif mode == 't':
return t_nnet(x, t)
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
with torch.no_grad():
with torch.autocast(device_type=device):
x = dpm_solver.sample(_x_init, steps=steps, eps=1. / N, T=1.)
os.makedirs(output_path, exist_ok=True)
if mode == 'joint':
_z, _clip_img, _text = split_joint(x)
return _z, _clip_img, _text
elif mode in ['t2i', 'i']:
_z, _clip_img = split(x)
return _z, _clip_img
elif mode in ['i2t', 't']:
return x
Sampling
それでは、設定したタスクのサンプリングを実行し画像を生成してみます。
set_seed(seed)
def show(path):
samples = Image.open(path)
display(samples)
if mode in ['joint']:
_z, _clip_img, _text = sample_fn(mode)
samples = unpreprocess(decode(_z))
prompts = caption_decoder.generate_captions(_text)
os.makedirs(os.path.join(output_path, mode), exist_ok=True)
print(prompts)
with open(os.path.join(output_path, mode, 'prompts.txt'), 'w') as f:
print('\n'.join(prompts), file=f)
for idx, sample in enumerate(samples):
save_path = os.path.join(output_path, mode, f'{idx}.png')
save_image(sample, save_path)
watermarking(save_path)
# save a grid of generated images
samples_pos = []
for idx, sample in enumerate(samples):
sample_pil = standard_transforms.ToPILImage()(sample)
sample_pil = utils.add_water(sample_pil)
sample = standard_transforms.ToTensor()(sample_pil)
samples_pos.append(sample)
samples = make_grid(samples_pos, nrow)
save_path = os.path.join(output_path, mode, f'grid.png')
save_image(samples, save_path)
show(save_path)
elif mode in ['t2i', 'i', 'i2t2i']:
if mode == 't2i':
_z, _clip_img = sample_fn(mode, text=contexts_low_dim) # conditioned on the text embedding
elif mode == 'i':
_z, _clip_img = sample_fn(mode)
elif mode == 'i2t2i':
_text = sample_fn('i2t', z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
_z, _clip_img = sample_fn('t2i', text=_text)
samples = unpreprocess(decode(_z))
os.makedirs(os.path.join(output_path, mode), exist_ok=True)
for idx, sample in enumerate(samples):
save_path = os.path.join(output_path, mode, f'{idx}.png')
save_image(sample, save_path)
watermarking(save_path)
# save a grid of generated images
samples_pos = []
for idx, sample in enumerate(samples):
sample_pil = standard_transforms.ToPILImage()(sample)
sample_pil = utils.add_water(sample_pil)
sample = standard_transforms.ToTensor()(sample_pil)
samples_pos.append(sample)
samples = make_grid(samples_pos, nrow)
save_path = os.path.join(output_path, mode, f'grid.png')
save_image(samples, save_path)
show(save_path)
elif mode in ['i2t', 't', 't2i2t']:
if mode == 'i2t':
_text = sample_fn(mode, z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
elif mode == 't':
_text = sample_fn(mode)
elif mode == 't2i2t':
_z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
_text = sample_fn('i2t', z=_z, clip_img=_clip_img)
samples = caption_decoder.generate_captions(_text)
print(samples)
logging.info(samples)
os.makedirs(os.path.join(output_path, mode), exist_ok=True)
with open(os.path.join(output_path, mode, f'{mode}.txt'), 'w') as f:
print('\n'.join(samples), file=f)
出力結果は以下の通りです。
入力画像から、テキストを生成し、そのテキストから画像が生成されています。
まとめ
本記事では、UniDIffuserを用いて、様々なタスクを実行する方法をご紹介しました。
他のタスクを実行したい場合mode = "i2t2i"
を切り替え適切なプロンプトを設定するだけで実行することができます。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
リンク
リンク
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
参考文献
1.
論文 - One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale
2. GitHub - thu-ml/unidiffuser
0 件のコメント :
コメントを投稿