[ControlNet] AIで落書きからイラストを生成する

2023年2月21日火曜日

Artificial Intelligence

本記事では、ControlNetと呼ばれる機械学習手法を用いて、落書きからイラストを生成する方法をご紹介します。

アイキャッチ

ControlNet

概要

ControlNetは、事前トレーニングされた大規模な拡散モデルを制御するためのニューラルネットワーク構造です。

論文中では、エッジマップ、セグメンテーションマップ、キーポイントなどの条件付き入力ををエンドツーエンドで学習し、50k以下の小さいデータセットでも堅牢に学習すると示され
大規模な拡散モデルにも拡張可能なため、Stable Diffusionの制御方法が充実すると述べられています。

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、落書きからイラストを生成していきます。

デモ(Colaboratory)

それでは、実際に動かしながら落書きからイラストを生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content

!git clone https://github.com/lllyasviel/ControlNet.git

# Commits on Feb 15, 2023
%cd /content/ControlNet
!git checkout f8a359543a6bbc1080b81401d40323cb61822f13

次にライブラリをインストールします。

%cd /content/ControlNet

!pip uninstall torch torchtext -y
!pip install torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1
!pip install omegaconf==2.1.1 einops==0.3.0
!pip install pytorch-lightning==1.5.0
!pip install transformers==4.19.2 open_clip_torch==2.0.2

最後にライブラリをインポートします。

%cd /content/ControlNet

import numpy as np
from PIL import Image as PilImage
import einops
import matplotlib.pyplot as plt

from IPython.display import HTML, Image
from google.colab.output import eval_js
from base64 import b64decode

import torch
from pytorch_lightning import seed_everything

from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler

from annotator.util import resize_image, HWC3

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

ここでは、HuggingFaceから学習済みモデルをダウンロードします。

%cd /content/ControlNet

# download model from huggingface
!wget -c https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/control_sd15_scribble.pth \
      -O ./models/control_sd15_scribble.pth

%cd /content/ControlNet

model = create_model('./models/cldm_v15.yaml')
model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)

プロンプトの設定

ここでは、モデルに入力するテキストプロンプトなどを設定します。

prompt = 'a albino muscular and attractive lion.' # @param {type:"string"}

# active and negative prompt
a_prompt = 'best quality, extremely detailed'
n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'

image_resolution = 512
num_samples = 1
seed = 12
ddim_steps = 20
eta = 0.0
scale = 9.0

落書きの作成

Google Colab上にCanvasを定義し、落書きを作成します。

canvas_html = """
<canvas width=%d height=%d></canvas>
<button>Finish</button>
<script>
var canvas = document.querySelector('canvas')
var ctx = canvas.getContext('2d')
ctx.lineWidth = %d
var button = document.querySelector('button')
var mouse = {x: 0, y: 0}
canvas.addEventListener('mousemove', function(e) {
  mouse.x = e.pageX - this.offsetLeft
  mouse.y = e.pageY - this.offsetTop
})
canvas.onmousedown = ()=>{
  ctx.beginPath()
  ctx.moveTo(mouse.x, mouse.y)
  canvas.addEventListener('mousemove', onPaint)
}
canvas.onmouseup = ()=>{
  canvas.removeEventListener('mousemove', onPaint)
}
var onPaint = ()=>{
  ctx.lineTo(mouse.x, mouse.y)
  ctx.stroke()
}
var data = new Promise(resolve=>{
  button.onclick = ()=>{
    resolve(canvas.toDataURL('image/png'))
  }
})
</script>
"""

def draw(filename='drawing.png', w=512, h=512, line_width=1):
  display(HTML(canvas_html % (w, h, line_width)))
  data = eval_js("data")
  binary = b64decode(data.split(',')[1])
  with open(filename, 'wb') as f:
    f.write(binary)
  return len(binary)
  
draw()

img_path = 'drawing.png'
PilImage.open(img_path)

例として、以下の落書きを作成しました。

scribble iamge

Inference

落書きにpromptを反映し、画像を生成します。

def inference(img_path, prompt):
  # preprocess image
  input_img = np.array(PilImage.open(img_path))
  img = resize_image(HWC3(input_img), image_resolution)
  H, W, C = img.shape

  # initialize detect map
  detected_map = np.zeros_like(img, dtype=np.uint8)
  detected_map[np.min(img, axis=2) < 127] = 255
  control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
  control = torch.stack([control for _ in range(num_samples)], dim=0)
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()

  # set random seed
  seed_everything(seed)

  # get conftioning and unconditioning
  cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
  un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
  shape = (4, H // 8, W // 8)

  # sampling
  samples, intermediates = ddim_sampler.sample(
      ddim_steps, num_samples,
      shape, cond, verbose=False, eta=eta,
      unconditional_guidance_scale=scale,
      unconditional_conditioning=un_cond)
  
  # post process
  x_samples = model.decode_first_stage(samples)
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

  results = [x_samples[i] for i in range(num_samples)]
  np_imgs = [255 - detected_map] + results

  return np_imgs
  
np_imgs = inference(img_path, prompt)

出力結果は以下の通りです。

src = PilImage.fromarray(np_imgs[0])
dst = PilImage.fromarray(np_imgs[1])

fig = plt.figure(figsize=(25, 10))

ax1 = fig.add_subplot(1, 2, 1)
plt.title('Scribble image', fontsize=16)
ax1.axis('off')
ax1.imshow(src)

ax2 = fig.add_subplot(1, 2, 2)
plt.title('Generate image', fontsize=16)
ax2.axis('off')
ax2.imshow(dst)

plt.show()
結果

まとめ

本記事では、ControlNetを用いてStable Diffusionを制御し、落書きからイラストを生成する方法をご紹介しました。

画像の構図の制御が容易になったため、複数回乱数を変えながら画像生成を繰り返すといった作業が減ることが想定されます。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - Adding Conditional Control to Text-to-Image Diffusion Models

2. GitHub - lllyasviel/ControlNet

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology