[MaskGIT] 機械学習で画像生成、画像編集 [Python]

2022年4月16日土曜日

Artificial Intelligence

本記事では、Masked Generative Image Transformer(MaskGIT)と呼ばれる機械学習手法を用いて、クラス条件付き画像生成、画像編集を行う方法をご紹介します。

アイキャッチ
出典: google-research/maskgit

MaskGIT: Masked Generative Image Transformer

概要

Masked Generative Image Transformer(以下、MaskGIT)は、双方向トランスデコーダ(bidirectional transformer decoder)を使用した新しい画像合成パラダイムです。

従来の画像合成を実現するTransformerは、画像をトークンのシーケンスとして扱い、ラスタースキャンのように行ごとに画像をデコードします。

MaskGITでは、双方向トランスデコーダを用いて、すべての方向のトークンに注意を払い、ランダムにマスクされたトークンを予測することを学習します。推論時には、画像のすべてのトークンを同時に生成することから始め、次のEpochでそれらの画像を改良していきます。

MaskGITは、ImageNetデータセットを用いた最先端のTransformerモデルを大幅に上回り、自己回帰デコードを最大64倍高速化すると論文で示されています。
前述の画像の通り、MaskGITは、Image Generation, Image Manipulation, Image Extrapolationなどの様々な画像編集タスクに拡張できます。

compare method

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、クラス条件付き画像生成・画像編集を行います。

デモ(Colaboratory)

それでは、実際に動かしながらクラス条件付き画像生成・画像編集を行っていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

また、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めに、論文発表元のGithubからソースコードを取得します

%cd /content

!git clone https://github.com/google-research/maskgit.git

# TypeError: take_along_axis indices must be of integer type, got float32回避
!sed -E -i "s/mask_len\)\)/mask_len\)\).astype\('int32'\)/" /content/maskgit/maskgit/libml/parallel_decode.py

次にライブラリをインストールします。

%cd /content

!pip install jax flax
!pip install numpy tensorflow matplotlib ml_collections

最後にライブラリをインポートしておきます。

%cd /content/maskgit

import numpy as np
import jax
import jax.numpy as jnp
import os
import itertools
from timeit import default_timer as timer
from PIL import Image
import matplotlib.pyplot as plt

import maskgit
from maskgit.utils import visualize_images, read_image_from_url, restore_from_path, draw_image_with_bbox, Bbox
from maskgit.inference import ImageNet_class_conditional_generator

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

続いて、論文発表元が提供している学習済みモデルをダウンロードします。

%cd /content/maskgit

!mkdir -p checkpoints/

models_to_download = itertools.product( 
    *[ ["maskgit", "tokenizer"],   [256, 512] ])

for (type_, resolution) in models_to_download:
  canonical_path = ImageNet_class_conditional_generator.checkpoint_canonical_path(type_, resolution)
  if os.path.isfile(canonical_path):
    print(f"Checkpoint for {resolution} {type_} already exists, not downloading again")
  else:
    source_url = f'https://storage.googleapis.com/maskgit-public/checkpoints/{type_}_imagenet{resolution}_checkpoint'
    !wget {source_url} -O {canonical_path}

checkpointsに学習済みモデルがダウンロードされます。

Class-conditional Image Synthesis(クラス条件付き画像生成)

それでは、まずImageNetのラベルをクラスとして指定し、クラスに応じた画像生成を行っていきます。
はじめにモデルをビルドします。

generator_256 = ImageNet_class_conditional_generator(image_size=256)

SEED = 13 #@param {type:"integer"}
rng = jax.random.PRNGKey(SEED)

p_generate_256_samples = generator_256.p_generate_samples()

SEEDは任意の整数を指定してください。SEEDによって生成される画像が変化します。

次に生成する画像のクラスを定義します。今回は黒い白鳥を指定します。

category = "100) black swan, Cygnus atratus"

それでは、画像を生成してみます。
なお、初回生成時はTesla K80で45秒ほど、2回目以降は25秒ほど生成時間に要します。

image_size = 256

rng, sample_rng = jax.random.split(rng)

start_timer = timer()
results = generator_256.generate_samples(input_tokens, sample_rng)
end_timer = timer()

print(f"generated {generator_256.eval_batch_size()} images in {end_timer - start_timer} seconds")

# Visualize
visualize_images(results, title=f'results')

出力結果は以下の通りです。

black swan seed 13

SEEDを12に変更した結果は以下の通りです。

black swan seed 12

写真のように鮮明ですね。

Class-conditional Image Editing(クラス条件付き画像編集)

続いて画像編集を行います。

先ほどと同様にモデルをビルドします。

generator_512 = ImageNet_class_conditional_generator(image_size=512)

SEED = 123 #@param {type:"integer"}
rng = jax.random.PRNGKey(SEED)

p_edit_512_samples = generator_512.p_edit_samples()

次にクラスを定義します。

今回はショッピングバスケットを設定してみます。

category = "790) shopping basket"

次に、入力するテスト画像をダウンロードします。
今回はぱくたそ様の画像を使用させていただきます。

!mkdir test_imgs

IMG_URL = "https://www.pakutaso.com/shared/img/thumb/TM20152A201110-0030_TP_V4.jpg" # @param {type:"string"}
img_path = "test_imgs/TM20152A201110-0030_TP_V4.jpg" # @param {type:"string"}

!wget -P test_imgs {IMG_URL}

# ダウンロードした画像を512x512にcrop
pil_img = Image.open(img_path).convert('RGB')
crop_pil_img = pil_img.crop((150, 0, 662, 512))

plt.imshow(np.array(crop_pil_img))

モデルの入力に合わせて512x512にクロップした画像は以下の通りです。

入力画像

次に、画像のどの範囲をショッピングバスケットに編集するか範囲を指定します。

image_size = 512

bbox_top_left_height_width = '130_80_320_335' # @param {type:"string"}

bbox = Bbox(bbox_top_left_height_width)

# Load the input image, and visualize it with our bounding box
pil_image = crop_pil_img.resize((image_size, image_size), Image.BICUBIC)
image = np.float32(pil_image) / 255.

draw_image_with_bbox(image, bbox)

latent_mask, input_tokens = generator_512.create_latent_mask_and_input_tokens_for_image_editing(
    image, bbox, label)

pmap_input_tokens = generator_512.pmap_input_tokens(input_tokens)

画像編集を実行します。

Tesla K80で約60秒ほど要します。

rng, sample_rng = jax.random.split(rng)

start_timer = timer()
results = generator_512.generate_samples(
    input_tokens, 
    sample_rng, 
    start_iter=2, 
    num_iterations=12 
    )
end_timer = timer()
print(f"edited {generator_512.eval_batch_size()} images in {end_timer - start_timer} seconds")

composite_images = generator_512.composite_outputs(image, latent_mask, results)

visualize_images(composite_images, title=f'outputs')

出力結果は以下の通りです。

Image manipulation結果

入力画像の犬の名残を残しつつ、ショッピングバスケットが追加された画像が多く出力されました。やや怖いですね。

まとめ

本記事では、MaskGITを用いたクラス条件付き画像生成、画像編集を行いました。
指定範囲の編集は精度が良ければ様々な使い方が考えられます。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - MaskGIT: Masked Generative Image Transformer

2. GitHub - google-research/maskgit

AIエンジニア向けフリーランスならここがおすすめです

まずは無料会員登録

プロフィール

自分の写真
製造業に勤務する傍ら、日々AIの技術動向を調査しブログにアウトプットしています。 AIに関するご相談やお仕事のご依頼はブログのお問い合わせフォームか以下のアドレスまでお気軽にお問い合わせください。 bhupb13511@yahoo.co.jp

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology