Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support losses for DDPM-based models #142

Merged
merged 11 commits into from
Nov 23, 2021

Conversation

LeoXing1996
Copy link
Collaborator

We support two critical features in DDPMLoss:

  1. Support users to define how to collect log_vars by configs.
  2. Support rescales loss corresponding to timesteps.

mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to `''`.
rescale_cfg (dict, optional): Config of the loss rescale method.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, we do not prefer to widely use config-dict as the argument for these reasons:

  1. Not straightforward: Like this example, users cannot know which information should be defined in this log_cfgs by reading the docs. Besides, the config-dict is always related to another function, which cannot be easily found.
  2. Too complex: There are always some rules in designing the config-dict.
  3. Not Safe: Dict data is mutable and not safe.

Just a suggestion. Next time for designing new modules, try to organize the arguments and offer a clear and straightforward API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each log_cfg may contain three keywords: type, prefix and reduction. To avoid using config-dict as input, we must set log_prefix and log_reduction as inputs in the initialize function.

Copy link
Collaborator

@nbei nbei Nov 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, you may specify more here and give clear and straightforward guidance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More doc about the design of log collection functions is added.

mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
assert hasattr(self, log_collect_fn)
log_collect_fn = getattr(self, log_collect_fn)

log_cfg_.setdefault('prefix_name', 'loss')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether users will change this argument. It seems that users don't have any reason to change the prefix_name. Is it right? If it is right, we can directly remove the redundant codes. If not, you may specify how to use the prefix.

mmgen/models/losses/ddpm_loss.py Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Show resolved Hide resolved
mmgen/models/losses/ddpm_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/pixelwise_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/utils.py Outdated Show resolved Hide resolved
mmgen/models/losses/pixelwise_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/pixelwise_loss.py Outdated Show resolved Hide resolved
mmgen/models/losses/pixelwise_loss.py Show resolved Hide resolved
@nbei nbei merged commit d4495ec into open-mmlab:master Nov 23, 2021
LeoXing1996 added a commit that referenced this pull request Jul 16, 2022
* support all losses needed for DDPM

* fix typos

* add copyright

* solve divide error in quartile log collection for pt1.5

* add more docstrings

* revise known problems

* cvt default value of rescale_mode to None

* revise known problems

* fix unit test

* revise docstring for ddpm losses
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants