[DreamBooth] Stable Diffusionを追加学習して独自のモデルを作成 [無料Colab]

2022年10月5日水曜日

Artificial Intelligence

本記事では、DreamBoothと呼ばれる機械学習手法を用いて事前学習済みStable Diffusionを追加学習する方法をご紹介します。

アイキャッチ

Dream Booth

概要

Dreamboothは、Text to Imageタスクを実現する拡散モデルの追加学習手法です。

Dreamboothでは、まず、事前に学習されたText to Imageの拡散モデルに数枚の特定の被写体が写る画像と、識別子となるプロンプトを与え、追加学習させます。

追加学習されたモデルは、識別子を用いた様々なプロンプトに応じて画像を生成することが可能となります。

出力例

上図では、いらすとや様の男の子の画像を数枚追加学習させ、「自転車に乗る男の子」をモデルに出力してもらった結果を示しています。

追加学習させたモデルでは、識別子「男の子」を与えると、いらすとや様の男の子風の画像を生成することが可能になっています。

このように、数枚の画像を追加学習させることにより、モデルをカスタマイズすることが可能になります。

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、Stable Diffusionの事前学習済みモデルを追加学習していきます。

デモ(Colaboratory)

それでは、実際に動かしながらStable Diffusionの追加学習を行っていきます。
本記事にはソースコードの要点を記載しています。全文は下記のGitHubをご参照下さい。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきますが、Hugging FaceのAccess Tokenは取得済みであるという前提のもと進めていきます。
Access Tokenの取得がまだの方は以下の記事をご参照ください。

Access Tokenの取得後、Colaboratoryを開き、下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

はじめに、ライブラリをインストールします。

!pip install -qq git+https://github.com/huggingface/diffusers.git accelerate tensorboard transformers ftfy gradio
!pip install -qq "ipywidgets>=7,<8"
!pip install -qq bitsandbytes

次に、Hugging Faceへのログインを済ませておきます。
以下のセルを実行後表示されるプロンプトに取得したAccess Tokenを入力します。

from huggingface_hub import notebook_login
!git config --global credential.helper store

notebook_login()

最後にライブラリをインポートします。

import argparse
import itertools
import math
import os
from contextlib import nullcontext
import random

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset

import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

import bitsandbytes as bnb

以上で環境セットアップは完了です。

学習済みモデルのセットアップ、追加学習画像の設定

ここでは、追加学習に用いるベースとなるモデルを設定します。

今回はStable diffusionを使用します。

pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4" #@param {type:"string"}

続いて、追加学習させる画像を指定します。

ここでは、いらすとや様の男の子の画像を5枚使用します。

urls = [
    'https://1.bp.blogspot.com/-JLT7FAbMiF4/X3hGHAgEHjI/AAAAAAABboc/OWpiTz5HLHg9A1b5tK7RVz1orj2H0XKvwCNcBGAsYHQ/s250/penlight_man01_blue.png',
    'https://1.bp.blogspot.com/-qrkLCElBrZU/X3hGIGBcv0I/AAAAAAABbos/uumBDo0U7TgnTwvhbqMWEWSlqeA3kuNVACNcBGAsYHQ/s250/penlight_man05_orange.png',
    'https://1.bp.blogspot.com/-TPcG8H-94Q8/XAnwX_0l2FI/AAAAAAABQw4/wJI0CD-SxO4UWBm8YVegxlls031mKKqzgCLcBGAs/s330/stand_naname1_boy.png',
    'https://4.bp.blogspot.com/-ld3QWamGpWI/XAnwYQ4RPeI/AAAAAAABQw8/AcQhuWaS_CAFHfINOJDm3wxmFG57AhmFgCLcBGAs/s380/stand_naname2_school_boy.png',
    'https://1.bp.blogspot.com/-LgVrdZ34XRM/XAnwY68nMrI/AAAAAAABQxA/P71YbDy3z3YCWiX9rLRsPQPw0iUzevxGwCLcBGAs/s400/stand_naname3_man.png',
    'https://1.bp.blogspot.com/-598mkkGFaD8/XAnwZDh48DI/AAAAAAABQxE/y-rIs9aaPQMVaYIuM_LaY2Q1HJojNIT0wCLcBGAs/s400/stand_naname4_businessman.png'
          ]
inputs

識別子の設定

ここでは、学習させる被写体の識別子を定義します。

追加学習させた識別子であることを示す"sks"+オブジェクト名でプロンプトを指定します。
今回はan illustration of sks boyを指定します。

# オブジェクト、画風の説明
instance_prompt = "an illustration of sks boy" #@param {type:"string"}

# コンセプトクラスの指定、画質が向上
prior_preservation = False #@param {type:"boolean"}
prior_preservation_class_prompt = " an illustration of a boy" #@param {type:"string"}

num_class_images = 12 
sample_batch_size = 2
prior_loss_weight = 0.5
prior_preservation_class_folder = "./class_images"
class_data_root=prior_preservation_class_folder
class_prompt=prior_preservation_class_prompt

あとは残りのセルをすべて実行することでトレーニング始まります。
無料のGoogle Colaboratory環境(Tesla T4)で約30分ほどで完了します。

追加学習モデルによるText to Image

それでは、追加学習させたモデルで様々な画像を生成していきます。

pipe = StableDiffusionPipeline.from_pretrained(
        args.output_dir,
        torch_dtype=torch.float16,
    ).to("cuda")

from torch import autocast
prompt = "an real illustration of sks boy swimming sea" #@param {type:"string"}

num_samples = 1 #@param {type:"number"}
num_rows = 1 #@param {type:"number"}

all_images = [] 
for _ in range(num_rows):
    with autocast("cuda"):
        images = pipe([prompt] * num_samples, num_inference_steps=50, guidance_scale=7.5).images
        all_images.extend(images)

grid = image_grid(all_images, num_samples, num_rows)
grid

プロンプトan real illustration of sks boy swimming seaの出力結果は以下の通りです。

出力画像1

an illustration of sks boy eating hamburgerは以下の通りです。

出力画像2

いらすとや様の男の子の画風を継承した様々な画像が生成されています。

まとめ

本記事では、DreamBoothを用いたStable Diffusionの追加学習方法をご紹介しました。

高い表現力を誇るStable Diffusionを使って特定のオブジェクトに特化した画像が生成できるので求める画像が生成しやすくなります。
一方で、悪用厳禁であることは言うまでもありません。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation

2. GitHub - XavierXiao/Dreambooth-Stable-Diffusion

3. GitHub - XavierXiao/Dreambooth-Stable-Diffusion

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology