[ClipSeg] AIでテキストに応じたImage Segmentation

2022年10月1日土曜日

Artificial Intelligence

本記事では、ClipSegと呼ばれる機械学習手法を用いてテキストに応じたImage segmentationを行う方法をご紹介します。

アイキャッチ

ClipSeg

概要

ClipSegは、任意のテキストプロンプトに基づいた物体を検出するセグメンテーション技術です。

従来のセグメンテーションは検出対象のオブジェクト、および対応したクラスが固定されてトレーニングされます。このため、クラスの追加には再トレーニングが伴うため追加コストが発生します。
ClipSegでは、CLIPをバックボーンとしTransformerベースのデコーダーを拡張することにより、任意のプロンプトに基づいてゼロショットのセグメンテーションを実現しています。

overview

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、任意のプロンプトに基づいたImage Segmentationを行っていきます。

デモ(Colaboratory)

それでは、実際に動かしながらImage Segmentationを動かしていきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content
!git clone https://github.com/timojl/clipseg.git

# Commits on Sep 27, 2022
%cd /content/clipseg
!git checkout 515ca6ec2d066d447240c1dd79f3bbbee685bd29

次にライブラリをインストールします。

!pip install git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1

最後にライブラリをインポートします。

%cd /content/clipseg

import os

import torch
import requests

from models.clipseg import CLIPDensePredT
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt

device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device is", device)

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

ここでは、論文発表元が公開する事前学習済みモデルをGoogle Colaboratoryにダウンロードします。

%cd /content/clipseg
!mkdir pretrained

if not os.path.exists('pretrained/weights.zip'):
  !wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O pretrained/weights.zip
  !unzip -d pretrained/weights -j pretrained/weights.zip

Image Segmentation using Text

それでは、実際にテキストに応じたセグメンテーションを試してみます。

まずダウンロードしたモデルをロードします。

%cd /content/clipseg

# load model
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
model.eval();

model.load_state_dict(torch.load('pretrained/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False);

続いて、入力する画像をセットアップします。

!wget -c https://www.pakutaso.com/shared/img/thumb/smIMGL4174_TP_V4.jpg \
      -O test_01.jpg
      
%cd /content/clipseg

# 画像のロード
input_image = Image.open('test_01.jpg')

# Normalize
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Resize((352, 352)),
])
img = transform(input_image).unsqueeze(0)

input_image

以下の画像をセグメンテーションに使用します。

入力画像

最後にセグメンテーションを実行します。

prompts = ['a lemon', 'a girl', 'wood']
num_of_p = len(prompts)

# predict
with torch.no_grad():
  preds = model(img.repeat(num_of_p,1,1,1), prompts)[0]

# visualize prediction
_, ax = plt.subplots(1, 5, figsize=(15, num_of_p))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(input_image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(num_of_p)];
[ax[i+1].text(0, -15, prompts[i]) for i in range(num_of_p)];

出力結果は以下の通りです。

レモンや女性などうまくセグメンテーションできていそうです。woodに関しては、そもそも人間の目視でも判別しにくいオブジェクトなので木製に見える窓枠を検出しているだけでも十分だと言えそうです。

予測結果

まとめ

本記事では、ClipSegを用いて任意のテキストに応じたImage Segmentationを行う方法をご紹介しました。

自然言語表現によるセグメンテーションにより活用の幅が広がりそうです。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - Image Segmentation Using Text and Image Prompts

2. GitHub - https://github.com/timojl/clipseg

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology