Skip to content

Commit

Permalink
fix: Default includes for query #45
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Mar 14, 2024
1 parent a463994 commit ab1339d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
2 changes: 1 addition & 1 deletion chroma.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ func (c *Collection) QueryWithOptions(ctx context.Context, queryOptions ...types
}
var localInclude = b.Include
if len(b.Include) == 0 {
localInclude = []types.QueryEnum{types.IDocuments, types.IMetadatas}
localInclude = []types.QueryEnum{types.IDocuments, types.IMetadatas, types.IDistances}
}
_includes := make([]openapiclient.IncludeInner, len(localInclude))
for i, v := range localInclude {
Expand Down
49 changes: 46 additions & 3 deletions test/chroma_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ func Test_chroma_client(t *testing.T) {
require.Contains(t, res.Ids, "ID2")
})

t.Run("Test Query Collection Documents", func(t *testing.T) {
t.Run("Test Query Collection Documents - with Default includes", func(t *testing.T) {
collectionName := "test-collection"
metadata := map[string]interface{}{}
embeddingFunction := types.NewConsistentHashEmbeddingFunction()
Expand Down Expand Up @@ -487,9 +487,52 @@ func Test_chroma_client(t *testing.T) {
require.NoError(t, getErr)
assert.Equal(t, int32(2), colGet)

qr, qrerr := col.Query(context.Background(), []string{"Dogs are my favorite animals"}, 5, nil, nil, nil)
require.NoError(t, qrerr)
qr, err := col.Query(context.Background(), []string{"Dogs are my favorite animals"}, 5, nil, nil, nil)
require.NoError(t, err)
require.Equal(t, 2, len(qr.Documents[0]))
require.Equal(t, 2, len(qr.Metadatas[0]))
require.Equal(t, 2, len(qr.Distances[0]))
require.Equal(t, documents[1], qr.Documents[0][0]) // ensure that the first document is the one about dogs
})

t.Run("Test Query Collection Documents - with document only includes", func(t *testing.T) {
collectionName := "test-collection"
metadata := map[string]interface{}{}
embeddingFunction := types.NewConsistentHashEmbeddingFunction()
_, errRest := client.Reset(context.Background())
require.NoError(t, errRest)
newCollection, err := client.CreateCollection(context.Background(), collectionName, metadata, true, embeddingFunction, types.L2)
require.NoError(t, err)
require.NotNil(t, newCollection)
require.Equal(t, collectionName, newCollection.Name)
require.Equal(t, 2, len(newCollection.Metadata))
// assert the metadata contains key embedding_function
require.Contains(t, chroma.GetStringTypeOfEmbeddingFunction(embeddingFunction), newCollection.Metadata["embedding_function"])
documents := []string{
"This is a document about cats. Cats are great.",
"this is a document about dogs. Dogs are great.",
}
ids := []string{
"ID1",
"ID2",
}

metadatas := []map[string]interface{}{
{"key1": "value1"},
{"key2": "value2"},
}
col, addError := newCollection.Add(context.Background(), nil, metadatas, documents, ids)
require.NoError(t, addError)

colGet, getErr := col.Count(context.Background())
require.NoError(t, getErr)
assert.Equal(t, int32(2), colGet)

qr, err := col.Query(context.Background(), []string{"Dogs are my favorite animals"}, 5, nil, nil, []types.QueryEnum{types.IDocuments})
require.NoError(t, err)
require.Equal(t, 2, len(qr.Documents[0]))
require.Equal(t, 0, len(qr.Metadatas))
require.Equal(t, 0, len(qr.Distances))
require.Equal(t, documents[1], qr.Documents[0][0]) // ensure that the first document is the one about dogs
})

Expand Down

0 comments on commit ab1339d

Please sign in to comment.