Skip to content

Commit

Permalink
compute_train_stats: Fix logits passed in as proba (HumanCompatibleAI…
Browse files Browse the repository at this point in the history
…#273)

Led to an error when I was training.
  • Loading branch information
shwang committed Mar 14, 2021
1 parent 1c1bf24 commit e7e87b2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/imitation/rewards/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def compute_train_stats(
_n_gen_or_1 = max(1, n_generated)
generated_acc = _n_pred_gen / float(_n_gen_or_1)

label_dist = th.distributions.Bernoulli(disc_logits_gen_is_high)
label_dist = th.distributions.Bernoulli(logits=disc_logits_gen_is_high)
entropy = th.mean(label_dist.entropy())

pairs = [
Expand Down

0 comments on commit e7e87b2

Please sign in to comment.