-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support KLD metric and support evaluation for probabilistic models #108
Conversation
Codecov Report
@@ Coverage Diff @@
## master #108 +/- ##
==========================================
- Coverage 76.06% 75.28% -0.79%
==========================================
Files 118 121 +3
Lines 8089 8188 +99
Branches 1519 1561 +42
==========================================
+ Hits 6153 6164 +11
- Misses 1546 1609 +63
- Partials 390 415 +25
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
… models (#108) * support KLD metric * add init under tests * solve import error * add manually type convert for pt<1.8 * remove an invalid element from evaluation. init * fix by comment
Metric Design
Different from gan metrics, probabilistic metrics:
Therefore, we design a list called
probabilistic_metric_name
to contain probabilistic metrics insingle_gpu_online_evaluation
. When evaluation with those metric separately.When evaluating those probabilistic metrics, we use set forward mode as reconstruction.
We default that all probabilistic model all support this mode. Unlike
forward_test
which starts with random noise,mode=reconstruction
performs a reconstruction behavior with given data and returns a dict containing desired probabilistic parameters.We also slightly modify the batch truncation operation to make the
Metric.feed
support dict input.Some further design about
mode=reconstruction
Although we have not implemented any code, we find a specific function to release reconstruction operation is critical for probabilistic models (e.g., DDPM).
In
train_step
, reconstruction, loss calculation, and update operation are performed altogether.For
forward_test
, the interface is fixed and called bysample_from_noise
to perform a random generation process.Therefore, we need a function to implement a separate reconstruction process and return all the intermedia probabilistic parameters. And this function can also be called by
train_step
.