Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawn-LX committed Mar 17, 2022
1 parent 7301764 commit 2e53c7b
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions models/model_pairwise_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,12 @@ def construct_triplet(self,proposal,pred_logits,pair_ids):
uniq_dura_inters = uniq_dura_inters[mask,:]


# sort by score and select top200
top200ids = uniq_scores.mean(dim=-1).argsort(descending=True)[:200]
uniq_scores = uniq_scores[top200ids,:]
uniq_quintuples = uniq_quintuples[top200ids,:]
uniq_dura_inters = uniq_dura_inters[top200ids,:]
if self.rt_triplets_topk > 0:
# sort by score and select top200 (for save GPU memory when doing the grounding stage)
top200ids = uniq_scores.mean(dim=-1).argsort(descending=True)[:self.rt_triplets_topk]
uniq_scores = uniq_scores[top200ids,:]
uniq_quintuples = uniq_quintuples[top200ids,:]
uniq_dura_inters = uniq_dura_inters[top200ids,:]

uniq_query_ids = torch.empty(size=(uniq_scores.shape[0],))

Expand Down

0 comments on commit 2e53c7b

Please sign in to comment.