【GAN】depth画像を用いた等値面の重ね合わせ(PyTorch)

プログラミング
スポンサーリンク

スポンサーリンク

はじめに

どーも、可視化研究室でデータサイエンスを学んでいるゆうき(@engieerblog_Yu)です。

今回は研究の方で、depth情報を使って、等値面を重ね合わせる機会があったのでログを残しておきます。

あまり綺麗なコードではないのでご了承いただきたいです。

二つの等値面を重ね合わせるアルゴリズム

等値面Aと等値面Bを重ね合わせるアルゴリズムです。

AとBの各ピクセルで、depthの値が小さい方を採用して行きます。

Image A;
Image B;
Image C; // C = A + B

for ( 0 <= j < height )
{
    for ( 0 <= i < width )
    {
        if ( A.depth(i,j) <= B.depth(i,j) )
        {
            C.r(i,j) = A.r(i,j);
            C.g(i,j) = A.g(i,j);
            C.b(i,j) = A.b(i,j);
            C.depth(i,j) = A.depth(i,j);
        }
        else
        {
            C.r(i,j) = B.r(i,j);
            C.g(i,j) = B.g(i,j);
            C.b(i,j) = B.b(i,j);
            C.depth(i,j) = B.depth(i,j);
        }
    }
}

上記のアルゴリズムをGANの出力テンソルに適用するために、かなり改変したのが以下の関数になります。

depth・rgbテンソルを与えて等値面を重ね合わせる関数

関数は、等値面A,Bの二つの等値面を重ね合わせる関数になります。

depthのテンソルの形は[1,1,256,256]、rgbテンソルは[1,3,256,256]になっています。

引数にAとBのdepthテンソルとrgbテンソルを渡しています。

def image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H):
  image_array = []
  depth_array = []

  for i in range(W):
    for j in range(H):
      if depth_a[0][0][i][j] <= depth_b[0][0][i][j]:
        depth_array.append(depth_a[0][0][i][j])
        for n in range(3):
          image_array.append(rgb_a[0][n][i][j])
         
      else:
        depth_array.append(depth_b[0][0][i][j])
        for n in range(3):
          image_array.append(rgb_b[0][n][i][j])
  

  k = 0
  image = np.zeros((1,3,W,H))
  for w in range(W):
    for h in range(H):
      for i in range(3):
        image[0][i][w][h] = image_array[k]
        k += 1
  depth = np.zeros((1,1,W,H))
  l = 0
  for w in range(W):
    for h in range(H):
        depth[0][0][w][h] = depth_array[k]
        l += 1
          
  depth_tensor = torch.tensor(depth,dtype=torch.float32)
  image_tensor = torch.tensor(image,dtype=torch.float32)
  return depth_tensor,image_tensor

モデルの読み込みとdepth,RGB画像の生成

すでに学習済みのGANモデルがあるのでそちらを使っています。

GANの出力は[1,4,256,256]です。

4というのはr,g,b,depthを足した4になっています。

def generate_image(x,y,z,args):
  # model
  def weights_init(m):
    if isinstance(m, nn.Linear):
      nn.init.orthogonal_(m.weight)
      if m.bias is not None:
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
      nn.init.orthogonal_(m.weight)
      if m.bias is not None:
        nn.init.zeros_(m.bias)

  def add_sn(m):
    for name, c in m.named_children():
      m.add_module(name, add_sn(c))
    if isinstance(m, (nn.Linear, nn.Conv2d)):
      return nn.utils.spectral_norm(m, eps=1e-4)
    else:
      return m

  g_model = Generator(ch=64)

  g_model.apply(weights_init)

  g_model = add_sn(g_model)


  order_dict = torch.load(os.path.join(args.root, "model_" + str(False) + "_" + "relu1_2" + "_" + "none" + "_" + str(480) + ".pth"),map_location=torch.device('cpu'))

  g_model.load_state_dict(order_dict)

  g_model.eval()

  sparams = torch.tensor([[1]],dtype=torch.float32)
  vops_a = torch.tensor([[100001.1]],dtype=torch.float32)
  vops_b = torch.tensor([[100005.3]],dtype=torch.float32)
  vops_c = torch.tensor([[100013.7]],dtype=torch.float32)
  vparams = torch.tensor([[x,y,z]],dtype=torch.float32)
  fake_image_a = g_model(sparams, vops_a, vparams)
  fake_image_b = g_model(sparams, vops_b, vparams)
  fake_image_c = g_model(sparams, vops_c, vparams)

  # RGB画像の出力
  rgb_a = fake_image_a[:,0:3,:,:]
  rgb_b = fake_image_b[:,0:3,:,:]
  rgb_c = fake_image_c[:,0:3,:,:]


  # depth画像の出力
  depth_a = fake_image_a[:,3:4,:,:]
  depth_b = fake_image_b[:,3:4,:,:]
  depth_c = fake_image_c[:,3:4,:,:]

  W = rgb_a.shape[2]
  H = rgb_a.shape[3]

  if args.A and args.B and args.C:
    merge_depth,merge_rgb = image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H)
    merge_depth,merge_rgb = image_merge(merge_depth,depth_c,merge_rgb,rgb_c,W,H)

  elif args.A and args.B and not args.C:
    merge_depth,merge_rgb = image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H)

  elif args.A and not args.B and args.C:
    merge_depth,merge_rgb = image_merge(depth_a,depth_c,rgb_a,rgb_c,W,H)

  elif not args.A and args.B and args.C:
    merge_depth,merge_rgb = image_merge(depth_b,depth_c,rgb_b,rgb_c,W,H)

  elif args.A and not args.B and not args.C:
    merge_depth,merge_rgb = depth_a,rgb_a

  elif not args.A and args.B and not args.C:
    merge_depth,merge_rgb = depth_b,rgb_b

  elif not args.A and not args.B and args.C:
    merge_depth,merge_rgb = depth_c,rgb_c  


  return merge_rgb

生成しているdepthテンソルを画像にすると、こんな感じです。

main関数の実行

def main(args):
    x = 12*np.sin(np.deg2rad(args.theta))*np.sin(np.deg2rad(args.phi))
    y = 12*np.cos(np.deg2rad(args.theta))
    z = 12*np.sin(np.deg2rad(args.theta)) * np.cos(np.deg2rad(args.phi))
    fake_image = generate_image(x,y,z,args)
    save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)

全体のコード

全体のコードです。

import os
import argparse

import numpy as np

import torch
import torch.nn as nn
from torchvision.utils import save_image

import sys
sys.path.append("../")

from generator import Generator


# parse arguments
def parse_args():
  parser = argparse.ArgumentParser(description="InSituNet")

  parser.add_argument("--theta", type=float, default=False)

  parser.add_argument("--phi", type=float, default=False)

  parser.add_argument("--A", type=bool, default=False)

  parser.add_argument("--B", type=bool, default=False)

  parser.add_argument("--C", type=bool, default=False)

  parser.add_argument("--root", required=True, type=str,
                      help="root of the dataset")
  return parser.parse_args()

def image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H):
  image_array = []
  depth_array = []

  for i in range(W):
    for j in range(H):
      if depth_a[0][0][i][j] <= depth_b[0][0][i][j]:
        depth_array.append(depth_a[0][0][i][j])
        for n in range(3):
          image_array.append(rgb_a[0][n][i][j])
          # image_array.append(rgb_b[0][n][i][j].item())
          # image_array.append(-1)
      elif depth_a[0][0][i][j] > depth_b[0][0][i][j]:
        depth_array.append(depth_b[0][0][i][j])
        for n in range(3):
          image_array.append(rgb_b[0][n][i][j])
          # image_array.append(0)

  k = 0
  image = np.zeros((1,3,W,H))
  for w in range(W):
    for h in range(H):
      for i in range(3):
        image[0][i][w][h] = image_array[k]
        k += 1
  depth = np.zeros((1,1,W,H))
  l = 0
  for w in range(W):
    for h in range(H):
        depth[0][0][w][h] = depth_array[k]
        l += 1
          
  depth_tensor = torch.tensor(depth,dtype=torch.float32)
  image_tensor = torch.tensor(image,dtype=torch.float32)
  return depth_tensor,image_tensor



def generate_image(x,y,z,args):
  # model
  def weights_init(m):
    if isinstance(m, nn.Linear):
      nn.init.orthogonal_(m.weight)
      if m.bias is not None:
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
      nn.init.orthogonal_(m.weight)
      if m.bias is not None:
        nn.init.zeros_(m.bias)

  def add_sn(m):
    for name, c in m.named_children():
      m.add_module(name, add_sn(c))
    if isinstance(m, (nn.Linear, nn.Conv2d)):
      return nn.utils.spectral_norm(m, eps=1e-4)
    else:
      return m

  g_model = Generator(ch=64)

  g_model.apply(weights_init)

  g_model = add_sn(g_model)


  order_dict = torch.load(os.path.join(args.root, "model_" + str(False) + "_" + "relu1_2" + "_" + "none" + "_" + str(480) + ".pth"),map_location=torch.device('cpu'))


  g_model.load_state_dict(order_dict)

  g_model.eval()

  sparams = torch.tensor([[1]],dtype=torch.float32)
  vops_a = torch.tensor([[100001.1,0,0]],dtype=torch.float32)
  vops_b = torch.tensor([[0,100005.3,0]],dtype=torch.float32)
  vops_c = torch.tensor([[0,0,100013.7]],dtype=torch.float32)
  vparams = torch.tensor([[x,y,z]],dtype=torch.float32)
  fake_image_a = g_model(sparams, vops_a, vparams)
  fake_image_b = g_model(sparams, vops_b, vparams)
  fake_image_c = g_model(sparams, vops_c, vparams)

  # RGB画像の出力
  rgb_a = fake_image_a[:,0:3,:,:]
  rgb_b = fake_image_b[:,0:3,:,:]
  rgb_c = fake_image_c[:,0:3,:,:]


  # depth画像の出力
  depth_a = fake_image_a[:,3:4,:,:]
  depth_b = fake_image_b[:,3:4,:,:]
  depth_c = fake_image_c[:,3:4,:,:]

  W = rgb_a.shape[2]
  H = rgb_a.shape[3]

  if args.A and args.B and args.C:
    merge_depth,merge_rgb = image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H)
    merge_depth,merge_rgb = image_merge(merge_depth,depth_c,merge_rgb,rgb_c,W,H)

  elif args.A and args.B and not args.C:
    merge_depth,merge_rgb = image_merge(depth_a,depth_b,rgb_a,rgb_b,W,H)

  elif args.A and not args.B and args.C:
    merge_depth,merge_rgb = image_merge(depth_a,depth_c,rgb_a,rgb_c,W,H)

  elif not args.A and args.B and args.C:
    merge_depth,merge_rgb = image_merge(depth_b,depth_c,rgb_b,rgb_c,W,H)

  elif args.A and not args.B and not args.C:
    merge_depth,merge_rgb = depth_a,rgb_a

  elif not args.A and args.B and not args.C:
    merge_depth,merge_rgb = depth_b,rgb_b

  elif not args.A and not args.B and args.C:
    merge_depth,merge_rgb = depth_c,rgb_c  


  return merge_rgb


def main(args):
    x = 12*np.sin(np.deg2rad(args.theta))*np.sin(np.deg2rad(args.phi))
    y = 12*np.cos(np.deg2rad(args.theta))
    z = 12*np.sin(np.deg2rad(args.theta)) * np.cos(np.deg2rad(args.phi))
    fake_image = generate_image(x,y,z,args)
    save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)

    

if __name__ == "__main__":
  main(parse_args())

終わりに

今回はdepth情報を使って、等値面を重ね合わせるコードを紹介しました。

かなり大変だったので、参考にしていただけると幸いです。

ゆうき
ゆうき

最後まで読んでいただきありがとうございました。

コメント

タイトルとURLをコピーしました