Skip to content

Commit

Permalink
Fix errors when non-adversarial training
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldugas committed Jan 16, 2017
1 parent cc2dae7 commit b7f0aaa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
3 changes: 2 additions & 1 deletion segmatch/python/autoencoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def train_on_single_batch(self, batch_input, train_target=None, adversarial=Fals
cost = [self.cost]
if adversarial: cost = cost + [self.generator_loss_no_MI, self.discriminator_loss_no_MI, self.mutual_information_est]
opt = train_target if train_target is not None else self.optimizer
if opt is self.discriminator_optimizer or opt is self.generator_optimizer: dict_[self.stop_gradient_placeholder] = True
if adversarial:
if opt is self.discriminator_optimizer or opt is self.generator_optimizer: dict_[self.stop_gradient_placeholder] = True
# compute
cost, _, _, summary = self.sess.run((cost, opt, self.catch_nans, self.merged), feed_dict=dict_)
if summary_writer is not None: summary_writer.add_summary(summary)
Expand Down
12 changes: 8 additions & 4 deletions segmatch/python/autoencoder_node.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,15 @@
" step_times = {'batchmaking': zero, 'training': zero, 'plotting': zero}\n",
" avg_step_cost = Average()\n",
" training_batchmaker = Batchmaker(train_vox, BATCH_SIZE, MP)\n",
" train_order = 4*[vae.optimizer] + 4*[vae.generator_optimizer] + [vae.discriminator_optimizer]\n",
" if ADVERSARIAL:\n",
" train_order = 4*[vae.optimizer] + 4*[vae.generator_optimizer] + [vae.discriminator_optimizer]\n",
" else:\n",
" train_order = [vae.optimizer]\n",
" for train_target in itertools.cycle(train_order):\n",
" if train_target is vae.discriminator_optimizer:\n",
" if cost_value[1] > G_THRESHOLD or cost_value[2] < D_THRESHOLD:\n",
" continue\n",
" if ADVERSARIAL:\n",
" if train_target is vae.discriminator_optimizer:\n",
" if cost_value[1] > G_THRESHOLD or cost_value[2] < D_THRESHOLD:\n",
" continue\n",
" if training_batchmaker.is_depleted():\n",
" break\n",
" else:\n",
Expand Down

0 comments on commit b7f0aaa

Please sign in to comment.