Skip to content

Commit

Permalink
adds scorer to AggregateRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
rbs333 committed Oct 8, 2024
1 parent 700045c commit dc9a866
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 1 deletion.
18 changes: 17 additions & 1 deletion redis/commands/search/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Union, Optional

FIELDNAME = object()

Expand Down Expand Up @@ -112,6 +112,7 @@ def __init__(self, query: str = "*") -> None:
self._cursor = []
self._dialect = None
self._add_scores = False
self._scorer = Optional[str] = None

def load(self, *fields: List[str]) -> "AggregateRequest":
"""
Expand Down Expand Up @@ -300,6 +301,17 @@ def add_scores(self) -> "AggregateRequest":
self._add_scores = True
return self

def scorer(self, scorer: str) -> "AggregateRequest":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self

def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
Expand All @@ -323,6 +335,9 @@ def build_args(self) -> List[str]:
if self._verbatim:
ret.append("VERBATIM")

if self._scorer:
ret.extend(["SCORER", self._scorer])

if self._add_scores:
ret.append("ADDSCORES")

Expand All @@ -332,6 +347,7 @@ def build_args(self) -> List[str]:
if self._loadall:
ret.append("LOAD")
ret.append("*")

elif self._loadfields:
ret.append("LOAD")
ret.append(str(len(self._loadfields)))
Expand Down
60 changes: 60 additions & 0 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,66 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis):
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
assert await decoded_r.ft().create_index(
(
TextField("name", sortable=True, weight=5.0),
TextField("description", sortable=True, weight=5.0),
VectorField(
"vector",
"HNSW",
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
),
)
)

assert await decoded_r.hset(
"doc1",
mapping={
"name": "cat book",
"description": "a book about cats",
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
},
)
assert await decoded_r.hset(
"doc2",
mapping={
"name": "dog book",
"description": "a book about dogs",
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
},
)

query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]"
req = (
aggregations.AggregateRequest(query_string)
.scorer("BM25")
.add_scores()
.apply(hybrid_score="@__score + @dist")
.load("*")
.dialect(4)
)

res = (
await decoded_r.ft()
.aggregate(
req,
query_params={
"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()
},
)
.rows[0]
)

assert len(res) == 6
assert b"hybrid_score" in res
assert b"__score" in res
assert b"__dist" in res
assert float(res[1]) + float(res[3]) == float(res[5])


@pytest.mark.redismod
@skip_if_redis_enterprise()
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):
Expand Down
60 changes: 60 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,66 @@ def test_aggregations_add_scores(client):
assert res.rows[1] == ["__score", "0.2"]


@pytest.mark.redismod
@skip_ifmodversion_lt("2.10.05", "search")
async def test_aggregations_hybrid_scoring(client):
client.ft().create_index(
(
TextField("name", sortable=True, weight=5.0),
TextField("description", sortable=True, weight=5.0),
VectorField(
"vector",
"HNSW",
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
),
)
)

client.hset(
"doc1",
mapping={
"name": "cat book",
"description": "a book about cats",
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
},
)
client.hset(
"doc2",
mapping={
"name": "dog book",
"description": "a book about dogs",
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
},
)

query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]"
req = (
aggregations.AggregateRequest(query_string)
.scorer("BM25")
.add_scores()
.apply(hybrid_score="@__score + @dist")
.load("*")
.dialect(4)
)

res = (
client.ft()
.aggregate(
req,
query_params={
"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()
},
)
.rows[0]
)

assert len(res) == 6
assert b"hybrid_score" in res
assert b"__score" in res
assert b"__dist" in res
assert float(res[1]) + float(res[3]) == float(res[5])


@pytest.mark.redismod
@skip_ifmodversion_lt("2.0.0", "search")
def test_index_definition(client):
Expand Down

0 comments on commit dc9a866

Please sign in to comment.