[ShAPO] AIで6Dポーズ推定する

2022年11月17日木曜日

Artificial Intelligence

本記事では、ShAPOと呼ばれる機械学習手法を用いて、6Dポーズ推定などを行う方法をご紹介します。

アイキャッチ
出典: zubair-irshad/shapo

ShAPO

概要

Multi-Object Shape, Appearance, and Pose Optimization(以下、ShAPO)は、単一のRGB-Dから3Dオブジェクトに関連するいくつかのタスクを実現する技術です。

ShAPOが実現するタスクは以下の通りです。

  • Joint multi-object detection
  • 3D textured reconstruction
  • 6D object pose estimation
  • 6D object size estimation

ShAPOの特徴は、それぞれのオブジェクトインスタンスのマスクと共に、形状、外観、ポーズの潜在コードを回帰するためのSingle shotパイプラインです。
このパイプラインでは、最初にオブジェクトをそれぞれの形状と外観空間に埋め込むように学習されます。また、octree-basedの微分可能な最適化ステップにより、学習した潜在空間の下で、オブジェクトの形状、ポーズ、及び外観の改善を行っています。

この暗黙的なテクスチャオブジェクト表現により、3Dメッシュを必要とせず、3Dオブジェクトの様々なタスクの再構成において優れた精度を実現しています。

Architecture
出典: ShAPO: Implicit Representations for Multi-Object Shape, Appearance, and Pose Optimization

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、6Dポーズ推定などを行っていきます。

スポンサーリンク

デモ(Colaboratory)

それでは、実際に動かしながら6D pose estimationを行っていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

  1. %cd /content
  2.  
  3. !git clone https://github.com/zubair-irshad/shapo.git
  4.  
  5. # Commits on Nov 8, 2022
  6. %cd /content/shapo
  7. !git checkout 9269bf4ba74bc86022449ebc598cd79b9c2630c4

次にライブラリのインストールとテストデータのダウンロードを行います。

  1. %cd /content/shapo
  2.  
  3. !pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
  4. !pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html
  5.  
  6. !wget https://www.dropbox.com/s/cvqyhr67zpxyq36/test_subset.tar.xz?dl=1 -O test_subset.tar.xz
  7. !tar -xvf test_subset.tar.xz
  8.  
  9. !wget https://www.dropbox.com/s/929kz7zuxw8jajy/sdf_rgb_pretrained.tar.xz?dl=1 -O sdf_rgb_pretrained.tar.xz
  10. ! tar -xvf sdf_rgb_pretrained.tar.xz
  11. !wget https://www.dropbox.com/s/nrsl67ir6fml9ro/ckpts.tar.xz?dl=1 -O ckpts.tar.xz
  12. !tar -xvf ckpts.tar.xz
  13.  
  14. !mkdir test_data
  15. !mv test_subset/* test_data
  16. !mv sdf_rgb_pretrained test_data

最後にライブラリをインポートします。

  1. %cd /content/shapo
  2.  
  3. import argparse
  4. import pathlib
  5. import cv2
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. import open3d as o3d
  10. import matplotlib.pyplot as plt
  11. import os
  12. import time
  13. import pytorch_lightning as pl
  14. import _pickle as cPickle
  15. import os, sys
  16.  
  17. from simnet.lib.net import common
  18. from simnet.lib import camera
  19. from simnet.lib.net.panoptic_trainer import PanopticModel
  20. from utils.nocs_utils import load_img_NOCS, create_input_norm
  21. from utils.viz_utils import depth2inv, viz_inv_depth
  22. from utils.transform_utils import get_gt_pointclouds, transform_coordinates_3d, calculate_2d_projections
  23. from utils.transform_utils import project, get_pc_absposes, transform_pcd_to_canonical
  24. from utils.viz_utils import save_projected_points, draw_bboxes, line_set_mesh, display_gird, draw_geometries, show_projected_points
  25. from sdf_latent_codes.get_surface_pointcloud import get_surface_pointclouds_octgrid_viz, get_surface_pointclouds
  26. from sdf_latent_codes.get_rgb import get_rgbnet, get_rgb_from_rgbnet
  27. from sdf_latent_codes.get_surface_pointcloud import get_sdfnet
  28. from sdf_latent_codes.get_rgb import get_rgbnet
  29. from utils.transform_utils import get_abs_pose_vector_from_matrix, get_abs_pose_from_vector
  30. from utils.nocs_utils import get_masks_out, get_aligned_masks_segout, get_masked_textured_pointclouds
  31. from opt.optimization_all import Optimizer
  32.  
  33. device = 'cuda' if torch.cuda.is_available() else "cpu"
  34. print("using device is", device)

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

それでは、まず論文発表元が公開する学習済みモデルをロードします。

  1. %cd /content
  2.  
  3. # load config
  4. sys.argv = ['', '@shapo/configs/net_config.txt']
  5. parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
  6. common.add_train_args(parser)
  7. app_group = parser.add_argument_group('app')
  8. app_group.add_argument('--app_output', default='inference', type=str)
  9. app_group.add_argument('--result_name', default='shapo_inference', type=str)
  10. app_group.add_argument('--data_dir', default='shapo/test_data', type=str)
  11.  
  12. # load model
  13. hparams = parser.parse_args()
  14. min_confidence = 0.50
  15. hparams.checkpoint = 'shapo/ckpts/shapo_real.ckpt'
  16. model = PanopticModel(hparams, 0, None, None)
  17. model.eval()
  18. if device == "cuda":
  19. model.cuda()
  20.  
  21. # set dataset
  22. data_path = open(os.path.join(hparams.data_dir, 'Real', 'test_list_subset.txt')).read().splitlines()
  23. _CAMERA = camera.NOCS_Real()
  24. sdf_pretrained_dir = os.path.join(hparams.data_dir, 'sdf_rgb_pretrained')
  25. rgb_model_dir = os.path.join(hparams.data_dir, 'sdf_rgb_pretrained', 'rgb_net_weights')

推論

ロードしたモデルを使って、6Dポーズ推定などを行います。

  1. #num from 0 to 3 (small subset of data)
  2. num = 1
  3. img_full_path = os.path.join(hparams.data_dir, 'Real', data_path[num])
  4. img_vis = cv2.imread(img_full_path + '_color.png')
  5.  
  6. left_linear, depth, actual_depth = load_img_NOCS(img_full_path + '_color.png' , img_full_path + '_depth.png')
  7. input = create_input_norm(left_linear, depth)[None, :, :, :]
  8. input = input.to(torch.device(device))
  9.  
  10. with torch.no_grad():
  11. seg_output, _, _ , pose_output = model.forward(input)
  12. _, _, _ , pose_output = model.forward(input)
  13. shape_emb_outputs, appearance_emb_outputs, abs_pose_outputs, peak_output, scores_out, output_indices = pose_output.compute_shape_pose_and_appearance(min_confidence,is_target = False)

初めに、顕著性マップと深度画像を出力します。

  1. display_gird(img_vis, depth, peak_output)
depth result

続いて、点群を用いて、3Dオブジェクトを表現します。

  1. rotated_pcds = []
  2. points_2d = []
  3. box_obb = []
  4. axes = []
  5. lod = 7 # Choose from LOD 3-7 here, going higher means more memory and finer details
  6.  
  7. # Here we visualize the output of our network
  8. for j in range(len(shape_emb_outputs)):
  9. shape_emb = shape_emb_outputs[j]
  10. # appearance_emb = appearance_emb_putputs[j]
  11. appearance_emb = appearance_emb_outputs[j]
  12. is_oct_grid = True
  13. if is_oct_grid:
  14. # pcd_dsdf_actual = get_surface_pointclouds_octgrid_sparse(shape_emb, sdf_latent_code_dir = sdf_pretrained_dir, lods=[2,3,4,5,6])
  15. pcd_dsdf, nrm_dsdf = get_surface_pointclouds_octgrid_viz(shape_emb, lod=lod, sdf_latent_code_dir=sdf_pretrained_dir)
  16. else:
  17. pcd_dsdf = get_surface_pointclouds(shape_emb)
  18. rgbnet = get_rgbnet(rgb_model_dir)
  19. pred_rgb = get_rgb_from_rgbnet(shape_emb, pcd_dsdf, appearance_emb, rgbnet)
  20. rotated_pc, rotated_box, _ = get_pc_absposes(abs_pose_outputs[j], pcd_dsdf)
  21. pcd = o3d.geometry.PointCloud()
  22. pcd.points = o3d.utility.Vector3dVector(np.copy(rotated_pc))
  23. pcd.colors = o3d.utility.Vector3dVector(pred_rgb.detach().cpu().numpy())
  24. pcd.normals = o3d.utility.Vector3dVector(nrm_dsdf)
  25. rotated_pcds.append(pcd)
  26. cylinder_segments = line_set_mesh(rotated_box)
  27. # draw 3D bounding boxes around the object
  28. for k in range(len(cylinder_segments)):
  29. rotated_pcds.append(cylinder_segments[k])
  30.  
  31. # draw 3D coordinate frames around each object
  32. mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
  33. T = abs_pose_outputs[j].camera_T_object
  34. mesh_t = mesh_frame.transform(T)
  35. rotated_pcds.append(mesh_t)
  36. points_mesh = camera.convert_points_to_homopoints(rotated_pc.T)
  37. points_2d.append(project(_CAMERA.K_matrix, points_mesh).T)
  38. #2D output
  39. points_obb = camera.convert_points_to_homopoints(np.array(rotated_box).T)
  40. box_obb.append(project(_CAMERA.K_matrix, points_obb).T)
  41. xyz_axis = 0.3*np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]).transpose()
  42. sRT = abs_pose_outputs[j].camera_T_object @ abs_pose_outputs[j].scale_matrix
  43. transformed_axes = transform_coordinates_3d(xyz_axis, sRT)
  44. axes.append(calculate_2d_projections(transformed_axes, _CAMERA.K_matrix[:3,:3]))
  45. draw_geometries(rotated_pcds)

予測結果は以下の通りです。

3d shape

最後に6Dポーズの出力結果をプロットします。

  1. color_img = np.copy(img_vis)
  2. projected_points_img = show_projected_points(color_img, points_2d)
  3. colors_box = [(63, 237, 234)]
  4. im = np.array(np.copy(img_vis)).copy()
  5. for k in range(len(colors_box)):
  6. for points_2d, axis in zip(box_obb, axes):
  7. points_2d = np.array(points_2d)
  8. im = draw_bboxes(im, points_2d, axis, colors_box[k])
  9.  
  10. plt.gca().invert_yaxis()
  11. plt.axis('off')
  12. plt.imshow(im[...,::-1])
  13. plt.show()

出力結果は以下の通りです。

draw
6d pose

まとめ

本記事では、ShAPOをを用いて6Dポーズ推定を行う方法をご紹介しました。
3Dメッシュ不要のため学習データ生成の負荷軽減が想定されますね。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

スポンサーリンク

参考文献

1.  論文 - ShAPO: Implicit Representations for Multi-Object Shape, Appearance, and Pose Optimization

2. GitHub - zubair-irshad/shapo

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology