diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 95577fd311..5e93d9e33b 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include // DI namespace raft::distance::detail::ops { @@ -33,7 +34,7 @@ struct l2_exp_cutlass_op { // outVal could be negative due to numerical instability, especially when // calculating self distance. // clamp to 0 to avoid potential NaN in sqrt - outVal = outVal * (outVal > DataT(0.0)); + outVal = outVal * (raft::abs(outVal) >= DataT(0.0001)); return sqrt ? raft::sqrt(outVal) : outVal; } @@ -88,7 +89,7 @@ struct l2_exp_distance_op { DataT val = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; // val could be negative due to numerical instability, especially when // calculating self distance. Clamp to 0 to avoid potential NaN in sqrt - acc[i][j] = val * (val > DataT(0.0)); + acc[i][j] = val * (raft::abs(val) >= DataT(0.0001)); } } if (sqrt) { diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 18f1906dc5..81779668c4 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -884,6 +884,7 @@ void launch_kernel(Lambda lambda, queries += grid_dim_y * index.dim(); neighbors += grid_dim_y * grid_dim_x * k; distances += grid_dim_y * grid_dim_x * k; + coarse_index += grid_dim_y * n_probes; } } diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index d72d73680a..71d48cdeb7 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -497,6 +497,11 @@ const std::vector> inputs = { raft::matrix::detail::select::warpsort::kMaxCapacity * 4, raft::matrix::detail::select::warpsort::kMaxCapacity * 4, raft::distance::DistanceType::InnerProduct, - false}}; + false}, + + // The following two test cases should show very similar recall. + // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric, adaptive_centers + {20000, 8712, 3, 10, 51, 66, raft::distance::DistanceType::L2Expanded, false}, + {100000, 8712, 3, 10, 51, 66, raft::distance::DistanceType::L2Expanded, false}}; } // namespace raft::neighbors::ivf_flat diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 3037b9a725..20dadf0275 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -60,9 +60,9 @@ cdef extern from "raft_runtime/distance/pairwise_distance.hpp" \ float metric_arg) except + DISTANCE_TYPES = { - "l2": DistanceType.L2SqrtUnexpanded, - "sqeuclidean": DistanceType.L2Unexpanded, - "euclidean": DistanceType.L2SqrtUnexpanded, + "l2": DistanceType.L2SqrtExpanded, + "sqeuclidean": DistanceType.L2Expanded, + "euclidean": DistanceType.L2SqrtExpanded, "l1": DistanceType.L1, "cityblock": DistanceType.L1, "inner_product": DistanceType.InnerProduct, diff --git a/python/pylibraft/pylibraft/test/test_brute_force.py b/python/pylibraft/pylibraft/test/test_brute_force.py index 2e118d210d..42095c3b9f 100644 --- a/python/pylibraft/pylibraft/test/test_brute_force.py +++ b/python/pylibraft/pylibraft/test/test_brute_force.py @@ -89,7 +89,7 @@ def test_knn(n_index_rows, n_query_rows, n_cols, k, inplace, metric, dtype): cpu_ordered = pw_dists[i, expected_indices] np.testing.assert_allclose( - cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4 + cpu_ordered[:k], gpu_dists, atol=1e-3, rtol=1e-3 ) diff --git a/python/pylibraft/pylibraft/test/test_distance.py b/python/pylibraft/pylibraft/test/test_distance.py index 2c0a842fe5..f9d3890ff7 100644 --- a/python/pylibraft/pylibraft/test/test_distance.py +++ b/python/pylibraft/pylibraft/test/test_distance.py @@ -81,4 +81,4 @@ def test_distance(n_rows, n_cols, inplace, metric, order, dtype): actual[actual <= 1e-5] = 0.0 - assert np.allclose(expected, actual, rtol=1e-4) + assert np.allclose(expected, actual, atol=1e-3, rtol=1e-3)