Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Has anyone succeeded in reproducing the results? #44

Open
nashory opened this issue Apr 26, 2021 · 4 comments
Open

Has anyone succeeded in reproducing the results? #44

nashory opened this issue Apr 26, 2021 · 4 comments

Comments

@nashory
Copy link

nashory commented Apr 26, 2021

I am still struggling with training VQ-GAN in the first stage, not even the conditional transformer which is a second stage.
The result looks fine before the discriminator loss is injected. BUT using the discriminator loss suddenly ruins the reconstructed images. disc_loss remains 1.0 during the training. Why??

@richcmwang
Copy link

The loss goes up with the default parameters/set up very early in the training even before the discriminator kicked in. I am puzzled by this.

@mhh0318
Copy link

mhh0318 commented May 10, 2021

The same, using discriminator loss makes my result worse.

@nicolasfischoeder
Copy link

once the discriminator kicks in everything goes crazy, does somebody have an advice?

@function2-llx
Copy link

once the discriminator kicks in everything goes crazy, does somebody have an advice?

Perhaps because the discriminator is not trained at all until the g_loss is introduced to the total loss. According to the code, both the training of discriminator and the introduction of g_loss starts at the self.discriminator_iter_start step. Perhaps we can try training discriminator from the very beginning.

disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()

disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants