[CLIP-ODS] 任意のキーワードに対応した物体検出 [画像検索]

2022年1月26日水曜日

Artificial Intelligence

本記事では、CLIP Object Detection and Segmentationと呼ばれる機械学習手法を使って、画像中から入力したキーワードに適した物体を検出する方法を紹介します。

アイキャッチ

CLIP Object Detection and Segmentation(CLIP-ODS)

概要

CLIP ODSは、事前学習済みの視覚言語モデルCLIPを用いて、ゼロショットセマンティックセグメンテーションを実現する手法です。

CLIPとは、ImageToTextタスクを実現する手法で、画像を認識し適切なキャプション(画像の説明)を生成する技術です。

下図のように入力された画像の特徴から、対応する画像の説明をテキストで出力します。

CLIP概要図

CLIPは、他にもTextToImageタスクと呼ばれるテキストから画像を生成する技術などでも利用されており、CLIPをベースとした様々な手法が提案されています。

TextToImageの一例であるFuseDreamを下記の記事で紹介しているので宜しければご参照ください。

CLIP ODSは、上記CLIPの技術を応用し入力されたキーワードに応じて、事前学習していないカテゴリの画像(ゼロショット)をセマンティックセグメンテーションします。

セマンティックセグメンテーションとは、画像中の全てのピクセルをクラスに分類することを指し、下図の例では「wheel of bicycle」や「face of dog」など入力テキストに対応したピクセル部分が図示されています。
セマンティックセグメンテーション結果サンプル
出典: https://github.com/shonenkov/CLIP-ODS

また、CLIP ODSは以下のような構成で実現されています。

CLIP-ODSアーキテクチャ
出典: https://arxiv.org/pdf/2112.14757.pdf

簡単な理解としては、以下のとおりです。

  1. 入力画像からマスク画像(上図白黒画像)を生成
    この時、画像中のオブジェクトをピクセル単位で全て分類
  2. CLIPを用いてマスク画像をCLASSに分類
  3. 入力されたテキストに近しいCLASSのマスク画像を選定
  4. 入力テキストに応じたセマンティックセグメンテーションを出力

デモ(Colaboratory)

それでは早速CLIP-ODSをGoogle Colaboratoryで動かしていきます。

なお、これから紹介するソースコードは全てこちらのGitHubに掲載しております。以下のボタンをクリックするとColaboratoryを開くことも可能です。

Open In Colab

また、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

まず、Pytorchや、その他ライブラリをインストールします。

import subprocess

CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
elif "11." in CUDA_version:
    torch_version_suffix = "+cu110"
else:
  torch_version_suffix = "+cpu"

print("CUDA version:", CUDA_version, " Suffix:", torch_version_suffix)

上記ではColaboratoryにインストールされているCudaドライバーのバージョンを確認しています。
以降ではCudaドライバーに対応したPytorchをインストールします。

!pip install --upgrade pip
!pip uninstall torch -y
!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

!pip install ftfy regex
!pip install clip-ods==0.0.1rc2

モデルのロード

次に、学習済みモデルをロードします。学習済みモデルはclip_odsライブラリを介して自動でダウンロードされます。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = torch.device(device)
base_model = 'RN50x4' #@param ["ViT-B/32","RN50","RN101","RN50x4"]
model, preprocess = clip.load("RN50x4", device=device)
clip_detector = CLIPDetectorV1(model, preprocess, device)

物体検出用画像の準備

uploadを選択するとPCからアップロードした画像を物体検出に使用します。
sampleを選択するとCLIP ODSライブラリ提供元から提供された画像をダウンロードして使用します。

本記事では、以下の画像を使用します。

物体検出対象画像
%cd /content/
!mkdir images
%cd /content/images

image_type ='upload' #@param ['sample', 'upload']
# sample選択時
if image_type == "sample":
  for google_drive_file_id in ['1nMPyWquE7U7_fuh0Rk4ZGgeWAtCFEqi8','1bsaZ1FSAfMByWeT4Ftr5_J5YaWARUu-x', '1lwhhDDBGztqxW4AVYqjjwKMpda19Vpbu']:
    gdown.download(f'https://drive.google.com/uc?id={google_drive_file_id}', './', quiet=True)
  image_path = "/content/images/example11.jpg"
# upload選択時
else:
  uploaded = files.upload()
  uploaded = list(uploaded.keys())
  file_name = uploaded[0]
  image_path = "/content/images/" + file_name

# 画像の表示
image = Image.open(image_path).convert("RGB")
plt.figure(figsize=(6, 6))
plt.imshow(image)

セマンティックセグメンテーション

入力画像に対してセマンティックセグメンテーションを行います。
本記事では、「watch」を入力し、画像中から"時計"を検出してみます。

coords, masks = clip_detector.get_coords_and_masks(Image.open(image_path))
anchor_features = clip_detector.get_anchor_features(Image.open(image_path), coords)


text = 'watch' #@param {type:"string"}

result = clip_detector.detect_by_text(
    texts=[text],
    img=Image.open(image_path),
    coords=coords, # detect coords
    masks=masks, # Segmentation mask
    anchor_features=anchor_features,
    skip_box_thr=0.7
)

img = Image.open(image_path)
colour = (0,255,0)

img = clip_detector.draw(
    img, 
    result,
    label=text,
    colour=colour,
    font_colour=colour,
    font_scale=0.5, 
    font_thickness=1,
)

plt.figure(num=None, figsize=(8, 8), dpi=120, facecolor='w', edgecolor='k')
plt.imshow(img)

セマンティックセグメンテーションの結果は以下の通りです。

セマンティックセグメンテーション結果

物体検出

セグメンテーションマスクをオフすることで、矩形のみを表示することもできます。

%%time

text = 'watch' #@param {type:"string"}

result = clip_detector.detect_by_text(
    texts=[text],
    img=Image.open(image_path),
    coords=coords,
    #masks=masks, # Segmentation maskなし
    anchor_features=anchor_features,
    skip_box_thr=0.7
)

img = Image.open(image_path)
colour = (0,255,0)

img = clip_detector.draw(
    img, 
    result,
    label=text,
    colour=colour,
    font_colour=colour,
    font_scale=0.5, 
    font_thickness=1,
)

plt.figure(num=None, figsize=(8, 8), dpi=120, facecolor='w', edgecolor='k')
plt.imshow(img)

物体検出結果は以下の通りです。

物体検出結果

検出精度にややずれがありますね。

下記は2022年にMeta(旧Facebook)が発表したDeticです。検出精度は最新なだけありかなりのものです。

まとめ

本記事では、CLIP-ODSと呼ばれる機械学習手法を用いて、キーワードに応じたセマンティックセグメンテーションを行いました。
この技術の精度が向上すれば、膨大な写真の中からキーワードに対応する画像のみを切り出して出力するなんてことが可能になるかもしれません。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
機械学習について学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - A Simple Baseline for Zero-shot Semantic Segmentation with Pre-trained Vision-language Model

2. GitHub - shonenkov/CLIP-ODS

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology