[Text2Light] AIでテキストからパノラマ画像を生成する

2022年9月23日金曜日

Artificial Intelligence

本記事では、Text2Lightと呼ばれる機械学習手法を用いて任意の自然言語表現からパノラマ画像を生成する方法をご紹介します。

アイキャッチ
出典: FrozenBurning/Text2Light

Text2Light

概要

Text2Lightは、ゼロショットテキスト駆動型のHDRパノラマ画像(High Dynamic Range Images)生成技術です。

Text2Lightでは以下の2ステップでHDRIを生成します。

  1. Text-driven LDR Panorama Generation
    このステップでは、初めにCLIPを用いてテキストから低ダイナミックレンジで低解像度のLDRパノラマを生成
  2. Super-resolution Inverse Tone Mapping
    LDRパノラマをパッチに分解し、球状に構造化された潜在空間にマッピングし、解像度とダイナミックレンジを同時にアップスケーリング

以上のステップで構成されたTex2Lightは、ゼロショットで汎用性の高いHDR画像の生成を可能にしています。

Architecture

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、テキストからパノラマ画像を生成していきます。

デモ(Colaboratory)

それでは、実際に動かしながらテキストからパノラマ画像を生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

%cd /content

!git clone https://github.com/FrozenBurning/Text2Light


# Commits on Sep 22, 2022
%cd /content/Text2Light
!git checkout acd216150a76b8bb4e460801e95100d167a6f7fe

次にライブラリをインストールします。

%cd /content/Text2Light

!sudo apt-get install libomp-dev

!pip install ftfy regex tqdm omegaconf pytorch-lightning tensorboardX einops transformers
!pip install kornia
!pip install imageio-ffmpeg
!pip install faiss
!pip install --upgrade gdown

最後にライブラリをインポートします。

import os

import argparse, os, sys, glob
import cv2
import torch
import faiss
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm

import clip
from taming.util import instantiate_from_config
from sritmo.global_sritmo import SRiTMO

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

続いて論文発表元が公開している学習済みモデルをGoogle Colaboratoryにダウンロードします。
Google Driveに公開された学習済みモデルをgdownライブラリを使用してダウンロードしています。

%cd /content/Text2Light
!mkdir -p ckpts


ckpt_path = 'ckpts/text2light_released_model/global_sampler_clip/checkpoints/last.ckpt'
if not os.path.exists(ckpt_path):
  !gdown --folder https://drive.google.com/drive/folders/1HKBjC7oQOzrkGFKMQmSh6PySv6AycDS3?usp=sharing \
         -O 'ckpts/text2light_released_model'


# 学習済みモデルのダウンロードで下記エラーが出た場合はこのセルを実行
# Access denied with the following error:

#  	Too many users have viewed or downloaded this file recently. Please
# 	try accessing the file again later. If the file you are trying to
# 	access is particularly large or is shared with many people, it may
# 	take up to 24 hours to be able to view or download the file. If you
# 	still can't access a file after 24 hours, contact your domain
# 	administrator. 

# text2light_released_modelの「ドライブへのショートカットを追加」後
# from google.colab import drive
# drive.mount('/content/drive')

# !rm -rf /content/Text2Light/ckpts/text2light_released_model
# !cp -r /content/drive/MyDrive/text2light_released_model /content/Text2Light/ckpts/

Utility関数定義

続いて画像保存用関数や、モデルのロード用関数等を定義していきます。

# Tensor imgの保存
def save_image(x, path):
  c,h,w = x.shape
  assert c==3
  x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
  s = Image.fromarray(x)
  s.save(path)
  return s

def get_knn(database: np.array, index: faiss.Index, txt_emb, k = 5):
  dist, idx  = index.search(txt_emb, k)
  return database[idx], idx #[bs, k, 512]

# predict
@torch.no_grad()
def text2light(models: dict, prompts, outdir, params: dict):
  # models
  global_sampler = models["gs"]
  local_sampler = models["ls"]
  # params
  batch_size = len(prompts)
  top_k = params["top_k"]
  temperature = params['temperature']
  database = params['data4knn']
  faiss_index = params['index4knn']
  device = params['device']

  # embed input texts
  lan_model, _ = clip.load("ViT-B/32", device=device)
  lan_model.eval()
  text = clip.tokenize(prompts).to(device)
  text_features = lan_model.encode_text(text)
  target_txt_emb = text_features / text_features.norm(dim=-1, keepdim=True)
  cond, _ = get_knn(database, faiss_index, target_txt_emb.cpu().numpy().astype('float32'))
  txt_cond = torch.from_numpy(cond.reshape(batch_size, 5, cond.shape[-1]))
  txt_cond = torch.cat([txt_cond, txt_cond,], dim=-1).to(device)

  # sample holistic condition
  bs = batch_size
  start = 0
  idx = torch.zeros(bs, 1, dtype=int)[:, :start].to(device)
  cshape = [bs, 256, 8, 16]
  sample = True

  print("Generating holistic conditions according to texts...")
  for i in tqdm(range(start, cshape[2]*cshape[3])):
    logits, _ = global_sampler.transformer(idx, embeddings=txt_cond)
    logits = logits[:, -1, :]
    if top_k is not None:
      logits = global_sampler.top_k_logits(logits, top_k)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    if sample:
      ix = torch.multinomial(probs, num_samples=1)
    else:
      _, ix = torch.topk(probs, k=1, dim=-1)
    idx = torch.cat((idx, ix), dim=1)

  xsample_holistic = global_sampler.decode_to_img(idx, cshape)
  for i in range(xsample_holistic.shape[0]):
    holistic_save = save_image(xsample_holistic[i], os.path.join(outdir, "holistic", "holistic_[{}].png".format(prompts[i])))

  print("Synthesizing patches...")
  # synthesize patch by patch according to holistic condition
  h = 512
  w = 1024
  xx, yy = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
  screen_points = np.stack([xx, yy], axis=-1)
  coord = (screen_points * 2 - 1) * np.array([np.pi, np.pi/2])
  spe = torch.from_numpy(coord).to(xsample_holistic).repeat(xsample_holistic.shape[0], 1, 1, 1).permute(0, 3, 1, 2)
  spe = torch.nn.functional.interpolate(spe, scale_factor=1/8, mode="bicubic", recompute_scale_factor=False, align_corners=True)
  spe = local_sampler.embedder(spe.permute(0, 2, 3, 1))
  spe = spe.permute(0, 3, 1, 2)

  _, h_indices = local_sampler.encode_to_h(xsample_holistic)
  cshape = [xsample_holistic.shape[0], 256, h // 16, w // 16]
  idx = torch.randint(0, 1024, (cshape[0], cshape[2], cshape[3])).to(h_indices)
  idx = idx.reshape(cshape[0], cshape[2], cshape[3])
  
  start = 0
  start_i = start // cshape[3]
  start_j = start % cshape[3]
  sample = True

  for i in tqdm(range(start_i, cshape[2])):
    if i <= 8:
      local_i = i
    elif cshape[2]-i < 8:
      local_i = 16-(cshape[2]-i)
    else:
      local_i = 8
    for j in range(start_j, cshape[3]):
      if j <= 8:
        local_j = j
      elif cshape[3]-j < 8:
        local_j = 16-(cshape[3]-j)
      else:
        local_j = 8

        i_start = i-local_i
        i_end = i_start+16
        j_start = j-local_j
        j_end = j_start+16
        patch = idx[:,i_start:i_end,j_start:j_end]
        patch = patch.reshape(patch.shape[0],-1)
        cpatch = spe[:, :, i_start*2:i_end*2,j_start*2:j_end*2]
        cpatch = cpatch.reshape(cpatch.shape[0], local_sampler.cdim, -1)
        patch = torch.cat((h_indices, patch), dim=1)
        logits, _ = local_sampler.transformer(patch[:,:-1], embeddings=cpatch)
        logits = logits[:, -256:, :]
        logits = logits.reshape(cshape[0],16,16,-1)
        logits = logits[:,local_i,local_j,:]
        logits = logits / temperature

        if top_k is not None:
          logits = local_sampler.top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = torch.nn.functional.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
          ix = torch.multinomial(probs, num_samples=1)
        else:
          _, ix = torch.topk(probs, k=1, dim=-1)
        idx[:,i,j] = ix.reshape(-1)
  xsample = local_sampler.decode_to_img(idx, cshape)
  for i in range(xsample.shape[0]):
    ldr_save = save_image(xsample[i], os.path.join(outdir, "ldr", "ldr_[{}].png".format(prompts[i])))

  # super-resolution inverse tone mapping
  if params['sritmo'] is not None:
    ldr_hr_samples, hdr_hr_samples = SRiTMO(xsample, params)
  else:
    print("no checkpoint provided, skip Stage II (SR-iTMO)...")
    return
      
  for i in range(xsample.shape[0]):
    ldr_hr_save = (ldr_hr_samples[i].permute(1, 2, 0).detach().cpu().numpy() + 1) * 127.5
    cv2.imwrite(os.path.join(outdir, "ldr", "hrldr_[{}].png".format(prompts[i])), ldr_hr_save)
    # cv2.imwrite(os.path.join(outdir, "hdr", "hdr_[{}].exr".format(prompts[i])), hdr_hr_samples[i].permute(1, 2, 0).detach().cpu().numpy())
  return ldr_hr_save

# load model
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
  if "ckpt_path" in config.params:
    print("Deleting the restore-ckpt path from the config...")
    config.params.ckpt_path = None
  if "downsample_cond_size" in config.params:
    print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
    config.params.downsample_cond_size = -1
    config.params["downsample_cond_factor"] = 0.5
  try:
    if "ckpt_path" in config.params.first_stage_config.params:
      config.params.first_stage_config.params.ckpt_path = None
      print("Deleting the first-stage restore-ckpt path from the config...")
    if "ckpt_path" in config.params.cond_stage_config.params:
      config.params.cond_stage_config.params.ckpt_path = None
      print("Deleting the cond-stage restore-ckpt path from the config...")
    if "ckpt_path" in config.params.holistic_config.params:
      config.params.holistic_config.params.ckpt_path = None
      print("Deleting the global sampler restore-ckpt path from the config...")
  except:
    pass

  model = instantiate_from_config(config)
  if sd is not None:
    missing, unexpected = model.load_state_dict(sd, strict=False)
    print(f"Missing Keys in State Dict: {missing}")
    print(f"Unexpected Keys in State Dict: {unexpected}")
  if gpu:
    model.cuda()
  if eval_mode:
    model.eval()
  return {"model": model}

def load_model(config, ckpt, gpu, eval_mode):
  if ckpt:
    raw_model = torch.load(ckpt, map_location="cpu")
    state_dict = raw_model["state_dict"]
  else:
    raise NotImplementedError("checkpoint at [{}] is not found!".format(ckpt))
  model = load_model_from_config(config.model, state_dict, gpu=gpu, eval_mode=eval_mode)["model"]
  return model

Text to High Dynamic Range Images

それでは、テキストを設定し、パノラマ画像を生成していきます。
本記事では、Sunny Parkを設定してみます。

texts = "Sunny Park" #@param {type:"string"}
model = "outdoor" #@param ["full", "outdoor", "indoor"]
sritmo = True #@param {type:"boolean"}
sr_factor = 4 #@param {type:"number"}
top_k = 100 #@param {type:"number"}
temperature = 1.0 #@param {type:"number"}


local_sampler_path = None
if model == "full":
  local_sampler_path = "./ckpts/text2light_released_model/local_sampler/"
elif model == "outdoor":
  local_sampler_path = "./ckpts/text2light_released_model/local_sampler_outdoor/"
elif model == "indoor":
  local_sampler_path = "./ckpts/text2light_released_model/local_sampler_indoor/"
else:
  raise NotImplementedError

sritmo_path = None
if sritmo:
  sritmo_path = "./ckpts/text2light_released_model/sritmo.pth"


opt = argparse.Namespace(
    resume_global="./ckpts/text2light_released_model/global_sampler_clip/",
    resume_local=local_sampler_path,
    sritmo=sritmo_path,
    sr_factor=sr_factor,
    outdir="./text2light_generated",
    clip="./clip_emb.npy",
    text=texts,
    top_k=top_k,
    temperature=temperature,
    bs=1,
)

出力結果は以下の通りです。
テキストに対応したパノラマ画像が生成されていますが、両端の品質が安定していません。

Sunny Park

まとめ

本記事では、Text2Lightを用いてテキストからパノラマ画像を生成する方法をご紹介しました。
3D空間へのレンダリングやVRへの応用などが考えられます。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

参考文献

1.  論文 - Text2Light: Zero-Shot Text-Driven HDR Panorama Generation

2. GitHub - FrozenBurning/Text2Light

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology