[Pez Dispenser] AIでStable Diffusionのプロンプトを最適化する

2023年2月12日日曜日

Artificial Intelligence

本記事では、Pez Dispenserと呼ばれる機械学習手法を用いて、任意の画像を生成するために最適なプロンプトを探索する方法をご紹介します。

アイキャッチ
出典: YuxinWenRick/hard-prompts-made-easy

Pez Dispenser

概要

Pez Dispenserは、テキストから画像または、テキストからテキストへのタスクにおいて、テキストプロンプトを効果的に最適化する機械学習手法です。

Pez Dispenserでは、効果的なグラデーションベースの最適化方法により、テキストプロンプトを最適化します。
この手法により、ユーザーは、テキストプロンプトに関する事前知識がなくとも、画像の概念を表現したプロンプトを生成することができ、Text to Imageタスクなどの利用が簡単になります。

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、任意の画像のから最適化されたプロンプトを生成していきます。

デモ(Colaboratory)

それでは、実際に動かしながらテキストプロンプトの最適化を行います。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content

!git clone https://github.com/YuxinWenRick/hard-prompts-made-easy.git

# using Commits on Feb 10, 2023
%cd /content/hard-prompts-made-easy
!git checkout 9d6254a77fa4aa440cc83507cfaf210b35204d16

次にライブラリをインストールします。

%cd /content/hard-prompts-made-easy

# for PEZ dispenser
!pip3 install transformers==4.23.1 sentence-transformers==2.2.2 ftfy==6.1.1 mediapy==1.1.2 diffusers==0.11.1

最後にライブラリをインポートします。

%cd /content/hard-prompts-made-easy

import argparse

import torch
device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device is", device)

import open_clip
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline

import mediapy as media

from optim_utils import (download_image, optimize_prompt)

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

ここでは、論文発表元が公開する学習済みモデルをダウンロードします。

まずCLIPモデルをダウンロードします。

clip_model = 'ViT-H-14'
clip_pretrain = 'laion2b_s32b_b79k'

# load clip model
model, _, clip_preprocess = open_clip.create_model_and_transforms(
    clip_model, 
    pretrained = clip_pretrain,
    device=device)

続いて、Stable Diffusionをダウンロードします。

model_id = "stabilityai/stable-diffusion-2-1-base"

# load scheduler
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id, 
    subfolder="scheduler")
# load stable diffusion model
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    scheduler=scheduler,
    torch_dtype=torch.float16,
    revision="fp16")
pipe = pipe.to(device)
# setting image length
image_length = 512

テスト画像のセットアップ

続いて、プロンプトを生成するための入力画像をWeb上から取得します。

# define image urls
urls = [
        "https://www.pakutaso.com/shared/img/thumb/tokyoPAUI3257_TP_V4.jpg",
       ]

# download image
orig_images = list(filter(None,[download_image(url) for url in urls]))
# show image
media.show_images(orig_images, height=512)
input画像

Optimize Prompt

それでは、画像から最適なテキストプロンプトを探索していきます。

まず探索時のコンフィグを設定します。

args = argparse.Namespace()

args.prompt_len = 16
args.iter = 1500
args.lr = 0.1
args.weight_decay = 0.1
args.prompt_bs = 1
args.print_step = 100
args.batch_size = 1
args.clip_model = clip_model
args.clip_pretrain = clip_pretrain

args

設定に従い最適化を実行します。

# target imageを表現する最適なpromptの探索
learned_prompt = optimize_prompt(
    model, 
    clip_preprocess, 
    args, 
    device, 
    target_images = orig_images)

コサイン類似度と共に、探索されたプロンプトが出力されます。

best cosine sim: 0.45574772357940674
best prompt: dayo charge markets cityscape lit exterior 🇩🇪 cre partner relocation atin yokohama ultimatefancorriarbitration lawyer

生成されたプロンプトを使用して、Stable Diffusionで画像を生成してみます。

prompt = learned_prompt

num_images = 4
guidance_scale = 9
num_inference_steps = 25

images = pipe(
    prompt,
    num_images_per_prompt = num_images,
    guidance_scale = guidance_scale,
    num_inference_steps = num_inference_steps,
    height = image_length,
    width = image_length,
    generator = torch.Generator(device).manual_seed(0)
    ).images

print("prompt:", prompt)
media.show_images(images, width=128)

出力結果は以下の通りです。
元画像に近い画像が出力されています。

output result

まとめ

本記事では、Pez Dispenserを用いてテキストプロンプトを最適化する方法をご紹介しました。
これまでStable Diffusionに入力するテキストを呪文として覚えておく必要がありましたが、それすらも必要なくなってきているのかもしれません。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - Hard Prompts Made Easy: Gradient-Based Discrete Optimization for Prompt Tuning and Discovery

2. GitHub - YuxinWenRick/hard-prompts-made-easy

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology