Skip to content

Commit

Permalink
Refactor: Remove cv2 depedency
Browse files Browse the repository at this point in the history
  • Loading branch information
bryandlee committed Nov 29, 2021
1 parent f11e056 commit b5bc801
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
import os
import argparse

import torch
import cv2
from PIL import Image
import numpy as np
import os

import torch
from torchvision.transforms.functional import to_tensor, to_pil_image

from model import Generator


torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True



def load_image(image_path, x32=False):
img = cv2.imread(image_path).astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]
img = Image.open(image_path).convert("RGB")

if x32: # resize image to multiple of 32s
if x32:
def to_32s(x):
return 256 if x < 256 else x - x%32
img = cv2.resize(img, (to_32s(w), to_32s(h)))
return 256 if x < 256 else x - x % 32
w, h = img.size
img = img.resize((to_32s(w), to_32s(h)))

img = torch.from_numpy(img)
img = img/127.5 - 1.0
return img


Expand All @@ -43,22 +44,22 @@ def test(args):
image = load_image(os.path.join(args.input_dir, image_name), args.x32)

with torch.no_grad():
input = image.permute(2, 0, 1).unsqueeze(0).to(device)
out = net(input, args.upsample_align).squeeze(0).permute(1, 2, 0).cpu().numpy()
out = (out + 1)*127.5
out = np.clip(out, 0, 255).astype(np.uint8)
cv2.imwrite(os.path.join(args.output_dir, image_name), cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
image = to_tensor(image).unsqueeze(0) * 2 - 1
out = net(image.to(device), args.upsample_align).cpu()
out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
out = to_pil_image(out)

out.save(os.path.join(args.output_dir, image_name))
print(f"image saved: {image_name}")


if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument(
'--checkpoint',
type=str,
default='./pytorch_generator_Paprika.pt',
default='./weights/paprika.pt',
)
parser.add_argument(
'--input_dir',
Expand All @@ -79,12 +80,13 @@ def test(args):
'--upsample_align',
type=bool,
default=False,
help="Align corners in decoder upsampling layers"
)
parser.add_argument(
'--x32',
action="store_true",
help="Resize images to multiple of 32"
)
args = parser.parse_args()

test(args)

0 comments on commit b5bc801

Please sign in to comment.