[DMVFN] AIで動画の次のフレームを予測する [Video Prediction]

2023年3月28日火曜日

Artificial Intelligence

本記事では、DMVFNと呼ばれる機械学習手法を用いて任意の動画の次のフレームを予測する方法をご紹介します。

eye catch
出典: A Dynamic Multi-Scale Voxel Flow Network for Video Prediction

DMVFN

概要

Dynamic Multi-Scale Voxel Flow Network(DMVFN)は、2023年3月に論文発表されたVideo Prediction技術です。

Video Predictionタスクでは、過去のフレームを元に未来のフレームを予測します。
このタスクを実現する従来技術は、モデルサイズが大きいという課題があり、かつ、高い精度を得るには、セマンティックセグメンテーションや、Depthマップなど追加の入力情報が必要な場合がありました。

DMVFNでは、Dynamic Multi-scale Voxel Flow Networkを構築し、RGB画像のみで従来より低い計算コストで優れた予測パフォーマンスを実現しています。

Architecture
出典: A Dynamic Multi-Scale Voxel Flow Network for Video Prediction

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、動画のフレーム画像から未来のフレーム画像を予測していきます。

デモ(Colaboratory)

それでは、実際に動かしながらVideo Predictionを試していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content

!git clone https://github.com/megvii-research/CVPR2023-DMVFN.git ./DMVFN

%cd /content/DMVFN
# Commits on Mar 24, 2023
!git checkout 2ffe0399ecb82e77ef5f386d0be75c8ce5bcef2f

次にライブラリをインストールします。

%cd /content/DMVFN

!pip install -r requirements.txt
!pip install --upgrade gdown

最後にライブラリをインポートします。

%cd /content/DMVFN

import os
import gdown
import random
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import glob

import torch
device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device is", device)

from utils.util import *
from model.model import Model

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

ここでは、論文発表元が公開する学習済みモデルをダウンロードしていきます。

%cd /content/DMVFN

os.makedirs('./pretrained_models', exist_ok=True)

if not os.path.exists("./pretrained_models/dmvfn_city.pkl"):
  gdown.download('https://drive.google.com/uc?id=1jILbS8Gm4E5Xx4tDCPZh_7rId0eo8r9W', "./pretrained_models/dmvfn_city.pkl", quiet=False)
if not os.path.exists("./pretrained_models/dmvfn_kitti.pkl"):
  gdown.download('https://drive.google.com/uc?id=1WrV30prRiS4hWOQBnVPUxdaTlp9XxmVK', "./pretrained_models/dmvfn_kitti.pkl", quiet=False)
if not os.path.exists("./pretrained_models/dmvfn_vimeo.pkl"):
  gdown.download('https://drive.google.com/uc?id=14_xQ3Yl3mO89hr28hbcQW3h63lLrcYY0', "./pretrained_models/dmvfn_vimeo.pkl", quiet=False)

テストデータのセットアップ

ここでは、モデルに入力するテストデータをダウンロードしていきます。

%cd /content/DMVFN

os.makedirs('./data/cityscapes/test', exist_ok=True)

if not os.path.exists("./data/cityscapes/test/test.zip"):
  gdown.download('https://drive.google.com/uc?id=10zCt-uZFOqgF3tpdhluRqbs-4aScvGR4&confirm=t', "./data/cityscapes/test/test.zip", quiet=False)
  %cd /content/DMVFN/data/cityscapes/test
  !unzip -q ./test.zip

%cd /content/DMVFN

Inference

それでは、Video Predictionを実行していきます。

モデルを選択し、ランダムシードを固定します。

%cd /content/DMVFN

pretrained_weights = 'city' #@param ['city', 'kitti', 'vimeo']

model_path = './pretrained_models/dmvfn_' + pretrained_weights + '.pkl'

seed = 12
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True

推論を実行します。

# input image dir
image_dir = "./data/cityscapes/test/000000"
image_list = glob.glob(image_dir + "/*.png")
image_list.sort()

# output dir
out_dir = "./output"
os.makedirs('./output', exist_ok=True)

# load model
model = Model(load_path=model_path, training=False)

pred_num = 3 # max->len(image_list) - 2

with torch.no_grad():
  for i in range(pred_num):
    if i == 0:
      # load image
      cvimg_0 = cv2.imread(image_list[0])
      cvimg_1 = cv2.imread(image_list[1])
    else:
      cvimg_0 = cvimg_1
      cvimg_1 = pred
    # preprocess image
    img_0 = cvimg_0.transpose(2, 0, 1).astype('float32')
    img_1 = cvimg_1.transpose(2, 0, 1).astype('float32')
    img = torch.cat([torch.tensor(img_0),torch.tensor(img_1)], dim=0)
    img = img.unsqueeze(0).unsqueeze(0).to(device, non_blocking=True) # NCHW
    img = img.to(device, non_blocking=True) / 255.
    # inference
    pred = model.eval(img, 'single_test') # 1CHW
    # post process
    pred = np.array(pred.cpu().squeeze() * 255).transpose(1, 2, 0) # CHW -> HWC
    # save result image
    cv2.imwrite(
        os.path.join(out_dir, f'pred_{i:06}.png'), pred)

推論結果を表示して確認します。

show_num = pred_num + 2

# show original images
fig = plt.figure(figsize=(15, 10))
for i in range(show_num):
  ax = fig.add_subplot(1, show_num, i+1)
  plt.title(os.path.basename(image_list[i]), fontsize=16)
  ax.axis('off')
  ax.imshow( Image.open(image_list[i]) )
plt.show()

# show predict images
pred_list = glob.glob(os.path.join(out_dir, '*.png'))
pred_list.sort()
fig = plt.figure(figsize=(15, 10))
for i in range(show_num):
  if i < 2:
    image_path = image_list[i]
  else:
    image_path = pred_list[i-2]

  ax = fig.add_subplot(1, show_num, i+1)
  plt.title(os.path.basename(image_path), fontsize=16)
  ax.axis('off')
  ax.imshow( Image.open(image_path) )
plt.show()

pred_***が予測されたフレーム画像です。
正解データより若干進みが遅いですが、車が進んでいるフレームが予測されています。

result 01
result 02

まとめ

本記事では、DMVFNを用いたVideo Predictionをご紹介しました。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - A Dynamic Multi-Scale Voxel Flow Network for Video Prediction

2. GitHub - megvii-research/CVPR2023-DMVFN

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology