本記事では、Text2Lightと呼ばれる機械学習手法を用いて任意の自然言語表現からパノラマ画像を生成する方法をご紹介します。
Text2Light
概要
Text2Lightは、ゼロショットテキスト駆動型のHDRパノラマ画像(High Dynamic Range
Images)生成技術です。
Text2Lightでは以下の2ステップでHDRIを生成します。
-
Text-driven LDR Panorama Generation
このステップでは、初めにCLIPを用いてテキストから低ダイナミックレンジで低解像度のLDRパノラマを生成
-
Super-resolution Inverse Tone Mapping
LDRパノラマをパッチに分解し、球状に構造化された潜在空間にマッピングし、解像度とダイナミックレンジを同時にアップスケーリング
以上のステップで構成されたTex2Lightは、ゼロショットで汎用性の高いHDR画像の生成を可能にしています。
詳細はこちらの論文をご参照ください。
本記事では上記手法を用いて、テキストからパノラマ画像を生成していきます。
デモ(Colaboratory)
それでは、実際に動かしながらテキストからパノラマ画像を生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo
また、下記から直接Google Colaboratoryで開くこともできます。
なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
環境セットアップ
それではセットアップしていきます。
Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
初めにGithubからソースコードを取得します。
%cd /content
!git clone https://github.com/FrozenBurning/Text2Light
# Commits on Sep 22, 2022
%cd /content/Text2Light
!git checkout acd216150a76b8bb4e460801e95100d167a6f7fe
次にライブラリをインストールします。
%cd /content/Text2Light
!sudo apt-get install libomp-dev
!pip install ftfy regex tqdm omegaconf pytorch-lightning tensorboardX einops transformers
!pip install kornia
!pip install imageio-ffmpeg
!pip install faiss
!pip install --upgrade gdown
最後にライブラリをインポートします。
import os
import argparse, os, sys, glob
import cv2
import torch
import faiss
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
import clip
from taming.util import instantiate_from_config
from sritmo.global_sritmo import SRiTMO
以上で環境セットアップは完了です。
学習済みモデルのセットアップ
続いて論文発表元が公開している学習済みモデルをGoogle
Colaboratoryにダウンロードします。
Google
Driveに公開された学習済みモデルをgdownライブラリを使用してダウンロードしています。
%cd /content/Text2Light
!mkdir -p ckpts
ckpt_path = 'ckpts/text2light_released_model/global_sampler_clip/checkpoints/last.ckpt'
if not os.path.exists(ckpt_path):
!gdown --folder https://drive.google.com/drive/folders/1HKBjC7oQOzrkGFKMQmSh6PySv6AycDS3?usp=sharing \
-O 'ckpts/text2light_released_model'
# 学習済みモデルのダウンロードで下記エラーが出た場合はこのセルを実行
# Access denied with the following error:
# Too many users have viewed or downloaded this file recently. Please
# try accessing the file again later. If the file you are trying to
# access is particularly large or is shared with many people, it may
# take up to 24 hours to be able to view or download the file. If you
# still can't access a file after 24 hours, contact your domain
# administrator.
# text2light_released_modelの「ドライブへのショートカットを追加」後
# from google.colab import drive
# drive.mount('/content/drive')
# !rm -rf /content/Text2Light/ckpts/text2light_released_model
# !cp -r /content/drive/MyDrive/text2light_released_model /content/Text2Light/ckpts/
Utility関数定義
続いて画像保存用関数や、モデルのロード用関数等を定義していきます。
# Tensor imgの保存
def save_image(x, path):
c,h,w = x.shape
assert c==3
x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
s = Image.fromarray(x)
s.save(path)
return s
def get_knn(database: np.array, index: faiss.Index, txt_emb, k = 5):
dist, idx = index.search(txt_emb, k)
return database[idx], idx #[bs, k, 512]
# predict
@torch.no_grad()
def text2light(models: dict, prompts, outdir, params: dict):
# models
global_sampler = models["gs"]
local_sampler = models["ls"]
# params
batch_size = len(prompts)
top_k = params["top_k"]
temperature = params['temperature']
database = params['data4knn']
faiss_index = params['index4knn']
device = params['device']
# embed input texts
lan_model, _ = clip.load("ViT-B/32", device=device)
lan_model.eval()
text = clip.tokenize(prompts).to(device)
text_features = lan_model.encode_text(text)
target_txt_emb = text_features / text_features.norm(dim=-1, keepdim=True)
cond, _ = get_knn(database, faiss_index, target_txt_emb.cpu().numpy().astype('float32'))
txt_cond = torch.from_numpy(cond.reshape(batch_size, 5, cond.shape[-1]))
txt_cond = torch.cat([txt_cond, txt_cond,], dim=-1).to(device)
# sample holistic condition
bs = batch_size
start = 0
idx = torch.zeros(bs, 1, dtype=int)[:, :start].to(device)
cshape = [bs, 256, 8, 16]
sample = True
print("Generating holistic conditions according to texts...")
for i in tqdm(range(start, cshape[2]*cshape[3])):
logits, _ = global_sampler.transformer(idx, embeddings=txt_cond)
logits = logits[:, -1, :]
if top_k is not None:
logits = global_sampler.top_k_logits(logits, top_k)
probs = torch.nn.functional.softmax(logits, dim=-1)
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
idx = torch.cat((idx, ix), dim=1)
xsample_holistic = global_sampler.decode_to_img(idx, cshape)
for i in range(xsample_holistic.shape[0]):
holistic_save = save_image(xsample_holistic[i], os.path.join(outdir, "holistic", "holistic_[{}].png".format(prompts[i])))
print("Synthesizing patches...")
# synthesize patch by patch according to holistic condition
h = 512
w = 1024
xx, yy = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
screen_points = np.stack([xx, yy], axis=-1)
coord = (screen_points * 2 - 1) * np.array([np.pi, np.pi/2])
spe = torch.from_numpy(coord).to(xsample_holistic).repeat(xsample_holistic.shape[0], 1, 1, 1).permute(0, 3, 1, 2)
spe = torch.nn.functional.interpolate(spe, scale_factor=1/8, mode="bicubic", recompute_scale_factor=False, align_corners=True)
spe = local_sampler.embedder(spe.permute(0, 2, 3, 1))
spe = spe.permute(0, 3, 1, 2)
_, h_indices = local_sampler.encode_to_h(xsample_holistic)
cshape = [xsample_holistic.shape[0], 256, h // 16, w // 16]
idx = torch.randint(0, 1024, (cshape[0], cshape[2], cshape[3])).to(h_indices)
idx = idx.reshape(cshape[0], cshape[2], cshape[3])
start = 0
start_i = start // cshape[3]
start_j = start % cshape[3]
sample = True
for i in tqdm(range(start_i, cshape[2])):
if i <= 8:
local_i = i
elif cshape[2]-i < 8:
local_i = 16-(cshape[2]-i)
else:
local_i = 8
for j in range(start_j, cshape[3]):
if j <= 8:
local_j = j
elif cshape[3]-j < 8:
local_j = 16-(cshape[3]-j)
else:
local_j = 8
i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16
patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = spe[:, :, i_start*2:i_end*2,j_start*2:j_end*2]
cpatch = cpatch.reshape(cpatch.shape[0], local_sampler.cdim, -1)
patch = torch.cat((h_indices, patch), dim=1)
logits, _ = local_sampler.transformer(patch[:,:-1], embeddings=cpatch)
logits = logits[:, -256:, :]
logits = logits.reshape(cshape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]
logits = logits / temperature
if top_k is not None:
logits = local_sampler.top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
idx[:,i,j] = ix.reshape(-1)
xsample = local_sampler.decode_to_img(idx, cshape)
for i in range(xsample.shape[0]):
ldr_save = save_image(xsample[i], os.path.join(outdir, "ldr", "ldr_[{}].png".format(prompts[i])))
# super-resolution inverse tone mapping
if params['sritmo'] is not None:
ldr_hr_samples, hdr_hr_samples = SRiTMO(xsample, params)
else:
print("no checkpoint provided, skip Stage II (SR-iTMO)...")
return
for i in range(xsample.shape[0]):
ldr_hr_save = (ldr_hr_samples[i].permute(1, 2, 0).detach().cpu().numpy() + 1) * 127.5
cv2.imwrite(os.path.join(outdir, "ldr", "hrldr_[{}].png".format(prompts[i])), ldr_hr_save)
# cv2.imwrite(os.path.join(outdir, "hdr", "hdr_[{}].exr".format(prompts[i])), hdr_hr_samples[i].permute(1, 2, 0).detach().cpu().numpy())
return ldr_hr_save
# load model
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
if "ckpt_path" in config.params:
print("Deleting the restore-ckpt path from the config...")
config.params.ckpt_path = None
if "downsample_cond_size" in config.params:
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
config.params.downsample_cond_size = -1
config.params["downsample_cond_factor"] = 0.5
try:
if "ckpt_path" in config.params.first_stage_config.params:
config.params.first_stage_config.params.ckpt_path = None
print("Deleting the first-stage restore-ckpt path from the config...")
if "ckpt_path" in config.params.cond_stage_config.params:
config.params.cond_stage_config.params.ckpt_path = None
print("Deleting the cond-stage restore-ckpt path from the config...")
if "ckpt_path" in config.params.holistic_config.params:
config.params.holistic_config.params.ckpt_path = None
print("Deleting the global sampler restore-ckpt path from the config...")
except:
pass
model = instantiate_from_config(config)
if sd is not None:
missing, unexpected = model.load_state_dict(sd, strict=False)
print(f"Missing Keys in State Dict: {missing}")
print(f"Unexpected Keys in State Dict: {unexpected}")
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def load_model(config, ckpt, gpu, eval_mode):
if ckpt:
raw_model = torch.load(ckpt, map_location="cpu")
state_dict = raw_model["state_dict"]
else:
raise NotImplementedError("checkpoint at [{}] is not found!".format(ckpt))
model = load_model_from_config(config.model, state_dict, gpu=gpu, eval_mode=eval_mode)["model"]
return model
Text to High Dynamic Range Images
それでは、テキストを設定し、パノラマ画像を生成していきます。
本記事では、Sunny Park
を設定してみます。
texts = "Sunny Park" #@param {type:"string"}
model = "outdoor" #@param ["full", "outdoor", "indoor"]
sritmo = True #@param {type:"boolean"}
sr_factor = 4 #@param {type:"number"}
top_k = 100 #@param {type:"number"}
temperature = 1.0 #@param {type:"number"}
local_sampler_path = None
if model == "full":
local_sampler_path = "./ckpts/text2light_released_model/local_sampler/"
elif model == "outdoor":
local_sampler_path = "./ckpts/text2light_released_model/local_sampler_outdoor/"
elif model == "indoor":
local_sampler_path = "./ckpts/text2light_released_model/local_sampler_indoor/"
else:
raise NotImplementedError
sritmo_path = None
if sritmo:
sritmo_path = "./ckpts/text2light_released_model/sritmo.pth"
opt = argparse.Namespace(
resume_global="./ckpts/text2light_released_model/global_sampler_clip/",
resume_local=local_sampler_path,
sritmo=sritmo_path,
sr_factor=sr_factor,
outdir="./text2light_generated",
clip="./clip_emb.npy",
text=texts,
top_k=top_k,
temperature=temperature,
bs=1,
)
出力結果は以下の通りです。
テキストに対応したパノラマ画像が生成されていますが、両端の品質が安定していません。
まとめ
本記事では、Text2Lightを用いてテキストからパノラマ画像を生成する方法をご紹介しました。
3D空間へのレンダリングやVRへの応用などが考えられます。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
リンク
リンク
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
参考文献
1.
論文 - Text2Light: Zero-Shot Text-Driven HDR Panorama Generation
2. GitHub - FrozenBurning/Text2Light
0 件のコメント :
コメントを投稿