本記事では、DreamBoothと呼ばれる機械学習手法を用いて事前学習済みStable Diffusionを追加学習する方法をご紹介します。
Dream Booth
概要
Dreamboothは、Text to Imageタスクを実現する拡散モデルの追加学習手法です。
Dreamboothでは、まず、事前に学習されたText to
Imageの拡散モデルに数枚の特定の被写体が写る画像と、識別子となるプロンプトを与え、追加学習させます。
追加学習されたモデルは、識別子を用いた様々なプロンプトに応じて画像を生成することが可能となります。
上図では、いらすとや様の男の子の画像を数枚追加学習させ、「自転車に乗る男の子」をモデルに出力してもらった結果を示しています。
追加学習させたモデルでは、識別子「男の子」を与えると、いらすとや様の男の子風の画像を生成することが可能になっています。
このように、数枚の画像を追加学習させることにより、モデルをカスタマイズすることが可能になります。
詳細はこちらの論文をご参照ください。
本記事では上記手法を用いて、Stable
Diffusionの事前学習済みモデルを追加学習していきます。
デモ(Colaboratory)
それでは、実際に動かしながらStable Diffusionの追加学習を行っていきます。
本記事にはソースコードの要点を記載しています。全文は下記のGitHubをご参照下さい。
GitHub - Colaboratory demo
また、下記から直接Google Colaboratoryで開くこともできます。
なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
環境セットアップ
それではセットアップしていきますが、Hugging FaceのAccess
Tokenは取得済みであるという前提のもと進めていきます。
Access Tokenの取得がまだの方は以下の記事をご参照ください。
[Stable Diffusion] AIでテキストから画像を生成する[text2img]
本記事では、機械学習手法Stable
Diffusionを用いてテキストから画像を生成する方法をご紹介しています。
Access
Tokenの取得後、Colaboratoryを開き、下記を設定しGPUを使用するようにしてください。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
はじめに、ライブラリをインストールします。
!pip install -qq git+https://github.com/huggingface/diffusers.git accelerate tensorboard transformers ftfy gradio
!pip install -qq "ipywidgets>=7,<8"
!pip install -qq bitsandbytes
次に、Hugging Faceへのログインを済ませておきます。
以下のセルを実行後表示されるプロンプトに取得したAccessTokenを入力します。
from huggingface_hub import notebook_login
!git config --global credential.helper store
notebook_login()
最後にライブラリをインポートします。
import argparse
import itertools
import math
import os
from contextlib import nullcontext
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import bitsandbytes as bnb
以上で環境セットアップは完了です。
学習済みモデルのセットアップ、追加学習画像の設定
ここでは、追加学習に用いるベースとなるモデルを設定します。
今回はStable diffusionを使用します。
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4" #@param {type:"string"}
続いて、追加学習させる画像を指定します。
ここでは、いらすとや様の男の子の画像を5枚使用します。
urls = [
'https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiCSk6WI_XdfyJdIab2Iebe3KZtcvlN4izAAJZMENM2qz7UH7QP3_M-zp6bmopKLTGo4_Z6M1sd0-ZRm7xFQeH1xbCjteSL4EBdygOiEfoUGP5Gx7QHrvZSxLusmgcQQerQofQHNMEts64y/s250/penlight_man01_blue.png',
'https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgZ5F3JkOSY-AsniE91XOcEEVf_hpTrc07G9VdOo0AuqEIOSFLDb96OScI00vZsQfcVfzM5nwDmRg3m40LNv0W10Gz0vpTa8WOMn5wUJMk6leURo9Owb76VnmqrkFx-2dcJjGv_K2IeobER/s250/penlight_man05_orange.png',
'https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEirusu8jzE5W0PVsa9Czgx5izBJIpKHomRYROtGKmly0_WSpFPtoJdskjTW3nW-4BEaVkM7Mm_XjCTfA8tbXCw-TRPVENIthTmqWbBDx_UXVkdOHeykUIKNz7MrVX8-6O8ZYEZwhYoyxM9C/s330/stand_naname1_boy.png',
'https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjdJK13Y1oP76dyVbVtPCI6tq1i1EfOzzRNrElxEvXB7lFCCr3GJphn1yCtMfqTugRa24LsfqRcylAbwjU5s501QbG1C4mXY4EgW3QxLCO8xf5lsUCCejoU_au3sxZhsYg62JlnwgIJ2gsK/s380/stand_naname2_school_boy.png',
'https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg66dr3gmrwA9X9hRQYJgZKMUh5E6xII5blwlOmZT_t3jYDTSoAfkRJ_gYnn0ChnDpC1qi8svE2X99cU9SMq1QSj1CIzjj7IlPiv9hza05QT7HKoUjR33P4qU7kv5sKQG_jgGjFzGwdewSQ/s400/stand_naname3_man.png',
'https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgtJcfF9UV1libQ1VhBXGxqt98VL1vObciKA6HUsUpD9jM7IcFncOBuqrbDZa2uoN1NMpy-11WGxlg9G_zlHrVXQucRB8mGAbmdDHnJjGh8CvXS1axe9PTyfaziuQhcVpqb9XFF1qknJqTC/s400/stand_naname4_businessman.png'
]
識別子の設定
ここでは、学習させる被写体の識別子を定義します。
追加学習させた識別子であることを示す"sks"+オブジェクト名でプロンプトを指定します。
今回はan illustration of sks boy
を指定します。
# オブジェクト、画風の説明
instance_prompt = "an illustration of sks boy" #@param {type:"string"}
# コンセプトクラスの指定、画質が向上
prior_preservation = False #@param {type:"boolean"}
prior_preservation_class_prompt = " an illustration of a boy" #@param {type:"string"}
num_class_images = 12
sample_batch_size = 2
prior_loss_weight = 0.5
prior_preservation_class_folder = "./class_images"
class_data_root=prior_preservation_class_folder
class_prompt=prior_preservation_class_prompt
あとは残りのセルをすべて実行することでトレーニング始まります。
無料のGoogle Colaboratory環境(Tesla T4)で約30分ほどで完了します。
追加学習モデルによるText to Image
それでは、追加学習させたモデルで様々な画像を生成していきます。
pipe = StableDiffusionPipeline.from_pretrained(
args.output_dir,
torch_dtype=torch.float16,
).to("cuda")
from torch import autocast
prompt = "an real illustration of sks boy swimming sea" #@param {type:"string"}
num_samples = 1 #@param {type:"number"}
num_rows = 1 #@param {type:"number"}
all_images = []
for _ in range(num_rows):
with autocast("cuda"):
images = pipe([prompt] * num_samples, num_inference_steps=50, guidance_scale=7.5).images
all_images.extend(images)
grid = image_grid(all_images, num_samples, num_rows)
grid
プロンプトan real illustration of sks boy swimming sea
の出力結果は以下の通りです。
an illustration of sks boy eating hamburger
は以下の通りです。
いらすとや様の男の子の画風を継承した様々な画像が生成されています。
まとめ
本記事では、DreamBoothを用いたStable Diffusionの追加学習方法をご紹介しました。
高い表現力を誇るStable Diffusionを使って特定のオブジェクトに特化した画像が生成できるので求める画像が生成しやすくなります。
一方で、悪用厳禁であることは言うまでもありません。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
リンク
リンク
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
参考文献
1. 論文 - DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation
2. GitHub - XavierXiao/Dreambooth-Stable-Diffusion
3. GitHub - XavierXiao/Dreambooth-Stable-Diffusion
text_encoder = CLIPTextModel.from_pretrained() あたりで、リクエストエラーになるみたいで、HTTPError Traceback (most recent call last)
返信削除/usr/local/lib/python3.7/dist-packages/huggingface_hub/utils/_errors.py in hf_raise_for_status(response, endpoint_name) 先に進めません。なにか、設定ミスしているか、アクセストークン関係のなにかを間違えてしまったからでしょうか。
OSErrorでなくHTTPErrorのためHuggingFaceのログインは完了しているが、Stable-diffusion-v1-4の条項に同意していない状態だと考えられます。
返信削除https://huggingface.co/CompVis/stable-diffusion-v1-4
上記リンクにアクセスし、 I have read the License and agree with its termsにチェックした後に、Agree and access repositoryにクリックする必要があります。
>あとは残りのセルをすべて実行することでトレーニング始まります。
返信削除残りのセルとはどこの、何のことでしょうか?
https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/Dreambooth_demo.ipynbに記載の「Advanced settings for prior preservation (optional)」以降のセルです。
返信削除Google Colabの「ランタイム」→「以降のセルを実行」から実行ください。
最後のセルで"NameError: name 'args' is not defined"が発生してしまいます。直し方を教えてほしいです。
削除args = Namespace(
削除から始まるセルが実行されておらずargsが定義されていないものと思われます。
args = Namespace(
から始まるセルを実行できているかご確認ください。
はじめまして。
返信削除こちらの方法で追加学習させたモデルを、google driveにダウンロードする方法はありますか。
https://github.com/kaz12tech/ai_demos/blob/main/Dreambooth_demo.ipynb
削除上記の最後にGoogle Driveに保存するコードを追加しました。
ImportError Traceback (most recent call last)
返信削除in
17 from accelerate.utils import set_seed
18 from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
---> 19 from diffusers.hub_utils import init_git_repo, push_to_hub
20 from diffusers.optimization import get_scheduler
21 from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
ImportError: cannot import name 'init_git_repo' from 'diffusers.hub_utils' (/usr/local/lib/python3.8/dist-packages/diffusers/hub_utils.py)
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some
ライブラリをインポートする際にこういったエラーが出ます。
何か解決策は御座いますか。
diffusersのソースコードが更新されinit_git_repo, push_to_hubが削除されたことにより発生しているエラーです。
削除下記の行を削除していただくことで解消します。
from diffusers.hub_utils import init_git_repo, push_to_hub
本ブログのソースコードも修正いたしました。
保存したモデルを使って再度イラストを出力するにはどうすればよいでしょうか?
返信削除「追加学習モデルによるText to Image」のソースコードが保存したモデルを使って再度イラストを出力しています。
削除書き込み失礼致します。
返信削除こちらのGoogleclabを使用させて頂いてGoogledriveにモデルを保存したのですが、stable diffusion に適用されるckptファイルはなくjsonファイルがありました。再度イラストを生成する時にこのファイルをどう使ったら良いか教えて頂けると大変有り難いのですが、お願い出来ますでしょうか?無知で申し訳ないです。
Google Driveに保存した結果は以下の通りになります。
削除もしファイルが欠落している場合は、Google Driveの容量が上限に達しコピーに失敗していると考えられます。
/content/drive/MyDrive/dreambooth-concept/
├── feature_extractor
│ └── preprocessor_config.json
├── model_index.json
├── safety_checker
│ ├── config.json
│ └── pytorch_model.bin
├── scheduler
│ └── scheduler_config.json
├── text_encoder
│ ├── config.json
│ └── pytorch_model.bin
├── tokenizer
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── unet
│ ├── config.json
│ └── diffusion_pytorch_model.bin
└── vae
├── config.json
└── diffusion_pytorch_model.bin
使用する際は、以下の"args.output_dir"を"/content/drive/MyDrive/dreambooth-concept/"に変更することでGoogle Driveに保存したモデルを使用することができます。
pipe = StableDiffusionPipeline.from_pretrained(
args.output_dir,
torch_dtype=torch.float16,
).to("cuda")
何度も申し訳御座いません。
削除教えていただいた箇所を書き換えて実行したのですが、
File "", line 2
/content/drive/MyDrive/dreambooth-concept/,
^
SyntaxError: invalid syntax
と失敗しているようです
(文字はコピーペーストしてそのまま置き換えているので誤字はないと思います)
Google Driveのファイル名まで書いて指定する必要があるのでしょうか?
最後のスラッシュを除いた以下はいかがでしょうか?
削除pipe = StableDiffusionPipeline.from_pretrained(
'/content/drive/MyDrive/dreambooth-concept',
torch_dtype=torch.float16,
).to("cuda")
度々本当に申し訳ありません。
削除教えて頂いた通り入力すると
---------------------------------------------------------------------------
HFValidationError Traceback (most recent call last)
/usr/local/lib/python3.8/dist-packages/diffusers/configuration_utils.py in load_config(cls, pretrained_model_name_or_path, return_unused_kwargs, **kwargs)
325 # Load from URL or cache if already cached
--> 326 config_file = hf_hub_download(
327 pretrained_model_name_or_path,
4 frames
HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/content/drive/MyDrive/dreambooth-concept'. Use `repo_type` argument if needed.
During handling of the above exception, another exception occurred:
OSError Traceback (most recent call last)
/usr/local/lib/python3.8/dist-packages/diffusers/configuration_utils.py in load_config(cls, pretrained_model_name_or_path, return_unused_kwargs, **kwargs)
361 )
362 except ValueError:
--> 363 raise EnvironmentError(
364 f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
365 f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
OSError: We couldn't connect to 'https://huggingface.co' to load this model, couldn't find it in the cached files and it looks like /content/drive/MyDrive/dreambooth-concept is not the path to a directory containing a model_index.json file.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'.
とエラー結果が出てしまいました。
これはノートブックのトレーニングより前のセルを全て実行した後に推論を実行すると出てしまいました。
トレーニングセルはとても時間がかかるので出来ればここを実行せずに作成したモデルを呼び出したいのですが不可能でしょうか?もしよろしければご教授頂けると大変有り難いです。
https://github.com/kaz12tech/ai_demos/blob/main/Dreambooth_demo.ipynb
削除上記にGoogle Driveに保存したモデルで推論する例を追加いたしました。
恐れ入ります。
返信削除トレーニングセルを実行する際短時間で完了できたのですが、stabble defusionのデフォルト画像しか出力されなくなりました。
改善して頂いたのに申し訳ございません。
max_train_stepsを10→300に修正しました。申し訳ございません。
削除動作確認時に学習ステップを少なくしたままにしておりました。
初めまして。質問失礼します。
返信削除学習済みモデル、追加学習画像のセットアップ の項目で、画像のURLを貼った後、画像のImport requestを実行した所でエラーが出ます。
import requests
import glob
from io import BytesIO
def download_image(url):
try:
response = requests.get(url)
except:
return None
return Image.open(BytesIO(response.content)).convert("RGB")
images = list(filter(None,[download_image(url) for url in urls]))
save_path = "./my_concept"
if not os.path.exists(save_path):
os.mkdir(save_path)
[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
image_grid(images, 1, len(images))
エラー内容
in :18
NameError: name 'image_grid' is not defined
原因と解決法がもし分かればご教示頂けないでしょうか。
ご覧いただきありがとうございます。
削除image_grid関数が定義されていないことが原因で発生しているエラーです。
下記セルが実行されているかご確認ください。
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
お世話になっております。
返信削除こちらのページの内容を順番に実行しており、
「追加学習モデルによるText to Image」まで進みましたが、
「"NameError: name 'args' is not defined"」というエラーが発生し、うまくいかないです。
コメント欄にも同じ人がいて、回答もついているのですが、理解できず…💦
お手数をおかけしますが、解消方法について回答をお願いいたします。
ご覧いただきありがとうございます。
削除一度下記からGoogle Colabを開いていただき「Hugging Faceログイン」まで完了後、
「ライブラリのインポート」のセルにて"ランタイム"→"以降のセルを実行"で動作完了可能かご確認いただけますでしょうか。