From b5bc801b930760556f2a4f4a0b51b38bd5fc75c6 Mon Sep 17 00:00:00 2001 From: bryandlee Date: Tue, 30 Nov 2021 08:26:07 +0900 Subject: [PATCH] Refactor: Remove cv2 depedency --- test.py | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/test.py b/test.py index 7e6b60f..855b451 100644 --- a/test.py +++ b/test.py @@ -1,28 +1,29 @@ +import os import argparse -import torch -import cv2 +from PIL import Image import numpy as np -import os + +import torch +from torchvision.transforms.functional import to_tensor, to_pil_image from model import Generator + torch.backends.cudnn.enabled = False torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True - + + def load_image(image_path, x32=False): - img = cv2.imread(image_path).astype(np.float32) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - h, w = img.shape[:2] + img = Image.open(image_path).convert("RGB") - if x32: # resize image to multiple of 32s + if x32: def to_32s(x): - return 256 if x < 256 else x - x%32 - img = cv2.resize(img, (to_32s(w), to_32s(h))) + return 256 if x < 256 else x - x % 32 + w, h = img.size + img = img.resize((to_32s(w), to_32s(h))) - img = torch.from_numpy(img) - img = img/127.5 - 1.0 return img @@ -43,22 +44,22 @@ def test(args): image = load_image(os.path.join(args.input_dir, image_name), args.x32) with torch.no_grad(): - input = image.permute(2, 0, 1).unsqueeze(0).to(device) - out = net(input, args.upsample_align).squeeze(0).permute(1, 2, 0).cpu().numpy() - out = (out + 1)*127.5 - out = np.clip(out, 0, 255).astype(np.uint8) - - cv2.imwrite(os.path.join(args.output_dir, image_name), cv2.cvtColor(out, cv2.COLOR_BGR2RGB)) + image = to_tensor(image).unsqueeze(0) * 2 - 1 + out = net(image.to(device), args.upsample_align).cpu() + out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5 + out = to_pil_image(out) + + out.save(os.path.join(args.output_dir, image_name)) print(f"image saved: {image_name}") - + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--checkpoint', type=str, - default='./pytorch_generator_Paprika.pt', + default='./weights/paprika.pt', ) parser.add_argument( '--input_dir', @@ -79,12 +80,13 @@ def test(args): '--upsample_align', type=bool, default=False, + help="Align corners in decoder upsampling layers" ) parser.add_argument( '--x32', action="store_true", + help="Resize images to multiple of 32" ) args = parser.parse_args() test(args) -