Skip to content

Commit

Permalink
fix crash with cosine similarity scores (#364)
Browse files Browse the repository at this point in the history
Co-authored-by: guillaume-be <guillaume.becquin@gmail.com>
  • Loading branch information
laptou and guillaume-be committed Apr 30, 2023
1 parent 5b8dcd2 commit c37eb32
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/pipelines/keywords_extraction/scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ fn cosine_similarity_score(
word_embeddings: Tensor,
num_keywords: usize,
) -> Vec<(usize, f32)> {
let similarities = cosine_similarity(Some(&document_embedding), &word_embeddings).squeeze();
let similarities = cosine_similarity(Some(&document_embedding), &word_embeddings).view([-1]);

let (top_scores, top_keywords) = similarities.topk(num_keywords as i64, 0, true, false);
top_scores
.iter::<f64>()
Expand All @@ -92,7 +93,7 @@ fn maximal_margin_relevance_score(
diversity: f64,
) -> Vec<(usize, f32)> {
let word_document_similarities =
cosine_similarity(Some(&document_embedding), &word_embeddings).squeeze();
cosine_similarity(Some(&document_embedding), &word_embeddings).view([-1]);
let word_similarities = cosine_similarity(None, &word_embeddings);

let mut keyword_indices = vec![i64::from(word_document_similarities.argmax(0, false))];
Expand Down Expand Up @@ -139,7 +140,7 @@ fn max_sum_score(
) -> Vec<(usize, f32)> {
let max_sum_candidates = max(num_keywords, max_sum_candidates);
let word_document_similarities =
cosine_similarity(Some(&document_embedding), &word_embeddings).squeeze();
cosine_similarity(Some(&document_embedding), &word_embeddings).view([-1]);
let word_similarities = cosine_similarity(None, &word_embeddings);
let (_, top_keywords) =
word_document_similarities.topk(max_sum_candidates as i64, 0, true, false);
Expand Down

0 comments on commit c37eb32

Please sign in to comment.