[VQGAN+CLIP] 機械学習でアニメーションを生成する

2022年2月14日月曜日

Artificial Intelligence

本記事では、VQGAN+CLIPを使用して、任意のテキストや、画像からアニメーションを作成する方法をご紹介します。

アイキャッチ

VQGAN+CLIP

概要

VQGAN+CLIPとは、自然言語表現のテキストを入力に、入力テキストと意味的関連性の高い画像を出力するText to Imageタスクを実現する技術です。

簡単な理解としては、GANを用いて画像を生成し、CLIPを用いて生成した画像と入力テキストの関連性を採点します。その後、採点結果を反映して再度画像生成という処理を繰り返していくことで、入力テキストと意味的関連性の高い画像を出力します。

VQGAN+CLIPアーキテクチャ
出典: CLIP を用いた画像ランキングによるパラメータ最適化に基づいた絵本の挿絵生成

上図のように、ジェネレータ(GAN)が画像生成し、エンコーダー(CLIP)が画像とテキストを採点(CLIPスコア算出)することによりText to Imageタスクが実現されています。

本記事では、GANに入力するパラメータを徐々に変化させ、パタパタ漫画のように徐々に変化する画像(フレーム画像)を大量に生成します。
その後フレーム画像を一つの動画にまとめることでアニメーションを作成していきます。

また、以下の記事では、Text to Image手法を用いて任意のテキストから画像を生成する方法をご紹介しています。 まずは画像を生成してみたいという方はこちらをご参照ください。

デモ(Colaboratory)

それでは、実際に動かしながらVQGAN+CLIPによるアニメーションを行っていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではまず環境からセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

始めに、Google ColaboratoryからGoogle Driveにアクセスできるようにマウントします。
マウント後、Google Drive上に作業ディレクトリを作成します。この作業ディレクトリには、生成画像や動画が格納されます。

from google.colab import drive
drive.mount('/content/drive')

!mkdir '/content/drive/MyDrive/vqgan'
!mkdir '/content/drive/MyDrive/vqgan/images'
working_dir = '/content/drive/MyDrive/vqgan'

次にライブラリをインストールしていきます。

%cd /content/

print("Downloading CLIP...")
!git clone https://github.com/openai/CLIP                 &> /dev/null

print("Downloading Python AI libraries...")
!git clone https://github.com/CompVis/taming-transformers &> /dev/null
!pip install ftfy regex tqdm omegaconf pytorch-lightning  &> /dev/null
!pip install kornia                                       &> /dev/null
!pip install einops                                       &> /dev/null
 
print("Installing libraries for handling metadata...")
!pip install stegano                                      &> /dev/null
!apt install exempi                                       &> /dev/null
!pip install python-xmp-toolkit                           &> /dev/null
!pip install imgtag                                       &> /dev/null
!pip install pillow==7.1.2                                &> /dev/null
 
print("Installing Python video creation libraries...")
!pip install imageio-ffmpeg &> /dev/null
path = f'{working_dir}/steps'
!mkdir --parents {path}
print("Installation finished.")

ライブラリをインポートします。

import argparse
import math
from pathlib import Path
import sys
import os
import cv2
import pandas as pd
import numpy as np
import subprocess
import ast
 
sys.path.append('/content/taming-transformers')

# Some models include transformers, others need explicit pip install
try:
    import transformers
except Exception:
    !pip install transformers
    import transformers

from IPython import display
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
 
from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image
from imgtag import ImgTag    # metadata 
from libxmp import *         # metadata
import libxmp                # metadata
from stegano import lsb
import json
ImageFile.LOAD_TRUNCATED_IMAGES = True

環境セットアップの最後に各処理で用いる関数を定義しておきます。
なお、実装を意識せずともアニメーション作成は可能です。実装を確認したい場合は以下を展開してご確認ください。

Default画像の取得

VQGAN+CLIPでは、画像生成時に、入力テキストと同時に初期画像やターゲット画像を指定することができます。
ここでは初期画像などに使用するデフォルト画像を事前にダウンロードしておきます。

%cd /content/drive/MyDrive/vqgan/images
!wget https://www.pakutaso.com/shared/img/thumb/nantoshi21PAR519902088_TP_V4.jpg
!wget https://www.pakutaso.com/shared/img/thumb/yuka16011215IMG_5574_TP_V4.jpg

src_img = Image.open('/content/drive/MyDrive/vqgan/images/nantoshi21PAR519902088_TP_V4.jpg')
dst_img = Image.open('/content/drive/MyDrive/vqgan/images/yuka16011215IMG_5574_TP_V4.jpg')
src_img = src_img.resize((src_img.width // 2, src_img.height // 2))
dst_img = dst_img.resize((dst_img.width // 2, dst_img.height // 2))
src_img.save('/content/drive/MyDrive/vqgan/images/nantoshi21PAR519902088_TP_V4.jpg')
dst_img.save('/content/drive/MyDrive/vqgan/images/yuka16011215IMG_5574_TP_V4.jpg')


%cd /content/

以下の画像をダウンロードしてGoogle Drive内に保存しています。

なお、下記以外の画像を使用することも可能ですが、画像サイズにはご注意ください。
画像サイズが大きい場合メモリを超過する場合があります。

本記事では、そのままの画像サイズだとメモリを超過するため半分のサイズにリサイズしています。

デフォルト画像

パラメータ設定

フレーム画像生成時にパラメータを設定します。
特に重要なパラメーターは以下の通りです。

  1. text_prompts
    GANに入力するクエリテキスト。
    'Moon': {10: 0, 60: 1}の場合、Moonというクエリテキストを10フレーム目では0(最小)で、60フレーム目には1(最大)になるように入力。
  2. initial_image
    初期画像。必要ない場合空欄でも可。
    本記事では先ほどセットアップした画像を設定。"/content/drive/MyDrive/vqgan/images/nantoshi21PAR519902088_TP_V4.jpg"
  3. target_images
    ターゲット画像。必要ない場合空欄でも可。
    本記事では先ほどセットアップした画像を設定。"/content/drive/MyDrive/vqgan/images/yuka16011215IMG_5574_TP_V4.jpg': {10: 0, 60: 1}"
    上記設定では、クエリテキスト同様に、10フレーム目には最小、60フレーム目で最大になるように入力される。
  4. zoom
    各フレームをズームインするための係数、1はズームなし、1未満はズームアウト、1を超える場合はズームイン(正の値のみ)(E.g. 10: 1, 30: 1.2, 50: 0.9)
key_frames = True #@param {type:"boolean"}
text_prompts = "'Moon': {10: 0, 60: 1}, 'Sun': {10: 1, 60: 0}" #@param {type:"string"}
width =  400 #@param {type:"number"}
height =  400 #@param {type:"number"}
model = "vqgan_imagenet_f16_16384" #@param ["vqgan_imagenet_f16_16384", "vqgan_imagenet_f16_1024", "wikiart_16384", "coco", "faceshq", "sflckr"]
interval =  1#@param {type:"number"}
initial_image = "/content/drive/MyDrive/vqgan/images/nantoshi21PAR519902088_TP_V4.jpg"#@param {type:"string"}
target_images = "'/content/drive/MyDrive/vqgan/images/yuka16011215IMG_5574_TP_V4.jpg': {10: 0, 60: 1}"#@param {type:"string"}
seed = 1#@param {type:"number"}
max_frames = 60#@param {type:"number"}
angle = "10: 0, 30: 1, 60: -1"#@param {type:"string"}

# @markdown Careful: do not use negative or 0 zoom. If you want to zoom out, use a number between 0 and 1.
zoom = "10: 1, 30: 1.2, 60: 0.9"#@param {type:"string"}
translation_x = "0: 0"#@param {type:"string"}
translation_y = "0: 0"#@param {type:"string"}
iterations_per_frame = "0: 10"#@param {type:"string"}
save_all_iterations = False#@param {type:"boolean"}

設定したパラメータを基に各フレームに入力するパラメータを生成します。

# option -C - skips download if already exists
!curl -C - -L -o {model}.yaml -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 1024
!curl -C - -L -o {model}.ckpt -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1'  #ImageNet 1024

if initial_image != "":
    print(
        "WARNING: You have specified an initial image. Note that the image resolution "
        "will be inherited from this image, not whatever width and height you specified. "
        "If the initial image resolution is too high, this can result in out of memory errors."
    )
elif width * height > 160000:
    print(
        "WARNING: The width and height you have specified may be too high, in which case "
        "you will encounter out of memory errors either at the image generation stage or the "
        "video synthesis stage. If so, try reducing the resolution"
    )
model_names={
    "vqgan_imagenet_f16_16384": 'ImageNet 16384',
    "vqgan_imagenet_f16_1024":"ImageNet 1024", 
    "wikiart_1024":"WikiArt 1024",
    "wikiart_16384":"WikiArt 16384",
    "coco":"COCO-Stuff",
    "faceshq":"FacesHQ",
    "sflckr":"S-FLCKR"
}
model_name = model_names[model]

if seed == -1:
    seed = None

def parse_key_frames(string, prompt_parser=None):
    """Given a string representing frame numbers paired with parameter values at that frame,
    return a dictionary with the frame numbers as keys and the parameter values as the values.

    Parameters
    ----------
    string: string
        Frame numbers paired with parameter values at that frame number, in the format
        'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'
    prompt_parser: function or None, optional
        If provided, prompt_parser will be applied to each string of parameter values.
    
    Returns
    -------
    dict
        Frame numbers as keys, parameter values at that frame number as values

    Raises
    ------
    RuntimeError
        If the input string does not match the expected format.
    
    Examples
    --------
    >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)")
    {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}

    >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower()))
    {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}
    """

    try:
        # This is the preferred way, the regex way will eventually be deprecated.
        frames = ast.literal_eval('{' + string + '}')
        if isinstance(frames, set):
            # If user forgot keyframes, just set value of frame 0
            (frame,) = list(frames)
            frames = {0: frame}
        return frames
    except Exception:
        import re
        pattern = r'((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])'
        frames = dict()
        for match_object in re.finditer(pattern, string):
            frame = int(match_object.groupdict()['frame'])
            param = match_object.groupdict()['param']
            if prompt_parser:
                frames[frame] = prompt_parser(param)
            else:
                frames[frame] = param

        if frames == {} and len(string) != 0:
            raise RuntimeError(f'Key Frame string not correctly formatted: {string}')
        return frames

# Defaults, if left empty
if angle == "":
    angle = "0"
if zoom == "":
    zoom = "1"
if translation_x == "":
    translation_x = "0"
if translation_y == "":
    translation_y = "0"
if iterations_per_frame == "":
    iterations_per_frame = "10"

if key_frames:
    parameter_dicts = dict()
    parameter_dicts['zoom'] = parse_key_frames(zoom, prompt_parser=float)
    parameter_dicts['angle'] = parse_key_frames(angle, prompt_parser=float)
    parameter_dicts['translation_x'] = parse_key_frames(translation_x, prompt_parser=float)
    parameter_dicts['translation_y'] = parse_key_frames(translation_y, prompt_parser=float)
    parameter_dicts['iterations_per_frame'] = parse_key_frames(iterations_per_frame, prompt_parser=int)

    text_prompts_dict = parse_key_frames(text_prompts)
    if all([isinstance(value, dict) for value in list(text_prompts_dict.values())]):
       for key, value in list(text_prompts_dict.items()):
           parameter_dicts[f'text_prompt: {key}'] = value
    else:
        # Old format
        text_prompts_dict = parse_key_frames(text_prompts, prompt_parser=lambda x: x.split('|'))
        for frame, prompt_list in text_prompts_dict.items():
            for prompt in prompt_list:
                prompt_key, prompt_value = prompt.split(":")
                prompt_key = f'text_prompt: {prompt_key.strip()}'
                prompt_value = prompt_value.strip()
                if prompt_key not in parameter_dicts:
                    parameter_dicts[prompt_key] = dict()
                parameter_dicts[prompt_key][frame] = prompt_value


    image_prompts_dict = parse_key_frames(target_images)
    if all([isinstance(value, dict) for value in list(image_prompts_dict.values())]):
        for key, value in list(image_prompts_dict.items()):
           parameter_dicts[f'image_prompt: {key}'] = value
    else:
        # Old format
        image_prompts_dict = parse_key_frames(target_images, prompt_parser=lambda x: x.split('|'))
        for frame, prompt_list in image_prompts_dict.items():
            for prompt in prompt_list:
                prompt_key, prompt_value = prompt.split(":")
                prompt_key = f'image_prompt: {prompt_key.strip()}'
                prompt_value = prompt_value.strip()
                if prompt_key not in parameter_dicts:
                    parameter_dicts[prompt_key] = dict()
                parameter_dicts[prompt_key][frame] = prompt_value


def add_inbetweens():
    global text_prompts
    global target_images
    global zoom
    global angle
    global translation_x
    global translation_y
    global iterations_per_frame

    global text_prompts_series
    global target_images_series
    global zoom_series
    global angle_series
    global translation_x_series
    global translation_y_series
    global iterations_per_frame_series
    global model
    global args
    def get_inbetweens(key_frames_dict, integer=False):
        """Given a dict with frame numbers as keys and a parameter value as values,
        return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.
        Any values not provided in the input dict are calculated by linear interpolation between
        the values of the previous and next provided frames. If there is no previous provided frame, then
        the value is equal to the value of the next provided frame, or if there is no next provided frame,
        then the value is equal to the value of the previous provided frame. If no frames are provided,
        all frame values are NaN.

        Parameters
        ----------
        key_frames_dict: dict
            A dict with integer frame numbers as keys and numerical values of a particular parameter as values.
        integer: Bool, optional
            If True, the values of the output series are converted to integers.
            Otherwise, the values are floats.
        
        Returns
        -------
        pd.Series
            A Series with length max_frames representing the parameter values for each frame.
        
        Examples
        --------
        >>> max_frames = 5
        >>> get_inbetweens({1: 5, 3: 6})
        0    5.0
        1    5.0
        2    5.5
        3    6.0
        4    6.0
        dtype: float64

        >>> get_inbetweens({1: 5, 3: 6}, integer=True)
        0    5
        1    5
        2    5
        3    6
        4    6
        dtype: int64
        """
        key_frame_series = pd.Series([np.nan for a in range(max_frames)])
        for i, value in key_frames_dict.items():
            key_frame_series[i] = value
        key_frame_series = key_frame_series.astype(float)
        key_frame_series = key_frame_series.interpolate(limit_direction='both')
        if integer:
            return key_frame_series.astype(int)
        return key_frame_series

    if key_frames:
        text_prompts_series_dict = dict()
        for parameter in parameter_dicts.keys():
            if len(parameter_dicts[parameter]) > 0:
                if parameter.startswith('text_prompt:'):
                    try:
                        text_prompts_series_dict[parameter] = get_inbetweens(parameter_dicts[parameter])
                    except RuntimeError as e:
                        raise RuntimeError(
                            "WARNING: You have selected to use key frames, but you have not "
                            "formatted `text_prompts` correctly for key frames.\n"
                            "Please read the instructions to find out how to use key frames "
                            "correctly.\n"
                        )
        text_prompts_series = pd.Series([np.nan for a in range(max_frames)])
        for i in range(max_frames):
            combined_prompt = []
            for parameter, value in text_prompts_series_dict.items():
                parameter = parameter[len('text_prompt:'):].strip()
                combined_prompt.append(f'{parameter}: {value[i]}')
            text_prompts_series[i] = ' | '.join(combined_prompt)

        image_prompts_series_dict = dict()
        for parameter in parameter_dicts.keys():
            if len(parameter_dicts[parameter]) > 0:
                if parameter.startswith('image_prompt:'):
                    try:
                        image_prompts_series_dict[parameter] = get_inbetweens(parameter_dicts[parameter])
                    except RuntimeError as e:
                        raise RuntimeError(
                            "WARNING: You have selected to use key frames, but you have not "
                            "formatted `image_prompts` correctly for key frames.\n"
                            "Please read the instructions to find out how to use key frames "
                            "correctly.\n"
                        )
        target_images_series = pd.Series([np.nan for a in range(max_frames)])
        for i in range(max_frames):
            combined_prompt = []
            for parameter, value in image_prompts_series_dict.items():
                parameter = parameter[len('image_prompt:'):].strip()
                combined_prompt.append(f'{parameter}: {value[i]}')
            target_images_series[i] = ' | '.join(combined_prompt)

        try:
            angle_series = get_inbetweens(parameter_dicts['angle'])
        except RuntimeError as e:
            print(
                "WARNING: You have selected to use key frames, but you have not "
                "formatted `angle` correctly for key frames.\n"
                "Attempting to interpret `angle` as "
                f'"0: ({angle})"\n'
                "Please read the instructions to find out how to use key frames "
                "correctly.\n"
            )
            angle = f"0: ({angle})"
            angle_series = get_inbetweens(parse_key_frames(angle))

        try:
            zoom_series = get_inbetweens(parameter_dicts['zoom'])
        except RuntimeError as e:
            print(
                "WARNING: You have selected to use key frames, but you have not "
                "formatted `zoom` correctly for key frames.\n"
                "Attempting to interpret `zoom` as "
                f'"0: ({zoom})"\n'
                "Please read the instructions to find out how to use key frames "
                "correctly.\n"
            )
            zoom = f"0: ({zoom})"
            zoom_series = get_inbetweens(parse_key_frames(zoom))
        for i, zoom in enumerate(zoom_series):
            if zoom <= 0:
                print(
                    f"WARNING: You have selected a zoom of {zoom} at frame {i}. "
                    "This is meaningless. "
                    "If you want to zoom out, use a value between 0 and 1. "
                    "If you want no zoom, use a value of 1."
                )

        try:
            translation_x_series = get_inbetweens(parameter_dicts['translation_x'])
        except RuntimeError as e:
            print(
                "WARNING: You have selected to use key frames, but you have not "
                "formatted `translation_x` correctly for key frames.\n"
                "Attempting to interpret `translation_x` as "
                f'"0: ({translation_x})"\n'
                "Please read the instructions to find out how to use key frames "
                "correctly.\n"
            )
            translation_x = f"0: ({translation_x})"
            translation_x_series = get_inbetweens(parse_key_frames(translation_x))

        try:
            translation_y_series = get_inbetweens(parameter_dicts['translation_y'])
        except RuntimeError as e:
            print(
                "WARNING: You have selected to use key frames, but you have not "
                "formatted `translation_y` correctly for key frames.\n"
                "Attempting to interpret `translation_y` as "
                f'"0: ({translation_y})"\n'
                "Please read the instructions to find out how to use key frames "
                "correctly.\n"
            )
            translation_y = f"0: ({translation_y})"
            translation_y_series = get_inbetweens(parse_key_frames(translation_y))

        try:
            iterations_per_frame_series = get_inbetweens(
                parameter_dicts['iterations_per_frame'], integer=True
            )
        except RuntimeError as e:
            print(
                "WARNING: You have selected to use key frames, but you have not "
                "formatted `iterations_per_frame` correctly for key frames.\n"
                "Attempting to interpret `iterations_per_frame` as "
                f'"0: ({iterations_per_frame})"\n'
                "Please read the instructions to find out how to use key frames "
                "correctly.\n"
            )
            iterations_per_frame = f"0: ({iterations_per_frame})"
            
            iterations_per_frame_series = get_inbetweens(
                parse_key_frames(iterations_per_frame), integer=True
            )
    else:
        text_prompts = [phrase.strip() for phrase in text_prompts.split("|")]
        if text_prompts == ['']:
            text_prompts = []
        if target_images == "None" or not target_images:
            target_images = []
        else:
            target_images = target_images.split("|")
            target_images = [image.strip() for image in target_images]

        angle = float(angle)
        zoom = float(zoom)
        translation_x = float(translation_x)
        translation_y = float(translation_y)
        iterations_per_frame = int(iterations_per_frame)
        if zoom <= 0:
            print(
                f"WARNING: You have selected a zoom of {zoom}. "
                "This is meaningless. "
                "If you want to zoom out, use a value between 0 and 1. "
                "If you want no zoom, use a value of 1."
            )

    args = argparse.Namespace(
        prompts=text_prompts,
        image_prompts=target_images,
        noise_prompt_seeds=[],
        noise_prompt_weights=[],
        size=[width, height],
        init_weight=0.,
        clip_model='ViT-B/32',
        vqgan_config=f'{model}.yaml',
        vqgan_checkpoint=f'{model}.ckpt',
        step_size=0.1,
        cutn=64,
        cut_pow=1.,
        display_freq=interval,
        seed=seed,
    )

add_inbetweens()

フレーム画像の生成

設定したパラメータに基づいてフレーム画像を生成していきます。
パラメータ設定が上手くいっていれば順次フレーム画像の生成が行われます。

path = f'{working_dir}/steps'
!rm -r {path}
!mkdir --parents {path}

#@title Actually do the run...

# Delete memory from previous runs
!nvidia-smi -caa
for var in ['device', 'model', 'perceptor', 'z']:
  try:
      del globals()[var]
  except:
      pass

try:
    import gc
    gc.collect()
except:
    pass

try:
    torch.cuda.empty_cache()
except:
    pass

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if not key_frames:
    if text_prompts:
        print('Using text prompts:', text_prompts)
    if target_images:
        print('Using image prompts:', target_images)
if args.seed is None:
    seed = torch.seed()
else:
    seed = args.seed
torch.manual_seed(seed)
print('Using seed:', seed)
 
model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
 
cut_size = perceptor.visual.input_resolution
e_dim = model.quantize.e_dim
f = 2**(model.decoder.num_resolutions - 1)
make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
n_toks = model.quantize.n_e
toksX, toksY = args.size[0] // f, args.size[1] // f
sideX, sideY = toksX * f, toksY * f
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
stop_on_next_loop = False  # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete

def read_image_workaround(path):
    """OpenCV reads images as BGR, Pillow saves them as RGB. Work around
    this incompatibility to avoid colour inversions."""
    im_tmp = cv2.imread(path)
    return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)

for i in range(max_frames):
    if stop_on_next_loop:
      break
    if key_frames:
        text_prompts = text_prompts_series[i]
        text_prompts = [phrase.strip() for phrase in text_prompts.split("|")]
        if text_prompts == ['']:
            text_prompts = []
        args.prompts = text_prompts

        target_images = target_images_series[i]

        if target_images == "None" or not target_images:
            target_images = []
        else:
            target_images = target_images.split("|")
            target_images = [image.strip() for image in target_images]
        args.image_prompts = target_images

        angle = angle_series[i]
        zoom = zoom_series[i]
        translation_x = translation_x_series[i]
        translation_y = translation_y_series[i]
        iterations_per_frame = iterations_per_frame_series[i]
        print(
            f'text_prompts: {text_prompts}',
            f'image_prompts: {target_images}',
            f'angle: {angle}',
            f'zoom: {zoom}',
            f'translation_x: {translation_x}',
            f'translation_y: {translation_y}',
            f'iterations_per_frame: {iterations_per_frame}'
        )
    try:
        if i == 0 and initial_image != "":
            img_0 = read_image_workaround(initial_image)
            z, *_ = model.encode(TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1)
        elif i == 0 and not os.path.isfile(f'{working_dir}/steps/{i:04d}.png'):
            one_hot = F.one_hot(
                torch.randint(n_toks, [toksY * toksX], device=device), n_toks
            ).float()
            z = one_hot @ model.quantize.embedding.weight
            z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
        else:
            if save_all_iterations:
                img_0 = read_image_workaround(
                    f'{working_dir}/steps/{i:04d}_{iterations_per_frame}.png')
            else:
                img_0 = read_image_workaround(f'{working_dir}/steps/{i:04d}.png')

            center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)
            trans_mat = np.float32(
                [[1, 0, translation_x],
                [0, 1, translation_y]]
            )
            rot_mat = cv2.getRotationMatrix2D( center, angle, zoom )

            trans_mat = np.vstack([trans_mat, [0,0,1]])
            rot_mat = np.vstack([rot_mat, [0,0,1]])
            transformation_matrix = np.matmul(rot_mat, trans_mat)

            img_0 = cv2.warpPerspective(
                img_0,
                transformation_matrix,
                (img_0.shape[1], img_0.shape[0]),
                borderMode=cv2.BORDER_WRAP
            )
            z, *_ = model.encode(TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1)
        i += 1

        z_orig = z.clone()
        z.requires_grad_(True)
        opt = optim.Adam([z], lr=args.step_size)

        normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                        std=[0.26862954, 0.26130258, 0.27577711])

        pMs = []

        for prompt in args.prompts:
            txt, weight, stop = parse_prompt(prompt)
            embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
            pMs.append(Prompt(embed, weight, stop).to(device))

        for prompt in args.image_prompts:
            path, weight, stop = parse_prompt(prompt)
            img = resize_image(Image.open(path).convert('RGB'), (sideX, sideY))
            batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
            embed = perceptor.encode_image(normalize(batch)).float()
            pMs.append(Prompt(embed, weight, stop).to(device))

        for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
            gen = torch.Generator().manual_seed(seed)
            embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
            pMs.append(Prompt(embed, weight).to(device))

        def synth(z):
            z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
            return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

        def add_xmp_data(filename):
            imagen = ImgTag(filename=filename)
            imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'creator', 'VQGAN+CLIP', {"prop_array_is_ordered":True, "prop_value_is_array":True})
            if args.prompts:
                imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', " | ".join(args.prompts), {"prop_array_is_ordered":True, "prop_value_is_array":True})
            else:
                imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', 'None', {"prop_array_is_ordered":True, "prop_value_is_array":True})
            imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'i', str(i), {"prop_array_is_ordered":True, "prop_value_is_array":True})
            imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'model', model_name, {"prop_array_is_ordered":True, "prop_value_is_array":True})
            imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'seed',str(seed) , {"prop_array_is_ordered":True, "prop_value_is_array":True})
            imagen.close()

        def add_stegano_data(filename):
            data = {
                "title": " | ".join(args.prompts) if args.prompts else None,
                "notebook": "VQGAN+CLIP",
                "i": i,
                "model": model_name,
                "seed": str(seed),
            }
            lsb.hide(filename, json.dumps(data)).save(filename)

        @torch.no_grad()
        def checkin(i, losses):
            losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
            tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
            out = synth(z)
            TF.to_pil_image(out[0].cpu()).save('progress.png')
            add_stegano_data('progress.png')
            add_xmp_data('progress.png')
            display.display(display.Image('progress.png'))

        def save_output(i, img, suffix=None):
            filename = \
                f"{working_dir}/steps/{i:04}{'_' + suffix if suffix else ''}.png"
            imageio.imwrite(filename, np.array(img))
            add_stegano_data(filename)
            add_xmp_data(filename)

        def ascend_txt(i, save=True, suffix=None):
            out = synth(z)
            iii = perceptor.encode_image(normalize(make_cutouts(out))).float()

            result = []

            if args.init_weight:
                result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)

            for prompt in pMs:
                result.append(prompt(iii))
            img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
            img = np.transpose(img, (1, 2, 0))
            if save:
                save_output(i, img, suffix=suffix)
            return result

        def train(i, save=True, suffix=None):
            opt.zero_grad()
            lossAll = ascend_txt(i, save=save, suffix=suffix)
            if i % args.display_freq == 0 and save:
                checkin(i, lossAll)
            loss = sum(lossAll)
            loss.backward()
            opt.step()
            with torch.no_grad():
                z.copy_(z.maximum(z_min).minimum(z_max))

        with tqdm() as pbar:
            if iterations_per_frame == 0:
                save_output(i, img_0)
            j = 1
            while True:
                suffix = (str(j) if save_all_iterations else None)
                if j >= iterations_per_frame:
                    train(i, save=True, suffix=suffix)
                    break
                if save_all_iterations:
                    train(i, save=True, suffix=suffix)
                else:
                    train(i, save=False, suffix=suffix)
                j += 1
                pbar.update()
    except KeyboardInterrupt:
      stop_on_next_loop = True
      pass

動画の生成

生成したフレーム画像から動画を生成します。 この時、last_frameを指定することで、動画に使用するフレーム画像の範囲を指定できます。

# @title Create video
# import subprocess in case this cell is run without the above cells
import subprocess

# Try to avoid OOM errors
torch.cuda.empty_cache()

init_frame = 1#@param {type:"number"} This is the frame where the video will start
last_frame = 60#@param {type:"number"} You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.
fps = 12#@param {type:"number"}

try:
    key_frames
except NameError:
    filename = "video.mp4"
else:
    if key_frames:
        # key frame filename would be too long
        filename = "video.mp4"
    else:
        filename = f"{'_'.join(text_prompts).replace(' ', '')}.mp4"
filepath = f'{working_dir}/{filename}'

frames = []
# tqdm.write('Generating video...')
try:
    zoomed
except NameError:
    image_path = f'{working_dir}/steps/%04d.png'
else:
    image_path = f'{working_dir}/steps/zoomed_%04d.png'

cmd = [
    'ffmpeg',
    '-y',
    '-vcodec',
    'png',
    '-r',
    str(fps),
    '-start_number',
    str(init_frame),
    '-i',
    image_path,
    '-c:v',
    'libx264',
    '-frames:v',
    str(last_frame-init_frame),
    '-vf',
    f'fps={fps}',
    '-pix_fmt',
    'yuv420p',
    '-crf',
    '17',
    '-preset',
    'veryslow',
    filepath
]

process = subprocess.Popen(cmd, cwd=f'{working_dir}/steps/', stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
if process.returncode != 0:
    print(stderr)
    print(
        "You may be able to avoid this error by backing up the frames,"
        "restarting the notebook, and running only the google drive/local connection and video synthesis cells,"
        "or by decreasing the resolution of the image generation steps. "
        "If these steps do not work, please post the traceback in the github."
    )
    raise RuntimeError(stderr)
else:
    print("The video is ready")

Google Driveに動画が出力されます。

text_prompts = "'Moon': {10: 0, 60: 1}, 'Sun': {10: 1, 60: 0}"
initial_image = "/content/drive/MyDrive/vqgan/images/nantoshi21PAR519902088_TP_V4.jpg"
target_images = "'/content/drive/MyDrive/vqgan/images/yuka16011215IMG_5574_TP_V4.jpg': {10: 0, 60: 1}"
の出力結果は以下の通りです。

アニメ生成結果1

text_prompts = "'Death': {10: 0, 60: 1}
initial_image = "/content/drive/MyDrive/vqgan/images/nantoshi21PAR519902088_TP_V4.jpg"
target_images = ""
の出力結果は以下の通りです。

アニメ生成結果2

「死」を十字架やドクロで表現している点が興味深いですね。

まとめ

本記事では、VQGAN+CLIPを用いて機械学習でアニメーションを作ってみました。
SEED一つで出力結果が大きく変わるのでいろいろ試してみてください。

これを機に機械学習に興味を持つ方が一人でもいらっしゃいましたら幸いです。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - Taming Transformers for High-Resolution Image Synthesis

2. GitHub - chigozienri/VQGAN-CLIP-animations

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology