[StyleGAN] AIでユーザー操作に応じて画像を編集する [UserControllableLT]

2023年6月1日木曜日

Artificial Intelligence

本記事では、UserControllableLTと呼ばれる機械学習手法を用いて画像をリアルタイムに編集する方法をご紹介します。

eye catch
出典: endo-yuki-t/UserControllableLT

UserControllableLT

概要

User-Controllable Latent Transformer(以下、UserControllableLT)は、ユーザー入力に従って潜在コードを操作し、StyleGAN imageを編集する対話型フレームワークです。

このフレームワークでは、ユーザーは移動したい箇所、移動したくない箇所に注釈をつけ、マウスのドラッグによって移動方向を指定します。
これらの入力と、初期の潜在コードからトランスフォーマーベースの latent transformerは潜在コードを推定し、StyleGAN Generatorで画像を出力しています。

overview
出典: User-Controllable Latent Transformer for StyleGAN Image Layout Editing

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、ユーザー操作に応じたStyleGAN imageの画像編集を行っていきます。

デモ(Colaboratory)

それでは、実際に動かしながら画像編集を行っていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content

!git clone https://huggingface.co/spaces/radames/UserControllableLT-Latent-Transformer UserControllableLT

%cd /content/UserControllableLT
# Commits on May 31, 2023
!git checkout a8d6a354abcf249b64e6ccf9b5949832877b345b

次にライブラリをインストールします。

%cd /content/UserControllableLT

!pip install -r requirements.txt

以上で環境セットアップは完了です。

Launch Gradio

ここでは、Gradioを使用してアプリを起動します。
GradioでWebアプリを定義し起動しています

%cd /content/UserControllableLT/interface

import gradio as gr
import sys

sys.path.append(".")
sys.path.append("..")
from model_loader import Model
from inversion import InversionModel
from PIL import Image
import cv2
from huggingface_hub import snapshot_download
import json

# disable if running on another environment
RESIZE = True

models_path = snapshot_download(repo_id="radames/UserControllableLT", repo_type="model")


# models fron pretrained/latent_transformer folder
models_files = {
    "anime": "anime.pt",
    "car": "car.pt",
    "cat": "cat.pt",
    "church": "church.pt",
    "ffhq": "ffhq.pt",
}

models = {name: Model(models_path + "/" + path) for name, path in models_files.items()}
inversion_model = InversionModel(
    models_path + "/psp_ffhq_encode.pt",
    models_path + "/shape_predictor_68_face_landmarks.dat",
)

canvas_html = """<draggan-canvas id="canvas-root" style='display:flex;max-width: 500px;margin: 0 auto;'></draggan-canvas>"""
load_js = """
async () => {
  const script = document.createElement('script');
  script.type = "module"
  script.src = "file=custom_component.js"
  document.head.appendChild(script);
}
"""
image_change = """
async (base64img) => {
  const canvasEl = document.getElementById("canvas-root");
  canvasEl.loadBase64Image(base64img);
}   
"""
reset_stop_points = """
async () => {
  const canvasEl = document.getElementById("canvas-root");
  canvasEl.resetStopPoints();
}
"""

default_dxdysxsy = json.dumps(
    {"dx": 1, "dy": 0, "sx": 128, "sy": 128, "stopPoints": []}
)


def cv_to_pil(img):
    img = Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB))
    if RESIZE:
        img = img.resize((128, 128))
    return img


def random_sample(model_name: str):
    model = models[model_name]
    img, latents = model.random_sample()
    img_pil = cv_to_pil(img)
    return img_pil, model_name, latents


def load_from_img_file(image_path: str):
    img_pil, latents = inversion_model.inference(image_path)
    if RESIZE:
        img_pil = img_pil.resize((128, 128))
    return img_pil, "ffhq", latents


def transform(model_state, latents_state, dxdysxsy=default_dxdysxsy, dz=0):
    if "w1" not in latents_state or "w1_initial" not in latents_state:
        raise gr.Error("Generate a random sample first")

    data = json.loads(dxdysxsy)

    model = models[model_state]
    dx = int(data["dx"])
    dy = int(data["dy"])
    sx = int(data["sx"])
    sy = int(data["sy"])
    stop_points = [[int(x), int(y)] for x, y in data["stopPoints"]]
    img, latents_state = model.transform(
        latents_state, dz, dxy=[dx, dy], sxsy=[sx, sy], stop_points=stop_points
    )
    img_pil = cv_to_pil(img)
    return img_pil, latents_state


def change_style(image: Image.Image, model_state, latents_state):
    model = models[model_state]
    img, latents_state = model.change_style(latents_state)
    img_pil = cv_to_pil(img)
    return img_pil, latents_state


def reset(model_state, latents_state):
    model = models[model_state]
    img, latents_state = model.reset(latents_state)
    img_pil = cv_to_pil(img)
    return img_pil, latents_state


def image_click(evt: gr.SelectData):
    click_pos = evt.index
    return click_pos


with gr.Blocks() as block:
    model_state = gr.State(value="ffhq")
    latents_state = gr.State({})
    gr.Markdown(
        """# UserControllableLT: User Controllable Latent Transformer
Unofficial Gradio Demo

**Author**: Yuki Endo\\
**Paper**:  [2208.12408](https://huggingface.co/papers/2208.12408)\\
**Code**: [UserControllableLT](https://github.com/endo-yuki-t/UserControllableLT)

<small>
Double click to add or remove stop points.
<small>
"""
    )

    with gr.Row():
        with gr.Column():
            model_name = gr.Dropdown(
                choices=list(models_files.keys()),
                label="Select Pretrained Model",
                value="ffhq",
            )
            with gr.Row():
                button = gr.Button("Random sample")
                reset_btn = gr.Button("Reset")
                change_style_bt = gr.Button("Change style")
            dxdysxsy = gr.Textbox(
                label="dxdysxsy",
                value=default_dxdysxsy,
                elem_id="dxdysxsy",
                visible=False,
            )
            dz = gr.Slider(
                minimum=-15, maximum=15, step_size=0.01, label="zoom", value=0.0
            )
            image = gr.Image(type="pil", visible=False, preprocess=False)
            with gr.Accordion(label="Upload your face image", open=False):
                gr.Markdown("<small> This only works on FFHQ model </small>")
                with gr.Row():
                    image_path = gr.Image(
                        type="filepath", label="input image", interactive=True
                    )
                    examples = gr.Examples(
                        examples=[
                            "./examples/benedict.jpg",
                            "./examples/obama.jpg",
                            "./examples/me.jpg",
                        ],
                        fn=load_from_img_file,
                        run_on_click=True,
                        inputs=[image_path],
                        outputs=[image, model_state, latents_state],
                    )
        with gr.Column():
            html = gr.HTML(canvas_html, label="output")

    button.click(
        random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
    )
    reset_btn.click(
        reset,
        inputs=[model_state, latents_state],
        outputs=[image, latents_state],
        queue=False,
    ).then(None, None, None, _js=reset_stop_points, queue=False)

    change_style_bt.click(
        change_style,
        inputs=[image, model_state, latents_state],
        outputs=[image, latents_state],
    )
    dxdysxsy.change(
        transform,
        inputs=[model_state, latents_state, dxdysxsy, dz],
        outputs=[image, latents_state],
        show_progress=False,
    )
    dz.change(
        transform,
        inputs=[model_state, latents_state, dxdysxsy, dz],
        outputs=[image, latents_state],
        show_progress=False,
    )
    image.change(None, inputs=[image], outputs=None, _js=image_change)
    image_path.upload(
        load_from_img_file,
        inputs=[image_path],
        outputs=[image, model_state, latents_state],
    )

    block.load(None, None, None, _js=load_js)
    block.load(
        random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
    )

block.queue(api_open=False)
block.launch(show_api=False)

モデルのロードなどが完了した後以下のように、GradioのWebアプリがインライン表示されます。
以下のようにドラッグ操作で顔の向きの編集が可能です。処理速度上ややもっさりしますが、顔の向きが変わっていることが確認できます。

result

まとめ

本記事では、UserControllableLTを用いたマウス操作による画像編集をご紹介しました。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - User-Controllable Latent Transformer for StyleGAN Image Layout Editing

2. GitHub - endo-yuki-t/UserControllableLT

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology