Skip to content

Commit

Permalink
Change naming
Browse files Browse the repository at this point in the history
  • Loading branch information
HasnainRaz committed Oct 20, 2019
1 parent b06ea75 commit 72e2ccb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser
from dataloader import DataLoader
from model import MobileSRGAN
from model import FastSRGAN
import tensorflow as tf
import os

Expand Down Expand Up @@ -46,7 +46,7 @@ def pretrain_generator(model, dataset, writer):
"""
with writer.as_default():
iteration = 0
for epoch in range(20):
for epoch in range(2):
for x, y in dataset:
loss = pretrain_step(model, x, y)
if iteration % 20 == 0:
Expand All @@ -66,7 +66,7 @@ def train_step(model, x, y):
Returns:
d_loss: The mean loss of the discriminator.
"""

# Label smoothing for better gradient flow
valid = tf.ones((x.shape[0], 1)) - tf.random.uniform((x.shape[0], 1)) * 0.1
fake = tf.ones((x.shape[0], 1)) * tf.random.uniform((x.shape[0], 1)) * 0.1

Expand Down Expand Up @@ -144,7 +144,7 @@ def main():

with tf.device('GPU:1'):
# Initialize the GAN object.
gan = MobileSRGAN(args)
gan = FastSRGAN(args)

# Define the directory for saving pretrainig loss tensorboard summary.
pretrain_summary_writer = tf.summary.create_file_writer('logs/pretrain')
Expand Down
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import tensorflow as tf


class MobileSRGAN(object):
"""MobileNet SRGAN for fast super resolution."""
class FastSRGAN(object):
"""SRGAN for fast super resolution."""

def __init__(self, args):
"""
Expand Down Expand Up @@ -87,7 +87,7 @@ def build_vgg(self):

def build_generator(self):
"""Build the generator that will do the Super Resolution task.
Based on the mobilenet design taken from Galteri et al."""
Based on the Mobilenet design. Idea from Galteri et al."""

def _make_divisible(v, divisor, min_value=None):
if min_value is None:
Expand Down Expand Up @@ -204,7 +204,7 @@ def deconv2d(layer_input):
return keras.models.Model(img_lr, gen_hr)

def build_discriminator(self):
"""Builds a discriminator network based on the Patch-GAN design."""
"""Builds a discriminator network based on the SRGAN design."""

def d_block(layer_input, filters, strides=1, bn=True):
"""Discriminator layer block.
Expand Down

0 comments on commit 72e2ccb

Please sign in to comment.