本記事では、UserControllableLTと呼ばれる機械学習手法を用いて画像をリアルタイムに編集する方法をご紹介します。
UserControllableLT
概要
User-Controllable Latent Transformer(以下、UserControllableLT)は、ユーザー入力に従って潜在コードを操作し、StyleGAN imageを編集する対話型フレームワークです。
このフレームワークでは、ユーザーは移動したい箇所、移動したくない箇所に注釈をつけ、マウスのドラッグによって移動方向を指定します。
これらの入力と、初期の潜在コードからトランスフォーマーベースの latent transformerは潜在コードを推定し、StyleGAN Generatorで画像を出力しています。
詳細はこちらの論文をご参照ください。
本記事では上記手法を用いて、ユーザー操作に応じたStyleGAN imageの画像編集を行っていきます。
デモ(Colaboratory)
それでは、実際に動かしながら画像編集を行っていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo
また、下記から直接Google Colaboratoryで開くこともできます。
なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
環境セットアップ
それではセットアップしていきます。
Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
初めにGithubからソースコードを取得します。
%cd /content
!git clone https://huggingface.co/spaces/radames/UserControllableLT-Latent-Transformer UserControllableLT
%cd /content/UserControllableLT
# Commits on May 31, 2023
!git checkout a8d6a354abcf249b64e6ccf9b5949832877b345b
次にライブラリをインストールします。
%cd /content/UserControllableLT
!pip install -r requirements.txt
以上で環境セットアップは完了です。
Launch Gradio
ここでは、Gradioを使用してアプリを起動します。
GradioでWebアプリを定義し起動しています
%cd /content/UserControllableLT/interface
import gradio as gr
import sys
sys.path.append(".")
sys.path.append("..")
from model_loader import Model
from inversion import InversionModel
from PIL import Image
import cv2
from huggingface_hub import snapshot_download
import json
# disable if running on another environment
RESIZE = True
models_path = snapshot_download(repo_id="radames/UserControllableLT", repo_type="model")
# models fron pretrained/latent_transformer folder
models_files = {
"anime": "anime.pt",
"car": "car.pt",
"cat": "cat.pt",
"church": "church.pt",
"ffhq": "ffhq.pt",
}
models = {name: Model(models_path + "/" + path) for name, path in models_files.items()}
inversion_model = InversionModel(
models_path + "/psp_ffhq_encode.pt",
models_path + "/shape_predictor_68_face_landmarks.dat",
)
canvas_html = """<draggan-canvas id="canvas-root" style='display:flex;max-width: 500px;margin: 0 auto;'></draggan-canvas>"""
load_js = """
async () => {
const script = document.createElement('script');
script.type = "module"
script.src = "file=custom_component.js"
document.head.appendChild(script);
}
"""
image_change = """
async (base64img) => {
const canvasEl = document.getElementById("canvas-root");
canvasEl.loadBase64Image(base64img);
}
"""
reset_stop_points = """
async () => {
const canvasEl = document.getElementById("canvas-root");
canvasEl.resetStopPoints();
}
"""
default_dxdysxsy = json.dumps(
{"dx": 1, "dy": 0, "sx": 128, "sy": 128, "stopPoints": []}
)
def cv_to_pil(img):
img = Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB))
if RESIZE:
img = img.resize((128, 128))
return img
def random_sample(model_name: str):
model = models[model_name]
img, latents = model.random_sample()
img_pil = cv_to_pil(img)
return img_pil, model_name, latents
def load_from_img_file(image_path: str):
img_pil, latents = inversion_model.inference(image_path)
if RESIZE:
img_pil = img_pil.resize((128, 128))
return img_pil, "ffhq", latents
def transform(model_state, latents_state, dxdysxsy=default_dxdysxsy, dz=0):
if "w1" not in latents_state or "w1_initial" not in latents_state:
raise gr.Error("Generate a random sample first")
data = json.loads(dxdysxsy)
model = models[model_state]
dx = int(data["dx"])
dy = int(data["dy"])
sx = int(data["sx"])
sy = int(data["sy"])
stop_points = [[int(x), int(y)] for x, y in data["stopPoints"]]
img, latents_state = model.transform(
latents_state, dz, dxy=[dx, dy], sxsy=[sx, sy], stop_points=stop_points
)
img_pil = cv_to_pil(img)
return img_pil, latents_state
def change_style(image: Image.Image, model_state, latents_state):
model = models[model_state]
img, latents_state = model.change_style(latents_state)
img_pil = cv_to_pil(img)
return img_pil, latents_state
def reset(model_state, latents_state):
model = models[model_state]
img, latents_state = model.reset(latents_state)
img_pil = cv_to_pil(img)
return img_pil, latents_state
def image_click(evt: gr.SelectData):
click_pos = evt.index
return click_pos
with gr.Blocks() as block:
model_state = gr.State(value="ffhq")
latents_state = gr.State({})
gr.Markdown(
"""# UserControllableLT: User Controllable Latent Transformer
Unofficial Gradio Demo
**Author**: Yuki Endo\\
**Paper**: [2208.12408](https://huggingface.co/papers/2208.12408)\\
**Code**: [UserControllableLT](https://github.com/endo-yuki-t/UserControllableLT)
<small>
Double click to add or remove stop points.
<small>
"""
)
with gr.Row():
with gr.Column():
model_name = gr.Dropdown(
choices=list(models_files.keys()),
label="Select Pretrained Model",
value="ffhq",
)
with gr.Row():
button = gr.Button("Random sample")
reset_btn = gr.Button("Reset")
change_style_bt = gr.Button("Change style")
dxdysxsy = gr.Textbox(
label="dxdysxsy",
value=default_dxdysxsy,
elem_id="dxdysxsy",
visible=False,
)
dz = gr.Slider(
minimum=-15, maximum=15, step_size=0.01, label="zoom", value=0.0
)
image = gr.Image(type="pil", visible=False, preprocess=False)
with gr.Accordion(label="Upload your face image", open=False):
gr.Markdown("<small> This only works on FFHQ model </small>")
with gr.Row():
image_path = gr.Image(
type="filepath", label="input image", interactive=True
)
examples = gr.Examples(
examples=[
"./examples/benedict.jpg",
"./examples/obama.jpg",
"./examples/me.jpg",
],
fn=load_from_img_file,
run_on_click=True,
inputs=[image_path],
outputs=[image, model_state, latents_state],
)
with gr.Column():
html = gr.HTML(canvas_html, label="output")
button.click(
random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
)
reset_btn.click(
reset,
inputs=[model_state, latents_state],
outputs=[image, latents_state],
queue=False,
).then(None, None, None, _js=reset_stop_points, queue=False)
change_style_bt.click(
change_style,
inputs=[image, model_state, latents_state],
outputs=[image, latents_state],
)
dxdysxsy.change(
transform,
inputs=[model_state, latents_state, dxdysxsy, dz],
outputs=[image, latents_state],
show_progress=False,
)
dz.change(
transform,
inputs=[model_state, latents_state, dxdysxsy, dz],
outputs=[image, latents_state],
show_progress=False,
)
image.change(None, inputs=[image], outputs=None, _js=image_change)
image_path.upload(
load_from_img_file,
inputs=[image_path],
outputs=[image, model_state, latents_state],
)
block.load(None, None, None, _js=load_js)
block.load(
random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
)
block.queue(api_open=False)
block.launch(show_api=False)
モデルのロードなどが完了した後以下のように、GradioのWebアプリがインライン表示されます。
以下のようにドラッグ操作で顔の向きの編集が可能です。処理速度上ややもっさりしますが、顔の向きが変わっていることが確認できます。
まとめ
本記事では、UserControllableLTを用いたマウス操作による画像編集をご紹介しました。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
リンク
リンク
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
参考文献
1.
論文 - User-Controllable Latent Transformer for StyleGAN Image Layout Editing
2. GitHub - endo-yuki-t/UserControllableLT
0 件のコメント :
コメントを投稿