Skip to content

Commit

Permalink
#186 implement Jaccard similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasvonM committed Jul 10, 2024
1 parent c70f229 commit a9881ea
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 26 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ unicode-segmentation = "^1.6.0"
unicode-normalization = "^0.1"
smallvec = "1.10.0"
ahash = "0.8.3"
num-traits = "0.2"

[dev-dependencies]
csv = "1.1"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ String comparison:

* Levenshtein Distance
* Damerau-Levenshtein Distance
* Jaccard Index
* Jaro Distance
* Jaro-Winkler Distance
* Match Rating Approach Comparison
Expand Down
12 changes: 12 additions & 0 deletions docs/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ considers extra characters as differing. For example ``hamming_distance('abc',

See the [Hamming distance article at Wikipedia](http://en.wikipedia.org/wiki/Hamming_distance) for more details.

### Jaccard Similarity

``` python
def jaccard_similarity(s1: str, s2: str, ngram_size: Optional[int] = None) -> float
```

Compute the Jaccard index between s1 and s2.

The Jaccard index between two sets is defined as the number of elements of the intersection divided by the number of elements of the union of the two sets. The elements of the sets are ngrams (the substrings of length `ngram_size`) or words if `ngram_size` is `None`. The strings are split by whitespace.

The Jaccard index does not consider order of words/ngrams. Hence "hello world" and "world hello" have a Jaccard similarity of 1.

### Jaro Similarity

``` python
Expand Down
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ Issues: [https://github.com/jamesturk/jellyfish/issues](https://github.com/james
[![Test badge](https://github.com/jamesturk/jellyfish/workflows/Python%20package/badge.svg)](https://github.com/jamesturk/jellyfish/actions?query=workflow%3A%22Python+package)
[![Coveralls](https://coveralls.io/repos/jamesturk/jellyfish/badge.png?branch=master)](https://coveralls.io/r/jamesturk/jellyfish)


## Included Algorithms

String comparison:

* Levenshtein Distance
* Damerau-Levenshtein Distance
* Jaccard Similarity
* Jaro Distance
* Jaro-Winkler Distance
* Match Rating Approach Comparison
Expand All @@ -36,7 +36,7 @@ Phonetic encoding:
Each algorithm has Rust and Python implementations.

The Rust implementations are used by default. The Python
implementations are a remnant of an early version of
implementations are a remnant of an early version of
the library and will probably be removed in 1.0.

To explicitly use a specific implementation, refer to the appropriate module::
Expand Down
3 changes: 3 additions & 0 deletions python/jellyfish/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional

def levenshtein_distance(s1: str, s2: str) -> int: ...
def jaccard_similarity(s1: str, s2: str, ngram_size: Optional[int] = None) -> float: ...
def jaro_similarity(s1: str, s2: str) -> float: ...
def jaro_winkler_similarity(s1: str, s2: str, long_tolerance: bool = ...) -> float: ...
def damerau_levenshtein_distance(s1: str, s2: str) -> int: ...
Expand Down
54 changes: 54 additions & 0 deletions src/jaccard.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use std::borrow::Cow;
use std::collections::HashSet;

pub fn jaccard_similarity(s1: &str, s2: &str, ngram_size: Option<usize>) -> f64 {
// 1. Tokenize into ngrams
let grams1: HashSet<String> = get_ngrams(s1, ngram_size)
.into_iter()
.map(|cow| cow.into_owned())
.collect();
let grams2: HashSet<String> = get_ngrams(s2, ngram_size)
.into_iter()
.map(|cow| cow.into_owned())
.collect();

// 2. Calculate intersection and union sizes
let intersection_size: usize = grams1.iter().filter(|gram| grams2.contains(*gram)).count();
let union_size: usize = grams1.len() + grams2.len() - intersection_size;

// 3. Calculate Jaccard index
if union_size == 0 {
0.0
} else {
intersection_size as f64 / union_size as f64
}
}

fn get_ngrams(s: &str, n: Option<usize>) -> Vec<Cow<'_, str>> {
if let Some(size) = n {
// Non-overlapping character-level n-grams
s.chars()
.collect::<Vec<char>>()
.chunks(size) // Use chunks() for non-overlapping groups
.map(|chunk| Cow::from(chunk.iter().collect::<String>()))
.collect()
} else {
// Word-level "n-grams" (i.e., words)
s.split_whitespace()
.map(Cow::from)
.collect()
}
}



#[cfg(test)]
mod test {
use super::*; // Import the Jaccard functions
use crate::testutils::testutils; // Import the test utils

#[test]
fn test_jaccard_similarity() {
testutils::test_similarity_func_three_args("testdata/jaccard.csv", jaccard_similarity);
}
}
6 changes: 3 additions & 3 deletions src/jaro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ mod test {
use crate::testutils::testutils;
#[test]
fn test_jaro() {
testutils::test_similarity_func("testdata/jaro_distance.csv", jaro_similarity);
testutils::test_similarity_func_two_args("testdata/jaro_distance.csv", jaro_similarity);
}

#[test]
fn test_jaro_winkler() {
testutils::test_similarity_func("testdata/jaro_winkler.csv", jaro_winkler_similarity);
testutils::test_similarity_func_two_args("testdata/jaro_winkler.csv", jaro_winkler_similarity);
}

#[test]
fn test_jaro_winkler_longtol() {
testutils::test_similarity_func(
testutils::test_similarity_func_two_args(
"testdata/jaro_winkler_longtol.csv",
jaro_winkler_similarity_longtol,
);
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod common;
mod hamming;
mod jaccard;
mod jaro;
mod levenshtein;
mod match_rating;
Expand All @@ -9,6 +10,7 @@ mod soundex;
mod testutils;

pub use hamming::{hamming_distance, vec_hamming_distance};
pub use jaccard::jaccard_similarity;
pub use jaro::{
jaro_similarity, jaro_winkler_similarity, jaro_winkler_similarity_longtol, vec_jaro_similarity,
vec_jaro_winkler_similarity, vec_jaro_winkler_similarity_longtol,
Expand Down
8 changes: 8 additions & 0 deletions src/rustyfish.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::damerau_levenshtein_distance as _damerau;
use crate::hamming_distance as _hamming;
use crate::jaccard_similarity as _jaccard;
use crate::jaro_similarity as _jaro;
use crate::jaro_winkler_similarity as _jaro_winkler;
use crate::jaro_winkler_similarity_longtol as _jaro_winkler_long;
Expand All @@ -24,6 +25,12 @@ fn hamming_distance(a: &str, b: &str) -> PyResult<usize> {
Ok(_hamming(a, b))
}

// Calculates the Jaccard index between two strings.
#[pyfunction]
fn jaccard_similarity(a: &str, b: &str, ngram_size: Option<usize>) -> PyResult<f64> {
Ok(_jaccard(a, b, ngram_size))
}

// Calculates the Jaro similarity between two strings.
#[pyfunction]
fn jaro_similarity(a: &str, b: &str) -> PyResult<f64> {
Expand Down Expand Up @@ -84,6 +91,7 @@ fn metaphone(a: &str) -> PyResult<String> {
pub fn _rustyfish(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(damerau_levenshtein_distance, m)?)?;
m.add_function(wrap_pyfunction!(hamming_distance, m)?)?;
m.add_function(wrap_pyfunction!(jaccard_similarity, m)?)?;
m.add_function(wrap_pyfunction!(jaro_similarity, m)?)?;
m.add_function(wrap_pyfunction!(jaro_winkler_similarity, m)?)?;
m.add_function(wrap_pyfunction!(levenshtein_distance, m)?)?;
Expand Down
75 changes: 54 additions & 21 deletions src/testutils.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,78 @@
#[cfg(test)]
pub mod testutils {
use csv;
use num_traits::{Float, FromPrimitive};

pub fn test_distance_func(filename: &str, func: fn(&str, &str) -> usize) {
fn test_generic_func<T, F>(filename: &str, func: F)
where
F: Fn(&str, &str, Option<usize>) -> T, // Signature for functions with ngram_size
T: PartialEq + std::fmt::Debug + std::str::FromStr + Float + FromPrimitive,
<T as std::str::FromStr>::Err: std::fmt::Debug,
{
let mut reader = csv::ReaderBuilder::new()
.has_headers(false)
.from_path(filename)
.unwrap();
let mut num_tested = 0;
for result in reader.records() {
let rec = result.unwrap();
let expected = rec[2].parse().ok().unwrap();
println!(
"comparing {} to {}, expecting {:?}",
&rec[0], &rec[1], expected
let input1 = &rec[0];
let input2 = &rec[1];
let ngram_size = rec.get(3).and_then(|s| s.parse().ok());

let expected: T = rec[2].parse().expect("Failed to parse expected value");
let output = func(input1, input2, ngram_size);

let abs_diff = (output.to_f64().unwrap() - expected.to_f64().unwrap()).abs();
assert!(
abs_diff < 0.001,
"comparing {} to {} (ngram_size: {:?}), expected {:?}, got {:?} (diff {:?})",
input1,
input2,
ngram_size,
expected,
output,
abs_diff
);
assert_eq!(func(&rec[0], &rec[1]), expected);

num_tested += 1;
}
assert!(num_tested > 0);
}

pub fn test_similarity_func(filename: &str, func: fn(&str, &str) -> f64) {
pub fn test_distance_func(filename: &str, func: fn(&str, &str) -> usize) {
let mut reader = csv::ReaderBuilder::new()
.has_headers(false)
.from_path(filename)
.unwrap();
let mut num_tested = 0;
for result in reader.records() {
let rec = result.unwrap();
let expected: f64 = rec[2].parse().ok().unwrap();
let output = func(&rec[0], &rec[1]);
let input1 = &rec[0];
let input2 = &rec[1];
let expected: usize = rec[2].parse().expect("Failed to parse expected value");
let output = func(input1, input2);

println!(
"comparing {} to {}, expecting {}, got {}",
&rec[0], &rec[1], expected, output
);
assert!(
(output - expected).abs() < 0.001,
"{} !~= {} [{}]",
output,
expected,
output - expected
"comparing {} to {}, expecting {:?}, got {:?}",
input1, input2, expected, output
);
assert_eq!(output, expected);
num_tested += 1;
}
assert!(num_tested > 0);
}

// For functions with two string arguments
pub fn test_similarity_func_two_args(filename: &str, func: fn(&str, &str) -> f64) {
test_generic_func::<f64, _>(filename, |a, b, _| func(a, b));
}

// For functions with three arguments (including the optional usize)
pub fn test_similarity_func_three_args(filename: &str, func: fn(&str, &str, Option<usize>) -> f64) {
test_generic_func::<f64, _>(filename, |a, b, n| func(a, b, n));
}

pub fn test_str_func(filename: &str, func: fn(&str) -> String) {
let mut reader = csv::ReaderBuilder::new()
.has_headers(false)
Expand All @@ -55,9 +81,16 @@ pub mod testutils {
let mut num_tested = 0;
for result in reader.records() {
let rec = result.unwrap();
let output = func(&rec[0]);
println!("testing {}, expecting {}, got {}", &rec[0], &rec[1], output);
assert_eq!(&rec[1], output);
let input1 = &rec[0];
let expected = rec[1].to_string();

let output = func(input1);

println!(
"comparing {}, expecting {:?}, got {:?}",
input1, expected, output
);
assert_eq!(output, expected);
num_tested += 1;
}
assert!(num_tested > 0);
Expand Down

0 comments on commit a9881ea

Please sign in to comment.