diff --git a/segmatch/python/autoencoder/model.py b/segmatch/python/autoencoder/model.py index 349fc93..d01b830 100644 --- a/segmatch/python/autoencoder/model.py +++ b/segmatch/python/autoencoder/model.py @@ -409,7 +409,8 @@ def train_on_single_batch(self, batch_input, train_target=None, adversarial=Fals cost = [self.cost] if adversarial: cost = cost + [self.generator_loss_no_MI, self.discriminator_loss_no_MI, self.mutual_information_est] opt = train_target if train_target is not None else self.optimizer - if opt is self.discriminator_optimizer or opt is self.generator_optimizer: dict_[self.stop_gradient_placeholder] = True + if adversarial: + if opt is self.discriminator_optimizer or opt is self.generator_optimizer: dict_[self.stop_gradient_placeholder] = True # compute cost, _, _, summary = self.sess.run((cost, opt, self.catch_nans, self.merged), feed_dict=dict_) if summary_writer is not None: summary_writer.add_summary(summary) diff --git a/segmatch/python/autoencoder_node.ipynb b/segmatch/python/autoencoder_node.ipynb index ef090ea..8d0316f 100644 --- a/segmatch/python/autoencoder_node.ipynb +++ b/segmatch/python/autoencoder_node.ipynb @@ -508,11 +508,15 @@ " step_times = {'batchmaking': zero, 'training': zero, 'plotting': zero}\n", " avg_step_cost = Average()\n", " training_batchmaker = Batchmaker(train_vox, BATCH_SIZE, MP)\n", - " train_order = 4*[vae.optimizer] + 4*[vae.generator_optimizer] + [vae.discriminator_optimizer]\n", + " if ADVERSARIAL:\n", + " train_order = 4*[vae.optimizer] + 4*[vae.generator_optimizer] + [vae.discriminator_optimizer]\n", + " else:\n", + " train_order = [vae.optimizer]\n", " for train_target in itertools.cycle(train_order):\n", - " if train_target is vae.discriminator_optimizer:\n", - " if cost_value[1] > G_THRESHOLD or cost_value[2] < D_THRESHOLD:\n", - " continue\n", + " if ADVERSARIAL:\n", + " if train_target is vae.discriminator_optimizer:\n", + " if cost_value[1] > G_THRESHOLD or cost_value[2] < D_THRESHOLD:\n", + " continue\n", " if training_batchmaker.is_depleted():\n", " break\n", " else:\n",