[ClipCap] 機械学習で画像からキャプション(説明)を生成する [Python]

2021年11月27日土曜日

Artificial Intelligence

本記事では、ClipCapと呼ばれる機械学習手法を用いて、Pythonで画像のキャプションを生成する方法を紹介します。

アイキャッチ

ClipCapとは

ClipCapは、画像のキャプションタスクをシンプル、高速、軽量に実現した機械学習手法です。

従来、画像キャプション生成は複雑なタスクであり事前にトレーニングされた検出ネットワークが必要でした。

ClipCapでは、画像キャプションのみで画像キャプショントレーニングが可能であるため、モデルのトレーニング時間は従来の方法より高速でありながら、 最先端技術と同様の精度を示しています。

また、ClipCapは事前に多くの画像を学習しているCLIPモデルを使用しており、追加のsupervisionなしに任意の画像のセマンティックエンコーディングの生成を実現しています。

この手法の重要なアイデアはトレーニング済みのCLIPモデルと言語モデルを組み合わせ、単純なマッピングネットワークを使って画像キャプションタスクを実現している点です。

ClipCapの導入手順

セットアップ1: conda環境構築

それでは早速、開発環境にClipCapをセットアップしていきます。
動作確認は下記の環境で行っています。

OS: Ubuntu 18.04.3 LTS
GPU: GeForce GTX 1080

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

それでは、ClipCapをインストールしていきます。他の機械学習環境に影響を与えないためにMinicondaの仮想環境上に構築していきます。Minicondaのインストール手順は公式ドキュメントをご参照ください。

conda環境構築手順を下記に記載します。

# codeをgit clone
$ git clone https://github.com/rmokady/CLIP_prefix_caption.git

# environment.yamlからconda環境を作成
$ conda env create -f CLIP_prefix_caption/environment.yml
# cloneしたコードを作成したconda環境内に配置する
$ mv CLIP_prefix_caption/ clip_prefix_caption/
 
# 作成した環境をアクティベート
$ conda activate clip_prefix_caption
$ cd clip_prefix_caption/CLIP_prefix_caption
 
# Proxy配下の場合はインストール前に下記を設定
$ export http_proxy="http://"username":"password"@proxy:port"
$ export https_proxy="http://"username":"password"@proxy:port"

# ライブラリをインストール
$ pip3 install transformers
$ pip3 install git+https://github.com/openai/CLIP.git

# torchの動作確認
$ python
>>> import torch
>>> print(torch.cuda.is_available())
True
>>> print(torch.cuda.device_count())
2

以上でconda環境構築は完了です。

セットアップ2: 学習済みモデルとテストデータ準備

続いて学習済みモデルとテストデータを準備していきます。
本記事のテストデータはぱくたそ様の以下の2画像を使用します。

test用画像群
# ディレクトリ確認
$ pwd
clip_prefix_caption/CLIP_prefix_caption

# 学習済みモデル格納用ディレクトリ作成
$ mkdir models


# 学習済みモデルをダウンロード
# https://drive.google.com/file/d/14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT/view?usp=sharing
# https://drive.google.com/file/d/1IdaBtMSvtyzF0ByVaBHtvM0JYSXRExRX/view?usp=sharing
# modelsに学習済みモデルを格納

# テスト用画像格納用ディレクトリ作成
$ mkdir inputs
# inputsにテスト用画像を格納

画像キャプションの実行

最後に、画像キャプションを実行します。
Githubに公開されているCLIP_prefix_caption/notebooks/clip_prefix_captioning_inference.ipynbは、Google Colaboratoryで動作することを前提としているためこちらを参考に下記CLIP_prefix_caption/demo.pyを作成します。
なお、下記と同様のコードをGitHubに公開しています。
https://github.com/kaz12tech/CLIP_prefix_caption/blob/main/demo.py

import argparse
import os
from torch import nn
import numpy as np
import torch
import torch.nn.functional as nnf
from typing import Tuple, List, Union, Optional
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import clip
import glob
import skimage.io as io
import PIL.Image
from tqdm import tqdm, trange

N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VS = Union[Tuple[V, ...], List[V]]
VN = Union[V, N]
VNS = Union[VS, N]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]

D = torch.device
CPU = torch.device('cpu')

def get_device(device_id: int) -> D:
    if not torch.cuda.is_available():
        return CPU
    device_id = min(torch.cuda.device_count() - 1, device_id)
    return torch.device(f'cuda:{device_id}')

CUDA = get_device

#@title Model

class MLP(nn.Module):

    def forward(self, x: T) -> T:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) -1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)


class ClipCaptionModel(nn.Module):

    #@functools.lru_cache #FIXME
    def get_dummy_token(self, batch_size: int, device: D) -> T:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        #print(embedding_text.size()) #torch.Size([5, 67, 768])
        #print(prefix_projections.size()) #torch.Size([5, 1, 768])
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, prefix_size: int = 512):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if prefix_length > 10:  # not enough memory
            self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
        else:
            self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))


class ClipCaptionPrefix(ClipCaptionModel):

    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self

def generate2(
        model,
        tokenizer,
        tokens=None,
        prompt=None,
        embed=None,
        entry_count=1,
        entry_length=67,  # maximum number of words
        top_p=0.8,
        temperature=1.,
        stop_token: str = '.',
):
    model.eval()
    generated_num = 0
    generated_list = []
    stop_token_index = tokenizer.encode(stop_token)[0]
    filter_value = -float("Inf")
    device = next(model.parameters()).device

    with torch.no_grad():

        for entry_idx in trange(entry_count):
            if embed is not None:
                generated = embed
            else:
                if tokens is None:
                    tokens = torch.tensor(tokenizer.encode(prompt))
                    tokens = tokens.unsqueeze(0).to(device)

                generated = model.gpt.transformer.wte(tokens)

            for i in range(entry_length):

                outputs = model.gpt(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                                                    ..., :-1
                                                    ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value
                next_token = torch.argmax(logits, -1).unsqueeze(0)
                next_token_embed = model.gpt.transformer.wte(next_token)
                if tokens is None:
                    tokens = next_token
                else:
                    tokens = torch.cat((tokens, next_token), dim=1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                if stop_token_index == next_token.item():
                    break

            output_list = list(tokens.squeeze().cpu().numpy())
            output_text = tokenizer.decode(output_list)
            generated_list.append(output_text)

    return generated_list[0]

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--img_path", type=str, default='./inputs/')
    parser.add_argument("--model_path", type=str, default='./models/')

    args = parser.parse_args()

    return (args)

def main():
    args = get_args()
    print('img_dir', args.img_path)
    print('model_dir', args.model_path)

    model_path = os.path.join(args.model_path, 'conceptual_weights.pt')

    device = CUDA(0) if torch.cuda.is_available() else "cpu"
    clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    prefix_length = 10
    model = ClipCaptionModel(prefix_length)
    model.load_state_dict(torch.load(model_path, map_location=CPU))
    model = model.eval() 
    device = CUDA(0) if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    files = glob.glob(args.img_path + "/*")
    print(files)

    for file in files:
        image = io.imread(file)
        pil_image = PIL.Image.fromarray(image)

        image = preprocess(pil_image).unsqueeze(0).to(device)
        with torch.no_grad():
            prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
            prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
        generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)

        print('\n')
        print(generated_text_prefix)

if __name__ == '__main__':
    main()

最後に作成したdemo.pyを下記コマンドで実行します。

$ pwd
clip_prefix_caption/CLIP_prefix_caption
$ python demo.py --img_path ./inputs/ --model_path ./models/

上記コマンドを実行するとinputsに格納した画像のキャプションが出力されます。

キャプション出力結果

画像に対して、十分に意味の通るキャプションが生成されていますね。

まとめ

本記事では、ClipCapで画像からキャプションを生成する方法を紹介しました。

この技術には、Transformerによる自然言語処理や画像検出など様々なタスクが組み込まれていますが、 これらを意識せずに実現したいタスクのみを実現できるモジュラー性が非常に優秀です。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1. 論文 - ClipCap: CLIP Prefix for Image Captioning

2. GitHub - rmokady/CLIP_prefix_caption

AIエンジニア向けフリーランスならここがおすすめです

まずは無料会員登録

プロフィール

自分の写真
製造業に勤務する傍ら、日々AIの技術動向を調査しブログにアウトプットしています。 AIに関するご相談やお仕事のご依頼はブログのお問い合わせフォームか以下のアドレスまでお気軽にお問い合わせください。 bhupb13511@yahoo.co.jp

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology