本記事では、Text2Lightと呼ばれる機械学習手法を用いて任意の自然言語表現からパノラマ画像を生成する方法をご紹介します。
Text2Light
概要
Text2Lightは、ゼロショットテキスト駆動型のHDRパノラマ画像(High Dynamic Range
Images)生成技術です。
Text2Lightでは以下の2ステップでHDRIを生成します。
-
Text-driven LDR Panorama Generation
このステップでは、初めにCLIPを用いてテキストから低ダイナミックレンジで低解像度のLDRパノラマを生成
-
Super-resolution Inverse Tone Mapping
LDRパノラマをパッチに分解し、球状に構造化された潜在空間にマッピングし、解像度とダイナミックレンジを同時にアップスケーリング
以上のステップで構成されたTex2Lightは、ゼロショットで汎用性の高いHDR画像の生成を可能にしています。
詳細はこちらの論文をご参照ください。
本記事では上記手法を用いて、テキストからパノラマ画像を生成していきます。
スポンサーリンク
デモ(Colaboratory)
それでは、実際に動かしながらテキストからパノラマ画像を生成していきます。
ソースコードは本記事にも記載していますが、下記のGitHubでも取得可能です。
GitHub - Colaboratory demo
また、下記から直接Google Colaboratoryで開くこともできます。
なお、このデモはPythonで実装しています。
Pythonの実装に不安がある方、Pythonを使った機械学習について詳しく勉強したい方は、以下の書籍やオンライン講座などがおすすめです。
環境セットアップ
それではセットアップしていきます。
Colaboratoryを開いたら下記を設定しGPUを使用するようにしてください。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
初めにGithubからソースコードを取得します。
- %cd /content
-
- !git clone https://github.com/FrozenBurning/Text2Light
-
-
- # Commits on Sep 22, 2022
- %cd /content/Text2Light
- !git checkout acd216150a76b8bb4e460801e95100d167a6f7fe
次にライブラリをインストールします。
- %cd /content/Text2Light
-
- !sudo apt-get install libomp-dev
-
- !pip install ftfy regex tqdm omegaconf pytorch-lightning tensorboardX einops transformers
- !pip install kornia
- !pip install imageio-ffmpeg
- !pip install faiss
- !pip install --upgrade gdown
最後にライブラリをインポートします。
- import os
-
- import argparse, os, sys, glob
- import cv2
- import torch
- import faiss
- import numpy as np
- from omegaconf import OmegaConf
- from PIL import Image
- from tqdm import tqdm
-
- import clip
- from taming.util import instantiate_from_config
- from sritmo.global_sritmo import SRiTMO
以上で環境セットアップは完了です。
学習済みモデルのセットアップ
続いて論文発表元が公開している学習済みモデルをGoogle
Colaboratoryにダウンロードします。
Google
Driveに公開された学習済みモデルをgdownライブラリを使用してダウンロードしています。
- %cd /content/Text2Light
- !mkdir -p ckpts
-
-
- ckpt_path = 'ckpts/text2light_released_model/global_sampler_clip/checkpoints/last.ckpt'
- if not os.path.exists(ckpt_path):
- !gdown --folder https://drive.google.com/drive/folders/1HKBjC7oQOzrkGFKMQmSh6PySv6AycDS3?usp=sharing \
- -O 'ckpts/text2light_released_model'
-
-
- # 学習済みモデルのダウンロードで下記エラーが出た場合はこのセルを実行
- # Access denied with the following error:
-
- # Too many users have viewed or downloaded this file recently. Please
- # try accessing the file again later. If the file you are trying to
- # access is particularly large or is shared with many people, it may
- # take up to 24 hours to be able to view or download the file. If you
- # still can't access a file after 24 hours, contact your domain
- # administrator.
-
- # text2light_released_modelの「ドライブへのショートカットを追加」後
- # from google.colab import drive
- # drive.mount('/content/drive')
-
- # !rm -rf /content/Text2Light/ckpts/text2light_released_model
- # !cp -r /content/drive/MyDrive/text2light_released_model /content/Text2Light/ckpts/
Utility関数定義
続いて画像保存用関数や、モデルのロード用関数等を定義していきます。
- # Tensor imgの保存
- def save_image(x, path):
- c,h,w = x.shape
- assert c==3
- x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
- s = Image.fromarray(x)
- s.save(path)
- return s
-
- def get_knn(database: np.array, index: faiss.Index, txt_emb, k = 5):
- dist, idx = index.search(txt_emb, k)
- return database[idx], idx #[bs, k, 512]
-
- # predict
- @torch.no_grad()
- def text2light(models: dict, prompts, outdir, params: dict):
- # models
- global_sampler = models["gs"]
- local_sampler = models["ls"]
- # params
- batch_size = len(prompts)
- top_k = params["top_k"]
- temperature = params['temperature']
- database = params['data4knn']
- faiss_index = params['index4knn']
- device = params['device']
-
- # embed input texts
- lan_model, _ = clip.load("ViT-B/32", device=device)
- lan_model.eval()
- text = clip.tokenize(prompts).to(device)
- text_features = lan_model.encode_text(text)
- target_txt_emb = text_features / text_features.norm(dim=-1, keepdim=True)
- cond, _ = get_knn(database, faiss_index, target_txt_emb.cpu().numpy().astype('float32'))
- txt_cond = torch.from_numpy(cond.reshape(batch_size, 5, cond.shape[-1]))
- txt_cond = torch.cat([txt_cond, txt_cond,], dim=-1).to(device)
-
- # sample holistic condition
- bs = batch_size
- start = 0
- idx = torch.zeros(bs, 1, dtype=int)[:, :start].to(device)
- cshape = [bs, 256, 8, 16]
- sample = True
-
- print("Generating holistic conditions according to texts...")
- for i in tqdm(range(start, cshape[2]*cshape[3])):
- logits, _ = global_sampler.transformer(idx, embeddings=txt_cond)
- logits = logits[:, -1, :]
- if top_k is not None:
- logits = global_sampler.top_k_logits(logits, top_k)
- probs = torch.nn.functional.softmax(logits, dim=-1)
- if sample:
- ix = torch.multinomial(probs, num_samples=1)
- else:
- _, ix = torch.topk(probs, k=1, dim=-1)
- idx = torch.cat((idx, ix), dim=1)
-
- xsample_holistic = global_sampler.decode_to_img(idx, cshape)
- for i in range(xsample_holistic.shape[0]):
- holistic_save = save_image(xsample_holistic[i], os.path.join(outdir, "holistic", "holistic_[{}].png".format(prompts[i])))
-
- print("Synthesizing patches...")
- # synthesize patch by patch according to holistic condition
- h = 512
- w = 1024
- xx, yy = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
- screen_points = np.stack([xx, yy], axis=-1)
- coord = (screen_points * 2 - 1) * np.array([np.pi, np.pi/2])
- spe = torch.from_numpy(coord).to(xsample_holistic).repeat(xsample_holistic.shape[0], 1, 1, 1).permute(0, 3, 1, 2)
- spe = torch.nn.functional.interpolate(spe, scale_factor=1/8, mode="bicubic", recompute_scale_factor=False, align_corners=True)
- spe = local_sampler.embedder(spe.permute(0, 2, 3, 1))
- spe = spe.permute(0, 3, 1, 2)
-
- _, h_indices = local_sampler.encode_to_h(xsample_holistic)
- cshape = [xsample_holistic.shape[0], 256, h // 16, w // 16]
- idx = torch.randint(0, 1024, (cshape[0], cshape[2], cshape[3])).to(h_indices)
- idx = idx.reshape(cshape[0], cshape[2], cshape[3])
-
- start = 0
- start_i = start // cshape[3]
- start_j = start % cshape[3]
- sample = True
-
- for i in tqdm(range(start_i, cshape[2])):
- if i <= 8:
- local_i = i
- elif cshape[2]-i < 8:
- local_i = 16-(cshape[2]-i)
- else:
- local_i = 8
- for j in range(start_j, cshape[3]):
- if j <= 8:
- local_j = j
- elif cshape[3]-j < 8:
- local_j = 16-(cshape[3]-j)
- else:
- local_j = 8
-
- i_start = i-local_i
- i_end = i_start+16
- j_start = j-local_j
- j_end = j_start+16
- patch = idx[:,i_start:i_end,j_start:j_end]
- patch = patch.reshape(patch.shape[0],-1)
- cpatch = spe[:, :, i_start*2:i_end*2,j_start*2:j_end*2]
- cpatch = cpatch.reshape(cpatch.shape[0], local_sampler.cdim, -1)
- patch = torch.cat((h_indices, patch), dim=1)
- logits, _ = local_sampler.transformer(patch[:,:-1], embeddings=cpatch)
- logits = logits[:, -256:, :]
- logits = logits.reshape(cshape[0],16,16,-1)
- logits = logits[:,local_i,local_j,:]
- logits = logits / temperature
-
- if top_k is not None:
- logits = local_sampler.top_k_logits(logits, top_k)
- # apply softmax to convert to probabilities
- probs = torch.nn.functional.softmax(logits, dim=-1)
- # sample from the distribution or take the most likely
- if sample:
- ix = torch.multinomial(probs, num_samples=1)
- else:
- _, ix = torch.topk(probs, k=1, dim=-1)
- idx[:,i,j] = ix.reshape(-1)
- xsample = local_sampler.decode_to_img(idx, cshape)
- for i in range(xsample.shape[0]):
- ldr_save = save_image(xsample[i], os.path.join(outdir, "ldr", "ldr_[{}].png".format(prompts[i])))
-
- # super-resolution inverse tone mapping
- if params['sritmo'] is not None:
- ldr_hr_samples, hdr_hr_samples = SRiTMO(xsample, params)
- else:
- print("no checkpoint provided, skip Stage II (SR-iTMO)...")
- return
-
- for i in range(xsample.shape[0]):
- ldr_hr_save = (ldr_hr_samples[i].permute(1, 2, 0).detach().cpu().numpy() + 1) * 127.5
- cv2.imwrite(os.path.join(outdir, "ldr", "hrldr_[{}].png".format(prompts[i])), ldr_hr_save)
- # cv2.imwrite(os.path.join(outdir, "hdr", "hdr_[{}].exr".format(prompts[i])), hdr_hr_samples[i].permute(1, 2, 0).detach().cpu().numpy())
- return ldr_hr_save
-
- # load model
- def load_model_from_config(config, sd, gpu=True, eval_mode=True):
- if "ckpt_path" in config.params:
- print("Deleting the restore-ckpt path from the config...")
- config.params.ckpt_path = None
- if "downsample_cond_size" in config.params:
- print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
- config.params.downsample_cond_size = -1
- config.params["downsample_cond_factor"] = 0.5
- try:
- if "ckpt_path" in config.params.first_stage_config.params:
- config.params.first_stage_config.params.ckpt_path = None
- print("Deleting the first-stage restore-ckpt path from the config...")
- if "ckpt_path" in config.params.cond_stage_config.params:
- config.params.cond_stage_config.params.ckpt_path = None
- print("Deleting the cond-stage restore-ckpt path from the config...")
- if "ckpt_path" in config.params.holistic_config.params:
- config.params.holistic_config.params.ckpt_path = None
- print("Deleting the global sampler restore-ckpt path from the config...")
- except:
- pass
-
- model = instantiate_from_config(config)
- if sd is not None:
- missing, unexpected = model.load_state_dict(sd, strict=False)
- print(f"Missing Keys in State Dict: {missing}")
- print(f"Unexpected Keys in State Dict: {unexpected}")
- if gpu:
- model.cuda()
- if eval_mode:
- model.eval()
- return {"model": model}
-
- def load_model(config, ckpt, gpu, eval_mode):
- if ckpt:
- raw_model = torch.load(ckpt, map_location="cpu")
- state_dict = raw_model["state_dict"]
- else:
- raise NotImplementedError("checkpoint at [{}] is not found!".format(ckpt))
- model = load_model_from_config(config.model, state_dict, gpu=gpu, eval_mode=eval_mode)["model"]
- return model
Text to High Dynamic Range Images
それでは、テキストを設定し、パノラマ画像を生成していきます。
本記事では、Sunny Park
を設定してみます。
- texts = "Sunny Park" #@param {type:"string"}
- model = "outdoor" #@param ["full", "outdoor", "indoor"]
- sritmo = True #@param {type:"boolean"}
- sr_factor = 4 #@param {type:"number"}
- top_k = 100 #@param {type:"number"}
- temperature = 1.0 #@param {type:"number"}
-
-
- local_sampler_path = None
- if model == "full":
- local_sampler_path = "./ckpts/text2light_released_model/local_sampler/"
- elif model == "outdoor":
- local_sampler_path = "./ckpts/text2light_released_model/local_sampler_outdoor/"
- elif model == "indoor":
- local_sampler_path = "./ckpts/text2light_released_model/local_sampler_indoor/"
- else:
- raise NotImplementedError
-
- sritmo_path = None
- if sritmo:
- sritmo_path = "./ckpts/text2light_released_model/sritmo.pth"
-
-
- opt = argparse.Namespace(
- resume_global="./ckpts/text2light_released_model/global_sampler_clip/",
- resume_local=local_sampler_path,
- sritmo=sritmo_path,
- sr_factor=sr_factor,
- outdir="./text2light_generated",
- clip="./clip_emb.npy",
- text=texts,
- top_k=top_k,
- temperature=temperature,
- bs=1,
- )
-
出力結果は以下の通りです。
テキストに対応したパノラマ画像が生成されていますが、両端の品質が安定していません。
まとめ
本記事では、Text2Lightを用いてテキストからパノラマ画像を生成する方法をご紹介しました。
3D空間へのレンダリングやVRへの応用などが考えられます。
また本記事では、機械学習を動かすことにフォーカスしてご紹介しました。
もう少し学術的に体系立てて学びたいという方には以下の書籍などがお勧めです。ぜひご一読下さい。
また動かせるだけから理解して応用できるエンジニアの足掛かりに下記のUdemyなどもお勧めです。
スポンサーリンク
参考文献
1.
論文 - Text2Light: Zero-Shot Text-Driven HDR Panorama Generation
2. GitHub - FrozenBurning/Text2Light
0 件のコメント :
コメントを投稿