本記事では、ControlNetと呼ばれる機械学習手法を用いて、落書きからイラストを生成する方法をご紹介します。
ControlNet
概要
ControlNetは、事前トレーニングされた大規模な拡散モデルを制御するためのニューラルネットワーク構造です。
論文中では、エッジマップ、セグメンテーションマップ、キーポイントなどの条件付き入力ををエンドツーエンドで学習し、50k以下の小さいデータセットでも堅牢に学習すると示され
大規模な拡散モデルにも拡張可能なため、Stable Diffusionの制御方法が充実すると述べられています。
詳細はこちらの論文をご参照ください。
本記事では上記手法を用いて、落書きからイラストを生成していきます。
デモ(Colaboratory)
それでは、実際に動かしながら落書きからイラストを生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo
また、下記から直接Google Colaboratoryで開くこともできます。
なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
環境セットアップ
それではセットアップしていきます。
Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
初めにGithubからソースコードを取得します。
%cd /content
!git clone https://github.com/lllyasviel/ControlNet.git
# Commits on Feb 15, 2023
%cd /content/ControlNet
!git checkout f8a359543a6bbc1080b81401d40323cb61822f13
次にライブラリをインストールします。
%cd /content/ControlNet
!pip uninstall torch torchtext -y
!pip install torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1
!pip install omegaconf==2.1.1 einops==0.3.0
!pip install pytorch-lightning==1.5.0
!pip install transformers==4.19.2 open_clip_torch==2.0.2
最後にライブラリをインポートします。
%cd /content/ControlNet
import numpy as np
from PIL import Image as PilImage
import einops
import matplotlib.pyplot as plt
from IPython.display import HTML, Image
from google.colab.output import eval_js
from base64 import b64decode
import torch
from pytorch_lightning import seed_everything
from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler
from annotator.util import resize_image, HWC3
以上で環境セットアップは完了です。
学習済みモデルのセットアップ
ここでは、HuggingFaceから学習済みモデルをダウンロードします。
%cd /content/ControlNet
# download model from huggingface
!wget -c https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/control_sd15_scribble.pth \
-O ./models/control_sd15_scribble.pth
%cd /content/ControlNet
model = create_model('./models/cldm_v15.yaml')
model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)
プロンプトの設定
ここでは、モデルに入力するテキストプロンプトなどを設定します。
prompt = 'a albino muscular and attractive lion.' # @param {type:"string"}
# active and negative prompt
a_prompt = 'best quality, extremely detailed'
n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
image_resolution = 512
num_samples = 1
seed = 12
ddim_steps = 20
eta = 0.0
scale = 9.0
落書きの作成
Google Colab上にCanvasを定義し、落書きを作成します。
canvas_html = """
<canvas width=%d height=%d></canvas>
<button>Finish</button>
<script>
var canvas = document.querySelector('canvas')
var ctx = canvas.getContext('2d')
ctx.lineWidth = %d
var button = document.querySelector('button')
var mouse = {x: 0, y: 0}
canvas.addEventListener('mousemove', function(e) {
mouse.x = e.pageX - this.offsetLeft
mouse.y = e.pageY - this.offsetTop
})
canvas.onmousedown = ()=>{
ctx.beginPath()
ctx.moveTo(mouse.x, mouse.y)
canvas.addEventListener('mousemove', onPaint)
}
canvas.onmouseup = ()=>{
canvas.removeEventListener('mousemove', onPaint)
}
var onPaint = ()=>{
ctx.lineTo(mouse.x, mouse.y)
ctx.stroke()
}
var data = new Promise(resolve=>{
button.onclick = ()=>{
resolve(canvas.toDataURL('image/png'))
}
})
</script>
"""
def draw(filename='drawing.png', w=512, h=512, line_width=1):
display(HTML(canvas_html % (w, h, line_width)))
data = eval_js("data")
binary = b64decode(data.split(',')[1])
with open(filename, 'wb') as f:
f.write(binary)
return len(binary)
draw()
img_path = 'drawing.png'
PilImage.open(img_path)
例として、以下の落書きを作成しました。
Inference
落書きにpromptを反映し、画像を生成します。
def inference(img_path, prompt):
# preprocess image
input_img = np.array(PilImage.open(img_path))
img = resize_image(HWC3(input_img), image_resolution)
H, W, C = img.shape
# initialize detect map
detected_map = np.zeros_like(img, dtype=np.uint8)
detected_map[np.min(img, axis=2) < 127] = 255
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
# set random seed
seed_everything(seed)
# get conftioning and unconditioning
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)
# sampling
samples, intermediates = ddim_sampler.sample(
ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
# post process
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)]
np_imgs = [255 - detected_map] + results
return np_imgs
np_imgs = inference(img_path, prompt)
出力結果は以下の通りです。
src = PilImage.fromarray(np_imgs[0])
dst = PilImage.fromarray(np_imgs[1])
fig = plt.figure(figsize=(25, 10))
ax1 = fig.add_subplot(1, 2, 1)
plt.title('Scribble image', fontsize=16)
ax1.axis('off')
ax1.imshow(src)
ax2 = fig.add_subplot(1, 2, 2)
plt.title('Generate image', fontsize=16)
ax2.axis('off')
ax2.imshow(dst)
plt.show()
まとめ
本記事では、ControlNetを用いてStable Diffusionを制御し、落書きからイラストを生成する方法をご紹介しました。
画像の構図の制御が容易になったため、複数回乱数を変えながら画像生成を繰り返すといった作業が減ることが想定されます。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
リンク
リンク
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
参考文献
1.
論文 - Adding Conditional Control to Text-to-Image Diffusion Models
2. GitHub - lllyasviel/ControlNet
0 件のコメント :
コメントを投稿