本記事では、ClipCapと呼ばれる機械学習手法を用いて、Pythonで画像のキャプションを生成する方法を紹介します。
ClipCapとは
ClipCapは、画像のキャプションタスクをシンプル、高速、軽量に実現した機械学習手法です。
従来、画像キャプション生成は複雑なタスクであり事前にトレーニングされた検出ネットワークが必要でした。
ClipCapでは、画像とキャプションのみで画像キャプショントレーニングが可能であるため、モデルのトレーニング時間は従来の方法より高速でありながら、
最先端技術と同様の精度を示しています。
また、ClipCapは事前に多くの画像を学習しているCLIPモデルを使用しており、追加のsupervisionなしに任意の画像のセマンティックエンコーディングの生成を実現しています。
この手法の重要なアイデアはトレーニング済みのCLIPモデルと言語モデルを組み合わせ、単純なマッピングネットワークを使って画像キャプションタスクを実現している点です。
ClipCapの導入手順
セットアップ1: conda環境構築
それでは早速、開発環境にClipCapをセットアップしていきます。
動作確認は下記の環境で行っています。
OS: Ubuntu 18.04.3 LTS
GPU: GeForce GTX 1080
なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
それでは、ClipCapをインストールしていきます。他の機械学習環境に影響を与えないためにMinicondaの仮想環境上に構築していきます。Minicondaのインストール手順は公式ドキュメントをご参照ください。
conda環境構築手順を下記に記載します。
# codeをgit clone
$ git clone https://github.com/rmokady/CLIP_prefix_caption.git
# environment.yamlからconda環境を作成
$ conda env create -f CLIP_prefix_caption/environment.yml
# cloneしたコードを作成したconda環境内に配置する
$ mv CLIP_prefix_caption/ clip_prefix_caption/
# 作成した環境をアクティベート
$ conda activate clip_prefix_caption
$ cd clip_prefix_caption/CLIP_prefix_caption
# Proxy配下の場合はインストール前に下記を設定
$ export http_proxy="http://"username":"password"@proxy:port"
$ export https_proxy="http://"username":"password"@proxy:port"
# ライブラリをインストール
$ pip3 install transformers
$ pip3 install git+https://github.com/openai/CLIP.git
# torchの動作確認
$ python
>>> import torch
>>> print(torch.cuda.is_available())
True
>>> print(torch.cuda.device_count())
2
以上でconda環境構築は完了です。
セットアップ2: 学習済みモデルとテストデータ準備
続いて学習済みモデルとテストデータを準備していきます。
本記事のテストデータはぱくたそ様の以下の2画像を使用します。
# ディレクトリ確認
$ pwd
clip_prefix_caption/CLIP_prefix_caption
# 学習済みモデル格納用ディレクトリ作成
$ mkdir models
# 学習済みモデルをダウンロード
# https://drive.google.com/file/d/14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT/view?usp=sharing
# https://drive.google.com/file/d/1IdaBtMSvtyzF0ByVaBHtvM0JYSXRExRX/view?usp=sharing
# modelsに学習済みモデルを格納
# テスト用画像格納用ディレクトリ作成
$ mkdir inputs
# inputsにテスト用画像を格納
画像キャプションの実行
最後に、画像キャプションを実行します。
Githubに公開されているCLIP_prefix_caption/notebooks/clip_prefix_captioning_inference.ipynb
は、Google
Colaboratoryで動作することを前提としているためこちらを参考に下記CLIP_prefix_caption/demo.py
を作成します。
なお、下記と同様のコードをGitHubに公開しています。
https://github.com/kaz12tech/CLIP_prefix_caption/blob/main/demo.py
import argparse
import os
from torch import nn
import numpy as np
import torch
import torch.nn.functional as nnf
from typing import Tuple, List, Union, Optional
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import clip
import glob
import skimage.io as io
import PIL.Image
from tqdm import tqdm, trange
N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VS = Union[Tuple[V, ...], List[V]]
VN = Union[V, N]
VNS = Union[VS, N]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]
D = torch.device
CPU = torch.device('cpu')
def get_device(device_id: int) -> D:
if not torch.cuda.is_available():
return CPU
device_id = min(torch.cuda.device_count() - 1, device_id)
return torch.device(f'cuda:{device_id}')
CUDA = get_device
#@title Model
class MLP(nn.Module):
def forward(self, x: T) -> T:
return self.model(x)
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
super(MLP, self).__init__()
layers = []
for i in range(len(sizes) -1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
class ClipCaptionModel(nn.Module):
#@functools.lru_cache #FIXME
def get_dummy_token(self, batch_size: int, device: D) -> T:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
embedding_text = self.gpt.transformer.wte(tokens)
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
#print(embedding_text.size()) #torch.Size([5, 67, 768])
#print(prefix_projections.size()) #torch.Size([5, 1, 768])
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
return out
def __init__(self, prefix_length: int, prefix_size: int = 512):
super(ClipCaptionModel, self).__init__()
self.prefix_length = prefix_length
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
if prefix_length > 10: # not enough memory
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
else:
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
class ClipCaptionPrefix(ClipCaptionModel):
def parameters(self, recurse: bool = True):
return self.clip_project.parameters()
def train(self, mode: bool = True):
super(ClipCaptionPrefix, self).train(mode)
self.gpt.eval()
return self
def generate2(
model,
tokenizer,
tokens=None,
prompt=None,
embed=None,
entry_count=1,
entry_length=67, # maximum number of words
top_p=0.8,
temperature=1.,
stop_token: str = '.',
):
model.eval()
generated_num = 0
generated_list = []
stop_token_index = tokenizer.encode(stop_token)[0]
filter_value = -float("Inf")
device = next(model.parameters()).device
with torch.no_grad():
for entry_idx in trange(entry_count):
if embed is not None:
generated = embed
else:
if tokens is None:
tokens = torch.tensor(tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = model.gpt.transformer.wte(tokens)
for i in range(entry_length):
outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
next_token = torch.argmax(logits, -1).unsqueeze(0)
next_token_embed = model.gpt.transformer.wte(next_token)
if tokens is None:
tokens = next_token
else:
tokens = torch.cat((tokens, next_token), dim=1)
generated = torch.cat((generated, next_token_embed), dim=1)
if stop_token_index == next_token.item():
break
output_list = list(tokens.squeeze().cpu().numpy())
output_text = tokenizer.decode(output_list)
generated_list.append(output_text)
return generated_list[0]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--img_path", type=str, default='./inputs/')
parser.add_argument("--model_path", type=str, default='./models/')
args = parser.parse_args()
return (args)
def main():
args = get_args()
print('img_dir', args.img_path)
print('model_dir', args.model_path)
model_path = os.path.join(args.model_path, 'conceptual_weights.pt')
device = CUDA(0) if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
prefix_length = 10
model = ClipCaptionModel(prefix_length)
model.load_state_dict(torch.load(model_path, map_location=CPU))
model = model.eval()
device = CUDA(0) if torch.cuda.is_available() else "cpu"
model = model.to(device)
files = glob.glob(args.img_path + "/*")
print(files)
for file in files:
image = io.imread(file)
pil_image = PIL.Image.fromarray(image)
image = preprocess(pil_image).unsqueeze(0).to(device)
with torch.no_grad():
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
print('\n')
print(generated_text_prefix)
if __name__ == '__main__':
main()
最後に作成したdemo.py
を下記コマンドで実行します。
$ pwd
clip_prefix_caption/CLIP_prefix_caption
$ python demo.py --img_path ./inputs/ --model_path ./models/
上記コマンドを実行するとinputs
に格納した画像のキャプションが出力されます。
画像に対して、十分に意味の通るキャプションが生成されていますね。
まとめ
本記事では、ClipCapで画像からキャプションを生成する方法を紹介しました。
この技術には、Transformerによる自然言語処理や画像検出など様々なタスクが組み込まれていますが、
これらを意識せずに実現したいタスクのみを実現できるモジュラー性が非常に優秀です。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
リンク
リンク
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
参考文献
1. 論文 - ClipCap: CLIP Prefix for Image Captioning
2. GitHub - rmokady/CLIP_prefix_caption
0 件のコメント :
コメントを投稿