本記事では、DMVFNと呼ばれる機械学習手法を用いて任意の動画の次のフレームを予測する方法をご紹介します。
DMVFN
概要
Dynamic Multi-Scale Voxel Flow Network(DMVFN)は、2023年3月に論文発表されたVideo Prediction技術です。
Video Predictionタスクでは、過去のフレームを元に未来のフレームを予測します。
このタスクを実現する従来技術は、モデルサイズが大きいという課題があり、かつ、高い精度を得るには、セマンティックセグメンテーションや、Depthマップなど追加の入力情報が必要な場合がありました。
DMVFNでは、Dynamic Multi-scale Voxel Flow Networkを構築し、RGB画像のみで従来より低い計算コストで優れた予測パフォーマンスを実現しています。
詳細はこちらの論文をご参照ください。
本記事では上記手法を用いて、動画のフレーム画像から未来のフレーム画像を予測していきます。
デモ(Colaboratory)
それでは、実際に動かしながらVideo Predictionを試していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo
また、下記から直接Google Colaboratoryで開くこともできます。
なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
環境セットアップ
それではセットアップしていきます。
Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
初めにGithubからソースコードを取得します。
%cd /content
!git clone https://github.com/megvii-research/CVPR2023-DMVFN.git ./DMVFN
%cd /content/DMVFN
# Commits on Mar 24, 2023
!git checkout 2ffe0399ecb82e77ef5f386d0be75c8ce5bcef2f
次にライブラリをインストールします。
%cd /content/DMVFN
!pip install -r requirements.txt
!pip install --upgrade gdown
最後にライブラリをインポートします。
%cd /content/DMVFN
import os
import gdown
import random
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import glob
import torch
device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device is", device)
from utils.util import *
from model.model import Model
以上で環境セットアップは完了です。
学習済みモデルのセットアップ
ここでは、論文発表元が公開する学習済みモデルをダウンロードしていきます。
%cd /content/DMVFN
os.makedirs('./pretrained_models', exist_ok=True)
if not os.path.exists("./pretrained_models/dmvfn_city.pkl"):
gdown.download('https://drive.google.com/uc?id=1jILbS8Gm4E5Xx4tDCPZh_7rId0eo8r9W', "./pretrained_models/dmvfn_city.pkl", quiet=False)
if not os.path.exists("./pretrained_models/dmvfn_kitti.pkl"):
gdown.download('https://drive.google.com/uc?id=1WrV30prRiS4hWOQBnVPUxdaTlp9XxmVK', "./pretrained_models/dmvfn_kitti.pkl", quiet=False)
if not os.path.exists("./pretrained_models/dmvfn_vimeo.pkl"):
gdown.download('https://drive.google.com/uc?id=14_xQ3Yl3mO89hr28hbcQW3h63lLrcYY0', "./pretrained_models/dmvfn_vimeo.pkl", quiet=False)
テストデータのセットアップ
ここでは、モデルに入力するテストデータをダウンロードしていきます。
%cd /content/DMVFN
os.makedirs('./data/cityscapes/test', exist_ok=True)
if not os.path.exists("./data/cityscapes/test/test.zip"):
gdown.download('https://drive.google.com/uc?id=10zCt-uZFOqgF3tpdhluRqbs-4aScvGR4&confirm=t', "./data/cityscapes/test/test.zip", quiet=False)
%cd /content/DMVFN/data/cityscapes/test
!unzip -q ./test.zip
%cd /content/DMVFN
Inference
それでは、Video Predictionを実行していきます。
モデルを選択し、ランダムシードを固定します。
%cd /content/DMVFN
pretrained_weights = 'city' #@param ['city', 'kitti', 'vimeo']
model_path = './pretrained_models/dmvfn_' + pretrained_weights + '.pkl'
seed = 12
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
推論を実行します。
# input image dir
image_dir = "./data/cityscapes/test/000000"
image_list = glob.glob(image_dir + "/*.png")
image_list.sort()
# output dir
out_dir = "./output"
os.makedirs('./output', exist_ok=True)
# load model
model = Model(load_path=model_path, training=False)
pred_num = 3 # max->len(image_list) - 2
with torch.no_grad():
for i in range(pred_num):
if i == 0:
# load image
cvimg_0 = cv2.imread(image_list[0])
cvimg_1 = cv2.imread(image_list[1])
else:
cvimg_0 = cvimg_1
cvimg_1 = pred
# preprocess image
img_0 = cvimg_0.transpose(2, 0, 1).astype('float32')
img_1 = cvimg_1.transpose(2, 0, 1).astype('float32')
img = torch.cat([torch.tensor(img_0),torch.tensor(img_1)], dim=0)
img = img.unsqueeze(0).unsqueeze(0).to(device, non_blocking=True) # NCHW
img = img.to(device, non_blocking=True) / 255.
# inference
pred = model.eval(img, 'single_test') # 1CHW
# post process
pred = np.array(pred.cpu().squeeze() * 255).transpose(1, 2, 0) # CHW -> HWC
# save result image
cv2.imwrite(
os.path.join(out_dir, f'pred_{i:06}.png'), pred)
推論結果を表示して確認します。
show_num = pred_num + 2
# show original images
fig = plt.figure(figsize=(15, 10))
for i in range(show_num):
ax = fig.add_subplot(1, show_num, i+1)
plt.title(os.path.basename(image_list[i]), fontsize=16)
ax.axis('off')
ax.imshow( Image.open(image_list[i]) )
plt.show()
# show predict images
pred_list = glob.glob(os.path.join(out_dir, '*.png'))
pred_list.sort()
fig = plt.figure(figsize=(15, 10))
for i in range(show_num):
if i < 2:
image_path = image_list[i]
else:
image_path = pred_list[i-2]
ax = fig.add_subplot(1, show_num, i+1)
plt.title(os.path.basename(image_path), fontsize=16)
ax.axis('off')
ax.imshow( Image.open(image_path) )
plt.show()
pred_***が予測されたフレーム画像です。
正解データより若干進みが遅いですが、車が進んでいるフレームが予測されています。
まとめ
本記事では、DMVFNを用いたVideo Predictionをご紹介しました。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
リンク
リンク
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
参考文献
1.
論文 - A Dynamic Multi-Scale Voxel Flow Network for Video Prediction
2. GitHub - megvii-research/CVPR2023-DMVFN
0 件のコメント :
コメントを投稿