Skip to content

Commit

Permalink
Default diff-sampling to gumbel softmax
Browse files Browse the repository at this point in the history
This make it consistent with the implementation used in https://proceedings.mlr.press/v181/lang22a/lang22a.pdf

See also: #5
  • Loading branch information
braun-steven committed Nov 8, 2023
1 parent 179cc72 commit 29680f3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion simple_einet/sampling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def sample_categorical_differentiably(
tau: float,
logits: torch.Tensor = None,
log_weights: torch.Tensor = None,
method=DiffSampleMethod.SIMPLE,
method=DiffSampleMethod.GUMBEL,
) -> torch.Tensor:
"""
Perform differentiable sampling/mpe on the given input along a specific dimension.
Expand Down

0 comments on commit 29680f3

Please sign in to comment.