-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Changed how checklist balance is used in action prediction #1204
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the simplification to the code; how does this actually affect performance?
the current hidden state, attended encoder input and the current checklist balance into the | ||
action space. The size of the checklist balance vector is the same as the number of | ||
terminals. This is not needed if we are training the parser using target action sequences. | ||
use_coverage : ``bool`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to mention this is optional here, and give the default value.
@@ -132,10 +127,12 @@ def take_step(self, # type: ignore | |||
# action_mask: (group_size, num_embedded_actions) | |||
action_embeddings, embedded_action_mask = self._get_action_embeddings(state, | |||
global_actions_to_embed) | |||
action_query = self._get_action_query(state, hidden_state, attended_sentence) | |||
action_query = torch.cat([hidden_state, attended_sentence], dim=-1) | |||
# (group_size, action_embedding_dim) | |||
predicted_action_embedding = self._output_projection_layer(action_query) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR, but you probably want a non-linearity here. See #1150, where I added this to the wikitables model. There's another spot where you probably don't have one but probably want one, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And, now that I think about it, because this is getting a dot product with action embeddings, I wonder if relu isn't the best non-linearity. We're basically saying that any embedding dimension with a negative value is entirely ignored, and thus artificially constraining the space that's available to the dot product... Maybe tanh would be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you find that adding relus worked better? In my other PR I added dropout and a relu here to match the wikitables parser, and that seemed to give slightly better results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a non-linearity at the decoder input as well, and made both of them tanh.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran it with tanh instead of relu last night, and it didn't really change anything. It looked like it was learning faster - epoch 1 performance was a bit higher - but in the end, final performance was within the normal variance that I've seen across runs. So, it probably doesn't matter.
|
||
if state.checklist_state[0] is not None: | ||
embedding_addition = self._get_predicted_embedding_addition(state) | ||
predicted_action_embedding += self._checklist_embedding_multiplier * embedding_addition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't ever use +=
on a tensor. It doesn't do what you expect. Use x = x + y
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I figured out that +=
causes in-place updates, messing up the computation graph while x = x+y
does not. I changed the other places I did in-place updates, but I guess I missed these two. Thanks, fixed them.
embedding_addition = self._get_predicted_embedding_addition(state, | ||
self._unlinked_terminal_indices, | ||
unlinked_balance) | ||
predicted_action_embedding += self._unlinked_checklist_multiplier * embedding_addition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use +=
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
22fd0f2
to
7ef6b24
Compare
@matt-gardner On NLVR, the variant with this change is at least slightly better (about 0.5 pp) than the run without this change. The experiment is still running though. |
7ef6b24
to
ada48af
Compare
Summary of changes: