[Text2Light] AIでテキストからパノラマ画像を生成する

2022年9月23日金曜日

Artificial Intelligence

本記事では、Text2Lightと呼ばれる機械学習手法を用いて任意の自然言語表現からパノラマ画像を生成する方法をご紹介します。

アイキャッチ
出典: FrozenBurning/Text2Light

Text2Light

概要

Text2Lightは、ゼロショットテキスト駆動型のHDRパノラマ画像(High Dynamic Range Images)生成技術です。

Text2Lightでは以下の2ステップでHDRIを生成します。

  1. Text-driven LDR Panorama Generation
    このステップでは、初めにCLIPを用いてテキストから低ダイナミックレンジで低解像度のLDRパノラマを生成
  2. Super-resolution Inverse Tone Mapping
    LDRパノラマをパッチに分解し、球状に構造化された潜在空間にマッピングし、解像度とダイナミックレンジを同時にアップスケーリング

以上のステップで構成されたTex2Lightは、ゼロショットで汎用性の高いHDR画像の生成を可能にしています。

Architecture

詳細はこちらの論文をご参照ください。

本記事では上記手法を用いて、テキストからパノラマ画像を生成していきます。

スポンサーリンク

デモ(Colaboratory)

それでは、実際に動かしながらテキストからパノラマ画像を生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo

また、下記から直接Google Colaboratoryで開くこともできます。
Open In Colab

なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。

環境セットアップ

それではセットアップしていきます。 Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

初めにGithubからソースコードを取得します。

  1. %cd /content
  2.  
  3. !git clone https://github.com/FrozenBurning/Text2Light
  4.  
  5.  
  6. # Commits on Sep 22, 2022
  7. %cd /content/Text2Light
  8. !git checkout acd216150a76b8bb4e460801e95100d167a6f7fe

次にライブラリをインストールします。

  1. %cd /content/Text2Light
  2.  
  3. !sudo apt-get install libomp-dev
  4.  
  5. !pip install ftfy regex tqdm omegaconf pytorch-lightning tensorboardX einops transformers
  6. !pip install kornia
  7. !pip install imageio-ffmpeg
  8. !pip install faiss
  9. !pip install --upgrade gdown

最後にライブラリをインポートします。

  1. import os
  2.  
  3. import argparse, os, sys, glob
  4. import cv2
  5. import torch
  6. import faiss
  7. import numpy as np
  8. from omegaconf import OmegaConf
  9. from PIL import Image
  10. from tqdm import tqdm
  11.  
  12. import clip
  13. from taming.util import instantiate_from_config
  14. from sritmo.global_sritmo import SRiTMO

以上で環境セットアップは完了です。

学習済みモデルのセットアップ

続いて論文発表元が公開している学習済みモデルをGoogle Colaboratoryにダウンロードします。
Google Driveに公開された学習済みモデルをgdownライブラリを使用してダウンロードしています。

  1. %cd /content/Text2Light
  2. !mkdir -p ckpts
  3.  
  4.  
  5. ckpt_path = 'ckpts/text2light_released_model/global_sampler_clip/checkpoints/last.ckpt'
  6. if not os.path.exists(ckpt_path):
  7. !gdown --folder https://drive.google.com/drive/folders/1HKBjC7oQOzrkGFKMQmSh6PySv6AycDS3?usp=sharing \
  8. -O 'ckpts/text2light_released_model'
  9.  
  10.  
  11. # 学習済みモデルのダウンロードで下記エラーが出た場合はこのセルを実行
  12. # Access denied with the following error:
  13.  
  14. # Too many users have viewed or downloaded this file recently. Please
  15. # try accessing the file again later. If the file you are trying to
  16. # access is particularly large or is shared with many people, it may
  17. # take up to 24 hours to be able to view or download the file. If you
  18. # still can't access a file after 24 hours, contact your domain
  19. # administrator.
  20.  
  21. # text2light_released_modelの「ドライブへのショートカットを追加」後
  22. # from google.colab import drive
  23. # drive.mount('/content/drive')
  24.  
  25. # !rm -rf /content/Text2Light/ckpts/text2light_released_model
  26. # !cp -r /content/drive/MyDrive/text2light_released_model /content/Text2Light/ckpts/

Utility関数定義

続いて画像保存用関数や、モデルのロード用関数等を定義していきます。

  1. # Tensor imgの保存
  2. def save_image(x, path):
  3. c,h,w = x.shape
  4. assert c==3
  5. x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
  6. s = Image.fromarray(x)
  7. s.save(path)
  8. return s
  9.  
  10. def get_knn(database: np.array, index: faiss.Index, txt_emb, k = 5):
  11. dist, idx = index.search(txt_emb, k)
  12. return database[idx], idx #[bs, k, 512]
  13.  
  14. # predict
  15. @torch.no_grad()
  16. def text2light(models: dict, prompts, outdir, params: dict):
  17. # models
  18. global_sampler = models["gs"]
  19. local_sampler = models["ls"]
  20. # params
  21. batch_size = len(prompts)
  22. top_k = params["top_k"]
  23. temperature = params['temperature']
  24. database = params['data4knn']
  25. faiss_index = params['index4knn']
  26. device = params['device']
  27.  
  28. # embed input texts
  29. lan_model, _ = clip.load("ViT-B/32", device=device)
  30. lan_model.eval()
  31. text = clip.tokenize(prompts).to(device)
  32. text_features = lan_model.encode_text(text)
  33. target_txt_emb = text_features / text_features.norm(dim=-1, keepdim=True)
  34. cond, _ = get_knn(database, faiss_index, target_txt_emb.cpu().numpy().astype('float32'))
  35. txt_cond = torch.from_numpy(cond.reshape(batch_size, 5, cond.shape[-1]))
  36. txt_cond = torch.cat([txt_cond, txt_cond,], dim=-1).to(device)
  37.  
  38. # sample holistic condition
  39. bs = batch_size
  40. start = 0
  41. idx = torch.zeros(bs, 1, dtype=int)[:, :start].to(device)
  42. cshape = [bs, 256, 8, 16]
  43. sample = True
  44.  
  45. print("Generating holistic conditions according to texts...")
  46. for i in tqdm(range(start, cshape[2]*cshape[3])):
  47. logits, _ = global_sampler.transformer(idx, embeddings=txt_cond)
  48. logits = logits[:, -1, :]
  49. if top_k is not None:
  50. logits = global_sampler.top_k_logits(logits, top_k)
  51. probs = torch.nn.functional.softmax(logits, dim=-1)
  52. if sample:
  53. ix = torch.multinomial(probs, num_samples=1)
  54. else:
  55. _, ix = torch.topk(probs, k=1, dim=-1)
  56. idx = torch.cat((idx, ix), dim=1)
  57.  
  58. xsample_holistic = global_sampler.decode_to_img(idx, cshape)
  59. for i in range(xsample_holistic.shape[0]):
  60. holistic_save = save_image(xsample_holistic[i], os.path.join(outdir, "holistic", "holistic_[{}].png".format(prompts[i])))
  61.  
  62. print("Synthesizing patches...")
  63. # synthesize patch by patch according to holistic condition
  64. h = 512
  65. w = 1024
  66. xx, yy = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
  67. screen_points = np.stack([xx, yy], axis=-1)
  68. coord = (screen_points * 2 - 1) * np.array([np.pi, np.pi/2])
  69. spe = torch.from_numpy(coord).to(xsample_holistic).repeat(xsample_holistic.shape[0], 1, 1, 1).permute(0, 3, 1, 2)
  70. spe = torch.nn.functional.interpolate(spe, scale_factor=1/8, mode="bicubic", recompute_scale_factor=False, align_corners=True)
  71. spe = local_sampler.embedder(spe.permute(0, 2, 3, 1))
  72. spe = spe.permute(0, 3, 1, 2)
  73.  
  74. _, h_indices = local_sampler.encode_to_h(xsample_holistic)
  75. cshape = [xsample_holistic.shape[0], 256, h // 16, w // 16]
  76. idx = torch.randint(0, 1024, (cshape[0], cshape[2], cshape[3])).to(h_indices)
  77. idx = idx.reshape(cshape[0], cshape[2], cshape[3])
  78. start = 0
  79. start_i = start // cshape[3]
  80. start_j = start % cshape[3]
  81. sample = True
  82.  
  83. for i in tqdm(range(start_i, cshape[2])):
  84. if i <= 8:
  85. local_i = i
  86. elif cshape[2]-i < 8:
  87. local_i = 16-(cshape[2]-i)
  88. else:
  89. local_i = 8
  90. for j in range(start_j, cshape[3]):
  91. if j <= 8:
  92. local_j = j
  93. elif cshape[3]-j < 8:
  94. local_j = 16-(cshape[3]-j)
  95. else:
  96. local_j = 8
  97.  
  98. i_start = i-local_i
  99. i_end = i_start+16
  100. j_start = j-local_j
  101. j_end = j_start+16
  102. patch = idx[:,i_start:i_end,j_start:j_end]
  103. patch = patch.reshape(patch.shape[0],-1)
  104. cpatch = spe[:, :, i_start*2:i_end*2,j_start*2:j_end*2]
  105. cpatch = cpatch.reshape(cpatch.shape[0], local_sampler.cdim, -1)
  106. patch = torch.cat((h_indices, patch), dim=1)
  107. logits, _ = local_sampler.transformer(patch[:,:-1], embeddings=cpatch)
  108. logits = logits[:, -256:, :]
  109. logits = logits.reshape(cshape[0],16,16,-1)
  110. logits = logits[:,local_i,local_j,:]
  111. logits = logits / temperature
  112.  
  113. if top_k is not None:
  114. logits = local_sampler.top_k_logits(logits, top_k)
  115. # apply softmax to convert to probabilities
  116. probs = torch.nn.functional.softmax(logits, dim=-1)
  117. # sample from the distribution or take the most likely
  118. if sample:
  119. ix = torch.multinomial(probs, num_samples=1)
  120. else:
  121. _, ix = torch.topk(probs, k=1, dim=-1)
  122. idx[:,i,j] = ix.reshape(-1)
  123. xsample = local_sampler.decode_to_img(idx, cshape)
  124. for i in range(xsample.shape[0]):
  125. ldr_save = save_image(xsample[i], os.path.join(outdir, "ldr", "ldr_[{}].png".format(prompts[i])))
  126.  
  127. # super-resolution inverse tone mapping
  128. if params['sritmo'] is not None:
  129. ldr_hr_samples, hdr_hr_samples = SRiTMO(xsample, params)
  130. else:
  131. print("no checkpoint provided, skip Stage II (SR-iTMO)...")
  132. return
  133. for i in range(xsample.shape[0]):
  134. ldr_hr_save = (ldr_hr_samples[i].permute(1, 2, 0).detach().cpu().numpy() + 1) * 127.5
  135. cv2.imwrite(os.path.join(outdir, "ldr", "hrldr_[{}].png".format(prompts[i])), ldr_hr_save)
  136. # cv2.imwrite(os.path.join(outdir, "hdr", "hdr_[{}].exr".format(prompts[i])), hdr_hr_samples[i].permute(1, 2, 0).detach().cpu().numpy())
  137. return ldr_hr_save
  138.  
  139. # load model
  140. def load_model_from_config(config, sd, gpu=True, eval_mode=True):
  141. if "ckpt_path" in config.params:
  142. print("Deleting the restore-ckpt path from the config...")
  143. config.params.ckpt_path = None
  144. if "downsample_cond_size" in config.params:
  145. print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
  146. config.params.downsample_cond_size = -1
  147. config.params["downsample_cond_factor"] = 0.5
  148. try:
  149. if "ckpt_path" in config.params.first_stage_config.params:
  150. config.params.first_stage_config.params.ckpt_path = None
  151. print("Deleting the first-stage restore-ckpt path from the config...")
  152. if "ckpt_path" in config.params.cond_stage_config.params:
  153. config.params.cond_stage_config.params.ckpt_path = None
  154. print("Deleting the cond-stage restore-ckpt path from the config...")
  155. if "ckpt_path" in config.params.holistic_config.params:
  156. config.params.holistic_config.params.ckpt_path = None
  157. print("Deleting the global sampler restore-ckpt path from the config...")
  158. except:
  159. pass
  160.  
  161. model = instantiate_from_config(config)
  162. if sd is not None:
  163. missing, unexpected = model.load_state_dict(sd, strict=False)
  164. print(f"Missing Keys in State Dict: {missing}")
  165. print(f"Unexpected Keys in State Dict: {unexpected}")
  166. if gpu:
  167. model.cuda()
  168. if eval_mode:
  169. model.eval()
  170. return {"model": model}
  171.  
  172. def load_model(config, ckpt, gpu, eval_mode):
  173. if ckpt:
  174. raw_model = torch.load(ckpt, map_location="cpu")
  175. state_dict = raw_model["state_dict"]
  176. else:
  177. raise NotImplementedError("checkpoint at [{}] is not found!".format(ckpt))
  178. model = load_model_from_config(config.model, state_dict, gpu=gpu, eval_mode=eval_mode)["model"]
  179. return model

Text to High Dynamic Range Images

それでは、テキストを設定し、パノラマ画像を生成していきます。
本記事では、Sunny Parkを設定してみます。

  1. texts = "Sunny Park" #@param {type:"string"}
  2. model = "outdoor" #@param ["full", "outdoor", "indoor"]
  3. sritmo = True #@param {type:"boolean"}
  4. sr_factor = 4 #@param {type:"number"}
  5. top_k = 100 #@param {type:"number"}
  6. temperature = 1.0 #@param {type:"number"}
  7.  
  8.  
  9. local_sampler_path = None
  10. if model == "full":
  11. local_sampler_path = "./ckpts/text2light_released_model/local_sampler/"
  12. elif model == "outdoor":
  13. local_sampler_path = "./ckpts/text2light_released_model/local_sampler_outdoor/"
  14. elif model == "indoor":
  15. local_sampler_path = "./ckpts/text2light_released_model/local_sampler_indoor/"
  16. else:
  17. raise NotImplementedError
  18.  
  19. sritmo_path = None
  20. if sritmo:
  21. sritmo_path = "./ckpts/text2light_released_model/sritmo.pth"
  22.  
  23.  
  24. opt = argparse.Namespace(
  25. resume_global="./ckpts/text2light_released_model/global_sampler_clip/",
  26. resume_local=local_sampler_path,
  27. sritmo=sritmo_path,
  28. sr_factor=sr_factor,
  29. outdir="./text2light_generated",
  30. clip="./clip_emb.npy",
  31. text=texts,
  32. top_k=top_k,
  33. temperature=temperature,
  34. bs=1,
  35. )
  36.  

出力結果は以下の通りです。
テキストに対応したパノラマ画像が生成されていますが、両端の品質が安定していません。

Sunny Park

まとめ

本記事では、Text2Lightを用いてテキストからパノラマ画像を生成する方法をご紹介しました。
3D空間へのレンダリングやVRへの応用などが考えられます。

また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。


また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。

スポンサーリンク

参考文献

1.  論文 - Text2Light: Zero-Shot Text-Driven HDR Panorama Generation

2. GitHub - FrozenBurning/Text2Light

AIで副業ならココから!

まずは無料会員登録

プロフィール

メーカーで研究開発を行う現役エンジニア
組み込み機器開発や機会学習モデル開発に従事しています

本ブログでは最新AI技術を中心にソースコード付きでご紹介します


Twitter

カテゴリ

このブログを検索

ブログ アーカイブ

TeDokology