[ShAPO] AIで6Dポーズ推定する

2022年11月17日木曜日

Artificial Intelligence

本記事では、ShAPOと呼ばれる機械学習手法を用いて、6Dポーズ推定などを行う方法をご紹介します。

アイキャッチ
出典: zubair-irshad/shapo

ShAPO

概要

Multi-Object Shape, Appearance, and Pose Optimization(以下、ShAPO)は、単一のRGB-Dから3Dオブジェクトに関連するいくつかのタスクを実現する技術です。

ShAPOが実現するタスクは以下の通りです。

  • Joint multi-object detection
  • 3D textured reconstruction
  • 6D object pose estimation
  • 6D object size estimation

ShAPOの特徴は、それぞれのオブジェクトインスタンスのマスクと共に、形状、外観、ポーズの潜在コードを回帰するためのSingle shotパイプラインです。
このパイプラインでは、最初にオブジェクトをそれぞれの形状と外観空間に埋め込むように学習されます。また、octree-basedの微分可能な最適化ステップにより、学習した潜在空間の下で、オブジェクトの形状、ポーズ、及び外観の改善を行っています。

この暗黙的なテクスチャオブジェクト表現により、3Dメッシュを必要とせず、3Dオブジェクトの様々なタスクの再構成において優れた精度を実現しています。

Architecture
出典: ShAPO: Implicit Representations for Multi-Object Shape, Appearance, and Pose Optimization

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、6Dポーズ推定などを行っていきます。

デモ(Colaboratory)

それでは、実際に動かしながら6D pose estimationを行っていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content

!git clone https://github.com/zubair-irshad/shapo.git

# Commits on Nov 8, 2022
%cd /content/shapo
!git checkout 9269bf4ba74bc86022449ebc598cd79b9c2630c4

次にライブラリのインストールとテストデータのダウンロードを行います。

%cd /content/shapo

!pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html

!wget https://www.dropbox.com/s/cvqyhr67zpxyq36/test_subset.tar.xz?dl=1 -O test_subset.tar.xz
!tar -xvf test_subset.tar.xz

!wget https://www.dropbox.com/s/929kz7zuxw8jajy/sdf_rgb_pretrained.tar.xz?dl=1 -O sdf_rgb_pretrained.tar.xz
! tar -xvf sdf_rgb_pretrained.tar.xz
!wget https://www.dropbox.com/s/nrsl67ir6fml9ro/ckpts.tar.xz?dl=1 -O ckpts.tar.xz
!tar -xvf ckpts.tar.xz

!mkdir test_data
!mv test_subset/* test_data
!mv sdf_rgb_pretrained test_data

最後にライブラリをインポートします。

%cd /content/shapo

import argparse
import pathlib
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import open3d as o3d
import matplotlib.pyplot as plt
import os
import time
import pytorch_lightning as pl
import _pickle as cPickle
import os, sys

from simnet.lib.net import common
from simnet.lib import camera
from simnet.lib.net.panoptic_trainer import PanopticModel
from utils.nocs_utils import load_img_NOCS, create_input_norm
from utils.viz_utils import depth2inv, viz_inv_depth
from utils.transform_utils import get_gt_pointclouds, transform_coordinates_3d, calculate_2d_projections
from utils.transform_utils import project, get_pc_absposes, transform_pcd_to_canonical
from utils.viz_utils import save_projected_points, draw_bboxes, line_set_mesh, display_gird, draw_geometries, show_projected_points
from sdf_latent_codes.get_surface_pointcloud import get_surface_pointclouds_octgrid_viz, get_surface_pointclouds
from sdf_latent_codes.get_rgb import get_rgbnet, get_rgb_from_rgbnet
from sdf_latent_codes.get_surface_pointcloud import get_sdfnet
from sdf_latent_codes.get_rgb import get_rgbnet
from utils.transform_utils import get_abs_pose_vector_from_matrix, get_abs_pose_from_vector
from utils.nocs_utils import get_masks_out, get_aligned_masks_segout, get_masked_textured_pointclouds
from opt.optimization_all import Optimizer

device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device is", device)

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

それでは、まず論文発表元が公開する学習済みモデルをロードします。

%cd /content

# load config
sys.argv = ['', '@shapo/configs/net_config.txt']
parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
common.add_train_args(parser)
app_group = parser.add_argument_group('app')
app_group.add_argument('--app_output', default='inference', type=str)
app_group.add_argument('--result_name', default='shapo_inference', type=str)
app_group.add_argument('--data_dir', default='shapo/test_data', type=str)

# load model
hparams = parser.parse_args()
min_confidence = 0.50
hparams.checkpoint = 'shapo/ckpts/shapo_real.ckpt'
model = PanopticModel(hparams, 0, None, None)
model.eval()
if device == "cuda":
    model.cuda()

# set dataset 
data_path = open(os.path.join(hparams.data_dir, 'Real', 'test_list_subset.txt')).read().splitlines()
_CAMERA = camera.NOCS_Real()
sdf_pretrained_dir = os.path.join(hparams.data_dir, 'sdf_rgb_pretrained')
rgb_model_dir = os.path.join(hparams.data_dir, 'sdf_rgb_pretrained', 'rgb_net_weights')

推論

ロードしたモデルを使って、6Dポーズ推定などを行います。

#num from 0 to 3 (small subset of data)
num = 1
img_full_path = os.path.join(hparams.data_dir, 'Real', data_path[num])
img_vis = cv2.imread(img_full_path + '_color.png')

left_linear, depth, actual_depth = load_img_NOCS(img_full_path + '_color.png' , img_full_path + '_depth.png')
input = create_input_norm(left_linear, depth)[None, :, :, :]
    
input = input.to(torch.device(device))

with torch.no_grad():
    seg_output, _, _ , pose_output = model.forward(input)
    _, _, _ , pose_output = model.forward(input)
    shape_emb_outputs, appearance_emb_outputs, abs_pose_outputs, peak_output, scores_out, output_indices = pose_output.compute_shape_pose_and_appearance(min_confidence,is_target = False)

初めに、顕著性マップと深度画像を出力します。

display_gird(img_vis, depth, peak_output)
depth result

続いて、点群を用いて、3Dオブジェクトを表現します。

rotated_pcds = []
points_2d = []
box_obb = []
axes = [] 
lod = 7 # Choose from LOD 3-7 here, going higher means more memory and finer details

# Here we visualize the output of our network
for j in range(len(shape_emb_outputs)):
    shape_emb = shape_emb_outputs[j]
    # appearance_emb = appearance_emb_putputs[j]
    appearance_emb = appearance_emb_outputs[j]
    is_oct_grid = True
    if is_oct_grid:
        # pcd_dsdf_actual = get_surface_pointclouds_octgrid_sparse(shape_emb, sdf_latent_code_dir = sdf_pretrained_dir, lods=[2,3,4,5,6])
        pcd_dsdf, nrm_dsdf = get_surface_pointclouds_octgrid_viz(shape_emb, lod=lod, sdf_latent_code_dir=sdf_pretrained_dir)
    else:
        pcd_dsdf = get_surface_pointclouds(shape_emb)
    rgbnet = get_rgbnet(rgb_model_dir)
    pred_rgb = get_rgb_from_rgbnet(shape_emb, pcd_dsdf, appearance_emb, rgbnet)
    rotated_pc, rotated_box, _ = get_pc_absposes(abs_pose_outputs[j], pcd_dsdf)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.copy(rotated_pc))
    pcd.colors = o3d.utility.Vector3dVector(pred_rgb.detach().cpu().numpy())
    pcd.normals = o3d.utility.Vector3dVector(nrm_dsdf)
    rotated_pcds.append(pcd)
    
    cylinder_segments = line_set_mesh(rotated_box)
    # draw 3D bounding boxes around the object
    for k in range(len(cylinder_segments)):
      rotated_pcds.append(cylinder_segments[k])

    # draw 3D coordinate frames around each object
    mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
    T = abs_pose_outputs[j].camera_T_object
    mesh_t = mesh_frame.transform(T)
    rotated_pcds.append(mesh_t)
    
    points_mesh = camera.convert_points_to_homopoints(rotated_pc.T)
    points_2d.append(project(_CAMERA.K_matrix, points_mesh).T)
    #2D output
    points_obb = camera.convert_points_to_homopoints(np.array(rotated_box).T)
    box_obb.append(project(_CAMERA.K_matrix, points_obb).T)
    xyz_axis = 0.3*np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]).transpose()
    sRT = abs_pose_outputs[j].camera_T_object @ abs_pose_outputs[j].scale_matrix
    transformed_axes = transform_coordinates_3d(xyz_axis, sRT)
    axes.append(calculate_2d_projections(transformed_axes, _CAMERA.K_matrix[:3,:3]))
draw_geometries(rotated_pcds)

予測結果は以下の通りです。

3d shape

最後に6Dポーズの出力結果をプロットします。

color_img = np.copy(img_vis) 
projected_points_img = show_projected_points(color_img, points_2d)
colors_box = [(63, 237, 234)]
im = np.array(np.copy(img_vis)).copy()
for k in range(len(colors_box)):
    for points_2d, axis in zip(box_obb, axes):
        points_2d = np.array(points_2d)
        im = draw_bboxes(im, points_2d, axis, colors_box[k])

plt.gca().invert_yaxis()
plt.axis('off')
plt.imshow(im[...,::-1])
plt.show()

出力結果は以下の通りです。

draw
6d pose

まとめ

本記事では、ShAPOをを用いて6Dポーズ推定を行う方法をご紹介しました。
3Dメッシュ不要のため学習データ生成の負荷軽減が想定されますね。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - ShAPO: Implicit Representations for Multi-Object Shape, Appearance, and Pose Optimization

2. GitHub - zubair-irshad/shapo

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology