目次
はじめに
どーも、可視化研究室でデータサイエンスを学んでいるゆうき(@engieerblog_Yu)です。
今回は研究の方で、depth情報を使って、等値面を重ね合わせる機会があったのでログを残しておきます。
あまり綺麗なコードではないのでご了承いただきたいです。
二つの等値面を重ね合わせるアルゴリズム
等値面Aと等値面Bを重ね合わせるアルゴリズムです。
AとBの各ピクセルで、depthの値が小さい方を採用して行きます。
Image A;
Image B;
Image C; // C = A + B
for ( 0 <= j < height )
{
for ( 0 <= i < width )
{
if ( A.depth(i,j) <= B.depth(i,j) )
{
C.r(i,j) = A.r(i,j);
C.g(i,j) = A.g(i,j);
C.b(i,j) = A.b(i,j);
C.depth(i,j) = A.depth(i,j);
}
else
{
C.r(i,j) = B.r(i,j);
C.g(i,j) = B.g(i,j);
C.b(i,j) = B.b(i,j);
C.depth(i,j) = B.depth(i,j);
}
}
}
上記のアルゴリズムをGANの出力テンソルに適用するために、かなり改変したのが以下の関数になります。
depth・rgbテンソルを与えて等値面を重ね合わせる関数
関数は、等値面A,Bの二つの等値面を重ね合わせる関数になります。
depthのテンソルの形は[1,1,256,256]、rgbテンソルは[1,3,256,256]になっています。
引数にAとBのdepthテンソルとrgbテンソルを渡しています。
def image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H):
image_array = []
depth_array = []
for i in range(W):
for j in range(H):
if depth_a[0][0][i][j] <= depth_b[0][0][i][j]:
depth_array.append(depth_a[0][0][i][j])
for n in range(3):
image_array.append(rgb_a[0][n][i][j])
else:
depth_array.append(depth_b[0][0][i][j])
for n in range(3):
image_array.append(rgb_b[0][n][i][j])
k = 0
image = np.zeros((1,3,W,H))
for w in range(W):
for h in range(H):
for i in range(3):
image[0][i][w][h] = image_array[k]
k += 1
depth = np.zeros((1,1,W,H))
l = 0
for w in range(W):
for h in range(H):
depth[0][0][w][h] = depth_array[k]
l += 1
depth_tensor = torch.tensor(depth,dtype=torch.float32)
image_tensor = torch.tensor(image,dtype=torch.float32)
return depth_tensor,image_tensor
モデルの読み込みとdepth,RGB画像の生成
すでに学習済みのGANモデルがあるのでそちらを使っています。
GANの出力は[1,4,256,256]です。
4というのはr,g,b,depthを足した4になっています。
def generate_image(x,y,z,args):
# model
def weights_init(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.orthogonal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def add_sn(m):
for name, c in m.named_children():
m.add_module(name, add_sn(c))
if isinstance(m, (nn.Linear, nn.Conv2d)):
return nn.utils.spectral_norm(m, eps=1e-4)
else:
return m
g_model = Generator(ch=64)
g_model.apply(weights_init)
g_model = add_sn(g_model)
order_dict = torch.load(os.path.join(args.root, "model_" + str(False) + "_" + "relu1_2" + "_" + "none" + "_" + str(480) + ".pth"),map_location=torch.device('cpu'))
g_model.load_state_dict(order_dict)
g_model.eval()
sparams = torch.tensor([[1]],dtype=torch.float32)
vops_a = torch.tensor([[100001.1]],dtype=torch.float32)
vops_b = torch.tensor([[100005.3]],dtype=torch.float32)
vops_c = torch.tensor([[100013.7]],dtype=torch.float32)
vparams = torch.tensor([[x,y,z]],dtype=torch.float32)
fake_image_a = g_model(sparams, vops_a, vparams)
fake_image_b = g_model(sparams, vops_b, vparams)
fake_image_c = g_model(sparams, vops_c, vparams)
# RGB画像の出力
rgb_a = fake_image_a[:,0:3,:,:]
rgb_b = fake_image_b[:,0:3,:,:]
rgb_c = fake_image_c[:,0:3,:,:]
# depth画像の出力
depth_a = fake_image_a[:,3:4,:,:]
depth_b = fake_image_b[:,3:4,:,:]
depth_c = fake_image_c[:,3:4,:,:]
W = rgb_a.shape[2]
H = rgb_a.shape[3]
if args.A and args.B and args.C:
merge_depth,merge_rgb = image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H)
merge_depth,merge_rgb = image_merge(merge_depth,depth_c,merge_rgb,rgb_c,W,H)
elif args.A and args.B and not args.C:
merge_depth,merge_rgb = image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H)
elif args.A and not args.B and args.C:
merge_depth,merge_rgb = image_merge(depth_a,depth_c,rgb_a,rgb_c,W,H)
elif not args.A and args.B and args.C:
merge_depth,merge_rgb = image_merge(depth_b,depth_c,rgb_b,rgb_c,W,H)
elif args.A and not args.B and not args.C:
merge_depth,merge_rgb = depth_a,rgb_a
elif not args.A and args.B and not args.C:
merge_depth,merge_rgb = depth_b,rgb_b
elif not args.A and not args.B and args.C:
merge_depth,merge_rgb = depth_c,rgb_c
return merge_rgb
生成しているdepthテンソルを画像にすると、こんな感じです。
main関数の実行
def main(args):
x = 12*np.sin(np.deg2rad(args.theta))*np.sin(np.deg2rad(args.phi))
y = 12*np.cos(np.deg2rad(args.theta))
z = 12*np.sin(np.deg2rad(args.theta)) * np.cos(np.deg2rad(args.phi))
fake_image = generate_image(x,y,z,args)
save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)
全体のコード
全体のコードです。
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
import sys
sys.path.append("../")
from generator import Generator
# parse arguments
def parse_args():
parser = argparse.ArgumentParser(description="InSituNet")
parser.add_argument("--theta", type=float, default=False)
parser.add_argument("--phi", type=float, default=False)
parser.add_argument("--A", type=bool, default=False)
parser.add_argument("--B", type=bool, default=False)
parser.add_argument("--C", type=bool, default=False)
parser.add_argument("--root", required=True, type=str,
help="root of the dataset")
return parser.parse_args()
def image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H):
image_array = []
depth_array = []
for i in range(W):
for j in range(H):
if depth_a[0][0][i][j] <= depth_b[0][0][i][j]:
depth_array.append(depth_a[0][0][i][j])
for n in range(3):
image_array.append(rgb_a[0][n][i][j])
# image_array.append(rgb_b[0][n][i][j].item())
# image_array.append(-1)
elif depth_a[0][0][i][j] > depth_b[0][0][i][j]:
depth_array.append(depth_b[0][0][i][j])
for n in range(3):
image_array.append(rgb_b[0][n][i][j])
# image_array.append(0)
k = 0
image = np.zeros((1,3,W,H))
for w in range(W):
for h in range(H):
for i in range(3):
image[0][i][w][h] = image_array[k]
k += 1
depth = np.zeros((1,1,W,H))
l = 0
for w in range(W):
for h in range(H):
depth[0][0][w][h] = depth_array[k]
l += 1
depth_tensor = torch.tensor(depth,dtype=torch.float32)
image_tensor = torch.tensor(image,dtype=torch.float32)
return depth_tensor,image_tensor
def generate_image(x,y,z,args):
# model
def weights_init(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.orthogonal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def add_sn(m):
for name, c in m.named_children():
m.add_module(name, add_sn(c))
if isinstance(m, (nn.Linear, nn.Conv2d)):
return nn.utils.spectral_norm(m, eps=1e-4)
else:
return m
g_model = Generator(ch=64)
g_model.apply(weights_init)
g_model = add_sn(g_model)
order_dict = torch.load(os.path.join(args.root, "model_" + str(False) + "_" + "relu1_2" + "_" + "none" + "_" + str(480) + ".pth"),map_location=torch.device('cpu'))
g_model.load_state_dict(order_dict)
g_model.eval()
sparams = torch.tensor([[1]],dtype=torch.float32)
vops_a = torch.tensor([[100001.1,0,0]],dtype=torch.float32)
vops_b = torch.tensor([[0,100005.3,0]],dtype=torch.float32)
vops_c = torch.tensor([[0,0,100013.7]],dtype=torch.float32)
vparams = torch.tensor([[x,y,z]],dtype=torch.float32)
fake_image_a = g_model(sparams, vops_a, vparams)
fake_image_b = g_model(sparams, vops_b, vparams)
fake_image_c = g_model(sparams, vops_c, vparams)
# RGB画像の出力
rgb_a = fake_image_a[:,0:3,:,:]
rgb_b = fake_image_b[:,0:3,:,:]
rgb_c = fake_image_c[:,0:3,:,:]
# depth画像の出力
depth_a = fake_image_a[:,3:4,:,:]
depth_b = fake_image_b[:,3:4,:,:]
depth_c = fake_image_c[:,3:4,:,:]
W = rgb_a.shape[2]
H = rgb_a.shape[3]
if args.A and args.B and args.C:
merge_depth,merge_rgb = image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H)
merge_depth,merge_rgb = image_merge(merge_depth,depth_c,merge_rgb,rgb_c,W,H)
elif args.A and args.B and not args.C:
merge_depth,merge_rgb = image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H)
elif args.A and not args.B and args.C:
merge_depth,merge_rgb = image_merge(depth_a,depth_c,rgb_a,rgb_c,W,H)
elif not args.A and args.B and args.C:
merge_depth,merge_rgb = image_merge(depth_b,depth_c,rgb_b,rgb_c,W,H)
elif args.A and not args.B and not args.C:
merge_depth,merge_rgb = depth_a,rgb_a
elif not args.A and args.B and not args.C:
merge_depth,merge_rgb = depth_b,rgb_b
elif not args.A and not args.B and args.C:
merge_depth,merge_rgb = depth_c,rgb_c
return merge_rgb
def main(args):
x = 12*np.sin(np.deg2rad(args.theta))*np.sin(np.deg2rad(args.phi))
y = 12*np.cos(np.deg2rad(args.theta))
z = 12*np.sin(np.deg2rad(args.theta)) * np.cos(np.deg2rad(args.phi))
fake_image = generate_image(x,y,z,args)
save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)
if __name__ == "__main__":
main(parse_args())
終わりに
今回はdepth情報を使って、等値面を重ね合わせるコードを紹介しました。
かなり大変だったので、参考にしていただけると幸いです。
ゆうき
最後まで読んでいただきありがとうございました。
コメント