はじめに
どーも、可視化研究室でデータサイエンスを学んでいるゆうき(@engieerblog_Yu)です。
今回は研究の方で、depth画像をグレースケール画像化する機会があったので、ログを残しておきます。
GANで出力されるテンソルについて
今回の学習済みのGANで出力されるテンソルは、[[r,g,b,depth]]になっています。
r,g,b,depthは、それぞれ256×256の二次テンソルになっています。
今回は、[depth,depth,depth]の三次テンソルをrgbのように画像化するイメージです。
該当のコード
該当するコードは以下です。
def generate_image(rgb,W,H):
k = 0
image = np.zeros((3,W,H))
for w in range(W):
for h in range(H):
for i in range(3):
image[i][w][h] = rgb[k]
k += 1
return image
# 学習済みのGANを使う(細かい部分は省略)
fake_image = g_model(sparams, vops, vparams)
# RGB画像の出力(今回は出力しない)
# fake_image = fake_image[:,0:3,:,:]
# グレースケール画像の出力
fake_image = fake_image[:,3:4,:,:]
fake_image = torch.cat([fake_image, fake_image,fake_image],dim=1)
W = fake_image.shape[2]
H = fake_image.shape[3]
array = []
for i in range(W):
for j in range(H):
for n in range(3):
array.append(fake_image[0][n][i][j].item())
tuple = (array,W,H)
fake_image = generate_image(tuple[0],tuple[1],tuple[2])
fake_image = torch.from_numpy(fake_image.astype(np.float32)).clone()
save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)
以下のコードのsave_imageでは、各ピクセル値を0~1にする必要があります。
(torchvisionの仕様で自動的に255をかけてくれます。)
現在のfake_imageは-1~1のピクセル値に正規化されているので、0~1に変換しています。
save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)
プログラムによって画像の画素値は違うと思いますが、torchvisionを使う場合は、0~1に変換することが必要です。
こんな感じでdepth画像を出力できました。

全体のコード
一応全体のコードも載せておきます。
やっていることは簡単で、学習済みGANモデルを読み込んで、出力を画像にしています。
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
import sys
sys.path.append("../")
from generator import Generator
# parse arguments
def parse_args():
parser = argparse.ArgumentParser(description="InSituNet")
parser.add_argument("--x", type=float, default=False)
parser.add_argument("--y", type=float, default=False)
parser.add_argument("--z", type=float, default=False)
parser.add_argument("--theta", type=float, default=False)
parser.add_argument("--phi", type=float, default=False)
parser.add_argument("--root", required=True, type=str,
help="root of the dataset")
return parser.parse_args()
def generate_imageTuple(x,y,z,args):
# model
def weights_init(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.orthogonal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def add_sn(m):
for name, c in m.named_children():
m.add_module(name, add_sn(c))
if isinstance(m, (nn.Linear, nn.Conv2d)):
return nn.utils.spectral_norm(m, eps=1e-4)
else:
return m
g_model = Generator(ch=64)
g_model.apply(weights_init)
g_model = add_sn(g_model)
order_dict = torch.load(os.path.join(args.root, "model_" + str(False) + "_" + "relu1_2" + "_" + "none" + "_" + str(520) + ".pth"),map_location=torch.device('cpu'))
g_model.load_state_dict(order_dict)
g_model.eval()
sparams = torch.tensor([[1]],dtype=torch.float32)
vops = torch.tensor([[100001.1,0,0]],dtype=torch.float32)
# vops = torch.tensor([[100001.1,100005.3,100013.7]],dtype=torch.float32)
vparams = torch.tensor([[x,y,z]],dtype=torch.float32)
fake_image = g_model(sparams, vops, vparams)
# RGB画像の出力
# fake_image = fake_image[:,0:3,:,:]
# グレースケール画像の出力
fake_image = fake_image[:,3:4,:,:]
fake_image = torch.cat([fake_image, fake_image,fake_image],dim=1)
W = fake_image.shape[2]
H = fake_image.shape[3]
array = []
for i in range(W):
for j in range(H):
for n in range(3):
array.append(fake_image[0][n][i][j].item())
tuple = (array,W,H)
return tuple
def generate_image(rgb,W,H):
k = 0
image = np.zeros((3,W,H))
for w in range(W):
for h in range(H):
for i in range(3):
image[i][w][h] = rgb[k]
k += 1
return image
def main(args):
x = 12*np.sin(np.deg2rad(args.theta))*np.sin(np.deg2rad(args.phi))
y = 12*np.cos(np.deg2rad(args.theta))
z = 12*np.sin(np.deg2rad(args.theta)) * np.cos(np.deg2rad(args.phi))
tuple = generate_imageTuple(x,y,z,args)
W = tuple[1]
H = tuple[2]
fake_image = generate_image(tuple[0],W,H)
fake_image = torch.from_numpy(fake_image.astype(np.float32)).clone()
save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)
if __name__ == "__main__":
main(parse_args())
GANモデルを用いて何をやっているかは、以下の記事でまとめてあります。
終わりに
今回はdepth画像をグレースケールで出力するコードを紹介しました。
皆さんのお役に立てれば何よりです。

ゆうき
最後まで読んでいただきありがとうございました。
コメント