From 29fb5883e672fc01c6d84762b5a6a1b12796f54e Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Mon, 17 Jun 2024 07:52:29 +0200 Subject: [PATCH] Fix wrong softmax projection on probs/logits (John Leland) --- simple_einet/layers/distributions/abstract_leaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simple_einet/layers/distributions/abstract_leaf.py b/simple_einet/layers/distributions/abstract_leaf.py index 1936d3b..0beae86 100644 --- a/simple_einet/layers/distributions/abstract_leaf.py +++ b/simple_einet/layers/distributions/abstract_leaf.py @@ -112,7 +112,7 @@ def dist_sample(distribution: dist.Distribution, ctx: SamplingContext = None) -> elif type(distribution) == CustomNormal: distribution = CustomNormal(mu=distribution.mu, sigma=distribution.sigma / ctx.temperature_leaves) elif type(distribution) == dist.Categorical: - distribution = dist.Categorical(logits=F.log_softmax(distribution.probs / ctx.temperature_leaves)) + distribution = dist.Categorical(logits=F.log_softmax(distribution.logits / ctx.temperature_leaves)) samples = distribution.sample(sample_shape=(ctx.num_samples,)).float() assert (