【torchvision】depth画像をグレースケールで出力する方法

プログラミング
スポンサーリンク

スポンサーリンク

はじめに

どーも、可視化研究室でデータサイエンスを学んでいるゆうき(@engieerblog_Yu)です。

今回は研究の方で、depth画像をグレースケール画像化する機会があったので、ログを残しておきます。

GANで出力されるテンソルについて

今回の学習済みのGANで出力されるテンソルは、[[r,g,b,depth]]になっています。

r,g,b,depthは、それぞれ256×256の二次テンソルになっています。

今回は、[depth,depth,depth]の三次テンソルをrgbのように画像化するイメージです。

該当のコード

該当するコードは以下です。

def generate_image(rgb,W,H):
    k = 0
    image = np.zeros((3,W,H))
    for w in range(W):
        for h in range(H):
            for i in range(3):
                image[i][w][h] = rgb[k]
                k += 1
        
    return image

  # 学習済みのGANを使う(細かい部分は省略)
  fake_image = g_model(sparams, vops, vparams)

  # RGB画像の出力(今回は出力しない)
  # fake_image = fake_image[:,0:3,:,:]

  # グレースケール画像の出力
  fake_image = fake_image[:,3:4,:,:]
  fake_image = torch.cat([fake_image, fake_image,fake_image],dim=1)

  W = fake_image.shape[2]
  H = fake_image.shape[3]

  array = []
  for i in range(W):
    for j in range(H):
      for n in range(3):
        array.append(fake_image[0][n][i][j].item())

  tuple = (array,W,H)

  fake_image = generate_image(tuple[0],tuple[1],tuple[2])
  fake_image = torch.from_numpy(fake_image.astype(np.float32)).clone()
  save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)

以下のコードのsave_imageでは、各ピクセル値を0~1にする必要があります。

(torchvisionの仕様で自動的に255をかけてくれます。)

現在のfake_imageは-1~1のピクセル値に正規化されているので、0~1に変換しています。

 save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)

プログラムによって画像の画素値は違うと思いますが、torchvisionを使う場合は、0~1に変換することが必要です。

こんな感じでdepth画像を出力できました。

全体のコード

一応全体のコードも載せておきます。

やっていることは簡単で、学習済みGANモデルを読み込んで、出力を画像にしています。

import os
import argparse

import numpy as np

import torch
import torch.nn as nn
from torchvision.utils import save_image

import sys
sys.path.append("../")

from generator import Generator


# parse arguments
def parse_args():
  parser = argparse.ArgumentParser(description="InSituNet")

  parser.add_argument("--x", type=float, default=False)

  parser.add_argument("--y", type=float, default=False)

  parser.add_argument("--z", type=float, default=False)

  parser.add_argument("--theta", type=float, default=False)

  parser.add_argument("--phi", type=float, default=False)

  parser.add_argument("--root", required=True, type=str,
                      help="root of the dataset")
  return parser.parse_args()


def generate_imageTuple(x,y,z,args):
  # model
  def weights_init(m):
    if isinstance(m, nn.Linear):
      nn.init.orthogonal_(m.weight)
      if m.bias is not None:
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
      nn.init.orthogonal_(m.weight)
      if m.bias is not None:
        nn.init.zeros_(m.bias)

  def add_sn(m):
    for name, c in m.named_children():
      m.add_module(name, add_sn(c))
    if isinstance(m, (nn.Linear, nn.Conv2d)):
      return nn.utils.spectral_norm(m, eps=1e-4)
    else:
      return m

  g_model = Generator(ch=64)

  g_model.apply(weights_init)

  g_model = add_sn(g_model)


  order_dict = torch.load(os.path.join(args.root, "model_" + str(False) + "_" + "relu1_2" + "_" + "none" + "_" + str(520) + ".pth"),map_location=torch.device('cpu'))


  g_model.load_state_dict(order_dict)

  g_model.eval()

  sparams = torch.tensor([[1]],dtype=torch.float32)
  vops = torch.tensor([[100001.1,0,0]],dtype=torch.float32)
  # vops = torch.tensor([[100001.1,100005.3,100013.7]],dtype=torch.float32)
  vparams = torch.tensor([[x,y,z]],dtype=torch.float32)
  fake_image = g_model(sparams, vops, vparams)

  # RGB画像の出力
  # fake_image = fake_image[:,0:3,:,:]

  # グレースケール画像の出力
  fake_image = fake_image[:,3:4,:,:]
  fake_image = torch.cat([fake_image, fake_image,fake_image],dim=1)

  W = fake_image.shape[2]
  H = fake_image.shape[3]
  array = []
  for i in range(W):
    for j in range(H):
      for n in range(3):
        array.append(fake_image[0][n][i][j].item())

  tuple = (array,W,H)

  return tuple

def generate_image(rgb,W,H):
    k = 0
    image = np.zeros((3,W,H))
    for w in range(W):
        for h in range(H):
            for i in range(3):
                image[i][w][h] = rgb[k]
                k += 1
        
    return image


def main(args):
    x = 12*np.sin(np.deg2rad(args.theta))*np.sin(np.deg2rad(args.phi))
    y = 12*np.cos(np.deg2rad(args.theta))
    z = 12*np.sin(np.deg2rad(args.theta)) * np.cos(np.deg2rad(args.phi))
    tuple = generate_imageTuple(x,y,z,args)
    W = tuple[1]
    H = tuple[2]
    fake_image = generate_image(tuple[0],W,H)
    fake_image = torch.from_numpy(fake_image.astype(np.float32)).clone()
    save_image((((fake_image) + 1.) * .5),'image.png',nrow=1)
 
    

if __name__ == "__main__":
  main(parse_args())

GANモデルを用いて何をやっているかは、以下の記事でまとめてあります。

終わりに

今回はdepth画像をグレースケールで出力するコードを紹介しました。

皆さんのお役に立てれば何よりです。

ゆうき
ゆうき

最後まで読んでいただきありがとうございました。

コメント

タイトルとURLをコピーしました