[DualStyleGAN] 機械学習で写真を似顔絵風にスタイル転送

2022年5月31日火曜日

Artificial Intelligence

本記事では、DualStyleGANと呼ばれる機械学習手法を用いて、写真に似顔絵スタイルを転送し、似顔絵風画像を生成する方法をご紹介します。

アイキャッチ
出典: Pastiche Master: Exemplar-Based High-Resolution Portrait Style Transfer

DualStyleGAN

概要

DualStyleGANは、オリジナル顔画像のドメインと拡張された芸術的なポートレートドメインのDualStyleを柔軟に制御可能なexemplarベースの高解像度ポートレートスタイル転送技術です。

DualStyleGANは、ポートレートのコンテンツとスタイルをそれぞれintrinsic style pathとextrinsic style pathで特徴づけることにより自然なスタイル転送を実現しています。

extrinsic style pathによりモデルは色や複雑なスタイルの両方を段階的にスタイル転送することが可能となっています。

overview
出典: Pastiche Master: Exemplar-Based High-Resolution Portrait Style Transfer

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、スタイル転送を行っていきます。

デモ(Colaboratory)

それでは、実際に動かしながらスタイル転送を行っていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

また、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content

!git clone https://github.com/williamyang1991/DualStyleGAN.git

次にライブラリをインストールします。

%cd /content

# ninja
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 

!pip install faiss-cpu
!pip install wget
!pip install --upgrade --no-cache-dir gdown

最後にライブラリをインポートします。

%cd /content/DualStyleGAN

%load_ext autoreload
%autoreload 2

import sys
sys.path.append(".")
sys.path.append("..")

import numpy as np
import torch
from util import save_image, load_image, visualize
import argparse
from argparse import Namespace
from torchvision import transforms
from torch.nn import functional as F
import torchvision
import matplotlib.pyplot as plt
from model.dualstylegan import DualStyleGAN
from model.sampler.icp import ICPTrainer
from model.encoder.psp import pSp
from model.encoder.align_all_parallel import align_face

import os
import gdown
import wget
import bz2
import dlib

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

MODEL_DIR = '/content/DualStyleGAN/checkpoint'
DATA_DIR = '/content/DualStyleGAN/data'

以上で環境セットアップは完了です。

モデルのセットアップ

続いて転送するスタイルに合わせたモデルをセットアップしていきます。

はじめに転送するスタイルを設定します。

style_type = 'caricature' #@param ['cartoon', 'caricature', 'anime']

os.makedirs(os.path.join(MODEL_DIR, style_type), exist_ok=True)

設定に基づき学習済みモデルをダウンロードしていきます。

MODEL_PATHS = {
    "encoder": {"id": "1NgI4mPkboYvYw3MWcdUaQhkr0OWgs9ej", "name": "encoder.pt"},
    "cartoon-G": {"id": "1exS9cSFkg8J4keKPmq2zYQYfJYC5FkwL", "name": "generator.pt"},
    "cartoon-N": {"id": "1JSCdO0hx8Z5mi5Q5hI9HMFhLQKykFX5N", "name": "sampler.pt"},
    "cartoon-S": {"id": "1ce9v69JyW_Dtf7NhbOkfpH77bS_RK0vB", "name": "refined_exstyle_code.npy"},
    "caricature-G": {"id": "1BXfTiMlvow7LR7w8w0cNfqIl-q2z0Hgc", "name": "generator.pt"},
    "caricature-N": {"id": "1eJSoaGD7X0VbHS47YLehZayhWDSZ4L2Q", "name": "sampler.pt"},
    "caricature-S": {"id": "1-p1FMRzP_msqkjndRK_0JasTdwQKDsov", "name": "refined_exstyle_code.npy"},
    "anime-G": {"id": "1BToWH-9kEZIx2r5yFkbjoMw0642usI6y", "name": "generator.pt"},
    "anime-N": {"id": "19rLqx_s_SUdiROGnF_C6_uOiINiNZ7g2", "name": "sampler.pt"},
    "anime-S": {"id": "17-f7KtrgaQcnZysAftPogeBwz5nOWYuM", "name": "refined_exstyle_code.npy"},  
}

def get_download_model_command(file_id, file_name):
  download_path = os.path.join(MODEL_DIR, file_name)
  if not os.path.exists(download_path):
    gdown.download('https://drive.google.com/uc?id='+file_id, download_path, quiet=False)
    
# download pSp encoder
get_download_model_command(MODEL_PATHS["encoder"]["id"], MODEL_PATHS["encoder"]["name"])
# download dualstylegan
get_download_model_command(
    MODEL_PATHS[style_type+'-G']["id"], 
    os.path.join(style_type, MODEL_PATHS[style_type+'-G']["name"]) )
# download sampler
get_download_model_command(
    MODEL_PATHS[style_type+'-N']["id"], 
    os.path.join(style_type, MODEL_PATHS[style_type+'-N']["name"]) )
# download extrinsic style code
get_download_model_command(
    MODEL_PATHS[style_type+'-S']["id"], 
    os.path.join(style_type, MODEL_PATHS[style_type+'-S']["name"]) )

ダウンロードした学習済みモデルをロードします。

transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

# DualStyleGANのロード
generator = DualStyleGAN(1024, 512, 8, 2, res_index=6)
generator.eval()
ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'generator.pt'), map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt["g_ema"])
generator = generator.to(device)

# encoderのロード
model_path = os.path.join(MODEL_DIR, 'encoder.pt')
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = model_path
opts = Namespace(**opts)
opts.device = device
encoder = pSp(opts)
encoder.eval()
encoder = encoder.to(device)

# extrinsic style codeのロード
exstyles = np.load(os.path.join(MODEL_DIR, style_type, MODEL_PATHS[style_type+'-S']["name"]), allow_pickle='TRUE').item()

# sampler networkのロード
icptc = ICPTrainer(np.empty([0,512*11]), 128)
icpts = ICPTrainer(np.empty([0,512*7]), 128)
ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'sampler.pt'), map_location=lambda storage, loc: storage)
icptc.icp.netT.load_state_dict(ckpt['color'])
icpts.icp.netT.load_state_dict(ckpt['structure'])
icptc.icp.netT = icptc.icp.netT.to(device)
icpts.icp.netT = icpts.icp.netT.to(device)

print('Model successfully loaded!')

画像のセットアップ

スタイル転送したい写真をセットアップします。

似顔絵風に変換したいお好きな画像をご用意ください。 本記事ではぱくたそ様の以下の画像を使用します。

test image
%cd /content/DualStyleGAN
!rm -rf images output_images
!mkdir images output_images

!wget -c https://www.pakutaso.com/shared/img/thumb/soraPAR59476_TP_V.jpg \
      -O ./images/test1.jpg

用意した画像から顔部分を抜き出し位置合わせします。

def run_alignment(image_path):
    modelname = os.path.join(MODEL_DIR, 'shape_predictor_68_face_landmarks.dat')
    if not os.path.exists(modelname):
        wget.download('http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', modelname+'.bz2')
        zipfile = bz2.BZ2File(modelname+'.bz2')
        data = zipfile.read()
        open(modelname, 'wb').write(data) 
    predictor = dlib.shape_predictor(modelname)
    aligned_image = align_face(filepath=image_path, predictor=predictor)
    return aligned_image

I = transform(run_alignment(image_path)).unsqueeze(dim=0).to(device)

plt.figure(figsize=(10,10),dpi=30)
visualize(I[0].cpu())
plt.show()
align_face

スタイル転送

それでは先ほどの顔写真にスタイルを転送していきます。

まずスタイル転送元の画像を指定します。

if style_type == "anime":
  # stylepath = "/content/DualStyleGAN/data/anime/images/train/16031200.jpg"
  stylepath = "/content/DualStyleGAN/data/anime/images/train/23075800.jpg"
elif style_type == "caricature":
  stylepath = "/content/DualStyleGAN/data/caricature/images/train/Hillary_Clinton_C00034.jpg"
  # stylepath = "/content/DualStyleGAN/data/caricature/images/train/Liv_Tyler_C00009.jpg"
elif style_type == "cartoon":
  stylepath = "/content/DualStyleGAN/data/cartoon/images/train/Cartoons_00003_01.jpg"
  # stylepath = "/content/DualStyleGAN/data/cartoon/images/train/Cartoons_00038_07.jpg"
  # stylepath = "/content/DualStyleGAN/data/cartoon/images/train/Cartoons_00167_01.jpg"
else:
  print(exstyles.keys())
  raise Exception("Please download train images.")

stylename = os.path.basename(stylepath)

# style imageのロード
print('loading %s'%stylepath)
if os.path.exists(stylepath):
    S = load_image(stylepath)
    plt.figure(figsize=(10,10),dpi=30)
    visualize(S[0])
    plt.show()
else:
    print('%s is not found'%stylename)
style image

それではスタイル転送します。

with torch.no_grad():
    img_rec, instyle = encoder(I, randomize_noise=False, return_latents=True, 
                            z_plus_latent=True, return_z_plus_latent=True, resize=False)    
    img_rec = torch.clamp(img_rec.detach(), -1, 1)
    
    latent = torch.tensor(exstyles[stylename]).repeat(2,1,1).to(device)
    # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer
    latent[1,7:18] = instyle[0,7:18]
    exstyle = generator.generator.style(latent.reshape(latent.shape[0]*latent.shape[1], latent.shape[2])).reshape(latent.shape)
    
    img_gen, _ = generator([instyle.repeat(2,1,1)], exstyle, z_plus_latent=True, 
                           truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[1]*11)
    img_gen = torch.clamp(img_gen.detach(), -1, 1)
    # deactivate color-related layers by setting w_c = 0
    img_gen2, _ = generator([instyle], exstyle[0:1], z_plus_latent=True, 
                            truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[0]*11)
    img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
    
vis = torchvision.utils.make_grid(F.adaptive_avg_pool2d(torch.cat([img_rec, img_gen, img_gen2], dim=0), 256), 4, 1)
plt.figure(figsize=(10,10),dpi=120)
visualize(vis.cpu())
plt.show()

出力結果は以下の通りです。 左から以下の順で並んでいます。

  1. pSpで再構成したコンテンツ画像
  2. colorとstructureをスタイル転送した画像
  3. コンテンツ画像の色に置き換え、コンテンツ画像の色を再現したスタイル転送画像
  4. 色関連のレイヤーを非アクティブにすることによりコンテンツ画像の色を保持したスタイル転送画像
transfer result

段階的スタイル転送

次に段階的にスタイルを転送していき、徐々に似顔絵風に変換していきます。

results = []
s_root = 12
num = s_root*s_root
for i in range(num): 
  structrue_w = [i/num]*7 # structure codesのweightを変更
  color_w = [i/num]*11 # color codesのweightを変更

  w = structrue_w + color_w  
  img_gen, _ = generator(
      [instyle], exstyle[0:1], z_plus_latent=True, 
      truncation=0.7, truncation_latent=0, use_res=True, interp_weights=w)
  img_gen = torch.clamp(F.adaptive_avg_pool2d(img_gen.detach(), 512), -1, 1)
  results += [img_gen]
  
  # save image
  sv_img = torchvision.utils.make_grid(torch.cat([img_gen], dim=0), 1, 1)
  sv_img = ((sv_img.cpu().detach().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
  filename = os.path.join("/content/DualStyleGAN/output_images", "result_" + f'{i:06}' + ".jpg")
  plt.imsave(filename, sv_img)
        
vis = torchvision.utils.make_grid(torch.cat(results, dim=0), s_root, 1)
plt.figure(figsize=(10,10),dpi=120)
visualize(vis.cpu())
plt.show()
step by step result

徐々に変化する一連の画像から動画を生成します。

!ffmpeg -i "/content/DualStyleGAN/output_images/result_%06d.jpg" -c:v libx264 -vf "format=yuv420p" "/content/DualStyleGAN/output_images/result.mp4"

from moviepy.editor import *
from moviepy.video.fx.resize import resize
clip = VideoFileClip("/content/DualStyleGAN/output_images/result.mp4")
clip = resize(clip, height=420)
clip.ipython_display()

出力結果は以下の通りです。

result video

まとめ

本記事では、DualStyleGANと呼ばれる機械学習手法を用いて、写真に似顔絵スタイルを転送し、似顔絵風画像を生成する方法をご紹介しました。

段階的に変化する様子は見ていて知的好奇心をくすぐられます。 一方でオリジナル画像を潜在空間に再構成する際にオリジナル画像からやや顔が変わってため、別の手法でinverseして改善したいところです。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - Pastiche Master: Exemplar-Based High-Resolution Portrait Style Transfer

2. GitHub - williamyang1991/DualStyleGAN

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology