【PyTorch】LPIPSをGPUで使えない時の対処法

プログラミング
スポンサーリンク

スポンサーリンク

はじめに

どーも、学生エンジニアのゆうき(@engineerblog_Yu)です。

研究でLPIPSを使う機会があったのですが、GPUでLPIPSを使った例がWebにあまり見当たらなくて苦戦したので、ログを残しておきます。

LPIPSをGPUで導入するまで

公式ドキュメントより以下のコードを実装しました。

import torch
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

device = torch.device("cuda:0" if args.cuda else "cpu")
lpips_criterion = LearnedPerceptualImagePatchSimilarity(net_type='vgg')

//中略(imageとfake_imageはtorch.tensor型の四次元配列)

print(lpips_criterion(image,fake_image)

すると以下のようなエラーが出てきました。

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

エラーを調べてみたところ、GPUに送られている変数同士ではなく、CPUとGPUで演算されているときに起こるエラーだということが分かりました。

lpips_criterionのところでエラーが起きているので、imageとfake_imageが正しくGPUに送れているのか調べてみました。

print(image.is_cuda)
print(fake_image.is_cuda)

するとどちらもGPUに正しく送れていることが分かりました。

True
True

次に試しにLPIPS自体に問題があるのではないかと考えて、LPIPSではなくMSEで試してみました。

import torch

device = torch.device("cuda:0" if args.cuda else "cpu")
mse_criterion = nn.MSELoss()

//中略(imageとfake_imageはtorch.tensor型の四次元配列)

print(mse_criterion(image,fake_image)

すると正しく値が表示されました。

LPIPSについて色々詳しく調べてみたところ、LPIPSは学習済み画像分類ネットワークを用いていることがわかったので、LPIPSもGPUに送ってみました。

import torch
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

device = torch.device("cuda:0" if args.cuda else "cpu")
lpips_criterion = LearnedPerceptualImagePatchSimilarity(net_type='vgg').to(device)

//中略(imageとfake_imageはtorch.tensor型の四次元配列)

print(lpips_criterion(image,fake_image)

するとうまく実行することができました。

tensor(0.6221, device='cuda:0')

終わりに

今回は、久しぶりにエラー対処法についての記事を書いてみました。

LPIPSは、SSIMやPSNRより優れている画像類似度評価の指標として近年注目されているものです。

GANなどの生成モデルの評価指標として、これから使われていけば嬉しいですね。

ゆうき
ゆうき

最後まで読んでいただきありがとうございました。

コメント

タイトルとURLをコピーしました