Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -959,6 +959,30 @@ bool ace_check_use_disk_mode(raft::resources const& res,
return use_disk_mode;
}

// Resolve the ACE partition count while preserving 0 as the auto-selection sentinel.
inline size_t ace_resolve_partition_count(size_t n_partitions)
{
if (n_partitions == 0) { return 2; }
if (n_partitions == 1) {
RAFT_LOG_WARN(
"ACE: Requested 1 partition; adjusted to 2 before applying partitioning heuristics");
return 2;
}
return n_partitions;
}

// Validate the structural ACE partition-count invariants required by the labeler.
inline void ace_validate_partition_count(size_t n_partitions,
size_t dataset_size,
bool adjusted_for_memory = false)
{
RAFT_EXPECTS(n_partitions <= dataset_size,
adjusted_for_memory
? "ACE: configured memory limit is unsatisfiable because the requested partition "
"count cannot exceed dataset size"
: "ACE: number of partitions cannot exceed dataset size");
}

// Validate and adjust partitions for disk mode memory requirements
template <typename T, typename IdxT>
void ace_validate_disk_mode_partitions(raft::resources const& res,
Expand Down Expand Up @@ -1154,11 +1178,9 @@ index<T, IdxT> build_ace(raft::resources const& res,
"ACE: Intermediate graph degree must be greater than 0");
RAFT_EXPECTS(params.graph_degree > 0, "ACE: Graph degree must be greater than 0");

size_t n_partitions = npartitions;
if (n_partitions == 0) {
// Default: start with 2 partitions and increase if needed (minimum for ACE to make sense).
n_partitions = 2;
}
size_t n_partitions = ace_resolve_partition_count(npartitions);

ace_validate_partition_count(n_partitions, dataset_size);

size_t min_required_per_partition = 1000;
if (n_partitions > dataset_size / min_required_per_partition) {
Expand Down Expand Up @@ -1211,6 +1233,7 @@ index<T, IdxT> build_ace(raft::resources const& res,
graph_degree,
params.guarantee_connectivity,
mem);
ace_validate_partition_count(n_partitions, dataset_size, true);
}

// Preallocate space for files for better performance and fail early if not enough space.
Expand Down
44 changes: 42 additions & 2 deletions cpp/tests/neighbors/ann_hnsw_ace.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -262,6 +262,31 @@ class AnnHnswAceTest : public ::testing::TestWithParam<AnnHnswAceInputs> {
std::filesystem::remove_all(temp_dir);
}

void testHnswAceRejectsTooManyPartitions()
{
auto database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host.data_handle(), database_dev.data(), ps.n_rows * ps.dim, stream_);
raft::resource::sync_stream(handle_);

hnsw::index_params hnsw_params;
hnsw_params.metric = ps.metric;
hnsw_params.hierarchy = hnsw::HnswHierarchy::GPU;
hnsw_params.M = 32;
hnsw_params.ef_construction = ps.ef_construction;

auto ace_params = graph_build_params::ace_params();
ace_params.npartitions = static_cast<size_t>(ps.n_rows) + 1;
hnsw_params.graph_build_params = ace_params;

try {
[[maybe_unused]] auto hnsw_index =
hnsw::build(handle_, hnsw_params, raft::make_const_mdspan(database_host.view()));
FAIL() << "ACE accepted more partitions than dataset rows";
} catch (const std::exception& e) {
EXPECT_NE(std::string(e.what()).find("cannot exceed dataset size"), std::string::npos);
}
}

// Verify the in-memory CAGRA -> HNSW conversion spills to disk when the resulting HNSW
// index would not fit in (an artificially constrained) host memory. This exercises
// serialize_to_hnswlib_from_inmem and the batched serializer core, covering:
Expand Down Expand Up @@ -472,7 +497,7 @@ inline std::vector<AnnHnswAceInputs> generate_hnsw_ace_inputs()
{5000}, // n_rows
{64, 128}, // dim
{10}, // k
{2, 4}, // npartitions
{0, 1, 2, 4}, // npartitions (auto, adjusted, and explicit)
{100}, // ef_construction
{false, true}, // use_disk (test both modes)
{cuvs::distance::DistanceType::L2Expanded,
Expand Down Expand Up @@ -502,6 +527,19 @@ inline std::vector<AnnHnswAceInputs> generate_hnsw_ace_memory_fallback_inputs()
};
}

inline std::vector<AnnHnswAceInputs> generate_hnsw_ace_invalid_partition_inputs()
{
return {{10,
5000,
64,
10,
0, // Set to n_rows + 1 by testHnswAceRejectsTooManyPartitions.
100,
false,
cuvs::distance::DistanceType::L2Expanded,
0.0}};
}

// Inputs for testing the in-memory CAGRA -> HNSW disk-spill conversion path.
inline std::vector<AnnHnswAceInputs> generate_hnsw_inmem_spill_inputs()
{
Expand All @@ -524,6 +562,8 @@ inline std::vector<AnnHnswAceInputs> generate_hnsw_inmem_spill_inputs()
const std::vector<AnnHnswAceInputs> hnsw_ace_inputs = generate_hnsw_ace_inputs();
const std::vector<AnnHnswAceInputs> hnsw_ace_memory_fallback_inputs =
generate_hnsw_ace_memory_fallback_inputs();
const std::vector<AnnHnswAceInputs> hnsw_ace_invalid_partition_inputs =
generate_hnsw_ace_invalid_partition_inputs();
const std::vector<AnnHnswAceInputs> hnsw_inmem_spill_inputs = generate_hnsw_inmem_spill_inputs();

} // namespace cuvs::neighbors::hnsw
12 changes: 11 additions & 1 deletion cpp/tests/neighbors/ann_hnsw_ace/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -12,6 +12,16 @@ TEST_P(AnnHnswAceTest_float, AnnHnswAceBuild) { this->testHnswAceBuild(); }

INSTANTIATE_TEST_CASE_P(AnnHnswAceTest, AnnHnswAceTest_float, ::testing::ValuesIn(hnsw_ace_inputs));

typedef AnnHnswAceTest<float, float, uint32_t> AnnHnswAceInvalidPartitionTest_float;
TEST_P(AnnHnswAceInvalidPartitionTest_float, RejectsTooManyPartitions)
{
this->testHnswAceRejectsTooManyPartitions();
}

INSTANTIATE_TEST_CASE_P(AnnHnswAceInvalidPartitionTest,
AnnHnswAceInvalidPartitionTest_float,
::testing::ValuesIn(hnsw_ace_invalid_partition_inputs));

// Test for memory limit fallback to disk mode
typedef AnnHnswAceTest<float, float, uint32_t> AnnHnswAceMemoryFallbackTest_float;
TEST_P(AnnHnswAceMemoryFallbackTest_float, AnnHnswAceMemoryLimitFallback)
Expand Down
12 changes: 11 additions & 1 deletion cpp/tests/neighbors/ann_hnsw_ace/test_half_uint32_t.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -12,6 +12,16 @@ TEST_P(AnnHnswAceTest_half, AnnHnswAceBuild) { this->testHnswAceBuild(); }

INSTANTIATE_TEST_CASE_P(AnnHnswAceTest, AnnHnswAceTest_half, ::testing::ValuesIn(hnsw_ace_inputs));

typedef AnnHnswAceTest<float, half, uint32_t> AnnHnswAceInvalidPartitionTest_half;
TEST_P(AnnHnswAceInvalidPartitionTest_half, RejectsTooManyPartitions)
{
this->testHnswAceRejectsTooManyPartitions();
}

INSTANTIATE_TEST_CASE_P(AnnHnswAceInvalidPartitionTest,
AnnHnswAceInvalidPartitionTest_half,
::testing::ValuesIn(hnsw_ace_invalid_partition_inputs));

// Test for memory limit fallback to disk mode
typedef AnnHnswAceTest<float, half, uint32_t> AnnHnswAceMemoryFallbackTest_half;
TEST_P(AnnHnswAceMemoryFallbackTest_half, AnnHnswAceMemoryLimitFallback)
Expand Down
12 changes: 11 additions & 1 deletion cpp/tests/neighbors/ann_hnsw_ace/test_int8_t_uint32_t.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -14,6 +14,16 @@ INSTANTIATE_TEST_CASE_P(AnnHnswAceTest,
AnnHnswAceTest_int8_t,
::testing::ValuesIn(hnsw_ace_inputs));

typedef AnnHnswAceTest<float, int8_t, uint32_t> AnnHnswAceInvalidPartitionTest_int8_t;
TEST_P(AnnHnswAceInvalidPartitionTest_int8_t, RejectsTooManyPartitions)
{
this->testHnswAceRejectsTooManyPartitions();
}

INSTANTIATE_TEST_CASE_P(AnnHnswAceInvalidPartitionTest,
AnnHnswAceInvalidPartitionTest_int8_t,
::testing::ValuesIn(hnsw_ace_invalid_partition_inputs));

// Test for memory limit fallback to disk mode
typedef AnnHnswAceTest<float, int8_t, uint32_t> AnnHnswAceMemoryFallbackTest_int8_t;
TEST_P(AnnHnswAceMemoryFallbackTest_int8_t, AnnHnswAceMemoryLimitFallback)
Expand Down
12 changes: 11 additions & 1 deletion cpp/tests/neighbors/ann_hnsw_ace/test_uint8_t_uint32_t.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -14,6 +14,16 @@ INSTANTIATE_TEST_CASE_P(AnnHnswAceTest,
AnnHnswAceTest_uint8_t,
::testing::ValuesIn(hnsw_ace_inputs));

typedef AnnHnswAceTest<float, uint8_t, uint32_t> AnnHnswAceInvalidPartitionTest_uint8_t;
TEST_P(AnnHnswAceInvalidPartitionTest_uint8_t, RejectsTooManyPartitions)
{
this->testHnswAceRejectsTooManyPartitions();
}

INSTANTIATE_TEST_CASE_P(AnnHnswAceInvalidPartitionTest,
AnnHnswAceInvalidPartitionTest_uint8_t,
::testing::ValuesIn(hnsw_ace_invalid_partition_inputs));

// Test for memory limit fallback to disk mode
typedef AnnHnswAceTest<float, uint8_t, uint32_t> AnnHnswAceMemoryFallbackTest_uint8_t;
TEST_P(AnnHnswAceMemoryFallbackTest_uint8_t, AnnHnswAceMemoryLimitFallback)
Expand Down
23 changes: 21 additions & 2 deletions python/cuvs/cuvs/tests/test_cagra_ace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#

Expand Down Expand Up @@ -163,7 +163,7 @@ def test_cagra_ace_dtypes_and_metrics(dtype, metric, use_disk):
)


@pytest.mark.parametrize("npartitions", [2, 3, 8])
@pytest.mark.parametrize("npartitions", [0, 1, 2, 3, 8])
def test_cagra_ace_partitions(npartitions):
"""Test ACE with different partition sizes (disk mode only)."""
run_cagra_ace_build_search_test(
Expand All @@ -172,6 +172,25 @@ def test_cagra_ace_partitions(npartitions):
)


@pytest.mark.parametrize(
("npartitions", "message"),
[
(33, "cannot exceed dataset size"),
],
)
def test_cagra_ace_rejects_invalid_partition_count(npartitions, message):
"""ACE rejects partition counts that would make partition labeling invalid."""
dataset = np.zeros((32, 8), dtype=np.float32)
ace_params = cagra.AceParams(npartitions=npartitions, use_disk=False)
build_params = cagra.IndexParams(
build_algo="ace",
ace_params=ace_params,
)

with pytest.raises(Exception, match=message):
cagra.build(build_params, dataset)


@pytest.mark.parametrize("ef_construction", [50, 100, 200])
def test_cagra_ace_ef_construction(ef_construction):
"""Test ACE with different ef_construction values (disk mode only)."""
Expand Down
Loading