Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- `KernelThinning.kt_half_recursive` now returns the correct partitions. (https://github.com/gchq/coreax/pull/1088)

### Changed

Expand Down
8 changes: 4 additions & 4 deletions coreax/solvers/coresubset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,10 +1401,10 @@ def kt_half_recursive(
"""
# If m == 0, do not do anything just convert to original data to type Coresubset
if m == 0:
return [
Coresubset(Data(jnp.arange(len(current_coreset))), original_dataset)
]

if isinstance(current_coreset, Coresubset):
return [current_coreset]
default_indices = Data(jnp.arange(len(current_coreset)))
return [Coresubset(default_indices, original_dataset)]
# Recursively call self.kt_half on the coreset (or the dataset)
if isinstance(current_coreset, Coresubset):
subset1, subset2 = self.kt_half(current_coreset.points)
Expand Down
45 changes: 43 additions & 2 deletions tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,10 @@ def test_reduce(
"""
del kwargs
dataset, solver, _ = reduce_problem
coreset, state = jit_variant(solver.reduce)(dataset)
_reduce = jit_variant(solver.reduce)
coreset, state = _reduce(dataset)
if use_cached_state:
coreset_with_state, recycled_state = solver.reduce(dataset, state)
coreset_with_state, recycled_state = _reduce(dataset, state)
assert eqx.tree_equal(recycled_state, state)
assert eqx.tree_equal(coreset_with_state, coreset)
self.check_solution_invariants(coreset, reduce_problem)
Expand Down Expand Up @@ -3334,6 +3335,46 @@ def deterministic_uniform(_key, _shape=None):
coresets[1], jnp.array([[0.7], [0.6], [0.1], [0.12]])
)

def test_kt_half_recursive(self, solver_factory: jtu.Partial):
"""Test that kt_half_recursive returns valid Coresubsets."""
m = 2
coreset_size = 3
dataset_size = coreset_size * 2**m
thinning_solver = cast(
KernelThinning, solver_factory(coreset_size=coreset_size)
)
dataset = Data(jnp.ones((dataset_size, 1)))
result = thinning_solver.kt_half_recursive(dataset, m, dataset)
indices = jnp.hstack(
jax.tree.map(
lambda x: x.unweighted_indices,
result,
is_leaf=lambda x: isinstance(x, Coresubset),
)
)
unique_indices = jnp.unique(indices, return_counts=True)
expected_sorted_indices = jnp.arange(dataset_size)
expected_counts = jnp.ones(dataset_size, dtype=expected_sorted_indices.dtype)
expected_unique_indices = (expected_sorted_indices, expected_counts)
assert eqx.tree_equal(unique_indices, expected_unique_indices)

def test_kt_half_recursive_zero_depth(self, solver_factory: jtu.Partial):
"""Test that kt_half_recursive returns valid Coresubsets when m=0."""
m = 0
coreset_size = 3
dataset_size = coreset_size * 2**m
thinning_solver = cast(
KernelThinning, solver_factory(coreset_size=coreset_size)
)
dataset = Data(jnp.ones((dataset_size, 1)))
result = thinning_solver.kt_half_recursive(dataset, m, dataset)
expected_result = [Coresubset(Data(jnp.arange(dataset_size)), dataset)]
assert eqx.tree_equal(expected_result, result)
coresubset = Coresubset(Data(jnp.array([1, 4, 3])), dataset)
result = thinning_solver.kt_half_recursive(coresubset, m, dataset)
expected_result = [coresubset]
assert eqx.tree_equal(expected_result, result)


class TestCompressPlusPlus(ExplicitSizeSolverTest):
"""Test cases for :class:`coreax.solvers.coresubset.KernelThinning`."""
Expand Down