Skip to content

Commit

Permalink
Cosmos: strip implicit casts to allow vector search over arrays
Browse files Browse the repository at this point in the history
Fixes #34402
  • Loading branch information
roji committed Aug 14, 2024
1 parent 4305f7f commit 258020d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -820,13 +820,18 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
ExpressionType.Negate or ExpressionType.NegateChecked
=> sqlExpressionFactory.Negate(sqlOperand!),

ExpressionType.Convert or ExpressionType.ConvertChecked
when operand.Type.IsInterface
&& unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
// Convert nodes can be an explicit user gesture in the query, or they may get introduced by the compiler (e.g. when a Child is
// passed as an argument for a parameter of type Parent). The latter type should generally get stripped out as a pure C#/LINQ
// artifact that shouldn't affect translation, but the latter may be an indication from the user that they want to apply a
// type change.
ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs
when operand.Type.IsInterface && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
// We strip out implicit conversions, e.g. float[] -> ReadOnlyMemory<float> (for vector search)
|| unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit"}
|| unaryExpression.Type.UnwrapNullableType() == operand.Type
|| unaryExpression.Type.UnwrapNullableType() == typeof(Enum)
// Object convert needs to be converted to explicit cast when mismatching types
// But we let is pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types
// But we let it pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types
|| unaryExpression.Type == typeof(object)
=> sqlOperand!,

Expand Down
67 changes: 43 additions & 24 deletions test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,19 @@ public virtual async Task Query_for_vector_distance_bytes_array()
await using var context = CreateContext();
var inputVector = new byte[] { 2, 1, 4, 3, 5, 2, 5, 7, 3, 1 };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
// Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));
Assert.Equal(3, booksFromStore.Count);
Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));

AssertSql(
"""
SELECT VALUE c["BytesArray"]
@__inputVector_1='[2,1,4,3,5,2,5,7,3,1]'

SELECT VALUE VectorDistance(c["Bytes"], @__inputVector_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'})
FROM root c
""");
}
Expand All @@ -119,17 +122,20 @@ public virtual async Task Query_for_vector_distance_singles_array()
await using var context = CreateContext();
var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>()
.Select(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.Select(
e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
// Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));
Assert.Equal(3, booksFromStore.Count);
Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));

AssertSql(
"""
SELECT VALUE c["SinglesArray"]
@__inputVector_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78,0.86,-0.78]'

SELECT VALUE VectorDistance(c["Singles"], @__inputVector_1, false, {'distanceFunction':'dotproduct', 'dataType':'float32'})
FROM root c
""");
}
Expand Down Expand Up @@ -207,14 +213,20 @@ public virtual async Task Vector_distance_bytes_array_in_OrderBy()
await using var context = CreateContext();
var inputVector = new byte[] { 2, 1, 4, 6, 5, 2, 5, 7, 3, 1 };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync());

// Assert.Equal(3, booksFromStore.Count);
var booksFromStore = await context
.Set<Book>()
.OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector))
.ToListAsync();

Assert.Equal(3, booksFromStore.Count);
AssertSql(
);
"""
@__p_1='[2,1,4,6,5,2,5,7,3,1]'

SELECT VALUE c
FROM root c
ORDER BY VectorDistance(c["Bytes"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'})
""");
}

[ConditionalFact]
Expand All @@ -223,13 +235,20 @@ public virtual async Task Vector_distance_singles_array_in_OrderBy()
await using var context = CreateContext();
var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
Assert.Equal(3, booksFromStore.Count);
AssertSql(
"""
@__p_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78]'

AssertSql();
SELECT VALUE c
FROM root c
ORDER BY VectorDistance(c["Singles"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'float32'})
""");
}

[ConditionalFact]
Expand Down

0 comments on commit 258020d

Please sign in to comment.