From f7cf7d6133ef064027162c3a047a11802cbea6d7 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Sep 2025 22:40:55 +0200 Subject: [PATCH 01/62] First draft, needs tests & fixes --- sklearn/tree/_criterion.pyx | 256 +++++++++----------- sklearn/tree/_utils.pxd | 49 ---- sklearn/tree/_utils.pyx | 471 +++++++++++------------------------- 3 files changed, 249 insertions(+), 527 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 9f3db83399569..adacb51c145c6 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -12,7 +12,7 @@ cnp.import_array() from scipy.special.cython_special cimport xlogy from ._utils cimport log -from ._utils cimport WeightedMedianCalculator +from ._utils cimport WeightedHeap # EPSILON is used in the Poisson criterion cdef float64_t EPSILON = 10 * np.finfo('double').eps @@ -1186,11 +1186,11 @@ cdef class MAE(RegressionCriterion): MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true value and f_i is the predicted value.""" - cdef cnp.ndarray left_child - cdef cnp.ndarray right_child - cdef void** left_child_ptr - cdef void** right_child_ptr cdef float64_t[::1] node_medians + cdef float64_t[:, ::1] left_abs_errors + cdef float64_t[:, ::1] right_abs_errors + cdef float64_t[::1] left_medians + cdef float64_t[::1] right_medians def __cinit__(self, intp_t n_outputs, intp_t n_samples): """Initialize parameters for this criterion. @@ -1216,16 +1216,10 @@ cdef class MAE(RegressionCriterion): self.weighted_n_right = 0.0 self.node_medians = np.zeros(n_outputs, dtype=np.float64) - - self.left_child = np.empty(n_outputs, dtype='object') - self.right_child = np.empty(n_outputs, dtype='object') - # initialize WeightedMedianCalculators - for k in range(n_outputs): - self.left_child[k] = WeightedMedianCalculator(n_samples) - self.right_child[k] = WeightedMedianCalculator(n_samples) - - self.left_child_ptr = cnp.PyArray_DATA(self.left_child) - self.right_child_ptr = cnp.PyArray_DATA(self.right_child) + self.left_abs_errors = np.empty((n_outputs, n_samples), dtype=np.float64) + self.right_abs_errors = np.empty((n_outputs, n_samples), dtype=np.float64) + self.left_medians = np.empty(n_samples, dtype=np.float64) + self.right_medians = np.empty(n_samples, dtype=np.float64) cdef int init( self, @@ -1241,47 +1235,109 @@ cdef class MAE(RegressionCriterion): This initializes the criterion at node sample_indices[start:end] and children sample_indices[start:start] and sample_indices[start:end]. """ - cdef intp_t i, p, k + cdef intp_t i, k, j cdef float64_t w = 1.0 - + cdef intp_t n = end - start # Initialize fields self.y = y self.sample_weight = sample_weight self.sample_indices = sample_indices self.start = start self.end = end - self.n_node_samples = end - start + self.n_node_samples = n self.weighted_n_samples = weighted_n_samples self.weighted_n_node_samples = 0. - cdef void** left_child = self.left_child_ptr - cdef void** right_child = self.right_child_ptr - - for k in range(self.n_outputs): - ( left_child[k]).reset() - ( right_child[k]).reset() - for p in range(start, end): i = sample_indices[p] - - if sample_weight is not None: - w = sample_weight[i] - - for k in range(self.n_outputs): - # push method ends up calling safe_realloc, hence `except -1` - # push all values to the right side, - # since pos = start initially anyway - ( right_child[k]).push(self.y[i, k], w) - + if self.sample_weight is not None: + w = self.sample_weight[i] self.weighted_n_node_samples += w - # calculate the node medians - for k in range(self.n_outputs): - self.node_medians[k] = ( right_child[k]).get_median() + + for k in range(self.n_outputs - 1, -1, -1): + # TODO: think about indices alignment here and in update/children_impurity etc. + # it's likely wrong for now + self._precompute_absolute_errors(k, start, end, self.left_abs_errors, self.left_medians) + self._precompute_absolute_errors(k, end - 1, start - 1, self.right_abs_errors, self.right_medians) + self.node_medians[k] = self.right_medians[0] # Reset to pos=start self.reset() return 0 + cdef void _precompute_absolute_errors( + intp_t k, + intp_t start, + intp_t end, + float64_t[:, ::1] abs_errors, + float64_t[::1] medians) noexcept nogil: + """Fill `abs_errors` with prefix minimum AEs for (y[:i], w[:i]), i in [1, n-1]. + + Parameters + ---------- + y : 1D float64_t[::1] + Values. + w : 1D float64_t[::1] + Sample weights. + abs_errors : 1D float64_t[::1] + Output buffer, must have shape (n,). + """ + cdef intp_t n, step, j, p, i + if start < end: + step = 1 + j = 0 + else: + step = -1 + j = self.n_node_samples - 1 + + cdef WeightedHeap above = WeightedHeap(self.n_node_samples, True) # min-heap + cdef WeightedHeap below = WeightedHeap(self.n_node_samples, False) # max-heap + cdef float64_t y + cdef float64_t w = 1.0 + cdef float64_t val = 0.0 + cdef float64_t wt = 0.0 + cdef float64_t below_top = 0.0 + cdef float64_t below_wt = 0.0 + cdef float64_t median = 0.0 + cdef float64_t half_weight + + p = start + for _ in range(n): + i = self.sample_indices[p] + if sample_weight is not None: + w = self.sample_weight[i] + y = self.y[i, k] + + # Insert into the appropriate heap + if below.is_empty(): + above.push(y, w) + else: + below.peek(&below_top, &below_wt) + if y > below_top: + above.push(y, w) + else: + below.push(y, w) + + half_weight = (above.get_total_weight() + below.get_total_weight()) / 2.0 + + # Rebalance heaps + while above.get_total_weight() < half_weight and not below.is_empty(): + if below.pop(&val, &wt) == 0: + above.push(val, wt) + while (not above.is_empty() + and (above.get_total_weight() - above.top_weight()) > half_weight): + if above.pop(&val, &wt) == 0: + below.push(val, wt) + + # Current median + above.peek(&median, &wt) + medians[j] = median + abs_errors[k, j] = ((below.get_total_weight() - above.get_total_weight()) * median + - below.get_weighted_sum() + + above.get_weighted_sum()) + p += step + j += step + cdef void init_missing(self, intp_t n_missing) noexcept nogil: """Raise error if n_missing != 0.""" if n_missing == 0: @@ -1295,29 +1351,10 @@ cdef class MAE(RegressionCriterion): Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. """ - cdef intp_t i, k - cdef float64_t value - cdef float64_t weight - - cdef void** left_child = self.left_child_ptr - cdef void** right_child = self.right_child_ptr - self.weighted_n_left = 0.0 self.weighted_n_right = self.weighted_n_node_samples self.pos = self.start - # reset the WeightedMedianCalculators, left should have no - # elements and right should have all elements. - - for k in range(self.n_outputs): - # if left has no elements, it's already reset - for i in range(( left_child[k]).size()): - # remove everything from left and put it into right - ( left_child[k]).pop(&value, - &weight) - # push method ends up calling safe_realloc, hence `except -1` - ( right_child[k]).push(value, - weight) return 0 cdef int reverse_reset(self) except -1 nogil: @@ -1330,22 +1367,6 @@ cdef class MAE(RegressionCriterion): self.weighted_n_left = self.weighted_n_node_samples self.pos = self.end - cdef float64_t value - cdef float64_t weight - cdef void** left_child = self.left_child_ptr - cdef void** right_child = self.right_child_ptr - - # reverse reset the WeightedMedianCalculators, right should have no - # elements and left should have all elements. - for k in range(self.n_outputs): - # if right has no elements, it's already reset - for i in range(( right_child[k]).size()): - # remove everything from right and put it into left - ( right_child[k]).pop(&value, - &weight) - # push method ends up calling safe_realloc, hence `except -1` - ( left_child[k]).push(value, - weight) return 0 cdef int update(self, intp_t new_pos) except -1 nogil: @@ -1353,52 +1374,25 @@ cdef class MAE(RegressionCriterion): Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. + + Time complexity: O(new_pos - pos) (which usually is O(1)) """ cdef const float64_t[:] sample_weight = self.sample_weight cdef const intp_t[:] sample_indices = self.sample_indices - cdef void** left_child = self.left_child_ptr - cdef void** right_child = self.right_child_ptr - cdef intp_t pos = self.pos cdef intp_t end = self.end cdef intp_t i, p, k cdef float64_t w = 1.0 # Update statistics up to new_pos - # - # We are going to update right_child and left_child - # from the direction that require the least amount of - # computations, i.e. from pos to new_pos or from end to new_pos. - if (new_pos - pos) <= (end - new_pos): - for p in range(pos, new_pos): - i = sample_indices[p] - - if sample_weight is not None: - w = sample_weight[i] - - for k in range(self.n_outputs): - # remove y_ik and its weight w from right and add to left - ( right_child[k]).remove(self.y[i, k], w) - # push method ends up calling safe_realloc, hence except -1 - ( left_child[k]).push(self.y[i, k], w) - - self.weighted_n_left += w - else: - self.reverse_reset() - - for p in range(end - 1, new_pos - 1, -1): - i = sample_indices[p] - - if sample_weight is not None: - w = sample_weight[i] + for p in range(pos, new_pos): + i = sample_indices[p] - for k in range(self.n_outputs): - # remove y_ik and its weight w from left and add to right - ( left_child[k]).remove(self.y[i, k], w) - ( right_child[k]).push(self.y[i, k], w) + if sample_weight is not None: + w = sample_weight[i] - self.weighted_n_left -= w + self.weighted_n_left += w self.weighted_n_right = (self.weighted_n_node_samples - self.weighted_n_left) @@ -1418,9 +1412,10 @@ cdef class MAE(RegressionCriterion): Monotonicity constraints are only supported for single-output trees we can safely assume n_outputs == 1. """ + cdef intp_t j = self.pos - self.start return ( - ( self.left_child_ptr[0]).get_median() + - ( self.right_child_ptr[0]).get_median() + self.left_medians[j] + + self.right_medians[j] ) / 2 cdef inline bint check_monotonicity( @@ -1430,11 +1425,11 @@ cdef class MAE(RegressionCriterion): float64_t upper_bound, ) noexcept nogil: """Check monotonicity constraint is satisfied at the current regression split""" - cdef: - float64_t value_left = ( self.left_child_ptr[0]).get_median() - float64_t value_right = ( self.right_child_ptr[0]).get_median() + cdef intp_t j = self.pos - self.start - return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) + return self._check_monotonicity( + monotonic_cst, lower_bound, upper_bound, + self.left_medians[j], self.right_medians[j]) cdef float64_t node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. @@ -1466,44 +1461,21 @@ cdef class MAE(RegressionCriterion): i.e. the impurity of the left child (sample_indices[start:pos]) and the impurity the right child (sample_indices[pos:end]). - """ - cdef const float64_t[:] sample_weight = self.sample_weight - cdef const intp_t[:] sample_indices = self.sample_indices - - cdef intp_t start = self.start - cdef intp_t pos = self.pos - cdef intp_t end = self.end - cdef intp_t i, p, k - cdef float64_t median - cdef float64_t w = 1.0 + Time complexity: O(n_outputs) + """ + cdef intp_t j = self.pos - self.start + cdef intp_t k cdef float64_t impurity_left = 0.0 cdef float64_t impurity_right = 0.0 - cdef void** left_child = self.left_child_ptr - cdef void** right_child = self.right_child_ptr - for k in range(self.n_outputs): - median = ( left_child[k]).get_median() - for p in range(start, pos): - i = sample_indices[p] - - if sample_weight is not None: - w = sample_weight[i] - - impurity_left += fabs(self.y[i, k] - median) * w + impurity_left += self.left_abs_errors[k, j] p_impurity_left[0] = impurity_left / (self.weighted_n_left * self.n_outputs) for k in range(self.n_outputs): - median = ( right_child[k]).get_median() - for p in range(pos, end): - i = sample_indices[p] - - if sample_weight is not None: - w = sample_weight[i] - - impurity_right += fabs(self.y[i, k] - median) * w + impurity_right += self.right_abs_errors[k, j] p_impurity_right[0] = impurity_right / (self.weighted_n_right * self.n_outputs) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index bc1d7668187d7..9eb059b99130b 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -28,7 +28,6 @@ ctypedef fused realloc_ptr: (float32_t*) (intp_t*) (uint8_t*) - (WeightedPQueueRecord*) (float64_t*) (float64_t**) (Node*) @@ -50,51 +49,3 @@ cdef float64_t rand_uniform(float64_t low, float64_t high, cdef float64_t log(float64_t x) noexcept nogil - -# ============================================================================= -# WeightedPQueue data structure -# ============================================================================= - -# A record stored in the WeightedPQueue -cdef struct WeightedPQueueRecord: - float64_t data - float64_t weight - -cdef class WeightedPQueue: - cdef intp_t capacity - cdef intp_t array_ptr - cdef WeightedPQueueRecord* array_ - - cdef bint is_empty(self) noexcept nogil - cdef int reset(self) except -1 nogil - cdef intp_t size(self) noexcept nogil - cdef int push(self, float64_t data, float64_t weight) except -1 nogil - cdef int remove(self, float64_t data, float64_t weight) noexcept nogil - cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil - cdef int peek(self, float64_t* data, float64_t* weight) noexcept nogil - cdef float64_t get_weight_from_index(self, intp_t index) noexcept nogil - cdef float64_t get_value_from_index(self, intp_t index) noexcept nogil - - -# ============================================================================= -# WeightedMedianCalculator data structure -# ============================================================================= - -cdef class WeightedMedianCalculator: - cdef intp_t initial_capacity - cdef WeightedPQueue samples - cdef float64_t total_weight - cdef intp_t k - cdef float64_t sum_w_0_k # represents sum(weights[0:k]) = w[0] + w[1] + ... + w[k-1] - cdef intp_t size(self) noexcept nogil - cdef int push(self, float64_t data, float64_t weight) except -1 nogil - cdef int reset(self) except -1 nogil - cdef int update_median_parameters_post_push( - self, float64_t data, float64_t weight, - float64_t original_median) noexcept nogil - cdef int remove(self, float64_t data, float64_t weight) noexcept nogil - cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil - cdef int update_median_parameters_post_remove( - self, float64_t data, float64_t weight, - float64_t original_median) noexcept nogil - cdef float64_t get_median(self) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index c5e936ae48eb1..f98a546543a87 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -66,379 +66,178 @@ cdef inline float64_t log(float64_t x) noexcept nogil: return ln(x) / ln(2.0) # ============================================================================= -# WeightedPQueue data structure +# WeightedHeap data structure # ============================================================================= -cdef class WeightedPQueue: - """A priority queue class, always sorted in increasing order. +cdef class WeightedHeap: + """Binary heap with per-item weights, supporting min-heap and max-heap modes. + + Values are stored sign-adjusted internally so that the ordering logic + is always "min-heap" on the stored buffer: + - if min_heap: store v + - else (max-heap): store -v Attributes ---------- capacity : intp_t - The capacity of the priority queue. + Allocated capacity for the heap arrays. + + size_ : intp_t + Current number of elements in the heap. + + heap_ : float64_t* + Array of (possibly sign-adjusted) values that determines ordering. + + weights_ : float64_t* + Parallel array of weights. + + total_weight : float64_t + Sum of all weights currently in the heap. - array_ptr : intp_t - The water mark of the priority queue; the priority queue grows from - left to right in the array ``array_``. ``array_ptr`` is always - less than ``capacity``. + weighted_sum : float64_t + Sum over items of (original_value * weight), i.e. without sign-adjustment. - array_ : WeightedPQueueRecord* - The array of priority queue records. The minimum element is on the - left at index 0, and the maximum element is on the right at index - ``array_ptr-1``. + min_heap : bint + If True, behaves as a min-heap; if False, behaves as a max-heap. """ - def __cinit__(self, intp_t capacity): + def __cinit__(self, intp_t capacity, bint min_heap=True): + if capacity <= 0: + capacity = 1 self.capacity = capacity - self.array_ptr = 0 - safe_realloc(&self.array_, capacity) + self.size_ = 0 + self.min_heap = min_heap + self.total_weight = 0.0 + self.weighted_sum = 0.0 + self.heap_ = NULL + self.weights_ = NULL + # safe_realloc can raise MemoryError -> __cinit__ may propagate + safe_realloc(&self.heap_, capacity) + safe_realloc(&self.weights_, capacity) def __dealloc__(self): - free(self.array_) + if self.heap_ != NULL: + free(self.heap_) + if self.weights_ != NULL: + free(self.weights_) cdef int reset(self) except -1 nogil: - """Reset the WeightedPQueue to its state at construction - - Return -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. - """ - self.array_ptr = 0 - # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.array_, self.capacity) + """Reset to construction state (keeps capacity).""" + self.size_ = 0 + self.total_weight = 0.0 + self.weighted_sum = 0.0 + # Ensure buffers still allocated (realloc may raise MemoryError) + safe_realloc(&self.heap_, self.capacity) + safe_realloc(&self.weights_, self.capacity) return 0 cdef bint is_empty(self) noexcept nogil: - return self.array_ptr <= 0 + return self.size_ == 0 cdef intp_t size(self) noexcept nogil: - return self.array_ptr + return self.size_ - cdef int push(self, float64_t data, float64_t weight) except -1 nogil: - """Push record on the array. + cdef int push(self, float64_t value, float64_t weight) except -1 nogil: + """Insert a (value, weight). Returns 0 or raises MemoryError on alloc fail.""" + cdef intp_t n = self.size_ + cdef float64_t stored = value if self.min_heap else -value - Return -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. - """ - cdef intp_t array_ptr = self.array_ptr - cdef WeightedPQueueRecord* array = NULL - cdef intp_t i - - # Resize if capacity not sufficient - if array_ptr >= self.capacity: + if n >= self.capacity: self.capacity *= 2 - # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.array_, self.capacity) - - # Put element as last element of array - array = self.array_ - array[array_ptr].data = data - array[array_ptr].weight = weight - - # bubble last element up according until it is sorted - # in ascending order - i = array_ptr - while(i != 0 and array[i].data < array[i-1].data): - array[i], array[i-1] = array[i-1], array[i] - i -= 1 - - # Increase element count - self.array_ptr = array_ptr + 1 - return 0 - - cdef int remove(self, float64_t data, float64_t weight) noexcept nogil: - """Remove a specific value/weight record from the array. - Returns 0 if successful, -1 if record not found.""" - cdef intp_t array_ptr = self.array_ptr - cdef WeightedPQueueRecord* array = self.array_ - cdef intp_t idx_to_remove = -1 - cdef intp_t i + safe_realloc(&self.heap_, self.capacity) + safe_realloc(&self.weights_, self.capacity) - if array_ptr <= 0: - return -1 - - # find element to remove - for i in range(array_ptr): - if array[i].data == data and array[i].weight == weight: - idx_to_remove = i - break + self.heap_[n] = stored + self.weights_[n] = weight + self.size_ = n + 1 - if idx_to_remove == -1: - return -1 - - # shift the elements after the removed element - # to the left. - for i in range(idx_to_remove, array_ptr-1): - array[i] = array[i+1] + self.total_weight += weight + self.weighted_sum += value * weight - self.array_ptr = array_ptr - 1 + self._perc_up(n) return 0 - cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil: - """Remove the top (minimum) element from array. - Returns 0 if successful, -1 if nothing to remove.""" - cdef intp_t array_ptr = self.array_ptr - cdef WeightedPQueueRecord* array = self.array_ - cdef intp_t i - - if array_ptr <= 0: + cdef int pop(self, float64_t* value, float64_t* weight) noexcept nogil: + """Pop top element into pointers. Returns 0 on success, -1 if empty.""" + cdef intp_t n = self.size_ + if n == 0: return -1 - data[0] = array[0].data - weight[0] = array[0].weight - - # shift the elements after the removed element - # to the left. - for i in range(0, array_ptr-1): - array[i] = array[i+1] - - self.array_ptr = array_ptr - 1 + self._peek_raw(value, weight) + + # Update aggregates with *original* value (undo sign for max-heap) + cdef float64_t orig_v = value[0] + cdef float64_t w = weight[0] + self.total_weight -= w + self.weighted_sum -= orig_v * w + + # Move last to root and sift down + n -= 1 + self.size_ = n + if n > 0: + self.heap_[0] = self.heap_[n] + self.weights_[0] = self.weights_[n] + self._perc_down(0) return 0 - cdef int peek(self, float64_t* data, float64_t* weight) noexcept nogil: - """Write the top element from array to a pointer. - Returns 0 if successful, -1 if nothing to write.""" - cdef WeightedPQueueRecord* array = self.array_ - if self.array_ptr <= 0: + cdef int peek(self, float64_t* value, float64_t* weight) noexcept nogil: + """Write top element into pointers without removing it. Returns 0, or -1 if empty.""" + if self.size_ == 0: return -1 - # Take first value - data[0] = array[0].data - weight[0] = array[0].weight + self._peek_raw(value, weight) return 0 - cdef float64_t get_weight_from_index(self, intp_t index) noexcept nogil: - """Given an index between [0,self.current_capacity], access - the appropriate heap and return the requested weight""" - cdef WeightedPQueueRecord* array = self.array_ - - # get weight at index - return array[index].weight - - cdef float64_t get_value_from_index(self, intp_t index) noexcept nogil: - """Given an index between [0,self.current_capacity], access - the appropriate heap and return the requested value""" - cdef WeightedPQueueRecord* array = self.array_ - - # get value at index - return array[index].data - -# ============================================================================= -# WeightedMedianCalculator data structure -# ============================================================================= - -cdef class WeightedMedianCalculator: - """A class to handle calculation of the weighted median from streams of - data. To do so, it maintains a parameter ``k`` such that the sum of the - weights in the range [0,k) is greater than or equal to half of the total - weight. By minimizing the value of ``k`` that fulfills this constraint, - calculating the median is done by either taking the value of the sample - at index ``k-1`` of ``samples`` (samples[k-1].data) or the average of - the samples at index ``k-1`` and ``k`` of ``samples`` - ((samples[k-1] + samples[k]) / 2). - - Attributes - ---------- - initial_capacity : intp_t - The initial capacity of the WeightedMedianCalculator. - - samples : WeightedPQueue - Holds the samples (consisting of values and their weights) used in the - weighted median calculation. - - total_weight : float64_t - The sum of the weights of items in ``samples``. Represents the total - weight of all samples used in the median calculation. - - k : intp_t - Index used to calculate the median. - - sum_w_0_k : float64_t - The sum of the weights from samples[0:k]. Used in the weighted - median calculation; minimizing the value of ``k`` such that - ``sum_w_0_k`` >= ``total_weight / 2`` provides a mechanism for - calculating the median in constant time. - - """ - - def __cinit__(self, intp_t initial_capacity): - self.initial_capacity = initial_capacity - self.samples = WeightedPQueue(initial_capacity) - self.total_weight = 0 - self.k = 0 - self.sum_w_0_k = 0 - - cdef intp_t size(self) noexcept nogil: - """Return the number of samples in the - WeightedMedianCalculator""" - return self.samples.size() - - cdef int reset(self) except -1 nogil: - """Reset the WeightedMedianCalculator to its state at construction - - Return -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. - """ - # samples.reset (WeightedPQueue.reset) uses safe_realloc, hence - # except -1 - self.samples.reset() - self.total_weight = 0 - self.k = 0 - self.sum_w_0_k = 0 - return 0 - - cdef int push(self, float64_t data, float64_t weight) except -1 nogil: - """Push a value and its associated weight to the WeightedMedianCalculator - - Return -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. - """ - cdef int return_value - cdef float64_t original_median = 0.0 - - if self.size() != 0: - original_median = self.get_median() - # samples.push (WeightedPQueue.push) uses safe_realloc, hence except -1 - return_value = self.samples.push(data, weight) - self.update_median_parameters_post_push(data, weight, - original_median) - return return_value - - cdef int update_median_parameters_post_push( - self, float64_t data, float64_t weight, - float64_t original_median) noexcept nogil: - """Update the parameters used in the median calculation, - namely `k` and `sum_w_0_k` after an insertion""" - - # trivial case of one element. - if self.size() == 1: - self.k = 1 - self.total_weight = weight - self.sum_w_0_k = self.total_weight - return 0 - - # get the original weighted median - self.total_weight += weight + cdef float64_t get_total_weight(self) noexcept nogil: + return self.total_weight + + cdef float64_t get_weighted_sum(self) noexcept nogil: + return self.weighted_sum + + # ---------------------------- + # Internal helpers (nogil) + # ---------------------------- + + cdef void _peek_raw(self, float64_t* value, float64_t* weight) noexcept nogil: + """Internal: read top with proper sign restoration.""" + cdef float64_t stored = self.heap_[0] + value[0] = stored if self.min_heap else -stored + weight[0] = self.weights_[0] + + cdef inline void _swap(self, intp_t i, intp_t j) noexcept nogil: + cdef float64_t tv = self.heap_[i] + cdef float64_t tw = self.weights_[i] + self.heap_[i] = self.heap_[j] + self.weights_[i] = self.weights_[j] + self.heap_[j] = tv + self.weights_[j] = tw + + cdef void _perc_up(self, intp_t i) noexcept nogil: + cdef intp_t p + while i > 0: + p = (i - 1) >> 1 + if self.heap_[i] < self.heap_[p]: + self._swap(i, p) + i = p + else: + break - if data < original_median: - # inserting below the median, so increment k and - # then update self.sum_w_0_k accordingly by adding - # the weight that was added. - self.k += 1 - # update sum_w_0_k by adding the weight added - self.sum_w_0_k += weight - - # minimize k such that sum(W[0:k]) >= total_weight / 2 - # minimum value of k is 1 - while(self.k > 1 and ((self.sum_w_0_k - - self.samples.get_weight_from_index(self.k-1)) - >= self.total_weight / 2.0)): - self.k -= 1 - self.sum_w_0_k -= self.samples.get_weight_from_index(self.k) - return 0 - - if data >= original_median: - # inserting above or at the median - # minimize k such that sum(W[0:k]) >= total_weight / 2 - while(self.k < self.samples.size() and - (self.sum_w_0_k < self.total_weight / 2.0)): - self.k += 1 - self.sum_w_0_k += self.samples.get_weight_from_index(self.k-1) - return 0 - - cdef int remove(self, float64_t data, float64_t weight) noexcept nogil: - """Remove a value from the MedianHeap, removing it - from consideration in the median calculation - """ - cdef int return_value - cdef float64_t original_median = 0.0 - - if self.size() != 0: - original_median = self.get_median() - - return_value = self.samples.remove(data, weight) - self.update_median_parameters_post_remove(data, weight, - original_median) - return return_value - - cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil: - """Pop a value from the MedianHeap, starting from the - left and moving to the right. - """ - cdef int return_value - cdef float64_t original_median = 0.0 - - if self.size() != 0: - original_median = self.get_median() - - # no elements to pop - if self.samples.size() == 0: - return -1 + cdef void _perc_down(self, intp_t i) noexcept nogil: + cdef intp_t n = self.size_ + cdef intp_t left, right, mc + while True: + left = (i << 1) + 1 + right = left + 1 + if left >= n: + return + mc = left + if right < n and self.heap_[right] < self.heap_[left]: + mc = right + if self.heap_[i] > self.heap_[mc]: + self._swap(i, mc) + i = mc + else: + return - return_value = self.samples.pop(data, weight) - self.update_median_parameters_post_remove(data[0], - weight[0], - original_median) - return return_value - - cdef int update_median_parameters_post_remove( - self, float64_t data, float64_t weight, - float64_t original_median) noexcept nogil: - """Update the parameters used in the median calculation, - namely `k` and `sum_w_0_k` after a removal""" - # reset parameters because it there are no elements - if self.samples.size() == 0: - self.k = 0 - self.total_weight = 0 - self.sum_w_0_k = 0 - return 0 - - # trivial case of one element. - if self.samples.size() == 1: - self.k = 1 - self.total_weight -= weight - self.sum_w_0_k = self.total_weight - return 0 - - # get the current weighted median - self.total_weight -= weight - - if data < original_median: - # removing below the median, so decrement k and - # then update self.sum_w_0_k accordingly by subtracting - # the removed weight - - self.k -= 1 - # update sum_w_0_k by removing the weight at index k - self.sum_w_0_k -= weight - - # minimize k such that sum(W[0:k]) >= total_weight / 2 - # by incrementing k and updating sum_w_0_k accordingly - # until the condition is met. - while(self.k < self.samples.size() and - (self.sum_w_0_k < self.total_weight / 2.0)): - self.k += 1 - self.sum_w_0_k += self.samples.get_weight_from_index(self.k-1) - return 0 - - if data >= original_median: - # removing above the median - # minimize k such that sum(W[0:k]) >= total_weight / 2 - while(self.k > 1 and ((self.sum_w_0_k - - self.samples.get_weight_from_index(self.k-1)) - >= self.total_weight / 2.0)): - self.k -= 1 - self.sum_w_0_k -= self.samples.get_weight_from_index(self.k) - return 0 - - cdef float64_t get_median(self) noexcept nogil: - """Write the median to a pointer, taking into account - sample weights.""" - if self.sum_w_0_k == (self.total_weight / 2.0): - # split median - return (self.samples.get_value_from_index(self.k) + - self.samples.get_value_from_index(self.k-1)) / 2.0 - if self.sum_w_0_k > (self.total_weight / 2.0): - # whole median - return self.samples.get_value_from_index(self.k-1) def _any_isnan_axis0(const float32_t[:, :] X): From f4edaa2bab93e49d0d921798d1d0224e64fa195b Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 11:15:29 +0200 Subject: [PATCH 02/62] fixed compilation errors --- sklearn/tree/_criterion.pxd | 21 ++++++++++ sklearn/tree/_criterion.pyx | 83 +++++++++++++++++++------------------ sklearn/tree/_utils.pxd | 23 ++++++++++ sklearn/tree/_utils.pyx | 26 +++++++----- 4 files changed, 103 insertions(+), 50 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 84d2e800d6a87..96e20addb4edc 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -3,6 +3,7 @@ # See _criterion.pyx for implementation details. from ..utils._typedefs cimport float64_t, int8_t, intp_t +from ._utils cimport WeightedHeap cdef class Criterion: @@ -107,3 +108,23 @@ cdef class RegressionCriterion(Criterion): cdef float64_t[::1] sum_left # Same as above, but for the left side of the split cdef float64_t[::1] sum_right # Same as above, but for the right side of the split cdef float64_t[::1] sum_missing # Same as above, but for missing values in X + + +cdef class MAE(RegressionCriterion): + + cdef float64_t[::1] node_medians + cdef float64_t[:, ::1] left_abs_errors + cdef float64_t[:, ::1] right_abs_errors + cdef float64_t[::1] left_medians + cdef float64_t[::1] right_medians + cdef WeightedHeap above + cdef WeightedHeap below + + cdef inline void _precompute_absolute_errors( + self, + intp_t k, + intp_t start, + intp_t end, + float64_t[::1] abs_errors, + float64_t[::1] medians + ) noexcept nogil diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index adacb51c145c6..392e03f95d6f5 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1186,12 +1186,6 @@ cdef class MAE(RegressionCriterion): MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true value and f_i is the predicted value.""" - cdef float64_t[::1] node_medians - cdef float64_t[:, ::1] left_abs_errors - cdef float64_t[:, ::1] right_abs_errors - cdef float64_t[::1] left_medians - cdef float64_t[::1] right_medians - def __cinit__(self, intp_t n_outputs, intp_t n_samples): """Initialize parameters for this criterion. @@ -1221,6 +1215,9 @@ cdef class MAE(RegressionCriterion): self.left_medians = np.empty(n_samples, dtype=np.float64) self.right_medians = np.empty(n_samples, dtype=np.float64) + self.above = WeightedHeap(n_samples, True) # min-heap + self.below = WeightedHeap(n_samples, False) # max-heap + cdef int init( self, const float64_t[:, ::1] y, @@ -1250,15 +1247,15 @@ cdef class MAE(RegressionCriterion): for p in range(start, end): i = sample_indices[p] - if self.sample_weight is not None: - w = self.sample_weight[i] + if sample_weight is not None: + w = sample_weight[i] self.weighted_n_node_samples += w for k in range(self.n_outputs - 1, -1, -1): # TODO: think about indices alignment here and in update/children_impurity etc. # it's likely wrong for now - self._precompute_absolute_errors(k, start, end, self.left_abs_errors, self.left_medians) - self._precompute_absolute_errors(k, end - 1, start - 1, self.right_abs_errors, self.right_medians) + self._precompute_absolute_errors(k, start, end, self.left_abs_errors[k], self.left_medians) + self._precompute_absolute_errors(k, end - 1, start - 1, self.right_abs_errors[k], self.right_medians) self.node_medians[k] = self.right_medians[0] # Reset to pos=start @@ -1266,11 +1263,13 @@ cdef class MAE(RegressionCriterion): return 0 cdef void _precompute_absolute_errors( - intp_t k, - intp_t start, - intp_t end, - float64_t[:, ::1] abs_errors, - float64_t[::1] medians) noexcept nogil: + self, + intp_t k, + intp_t start, + intp_t end, + float64_t[::1] abs_errors, + float64_t[::1] medians + ) noexcept nogil: """Fill `abs_errors` with prefix minimum AEs for (y[:i], w[:i]), i in [1, n-1]. Parameters @@ -1282,7 +1281,10 @@ cdef class MAE(RegressionCriterion): abs_errors : 1D float64_t[::1] Output buffer, must have shape (n,). """ - cdef intp_t n, step, j, p, i + cdef const float64_t[:] sample_weight = self.sample_weight + cdef const intp_t[:] sample_indices = self.sample_indices + cdef const float64_t[:, ::1] ys = self.y + cdef intp_t step, j, p, i if start < end: step = 1 j = 0 @@ -1290,8 +1292,8 @@ cdef class MAE(RegressionCriterion): step = -1 j = self.n_node_samples - 1 - cdef WeightedHeap above = WeightedHeap(self.n_node_samples, True) # min-heap - cdef WeightedHeap below = WeightedHeap(self.n_node_samples, False) # max-heap + self.above.reset() + self.below.reset() cdef float64_t y cdef float64_t w = 1.0 cdef float64_t val = 0.0 @@ -1302,39 +1304,40 @@ cdef class MAE(RegressionCriterion): cdef float64_t half_weight p = start - for _ in range(n): - i = self.sample_indices[p] + for _ in range(self.n_node_samples): + i = sample_indices[p] if sample_weight is not None: - w = self.sample_weight[i] - y = self.y[i, k] + w = sample_weight[i] + y = ys[i, k] # Insert into the appropriate heap - if below.is_empty(): - above.push(y, w) + if self.below.is_empty(): + self.above.push(y, w) else: - below.peek(&below_top, &below_wt) - if y > below_top: - above.push(y, w) + if y > self.below.top(): + self.above.push(y, w) else: - below.push(y, w) + self.below.push(y, w) - half_weight = (above.get_total_weight() + below.get_total_weight()) / 2.0 + half_weight = (self.above.get_total_weight() + self.below.get_total_weight()) / 2.0 # Rebalance heaps - while above.get_total_weight() < half_weight and not below.is_empty(): - if below.pop(&val, &wt) == 0: - above.push(val, wt) - while (not above.is_empty() - and (above.get_total_weight() - above.top_weight()) > half_weight): - if above.pop(&val, &wt) == 0: - below.push(val, wt) + while self.above.get_total_weight() < half_weight and not self.below.is_empty(): + if self.below.pop(&val, &wt) == 0: + self.above.push(val, wt) + while (not self.above.is_empty() + and (self.above.get_total_weight() - self.above.top_weight()) > half_weight): + if self.above.pop(&val, &wt) == 0: + self.below.push(val, wt) # Current median - above.peek(&median, &wt) + median = self.above.top() medians[j] = median - abs_errors[k, j] = ((below.get_total_weight() - above.get_total_weight()) * median - - below.get_weighted_sum() - + above.get_weighted_sum()) + abs_errors[j] = ( + (self.below.get_total_weight() - self.above.get_total_weight()) * median + - self.below.get_weighted_sum() + + self.above.get_weighted_sum() + ) p += step j += step diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 9eb059b99130b..68b7595348a6e 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -47,5 +47,28 @@ cdef intp_t rand_int(intp_t low, intp_t high, cdef float64_t rand_uniform(float64_t low, float64_t high, uint32_t* random_state) noexcept nogil +cdef class WeightedHeap: + cdef intp_t capacity + cdef intp_t size_ + cdef float64_t* heap_ + cdef float64_t* weights_ + cdef float64_t total_weight + cdef float64_t weighted_sum + cdef bint min_heap + + cdef int reset(self) except -1 nogil + cdef bint is_empty(self) noexcept nogil + cdef intp_t size(self) noexcept nogil + cdef int push(self, float64_t value, float64_t weight) except -1 nogil + cdef int pop(self, float64_t* value, float64_t* weight) noexcept nogil + cdef float64_t get_total_weight(self) noexcept nogil + cdef float64_t get_weighted_sum(self) noexcept nogil + cdef float64_t top_weight(self) noexcept nogil + cdef float64_t top(self) noexcept nogil + cdef void _peek_raw(self, float64_t*, float64_t*) noexcept nogil + cdef void _swap(self, intp_t, intp_t) noexcept nogil + cdef void _perc_up(self, intp_t) noexcept nogil + cdef void _perc_down(self, intp_t) noexcept nogil + cdef float64_t log(float64_t x) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index f98a546543a87..2756b32ed0bc8 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -127,6 +127,7 @@ cdef class WeightedHeap: self.total_weight = 0.0 self.weighted_sum = 0.0 # Ensure buffers still allocated (realloc may raise MemoryError) + # TODO: is this really needed? safe_realloc(&self.heap_, self.capacity) safe_realloc(&self.weights_, self.capacity) return 0 @@ -180,24 +181,29 @@ cdef class WeightedHeap: self._perc_down(0) return 0 - cdef int peek(self, float64_t* value, float64_t* weight) noexcept nogil: - """Write top element into pointers without removing it. Returns 0, or -1 if empty.""" - if self.size_ == 0: - return -1 - self._peek_raw(value, weight) - return 0 - cdef float64_t get_total_weight(self) noexcept nogil: return self.total_weight cdef float64_t get_weighted_sum(self) noexcept nogil: return self.weighted_sum + cdef float64_t top_weight(self) noexcept nogil: + if self.size_ == 0: + return 0.0 + return self.weights_[0] + + cdef float64_t top(self) noexcept nogil: + if self.size_ == 0: + return 0.0 + cdef float64_t s = self.heap_[0] + return s if self.min_heap else -s + + # ---------------------------- # Internal helpers (nogil) # ---------------------------- - cdef void _peek_raw(self, float64_t* value, float64_t* weight) noexcept nogil: + cdef inline void _peek_raw(self, float64_t* value, float64_t* weight) noexcept nogil: """Internal: read top with proper sign restoration.""" cdef float64_t stored = self.heap_[0] value[0] = stored if self.min_heap else -stored @@ -211,7 +217,7 @@ cdef class WeightedHeap: self.heap_[j] = tv self.weights_[j] = tw - cdef void _perc_up(self, intp_t i) noexcept nogil: + cdef inline void _perc_up(self, intp_t i) noexcept nogil: cdef intp_t p while i > 0: p = (i - 1) >> 1 @@ -221,7 +227,7 @@ cdef class WeightedHeap: else: break - cdef void _perc_down(self, intp_t i) noexcept nogil: + cdef inline void _perc_down(self, intp_t i) noexcept nogil: cdef intp_t n = self.size_ cdef intp_t left, right, mc while True: From 01fd9b2e0dac52a6599b76c3618c1fe8fb2c0f02 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 11:38:13 +0200 Subject: [PATCH 03/62] fixed compilation errors --- sklearn/tree/_criterion.pyx | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 392e03f95d6f5..cc25c7c733c08 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1254,8 +1254,8 @@ cdef class MAE(RegressionCriterion): for k in range(self.n_outputs - 1, -1, -1): # TODO: think about indices alignment here and in update/children_impurity etc. # it's likely wrong for now - self._precompute_absolute_errors(k, start, end, self.left_abs_errors[k], self.left_medians) - self._precompute_absolute_errors(k, end - 1, start - 1, self.right_abs_errors[k], self.right_medians) + self._precompute_absolute_errors(k, start, 1, self.left_abs_errors[k], self.left_medians) + self._precompute_absolute_errors(k, end - 1, -1, self.right_abs_errors[k], self.right_medians) self.node_medians[k] = self.right_medians[0] # Reset to pos=start @@ -1266,7 +1266,7 @@ cdef class MAE(RegressionCriterion): self, intp_t k, intp_t start, - intp_t end, + intp_t step, float64_t[::1] abs_errors, float64_t[::1] medians ) noexcept nogil: @@ -1285,11 +1285,9 @@ cdef class MAE(RegressionCriterion): cdef const intp_t[:] sample_indices = self.sample_indices cdef const float64_t[:, ::1] ys = self.y cdef intp_t step, j, p, i - if start < end: - step = 1 + if step > 0: j = 0 else: - step = -1 j = self.n_node_samples - 1 self.above.reset() @@ -1472,13 +1470,15 @@ cdef class MAE(RegressionCriterion): cdef float64_t impurity_left = 0.0 cdef float64_t impurity_right = 0.0 - for k in range(self.n_outputs): - impurity_left += self.left_abs_errors[k, j] + if j > 0: + for k in range(self.n_outputs): + impurity_left += self.left_abs_errors[k, j - 1] p_impurity_left[0] = impurity_left / (self.weighted_n_left * self.n_outputs) - for k in range(self.n_outputs): - impurity_right += self.right_abs_errors[k, j] + if self.pos < self.end: + for k in range(self.n_outputs): + impurity_right += self.right_abs_errors[k, j] p_impurity_right[0] = impurity_right / (self.weighted_n_right * self.n_outputs) From 3f87b99b12c31eb1ab4127ac5253d08b714dc303 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 16:45:07 +0200 Subject: [PATCH 04/62] Moved AE computation in external helper to be able to unit-test it; added print everywhere to debug; fixed some bugs --- sklearn/tree/_criterion.pxd | 9 --- sklearn/tree/_criterion.pyx | 135 +++++++++++------------------------- sklearn/tree/_utils.pxd | 12 ++++ sklearn/tree/_utils.pyx | 114 ++++++++++++++++++++++++++++++ 4 files changed, 166 insertions(+), 104 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 96e20addb4edc..850a31224c10b 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -119,12 +119,3 @@ cdef class MAE(RegressionCriterion): cdef float64_t[::1] right_medians cdef WeightedHeap above cdef WeightedHeap below - - cdef inline void _precompute_absolute_errors( - self, - intp_t k, - intp_t start, - intp_t end, - float64_t[::1] abs_errors, - float64_t[::1] medians - ) noexcept nogil diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index cc25c7c733c08..65971ebeaec22 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -4,6 +4,7 @@ from libc.string cimport memcpy from libc.string cimport memset from libc.math cimport fabs, INFINITY +from libc.stdio cimport printf import numpy as np cimport numpy as cnp @@ -13,6 +14,7 @@ from scipy.special.cython_special cimport xlogy from ._utils cimport log from ._utils cimport WeightedHeap +from ._utils cimport precompute_absolute_errors # EPSILON is used in the Poisson criterion cdef float64_t EPSILON = 10 * np.finfo('double').eps @@ -1231,6 +1233,8 @@ cdef class MAE(RegressionCriterion): This initializes the criterion at node sample_indices[start:end] and children sample_indices[start:start] and sample_indices[start:end]. + + WARNING: sample_indices will be modified in-place externally """ cdef intp_t i, k, j cdef float64_t w = 1.0 @@ -1245,100 +1249,19 @@ cdef class MAE(RegressionCriterion): self.weighted_n_samples = weighted_n_samples self.weighted_n_node_samples = 0. + # printf("start - end: %d %d\n", start, end) + for p in range(start, end): i = sample_indices[p] if sample_weight is not None: w = sample_weight[i] + # printf(" %.2f", y[i, 0]) self.weighted_n_node_samples += w - for k in range(self.n_outputs - 1, -1, -1): - # TODO: think about indices alignment here and in update/children_impurity etc. - # it's likely wrong for now - self._precompute_absolute_errors(k, start, 1, self.left_abs_errors[k], self.left_medians) - self._precompute_absolute_errors(k, end - 1, -1, self.right_abs_errors[k], self.right_medians) - self.node_medians[k] = self.right_medians[0] - # Reset to pos=start self.reset() return 0 - cdef void _precompute_absolute_errors( - self, - intp_t k, - intp_t start, - intp_t step, - float64_t[::1] abs_errors, - float64_t[::1] medians - ) noexcept nogil: - """Fill `abs_errors` with prefix minimum AEs for (y[:i], w[:i]), i in [1, n-1]. - - Parameters - ---------- - y : 1D float64_t[::1] - Values. - w : 1D float64_t[::1] - Sample weights. - abs_errors : 1D float64_t[::1] - Output buffer, must have shape (n,). - """ - cdef const float64_t[:] sample_weight = self.sample_weight - cdef const intp_t[:] sample_indices = self.sample_indices - cdef const float64_t[:, ::1] ys = self.y - cdef intp_t step, j, p, i - if step > 0: - j = 0 - else: - j = self.n_node_samples - 1 - - self.above.reset() - self.below.reset() - cdef float64_t y - cdef float64_t w = 1.0 - cdef float64_t val = 0.0 - cdef float64_t wt = 0.0 - cdef float64_t below_top = 0.0 - cdef float64_t below_wt = 0.0 - cdef float64_t median = 0.0 - cdef float64_t half_weight - - p = start - for _ in range(self.n_node_samples): - i = sample_indices[p] - if sample_weight is not None: - w = sample_weight[i] - y = ys[i, k] - - # Insert into the appropriate heap - if self.below.is_empty(): - self.above.push(y, w) - else: - if y > self.below.top(): - self.above.push(y, w) - else: - self.below.push(y, w) - - half_weight = (self.above.get_total_weight() + self.below.get_total_weight()) / 2.0 - - # Rebalance heaps - while self.above.get_total_weight() < half_weight and not self.below.is_empty(): - if self.below.pop(&val, &wt) == 0: - self.above.push(val, wt) - while (not self.above.is_empty() - and (self.above.get_total_weight() - self.above.top_weight()) > half_weight): - if self.above.pop(&val, &wt) == 0: - self.below.push(val, wt) - - # Current median - median = self.above.top() - medians[j] = median - abs_errors[j] = ( - (self.below.get_total_weight() - self.above.get_total_weight()) * median - - self.below.get_weighted_sum() - + self.above.get_weighted_sum() - ) - p += step - j += step - cdef void init_missing(self, intp_t n_missing) noexcept nogil: """Raise error if n_missing != 0.""" if n_missing == 0: @@ -1352,23 +1275,38 @@ cdef class MAE(RegressionCriterion): Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. """ + if False: + printf("Reset\n") + + printf("indices:") + for p in range(self.start, self.end): + printf(" %d", self.sample_indices[p]) + printf("\n") + self.weighted_n_left = 0.0 self.weighted_n_right = self.weighted_n_node_samples self.pos = self.start + for k in range(self.n_outputs - 1, -1, -1): + precompute_absolute_errors( + self.y, self.sample_weight, self.sample_indices, self.above, self.below, + k, self.start, self.end, self.left_abs_errors[k], self.left_medians + ) + precompute_absolute_errors( + self.y, self.sample_weight, self.sample_indices, self.above, self.below, + k, self.end - 1, self.start - 1, self.right_abs_errors[k], self.right_medians + ) + self.node_medians[k] = self.right_medians[0] + # printf('Node median: %.2f\n', self.right_medians[0]) + return 0 cdef int reverse_reset(self) except -1 nogil: - """Reset the criterion at pos=end. - - Returns -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. """ - self.weighted_n_right = 0.0 - self.weighted_n_left = self.weighted_n_node_samples - self.pos = self.end - - return 0 + In this class, this function is never called + (all calls are from inside other methods of other classes) + """ + return -1 cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left. @@ -1381,6 +1319,9 @@ cdef class MAE(RegressionCriterion): cdef const float64_t[:] sample_weight = self.sample_weight cdef const intp_t[:] sample_indices = self.sample_indices + # printf("update: %d->%d; i=%d\n", self.pos, new_pos, sample_indices[self.pos]) + + assert new_pos > self.pos cdef intp_t pos = self.pos cdef intp_t end = self.end cdef intp_t i, p, k @@ -1405,6 +1346,10 @@ cdef class MAE(RegressionCriterion): cdef intp_t k for k in range(self.n_outputs): dest[k] = self.node_medians[k] + # printf("Node value: %.2f\n", self.node_medians[k]) + # for p in range(self.start, self.end): + # printf("%.2f ", self.y[self.sample_indices[p], k]) + # printf("\n") cdef inline float64_t middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average @@ -1470,13 +1415,13 @@ cdef class MAE(RegressionCriterion): cdef float64_t impurity_left = 0.0 cdef float64_t impurity_right = 0.0 - if j > 0: + if self.pos > self.start: # if pos == start, left child is empty, hence impurity is 0 for k in range(self.n_outputs): impurity_left += self.left_abs_errors[k, j - 1] p_impurity_left[0] = impurity_left / (self.weighted_n_left * self.n_outputs) - if self.pos < self.end: + if self.pos < self.end: # if pos == end, right child is empty, hence impurity is 0 for k in range(self.n_outputs): impurity_right += self.right_abs_errors[k, j] p_impurity_right[0] = impurity_right / (self.weighted_n_right * diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 68b7595348a6e..bdd1560cfce74 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -70,5 +70,17 @@ cdef class WeightedHeap: cdef void _perc_up(self, intp_t) noexcept nogil cdef void _perc_down(self, intp_t) noexcept nogil +cdef void precompute_absolute_errors( + const float64_t[:, ::1] ys, + const float64_t[:] sample_weight, + const intp_t[:] sample_indices, + WeightedHeap above, + WeightedHeap below, + intp_t k, + intp_t start, + intp_t end, + float64_t[::1] abs_errors, + float64_t[::1] medians +) noexcept nogil cdef float64_t log(float64_t x) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 2756b32ed0bc8..9e604fa44ef80 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -5,6 +5,8 @@ from libc.stdlib cimport free from libc.stdlib cimport realloc from libc.math cimport log as ln from libc.math cimport isnan +from libc.math cimport fabs +from libc.stdio cimport printf import numpy as np cimport numpy as cnp @@ -245,6 +247,118 @@ cdef class WeightedHeap: return +cdef void precompute_absolute_errors( + const float64_t[:, ::1] ys, + const float64_t[:] sample_weight, + const intp_t[:] sample_indices, + WeightedHeap above, + WeightedHeap below, + intp_t k, + intp_t start, + intp_t end, + float64_t[::1] abs_errors, + float64_t[::1] medians +) noexcept nogil: + """Fill `abs_errors` with prefix minimum AEs for (y[:i], w[:i]), i in [1, n-1]. + + Parameters + ---------- + y : 1D float64_t[::1] + Values. + w : 1D float64_t[::1] + Sample weights. + abs_errors : 1D float64_t[::1] + Output buffer, must have shape (n,). + """ + cdef intp_t j, p, i, step, n + if start < end: + j = 0 + step = 1 + n = end - start + else: + n = start - end + step = -1 + j = n - 1 + + above.reset() + below.reset() + cdef float64_t y + cdef float64_t w = 1.0 + cdef float64_t val = 0.0 + cdef float64_t wt = 0.0 + cdef float64_t below_top = 0.0 + cdef float64_t below_wt = 0.0 + cdef float64_t median = 0.0 + cdef float64_t half_weight + + p = start + for _ in range(n): + i = sample_indices[p] + if sample_weight is not None: + w = sample_weight[i] + y = ys[i, k] + + # Insert into the appropriate heap + if below.is_empty(): + above.push(y, w) + else: + if y > below.top(): + above.push(y, w) + else: + below.push(y, w) + + half_weight = (above.get_total_weight() + below.get_total_weight()) / 2.0 + + # Rebalance heaps + while above.get_total_weight() < half_weight and not below.is_empty(): + if below.pop(&val, &wt) == 0: + above.push(val, wt) + while (not above.is_empty() + and (above.get_total_weight() - above.top_weight()) >= half_weight): + if above.pop(&val, &wt) == 0: + below.push(val, wt) + + # Current median + if above.get_total_weight() > half_weight + 1e-5 * fabs(half_weight): + median = above.top() + else: # above and below weight are almost exaclty equals + median = (above.top() + below.top()) / 2. + medians[j] = median + abs_errors[j] = ( + (below.get_total_weight() - above.get_total_weight()) * median + - below.get_weighted_sum() + + above.get_weighted_sum() + ) + p += step + j += step + +def _py_precompute_absolute_errors( + const float64_t[:, ::1] ys, + const float64_t[:] sample_weight, + const intp_t[:] sample_indices, + bint suffix=False +): + """ For testing """ + cdef: + intp_t n = sample_weight.size + WeightedHeap above = WeightedHeap(n, True) + WeightedHeap below = WeightedHeap(n, False) + intp_t k = 0 + intp_t start = 0 + intp_t end = n + float64_t[::1] abs_errors = np.zeros(n, dtype=np.float64) + float64_t[::1] medians = np.zeros(n, dtype=np.float64) + + if suffix: + start = n - 1 + end = -1 + + precompute_absolute_errors( + ys, sample_weight, sample_indices, above, below, + k, start, end, abs_errors, medians + ) + return np.asarray(abs_errors) + def _any_isnan_axis0(const float32_t[:, :] X): """Same as np.any(np.isnan(X), axis=0)""" From e8adf96d7e8f65e77128f6c5ecca06c750b004a4 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 17:30:35 +0200 Subject: [PATCH 05/62] WIP some additional tests that helped me, some will be kept in my final PR but not all --- sklearn/tree/tests/test_mae_split.py | 304 +++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 sklearn/tree/tests/test_mae_split.py diff --git a/sklearn/tree/tests/test_mae_split.py b/sklearn/tree/tests/test_mae_split.py new file mode 100644 index 0000000000000..cb2bb622fa98e --- /dev/null +++ b/sklearn/tree/tests/test_mae_split.py @@ -0,0 +1,304 @@ +import numpy +from sklearn.tree import DecisionTreeRegressor +from sklearn.tree._utils import _py_precompute_absolute_errors + +if False: + def test_first_split(): + reg = DecisionTreeRegressor(max_depth=1, criterion='absolute_error') + + def mae_min(y, w): + return min((np.abs(y - yi) * w).sum() for yi in y) + + def leaves_mae(l, y): + return np.array([mae_min(y[l == i], w[l == i]) for i in np.unique(l)]) + + X = np.array([ + [ 2.38, 3.13], + [-0.87, 0.24], + [ 3.42, 2.74], + [ 1.43, 2.57], + [ 0.86, 0.26] + ]) + y = np.array([0.784, 0.654, 1.125, 2.010, 0.614]) + w = np.array([0.622, 1.356, 1.206, 0.912, 1.424]) + + leaves = reg.fit(X, y, sample_weight=w).apply(X) + print(leaves) + assert leaves_mae(leaves, y).sum() < 1.1 + + +def sample_X_y_w(n): + x_true = (numpy.random.rand(n) > 0.5).astype(float) + X = numpy.array([ + numpy.random.randn(n) + 2*x_true, + numpy.round(2*numpy.random.rand(n) + 2*x_true, 2) + ]).T + X_pred = numpy.array([ + numpy.random.randn(n) + 2*x_true, + 2*numpy.random.rand(n) + 2*x_true + ]).T + y = numpy.random.rand(n) + (numpy.random.rand(n) + 0.5) * x_true + w = 0.5 + numpy.random.rand(n) + return X, y, w, X_pred + + + +def test_absolute_errors_precomputation_function(): + """ + Test the main bit of logic of the MAE(RegressionCriterion) class + (used by DecisionTreeRegressor()) + + The implemation of the criterion "repose" on an efficient precomputation + of left/right children absolute error for each split. This test verifies this + part of the computation, in case of major refactor of the MAE class, it can be safely removed + """ + + def compute_abs_error(y: numpy.ndarray, w: numpy.ndarray): + # 1) compute the weighted median + # i.e. once ordered by y, search for i such that: + # sum(w[:i]) <= 1/2 and sum(w[i+1:]) <= 1/2 + sorter = numpy.argsort(y) + wc = numpy.cumsum(w[sorter]) + idx = numpy.searchsorted(wc, wc[-1] / 2) + median = y[sorter[idx]] + print(y, median) + # 2) compute the AE + return (numpy.abs(y - median) * w).sum() + + def compute_prefix_abs_errors_naive(y: numpy.ndarray, w: numpy.ndarray): + y = y.ravel() + return numpy.array([compute_abs_error(y[:i], w[:i]) for i in range(1, y.size + 1)]) + + + for n in [3, 5, 10, 20, 100, 300]: + y = numpy.random.uniform(size=(n, 1)) + w = numpy.random.rand(n) + indices = numpy.arange(n) + abs_errors = _py_precompute_absolute_errors(y, w, indices) + expected = compute_prefix_abs_errors_naive(y, w) + assert numpy.allclose(abs_errors, expected) + + abs_errors = _py_precompute_absolute_errors(y, w, indices, suffix=True) + expected = compute_prefix_abs_errors_naive(y[::-1], w[::-1])[::-1] + assert numpy.allclose(abs_errors, expected) + + x = numpy.random.rand(n) + indices = numpy.argsort(x) + w[:] = 1 + y_sorted = y[indices] + w_sorted = w[indices] + + abs_errors = _py_precompute_absolute_errors(y, w, indices) + expected = compute_prefix_abs_errors_naive(y_sorted, w_sorted) + assert numpy.allclose(abs_errors, expected) + + abs_errors = _py_precompute_absolute_errors(y, w, indices, suffix=True) + expected = compute_prefix_abs_errors_naive(y_sorted[::-1], w_sorted[::-1])[::-1] + assert numpy.allclose(abs_errors, expected) + + + +def test_first_split(): + + def mae_min(y, w): + return min((numpy.abs(y - yi) * w).sum() for yi in y) + + def leaves_mae(l, y, w=None): + if w is None: + w = numpy.ones(y.size) + return numpy.array([mae_min(y[l == i], w[l == i]) for i in numpy.unique(l)]) + + it = 0 + for n in [5]*100 + [10]*100 + [100]*100 + [1000]*10 + [10_000]*3: + it += 1 + X, y, w, _ = sample_X_y_w(n) + + reg = DecisionTreeRegressor(max_depth=1, criterion='absolute_error') + sk_leaves = reg.fit(X, y, sample_weight=w).apply(X) + h_leaves = fit_apply(X, y, X_apply=X, sample_weights=w) + are_leaves_the_same = (sk_leaves == h_leaves).all() or (sk_leaves == (3 - h_leaves)).all() + if not are_leaves_the_same: + sk_mae = leaves_mae(sk_leaves, y, w).sum() + h_mae = leaves_mae(h_leaves, y, w).sum() + assert numpy.isclose(sk_mae, h_mae), it + + for n in [5]*100 + [10]*100 + [100]*100 + [1000]*10 + [10_000]*3: + it += 1 + X, y, _, _ = sample_X_y_w(n) + reg = DecisionTreeRegressor(max_depth=1, criterion='absolute_error') + sk_leaves = reg.fit(X, y).apply(X) + h_leaves = fit_apply(X, y, X) + are_leaves_the_same = (sk_leaves == h_leaves).all() or (sk_leaves == (3 - h_leaves)).all() + if not are_leaves_the_same: + sk_mae = leaves_mae(sk_leaves, y).sum() + h_mae = leaves_mae(h_leaves, y).sum() + assert numpy.isclose(sk_mae, h_mae), it + + +def fit_apply(X, y, X_apply=None, sample_weights=None): + if sample_weights is None: + sample_weights = numpy.ones(y.size) + X_apply = X if X_apply is None else X_apply + best_mae, best_feature, best_threshold = numpy.inf, -1, numpy.nan + for k, x in enumerate(X.T): + threshold, split_mae = min_mae_split(x, y, sample_weights) + if split_mae < best_mae: + best_mae = split_mae + best_feature = k + best_threshold = threshold + return (X_apply[:, best_feature] >= best_threshold) + 1 + + +def min_mae_split(x, y, w): + """ + Find the best split of x that minimizes the sum of left and right MAEs. + + Sorts and deduplicates x, y, w, then computes the MAE for all possible splits using splits_left_mae. + Returns the split value and the corresponding MAE. + + Parameters + ---------- + x : np.ndarray + Feature values. + y : np.ndarray + Target values. + w : np.ndarray + Sample weights. + + Returns + ------- + x_split : float + The value of x at which to split. + split_mae : float + The minimum sum of left and right MAEs. + """ + sorter = numpy.argsort(x) + x = x[sorter] + y = y[sorter] + w = w[sorter] + prefix_maes = compute_prefix_maes(y, w) + suffix_maes = compute_prefix_maes(y[::-1], w[::-1])[::-1] + maes = prefix_maes + suffix_maes # size: n-1 + maes[x[:-1] == x[1:]] = numpy.inf # impossible to split between 2 points that are exactly equals + best_split = numpy.argmin(maes) + # Choose split point between best_split and its neighbor with lower MAE + x_split = (x[best_split] + x[best_split + 1]) / 2 + split_mae = maes[best_split] + return x_split, split_mae + + +def compute_prefix_maes(y: numpy.ndarray, w: numpy.ndarray): + """ + Compute the minimum mean absolute error (MAE) for all (y[:i], w[:i]) with i ranging in [1, n-1] + O(n log n) complexity, expect for patological cases (w growing faster than x^2) + + Parameters + ---------- + y : numpy.ndarray + Array of target values + w : numpy.ndarray + Array of sample weights. + Returns + ------- + maes : numpy.ndarray + Prefix array of MAE values + """ + n = y.size + above = WeightedHeap(n, True) # Min-heap for values above the median + below = WeightedHeap(n, False) # Max-heap for values below the median + maes = numpy.full(n-1, numpy.inf) + for i in range(n - 1): + # Insert y[i] into the appropriate heap + if above.empty() or y[i] > below.top(): + above.push(y[i], w[i]) + else: + below.push(y[i], w[i]) + + half_weight = (above.total_weight + below.total_weight) / 2 + # Rebalance the heaps, we want to ensure that: + # above.total_weight >= 1/2 and above.total_weight - above.top_weight() <= 1/2 + # which ensures that above.top() is a weighted median of the heap + # and in particular, an argmin for the MAE + while above.total_weight < half_weight: + yt, wt = below.pop() + above.push(yt, wt) + while above.total_weight - above.top_weight() > half_weight: + yt, wt = above.pop() + below.push(yt, wt) + + median = above.top() # Current weighted median + # Compute MAE for this split + maes[i] = ( + (below.total_weight - above.total_weight) * median + - below.weighted_sum + + above.weighted_sum + ) + return maes + + +class WeightedHeap: + + def __init__(self, max_size, min_heap=True): + self.heap = numpy.zeros(max_size, dtype=numpy.float64) + self.weights = numpy.zeros(max_size, dtype=numpy.float64) + self.total_weight = 0 + self.weighted_sum = 0 + self.size = 0 + self.min_heap = min_heap + + def empty(self): + return self.size == 0 + + def push(self, val, weight): + self.heap[self.size] = val if self.min_heap else -val + self.weights[self.size] = weight + self.total_weight += weight + self.weighted_sum += val * weight + self.size += 1 + self._perc_up(self.size - 1) + + def swap(self, i, j): + self.heap[i], self.heap[j] = self.heap[j], self.heap[i] + self.weights[i], self.weights[j] = self.weights[j], self.weights[i] + + def top(self): + return self.heap[0] if self.min_heap else -self.heap[0] + + def top_weight(self): + return self.weights[0] + + def pop(self): + retv = self.top() + retw = self.top_weight() + self.size -= 1 + self.total_weight -= retw + self.weighted_sum -= retv * retw + self.heap[0] = self.heap[self.size] + self.weights[0] = self.weights[self.size] + self._perc_down(0) + return retv, retw + + def _perc_up(self, i): + p = (i - 1) >> 1 + while p >= 0: + if self.heap[i] < self.heap[p]: + self.swap(i, p) + i = p + p = (i - 1) >> 1 + + def _perc_down(self, i): + while (i << 1) + 2 <= self.size: + mc_i = self._min_child_node(i) + if self.heap[i] > self.heap[mc_i]: + self.swap(i, mc_i) + i = mc_i + + def _min_child_node(self, i): + if (i << 1) + 2 == self.size: + return (i << 1) | 1 + else: + if self.heap[(i << 1) | 1] < self.heap[(i << 1) + 2]: + return (i << 1) | 1 + else: + return (i << 1) + 2 + From 4ed868e1f06e1babb604a65a028d554f7ccf8fb6 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 17:34:03 +0200 Subject: [PATCH 06/62] tests cleanup --- sklearn/tree/tests/test_mae_split.py | 304 --------------------------- sklearn/tree/tests/test_tree.py | 55 +++++ 2 files changed, 55 insertions(+), 304 deletions(-) delete mode 100644 sklearn/tree/tests/test_mae_split.py diff --git a/sklearn/tree/tests/test_mae_split.py b/sklearn/tree/tests/test_mae_split.py deleted file mode 100644 index cb2bb622fa98e..0000000000000 --- a/sklearn/tree/tests/test_mae_split.py +++ /dev/null @@ -1,304 +0,0 @@ -import numpy -from sklearn.tree import DecisionTreeRegressor -from sklearn.tree._utils import _py_precompute_absolute_errors - -if False: - def test_first_split(): - reg = DecisionTreeRegressor(max_depth=1, criterion='absolute_error') - - def mae_min(y, w): - return min((np.abs(y - yi) * w).sum() for yi in y) - - def leaves_mae(l, y): - return np.array([mae_min(y[l == i], w[l == i]) for i in np.unique(l)]) - - X = np.array([ - [ 2.38, 3.13], - [-0.87, 0.24], - [ 3.42, 2.74], - [ 1.43, 2.57], - [ 0.86, 0.26] - ]) - y = np.array([0.784, 0.654, 1.125, 2.010, 0.614]) - w = np.array([0.622, 1.356, 1.206, 0.912, 1.424]) - - leaves = reg.fit(X, y, sample_weight=w).apply(X) - print(leaves) - assert leaves_mae(leaves, y).sum() < 1.1 - - -def sample_X_y_w(n): - x_true = (numpy.random.rand(n) > 0.5).astype(float) - X = numpy.array([ - numpy.random.randn(n) + 2*x_true, - numpy.round(2*numpy.random.rand(n) + 2*x_true, 2) - ]).T - X_pred = numpy.array([ - numpy.random.randn(n) + 2*x_true, - 2*numpy.random.rand(n) + 2*x_true - ]).T - y = numpy.random.rand(n) + (numpy.random.rand(n) + 0.5) * x_true - w = 0.5 + numpy.random.rand(n) - return X, y, w, X_pred - - - -def test_absolute_errors_precomputation_function(): - """ - Test the main bit of logic of the MAE(RegressionCriterion) class - (used by DecisionTreeRegressor()) - - The implemation of the criterion "repose" on an efficient precomputation - of left/right children absolute error for each split. This test verifies this - part of the computation, in case of major refactor of the MAE class, it can be safely removed - """ - - def compute_abs_error(y: numpy.ndarray, w: numpy.ndarray): - # 1) compute the weighted median - # i.e. once ordered by y, search for i such that: - # sum(w[:i]) <= 1/2 and sum(w[i+1:]) <= 1/2 - sorter = numpy.argsort(y) - wc = numpy.cumsum(w[sorter]) - idx = numpy.searchsorted(wc, wc[-1] / 2) - median = y[sorter[idx]] - print(y, median) - # 2) compute the AE - return (numpy.abs(y - median) * w).sum() - - def compute_prefix_abs_errors_naive(y: numpy.ndarray, w: numpy.ndarray): - y = y.ravel() - return numpy.array([compute_abs_error(y[:i], w[:i]) for i in range(1, y.size + 1)]) - - - for n in [3, 5, 10, 20, 100, 300]: - y = numpy.random.uniform(size=(n, 1)) - w = numpy.random.rand(n) - indices = numpy.arange(n) - abs_errors = _py_precompute_absolute_errors(y, w, indices) - expected = compute_prefix_abs_errors_naive(y, w) - assert numpy.allclose(abs_errors, expected) - - abs_errors = _py_precompute_absolute_errors(y, w, indices, suffix=True) - expected = compute_prefix_abs_errors_naive(y[::-1], w[::-1])[::-1] - assert numpy.allclose(abs_errors, expected) - - x = numpy.random.rand(n) - indices = numpy.argsort(x) - w[:] = 1 - y_sorted = y[indices] - w_sorted = w[indices] - - abs_errors = _py_precompute_absolute_errors(y, w, indices) - expected = compute_prefix_abs_errors_naive(y_sorted, w_sorted) - assert numpy.allclose(abs_errors, expected) - - abs_errors = _py_precompute_absolute_errors(y, w, indices, suffix=True) - expected = compute_prefix_abs_errors_naive(y_sorted[::-1], w_sorted[::-1])[::-1] - assert numpy.allclose(abs_errors, expected) - - - -def test_first_split(): - - def mae_min(y, w): - return min((numpy.abs(y - yi) * w).sum() for yi in y) - - def leaves_mae(l, y, w=None): - if w is None: - w = numpy.ones(y.size) - return numpy.array([mae_min(y[l == i], w[l == i]) for i in numpy.unique(l)]) - - it = 0 - for n in [5]*100 + [10]*100 + [100]*100 + [1000]*10 + [10_000]*3: - it += 1 - X, y, w, _ = sample_X_y_w(n) - - reg = DecisionTreeRegressor(max_depth=1, criterion='absolute_error') - sk_leaves = reg.fit(X, y, sample_weight=w).apply(X) - h_leaves = fit_apply(X, y, X_apply=X, sample_weights=w) - are_leaves_the_same = (sk_leaves == h_leaves).all() or (sk_leaves == (3 - h_leaves)).all() - if not are_leaves_the_same: - sk_mae = leaves_mae(sk_leaves, y, w).sum() - h_mae = leaves_mae(h_leaves, y, w).sum() - assert numpy.isclose(sk_mae, h_mae), it - - for n in [5]*100 + [10]*100 + [100]*100 + [1000]*10 + [10_000]*3: - it += 1 - X, y, _, _ = sample_X_y_w(n) - reg = DecisionTreeRegressor(max_depth=1, criterion='absolute_error') - sk_leaves = reg.fit(X, y).apply(X) - h_leaves = fit_apply(X, y, X) - are_leaves_the_same = (sk_leaves == h_leaves).all() or (sk_leaves == (3 - h_leaves)).all() - if not are_leaves_the_same: - sk_mae = leaves_mae(sk_leaves, y).sum() - h_mae = leaves_mae(h_leaves, y).sum() - assert numpy.isclose(sk_mae, h_mae), it - - -def fit_apply(X, y, X_apply=None, sample_weights=None): - if sample_weights is None: - sample_weights = numpy.ones(y.size) - X_apply = X if X_apply is None else X_apply - best_mae, best_feature, best_threshold = numpy.inf, -1, numpy.nan - for k, x in enumerate(X.T): - threshold, split_mae = min_mae_split(x, y, sample_weights) - if split_mae < best_mae: - best_mae = split_mae - best_feature = k - best_threshold = threshold - return (X_apply[:, best_feature] >= best_threshold) + 1 - - -def min_mae_split(x, y, w): - """ - Find the best split of x that minimizes the sum of left and right MAEs. - - Sorts and deduplicates x, y, w, then computes the MAE for all possible splits using splits_left_mae. - Returns the split value and the corresponding MAE. - - Parameters - ---------- - x : np.ndarray - Feature values. - y : np.ndarray - Target values. - w : np.ndarray - Sample weights. - - Returns - ------- - x_split : float - The value of x at which to split. - split_mae : float - The minimum sum of left and right MAEs. - """ - sorter = numpy.argsort(x) - x = x[sorter] - y = y[sorter] - w = w[sorter] - prefix_maes = compute_prefix_maes(y, w) - suffix_maes = compute_prefix_maes(y[::-1], w[::-1])[::-1] - maes = prefix_maes + suffix_maes # size: n-1 - maes[x[:-1] == x[1:]] = numpy.inf # impossible to split between 2 points that are exactly equals - best_split = numpy.argmin(maes) - # Choose split point between best_split and its neighbor with lower MAE - x_split = (x[best_split] + x[best_split + 1]) / 2 - split_mae = maes[best_split] - return x_split, split_mae - - -def compute_prefix_maes(y: numpy.ndarray, w: numpy.ndarray): - """ - Compute the minimum mean absolute error (MAE) for all (y[:i], w[:i]) with i ranging in [1, n-1] - O(n log n) complexity, expect for patological cases (w growing faster than x^2) - - Parameters - ---------- - y : numpy.ndarray - Array of target values - w : numpy.ndarray - Array of sample weights. - Returns - ------- - maes : numpy.ndarray - Prefix array of MAE values - """ - n = y.size - above = WeightedHeap(n, True) # Min-heap for values above the median - below = WeightedHeap(n, False) # Max-heap for values below the median - maes = numpy.full(n-1, numpy.inf) - for i in range(n - 1): - # Insert y[i] into the appropriate heap - if above.empty() or y[i] > below.top(): - above.push(y[i], w[i]) - else: - below.push(y[i], w[i]) - - half_weight = (above.total_weight + below.total_weight) / 2 - # Rebalance the heaps, we want to ensure that: - # above.total_weight >= 1/2 and above.total_weight - above.top_weight() <= 1/2 - # which ensures that above.top() is a weighted median of the heap - # and in particular, an argmin for the MAE - while above.total_weight < half_weight: - yt, wt = below.pop() - above.push(yt, wt) - while above.total_weight - above.top_weight() > half_weight: - yt, wt = above.pop() - below.push(yt, wt) - - median = above.top() # Current weighted median - # Compute MAE for this split - maes[i] = ( - (below.total_weight - above.total_weight) * median - - below.weighted_sum - + above.weighted_sum - ) - return maes - - -class WeightedHeap: - - def __init__(self, max_size, min_heap=True): - self.heap = numpy.zeros(max_size, dtype=numpy.float64) - self.weights = numpy.zeros(max_size, dtype=numpy.float64) - self.total_weight = 0 - self.weighted_sum = 0 - self.size = 0 - self.min_heap = min_heap - - def empty(self): - return self.size == 0 - - def push(self, val, weight): - self.heap[self.size] = val if self.min_heap else -val - self.weights[self.size] = weight - self.total_weight += weight - self.weighted_sum += val * weight - self.size += 1 - self._perc_up(self.size - 1) - - def swap(self, i, j): - self.heap[i], self.heap[j] = self.heap[j], self.heap[i] - self.weights[i], self.weights[j] = self.weights[j], self.weights[i] - - def top(self): - return self.heap[0] if self.min_heap else -self.heap[0] - - def top_weight(self): - return self.weights[0] - - def pop(self): - retv = self.top() - retw = self.top_weight() - self.size -= 1 - self.total_weight -= retw - self.weighted_sum -= retv * retw - self.heap[0] = self.heap[self.size] - self.weights[0] = self.weights[self.size] - self._perc_down(0) - return retv, retw - - def _perc_up(self, i): - p = (i - 1) >> 1 - while p >= 0: - if self.heap[i] < self.heap[p]: - self.swap(i, p) - i = p - p = (i - 1) >> 1 - - def _perc_down(self, i): - while (i << 1) + 2 <= self.size: - mc_i = self._min_child_node(i) - if self.heap[i] > self.heap[mc_i]: - self.swap(i, mc_i) - i = mc_i - - def _min_child_node(self, i): - if (i << 1) + 2 == self.size: - return (i << 1) | 1 - else: - if self.heap[(i << 1) | 1] < self.heap[(i << 1) + 2]: - return (i << 1) | 1 - else: - return (i << 1) + 2 - diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index bd8325f6e9a55..3204d8cf46943 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -47,6 +47,7 @@ _check_value_ndarray, ) from sklearn.tree._tree import Tree as CythonTree +from sklearn.tree._utils import _py_precompute_absolute_errors from sklearn.utils import compute_sample_weight from sklearn.utils._array_api import xpx from sklearn.utils._testing import ( @@ -2838,3 +2839,57 @@ def test_sort_log2_build(): ] # fmt: on assert_array_equal(samples, expected_samples) + + +def test_absolute_errors_precomputation_function(): + """ + Test the main bit of logic of the MAE(RegressionCriterion) class + (used by DecisionTreeRegressor()) + + The implemation of the criterion "repose" on an efficient precomputation + of left/right children absolute error for each split. This test verifies this + part of the computation, in case of major refactor of the MAE class, it can be safely removed + """ + + def compute_abs_error(y: np.ndarray, w: np.ndarray): + # 1) compute the weighted median + # i.e. once ordered by y, search for i such that: + # sum(w[:i]) <= 1/2 and sum(w[i+1:]) <= 1/2 + sorter = np.argsort(y) + wc = np.cumsum(w[sorter]) + idx = np.searchsorted(wc, wc[-1] / 2) + median = y[sorter[idx]] + print(y, median) + # 2) compute the AE + return (np.abs(y - median) * w).sum() + + def compute_prefix_abs_errors_naive(y: np.ndarray, w: np.ndarray): + y = y.ravel() + return np.array([compute_abs_error(y[:i], w[:i]) for i in range(1, y.size + 1)]) + + + for n in [3, 5, 10, 20, 100, 300]: + y = np.random.uniform(size=(n, 1)) + w = np.random.rand(n) + indices = np.arange(n) + abs_errors = _py_precompute_absolute_errors(y, w, indices) + expected = compute_prefix_abs_errors_naive(y, w) + assert np.allclose(abs_errors, expected) + + abs_errors = _py_precompute_absolute_errors(y, w, indices, suffix=True) + expected = compute_prefix_abs_errors_naive(y[::-1], w[::-1])[::-1] + assert np.allclose(abs_errors, expected) + + x = np.random.rand(n) + indices = np.argsort(x) + w[:] = 1 + y_sorted = y[indices] + w_sorted = w[indices] + + abs_errors = _py_precompute_absolute_errors(y, w, indices) + expected = compute_prefix_abs_errors_naive(y_sorted, w_sorted) + assert np.allclose(abs_errors, expected) + + abs_errors = _py_precompute_absolute_errors(y, w, indices, suffix=True) + expected = compute_prefix_abs_errors_naive(y_sorted[::-1], w_sorted[::-1])[::-1] + assert np.allclose(abs_errors, expected) From 83d89a43582011edaa448bd44e338ca55fde277c Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 17:38:59 +0200 Subject: [PATCH 07/62] cleanup --- sklearn/tree/_criterion.pyx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 65971ebeaec22..d2dad1c9123f0 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1314,7 +1314,7 @@ cdef class MAE(RegressionCriterion): Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. - Time complexity: O(new_pos - pos) (which usually is O(1)) + Time complexity: O(new_pos - pos) (which usually is O(1), at least for dense data) """ cdef const float64_t[:] sample_weight = self.sample_weight cdef const intp_t[:] sample_indices = self.sample_indices @@ -1360,7 +1360,7 @@ cdef class MAE(RegressionCriterion): """ cdef intp_t j = self.pos - self.start return ( - self.left_medians[j] + self.left_medians[j - 1] + self.right_medians[j] ) / 2 @@ -1375,7 +1375,7 @@ cdef class MAE(RegressionCriterion): return self._check_monotonicity( monotonic_cst, lower_bound, upper_bound, - self.left_medians[j], self.right_medians[j]) + self.left_medians[j - 1], self.right_medians[j]) cdef float64_t node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. @@ -1383,6 +1383,8 @@ cdef class MAE(RegressionCriterion): Evaluate the MAE criterion as impurity of the current node, i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. + + Time complexity: O(n := end - start) """ cdef const float64_t[:] sample_weight = self.sample_weight cdef const intp_t[:] sample_indices = self.sample_indices From 1ca34bf8cf4a2d1003cbbd850a5b84be9ff03d6d Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 18:06:45 +0200 Subject: [PATCH 08/62] cleanup --- sklearn/tree/_criterion.pyx | 48 ++++++++++++--------------------- sklearn/tree/_utils.pyx | 6 ++--- sklearn/tree/tests/test_tree.py | 20 +++++++------- 3 files changed, 30 insertions(+), 44 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index d2dad1c9123f0..3d2ce9df6f4e2 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -4,7 +4,6 @@ from libc.string cimport memcpy from libc.string cimport memset from libc.math cimport fabs, INFINITY -from libc.stdio cimport printf import numpy as np cimport numpy as cnp @@ -1212,11 +1211,14 @@ cdef class MAE(RegressionCriterion): self.weighted_n_right = 0.0 self.node_medians = np.zeros(n_outputs, dtype=np.float64) + # FIXME? Those arrays could maybe be allocated dynamically to reduce memory + # footprint. This might even be required. self.left_abs_errors = np.empty((n_outputs, n_samples), dtype=np.float64) self.right_abs_errors = np.empty((n_outputs, n_samples), dtype=np.float64) self.left_medians = np.empty(n_samples, dtype=np.float64) self.right_medians = np.empty(n_samples, dtype=np.float64) + # FIXME? Same here, I could adapt the code of WeightedHeap to use dynamic allocation self.above = WeightedHeap(n_samples, True) # min-heap self.below = WeightedHeap(n_samples, False) # max-heap @@ -1235,6 +1237,7 @@ cdef class MAE(RegressionCriterion): sample_indices[start:start] and sample_indices[start:end]. WARNING: sample_indices will be modified in-place externally + after this method is called """ cdef intp_t i, k, j cdef float64_t w = 1.0 @@ -1249,13 +1252,10 @@ cdef class MAE(RegressionCriterion): self.weighted_n_samples = weighted_n_samples self.weighted_n_node_samples = 0. - # printf("start - end: %d %d\n", start, end) - for p in range(start, end): i = sample_indices[p] if sample_weight is not None: w = sample_weight[i] - # printf(" %.2f", y[i, 0]) self.weighted_n_node_samples += w # Reset to pos=start @@ -1274,14 +1274,11 @@ cdef class MAE(RegressionCriterion): Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. - """ - if False: - printf("Reset\n") - printf("indices:") - for p in range(self.start, self.end): - printf(" %d", self.sample_indices[p]) - printf("\n") + Reset might be called after an external class has changed + inplace self.sample_indices[start:end], hence re-computing + the absolute errors is needed + """ self.weighted_n_left = 0.0 self.weighted_n_right = self.weighted_n_node_samples @@ -1297,31 +1294,24 @@ cdef class MAE(RegressionCriterion): k, self.end - 1, self.start - 1, self.right_abs_errors[k], self.right_medians ) self.node_medians[k] = self.right_medians[0] - # printf('Node median: %.2f\n', self.right_medians[0]) return 0 cdef int reverse_reset(self) except -1 nogil: """ - In this class, this function is never called - (all calls are from inside other methods of other classes) + For this class, this function is never called """ return -1 cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left. + new_pos is guaranted to be greater than pos Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. Time complexity: O(new_pos - pos) (which usually is O(1), at least for dense data) """ - cdef const float64_t[:] sample_weight = self.sample_weight - cdef const intp_t[:] sample_indices = self.sample_indices - - # printf("update: %d->%d; i=%d\n", self.pos, new_pos, sample_indices[self.pos]) - - assert new_pos > self.pos cdef intp_t pos = self.pos cdef intp_t end = self.end cdef intp_t i, p, k @@ -1329,11 +1319,9 @@ cdef class MAE(RegressionCriterion): # Update statistics up to new_pos for p in range(pos, new_pos): - i = sample_indices[p] - - if sample_weight is not None: - w = sample_weight[i] - + i = self.sample_indices[p] + if self.sample_weight is not None: + w = self.sample_weight[i] self.weighted_n_left += w self.weighted_n_right = (self.weighted_n_node_samples - @@ -1346,10 +1334,6 @@ cdef class MAE(RegressionCriterion): cdef intp_t k for k in range(self.n_outputs): dest[k] = self.node_medians[k] - # printf("Node value: %.2f\n", self.node_medians[k]) - # for p in range(self.start, self.end): - # printf("%.2f ", self.y[self.sample_indices[p], k]) - # printf("\n") cdef inline float64_t middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average @@ -1417,13 +1401,15 @@ cdef class MAE(RegressionCriterion): cdef float64_t impurity_left = 0.0 cdef float64_t impurity_right = 0.0 - if self.pos > self.start: # if pos == start, left child is empty, hence impurity is 0 + # if pos == start, left child is empty, hence impurity is 0 + if self.pos > self.start: for k in range(self.n_outputs): impurity_left += self.left_abs_errors[k, j - 1] p_impurity_left[0] = impurity_left / (self.weighted_n_left * self.n_outputs) - if self.pos < self.end: # if pos == end, right child is empty, hence impurity is 0 + # if pos == end, right child is empty, hence impurity is 0 + if self.pos < self.end: for k in range(self.n_outputs): impurity_right += self.right_abs_errors[k, j] p_impurity_right[0] = impurity_right / (self.weighted_n_right * diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 9e604fa44ef80..b26dbc86db2ac 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -6,7 +6,6 @@ from libc.stdlib cimport realloc from libc.math cimport log as ln from libc.math cimport isnan from libc.math cimport fabs -from libc.stdio cimport printf import numpy as np cimport numpy as cnp @@ -259,7 +258,8 @@ cdef void precompute_absolute_errors( float64_t[::1] abs_errors, float64_t[::1] medians ) noexcept nogil: - """Fill `abs_errors` with prefix minimum AEs for (y[:i], w[:i]), i in [1, n-1]. + """Fill `abs_errors` with the optimal AEs for (y[:i], w[:i]) + i in [1, n]. Parameters ---------- @@ -338,7 +338,7 @@ def _py_precompute_absolute_errors( const intp_t[:] sample_indices, bint suffix=False ): - """ For testing """ + """Used for testing precompute_absolute_errors""" cdef: intp_t n = sample_weight.size WeightedHeap above = WeightedHeap(n, True) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 3204d8cf46943..85235e36fb409 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2843,14 +2843,19 @@ def test_sort_log2_build(): def test_absolute_errors_precomputation_function(): """ - Test the main bit of logic of the MAE(RegressionCriterion) class - (used by DecisionTreeRegressor()) + Test the main bit of logic of the MAE(RegressionCriterion) class + (used by DecisionTreeRegressor(criterion="asbolute_error")). - The implemation of the criterion "repose" on an efficient precomputation + The implemation of the criterion relies on an efficient precomputation of left/right children absolute error for each split. This test verifies this - part of the computation, in case of major refactor of the MAE class, it can be safely removed + part of the computation, in case of major refactor of the MAE class, + it can be safely removed. """ + def compute_prefix_abs_errors_naive(y: np.ndarray, w: np.ndarray): + y = y.ravel() + return np.array([compute_abs_error(y[:i], w[:i]) for i in range(1, y.size + 1)]) + def compute_abs_error(y: np.ndarray, w: np.ndarray): # 1) compute the weighted median # i.e. once ordered by y, search for i such that: @@ -2863,12 +2868,7 @@ def compute_abs_error(y: np.ndarray, w: np.ndarray): # 2) compute the AE return (np.abs(y - median) * w).sum() - def compute_prefix_abs_errors_naive(y: np.ndarray, w: np.ndarray): - y = y.ravel() - return np.array([compute_abs_error(y[:i], w[:i]) for i in range(1, y.size + 1)]) - - - for n in [3, 5, 10, 20, 100, 300]: + for n in [3, 5, 10, 20, 50, 100]: y = np.random.uniform(size=(n, 1)) w = np.random.rand(n) indices = np.arange(n) From d46355846d3505b6f6559b2775d19191263057bc Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 22:30:13 +0200 Subject: [PATCH 09/62] WIP fixing linting issues --- sklearn/tree/_criterion.pyx | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 3d2ce9df6f4e2..1aceef20b83ed 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1239,7 +1239,7 @@ cdef class MAE(RegressionCriterion): WARNING: sample_indices will be modified in-place externally after this method is called """ - cdef intp_t i, k, j + cdef intp_t i cdef float64_t w = 1.0 cdef intp_t n = end - start # Initialize fields @@ -1313,8 +1313,7 @@ cdef class MAE(RegressionCriterion): Time complexity: O(new_pos - pos) (which usually is O(1), at least for dense data) """ cdef intp_t pos = self.pos - cdef intp_t end = self.end - cdef intp_t i, p, k + cdef intp_t i, p cdef float64_t w = 1.0 # Update statistics up to new_pos @@ -1368,7 +1367,7 @@ cdef class MAE(RegressionCriterion): i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. - Time complexity: O(n := end - start) + Time complexity: O(n := end - start) """ cdef const float64_t[:] sample_weight = self.sample_weight cdef const intp_t[:] sample_indices = self.sample_indices From fa993d4dd67fa553a3441cea2c145c16648a80e3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 22:33:38 +0200 Subject: [PATCH 10/62] fixed linting --- sklearn/tree/_utils.pyx | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index b26dbc86db2ac..35f8373cd0955 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -199,7 +199,6 @@ cdef class WeightedHeap: cdef float64_t s = self.heap_[0] return s if self.min_heap else -s - # ---------------------------- # Internal helpers (nogil) # ---------------------------- @@ -286,8 +285,6 @@ cdef void precompute_absolute_errors( cdef float64_t w = 1.0 cdef float64_t val = 0.0 cdef float64_t wt = 0.0 - cdef float64_t below_top = 0.0 - cdef float64_t below_wt = 0.0 cdef float64_t median = 0.0 cdef float64_t half_weight @@ -313,15 +310,17 @@ cdef void precompute_absolute_errors( while above.get_total_weight() < half_weight and not below.is_empty(): if below.pop(&val, &wt) == 0: above.push(val, wt) - while (not above.is_empty() - and (above.get_total_weight() - above.top_weight()) >= half_weight): + while ( + not above.is_empty() + and (above.get_total_weight() - above.top_weight()) >= half_weight + ): if above.pop(&val, &wt) == 0: below.push(val, wt) # Current median if above.get_total_weight() > half_weight + 1e-5 * fabs(half_weight): median = above.top() - else: # above and below weight are almost exaclty equals + else: # above and below weight are almost exactly equals median = (above.top() + below.top()) / 2. medians[j] = median abs_errors[j] = ( @@ -332,6 +331,7 @@ cdef void precompute_absolute_errors( p += step j += step + def _py_precompute_absolute_errors( const float64_t[:, ::1] ys, const float64_t[:] sample_weight, From cbf54057568099227d478b0bab2da85f573db682 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 3 Sep 2025 22:38:37 +0200 Subject: [PATCH 11/62] fix spelling --- sklearn/tree/_criterion.pyx | 2 +- sklearn/tree/tests/test_tree.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 1aceef20b83ed..7919b9deee837 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1305,7 +1305,7 @@ cdef class MAE(RegressionCriterion): cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left. - new_pos is guaranted to be greater than pos + new_pos is guaranteed to be greater than pos Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 85235e36fb409..d214b1aee58e7 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2846,7 +2846,7 @@ def test_absolute_errors_precomputation_function(): Test the main bit of logic of the MAE(RegressionCriterion) class (used by DecisionTreeRegressor(criterion="asbolute_error")). - The implemation of the criterion relies on an efficient precomputation + The implementation of the criterion relies on an efficient precomputation of left/right children absolute error for each split. This test verifies this part of the computation, in case of major refactor of the MAE class, it can be safely removed. From a4bd31043b7643672954086cab3b35f13b7469f2 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 4 Sep 2025 13:21:28 +0200 Subject: [PATCH 12/62] Added test that would fail before this PR --- sklearn/tree/tests/test_tree.py | 41 ++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index d214b1aee58e7..04adff64bb4c4 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1663,8 +1663,9 @@ def test_no_sparse_y_support(name, csr_container): def test_mae(): - """Check MAE criterion produces correct results on small toy dataset: + """Check MAE criterion produces correct results on small toys dataset: + ## First toy dataset ------------------ | X | y | weight | ------------------ @@ -1735,6 +1736,31 @@ def test_mae(): = 1.2 / 1.6 = 0.75 ------ + + ## Second toy dataset: + ------------------ + | X | y | weight | + ------------------ + | 1 | 1 | 3 | + | 2 | 1 | 3 | + | 3 | 3 | 2 | + | 4 | 1 | 1 | + | 5 | 2 | 2 | + ------------------ + |sum wt:| 11 | + ------------------ + + The weighted median is 1 + Total error = Absolute(1 - 3) * 2 + Absolute(1 - 2) * 2 = 6 + + The best split is between X values of 2 and 3, with: + - left node being the first 2 data points, both with y=1 + => AE and impurity is 0 + - right node being the last 3 data points, weighted median is 2. + Total error = (Absolute(2 - 3) * 2) + + (Absolute(2 - 1) * 1) + + (Absolute(2 - 2) * 2) + = 3 """ dt_mae = DecisionTreeRegressor( random_state=0, criterion="absolute_error", max_leaf_nodes=2 @@ -1761,6 +1787,19 @@ def test_mae(): assert_array_equal(dt_mae.tree_.impurity, [1.4, 1.5, 4.0 / 3.0]) assert_array_equal(dt_mae.tree_.value.flat, [4, 4.5, 4.0]) + dt_mae = DecisionTreeRegressor( + random_state=0, + criterion="absolute_error", + max_depth=1, # stop after one split + ) + dt_mae.fit( + X=[[1], [2], [3], [4], [5]], + y=[1, 1, 3, 1, 2], + sample_weight=[3, 3, 2, 1, 2], + ) + assert_allclose(dt_mae.tree_.impurity, [6 / 11, 0, 3 / 5]) + assert_array_equal(dt_mae.tree_.value.flat, [1, 1, 2]) + def test_criterion_copy(): # Let's check whether copy of our criterion has the same type From f4a0e0749cb1d284141aa4c6ad66c98f58a729dc Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 4 Sep 2025 13:35:44 +0200 Subject: [PATCH 13/62] added changed logs --- .../upcoming_changes/sklearn.tree/32100.efficiency.rst | 4 ++++ doc/whats_new/upcoming_changes/sklearn.tree/32100.fix.rst | 6 ++++++ 2 files changed, 10 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/sklearn.tree/32100.efficiency.rst create mode 100644 doc/whats_new/upcoming_changes/sklearn.tree/32100.fix.rst diff --git a/doc/whats_new/upcoming_changes/sklearn.tree/32100.efficiency.rst b/doc/whats_new/upcoming_changes/sklearn.tree/32100.efficiency.rst new file mode 100644 index 0000000000000..0df37311f22ce --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.tree/32100.efficiency.rst @@ -0,0 +1,4 @@ +- :class:`tree.DecisionTreeRegressor` with `criterion="absolute_error"` + now runs much faster: O(n log n) complexity against previous O(n^2) + allowing to scale to millions of data points, even hundred of millions. + By :user:`Arthur Lacote ` diff --git a/doc/whats_new/upcoming_changes/sklearn.tree/32100.fix.rst b/doc/whats_new/upcoming_changes/sklearn.tree/32100.fix.rst new file mode 100644 index 0000000000000..7d337131c25e6 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.tree/32100.fix.rst @@ -0,0 +1,6 @@ +- :class:`tree.DecisionTreeRegressor` with `criterion="absolute_error"` + would sometimes make sub-optimal splits + (i.e. splits that don't minimize the absolute error). + Now it's fixed. Hence retraining trees might gives slightly different + results. + By :user:`Arthur Lacote ` From a86a1900c1dbcde0c5c6a766c2d4cba35aed4c75 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 4 Sep 2025 15:52:55 +0200 Subject: [PATCH 14/62] cleanup --- sklearn/tree/_utils.pxd | 2 +- sklearn/tree/_utils.pyx | 12 +++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index bdd1560cfce74..fa0b7da8f56a2 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -56,7 +56,7 @@ cdef class WeightedHeap: cdef float64_t weighted_sum cdef bint min_heap - cdef int reset(self) except -1 nogil + cdef void reset(self) noexcept nogil cdef bint is_empty(self) noexcept nogil cdef intp_t size(self) noexcept nogil cdef int push(self, float64_t value, float64_t weight) except -1 nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 35f8373cd0955..aee23004fffbb 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -122,16 +122,11 @@ cdef class WeightedHeap: if self.weights_ != NULL: free(self.weights_) - cdef int reset(self) except -1 nogil: + cdef void reset(self) noexcept nogil: """Reset to construction state (keeps capacity).""" self.size_ = 0 self.total_weight = 0.0 self.weighted_sum = 0.0 - # Ensure buffers still allocated (realloc may raise MemoryError) - # TODO: is this really needed? - safe_realloc(&self.heap_, self.capacity) - safe_realloc(&self.weights_, self.capacity) - return 0 cdef bint is_empty(self) noexcept nogil: return self.size_ == 0 @@ -145,9 +140,8 @@ cdef class WeightedHeap: cdef float64_t stored = value if self.min_heap else -value if n >= self.capacity: - self.capacity *= 2 - safe_realloc(&self.heap_, self.capacity) - safe_realloc(&self.weights_, self.capacity) + # should never happen as capacity is set to the max possible size + return -1 self.heap_[n] = stored self.weights_[n] = weight From 092af65698f966fcc2be6afd7e6cd582013fbed7 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 4 Sep 2025 17:45:36 +0200 Subject: [PATCH 15/62] comments & cleanups --- sklearn/tree/_criterion.pyx | 9 +- sklearn/tree/_utils.pxd | 14 +-- sklearn/tree/_utils.pyx | 240 ++++++++++++++++++------------------ 3 files changed, 131 insertions(+), 132 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 7919b9deee837..6ba36fb67576a 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1211,14 +1211,13 @@ cdef class MAE(RegressionCriterion): self.weighted_n_right = 0.0 self.node_medians = np.zeros(n_outputs, dtype=np.float64) - # FIXME? Those arrays could maybe be allocated dynamically to reduce memory - # footprint. This might even be required. + + # Note: this criterion has an important memory footprint, which is + # fine as it's instantiated only once to build an entire tree self.left_abs_errors = np.empty((n_outputs, n_samples), dtype=np.float64) self.right_abs_errors = np.empty((n_outputs, n_samples), dtype=np.float64) self.left_medians = np.empty(n_samples, dtype=np.float64) self.right_medians = np.empty(n_samples, dtype=np.float64) - - # FIXME? Same here, I could adapt the code of WeightedHeap to use dynamic allocation self.above = WeightedHeap(n_samples, True) # min-heap self.below = WeightedHeap(n_samples, False) # max-heap @@ -1367,7 +1366,7 @@ cdef class MAE(RegressionCriterion): i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. - Time complexity: O(n := end - start) + Time complexity: O(n) (n = end - start) """ cdef const float64_t[:] sample_weight = self.sample_weight cdef const intp_t[:] sample_indices = self.sample_indices diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index fa0b7da8f56a2..996daf0679c81 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -49,23 +49,19 @@ cdef float64_t rand_uniform(float64_t low, float64_t high, cdef class WeightedHeap: cdef intp_t capacity - cdef intp_t size_ - cdef float64_t* heap_ - cdef float64_t* weights_ + cdef intp_t size + cdef float64_t* heap + cdef float64_t* weights cdef float64_t total_weight cdef float64_t weighted_sum cdef bint min_heap cdef void reset(self) noexcept nogil cdef bint is_empty(self) noexcept nogil - cdef intp_t size(self) noexcept nogil - cdef int push(self, float64_t value, float64_t weight) except -1 nogil - cdef int pop(self, float64_t* value, float64_t* weight) noexcept nogil - cdef float64_t get_total_weight(self) noexcept nogil - cdef float64_t get_weighted_sum(self) noexcept nogil + cdef void push(self, float64_t value, float64_t weight) noexcept nogil + cdef void pop(self, float64_t* value, float64_t* weight) noexcept nogil cdef float64_t top_weight(self) noexcept nogil cdef float64_t top(self) noexcept nogil - cdef void _peek_raw(self, float64_t*, float64_t*) noexcept nogil cdef void _swap(self, intp_t, intp_t) noexcept nogil cdef void _perc_up(self, intp_t) noexcept nogil cdef void _perc_down(self, intp_t) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index aee23004fffbb..7b42b526a18e3 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -66,6 +66,26 @@ cdef inline float64_t rand_uniform(float64_t low, float64_t high, cdef inline float64_t log(float64_t x) noexcept nogil: return ln(x) / ln(2.0) + +def _any_isnan_axis0(const float32_t[:, :] X): + """Same as np.any(np.isnan(X), axis=0)""" + cdef: + intp_t i, j + intp_t n_samples = X.shape[0] + intp_t n_features = X.shape[1] + uint8_t[::1] isnan_out = np.zeros(X.shape[1], dtype=np.bool_) + + with nogil: + for i in range(n_samples): + for j in range(n_features): + if isnan_out[j]: + continue + if isnan(X[i, j]): + isnan_out[j] = True + break + return np.asarray(isnan_out) + + # ============================================================================= # WeightedHeap data structure # ============================================================================= @@ -78,18 +98,18 @@ cdef class WeightedHeap: - if min_heap: store v - else (max-heap): store -v - Attributes + Attributes (all should be treated as readonly attributes) ---------- capacity : intp_t Allocated capacity for the heap arrays. - size_ : intp_t + size : intp_t Current number of elements in the heap. - heap_ : float64_t* + heap : float64_t* Array of (possibly sign-adjusted) values that determines ordering. - weights_ : float64_t* + weights : float64_t* Parallel array of weights. total_weight : float64_t @@ -106,123 +126,104 @@ cdef class WeightedHeap: if capacity <= 0: capacity = 1 self.capacity = capacity - self.size_ = 0 + self.size = 0 self.min_heap = min_heap self.total_weight = 0.0 self.weighted_sum = 0.0 - self.heap_ = NULL - self.weights_ = NULL + self.heap = NULL + self.weights = NULL # safe_realloc can raise MemoryError -> __cinit__ may propagate - safe_realloc(&self.heap_, capacity) - safe_realloc(&self.weights_, capacity) + safe_realloc(&self.heap, capacity) + safe_realloc(&self.weights, capacity) def __dealloc__(self): - if self.heap_ != NULL: - free(self.heap_) - if self.weights_ != NULL: - free(self.weights_) + if self.heap != NULL: + free(self.heap) + if self.weights != NULL: + free(self.weights) cdef void reset(self) noexcept nogil: """Reset to construction state (keeps capacity).""" - self.size_ = 0 + self.size = 0 self.total_weight = 0.0 self.weighted_sum = 0.0 cdef bint is_empty(self) noexcept nogil: - return self.size_ == 0 + return self.size == 0 - cdef intp_t size(self) noexcept nogil: - return self.size_ - - cdef int push(self, float64_t value, float64_t weight) except -1 nogil: - """Insert a (value, weight). Returns 0 or raises MemoryError on alloc fail.""" - cdef intp_t n = self.size_ + cdef void push(self, float64_t value, float64_t weight) noexcept nogil: + """Insert a (value, weight).""" + cdef intp_t n = self.size cdef float64_t stored = value if self.min_heap else -value - if n >= self.capacity: - # should never happen as capacity is set to the max possible size - return -1 + assert n < self.capacity + # ^ should never raise as capacity is set to the max possible size - self.heap_[n] = stored - self.weights_[n] = weight - self.size_ = n + 1 + self.heap[n] = stored + self.weights[n] = weight + self.size = n + 1 self.total_weight += weight self.weighted_sum += value * weight self._perc_up(n) - return 0 - cdef int pop(self, float64_t* value, float64_t* weight) noexcept nogil: - """Pop top element into pointers. Returns 0 on success, -1 if empty.""" - cdef intp_t n = self.size_ - if n == 0: - return -1 + cdef void pop(self, float64_t* value, float64_t* weight) noexcept nogil: + """Pop top element into pointers.""" + cdef intp_t n = self.size + assert n > 0 - self._peek_raw(value, weight) + cdef float64_t stored = self.heap[0] + cdef float64_t v = stored if self.min_heap else -stored + cdef float64_t w = self.weights[0] + value[0] = v + weight[0] = w - # Update aggregates with *original* value (undo sign for max-heap) - cdef float64_t orig_v = value[0] - cdef float64_t w = weight[0] + # Update aggregates self.total_weight -= w - self.weighted_sum -= orig_v * w + self.weighted_sum -= v * w # Move last to root and sift down n -= 1 - self.size_ = n + self.size = n if n > 0: - self.heap_[0] = self.heap_[n] - self.weights_[0] = self.weights_[n] + self.heap[0] = self.heap[n] + self.weights[0] = self.weights[n] self._perc_down(0) - return 0 - - cdef float64_t get_total_weight(self) noexcept nogil: - return self.total_weight - - cdef float64_t get_weighted_sum(self) noexcept nogil: - return self.weighted_sum cdef float64_t top_weight(self) noexcept nogil: - if self.size_ == 0: - return 0.0 - return self.weights_[0] + assert self.size > 0 + return self.weights[0] cdef float64_t top(self) noexcept nogil: - if self.size_ == 0: - return 0.0 - cdef float64_t s = self.heap_[0] + assert self.size > 0 + cdef float64_t s = self.heap[0] return s if self.min_heap else -s # ---------------------------- # Internal helpers (nogil) # ---------------------------- - cdef inline void _peek_raw(self, float64_t* value, float64_t* weight) noexcept nogil: - """Internal: read top with proper sign restoration.""" - cdef float64_t stored = self.heap_[0] - value[0] = stored if self.min_heap else -stored - weight[0] = self.weights_[0] - cdef inline void _swap(self, intp_t i, intp_t j) noexcept nogil: - cdef float64_t tv = self.heap_[i] - cdef float64_t tw = self.weights_[i] - self.heap_[i] = self.heap_[j] - self.weights_[i] = self.weights_[j] - self.heap_[j] = tv - self.weights_[j] = tw + cdef float64_t tmp = self.heap[i] + self.heap[i] = self.heap[j] + self.heap[j] = tmp + tmp = self.weights[i] + self.weights[i] = self.weights[j] + self.weights[j] = tmp cdef inline void _perc_up(self, intp_t i) noexcept nogil: cdef intp_t p while i > 0: p = (i - 1) >> 1 - if self.heap_[i] < self.heap_[p]: + if self.heap[i] < self.heap[p]: self._swap(i, p) i = p else: break cdef inline void _perc_down(self, intp_t i) noexcept nogil: - cdef intp_t n = self.size_ + cdef intp_t n = self.size cdef intp_t left, right, mc while True: left = (i << 1) + 1 @@ -230,14 +231,17 @@ cdef class WeightedHeap: if left >= n: return mc = left - if right < n and self.heap_[right] < self.heap_[left]: + if right < n and self.heap[right] < self.heap[left]: mc = right - if self.heap_[i] > self.heap_[mc]: + if self.heap[i] > self.heap[mc]: self._swap(i, mc) i = mc else: return +# ============================================================================= +# MAE split precomputations algorithm +# ============================================================================= cdef void precompute_absolute_errors( const float64_t[:, ::1] ys, @@ -251,17 +255,30 @@ cdef void precompute_absolute_errors( float64_t[::1] abs_errors, float64_t[::1] medians ) noexcept nogil: - """Fill `abs_errors` with the optimal AEs for (y[:i], w[:i]) - i in [1, n]. - - Parameters - ---------- - y : 1D float64_t[::1] - Values. - w : 1D float64_t[::1] - Sample weights. - abs_errors : 1D float64_t[::1] - Output buffer, must have shape (n,). + """ + Fill `abs_errors` and `medians`. + + If start < end: + Computes the "prefix" AEs/medians, i.e the AEs for each set of indices + sample_indices[start:start + i] with i in {1, ..., n} + where n = end - start + Else: + Computes the "suffix" AEs/medians, i.e the AEs for each set of indices + sample_indices[i:] with i in {0, ..., n-1} + + Complexity: O(n log n) + This algorithm is an adaptation of the two heaps solution of + the "find median from a data stream" problem + See for instance: https://www.geeksforgeeks.org/dsa/median-of-stream-of-integers-running-integers/ + + But here, it's the weighted median and we also need to compute the AE, so: + - instead of balancing the heaps based on their number of elements, + rebalance them based on the sum of the weights of their element + - rewrite the AE computation by splitting the sum between elements + above and below the median, which allow to express it as a simple + O(1) computation. + See the maths in the PR desc: + https://github.com/scikit-learn/scikit-learn/pull/32100 """ cdef intp_t j, p, i, step, n if start < end: @@ -277,8 +294,7 @@ cdef void precompute_absolute_errors( below.reset() cdef float64_t y cdef float64_t w = 1.0 - cdef float64_t val = 0.0 - cdef float64_t wt = 0.0 + cdef float64_t top_val, top_weight cdef float64_t median = 0.0 cdef float64_t half_weight @@ -292,35 +308,34 @@ cdef void precompute_absolute_errors( # Insert into the appropriate heap if below.is_empty(): above.push(y, w) + elif y > below.top(): + above.push(y, w) else: - if y > below.top(): - above.push(y, w) - else: - below.push(y, w) + below.push(y, w) - half_weight = (above.get_total_weight() + below.get_total_weight()) / 2.0 + half_weight = (above.total_weight + below.total_weight) / 2.0 # Rebalance heaps - while above.get_total_weight() < half_weight and not below.is_empty(): - if below.pop(&val, &wt) == 0: - above.push(val, wt) + while above.total_weight < half_weight and not below.is_empty(): + below.pop(&top_val, &top_weight) + above.push(top_val, top_weight) while ( not above.is_empty() - and (above.get_total_weight() - above.top_weight()) >= half_weight + and (above.total_weight - above.top_weight()) >= half_weight ): - if above.pop(&val, &wt) == 0: - below.push(val, wt) + above.pop(&top_val, &top_weight) + below.push(top_val, top_weight) # Current median - if above.get_total_weight() > half_weight + 1e-5 * fabs(half_weight): + if above.total_weight > half_weight + 1e-5 * fabs(half_weight): median = above.top() else: # above and below weight are almost exactly equals median = (above.top() + below.top()) / 2. medians[j] = median abs_errors[j] = ( - (below.get_total_weight() - above.get_total_weight()) * median - - below.get_weighted_sum() - + above.get_weighted_sum() + (below.total_weight - above.total_weight) * median + - below.weighted_sum + + above.weighted_sum ) p += step j += step @@ -332,7 +347,15 @@ def _py_precompute_absolute_errors( const intp_t[:] sample_indices, bint suffix=False ): - """Used for testing precompute_absolute_errors""" + """ + Used for testing precompute_absolute_errors. + - If `suffix` is False: + Computes the "prefix" AEs, i.e the AEs for each set of indices + sample_indices[:i] with i in {1, ..., n} + - If `suffix` is True: + Computes the "suffix" AEs, i.e the AEs for each set of indices + sample_indices[i:] with i in {0, ..., n-1} + """ cdef: intp_t n = sample_weight.size WeightedHeap above = WeightedHeap(n, True) @@ -352,22 +375,3 @@ def _py_precompute_absolute_errors( k, start, end, abs_errors, medians ) return np.asarray(abs_errors) - - -def _any_isnan_axis0(const float32_t[:, :] X): - """Same as np.any(np.isnan(X), axis=0)""" - cdef: - intp_t i, j - intp_t n_samples = X.shape[0] - intp_t n_features = X.shape[1] - uint8_t[::1] isnan_out = np.zeros(X.shape[1], dtype=np.bool_) - - with nogil: - for i in range(n_samples): - for j in range(n_features): - if isnan_out[j]: - continue - if isnan(X[i, j]): - isnan_out[j] = True - break - return np.asarray(isnan_out) From 4a12dea10b59e729d990bc0422a6dd3c639ecff5 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 4 Sep 2025 18:29:33 +0200 Subject: [PATCH 16/62] slight refactor of class inheritance --- sklearn/tree/_criterion.pxd | 2 +- sklearn/tree/_criterion.pyx | 42 ++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 850a31224c10b..c6f51a3befe49 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -110,7 +110,7 @@ cdef class RegressionCriterion(Criterion): cdef float64_t[::1] sum_missing # Same as above, but for missing values in X -cdef class MAE(RegressionCriterion): +cdef class MAE(Criterion): cdef float64_t[::1] node_medians cdef float64_t[:, ::1] left_abs_errors diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 6ba36fb67576a..3dfefd6635710 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -3,7 +3,7 @@ from libc.string cimport memcpy from libc.string cimport memset -from libc.math cimport fabs, INFINITY +from libc.math cimport INFINITY import numpy as np cimport numpy as cnp @@ -1065,6 +1065,7 @@ cdef class RegressionCriterion(Criterion): return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) + cdef class MSE(RegressionCriterion): """Mean squared error impurity criterion. @@ -1181,11 +1182,15 @@ cdef class MSE(RegressionCriterion): impurity_right[0] /= self.n_outputs -cdef class MAE(RegressionCriterion): +cdef class MAE(Criterion): r"""Mean absolute error impurity criterion. - MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true - value and f_i is the predicted value.""" + MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true + value and f_i is the predicted value. + + It has almost nothing in common with other regression criterions + so it doesn't inherit from RegressionCriterion + """ def __cinit__(self, intp_t n_outputs, intp_t n_samples): """Initialize parameters for this criterion. @@ -1297,9 +1302,7 @@ cdef class MAE(RegressionCriterion): return 0 cdef int reverse_reset(self) except -1 nogil: - """ - For this class, this function is never called - """ + """For this class, this method is never called""" return -1 cdef int update(self, intp_t new_pos) except -1 nogil: @@ -1366,22 +1369,12 @@ cdef class MAE(RegressionCriterion): i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. - Time complexity: O(n) (n = end - start) + Time complexity: O(n_outputs) (precomputed in `.reset()`) """ - cdef const float64_t[:] sample_weight = self.sample_weight - cdef const intp_t[:] sample_indices = self.sample_indices - cdef intp_t i, p, k - cdef float64_t w = 1.0 cdef float64_t impurity = 0.0 for k in range(self.n_outputs): - for p in range(self.start, self.end): - i = sample_indices[p] - - if sample_weight is not None: - w = sample_weight[i] - - impurity += fabs(self.y[i, k] - self.node_medians[k]) * w + impurity += self.right_abs_errors[k, 0] return impurity / (self.weighted_n_node_samples * self.n_outputs) @@ -1413,6 +1406,17 @@ cdef class MAE(RegressionCriterion): p_impurity_right[0] = impurity_right / (self.weighted_n_right * self.n_outputs) + # those 2 methods are copied from the RegressionCriterion abstract class: + def __reduce__(self): + return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) + + cdef inline void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil: + """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" + if dest[0] < lower_bound: + dest[0] = lower_bound + elif dest[0] > upper_bound: + dest[0] = upper_bound + cdef class FriedmanMSE(MSE): """Mean squared error impurity criterion with improvement score by Friedman. From 81728c2a5f9d570262703151ec49306beab88062 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 9 Sep 2025 14:47:41 +0200 Subject: [PATCH 17/62] adressed PR comments; simplified dimension of left/right abs errors array --- sklearn/tree/_criterion.pxd | 11 ------- sklearn/tree/_criterion.pyx | 65 ++++++++++++++++++++++++------------- sklearn/tree/_utils.pyx | 2 +- 3 files changed, 44 insertions(+), 34 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index c6f51a3befe49..24ea34892db7b 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -108,14 +108,3 @@ cdef class RegressionCriterion(Criterion): cdef float64_t[::1] sum_left # Same as above, but for the left side of the split cdef float64_t[::1] sum_right # Same as above, but for the right side of the split cdef float64_t[::1] sum_missing # Same as above, but for missing values in X - - -cdef class MAE(Criterion): - - cdef float64_t[::1] node_medians - cdef float64_t[:, ::1] left_abs_errors - cdef float64_t[:, ::1] right_abs_errors - cdef float64_t[::1] left_medians - cdef float64_t[::1] right_medians - cdef WeightedHeap above - cdef WeightedHeap below diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 3dfefd6635710..a1397dda980fa 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1191,6 +1191,13 @@ cdef class MAE(Criterion): It has almost nothing in common with other regression criterions so it doesn't inherit from RegressionCriterion """ + cdef float64_t[::1] node_medians + cdef float64_t[::1] left_abs_errors + cdef float64_t[::1] right_abs_errors + cdef float64_t[::1] left_medians + cdef float64_t[::1] right_medians + cdef WeightedHeap above + cdef WeightedHeap below def __cinit__(self, intp_t n_outputs, intp_t n_samples): """Initialize parameters for this criterion. @@ -1219,8 +1226,8 @@ cdef class MAE(Criterion): # Note: this criterion has an important memory footprint, which is # fine as it's instantiated only once to build an entire tree - self.left_abs_errors = np.empty((n_outputs, n_samples), dtype=np.float64) - self.right_abs_errors = np.empty((n_outputs, n_samples), dtype=np.float64) + self.left_abs_errors = np.empty(n_samples, dtype=np.float64) + self.right_abs_errors = np.empty(n_samples, dtype=np.float64) self.left_medians = np.empty(n_samples, dtype=np.float64) self.right_medians = np.empty(n_samples, dtype=np.float64) self.above = WeightedHeap(n_samples, True) # min-heap @@ -1245,14 +1252,13 @@ cdef class MAE(Criterion): """ cdef intp_t i cdef float64_t w = 1.0 - cdef intp_t n = end - start # Initialize fields self.y = y self.sample_weight = sample_weight self.sample_indices = sample_indices self.start = start self.end = end - self.n_node_samples = n + self.n_node_samples = end - start self.weighted_n_samples = weighted_n_samples self.weighted_n_node_samples = 0. @@ -1288,15 +1294,35 @@ cdef class MAE(Criterion): self.weighted_n_right = self.weighted_n_node_samples self.pos = self.start - for k in range(self.n_outputs - 1, -1, -1): + n_bytes = self.n_node_samples * sizeof(float64_t) + memset(&self.left_abs_errors[0], 0, n_bytes) + memset(&self.right_abs_errors[0], 0, n_bytes) + + # For each output (from last to first), precompute absolute errors and medians + # for both left and right splits. + # Precomputation is needed here and can't be done step-by-step in the update method + # like for other criterions. Indeed, we don't have efficient way to update right child + # statistics when removing samples from it. So we compute right child AEs/medians by + # traversing from right to left (and hence only adding samples). + for k in range(self.n_outputs): + # Note that at each iteration of this loop, we overwrite `self.left_medians` + # and `self.right_medians` which is fine. Those are used to check + # for monoticity constraints, which are allowed only with n_outputs=1. precompute_absolute_errors( - self.y, self.sample_weight, self.sample_indices, self.above, self.below, - k, self.start, self.end, self.left_abs_errors[k], self.left_medians + self.y, self.sample_weight, self.sample_indices, + self.above, self.below, k, self.start, self.end, + # left_abs_errors is incremented, left_medians is overwritten + self.left_abs_errors, self.left_medians ) + # For the right child, we consider samples from end-1 to start-1 + # i.e., reversed, and abs error & median are filled in reverse order to. precompute_absolute_errors( - self.y, self.sample_weight, self.sample_indices, self.above, self.below, - k, self.end - 1, self.start - 1, self.right_abs_errors[k], self.right_medians + self.y, self.sample_weight, self.sample_indices, + self.above, self.below, k, self.end - 1, self.start - 1, + # right_abs_errors is incremented, right_medians is overwritten + self.right_abs_errors, self.right_medians ) + # Store the median for the current node self.node_medians[k] = self.right_medians[0] return 0 @@ -1369,14 +1395,12 @@ cdef class MAE(Criterion): i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. - Time complexity: O(n_outputs) (precomputed in `.reset()`) + Time complexity: O(1) (precomputed in `.reset()`) """ - cdef float64_t impurity = 0.0 - - for k in range(self.n_outputs): - impurity += self.right_abs_errors[k, 0] - - return impurity / (self.weighted_n_node_samples * self.n_outputs) + return ( + self.right_abs_errors[0] + / (self.weighted_n_node_samples * self.n_outputs) + ) cdef void children_impurity(self, float64_t* p_impurity_left, float64_t* p_impurity_right) noexcept nogil: @@ -1385,24 +1409,21 @@ cdef class MAE(Criterion): i.e. the impurity of the left child (sample_indices[start:pos]) and the impurity the right child (sample_indices[pos:end]). - Time complexity: O(n_outputs) + Time complexity: O(1) """ cdef intp_t j = self.pos - self.start - cdef intp_t k cdef float64_t impurity_left = 0.0 cdef float64_t impurity_right = 0.0 # if pos == start, left child is empty, hence impurity is 0 if self.pos > self.start: - for k in range(self.n_outputs): - impurity_left += self.left_abs_errors[k, j - 1] + impurity_left += self.left_abs_errors[j - 1] p_impurity_left[0] = impurity_left / (self.weighted_n_left * self.n_outputs) # if pos == end, right child is empty, hence impurity is 0 if self.pos < self.end: - for k in range(self.n_outputs): - impurity_right += self.right_abs_errors[k, j] + impurity_right += self.right_abs_errors[j] p_impurity_right[0] = impurity_right / (self.weighted_n_right * self.n_outputs) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 7b42b526a18e3..4ce1aa8f4ebbe 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -332,7 +332,7 @@ cdef void precompute_absolute_errors( else: # above and below weight are almost exactly equals median = (above.top() + below.top()) / 2. medians[j] = median - abs_errors[j] = ( + abs_errors[j] += ( (below.total_weight - above.total_weight) * median - below.weighted_sum + above.weighted_sum From 7477f4cf598f38ff07c5d7a734a6974792f61435 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 9 Sep 2025 14:48:12 +0200 Subject: [PATCH 18/62] removed print --- sklearn/tree/tests/test_tree.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 04adff64bb4c4..22ac42e10d761 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2903,7 +2903,6 @@ def compute_abs_error(y: np.ndarray, w: np.ndarray): wc = np.cumsum(w[sorter]) idx = np.searchsorted(wc, wc[-1] / 2) median = y[sorter[idx]] - print(y, median) # 2) compute the AE return (np.abs(y - median) * w).sum() From 8f035d05d9675511e4c795461c4d1b064a1eff5f Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 10 Sep 2025 18:28:49 +0200 Subject: [PATCH 19/62] heap methods docstring; test: split assertion --- sklearn/tree/_utils.pxd | 4 ++-- sklearn/tree/_utils.pyx | 10 ++++++---- sklearn/tree/tests/test_tree.py | 4 +++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 996daf0679c81..d89939fa178e2 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -63,8 +63,8 @@ cdef class WeightedHeap: cdef float64_t top_weight(self) noexcept nogil cdef float64_t top(self) noexcept nogil cdef void _swap(self, intp_t, intp_t) noexcept nogil - cdef void _perc_up(self, intp_t) noexcept nogil - cdef void _perc_down(self, intp_t) noexcept nogil + cdef void _heapify_up(self, intp_t) noexcept nogil + cdef void _heapify_down(self, intp_t) noexcept nogil cdef void precompute_absolute_errors( const float64_t[:, ::1] ys, diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 4ce1aa8f4ebbe..9eeecdf8b740c 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -166,7 +166,7 @@ cdef class WeightedHeap: self.total_weight += weight self.weighted_sum += value * weight - self._perc_up(n) + self._heapify_up(n) cdef void pop(self, float64_t* value, float64_t* weight) noexcept nogil: """Pop top element into pointers.""" @@ -189,7 +189,7 @@ cdef class WeightedHeap: if n > 0: self.heap[0] = self.heap[n] self.weights[0] = self.weights[n] - self._perc_down(0) + self._heapify_down(0) cdef float64_t top_weight(self) noexcept nogil: assert self.size > 0 @@ -212,7 +212,8 @@ cdef class WeightedHeap: self.weights[i] = self.weights[j] self.weights[j] = tmp - cdef inline void _perc_up(self, intp_t i) noexcept nogil: + cdef inline void _heapify_up(self, intp_t i) noexcept nogil: + """Move up the element at index i until heap invariant is restored.""" cdef intp_t p while i > 0: p = (i - 1) >> 1 @@ -222,7 +223,8 @@ cdef class WeightedHeap: else: break - cdef inline void _perc_down(self, intp_t i) noexcept nogil: + cdef inline void _heapify_down(self, intp_t i) noexcept nogil: + """Move down the element at index i until heap invariant is restored.""" cdef intp_t n = self.size cdef intp_t left, right, mc while True: diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 22ac42e10d761..5188d07bc6594 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1792,11 +1792,13 @@ def test_mae(): criterion="absolute_error", max_depth=1, # stop after one split ) + X = [[1], [2], [3], [4], [5]] dt_mae.fit( - X=[[1], [2], [3], [4], [5]], + X=X, y=[1, 1, 3, 1, 2], sample_weight=[3, 3, 2, 1, 2], ) + assert_allclose(dt_mae.predict(X), [1, 1, 2, 2, 2]) assert_allclose(dt_mae.tree_.impurity, [6 / 11, 0, 3 / 5]) assert_array_equal(dt_mae.tree_.value.flat, [1, 1, 2]) From e6bf43b9e88305df0bbb768cbc95adb2157f94b3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 10 Sep 2025 20:48:09 +0200 Subject: [PATCH 20/62] unit test for heap --- sklearn/tree/_utils.pyx | 16 +++++++++++++--- sklearn/tree/tests/test_heap.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 sklearn/tree/tests/test_heap.py diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 9eeecdf8b740c..5dbb8ee6b897b 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -200,9 +200,7 @@ cdef class WeightedHeap: cdef float64_t s = self.heap[0] return s if self.min_heap else -s - # ---------------------------- - # Internal helpers (nogil) - # ---------------------------- + # Internal helpers (nogil): cdef inline void _swap(self, intp_t i, intp_t j) noexcept nogil: cdef float64_t tmp = self.heap[i] @@ -241,6 +239,18 @@ cdef class WeightedHeap: else: return + # Python callable wrappers for unit tests: + def _py_push(self, double value, double weight): + self.push(value, weight) + + def _py_pop(self): + cdef double v, w + self.pop(&v, &w) + return v, w + + def _py_reset(self): + self.reset() + # ============================================================================= # MAE split precomputations algorithm # ============================================================================= diff --git a/sklearn/tree/tests/test_heap.py b/sklearn/tree/tests/test_heap.py new file mode 100644 index 0000000000000..1f7321630bd04 --- /dev/null +++ b/sklearn/tree/tests/test_heap.py @@ -0,0 +1,33 @@ +import random +from heapq import heappop, heappush + +import pytest + +from sklearn.tree._utils import WeightedHeap + + +@pytest.mark.parametrize("min_heap", [True, False]) +def test_weighted_heap(min_heap): + n = 200 + w_heap = WeightedHeap(n, min_heap=min_heap) + py_heap = [] + + def pop_from_heaps_and_compare(): + top, top_w = w_heap._py_pop() + top_, top_w_ = heappop(py_heap) + if not min_heap: + top_ = -top_ + assert top == top_ + assert top_w == top_w_ + + for _ in range(n): + if len(py_heap) > 0 and random.random() < 1 / 3: + pop_from_heaps_and_compare() + else: + y = random.random() + w = random.random() + heappush(py_heap, (y if min_heap else -y, w)) + w_heap._py_push(y, w) + + for _ in range(len(py_heap)): + pop_from_heaps_and_compare() From eb2ccf54d54654eceabf1f2cc4545d5d46a0b776 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 10 Sep 2025 21:58:19 +0200 Subject: [PATCH 21/62] fix comment --- sklearn/tree/_utils.pyx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 5dbb8ee6b897b..644ba258d7472 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -239,7 +239,8 @@ cdef class WeightedHeap: else: return - # Python callable wrappers for unit tests: + # Wrappers callable from Python for tests: + def _py_push(self, double value, double weight): self.push(value, weight) From d13a2c54ea574b56ffa23a4869b954948ac68304 Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Sat, 13 Sep 2025 11:28:01 +0200 Subject: [PATCH 22/62] Apply suggestions from code review Naming & comments Co-authored-by: Adam Li --- sklearn/tree/_criterion.pyx | 5 +++-- sklearn/tree/tests/test_heap.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index a1397dda980fa..ea0bac6f086a4 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1298,8 +1298,9 @@ cdef class MAE(Criterion): memset(&self.left_abs_errors[0], 0, n_bytes) memset(&self.right_abs_errors[0], 0, n_bytes) - # For each output (from last to first), precompute absolute errors and medians - # for both left and right splits. + # Precompute absolute errors (summed over each ouput) and medians (used only when n_outputs=1) + # of the right and left child of all possible splits + # for the current ordering of `sample_indices` # Precomputation is needed here and can't be done step-by-step in the update method # like for other criterions. Indeed, we don't have efficient way to update right child # statistics when removing samples from it. So we compute right child AEs/medians by diff --git a/sklearn/tree/tests/test_heap.py b/sklearn/tree/tests/test_heap.py index 1f7321630bd04..51f92d5486929 100644 --- a/sklearn/tree/tests/test_heap.py +++ b/sklearn/tree/tests/test_heap.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("min_heap", [True, False]) -def test_weighted_heap(min_heap): +def test_cython_weighted_heap_vs_heapq(min_heap): n = 200 w_heap = WeightedHeap(n, min_heap=min_heap) py_heap = [] From 4fc78f4450576d53b25b0ee1e95d20deec66d0a5 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 13 Sep 2025 11:43:41 +0200 Subject: [PATCH 23/62] comments & naming --- sklearn/tree/_criterion.pyx | 13 +++++++------ sklearn/tree/_utils.pyx | 10 +++++----- sklearn/tree/tests/test_heap.py | 8 ++++---- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index ea0bac6f086a4..b12dc03d9e981 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1298,17 +1298,18 @@ cdef class MAE(Criterion): memset(&self.left_abs_errors[0], 0, n_bytes) memset(&self.right_abs_errors[0], 0, n_bytes) - # Precompute absolute errors (summed over each ouput) and medians (used only when n_outputs=1) + # Precompute absolute errors (summed over each output) + # and medians (used only when n_outputs=1) # of the right and left child of all possible splits # for the current ordering of `sample_indices` # Precomputation is needed here and can't be done step-by-step in the update method - # like for other criterions. Indeed, we don't have efficient way to update right child + # like for other criterions. Indeed, we don't have efficient ways to update right child # statistics when removing samples from it. So we compute right child AEs/medians by # traversing from right to left (and hence only adding samples). for k in range(self.n_outputs): # Note that at each iteration of this loop, we overwrite `self.left_medians` - # and `self.right_medians` which is fine. Those are used to check - # for monoticity constraints, which are allowed only with n_outputs=1. + # and `self.right_medians`. They are used to check for monoticity constraints, + # which are allowed only with n_outputs=1. precompute_absolute_errors( self.y, self.sample_weight, self.sample_indices, self.above, self.below, k, self.start, self.end, @@ -1330,7 +1331,7 @@ cdef class MAE(Criterion): cdef int reverse_reset(self) except -1 nogil: """For this class, this method is never called""" - return -1 + raise NotImplementedError("This method is not implemented for this subclass") cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left. @@ -1410,7 +1411,7 @@ cdef class MAE(Criterion): i.e. the impurity of the left child (sample_indices[start:pos]) and the impurity the right child (sample_indices[pos:end]). - Time complexity: O(1) + Time complexity: O(1) (precomputed in `.reset()`) """ cdef intp_t j = self.pos - self.start cdef float64_t impurity_left = 0.0 diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 644ba258d7472..325c2e0bf4612 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -239,18 +239,18 @@ cdef class WeightedHeap: else: return - # Wrappers callable from Python for tests: - def _py_push(self, double value, double weight): +cdef class PytestWeightedHeap(WeightedHeap): + """Used for testing only""" + + def py_push(self, double value, double weight): self.push(value, weight) - def _py_pop(self): + def py_pop(self): cdef double v, w self.pop(&v, &w) return v, w - def _py_reset(self): - self.reset() # ============================================================================= # MAE split precomputations algorithm diff --git a/sklearn/tree/tests/test_heap.py b/sklearn/tree/tests/test_heap.py index 51f92d5486929..00521afc83832 100644 --- a/sklearn/tree/tests/test_heap.py +++ b/sklearn/tree/tests/test_heap.py @@ -3,17 +3,17 @@ import pytest -from sklearn.tree._utils import WeightedHeap +from sklearn.tree._utils import PytestWeightedHeap @pytest.mark.parametrize("min_heap", [True, False]) def test_cython_weighted_heap_vs_heapq(min_heap): n = 200 - w_heap = WeightedHeap(n, min_heap=min_heap) + w_heap = PytestWeightedHeap(n, min_heap=min_heap) py_heap = [] def pop_from_heaps_and_compare(): - top, top_w = w_heap._py_pop() + top, top_w = w_heap.py_pop() top_, top_w_ = heappop(py_heap) if not min_heap: top_ = -top_ @@ -27,7 +27,7 @@ def pop_from_heaps_and_compare(): y = random.random() w = random.random() heappush(py_heap, (y if min_heap else -y, w)) - w_heap._py_push(y, w) + w_heap.py_push(y, w) for _ in range(len(py_heap)): pop_from_heaps_and_compare() From 220c34fa3303897b4ef34df3fedc65d46a72abea Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 13 Sep 2025 11:57:23 +0200 Subject: [PATCH 24/62] parameters docstring --- sklearn/tree/_utils.pyx | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 325c2e0bf4612..2360ce104aaa4 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -279,6 +279,28 @@ cdef void precompute_absolute_errors( Computes the "suffix" AEs/medians, i.e the AEs for each set of indices sample_indices[i:] with i in {0, ..., n-1} + Parameters + ---------- + ys : const float64_t[:, ::1] + Target values. Shape: (n_samples, n_outputs). + sample_weight : const float64_t[:] + Shape: (n_samples,) + sample_indices : const intp_t[:] + indices indicating which samples to use. Shape: (n_samples,) + above : WeightedHeap + below : WeightedHeap + k : intp_t + Dimension to consider in y. In [0, n_outputs - 1]. + start : intp_t + Start index in `sample_indices` + end : intp_t + End index (exclusive) in `sample_indices` + abs_errors : float64_t[::1] + array to store (increment) the computed absolute errors. Shape: (n,) + with n := end - start + medians : float64_t[::1] + array to store (overwrite) the computed medians. Shape: (n,) + Complexity: O(n log n) This algorithm is an adaptation of the two heaps solution of the "find median from a data stream" problem @@ -286,7 +308,7 @@ cdef void precompute_absolute_errors( But here, it's the weighted median and we also need to compute the AE, so: - instead of balancing the heaps based on their number of elements, - rebalance them based on the sum of the weights of their element + rebalance them based on the summed weights of their elements - rewrite the AE computation by splitting the sum between elements above and below the median, which allow to express it as a simple O(1) computation. From d9b3c35570cd6a505bf56c09b3bd7828a623e727 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 14 Sep 2025 17:16:44 +0200 Subject: [PATCH 25/62] Update doc about MAE criterion speed --- doc/modules/tree.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index ee36d9f6af1b2..8df9167ff1073 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -572,7 +572,7 @@ Mean Absolute Error: H(Q_m) = \frac{1}{n_m} \sum_{y \in Q_m} |y - median(y)_m| -Note that it fits much slower than the MSE criterion. +Note that it fits slower than the MSE criterion. .. _tree_missing_value_support: From 72e15b5c69ecf2570712878d6a16cf43551680c2 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 14 Sep 2025 18:26:12 +0200 Subject: [PATCH 26/62] move precompute --- sklearn/tree/_criterion.pyx | 161 +++++++++++++++++++++++++++++++- sklearn/tree/_utils.pxd | 13 --- sklearn/tree/_utils.pyx | 161 -------------------------------- sklearn/tree/tests/test_tree.py | 2 +- 4 files changed, 160 insertions(+), 177 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index b12dc03d9e981..0e4b3358e5f89 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -3,7 +3,7 @@ from libc.string cimport memcpy from libc.string cimport memset -from libc.math cimport INFINITY +from libc.math cimport fabs, INFINITY import numpy as np cimport numpy as cnp @@ -13,7 +13,6 @@ from scipy.special.cython_special cimport xlogy from ._utils cimport log from ._utils cimport WeightedHeap -from ._utils cimport precompute_absolute_errors # EPSILON is used in the Poisson criterion cdef float64_t EPSILON = 10 * np.finfo('double').eps @@ -1182,6 +1181,164 @@ cdef class MSE(RegressionCriterion): impurity_right[0] /= self.n_outputs +# Helper for MAE criterion: + +cdef void precompute_absolute_errors( + const float64_t[:, ::1] ys, + const float64_t[:] sample_weight, + const intp_t[:] sample_indices, + WeightedHeap above, + WeightedHeap below, + intp_t k, + intp_t start, + intp_t end, + float64_t[::1] abs_errors, + float64_t[::1] medians +) noexcept nogil: + """ + Fill `abs_errors` and `medians`. + + If start < end: + Computes the "prefix" AEs/medians, i.e the AEs for each set of indices + sample_indices[start:start + i] with i in {1, ..., n} + where n = end - start + Else: + Computes the "suffix" AEs/medians, i.e the AEs for each set of indices + sample_indices[i:] with i in {0, ..., n-1} + + Parameters + ---------- + ys : const float64_t[:, ::1] + Target values. Shape: (n_samples, n_outputs). + sample_weight : const float64_t[:] + Shape: (n_samples,) + sample_indices : const intp_t[:] + indices indicating which samples to use. Shape: (n_samples,) + above : WeightedHeap + below : WeightedHeap + k : intp_t + Dimension to consider in y. In [0, n_outputs - 1]. + start : intp_t + Start index in `sample_indices` + end : intp_t + End index (exclusive) in `sample_indices` + abs_errors : float64_t[::1] + array to store (increment) the computed absolute errors. Shape: (n,) + with n := end - start + medians : float64_t[::1] + array to store (overwrite) the computed medians. Shape: (n,) + + Complexity: O(n log n) + This algorithm is an adaptation of the two heaps solution of + the "find median from a data stream" problem + See for instance: https://www.geeksforgeeks.org/dsa/median-of-stream-of-integers-running-integers/ + + But here, it's the weighted median and we also need to compute the AE, so: + - instead of balancing the heaps based on their number of elements, + rebalance them based on the summed weights of their elements + - rewrite the AE computation by splitting the sum between elements + above and below the median, which allow to express it as a simple + O(1) computation. + See the maths in the PR desc: + https://github.com/scikit-learn/scikit-learn/pull/32100 + """ + cdef intp_t j, p, i, step, n + if start < end: + j = 0 + step = 1 + n = end - start + else: + n = start - end + step = -1 + j = n - 1 + + above.reset() + below.reset() + cdef float64_t y + cdef float64_t w = 1.0 + cdef float64_t top_val, top_weight + cdef float64_t median = 0.0 + cdef float64_t half_weight + + p = start + for _ in range(n): + i = sample_indices[p] + if sample_weight is not None: + w = sample_weight[i] + y = ys[i, k] + + # Insert into the appropriate heap + if below.is_empty(): + above.push(y, w) + elif y > below.top(): + above.push(y, w) + else: + below.push(y, w) + + half_weight = (above.total_weight + below.total_weight) / 2.0 + + # Rebalance heaps + while above.total_weight < half_weight and not below.is_empty(): + below.pop(&top_val, &top_weight) + above.push(top_val, top_weight) + while ( + not above.is_empty() + and (above.total_weight - above.top_weight()) >= half_weight + ): + above.pop(&top_val, &top_weight) + below.push(top_val, top_weight) + + # Current median + if above.total_weight > half_weight + 1e-5 * fabs(half_weight): + median = above.top() + else: # above and below weight are almost exactly equals + median = (above.top() + below.top()) / 2. + medians[j] = median + abs_errors[j] += ( + (below.total_weight - above.total_weight) * median + - below.weighted_sum + + above.weighted_sum + ) + p += step + j += step + + +def _py_precompute_absolute_errors( + const float64_t[:, ::1] ys, + const float64_t[:] sample_weight, + const intp_t[:] sample_indices, + bint suffix=False +): + """ + Used for testing precompute_absolute_errors. + - If `suffix` is False: + Computes the "prefix" AEs, i.e the AEs for each set of indices + sample_indices[:i] with i in {1, ..., n} + - If `suffix` is True: + Computes the "suffix" AEs, i.e the AEs for each set of indices + sample_indices[i:] with i in {0, ..., n-1} + """ + cdef: + intp_t n = sample_weight.size + WeightedHeap above = WeightedHeap(n, True) + WeightedHeap below = WeightedHeap(n, False) + intp_t k = 0 + intp_t start = 0 + intp_t end = n + float64_t[::1] abs_errors = np.zeros(n, dtype=np.float64) + float64_t[::1] medians = np.zeros(n, dtype=np.float64) + + if suffix: + start = n - 1 + end = -1 + + precompute_absolute_errors( + ys, sample_weight, sample_indices, above, below, + k, start, end, abs_errors, medians + ) + return np.asarray(abs_errors) + + cdef class MAE(Criterion): r"""Mean absolute error impurity criterion. diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index d89939fa178e2..9f0d670a51223 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -66,17 +66,4 @@ cdef class WeightedHeap: cdef void _heapify_up(self, intp_t) noexcept nogil cdef void _heapify_down(self, intp_t) noexcept nogil -cdef void precompute_absolute_errors( - const float64_t[:, ::1] ys, - const float64_t[:] sample_weight, - const intp_t[:] sample_indices, - WeightedHeap above, - WeightedHeap below, - intp_t k, - intp_t start, - intp_t end, - float64_t[::1] abs_errors, - float64_t[::1] medians -) noexcept nogil - cdef float64_t log(float64_t x) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 2360ce104aaa4..0c7d55039e2d1 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -5,7 +5,6 @@ from libc.stdlib cimport free from libc.stdlib cimport realloc from libc.math cimport log as ln from libc.math cimport isnan -from libc.math cimport fabs import numpy as np cimport numpy as cnp @@ -250,163 +249,3 @@ cdef class PytestWeightedHeap(WeightedHeap): cdef double v, w self.pop(&v, &w) return v, w - - -# ============================================================================= -# MAE split precomputations algorithm -# ============================================================================= - -cdef void precompute_absolute_errors( - const float64_t[:, ::1] ys, - const float64_t[:] sample_weight, - const intp_t[:] sample_indices, - WeightedHeap above, - WeightedHeap below, - intp_t k, - intp_t start, - intp_t end, - float64_t[::1] abs_errors, - float64_t[::1] medians -) noexcept nogil: - """ - Fill `abs_errors` and `medians`. - - If start < end: - Computes the "prefix" AEs/medians, i.e the AEs for each set of indices - sample_indices[start:start + i] with i in {1, ..., n} - where n = end - start - Else: - Computes the "suffix" AEs/medians, i.e the AEs for each set of indices - sample_indices[i:] with i in {0, ..., n-1} - - Parameters - ---------- - ys : const float64_t[:, ::1] - Target values. Shape: (n_samples, n_outputs). - sample_weight : const float64_t[:] - Shape: (n_samples,) - sample_indices : const intp_t[:] - indices indicating which samples to use. Shape: (n_samples,) - above : WeightedHeap - below : WeightedHeap - k : intp_t - Dimension to consider in y. In [0, n_outputs - 1]. - start : intp_t - Start index in `sample_indices` - end : intp_t - End index (exclusive) in `sample_indices` - abs_errors : float64_t[::1] - array to store (increment) the computed absolute errors. Shape: (n,) - with n := end - start - medians : float64_t[::1] - array to store (overwrite) the computed medians. Shape: (n,) - - Complexity: O(n log n) - This algorithm is an adaptation of the two heaps solution of - the "find median from a data stream" problem - See for instance: https://www.geeksforgeeks.org/dsa/median-of-stream-of-integers-running-integers/ - - But here, it's the weighted median and we also need to compute the AE, so: - - instead of balancing the heaps based on their number of elements, - rebalance them based on the summed weights of their elements - - rewrite the AE computation by splitting the sum between elements - above and below the median, which allow to express it as a simple - O(1) computation. - See the maths in the PR desc: - https://github.com/scikit-learn/scikit-learn/pull/32100 - """ - cdef intp_t j, p, i, step, n - if start < end: - j = 0 - step = 1 - n = end - start - else: - n = start - end - step = -1 - j = n - 1 - - above.reset() - below.reset() - cdef float64_t y - cdef float64_t w = 1.0 - cdef float64_t top_val, top_weight - cdef float64_t median = 0.0 - cdef float64_t half_weight - - p = start - for _ in range(n): - i = sample_indices[p] - if sample_weight is not None: - w = sample_weight[i] - y = ys[i, k] - - # Insert into the appropriate heap - if below.is_empty(): - above.push(y, w) - elif y > below.top(): - above.push(y, w) - else: - below.push(y, w) - - half_weight = (above.total_weight + below.total_weight) / 2.0 - - # Rebalance heaps - while above.total_weight < half_weight and not below.is_empty(): - below.pop(&top_val, &top_weight) - above.push(top_val, top_weight) - while ( - not above.is_empty() - and (above.total_weight - above.top_weight()) >= half_weight - ): - above.pop(&top_val, &top_weight) - below.push(top_val, top_weight) - - # Current median - if above.total_weight > half_weight + 1e-5 * fabs(half_weight): - median = above.top() - else: # above and below weight are almost exactly equals - median = (above.top() + below.top()) / 2. - medians[j] = median - abs_errors[j] += ( - (below.total_weight - above.total_weight) * median - - below.weighted_sum - + above.weighted_sum - ) - p += step - j += step - - -def _py_precompute_absolute_errors( - const float64_t[:, ::1] ys, - const float64_t[:] sample_weight, - const intp_t[:] sample_indices, - bint suffix=False -): - """ - Used for testing precompute_absolute_errors. - - If `suffix` is False: - Computes the "prefix" AEs, i.e the AEs for each set of indices - sample_indices[:i] with i in {1, ..., n} - - If `suffix` is True: - Computes the "suffix" AEs, i.e the AEs for each set of indices - sample_indices[i:] with i in {0, ..., n-1} - """ - cdef: - intp_t n = sample_weight.size - WeightedHeap above = WeightedHeap(n, True) - WeightedHeap below = WeightedHeap(n, False) - intp_t k = 0 - intp_t start = 0 - intp_t end = n - float64_t[::1] abs_errors = np.zeros(n, dtype=np.float64) - float64_t[::1] medians = np.zeros(n, dtype=np.float64) - - if suffix: - start = n - 1 - end = -1 - - precompute_absolute_errors( - ys, sample_weight, sample_indices, above, below, - k, start, end, abs_errors, medians - ) - return np.asarray(abs_errors) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 5a226336b4b35..f86e41a2723ca 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -36,6 +36,7 @@ DENSE_SPLITTERS, SPARSE_SPLITTERS, ) +from sklearn.tree._criterion import _py_precompute_absolute_errors from sklearn.tree._partitioner import _py_sort from sklearn.tree._tree import ( NODE_DTYPE, @@ -47,7 +48,6 @@ _check_value_ndarray, ) from sklearn.tree._tree import Tree as CythonTree -from sklearn.tree._utils import _py_precompute_absolute_errors from sklearn.utils import compute_sample_weight from sklearn.utils._array_api import xpx from sklearn.utils._testing import ( From aa91439d68c1d41e85dce5f640a333b38413c1d9 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 14 Sep 2025 22:58:46 +0200 Subject: [PATCH 27/62] doing typos is my signature move, sorry for taht --- sklearn/tree/tests/test_tree.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index f86e41a2723ca..48f85d74423eb 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1663,7 +1663,7 @@ def test_no_sparse_y_support(name, csr_container): def test_mae(): - """Check MAE criterion produces correct results on small toys dataset: + """Check MAE criterion produces correct results on small toy datasets: ## First toy dataset ------------------ @@ -2888,7 +2888,7 @@ def test_sort_log2_build(): def test_absolute_errors_precomputation_function(): """ Test the main bit of logic of the MAE(RegressionCriterion) class - (used by DecisionTreeRegressor(criterion="asbolute_error")). + (used by DecisionTreeRegressor(criterion="absolute_error")). The implementation of the criterion relies on an efficient precomputation of left/right children absolute error for each split. This test verifies this From 450290adb7d0c00cc9cd5bc20f7c10de52465701 Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Mon, 15 Sep 2025 10:43:31 +0200 Subject: [PATCH 28/62] Update doc/modules/tree.rst --- doc/modules/tree.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index 8df9167ff1073..0f86ac4b14893 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -572,7 +572,7 @@ Mean Absolute Error: H(Q_m) = \frac{1}{n_m} \sum_{y \in Q_m} |y - median(y)_m| -Note that it fits slower than the MSE criterion. +Note that it is 3–6× slower to fit than the MSE criterion as of version 1.8. .. _tree_missing_value_support: From bc7685ec8b1a8ec331b2049bdc0f427aea471223 Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Mon, 15 Sep 2025 10:44:01 +0200 Subject: [PATCH 29/62] Add docstring for test_cython_weighted_heap_vs_heapq Co-authored-by: Adam Li --- sklearn/tree/tests/test_heap.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/tree/tests/test_heap.py b/sklearn/tree/tests/test_heap.py index 00521afc83832..6ea572c00d800 100644 --- a/sklearn/tree/tests/test_heap.py +++ b/sklearn/tree/tests/test_heap.py @@ -8,6 +8,10 @@ @pytest.mark.parametrize("min_heap", [True, False]) def test_cython_weighted_heap_vs_heapq(min_heap): + """Test Cython's weighted heap vs STL's heapq implementation. + + This unit-test first populates Cython Weighted Heap and STL's heap with weighted samples, and then compares values that are popped. + """ n = 200 w_heap = PytestWeightedHeap(n, min_heap=min_heap) py_heap = [] From 390731a2229012a13df5f408671328bc034bf7b7 Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Mon, 15 Sep 2025 10:55:31 +0200 Subject: [PATCH 30/62] Update comment about mem footprint --- sklearn/tree/_criterion.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 0e4b3358e5f89..35338a6ebc079 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1381,7 +1381,7 @@ cdef class MAE(Criterion): self.node_medians = np.zeros(n_outputs, dtype=np.float64) - # Note: this criterion has an important memory footprint, which is + # Note: this criterion has a n_samples x 64 bytes memory footprint, which is # fine as it's instantiated only once to build an entire tree self.left_abs_errors = np.empty(n_samples, dtype=np.float64) self.right_abs_errors = np.empty(n_samples, dtype=np.float64) From 1153cb5e04e214d7316df6e91dbd3939a4c4c1dd Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Mon, 15 Sep 2025 09:27:32 +0200 Subject: [PATCH 31/62] PERF: Decision trees: improve prefs by ~20% with very simple changes (#32181) --- sklearn/tree/_criterion.pyx | 31 ++++++++++++------------------- sklearn/tree/_partitioner.pyx | 14 +++++--------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 35338a6ebc079..76d8d91bc84b9 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -490,10 +490,6 @@ cdef class ClassificationCriterion(Criterion): # self.sample_indices[-self.n_missing:] that is # self.sample_indices[end_non_missing:self.end]. cdef intp_t end_non_missing = self.end - self.n_missing - - cdef const intp_t[:] sample_indices = self.sample_indices - cdef const float64_t[:] sample_weight = self.sample_weight - cdef intp_t i cdef intp_t p cdef intp_t k @@ -509,10 +505,10 @@ cdef class ClassificationCriterion(Criterion): # of computations, i.e. from pos to new_pos or from end to new_po. if (new_pos - pos) <= (end_non_missing - new_pos): for p in range(pos, new_pos): - i = sample_indices[p] + i = self.sample_indices[p] - if sample_weight is not None: - w = sample_weight[i] + if self.sample_weight is not None: + w = self.sample_weight[i] for k in range(self.n_outputs): self.sum_left[k, self.y[i, k]] += w @@ -523,10 +519,10 @@ cdef class ClassificationCriterion(Criterion): self.reverse_reset() for p in range(end_non_missing - 1, new_pos - 1, -1): - i = sample_indices[p] + i = self.sample_indices[p] - if sample_weight is not None: - w = sample_weight[i] + if self.sample_weight is not None: + w = self.sample_weight[i] for k in range(self.n_outputs): self.sum_left[k, self.y[i, k]] -= w @@ -964,9 +960,6 @@ cdef class RegressionCriterion(Criterion): cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left.""" - cdef const float64_t[:] sample_weight = self.sample_weight - cdef const intp_t[:] sample_indices = self.sample_indices - cdef intp_t pos = self.pos # The missing samples are assumed to be in @@ -987,10 +980,10 @@ cdef class RegressionCriterion(Criterion): # of computations, i.e. from pos to new_pos or from end to new_pos. if (new_pos - pos) <= (end_non_missing - new_pos): for p in range(pos, new_pos): - i = sample_indices[p] + i = self.sample_indices[p] - if sample_weight is not None: - w = sample_weight[i] + if self.sample_weight is not None: + w = self.sample_weight[i] for k in range(self.n_outputs): self.sum_left[k] += w * self.y[i, k] @@ -1000,10 +993,10 @@ cdef class RegressionCriterion(Criterion): self.reverse_reset() for p in range(end_non_missing - 1, new_pos - 1, -1): - i = sample_indices[p] + i = self.sample_indices[p] - if sample_weight is not None: - w = sample_weight[i] + if self.sample_weight is not None: + w = self.sample_weight[i] for k in range(self.n_outputs): self.sum_left[k] -= w * self.y[i, k] diff --git a/sklearn/tree/_partitioner.pyx b/sklearn/tree/_partitioner.pyx index 7c342ed3a7d6b..5cec6073d74f1 100644 --- a/sklearn/tree/_partitioner.pyx +++ b/sklearn/tree/_partitioner.pyx @@ -171,13 +171,11 @@ cdef class DensePartitioner: The missing values are not included when iterating through the feature values. """ - cdef: - float32_t[::1] feature_values = self.feature_values - intp_t end_non_missing = self.end - self.n_missing + cdef intp_t end_non_missing = self.end - self.n_missing while ( p[0] + 1 < end_non_missing and - feature_values[p[0] + 1] <= feature_values[p[0]] + FEATURE_THRESHOLD + self.feature_values[p[0] + 1] <= self.feature_values[p[0]] + FEATURE_THRESHOLD ): p[0] += 1 @@ -398,9 +396,7 @@ cdef class SparsePartitioner: cdef inline void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil: """Compute the next p_prev and p for iterating over feature values.""" - cdef: - intp_t p_next - float32_t[::1] feature_values = self.feature_values + cdef intp_t p_next if p[0] + 1 != self.end_negative: p_next = p[0] + 1 @@ -408,7 +404,7 @@ cdef class SparsePartitioner: p_next = self.start_positive while (p_next < self.end and - feature_values[p_next] <= feature_values[p[0]] + FEATURE_THRESHOLD): + self.feature_values[p_next] <= self.feature_values[p[0]] + FEATURE_THRESHOLD): p[0] = p_next if p[0] + 1 != self.end_negative: p_next = p[0] + 1 @@ -489,7 +485,7 @@ cdef class SparsePartitioner: """ cdef intp_t[::1] samples = self.samples cdef float32_t[::1] feature_values = self.feature_values - cdef intp_t indptr_start = self.X_indptr[feature], + cdef intp_t indptr_start = self.X_indptr[feature] cdef intp_t indptr_end = self.X_indptr[feature + 1] cdef intp_t n_indices = (indptr_end - indptr_start) cdef intp_t n_samples = self.end - self.start From 0f6d896ae3758d8147eb034201533f2d5cd910b6 Mon Sep 17 00:00:00 2001 From: scikit-learn-bot Date: Mon, 15 Sep 2025 10:26:00 +0200 Subject: [PATCH 32/62] :lock: :robot: CI Update lock files for main CI build(s) :lock: :robot: (#32187) Co-authored-by: Lock file bot --- ...ylatest_conda_forge_mkl_linux-64_conda.lock | 18 +++++++++--------- ...conda_forge_mkl_no_openmp_osx-64_conda.lock | 11 ++++++----- .../pylatest_conda_forge_mkl_osx-64_conda.lock | 11 ++++++----- ...est_pip_openblas_pandas_linux-64_conda.lock | 4 ++-- ...enblas_min_dependencies_linux-64_conda.lock | 6 +++--- ...ymin_conda_forge_openblas_win-64_conda.lock | 4 ++-- build_tools/circle/doc_linux-64_conda.lock | 6 +++--- .../doc_min_dependencies_linux-64_conda.lock | 8 ++++---- ...in_conda_forge_arm_linux-aarch64_conda.lock | 2 +- 9 files changed, 36 insertions(+), 34 deletions(-) diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock index c943249bdb94b..63e3b5fff73e5 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock @@ -39,7 +39,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libmpdec-4.0.0-hb9d3cd8_0.conda# https://conda.anaconda.org/conda-forge/linux-64/libntlm-1.8-hb9d3cd8_0.conda#7c7927b404672409d9917d49bff5f2d6 https://conda.anaconda.org/conda-forge/linux-64/libpciaccess-0.18-hb9d3cd8_0.conda#70e3400cbbfa03e96dcde7fc13e38c7b https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.1.0-h8f9b012_5.conda#4e02a49aaa9d5190cb630fa43528fbe6 -https://conda.anaconda.org/conda-forge/linux-64/libutf8proc-2.10.0-h202a827_0.conda#0f98f3e95272d118f7931b6bef69bfe5 +https://conda.anaconda.org/conda-forge/linux-64/libutf8proc-2.11.0-hb04c3b8_0.conda#34fb73fd2d5a613d8f17ce2eaa15a8a5 https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.41.1-he9a06e4_0.conda#af930c65e9a79a3423d6d36e265cef65 https://conda.anaconda.org/conda-forge/linux-64/libuv-1.51.0-hb03c661_1.conda#0f03292cc56bf91a077a134ea8747118 https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.6.0-hd42ef1d_0.conda#aea31d2e5b1091feca96fcfe945c3cf9 @@ -141,7 +141,7 @@ https://conda.anaconda.org/conda-forge/noarch/pip-25.2-pyh145f28c_0.conda#e7ab34 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda#7da7ccd349dbf6487a7778579d2bb971 https://conda.anaconda.org/conda-forge/noarch/pybind11-global-3.0.1-pyhc7ab6ef_0.conda#fe10b422ce8b5af5dab3740e4084c3f9 https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.4-pyhcf101f3_0.conda#bf1f1292fc78307956289707e85cb1bf https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2025.2-pyhd8ed1ab_0.conda#88476ae6ebd24f39261e0854ac244f33 https://conda.anaconda.org/conda-forge/noarch/pytz-2025.2-pyhd8ed1ab_0.conda#bc8e3267d44011051f2eb14d22fb0960 https://conda.anaconda.org/conda-forge/linux-64/re2-2025.08.12-h5301d42_1.conda#4637c13ff87424af0f6a981ab6f5ffa5 @@ -221,30 +221,30 @@ https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-11.4.5-h15599e2_0.conda https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-35_hfdb39a5_mkl.conda#9fedd782400297fa574e739146f04e34 https://conda.anaconda.org/conda-forge/linux-64/mkl-devel-2024.2.2-ha770c72_17.conda#e67269e07e58be5672f06441316f05f2 https://conda.anaconda.org/conda-forge/linux-64/polars-1.33.1-default_h755bcc6_0.conda#1884a1a6acc457c8e4b59b0f6450e140 -https://conda.anaconda.org/conda-forge/linux-64/libarrow-21.0.0-hb708d0b_2_cpu.conda#f602a99e9fbe7aa3952620c6ec979cbc +https://conda.anaconda.org/conda-forge/linux-64/libarrow-21.0.0-hb708d0b_3_cpu.conda#2d0305c8802fcba095d8d4e14e66ed3b https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-35_h372d94f_mkl.conda#25fab7e2988299928dea5939d9958293 https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-35_hc41d3b0_mkl.conda#5b4f86e5bc48d347eaf1ca2d180780ad https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.9.2-h3fc9a0a_0.conda#70b5132b6e8a65198c2f9d5552c41126 -https://conda.anaconda.org/conda-forge/linux-64/libarrow-compute-21.0.0-hebab434_2_cpu.conda#c04669cb6fbb4b000f281d327ca7d66c +https://conda.anaconda.org/conda-forge/linux-64/libarrow-compute-21.0.0-h8c2c5c3_3_cpu.conda#b0b73752adfcbe6b73ef9f2eb5d5cf03 https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-35_hbc6e62b_mkl.conda#426313fe1dc5ad3060efea56253fcd76 -https://conda.anaconda.org/conda-forge/linux-64/libparquet-21.0.0-h790f06f_2_cpu.conda#3ac1cb5e3b76f399d29a5542a64184eb +https://conda.anaconda.org/conda-forge/linux-64/libparquet-21.0.0-h790f06f_3_cpu.conda#0568ba99a1f6c0ef7a04ca23dc78905a https://conda.anaconda.org/conda-forge/linux-64/libtorch-2.7.1-cpu_mkl_hf38bc2d_103.conda#cc613cc921fe87d8ecda7a7c8fafc097 https://conda.anaconda.org/conda-forge/linux-64/numpy-2.3.3-py313hf6604e3_0.conda#3122d20dc438287e125fb5acff1df170 https://conda.anaconda.org/conda-forge/linux-64/pyside6-6.9.2-py313ha3f37dd_1.conda#e2ec46ec4c607b97623e7b691ad31c54 https://conda.anaconda.org/conda-forge/noarch/array-api-strict-2.4.1-pyhe01879c_0.conda#648e253c455718227c61e26f4a4ce701 https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-35_hcf00494_mkl.conda#bbbe147bcbe26b14cfbd5975dd45c79d https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.3.3-py313h7037e92_2.conda#6c8b4c12099023fcd85e520af74fd755 -https://conda.anaconda.org/conda-forge/linux-64/libarrow-acero-21.0.0-h635bf11_2_cpu.conda#2d8a7987166fd16c22f1cfdc78c8fdb5 +https://conda.anaconda.org/conda-forge/linux-64/libarrow-acero-21.0.0-h635bf11_3_cpu.conda#12fe67afbd946adae49856b275478d0f https://conda.anaconda.org/conda-forge/linux-64/pandas-2.3.2-py313h08cd8bf_0.conda#5f4cc42e08d6d862b7b919a3c8959e0b https://conda.anaconda.org/conda-forge/linux-64/pyarrow-core-21.0.0-py313he109ebe_0_cpu.conda#3018b7f30825c21c47a7a1e061459f96 https://conda.anaconda.org/conda-forge/linux-64/pytorch-2.7.1-cpu_mkl_py313_h58dab0e_103.conda#14fd59c6195a9d61987cf42e138b1a92 -https://conda.anaconda.org/conda-forge/linux-64/scipy-1.16.1-py313h11c21cd_1.conda#270039a4640693aab11ee3c05385f149 +https://conda.anaconda.org/conda-forge/linux-64/scipy-1.16.2-py313h11c21cd_0.conda#85a80978a04be9c290b8fe6d9bccff1c https://conda.anaconda.org/conda-forge/noarch/scipy-doctest-2.0.1-pyhe01879c_0.conda#303ec962addf1b6016afd536e9db6bc6 https://conda.anaconda.org/conda-forge/linux-64/blas-2.135-mkl.conda#629ac47dbe946d9a709d4187baa6286d -https://conda.anaconda.org/conda-forge/linux-64/libarrow-dataset-21.0.0-h635bf11_2_cpu.conda#a510fbf01cf40904ccb4983110b901cb +https://conda.anaconda.org/conda-forge/linux-64/libarrow-dataset-21.0.0-h635bf11_3_cpu.conda#630dfffcaf67b800607164d4b5b08bf7 https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.10.6-py313h683a580_1.conda#0483ab1c5b6956442195742a5df64196 https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.3.0-py313hfaae9d9_1.conda#6d308eafec3de495f6b06ebe69c990ed https://conda.anaconda.org/conda-forge/linux-64/pytorch-cpu-2.7.1-cpu_mkl_hc60beec_103.conda#5832b21e4193b05a096a8db177b14031 -https://conda.anaconda.org/conda-forge/linux-64/libarrow-substrait-21.0.0-h3f74fd7_2_cpu.conda#dea8c0e2c635238b52aafda31d935073 +https://conda.anaconda.org/conda-forge/linux-64/libarrow-substrait-21.0.0-h3f74fd7_3_cpu.conda#595ca398ad8dcac76a315f358e3312a6 https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.10.6-py313h78bf25f_1.conda#a2644c545b6afde06f4847defc1a2b27 https://conda.anaconda.org/conda-forge/linux-64/pyarrow-21.0.0-py313h78bf25f_0.conda#1580ddd94606ccb60270877cb8838562 diff --git a/build_tools/azure/pylatest_conda_forge_mkl_no_openmp_osx-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_no_openmp_osx-64_conda.lock index f5645da7e6ec4..19038c88c0686 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_no_openmp_osx-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_no_openmp_osx-64_conda.lock @@ -7,6 +7,7 @@ https://conda.anaconda.org/conda-forge/noarch/python_abi-3.13-8_cp313.conda#9430 https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda#4222072737ccff51314b5ece9c7d6f5a https://conda.anaconda.org/conda-forge/osx-64/bzip2-1.0.8-h500dc9f_8.conda#97c4b3bd8a90722104798175a1bdddbf https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2025.8.3-hbd8a1cb_0.conda#74784ee3d225fc3dca89edb635b4e5cc +https://conda.anaconda.org/conda-forge/osx-64/icu-75.1-h120a0e1_0.conda#d68d48a3060eb5abdc1cdc8e2a3a5966 https://conda.anaconda.org/conda-forge/osx-64/libbrotlicommon-1.1.0-h1c43f85_4.conda#b8e1ee78815e0ba7835de4183304f96b https://conda.anaconda.org/conda-forge/osx-64/libcxx-21.1.1-h3d58e20_0.conda#7f5b7dfca71a5c165ce57f46e9e48480 https://conda.anaconda.org/conda-forge/osx-64/libdeflate-1.24-hcc1b750_0.conda#f0a46c359722a3e84deb05cd4072d153 @@ -31,7 +32,7 @@ https://conda.anaconda.org/conda-forge/osx-64/libgfortran5-15.1.0-hfa3c126_1.con https://conda.anaconda.org/conda-forge/osx-64/libpng-1.6.50-h84aeda2_1.conda#1fe32bb16991a24e112051cc0de89847 https://conda.anaconda.org/conda-forge/osx-64/libsqlite-3.50.4-h39a8b3b_0.conda#156bfb239b6a67ab4a01110e6718cbc4 https://conda.anaconda.org/conda-forge/osx-64/libxcb-1.17.0-hf1f96e2_0.conda#bbeca862892e2898bdb45792a61c4afc -https://conda.anaconda.org/conda-forge/osx-64/libxml2-16-2.14.6-h0ad03eb_1.conda#ef63fdd968a169e77caec7a0de620b2f +https://conda.anaconda.org/conda-forge/osx-64/libxml2-16-2.14.6-ha1d9b0f_2.conda#bce2f90c94826aaf5e9e170732d79fbc https://conda.anaconda.org/conda-forge/osx-64/ninja-1.13.1-h0ba0a54_0.conda#71576ca895305a20c73304fcb581ae1a https://conda.anaconda.org/conda-forge/osx-64/openssl-3.5.2-h6e31bce_0.conda#22f5d63e672b7ba467969e9f8b740ecd https://conda.anaconda.org/conda-forge/osx-64/qhull-2020.2-h3c5361c_5.conda#dd1ea9ff27c93db7c01a7b7656bd4ad4 @@ -42,7 +43,7 @@ https://conda.anaconda.org/conda-forge/osx-64/brotli-bin-1.1.0-h1c43f85_4.conda# https://conda.anaconda.org/conda-forge/osx-64/libfreetype6-2.14.0-h6912278_1.conda#ebfad8c56f5a71f57ec7c6fb2333458e https://conda.anaconda.org/conda-forge/osx-64/libgfortran-15.1.0-h5f6db21_1.conda#07cfad6b37da6e79349c6e3a0316a83b https://conda.anaconda.org/conda-forge/osx-64/libtiff-4.7.0-h59ddb5d_6.conda#1cb7b8054ffa9460ca3dd782062f3074 -https://conda.anaconda.org/conda-forge/osx-64/libxml2-2.14.6-h23bb396_1.conda#d9c72f0570422288880e1845b4c9bd9c +https://conda.anaconda.org/conda-forge/osx-64/libxml2-2.14.6-h7b7ecba_2.conda#191678d5ac5d2b30cb26458776b33900 https://conda.anaconda.org/conda-forge/osx-64/python-3.13.7-h5eba815_100_cp313.conda#1759e1c9591755521bd50489756a599d https://conda.anaconda.org/conda-forge/osx-64/brotli-1.1.0-h1c43f85_4.conda#1a0a37da4466d45c00fc818bb6b446b3 https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda#962b9857ee8e7018c22f2776ffa0b2d7 @@ -54,7 +55,7 @@ https://conda.anaconda.org/conda-forge/osx-64/kiwisolver-1.4.9-py313hb91e98b_1.c https://conda.anaconda.org/conda-forge/osx-64/lcms2-2.17-h72f5680_0.conda#bf210d0c63f2afb9e414a858b79f0eaa https://conda.anaconda.org/conda-forge/osx-64/libfreetype-2.14.0-h694c41f_1.conda#5b44e5691928a99306a20aa53afb86fd https://conda.anaconda.org/conda-forge/osx-64/libhiredis-1.0.2-h2beb688_0.tar.bz2#524282b2c46c9dedf051b3bc2ae05494 -https://conda.anaconda.org/conda-forge/osx-64/libhwloc-2.12.1-default_h094e1f9_1001.conda#75d7759422b200b38ccd24a2fc34ca55 +https://conda.anaconda.org/conda-forge/osx-64/libhwloc-2.12.1-default_h094e1f9_1002.conda#4d9e9610b6a16291168144842cd9cae2 https://conda.anaconda.org/conda-forge/noarch/meson-1.9.0-pyhcf101f3_0.conda#288989b6c775fa4181eb433114472274 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyhd8ed1ab_1.conda#37293a85a0f4f77bbd9cf7aaefc62609 https://conda.anaconda.org/conda-forge/osx-64/openjpeg-2.5.3-h036ada5_1.conda#38f264b121a043cf379980c959fb2d75 @@ -62,7 +63,7 @@ https://conda.anaconda.org/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda# https://conda.anaconda.org/conda-forge/noarch/pip-25.2-pyh145f28c_0.conda#e7ab34d5a93e0819b62563c78635d937 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda#7da7ccd349dbf6487a7778579d2bb971 https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.4-pyhcf101f3_0.conda#bf1f1292fc78307956289707e85cb1bf https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2025.2-pyhd8ed1ab_0.conda#88476ae6ebd24f39261e0854ac244f33 https://conda.anaconda.org/conda-forge/noarch/pytz-2025.2-pyhd8ed1ab_0.conda#bc8e3267d44011051f2eb14d22fb0960 https://conda.anaconda.org/conda-forge/noarch/setuptools-80.9.0-pyhff2d567_0.conda#4de79c071274a53dcaf2a8c749d1499e @@ -96,7 +97,7 @@ https://conda.anaconda.org/conda-forge/osx-64/numpy-2.3.3-py313ha99c057_0.conda# https://conda.anaconda.org/conda-forge/osx-64/blas-devel-3.9.0-20_osx64_mkl.conda#cc3260179093918b801e373c6e888e02 https://conda.anaconda.org/conda-forge/osx-64/contourpy-1.3.3-py313hc551f4f_2.conda#51eb4d5f1de7beda42425e430364165b https://conda.anaconda.org/conda-forge/osx-64/pandas-2.3.2-py313h366a99e_0.conda#31a66209f11793d320c1344f466d3d37 -https://conda.anaconda.org/conda-forge/osx-64/scipy-1.16.1-py313hf2e9e4d_1.conda#0acfa7f16b706fed7238e5b67d4e5abf +https://conda.anaconda.org/conda-forge/osx-64/scipy-1.16.2-py313h61f8160_0.conda#bce2603cfeb56dde6e7f1257975c8e03 https://conda.anaconda.org/conda-forge/osx-64/blas-2.120-mkl.conda#b041a7677a412f3d925d8208936cb1e2 https://conda.anaconda.org/conda-forge/osx-64/matplotlib-base-3.10.6-py313h4ad75b8_1.conda#ea88ae8e6f51e16c2b9353575a973a49 https://conda.anaconda.org/conda-forge/osx-64/pyamg-5.3.0-py313h7f78831_1.conda#1a6f985147e1a3ee3db88a56a7968fdb diff --git a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock index a7f3b13e3657c..33a705d143a2f 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock @@ -8,6 +8,7 @@ https://conda.anaconda.org/conda-forge/noarch/python_abi-3.13-8_cp313.conda#9430 https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda#4222072737ccff51314b5ece9c7d6f5a https://conda.anaconda.org/conda-forge/osx-64/bzip2-1.0.8-h500dc9f_8.conda#97c4b3bd8a90722104798175a1bdddbf https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2025.8.3-hbd8a1cb_0.conda#74784ee3d225fc3dca89edb635b4e5cc +https://conda.anaconda.org/conda-forge/osx-64/icu-75.1-h120a0e1_0.conda#d68d48a3060eb5abdc1cdc8e2a3a5966 https://conda.anaconda.org/conda-forge/osx-64/libbrotlicommon-1.1.0-h1c43f85_4.conda#b8e1ee78815e0ba7835de4183304f96b https://conda.anaconda.org/conda-forge/osx-64/libcxx-21.1.1-h3d58e20_0.conda#7f5b7dfca71a5c165ce57f46e9e48480 https://conda.anaconda.org/conda-forge/osx-64/libdeflate-1.24-hcc1b750_0.conda#f0a46c359722a3e84deb05cd4072d153 @@ -35,7 +36,7 @@ https://conda.anaconda.org/conda-forge/osx-64/libgfortran5-15.1.0-hfa3c126_1.con https://conda.anaconda.org/conda-forge/osx-64/libpng-1.6.50-h84aeda2_1.conda#1fe32bb16991a24e112051cc0de89847 https://conda.anaconda.org/conda-forge/osx-64/libsqlite-3.50.4-h39a8b3b_0.conda#156bfb239b6a67ab4a01110e6718cbc4 https://conda.anaconda.org/conda-forge/osx-64/libxcb-1.17.0-hf1f96e2_0.conda#bbeca862892e2898bdb45792a61c4afc -https://conda.anaconda.org/conda-forge/osx-64/libxml2-16-2.14.6-h0ad03eb_1.conda#ef63fdd968a169e77caec7a0de620b2f +https://conda.anaconda.org/conda-forge/osx-64/libxml2-16-2.14.6-ha1d9b0f_2.conda#bce2f90c94826aaf5e9e170732d79fbc https://conda.anaconda.org/conda-forge/osx-64/ninja-1.13.1-h0ba0a54_0.conda#71576ca895305a20c73304fcb581ae1a https://conda.anaconda.org/conda-forge/osx-64/openssl-3.5.2-h6e31bce_0.conda#22f5d63e672b7ba467969e9f8b740ecd https://conda.anaconda.org/conda-forge/osx-64/qhull-2020.2-h3c5361c_5.conda#dd1ea9ff27c93db7c01a7b7656bd4ad4 @@ -48,7 +49,7 @@ https://conda.anaconda.org/conda-forge/osx-64/brotli-bin-1.1.0-h1c43f85_4.conda# https://conda.anaconda.org/conda-forge/osx-64/libfreetype6-2.14.0-h6912278_1.conda#ebfad8c56f5a71f57ec7c6fb2333458e https://conda.anaconda.org/conda-forge/osx-64/libgfortran-15.1.0-h5f6db21_1.conda#07cfad6b37da6e79349c6e3a0316a83b https://conda.anaconda.org/conda-forge/osx-64/libtiff-4.7.0-h59ddb5d_6.conda#1cb7b8054ffa9460ca3dd782062f3074 -https://conda.anaconda.org/conda-forge/osx-64/libxml2-2.14.6-h23bb396_1.conda#d9c72f0570422288880e1845b4c9bd9c +https://conda.anaconda.org/conda-forge/osx-64/libxml2-2.14.6-h7b7ecba_2.conda#191678d5ac5d2b30cb26458776b33900 https://conda.anaconda.org/conda-forge/osx-64/mpfr-4.2.1-haed47dc_3.conda#d511e58aaaabfc23136880d9956fa7a6 https://conda.anaconda.org/conda-forge/osx-64/python-3.13.7-h5eba815_100_cp313.conda#1759e1c9591755521bd50489756a599d https://conda.anaconda.org/conda-forge/osx-64/sigtool-0.1.3-h88f4db0_0.tar.bz2#fbfb84b9de9a6939cb165c02c69b1865 @@ -62,7 +63,7 @@ https://conda.anaconda.org/conda-forge/osx-64/kiwisolver-1.4.9-py313hb91e98b_1.c https://conda.anaconda.org/conda-forge/osx-64/lcms2-2.17-h72f5680_0.conda#bf210d0c63f2afb9e414a858b79f0eaa https://conda.anaconda.org/conda-forge/osx-64/libfreetype-2.14.0-h694c41f_1.conda#5b44e5691928a99306a20aa53afb86fd https://conda.anaconda.org/conda-forge/osx-64/libhiredis-1.0.2-h2beb688_0.tar.bz2#524282b2c46c9dedf051b3bc2ae05494 -https://conda.anaconda.org/conda-forge/osx-64/libhwloc-2.12.1-default_h094e1f9_1001.conda#75d7759422b200b38ccd24a2fc34ca55 +https://conda.anaconda.org/conda-forge/osx-64/libhwloc-2.12.1-default_h094e1f9_1002.conda#4d9e9610b6a16291168144842cd9cae2 https://conda.anaconda.org/conda-forge/osx-64/libllvm19-19.1.7-h56e7563_2.conda#05a54b479099676e75f80ad0ddd38eff https://conda.anaconda.org/conda-forge/noarch/meson-1.9.0-pyhcf101f3_0.conda#288989b6c775fa4181eb433114472274 https://conda.anaconda.org/conda-forge/osx-64/mpc-1.3.1-h9d8efa1_1.conda#0520855aaae268ea413d6bc913f1384c @@ -72,7 +73,7 @@ https://conda.anaconda.org/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda# https://conda.anaconda.org/conda-forge/noarch/pip-25.2-pyh145f28c_0.conda#e7ab34d5a93e0819b62563c78635d937 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda#7da7ccd349dbf6487a7778579d2bb971 https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.4-pyhcf101f3_0.conda#bf1f1292fc78307956289707e85cb1bf https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2025.2-pyhd8ed1ab_0.conda#88476ae6ebd24f39261e0854ac244f33 https://conda.anaconda.org/conda-forge/noarch/pytz-2025.2-pyhd8ed1ab_0.conda#bc8e3267d44011051f2eb14d22fb0960 https://conda.anaconda.org/conda-forge/noarch/setuptools-80.9.0-pyhff2d567_0.conda#4de79c071274a53dcaf2a8c749d1499e @@ -120,7 +121,7 @@ https://conda.anaconda.org/conda-forge/osx-64/blas-devel-3.9.0-20_osx64_mkl.cond https://conda.anaconda.org/conda-forge/osx-64/clang_impl_osx-64-19.1.7-hc73cdc9_25.conda#76954503be09430fb7f4683a61ffb7b0 https://conda.anaconda.org/conda-forge/osx-64/contourpy-1.3.3-py313hc551f4f_2.conda#51eb4d5f1de7beda42425e430364165b https://conda.anaconda.org/conda-forge/osx-64/pandas-2.3.2-py313h366a99e_0.conda#31a66209f11793d320c1344f466d3d37 -https://conda.anaconda.org/conda-forge/osx-64/scipy-1.16.1-py313hf2e9e4d_1.conda#0acfa7f16b706fed7238e5b67d4e5abf +https://conda.anaconda.org/conda-forge/osx-64/scipy-1.16.2-py313h61f8160_0.conda#bce2603cfeb56dde6e7f1257975c8e03 https://conda.anaconda.org/conda-forge/osx-64/blas-2.120-mkl.conda#b041a7677a412f3d925d8208936cb1e2 https://conda.anaconda.org/conda-forge/osx-64/clang_osx-64-19.1.7-h7e5c614_25.conda#a526ba9df7e7d5448d57b33941614dae https://conda.anaconda.org/conda-forge/osx-64/matplotlib-base-3.10.6-py313h4ad75b8_1.conda#ea88ae8e6f51e16c2b9353575a973a49 diff --git a/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock b/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock index 72c3f48d1d093..d6c86a8d86921 100644 --- a/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock +++ b/build_tools/azure/pylatest_pip_openblas_pandas_linux-64_conda.lock @@ -57,7 +57,7 @@ https://conda.anaconda.org/conda-forge/linux-64/ccache-4.11.3-h80c52d3_0.conda#e # pip pillow @ https://files.pythonhosted.org/packages/d5/1c/a2a29649c0b1983d3ef57ee87a66487fdeb45132df66ab30dd37f7dbe162/pillow-11.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=13f87d581e71d9189ab21fe0efb5a23e9f28552d5be6979e84001d3b8505abe8 # pip pluggy @ https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl#sha256=e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 # pip pygments @ https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl#sha256=86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b -# pip pyparsing @ https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl#sha256=a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf +# pip pyparsing @ https://files.pythonhosted.org/packages/53/b8/fbab973592e23ae313042d450fc26fa24282ebffba21ba373786e1ce63b4/pyparsing-3.2.4-py3-none-any.whl#sha256=91d0fcde680d42cd031daf3a6ba20da3107e08a75de50da58360e7d94ab24d36 # pip pytz @ https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl#sha256=5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00 # pip roman-numerals-py @ https://files.pythonhosted.org/packages/53/97/d2cbbaa10c9b826af0e10fdf836e1bf344d9f0abb873ebc34d1f49642d3f/roman_numerals_py-3.1.0-py3-none-any.whl#sha256=9da2ad2fb670bcf24e81070ceb3be72f6c11c440d73bd579fbeca1e9f330954c # pip six @ https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl#sha256=4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 @@ -81,7 +81,7 @@ https://conda.anaconda.org/conda-forge/linux-64/ccache-4.11.3-h80c52d3_0.conda#e # pip pytest @ https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl#sha256=872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79 # pip python-dateutil @ https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl#sha256=a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 # pip requests @ https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl#sha256=2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6 -# pip scipy @ https://files.pythonhosted.org/packages/e4/82/08e4076df538fb56caa1d489588d880ec7c52d8273a606bb54d660528f7c/scipy-1.16.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl#sha256=fedc2cbd1baed37474b1924c331b97bdff611d762c196fac1a9b71e67b813b1b +# pip scipy @ https://files.pythonhosted.org/packages/da/6a/1a927b14ddc7714111ea51f4e568203b2bb6ed59bdd036d62127c1a360c8/scipy-1.16.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl#sha256=c2275ff105e508942f99d4e3bc56b6ef5e4b3c0af970386ca56b777608ce95b7 # pip tifffile @ https://files.pythonhosted.org/packages/48/c5/0d57e3547add58285f401afbc421bd3ffeddbbd275a2c0b980b9067fda4a/tifffile-2025.9.9-py3-none-any.whl#sha256=239247551fa10b5679036ee030cdbeb7762bc1b3f11b1ddaaf50759ef8b4eb26 # pip lightgbm @ https://files.pythonhosted.org/packages/42/86/dabda8fbcb1b00bcfb0003c3776e8ade1aa7b413dff0a2c08f457dace22f/lightgbm-4.6.0-py3-none-manylinux_2_28_x86_64.whl#sha256=cb19b5afea55b5b61cbb2131095f50538bd608a00655f23ad5d25ae3e3bf1c8d # pip matplotlib @ https://files.pythonhosted.org/packages/e5/b8/9eea6630198cb303d131d95d285a024b3b8645b1763a2916fddb44ca8760/matplotlib-3.10.6-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl#sha256=84e82d9e0fd70c70bc55739defbd8055c54300750cbacf4740c9673a24d6933a diff --git a/build_tools/azure/pymin_conda_forge_openblas_min_dependencies_linux-64_conda.lock b/build_tools/azure/pymin_conda_forge_openblas_min_dependencies_linux-64_conda.lock index 3b2931d4a3705..0a040bbc5448f 100644 --- a/build_tools/azure/pymin_conda_forge_openblas_min_dependencies_linux-64_conda.lock +++ b/build_tools/azure/pymin_conda_forge_openblas_min_dependencies_linux-64_conda.lock @@ -108,7 +108,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-3.21.12-hfc55251_2.c https://conda.anaconda.org/conda-forge/linux-64/libthrift-0.18.1-h8fd135c_2.conda#bbf65f7688512872f063810623b755dc https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.0-h8261f1e_6.conda#b6093922931b535a7ba566b6f384fbe6 https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda#318b08df404f9c9be5712aaa5a6f0bb0 -https://conda.anaconda.org/conda-forge/linux-64/nss-3.115-hc3c8bcf_0.conda#c8873d2f90ad15aaec7be6926f11b53d +https://conda.anaconda.org/conda-forge/linux-64/nss-3.116-h445c969_0.conda#deaf54211251a125c27aff34871124c3 https://conda.anaconda.org/conda-forge/linux-64/python-3.10.18-hd6af730_0_cpython.conda#4ea0c77cdcb0b81813a0436b162d7316 https://conda.anaconda.org/conda-forge/linux-64/rdma-core-28.9-h59595ed_1.conda#aeffb7c06b5f65e55e6c637408dc4100 https://conda.anaconda.org/conda-forge/linux-64/re2-2023.03.02-h8c504da_0.conda#206f8fa808748f6e90599c3368a1114e @@ -157,7 +157,7 @@ https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda#7d https://conda.anaconda.org/conda-forge/noarch/ply-3.11-pyhd8ed1ab_3.conda#fd5062942bfa1b0bd5e0d2a4397b099e https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyh29332c3_1.conda#12c566707c80111f9799308d9e265aef https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.4-pyhcf101f3_0.conda#bf1f1292fc78307956289707e85cb1bf https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha55dd90_7.conda#461219d1a5bd61342293efa2c0c90eac https://conda.anaconda.org/conda-forge/noarch/pytz-2025.2-pyhd8ed1ab_0.conda#bc8e3267d44011051f2eb14d22fb0960 https://conda.anaconda.org/conda-forge/noarch/setuptools-80.9.0-pyhff2d567_0.conda#4de79c071274a53dcaf2a8c749d1499e @@ -230,7 +230,7 @@ https://conda.anaconda.org/conda-forge/noarch/urllib3-2.5.0-pyhd8ed1ab_0.conda#4 https://conda.anaconda.org/conda-forge/linux-64/aws-crt-cpp-0.20.2-h2a5cb19_18.conda#7313674073496cec938f73b71163bc31 https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-20_linux64_openblas.conda#9932a1d4e9ecf2d35fb19475446e361e https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.3.2-py310h3788b33_0.conda#b6420d29123c7c823de168f49ccdfe6a -https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-11.4.5-h15599e2_0.conda#1276ae4aa3832a449fcb4253c30da4bc +https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-11.5.0-h15599e2_0.conda#47599428437d622bfee24fbd06a2d0b4 https://conda.anaconda.org/conda-forge/linux-64/pandas-1.5.0-py310h769672d_0.tar.bz2#06efc4b5f4b418b78de14d1db4a65cad https://conda.anaconda.org/conda-forge/linux-64/polars-0.20.30-py310h031f9ce_0.conda#0743f5db9f978b6df92d412935ff8371 https://conda.anaconda.org/conda-forge/noarch/requests-2.32.5-pyhd8ed1ab_0.conda#db0c6b99149880c8ba515cf4abe93ee4 diff --git a/build_tools/azure/pymin_conda_forge_openblas_win-64_conda.lock b/build_tools/azure/pymin_conda_forge_openblas_win-64_conda.lock index fb1a5b635ee72..8eb95bfc313a5 100644 --- a/build_tools/azure/pymin_conda_forge_openblas_win-64_conda.lock +++ b/build_tools/azure/pymin_conda_forge_openblas_win-64_conda.lock @@ -62,7 +62,7 @@ https://conda.anaconda.org/conda-forge/noarch/execnet-2.1.1-pyhd8ed1ab_1.conda#a https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda#6837f3eff7dcea42ecd714ce1ac2b108 https://conda.anaconda.org/conda-forge/win-64/kiwisolver-1.4.9-py310h1e1005b_1.conda#a0695050d0379e201f0c40b89d3b58dd https://conda.anaconda.org/conda-forge/win-64/libcblas-3.9.0-35_h2a8eebe_openblas.conda#b319a1bffa6c2c8ba7f6c8f12a40d898 -https://conda.anaconda.org/conda-forge/win-64/libclang13-21.1.0-default_ha2db4b5_1.conda#9065d254995bd88bda60c77c77fcad3d +https://conda.anaconda.org/conda-forge/win-64/libclang13-21.1.1-default_ha2db4b5_0.conda#17f5b2e04b696f148b1b8ff1d5d55b75 https://conda.anaconda.org/conda-forge/win-64/libfreetype6-2.14.0-hdbac1cb_1.conda#10dd24f0c2a81775f09952badfb52019 https://conda.anaconda.org/conda-forge/win-64/libglib-2.84.3-h1c1036b_0.conda#2bcc00752c158d4a70e1eaccbf6fe8ae https://conda.anaconda.org/conda-forge/win-64/liblapack-3.9.0-35_hd232482_openblas.conda#e446e419a887c9e0a04fee684f9b0551 @@ -74,7 +74,7 @@ https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyhd8ed1ab_1.conda#3 https://conda.anaconda.org/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda#58335b26c38bf4a20f399384c33cbcf9 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda#7da7ccd349dbf6487a7778579d2bb971 https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.4-pyhcf101f3_0.conda#bf1f1292fc78307956289707e85cb1bf https://conda.anaconda.org/conda-forge/noarch/setuptools-80.9.0-pyhff2d567_0.conda#4de79c071274a53dcaf2a8c749d1499e https://conda.anaconda.org/conda-forge/noarch/six-1.17.0-pyhe01879c_1.conda#3339e3b65d58accf4ca4fb8748ab16b3 https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda#9d64911b31d57ca443e9f1e36b04385f diff --git a/build_tools/circle/doc_linux-64_conda.lock b/build_tools/circle/doc_linux-64_conda.lock index 402e9d686db0c..7e799a7c51356 100644 --- a/build_tools/circle/doc_linux-64_conda.lock +++ b/build_tools/circle/doc_linux-64_conda.lock @@ -98,7 +98,7 @@ https://conda.anaconda.org/conda-forge/linux-64/icu-75.1-he02047a_0.conda#8b1893 https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.3-h659f571_0.conda#3f43953b7d3fb3aaa1d0d0723d91e368 https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.0-h73754d4_1.conda#df6bf113081fdea5b363eb5a7a5ceb69 https://conda.anaconda.org/conda-forge/linux-64/libglib-2.84.3-hf39c6af_0.conda#467f23819b1ea2b89c3fc94d65082301 -https://conda.anaconda.org/conda-forge/linux-64/libjxl-0.11.1-h0a47e8d_3.conda#509f4010a8345b36c81fa795dffcd25a +https://conda.anaconda.org/conda-forge/linux-64/libjxl-0.11.1-h6cb5226_4.conda#f2840d9c2afb19e303e126c9d3a04b36 https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.0-h8261f1e_6.conda#b6093922931b535a7ba566b6f384fbe6 https://conda.anaconda.org/conda-forge/linux-64/libzopfli-1.0.3-h9c3ff4c_0.tar.bz2#c66fe2d123249af7651ebde8984c51c2 https://conda.anaconda.org/conda-forge/linux-64/python-3.10.18-hd6af730_0_cpython.conda#4ea0c77cdcb0b81813a0436b162d7316 @@ -149,7 +149,7 @@ https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.2-py310h89163eb_1 https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda#592132998493b3ff25fd7479396e8351 https://conda.anaconda.org/conda-forge/noarch/meson-1.9.0-pyhcf101f3_0.conda#288989b6c775fa4181eb433114472274 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyhd8ed1ab_1.conda#37293a85a0f4f77bbd9cf7aaefc62609 -https://conda.anaconda.org/conda-forge/noarch/narwhals-2.4.0-pyhcf101f3_0.conda#bc703ec04a2f051e89522821489fac26 +https://conda.anaconda.org/conda-forge/noarch/narwhals-2.5.0-pyhcf101f3_0.conda#c64dc3b3e0c804e0f1213abd46c1705d https://conda.anaconda.org/conda-forge/noarch/networkx-3.4.2-pyh267e887_2.conda#fd40bf7f7f4bc4b647dc8512053d9873 https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.3-h55fea9a_1.conda#01243c4aaf71bde0297966125aea4706 https://conda.anaconda.org/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda#58335b26c38bf4a20f399384c33cbcf9 @@ -162,7 +162,7 @@ https://conda.anaconda.org/conda-forge/linux-64/psutil-7.0.0-py310h7c4b9e2_1.con https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd8ed1ab_1.conda#7d9daffbb8d8e0af0f769dbbcd173a54 https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyh29332c3_1.conda#12c566707c80111f9799308d9e265aef https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.4-pyhcf101f3_0.conda#bf1f1292fc78307956289707e85cb1bf https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha55dd90_7.conda#461219d1a5bd61342293efa2c0c90eac https://conda.anaconda.org/conda-forge/noarch/python-fastjsonschema-2.21.2-pyhe01879c_0.conda#23029aae904a2ba587daba708208012f https://conda.anaconda.org/conda-forge/noarch/python-json-logger-2.0.7-pyhd8ed1ab_0.conda#a61bf9ec79426938ff785eb69dbb1960 diff --git a/build_tools/circle/doc_min_dependencies_linux-64_conda.lock b/build_tools/circle/doc_min_dependencies_linux-64_conda.lock index e19944f705e5a..ffe916b1ac18b 100644 --- a/build_tools/circle/doc_min_dependencies_linux-64_conda.lock +++ b/build_tools/circle/doc_min_dependencies_linux-64_conda.lock @@ -113,10 +113,10 @@ https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-lib-1.11.1-hb9d3cd8_0. https://conda.anaconda.org/conda-forge/linux-64/libgettextpo-devel-0.25.1-h3f43e3d_1.conda#3f7a43b3160ec0345c9535a9f0d7908e https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-15.1.0-h69a702a_5.conda#41a5893c957ffed7f82b4005bc24866c https://conda.anaconda.org/conda-forge/linux-64/libglib-2.86.0-h1fed272_0.conda#b8e4c93f4ab70c3b6f6499299627dbdc -https://conda.anaconda.org/conda-forge/linux-64/libjxl-0.11.1-h0a47e8d_3.conda#509f4010a8345b36c81fa795dffcd25a +https://conda.anaconda.org/conda-forge/linux-64/libjxl-0.11.1-h6cb5226_4.conda#f2840d9c2afb19e303e126c9d3a04b36 https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.0-h8261f1e_6.conda#b6093922931b535a7ba566b6f384fbe6 https://conda.anaconda.org/conda-forge/linux-64/libzopfli-1.0.3-h9c3ff4c_0.tar.bz2#c66fe2d123249af7651ebde8984c51c2 -https://conda.anaconda.org/conda-forge/linux-64/nss-3.115-hc3c8bcf_0.conda#c8873d2f90ad15aaec7be6926f11b53d +https://conda.anaconda.org/conda-forge/linux-64/nss-3.116-h445c969_0.conda#deaf54211251a125c27aff34871124c3 https://conda.anaconda.org/conda-forge/linux-64/python-3.10.18-hd6af730_0_cpython.conda#4ea0c77cdcb0b81813a0436b162d7316 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.1-h4f16b4b_2.conda#fdc27cb255a7a2cc73b7919a968b48f0 https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.1-hb711507_0.conda#ad748ccca349aec3e91743e08b5e2b50 @@ -170,7 +170,7 @@ https://conda.anaconda.org/conda-forge/noarch/ply-3.11-pyhd8ed1ab_3.conda#fd5062 https://conda.anaconda.org/conda-forge/linux-64/psutil-7.0.0-py310h7c4b9e2_1.conda#165e1696a6859b5cd915f9486f171ace https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyh29332c3_1.conda#12c566707c80111f9799308d9e265aef https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.4-pyhcf101f3_0.conda#bf1f1292fc78307956289707e85cb1bf https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha55dd90_7.conda#461219d1a5bd61342293efa2c0c90eac https://conda.anaconda.org/conda-forge/noarch/pytz-2025.2-pyhd8ed1ab_0.conda#bc8e3267d44011051f2eb14d22fb0960 https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.2-py310h89163eb_2.conda#fd343408e64cf1e273ab7c710da374db @@ -256,7 +256,7 @@ https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-3.8.0-pyhd8ed1ab_0.co https://conda.anaconda.org/conda-forge/noarch/towncrier-24.8.0-pyhd8ed1ab_1.conda#820b6a1ddf590fba253f8204f7200d82 https://conda.anaconda.org/conda-forge/noarch/urllib3-2.5.0-pyhd8ed1ab_0.conda#436c165519e140cb08d246a4472a9d6a https://conda.anaconda.org/conda-forge/linux-64/compilers-1.11.0-ha770c72_0.conda#fdcf2e31dd960ef7c5daa9f2c95eff0e -https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-11.4.5-h15599e2_0.conda#1276ae4aa3832a449fcb4253c30da4bc +https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-11.5.0-h15599e2_0.conda#47599428437d622bfee24fbd06a2d0b4 https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-35_hfdb39a5_mkl.conda#9fedd782400297fa574e739146f04e34 https://conda.anaconda.org/conda-forge/linux-64/mkl-devel-2024.2.2-ha770c72_17.conda#e67269e07e58be5672f06441316f05f2 https://conda.anaconda.org/conda-forge/noarch/requests-2.32.5-pyhd8ed1ab_0.conda#db0c6b99149880c8ba515cf4abe93ee4 diff --git a/build_tools/github/pymin_conda_forge_arm_linux-aarch64_conda.lock b/build_tools/github/pymin_conda_forge_arm_linux-aarch64_conda.lock index 78d0aeb19d706..b05a36f821507 100644 --- a/build_tools/github/pymin_conda_forge_arm_linux-aarch64_conda.lock +++ b/build_tools/github/pymin_conda_forge_arm_linux-aarch64_conda.lock @@ -102,7 +102,7 @@ https://conda.anaconda.org/conda-forge/linux-aarch64/openjpeg-2.5.3-h5da879a_1.c https://conda.anaconda.org/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda#58335b26c38bf4a20f399384c33cbcf9 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda#7da7ccd349dbf6487a7778579d2bb971 https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.4-pyhcf101f3_0.conda#bf1f1292fc78307956289707e85cb1bf https://conda.anaconda.org/conda-forge/noarch/setuptools-80.9.0-pyhff2d567_0.conda#4de79c071274a53dcaf2a8c749d1499e https://conda.anaconda.org/conda-forge/noarch/six-1.17.0-pyhe01879c_1.conda#3339e3b65d58accf4ca4fb8748ab16b3 https://conda.anaconda.org/conda-forge/noarch/threadpoolctl-3.6.0-pyhecae5ae_0.conda#9d64911b31d57ca443e9f1e36b04385f From f48a2a464a325fce15373529f716bf81b595ff37 Mon Sep 17 00:00:00 2001 From: scikit-learn-bot Date: Mon, 15 Sep 2025 10:26:33 +0200 Subject: [PATCH 33/62] :lock: :robot: CI Update lock files for array-api CI build(s) :lock: :robot: (#32186) Co-authored-by: Lock file bot --- ...a_forge_cuda_array-api_linux-64_conda.lock | 77 ++++++++++--------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/build_tools/github/pylatest_conda_forge_cuda_array-api_linux-64_conda.lock b/build_tools/github/pylatest_conda_forge_cuda_array-api_linux-64_conda.lock index 71b04c3147b6c..97762fc25efb8 100644 --- a/build_tools/github/pylatest_conda_forge_cuda_array-api_linux-64_conda.lock +++ b/build_tools/github/pylatest_conda_forge_cuda_array-api_linux-64_conda.lock @@ -8,7 +8,9 @@ https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed3 https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2#4d59c254e01d9cde7957100457e2d5fb https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-h77eed37_3.conda#49023d73832ef61042f6a237cb2687e7 https://conda.anaconda.org/conda-forge/noarch/kernel-headers_linux-64-4.18.0-he073ed8_8.conda#ff007ab0f0fdc53d245972bba8a6d40c +https://conda.anaconda.org/conda-forge/linux-64/libopentelemetry-cpp-headers-1.18.0-ha770c72_1.conda#4fb055f57404920a43b147031471e03b https://conda.anaconda.org/conda-forge/linux-64/mkl-include-2024.2.2-ha770c72_17.conda#c18fd07c02239a7eb744ea728db39630 +https://conda.anaconda.org/conda-forge/linux-64/nlohmann_json-3.12.0-h3f2d84a_0.conda#d76872d096d063e226482c99337209dc https://conda.anaconda.org/conda-forge/noarch/python_abi-3.13-8_cp313.conda#94305520c52a4aa3f6c2b1ff6008d9f8 https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda#4222072737ccff51314b5ece9c7d6f5a https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2025.8.3-hbd8a1cb_0.conda#74784ee3d225fc3dca89edb635b4e5cc @@ -23,7 +25,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libegl-1.7.0-ha4b6fd6_2.conda#c1 https://conda.anaconda.org/conda-forge/linux-64/libopengl-1.7.0-ha4b6fd6_2.conda#7df50d44d4a14d6c31a2c54f2cd92157 https://conda.anaconda.org/conda-forge/linux-64/libgcc-15.1.0-h767d61c_5.conda#264fbfba7fb20acf3b29cde153e345ce https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.14-hb9d3cd8_0.conda#76df83c2a9035c54df5d04ff81bcc02d -https://conda.anaconda.org/conda-forge/linux-64/aws-c-common-0.10.6-hb9d3cd8_0.conda#d7d4680337a14001b0e043e96529409b +https://conda.anaconda.org/conda-forge/linux-64/aws-c-common-0.12.0-hb9d3cd8_0.conda#f65c946f28f0518f41ced702f44c52b7 https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hda65f42_8.conda#51a19bba1b8ebfb60df25cde030b7ebc https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.34.5-hb9d3cd8_0.conda#f7f0d6cc2dc986d42ac2689ec88192be https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.3-hb9d3cd8_0.conda#b38117a3c920364aff79f870c984b4a3 @@ -40,7 +42,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libmpdec-4.0.0-hb9d3cd8_0.conda# https://conda.anaconda.org/conda-forge/linux-64/libntlm-1.8-hb9d3cd8_0.conda#7c7927b404672409d9917d49bff5f2d6 https://conda.anaconda.org/conda-forge/linux-64/libpciaccess-0.18-hb9d3cd8_0.conda#70e3400cbbfa03e96dcde7fc13e38c7b https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.1.0-h8f9b012_5.conda#4e02a49aaa9d5190cb630fa43528fbe6 -https://conda.anaconda.org/conda-forge/linux-64/libutf8proc-2.9.0-hb9d3cd8_1.conda#1e936bd23d737aac62a18e9a1e7f8b18 +https://conda.anaconda.org/conda-forge/linux-64/libutf8proc-2.10.0-h202a827_0.conda#0f98f3e95272d118f7931b6bef69bfe5 https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.41.1-he9a06e4_0.conda#af930c65e9a79a3423d6d36e265cef65 https://conda.anaconda.org/conda-forge/linux-64/libuv-1.51.0-hb03c661_1.conda#0f03292cc56bf91a077a134ea8747118 https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.6.0-hd42ef1d_0.conda#aea31d2e5b1091feca96fcfe945c3cf9 @@ -51,10 +53,10 @@ https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002. https://conda.anaconda.org/conda-forge/linux-64/xorg-libice-1.1.2-hb9d3cd8_0.conda#fb901ff28063514abb6046c9ec2c4a45 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.12-hb9d3cd8_0.conda#f6ebe2cb3f82ba6c057dde5d9debe4f7 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxdmcp-1.1.5-hb9d3cd8_0.conda#8035c64cb77ed555e3f150b7b3972480 -https://conda.anaconda.org/conda-forge/linux-64/aws-c-cal-0.8.1-h1a47875_3.conda#55a8561fdbbbd34f50f57d9be12ed084 -https://conda.anaconda.org/conda-forge/linux-64/aws-c-compression-0.3.0-h4e1184b_5.conda#3f4c1197462a6df2be6dc8241828fe93 -https://conda.anaconda.org/conda-forge/linux-64/aws-c-sdkutils-0.2.1-h4e1184b_4.conda#a5126a90e74ac739b00564a4c7ddcc36 -https://conda.anaconda.org/conda-forge/linux-64/aws-checksums-0.2.2-h4e1184b_4.conda#74e8c3e4df4ceae34aa2959df4b28101 +https://conda.anaconda.org/conda-forge/linux-64/aws-c-cal-0.8.7-h043a21b_0.conda#4fdf835d66ea197e693125c64fbd4482 +https://conda.anaconda.org/conda-forge/linux-64/aws-c-compression-0.3.1-h3870646_2.conda#17ccde79d864e6183a83c5bbb8fff34d +https://conda.anaconda.org/conda-forge/linux-64/aws-c-sdkutils-0.2.3-h3870646_2.conda#06008b5ab42117c89c982aa2a32a5b25 +https://conda.anaconda.org/conda-forge/linux-64/aws-checksums-0.2.3-h3870646_2.conda#303d9e83e0518f1dcb66e90054635ca6 https://conda.anaconda.org/conda-forge/linux-64/double-conversion-3.3.1-h5888daf_0.conda#bfd56492d8346d669010eccafe0ba058 https://conda.anaconda.org/conda-forge/linux-64/gflags-2.2.2-h5888daf_1005.conda#d411fc29e338efb48c5fd4576d71d881 https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.14-hecca717_2.conda#2cd94587f3a401ae05e03a6caf09539d @@ -78,14 +80,15 @@ https://conda.anaconda.org/conda-forge/linux-64/ninja-1.13.1-h171cf75_0.conda#65 https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.45-hc749103_0.conda#b90bece58b4c2bf25969b70f3be42d25 https://conda.anaconda.org/conda-forge/linux-64/pixman-0.46.4-h54a6638_1.conda#c01af13bdc553d1a8fbfff6e8db075f0 https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8c095d6_2.conda#283b96675859b20a825f8fa30f311446 -https://conda.anaconda.org/conda-forge/linux-64/s2n-1.5.11-h072c03f_0.conda#5e8060d52f676a40edef0006a75c718f +https://conda.anaconda.org/conda-forge/linux-64/s2n-1.5.14-h6c98b2b_0.conda#efab4ad81ba5731b2fefa0ab4359e884 https://conda.anaconda.org/conda-forge/linux-64/sleef-3.9.0-ha0421bc_0.conda#e8a0b4f5e82ecacffaa5e805020473cb https://conda.anaconda.org/conda-forge/linux-64/snappy-1.2.2-h03e3b7b_0.conda#3d8da0248bdae970b4ade636a104b7f5 https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_hd72426e_102.conda#a0116df4f4ed05c303811a837d5b39d8 https://conda.anaconda.org/conda-forge/linux-64/wayland-1.24.0-h3e06ad9_0.conda#0f2ca7906bf166247d1d760c3422cb8a https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.6-he73a12e_0.conda#1c74ff8c35dcadf952a16f752ca5aa49 +https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda#c9f075ab2f33b3bbee9e62d4ad0a6cd8 https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda#6432cb5d4ac0046c3ac0a8a0f95842f9 -https://conda.anaconda.org/conda-forge/linux-64/aws-c-io-0.15.3-h173a860_6.conda#9a063178f1af0a898526cc24ba7be486 +https://conda.anaconda.org/conda-forge/linux-64/aws-c-io-0.17.0-h3dad3f2_6.conda#3a127d28266cdc0da93384d1f59fe8df https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.1.0-hb03c661_4.conda#ca4ed8015764937c81b830f7f5b68543 https://conda.anaconda.org/conda-forge/linux-64/cudatoolkit-11.8.0-h4ba93d1_13.conda#eb43f5f1f16e2fad2eba22219c3e499b https://conda.anaconda.org/conda-forge/linux-64/glog-0.7.1-hbabe93e_0.conda#ff862eebdfeb2fd048ae9dc92510baca @@ -97,7 +100,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.0-h73754d4_1.c https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-15.1.0-h69a702a_5.conda#41a5893c957ffed7f82b4005bc24866c https://conda.anaconda.org/conda-forge/linux-64/libglib-2.84.3-hf39c6af_0.conda#467f23819b1ea2b89c3fc94d65082301 https://conda.anaconda.org/conda-forge/linux-64/libnghttp2-1.67.0-had1ee68_0.conda#b499ce4b026493a13774bcf0f4c33849 -https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-5.28.2-h5b01275_0.conda#ab0bff36363bec94720275a681af8b83 +https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-5.28.3-h6128344_1.conda#d8703f1ffe5a06356f06467f1d0b9464 https://conda.anaconda.org/conda-forge/linux-64/libre2-11-2024.07.02-hbbce691_2.conda#b2fede24428726dd867611664fb372e8 https://conda.anaconda.org/conda-forge/linux-64/libthrift-0.21.0-h0e7cc3e_0.conda#dcb95c0a98ba9ff737f7ae482aef7833 https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.0-h8261f1e_6.conda#b6093922931b535a7ba566b6f384fbe6 @@ -109,8 +112,8 @@ https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.1-hb711507_ https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.10-hb711507_0.conda#0e0cbe0564d03a99afd5fd7b362feecd https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.2-hb711507_0.conda#608e0ef8256b81d04456e8d211eee3e8 https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.12-h4f16b4b_0.conda#db038ce880f100acc74dba10302b5630 -https://conda.anaconda.org/conda-forge/linux-64/aws-c-event-stream-0.5.0-h7959bf6_11.conda#9b3fb60fe57925a92f399bc3fc42eccf -https://conda.anaconda.org/conda-forge/linux-64/aws-c-http-0.9.2-hefd7a92_4.conda#5ce4df662d32d3123ea8da15571b6f51 +https://conda.anaconda.org/conda-forge/linux-64/aws-c-event-stream-0.5.4-h04a3f94_2.conda#81096a80f03fc2f0fb2a230f5d028643 +https://conda.anaconda.org/conda-forge/linux-64/aws-c-http-0.9.4-hb9b18c6_4.conda#773c99d0dbe2b3704af165f97ff399e5 https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hb03c661_4.conda#eaf3fbd2aa97c212336de38a51fe404e https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda#962b9857ee8e7018c22f2776ffa0b2d7 https://conda.anaconda.org/conda-forge/noarch/cpython-3.13.7-py313hd8ed1ab_100.conda#c5623ddbd37c5dafa7754a83f97de01e @@ -139,12 +142,12 @@ https://conda.anaconda.org/conda-forge/noarch/mpmath-1.3.0-pyhd8ed1ab_1.conda#35 https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyhd8ed1ab_1.conda#37293a85a0f4f77bbd9cf7aaefc62609 https://conda.anaconda.org/conda-forge/noarch/networkx-3.5-pyhe01879c_0.conda#16bff3d37a4f99e3aa089c36c2b8d650 https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.3-h55fea9a_1.conda#01243c4aaf71bde0297966125aea4706 -https://conda.anaconda.org/conda-forge/linux-64/orc-2.0.3-h97ab989_1.conda#2f46eae652623114e112df13fae311cf +https://conda.anaconda.org/conda-forge/linux-64/orc-2.1.1-h2271f48_0.conda#67075ef2cb33079efee3abfe58127a3b https://conda.anaconda.org/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda#58335b26c38bf4a20f399384c33cbcf9 https://conda.anaconda.org/conda-forge/noarch/pip-25.2-pyh145f28c_0.conda#e7ab34d5a93e0819b62563c78635d937 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda#7da7ccd349dbf6487a7778579d2bb971 https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 -https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d +https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.4-pyhcf101f3_0.conda#bf1f1292fc78307956289707e85cb1bf https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2025.2-pyhd8ed1ab_0.conda#88476ae6ebd24f39261e0854ac244f33 https://conda.anaconda.org/conda-forge/noarch/pytz-2025.2-pyhd8ed1ab_0.conda#bc8e3267d44011051f2eb14d22fb0960 https://conda.anaconda.org/conda-forge/linux-64/re2-2024.07.02-h9925aae_2.conda#e84ddf12bde691e8ec894b00ea829ddf @@ -160,8 +163,8 @@ https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.45-hb9d3cd8_0 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.6-hb9d3cd8_0.conda#febbab7d15033c913d53c7a2c102309d https://conda.anaconda.org/conda-forge/linux-64/xorg-libxfixes-6.0.1-hb9d3cd8_0.conda#4bdb303603e9821baf5fe5fdff1dc8f8 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.12-hb9d3cd8_0.conda#96d57aba173e878a2089d5638016dc5e -https://conda.anaconda.org/conda-forge/linux-64/aws-c-auth-0.8.0-hb921021_15.conda#c79d50f64cffa5ad51ecc1a81057962f -https://conda.anaconda.org/conda-forge/linux-64/aws-c-mqtt-0.11.0-h11f4f37_12.conda#96c3e0221fa2da97619ee82faa341a73 +https://conda.anaconda.org/conda-forge/linux-64/aws-c-auth-0.8.6-hd08a7f5_4.conda#f5a770ac1fd2cb34b21327fc513013a7 +https://conda.anaconda.org/conda-forge/linux-64/aws-c-mqtt-0.12.2-h108da3e_2.conda#90e07c8bac8da6378ee1882ef0a9374a https://conda.anaconda.org/conda-forge/linux-64/azure-core-cpp-1.14.0-h5cfcd09_0.conda#0a8838771cc2e985cd295e01ae83baf1 https://conda.anaconda.org/conda-forge/linux-64/ccache-4.11.3-h80c52d3_0.conda#eb517c6a2b960c3ccb6f1db1005f063a https://conda.anaconda.org/conda-forge/linux-64/coverage-7.10.6-py313h3dea7bd_1.conda#7d28b9543d76f78ccb110a1fdf5a0762 @@ -172,7 +175,7 @@ https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.6-pyhd8ed1ab_0.conda#44 https://conda.anaconda.org/conda-forge/noarch/joblib-1.5.2-pyhd8ed1ab_0.conda#4e717929cfa0d49cef92d911e31d0e90 https://conda.anaconda.org/conda-forge/linux-64/libcudnn-dev-9.10.1.4-h0fdc2d1_0.conda#a0c0b44d26a4710e6ea577fcddbe09d1 https://conda.anaconda.org/conda-forge/linux-64/libgl-1.7.0-ha4b6fd6_2.conda#928b8be80851f5d8ffb016f9c81dae7a -https://conda.anaconda.org/conda-forge/linux-64/libgrpc-1.67.1-hc2c308b_0.conda#4606a4647bfe857e3cfe21ca12ac3afb +https://conda.anaconda.org/conda-forge/linux-64/libgrpc-1.67.1-h25350d4_2.conda#bfcedaf5f9b003029cc6abe9431f66bf https://conda.anaconda.org/conda-forge/linux-64/libhwloc-2.12.1-default_h3d81e11_1000.conda#d821210ab60be56dd27b5525ed18366d https://conda.anaconda.org/conda-forge/linux-64/libllvm20-20.1.8-hecd9e04_0.conda#59a7b967b6ef5d63029b1712f8dcf661 https://conda.anaconda.org/conda-forge/linux-64/libllvm21-21.1.0-hecd9e04_0.conda#9ad637a7ac380c442be142dfb0b1b955 @@ -181,6 +184,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libxslt-1.1.43-h7a3aeb2_0.conda# https://conda.anaconda.org/conda-forge/linux-64/mpc-1.3.1-h24ddda3_1.conda#aa14b9a5196a6d8dd364164b7ce56acf https://conda.anaconda.org/conda-forge/linux-64/openldap-2.6.10-he970967_0.conda#2e5bf4f1da39c0b32778561c3c4e5878 https://conda.anaconda.org/conda-forge/linux-64/pillow-11.3.0-py313hf46931b_1.conda#8c2259ea124159da6660cbc3e68e30a2 +https://conda.anaconda.org/conda-forge/linux-64/prometheus-cpp-1.3.0-ha5d0236_0.conda#a83f6a2fdc079e643237887a37460668 https://conda.anaconda.org/conda-forge/noarch/pyproject-metadata-0.9.1-pyhd8ed1ab_0.conda#22ae7c6ea81e0c8661ef32168dda929b https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0.post0-pyhe01879c_2.conda#5b8d21249ff20967101ffa321cab24e8 https://conda.anaconda.org/conda-forge/noarch/python-gil-3.13.7-h4df99d1_100.conda#47a123ca8e727d886a2c6d0c71658f8c @@ -192,7 +196,7 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-libxi-1.8.2-hb9d3cd8_0.cond https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrandr-1.5.4-hb9d3cd8_0.conda#2de7f99d6581a4a7adbff607b5c278ca https://conda.anaconda.org/conda-forge/linux-64/xorg-libxxf86vm-1.1.6-hb9d3cd8_0.conda#5efa5fa6243a622445fdfd72aee15efa https://conda.anaconda.org/conda-forge/noarch/_python_abi3_support-1.0-hd8ed1ab_2.conda#aaa2a381ccc56eac91d63b6c1240312f -https://conda.anaconda.org/conda-forge/linux-64/aws-c-s3-0.7.7-hf454442_0.conda#947c82025693bebd557f782bb5d6b469 +https://conda.anaconda.org/conda-forge/linux-64/aws-c-s3-0.7.13-h822ba82_2.conda#9cf2c3c13468f2209ee814be2c88655f https://conda.anaconda.org/conda-forge/linux-64/azure-identity-cpp-1.10.0-h113e628_0.conda#73f73f60854f325a55f1d31459f2ab73 https://conda.anaconda.org/conda-forge/linux-64/azure-storage-common-cpp-12.8.0-h736e048_1.conda#13de36be8de3ae3f05ba127631599213 https://conda.anaconda.org/conda-forge/linux-64/cudnn-9.10.1.4-haad7af6_0.conda#8382d957333e0d3280dcbf5691516dc1 @@ -200,53 +204,54 @@ https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.15.0-h7e30c49_1.con https://conda.anaconda.org/conda-forge/linux-64/gmpy2-2.2.1-py313h86d8783_1.conda#c9bc12b70b0c422e937945694e7cf6c0 https://conda.anaconda.org/conda-forge/linux-64/libclang-cpp20.1-20.1.8-default_h99862b1_1.conda#d6ff2e232c817e377856130eaceb7d2d https://conda.anaconda.org/conda-forge/linux-64/libclang13-21.1.0-default_h746c552_1.conda#327c78a8ce710782425a89df851392f7 -https://conda.anaconda.org/conda-forge/linux-64/libgoogle-cloud-2.32.0-h804f50b_0.conda#3d96df4d6b1c88455e05b94ce8a14a53 +https://conda.anaconda.org/conda-forge/linux-64/libgoogle-cloud-2.36.0-h2b5623c_0.conda#c96ca58ad3352a964bfcb85de6cd1496 +https://conda.anaconda.org/conda-forge/linux-64/libopentelemetry-cpp-1.18.0-hfcad708_1.conda#1f5a5d66e77a39dc5bd639ec953705cf https://conda.anaconda.org/conda-forge/linux-64/libpq-17.6-h3675c94_1.conda#bcee8587faf5dce5050a01817835eaed https://conda.anaconda.org/conda-forge/noarch/meson-python-0.18.0-pyh70fd9c4_0.conda#576c04b9d9f8e45285fb4d9452c26133 https://conda.anaconda.org/conda-forge/noarch/pytest-8.4.2-pyhd8ed1ab_0.conda#1f987505580cb972cf28dc5f74a0f81b https://conda.anaconda.org/conda-forge/linux-64/tbb-2021.13.0-hb60516a_3.conda#aa15aae38fd752855ca03a68af7f40e2 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxtst-1.2.5-hb9d3cd8_3.conda#7bbe9a0cc0df0ac5f5a8ad6d6a11af2f -https://conda.anaconda.org/conda-forge/linux-64/aws-crt-cpp-0.29.7-hd92328a_7.conda#02b95564257d5c3db9c06beccf711f95 +https://conda.anaconda.org/conda-forge/linux-64/aws-crt-cpp-0.31.0-h55f77e1_4.conda#0627af705ed70681f5bede31e72348e5 https://conda.anaconda.org/conda-forge/linux-64/azure-storage-blobs-cpp-12.13.0-h3cf044e_1.conda#7eb66060455c7a47d9dcdbfa9f46579b https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.4-h3394656_0.conda#09262e66b19567aff4f592fb53b28760 -https://conda.anaconda.org/conda-forge/linux-64/libgoogle-cloud-storage-2.32.0-h0121fbd_0.conda#877a5ec0431a5af83bf0cd0522bfe661 +https://conda.anaconda.org/conda-forge/linux-64/libgoogle-cloud-storage-2.36.0-h0121fbd_0.conda#fc5efe1833a4d709953964037985bb72 https://conda.anaconda.org/conda-forge/linux-64/mkl-2024.2.2-ha770c72_17.conda#e4ab075598123e783b788b995afbdad0 https://conda.anaconda.org/conda-forge/linux-64/polars-default-1.33.1-py39hf521cc8_0.conda#900f486d119d5c83d14c812068a3ecad https://conda.anaconda.org/conda-forge/noarch/pytest-cov-6.3.0-pyhd8ed1ab_0.conda#50d191b852fccb4bf9ab7b59b030c99d https://conda.anaconda.org/conda-forge/noarch/pytest-xdist-3.8.0-pyhd8ed1ab_0.conda#8375cfbda7c57fbceeda18229be10417 https://conda.anaconda.org/conda-forge/noarch/sympy-1.14.0-pyh2585a3b_105.conda#8c09fac3785696e1c477156192d64b91 -https://conda.anaconda.org/conda-forge/linux-64/aws-sdk-cpp-1.11.458-hc430e4a_4.conda#aeefac461bea1f126653c1285cf5af08 +https://conda.anaconda.org/conda-forge/linux-64/aws-sdk-cpp-1.11.510-h37a5c72_3.conda#beb8577571033140c6897d257acc7724 https://conda.anaconda.org/conda-forge/linux-64/azure-storage-files-datalake-cpp-12.12.0-ha633028_1.conda#7c1980f89dd41b097549782121a73490 https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-11.4.5-h15599e2_0.conda#1276ae4aa3832a449fcb4253c30da4bc https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-35_hfdb39a5_mkl.conda#9fedd782400297fa574e739146f04e34 https://conda.anaconda.org/conda-forge/linux-64/mkl-devel-2024.2.2-ha770c72_17.conda#e67269e07e58be5672f06441316f05f2 https://conda.anaconda.org/conda-forge/linux-64/polars-1.33.1-default_h755bcc6_0.conda#1884a1a6acc457c8e4b59b0f6450e140 -https://conda.anaconda.org/conda-forge/linux-64/libarrow-18.1.0-h44a453e_6_cpu.conda#2cf6d608d6e66506f69797d5c6944c35 +https://conda.anaconda.org/conda-forge/linux-64/libarrow-19.0.1-hc7b3859_3_cpu.conda#9ed3ded6da29dec8417f2e1db68798f2 https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-35_h372d94f_mkl.conda#25fab7e2988299928dea5939d9958293 https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-35_hc41d3b0_mkl.conda#5b4f86e5bc48d347eaf1ca2d180780ad https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.9.2-h3fc9a0a_0.conda#70b5132b6e8a65198c2f9d5552c41126 -https://conda.anaconda.org/conda-forge/linux-64/libarrow-acero-18.1.0-hcb10f89_6_cpu.conda#143f9288b64759a6427563f058c62f2b +https://conda.anaconda.org/conda-forge/linux-64/libarrow-acero-19.0.1-hcb10f89_3_cpu.conda#8f8dc214d89e06933f1bc1dcd2310b9c https://conda.anaconda.org/conda-forge/linux-64/liblapacke-3.9.0-35_hbc6e62b_mkl.conda#426313fe1dc5ad3060efea56253fcd76 -https://conda.anaconda.org/conda-forge/linux-64/libmagma-2.8.0-h9ddd185_2.conda#8de40c4f75d36bb00a5870f682457f1d -https://conda.anaconda.org/conda-forge/linux-64/libparquet-18.1.0-h081d1f1_6_cpu.conda#68788df49ce7480187eb6387f15b2b67 +https://conda.anaconda.org/conda-forge/linux-64/libmagma-2.9.0-h45b15fe_0.conda#703a1ab01e36111d8bb40bc7517e900b +https://conda.anaconda.org/conda-forge/linux-64/libparquet-19.0.1-h081d1f1_3_cpu.conda#1d04307cdb1d8aeb5f55b047d5d403ea https://conda.anaconda.org/conda-forge/linux-64/numpy-2.3.3-py313hf6604e3_0.conda#3122d20dc438287e125fb5acff1df170 -https://conda.anaconda.org/conda-forge/linux-64/pyarrow-core-18.1.0-py313he5f92c8_0_cpu.conda#5380e12f4468e891911dbbd4248b521a +https://conda.anaconda.org/conda-forge/linux-64/pyarrow-core-19.0.1-py313he5f92c8_0_cpu.conda#7d8649531c807b24295c8f9a0a396a78 https://conda.anaconda.org/conda-forge/linux-64/pyside6-6.9.2-py313ha3f37dd_1.conda#e2ec46ec4c607b97623e7b691ad31c54 https://conda.anaconda.org/conda-forge/noarch/array-api-strict-2.4.1-pyhe01879c_0.conda#648e253c455718227c61e26f4a4ce701 https://conda.anaconda.org/conda-forge/linux-64/blas-devel-3.9.0-35_hcf00494_mkl.conda#bbbe147bcbe26b14cfbd5975dd45c79d https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.3.3-py313h7037e92_2.conda#6c8b4c12099023fcd85e520af74fd755 -https://conda.anaconda.org/conda-forge/linux-64/cupy-core-13.4.1-py313hc2a895b_0.conda#46dd595e816b278b178e3bef8a6acf71 -https://conda.anaconda.org/conda-forge/linux-64/libarrow-dataset-18.1.0-hcb10f89_6_cpu.conda#20ca46a6bc714a6ab189d5b3f46e66d8 -https://conda.anaconda.org/conda-forge/linux-64/libmagma_sparse-2.8.0-h9ddd185_0.conda#f4eb3cfeaf9d91e72d5b2b8706bf059f +https://conda.anaconda.org/conda-forge/linux-64/cupy-core-13.6.0-py313hc2a895b_2.conda#1b3207acc9af23dcfbccb4647df0838e +https://conda.anaconda.org/conda-forge/linux-64/libarrow-dataset-19.0.1-hcb10f89_3_cpu.conda#a28f04b6e68a1c76de76783108ad729d +https://conda.anaconda.org/conda-forge/linux-64/libmagma_sparse-2.9.0-h45b15fe_0.conda#beac0a5bbe0af75db6b16d3d8fd24f7e https://conda.anaconda.org/conda-forge/linux-64/pandas-2.3.2-py313h08cd8bf_0.conda#5f4cc42e08d6d862b7b919a3c8959e0b -https://conda.anaconda.org/conda-forge/linux-64/scipy-1.16.1-py313h11c21cd_1.conda#270039a4640693aab11ee3c05385f149 +https://conda.anaconda.org/conda-forge/linux-64/scipy-1.16.2-py313h11c21cd_0.conda#85a80978a04be9c290b8fe6d9bccff1c https://conda.anaconda.org/conda-forge/linux-64/blas-2.135-mkl.conda#629ac47dbe946d9a709d4187baa6286d -https://conda.anaconda.org/conda-forge/linux-64/cupy-13.4.1-py313h66a2ee2_0.conda#784d6bd149ef2b5d9c733ea3dd4d15ad -https://conda.anaconda.org/conda-forge/linux-64/libarrow-substrait-18.1.0-h3ee7192_6_cpu.conda#aa313b3168caf98d00b3753f5ba27650 -https://conda.anaconda.org/conda-forge/linux-64/libtorch-2.5.1-cuda118_hb34f2e8_303.conda#da799bf557ff6376a1a58f40bddfb293 +https://conda.anaconda.org/conda-forge/linux-64/cupy-13.6.0-py313h66a2ee2_2.conda#9d83bdb568a47daf7fc38117db17fe4e +https://conda.anaconda.org/conda-forge/linux-64/libarrow-substrait-19.0.1-h08228c5_3_cpu.conda#a58e4763af8293deaac77b63bc7804d8 +https://conda.anaconda.org/conda-forge/linux-64/libtorch-2.4.1-cuda118_mkl_hee7131c_306.conda#28b3b3da11973494ed0100aa50f47328 https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.10.6-py313h683a580_1.conda#0483ab1c5b6956442195742a5df64196 https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.3.0-py313hfaae9d9_1.conda#6d308eafec3de495f6b06ebe69c990ed https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.10.6-py313h78bf25f_1.conda#a2644c545b6afde06f4847defc1a2b27 -https://conda.anaconda.org/conda-forge/linux-64/pyarrow-18.1.0-py313h78bf25f_0.conda#a11d880ceedc33993c6f5c14a80ea9d3 -https://conda.anaconda.org/conda-forge/linux-64/pytorch-2.5.1-cuda118_py313h40cdc2d_303.conda#19ad990954a4ed89358d91d0a3e7016d -https://conda.anaconda.org/conda-forge/linux-64/pytorch-gpu-2.5.1-cuda126hf7c78f0_303.conda#afaf760e55725108ae78ed41198c49bb +https://conda.anaconda.org/conda-forge/linux-64/pyarrow-19.0.1-py313h78bf25f_0.conda#e8efe6998a383dd149787c83d3d6a92e +https://conda.anaconda.org/conda-forge/linux-64/pytorch-2.4.1-cuda118_mkl_py313_h909c4c2_306.conda#de6e45613bbdb51127e9ff483c31bf41 +https://conda.anaconda.org/conda-forge/linux-64/pytorch-gpu-2.4.1-cuda118_mkl_hf8a3b2d_306.conda#b1802a39f1ca7ebed5f8c35755bffec1 From b32df28056addbdc8a356130694b6fd24676433e Mon Sep 17 00:00:00 2001 From: scikit-learn-bot Date: Mon, 15 Sep 2025 10:26:59 +0200 Subject: [PATCH 34/62] :lock: :robot: CI Update lock files for free-threaded CI build(s) :lock: :robot: (#32185) Co-authored-by: Lock file bot --- build_tools/azure/pylatest_free_threaded_linux-64_conda.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/azure/pylatest_free_threaded_linux-64_conda.lock b/build_tools/azure/pylatest_free_threaded_linux-64_conda.lock index c5f9e95f5efca..ecd8eb15a0572 100644 --- a/build_tools/azure/pylatest_free_threaded_linux-64_conda.lock +++ b/build_tools/azure/pylatest_free_threaded_linux-64_conda.lock @@ -59,4 +59,4 @@ https://conda.anaconda.org/conda-forge/noarch/meson-python-0.18.0-pyh70fd9c4_0.c https://conda.anaconda.org/conda-forge/linux-64/numpy-2.3.3-py314hc30c27a_0.conda#f4359762e05d99518f79b6db512165af https://conda.anaconda.org/conda-forge/noarch/pytest-8.4.2-pyhd8ed1ab_0.conda#1f987505580cb972cf28dc5f74a0f81b https://conda.anaconda.org/conda-forge/noarch/pytest-run-parallel-0.6.1-pyhd8ed1ab_0.conda#4bc53a42b6c9f9f9e89b478d05091743 -https://conda.anaconda.org/conda-forge/linux-64/scipy-1.16.1-py314hf5b80f4_1.conda#857ebbdc0884bc9bcde1a8bd2d5d842c +https://conda.anaconda.org/conda-forge/linux-64/scipy-1.16.2-py314hf5b80f4_0.conda#392a136bd42c5f4b3ec8417c5432da23 From 6cdacd1b38813d88ea57448af9c60aa6f14950c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Mon, 15 Sep 2025 10:27:54 +0200 Subject: [PATCH 35/62] TST Fix the error message in test_min_dependencies_readme (#32149) --- sklearn/tests/test_min_dependencies_readme.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_min_dependencies_readme.py b/sklearn/tests/test_min_dependencies_readme.py index d6e8e138c4fe8..289b395afd78c 100644 --- a/sklearn/tests/test_min_dependencies_readme.py +++ b/sklearn/tests/test_min_dependencies_readme.py @@ -60,7 +60,7 @@ def test_min_dependencies_readme(): min_version = parse_version(dependent_packages[package][0]) message = ( - f"{package} has inconsistent minimum versions in pyproject.toml and" + f"{package} has inconsistent minimum versions in README.rst and" f" _min_depencies.py: {version} != {min_version}" ) assert version == min_version, message From 3a85d5cdfeacb4d8d504ec5709fc7a79be42e3ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Mon, 15 Sep 2025 10:31:28 +0200 Subject: [PATCH 36/62] Revert "API make murmurhash3_32 private (#32103)" (#32131) --- doc/api_reference.py | 2 +- .../sklearn.utils/32103.api.rst | 3 - sklearn/utils/murmurhash.pyx | 41 -------------- sklearn/utils/tests/test_murmurhash.py | 56 ++++++++----------- 4 files changed, 25 insertions(+), 77 deletions(-) delete mode 100644 doc/whats_new/upcoming_changes/sklearn.utils/32103.api.rst diff --git a/doc/api_reference.py b/doc/api_reference.py index e9c4cbb65284d..896f28ff1cf26 100644 --- a/doc/api_reference.py +++ b/doc/api_reference.py @@ -1351,4 +1351,4 @@ def _get_submodule(module_name, submodule_name): } """ -DEPRECATED_API_REFERENCE = {"1.8.0": ["utils.murmurhash3_32"]} # type: ignore[var-annotated] +DEPRECATED_API_REFERENCE = {} # type: ignore[var-annotated] diff --git a/doc/whats_new/upcoming_changes/sklearn.utils/32103.api.rst b/doc/whats_new/upcoming_changes/sklearn.utils/32103.api.rst deleted file mode 100644 index 6ed761a3b5f37..0000000000000 --- a/doc/whats_new/upcoming_changes/sklearn.utils/32103.api.rst +++ /dev/null @@ -1,3 +0,0 @@ -- The function :function:`utils.murmurhash.murmurhash3_32` is now deprecated and will be - removed in version 1.10. - By :user:`François Paugam `. diff --git a/sklearn/utils/murmurhash.pyx b/sklearn/utils/murmurhash.pyx index c7112ae245f81..fee239acd98fb 100644 --- a/sklearn/utils/murmurhash.pyx +++ b/sklearn/utils/murmurhash.pyx @@ -17,8 +17,6 @@ from ..utils._typedefs cimport int32_t, uint32_t import numpy as np -from sklearn.utils.deprecation import deprecated - cdef extern from "src/MurmurHash3.h": void MurmurHash3_x86_32(void *key, int len, uint32_t seed, void *out) void MurmurHash3_x86_128(void *key, int len, uint32_t seed, void *out) @@ -81,18 +79,9 @@ def _murmurhash3_bytes_array_s32( return np.asarray(out) -# TODO(1.10): remove -@deprecated( - "Function `murmurhash3_32` was deprecated in 1.8 and will be " - "removed in 1.10." -) def murmurhash3_32(key, seed=0, positive=False): """Compute the 32bit murmurhash3 of key at seed. - .. deprecated:: 1.8 - Function `murmurhash3_32` was deprecated in 1.8.0 and will be - removed in 1.10.0. - The underlying implementation is MurmurHash3_x86_32 generating low latency 32bits hash suitable for implementing lookup tables, Bloom filters, count min sketch or feature hashing. @@ -117,36 +106,6 @@ def murmurhash3_32(key, seed=0, positive=False): >>> murmurhash3_32(b"Hello World!", seed=42) 3565178 """ - return _murmurhash3_32(key, seed, positive) - - -def _murmurhash3_32(key, seed=0, positive=False): - """Compute the 32bit murmurhash3 of key at seed. - - The underlying implementation is MurmurHash3_x86_32 generating low - latency 32bits hash suitable for implementing lookup tables, Bloom - filters, count min sketch or feature hashing. - - Parameters - ---------- - key : np.int32, bytes, unicode or ndarray of dtype=np.int32 - The physical object to hash. - - seed : int, default=0 - Integer seed for the hashing algorithm. - - positive : bool, default=False - True: the results is casted to an unsigned int - from 0 to 2 ** 32 - 1 - False: the results is casted to a signed int - from -(2 ** 31) to 2 ** 31 - 1 - - Examples - -------- - >>> from sklearn.utils.murmurhash import _murmurhash3_32 - >>> _murmurhash3_32(b"Hello World!", seed=42) - 3565178 - """ if isinstance(key, bytes): if positive: return murmurhash3_bytes_u32(key, seed) diff --git a/sklearn/utils/tests/test_murmurhash.py b/sklearn/utils/tests/test_murmurhash.py index b2b54829d5221..20721c6e98f52 100644 --- a/sklearn/utils/tests/test_murmurhash.py +++ b/sklearn/utils/tests/test_murmurhash.py @@ -2,24 +2,23 @@ # SPDX-License-Identifier: BSD-3-Clause import numpy as np -import pytest from numpy.testing import assert_array_almost_equal, assert_array_equal -from sklearn.utils.murmurhash import _murmurhash3_32, murmurhash3_32 +from sklearn.utils.murmurhash import murmurhash3_32 def test_mmhash3_int(): - assert _murmurhash3_32(3) == 847579505 - assert _murmurhash3_32(3, seed=0) == 847579505 - assert _murmurhash3_32(3, seed=42) == -1823081949 + assert murmurhash3_32(3) == 847579505 + assert murmurhash3_32(3, seed=0) == 847579505 + assert murmurhash3_32(3, seed=42) == -1823081949 - assert _murmurhash3_32(3, positive=False) == 847579505 - assert _murmurhash3_32(3, seed=0, positive=False) == 847579505 - assert _murmurhash3_32(3, seed=42, positive=False) == -1823081949 + assert murmurhash3_32(3, positive=False) == 847579505 + assert murmurhash3_32(3, seed=0, positive=False) == 847579505 + assert murmurhash3_32(3, seed=42, positive=False) == -1823081949 - assert _murmurhash3_32(3, positive=True) == 847579505 - assert _murmurhash3_32(3, seed=0, positive=True) == 847579505 - assert _murmurhash3_32(3, seed=42, positive=True) == 2471885347 + assert murmurhash3_32(3, positive=True) == 847579505 + assert murmurhash3_32(3, seed=0, positive=True) == 847579505 + assert murmurhash3_32(3, seed=42, positive=True) == 2471885347 def test_mmhash3_int_array(): @@ -28,38 +27,36 @@ def test_mmhash3_int_array(): keys = keys.reshape((3, 2, 1)) for seed in [0, 42]: - expected = np.array([_murmurhash3_32(int(k), seed) for k in keys.flat]) + expected = np.array([murmurhash3_32(int(k), seed) for k in keys.flat]) expected = expected.reshape(keys.shape) - assert_array_equal(_murmurhash3_32(keys, seed), expected) + assert_array_equal(murmurhash3_32(keys, seed), expected) for seed in [0, 42]: - expected = np.array( - [_murmurhash3_32(k, seed, positive=True) for k in keys.flat] - ) + expected = np.array([murmurhash3_32(k, seed, positive=True) for k in keys.flat]) expected = expected.reshape(keys.shape) - assert_array_equal(_murmurhash3_32(keys, seed, positive=True), expected) + assert_array_equal(murmurhash3_32(keys, seed, positive=True), expected) def test_mmhash3_bytes(): - assert _murmurhash3_32(b"foo", 0) == -156908512 - assert _murmurhash3_32(b"foo", 42) == -1322301282 + assert murmurhash3_32(b"foo", 0) == -156908512 + assert murmurhash3_32(b"foo", 42) == -1322301282 - assert _murmurhash3_32(b"foo", 0, positive=True) == 4138058784 - assert _murmurhash3_32(b"foo", 42, positive=True) == 2972666014 + assert murmurhash3_32(b"foo", 0, positive=True) == 4138058784 + assert murmurhash3_32(b"foo", 42, positive=True) == 2972666014 def test_mmhash3_unicode(): - assert _murmurhash3_32("foo", 0) == -156908512 - assert _murmurhash3_32("foo", 42) == -1322301282 + assert murmurhash3_32("foo", 0) == -156908512 + assert murmurhash3_32("foo", 42) == -1322301282 - assert _murmurhash3_32("foo", 0, positive=True) == 4138058784 - assert _murmurhash3_32("foo", 42, positive=True) == 2972666014 + assert murmurhash3_32("foo", 0, positive=True) == 4138058784 + assert murmurhash3_32("foo", 42, positive=True) == 2972666014 def test_no_collision_on_byte_range(): previous_hashes = set() for i in range(100): - h = _murmurhash3_32(" " * i, 0) + h = murmurhash3_32(" " * i, 0) assert h not in previous_hashes, "Found collision on growing empty string" @@ -68,14 +65,9 @@ def test_uniform_distribution(): bins = np.zeros(n_bins, dtype=np.float64) for i in range(n_samples): - bins[_murmurhash3_32(i, positive=True) % n_bins] += 1 + bins[murmurhash3_32(i, positive=True) % n_bins] += 1 means = bins / n_samples expected = np.full(n_bins, 1.0 / n_bins) assert_array_almost_equal(means / expected, np.ones(n_bins), 2) - - -def test_deprecation_warning(): - with pytest.warns(FutureWarning, match="`murmurhash3_32` was deprecated"): - murmurhash3_32(3) From 85b12c9cfcc73e31bb4794aaec41c35fe591b06f Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 15 Sep 2025 11:03:56 +0200 Subject: [PATCH 37/62] fix docstring --- sklearn/tree/tests/test_heap.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/tests/test_heap.py b/sklearn/tree/tests/test_heap.py index 6ea572c00d800..e54abb8b9fc90 100644 --- a/sklearn/tree/tests/test_heap.py +++ b/sklearn/tree/tests/test_heap.py @@ -8,10 +8,12 @@ @pytest.mark.parametrize("min_heap", [True, False]) def test_cython_weighted_heap_vs_heapq(min_heap): - """Test Cython's weighted heap vs STL's heapq implementation. - - This unit-test first populates Cython Weighted Heap and STL's heap with weighted samples, and then compares values that are popped. - """ + """ + Test Cython's weighted heap vs STL's heapq implementation. + + This unit-test first populates Cython Weighted Heap and STL's heap + with weighted samples, and then compares values that are popped. + """ n = 200 w_heap = PytestWeightedHeap(n, min_heap=min_heap) py_heap = [] From 22c843e3115cdd4083f2f8f90582d40b82d228eb Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 15 Sep 2025 21:15:14 +0200 Subject: [PATCH 38/62] addressed comments around test_absolute_errors_precomputation_function --- sklearn/tree/_criterion.pyx | 9 ++------- sklearn/tree/tests/test_tree.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 76d8d91bc84b9..c2a2af6051cee 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1300,7 +1300,8 @@ def _py_precompute_absolute_errors( const float64_t[:, ::1] ys, const float64_t[:] sample_weight, const intp_t[:] sample_indices, - bint suffix=False + const intp_t start, + const intp_t end, ): """ Used for testing precompute_absolute_errors. @@ -1316,15 +1317,9 @@ def _py_precompute_absolute_errors( WeightedHeap above = WeightedHeap(n, True) WeightedHeap below = WeightedHeap(n, False) intp_t k = 0 - intp_t start = 0 - intp_t end = n float64_t[::1] abs_errors = np.zeros(n, dtype=np.float64) float64_t[::1] medians = np.zeros(n, dtype=np.float64) - if suffix: - start = n - 1 - end = -1 - precompute_absolute_errors( ys, sample_weight, sample_indices, above, below, k, start, end, abs_errors, medians diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 48f85d74423eb..d64c25b3b5cd1 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2885,7 +2885,7 @@ def test_sort_log2_build(): assert_array_equal(samples, expected_samples) -def test_absolute_errors_precomputation_function(): +def test_absolute_errors_precomputation_function(global_random_seed): """ Test the main bit of logic of the MAE(RegressionCriterion) class (used by DecisionTreeRegressor(criterion="absolute_error")). @@ -2911,28 +2911,30 @@ def compute_abs_error(y: np.ndarray, w: np.ndarray): # 2) compute the AE return (np.abs(y - median) * w).sum() + rng = np.random.RandomState(global_random_seed) + for n in [3, 5, 10, 20, 50, 100]: - y = np.random.uniform(size=(n, 1)) - w = np.random.rand(n) + y = rng.uniform(size=(n, 1)) + w = rng.rand(n) indices = np.arange(n) - abs_errors = _py_precompute_absolute_errors(y, w, indices) + abs_errors = _py_precompute_absolute_errors(y, w, indices, 0, n) expected = compute_prefix_abs_errors_naive(y, w) assert np.allclose(abs_errors, expected) - abs_errors = _py_precompute_absolute_errors(y, w, indices, suffix=True) + abs_errors = _py_precompute_absolute_errors(y, w, indices, n - 1, -1) expected = compute_prefix_abs_errors_naive(y[::-1], w[::-1])[::-1] assert np.allclose(abs_errors, expected) - x = np.random.rand(n) + x = rng.rand(n) indices = np.argsort(x) w[:] = 1 y_sorted = y[indices] w_sorted = w[indices] - abs_errors = _py_precompute_absolute_errors(y, w, indices) + abs_errors = _py_precompute_absolute_errors(y, w, indices, 0, n) expected = compute_prefix_abs_errors_naive(y_sorted, w_sorted) assert np.allclose(abs_errors, expected) - abs_errors = _py_precompute_absolute_errors(y, w, indices, suffix=True) + abs_errors = _py_precompute_absolute_errors(y, w, indices, n - 1, -1) expected = compute_prefix_abs_errors_naive(y_sorted[::-1], w_sorted[::-1])[::-1] assert np.allclose(abs_errors, expected) From 319523ae430629da1c5642ee033999527f480442 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 15 Sep 2025 21:17:10 +0200 Subject: [PATCH 39/62] update docstring --- sklearn/tree/_criterion.pyx | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index c2a2af6051cee..436d37369631f 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1305,15 +1305,9 @@ def _py_precompute_absolute_errors( ): """ Used for testing precompute_absolute_errors. - - If `suffix` is False: - Computes the "prefix" AEs, i.e the AEs for each set of indices - sample_indices[:i] with i in {1, ..., n} - - If `suffix` is True: - Computes the "suffix" AEs, i.e the AEs for each set of indices - sample_indices[i:] with i in {0, ..., n-1} """ cdef: - intp_t n = sample_weight.size + intp_t n = end - start if start < end else start - end WeightedHeap above = WeightedHeap(n, True) WeightedHeap below = WeightedHeap(n, False) intp_t k = 0 From d7f51579a6863389b6ff9672559479ddff262fc7 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 15 Sep 2025 21:17:37 +0200 Subject: [PATCH 40/62] update docstring; again --- sklearn/tree/_criterion.pyx | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 436d37369631f..485c476fc72d1 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1303,9 +1303,7 @@ def _py_precompute_absolute_errors( const intp_t start, const intp_t end, ): - """ - Used for testing precompute_absolute_errors. - """ + """Used for testing precompute_absolute_errors.""" cdef: intp_t n = end - start if start < end else start - end WeightedHeap above = WeightedHeap(n, True) From 77dcb198b1badee8a8bbed29c8f408426cd17f41 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 26 Sep 2025 16:59:12 +0200 Subject: [PATCH 41/62] new test and fix --- sklearn/tree/_criterion.pyx | 10 ++-- sklearn/tree/tests/test_tree.py | 86 +++++++++++++++++++-------------- 2 files changed, 56 insertions(+), 40 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 485c476fc72d1..38d00e2e40ca3 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -3,7 +3,7 @@ from libc.string cimport memcpy from libc.string cimport memset -from libc.math cimport fabs, INFINITY +from libc.math cimport INFINITY import numpy as np cimport numpy as cnp @@ -1282,10 +1282,10 @@ cdef void precompute_absolute_errors( below.push(top_val, top_weight) # Current median - if above.total_weight > half_weight + 1e-5 * fabs(half_weight): - median = above.top() - else: # above and below weight are almost exactly equals + if above.total_weight == half_weight: median = (above.top() + below.top()) / 2. + else: + median = above.top() medians[j] = median abs_errors[j] += ( (below.total_weight - above.total_weight) * median @@ -1316,7 +1316,7 @@ def _py_precompute_absolute_errors( ys, sample_weight, sample_indices, above, below, k, start, end, abs_errors, medians ) - return np.asarray(abs_errors) + return np.asarray(abs_errors), np.asarray(medians) cdef class MAE(Criterion): diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index d64c25b3b5cd1..1b028f0cf1ba1 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -64,6 +64,7 @@ CSC_CONTAINERS, CSR_CONTAINERS, ) +from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import check_random_state CLF_CRITERIONS = ("gini", "log_loss") @@ -2896,45 +2897,60 @@ def test_absolute_errors_precomputation_function(global_random_seed): it can be safely removed. """ - def compute_prefix_abs_errors_naive(y: np.ndarray, w: np.ndarray): - y = y.ravel() - return np.array([compute_abs_error(y[:i], w[:i]) for i in range(1, y.size + 1)]) - - def compute_abs_error(y: np.ndarray, w: np.ndarray): - # 1) compute the weighted median - # i.e. once ordered by y, search for i such that: - # sum(w[:i]) <= 1/2 and sum(w[i+1:]) <= 1/2 - sorter = np.argsort(y) - wc = np.cumsum(w[sorter]) - idx = np.searchsorted(wc, wc[-1] / 2) - median = y[sorter[idx]] - # 2) compute the AE - return (np.abs(y - median) * w).sum() + def compute_prefix_abs_errors_naive(y, w): + y = y.ravel().copy() + medians = [ + _weighted_percentile(y[:i], w[:i], 50, average=True) + for i in range(1, y.size + 1) + ] + errors = [ + (np.abs(y[:i] - m) * w[:i]).sum() + for i, m in zip(range(1, y.size + 1), medians) + ] + return np.array(errors), np.array(medians) - rng = np.random.RandomState(global_random_seed) + def assert_same_results(y, w, indices, reverse=False): + args = (n - 1, -1) if reverse else (0, n) + abs_errors, medians = _py_precompute_absolute_errors(y, w, indices, *args) + y_sorted = y[indices] + w_sorted = w[indices] + if reverse: + y_sorted = y_sorted[::-1] + w_sorted = w_sorted[::-1] + abs_errors_, medians_ = compute_prefix_abs_errors_naive(y_sorted, w_sorted) + if reverse: + abs_errors_ = abs_errors_[::-1] + medians_ = medians_[::-1] + assert_allclose(abs_errors, abs_errors_) + assert_allclose(medians, medians_) + + rng = np.random.default_rng(global_random_seed) for n in [3, 5, 10, 20, 50, 100]: y = rng.uniform(size=(n, 1)) w = rng.rand(n) + w *= np.pow(10, rng.uniform(-5, 5)) indices = np.arange(n) - abs_errors = _py_precompute_absolute_errors(y, w, indices, 0, n) - expected = compute_prefix_abs_errors_naive(y, w) - assert np.allclose(abs_errors, expected) - - abs_errors = _py_precompute_absolute_errors(y, w, indices, n - 1, -1) - expected = compute_prefix_abs_errors_naive(y[::-1], w[::-1])[::-1] - assert np.allclose(abs_errors, expected) - - x = rng.rand(n) - indices = np.argsort(x) - w[:] = 1 - y_sorted = y[indices] - w_sorted = w[indices] - - abs_errors = _py_precompute_absolute_errors(y, w, indices, 0, n) - expected = compute_prefix_abs_errors_naive(y_sorted, w_sorted) - assert np.allclose(abs_errors, expected) + assert_same_results(y, w, indices) + assert_same_results(y, np.ones(n), indices) + assert_same_results(y, w, indices, reverse=True) + indices = rng.permutation(n) + assert_same_results(y, w, indices) + assert_same_results(y, w, indices, reverse=True) + + +def test_absolute_error_accurately_predicts_weighted_median(global_random_seed): + rng = np.random.default_rng(global_random_seed) + n = int(1e5) + data = rng.lognormal(size=n) + # Large number of zeros and otherwise continuous weights: + weights = rng.integers(0, 3, size=n) * rng.uniform(0, 1, size=n) + + tree_leaf_weighted_median = ( + DecisionTreeRegressor(criterion="absolute_error", max_depth=1) + .fit(np.ones(shape=(data.shape[0], 1)), data, sample_weight=weights) + .tree_.value.ravel()[0] + ) + weighted_median = _weighted_percentile(data, weights, 50, average=True) - abs_errors = _py_precompute_absolute_errors(y, w, indices, n - 1, -1) - expected = compute_prefix_abs_errors_naive(y_sorted[::-1], w_sorted[::-1])[::-1] - assert np.allclose(abs_errors, expected) + assert_allclose(tree_leaf_weighted_median, weighted_median) From 14014f51926ce6991db72dd790c661c778e8ea4b Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 26 Sep 2025 17:32:15 +0200 Subject: [PATCH 42/62] fix typo --- sklearn/tree/tests/test_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 1b028f0cf1ba1..d172bb60aa4f5 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2928,7 +2928,7 @@ def assert_same_results(y, w, indices, reverse=False): for n in [3, 5, 10, 20, 50, 100]: y = rng.uniform(size=(n, 1)) - w = rng.rand(n) + w = rng.random(n) w *= np.pow(10, rng.uniform(-5, 5)) indices = np.arange(n) assert_same_results(y, w, indices) From ad16ae0286cadea9f491c606c28732f17f1fa8a3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 26 Sep 2025 21:00:50 +0200 Subject: [PATCH 43/62] remove np.pow --- sklearn/tree/tests/test_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index d172bb60aa4f5..d13e373885dff 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2929,7 +2929,7 @@ def assert_same_results(y, w, indices, reverse=False): for n in [3, 5, 10, 20, 50, 100]: y = rng.uniform(size=(n, 1)) w = rng.random(n) - w *= np.pow(10, rng.uniform(-5, 5)) + w *= 10.0 ** rng.uniform(-5, 5) indices = np.arange(n) assert_same_results(y, w, indices) assert_same_results(y, np.ones(n), indices) From 1e9c74ff9565920aa297480808614382f0f99696 Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Fri, 26 Sep 2025 22:41:06 +0200 Subject: [PATCH 44/62] Apply suggestion from @ogrisel Co-authored-by: Olivier Grisel --- sklearn/tree/tests/test_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index d13e373885dff..8d186e316256f 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2921,7 +2921,7 @@ def assert_same_results(y, w, indices, reverse=False): if reverse: abs_errors_ = abs_errors_[::-1] medians_ = medians_[::-1] - assert_allclose(abs_errors, abs_errors_) + assert_allclose(abs_errors, abs_errors_, atol=1e-12) assert_allclose(medians, medians_) rng = np.random.default_rng(global_random_seed) From b21040e714ee91b994e4752d1842c0ada6bd7716 Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Fri, 26 Sep 2025 22:41:29 +0200 Subject: [PATCH 45/62] Apply suggestion from @cakedev0 --- sklearn/tree/tests/test_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 8d186e316256f..834f0a21246d7 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2922,7 +2922,7 @@ def assert_same_results(y, w, indices, reverse=False): abs_errors_ = abs_errors_[::-1] medians_ = medians_[::-1] assert_allclose(abs_errors, abs_errors_, atol=1e-12) - assert_allclose(medians, medians_) + assert_allclose(medians, medians_, atol=1e-12) rng = np.random.default_rng(global_random_seed) From e557f9e10daa09e529b5730b1ce3098dced41e40 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 29 Sep 2025 23:12:25 +0200 Subject: [PATCH 46/62] added explanation test; more tests with integer weights --- sklearn/tree/tests/test_tree.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index d13e373885dff..f737066eb4d9d 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2933,6 +2933,7 @@ def assert_same_results(y, w, indices, reverse=False): indices = np.arange(n) assert_same_results(y, w, indices) assert_same_results(y, np.ones(n), indices) + assert_same_results(y, w.round() + 1, indices) assert_same_results(y, w, indices, reverse=True) indices = rng.permutation(n) assert_same_results(y, w, indices) @@ -2940,6 +2941,10 @@ def assert_same_results(y, w, indices, reverse=False): def test_absolute_error_accurately_predicts_weighted_median(global_random_seed): + """ + Test that the weighted-median computed under-the-hood when + building a tree with criterion="absolute_error" is correct. + """ rng = np.random.default_rng(global_random_seed) n = int(1e5) data = rng.lognormal(size=n) From 86dddc99324a3aee61a5369064491cf28ab2026e Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 8 Oct 2025 17:34:03 +0200 Subject: [PATCH 47/62] Added fenwick tree --- sklearn/tree/_criterion.pyx | 416 +++++++++++++++++++++++++++++++- sklearn/tree/_partitioner.pyx | 121 +--------- sklearn/tree/_sorting.pxd | 9 + sklearn/tree/_sorting.pyx | 120 +++++++++ sklearn/tree/_utils.pxd | 19 ++ sklearn/tree/_utils.pyx | 102 ++++++++ sklearn/tree/meson.build | 3 + sklearn/tree/tests/test_tree.py | 4 +- 8 files changed, 671 insertions(+), 123 deletions(-) create mode 100644 sklearn/tree/_sorting.pxd create mode 100644 sklearn/tree/_sorting.pyx diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index fa7925597b9b8..6ea5110989e9d 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -13,6 +13,8 @@ from scipy.special.cython_special cimport xlogy from sklearn.tree._utils cimport log from sklearn.tree._utils cimport WeightedHeap +from sklearn.tree._utils cimport WeightedFenwickTree +from sklearn.tree._sorting cimport sort # EPSILON is used in the Poisson criterion cdef float64_t EPSILON = 10 * np.finfo('double').eps @@ -1296,7 +1298,7 @@ cdef void precompute_absolute_errors( j += step -def _py_precompute_absolute_errors( +def _py_precompute_absolute_errors_old( const float64_t[:, ::1] ys, const float64_t[:] sample_weight, const intp_t[:] sample_indices, @@ -1319,7 +1321,7 @@ def _py_precompute_absolute_errors( return np.asarray(abs_errors), np.asarray(medians) -cdef class MAE(Criterion): +cdef class MAE_old(Criterion): r"""Mean absolute error impurity criterion. MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true @@ -1367,6 +1369,7 @@ cdef class MAE(Criterion): self.right_abs_errors = np.empty(n_samples, dtype=np.float64) self.left_medians = np.empty(n_samples, dtype=np.float64) self.right_medians = np.empty(n_samples, dtype=np.float64) + self.above = WeightedHeap(n_samples, True) # min-heap self.below = WeightedHeap(n_samples, False) # max-heap @@ -1578,6 +1581,415 @@ cdef class MAE(Criterion): dest[0] = upper_bound +# Helper for MAE criterion: + +cdef void precompute_absolute_errors_fenwick( + const float64_t[::1] sorted_y, + const intp_t[::1] ranks, + const float64_t[:] sample_weight, + const intp_t[:] sample_indices, + WeightedFenwickTree tree, + intp_t start, + intp_t end, + float64_t[::1] abs_errors, + float64_t[::1] medians, +) noexcept nogil: + """ + Fill `abs_errors` and `medians`. + + If start < end: + Computes the "prefix" AEs/medians, i.e the AEs for each set of indices + sample_indices[start:start + i] with i in {1, ..., n} + where n = end - start + Else: + Computes the "suffix" AEs/medians, i.e the AEs for each set of indices + sample_indices[i:] with i in {0, ..., n-1} + + Parameters + ---------- + sorted_y : const float64_t[::1] + Target values, sorted + ranks : const intp_t[::1] + sample_weight : const float64_t[:] + sample_indices : const intp_t[:] + indices indicating which samples to use. Shape: (n_samples,) + tree : WeightedFenwickTree + pre-instanciated tree + start : intp_t + Start index in `sample_indices` + end : intp_t + End index (exclusive) in `sample_indices` + abs_errors : float64_t[::1] + array to store (increment) the computed absolute errors. Shape: (n,) + with n := end - start + medians : float64_t[::1] + array to store (overwrite) the computed medians. Shape: (n,) + + Complexity: O(n log n) + """ + cdef: + intp_t p, i, step, n, r, median_idx, median_prev_idx + float64_t w = 1. + float64_t half_weight, median + float64_t w_right, w_left, wy_left, wy_right + + if start < end: + step = 1 + n = end - start + else: + n = start - end + step = -1 + + tree.reset(n) + + p = start + for _ in range(n): + i = sample_indices[p] + if sample_weight is not None: + w = sample_weight[i] + # Activate sample i at its y-rank + r = ranks[p] + tree.add(r, sorted_y[r], w) + + # Weighted alpha-quantile by cumulative weight + half_weight = 0.5 * tree.total_w + median_idx = tree.search(half_weight, &w_left, &wy_left, inclusive=True) + if w_left == half_weight: + median_prev_idx = tree.search(half_weight, &w_right, &wy_right, inclusive=False) + median = (sorted_y[median_prev_idx] + sorted_y[median_idx]) / 2 + else: + median = sorted_y[median_idx] + + # Right-side aggregates include the quantile position + w_right = tree.total_w - w_left + wy_right = tree.total_wy - wy_left + + # O(1) pinball loss formula + medians[p] = median + abs_errors[p] += ( + (wy_right - median * w_right) + + (median * w_left - wy_left) + ) + p += step + + +cdef inline void compute_ranks( + float64_t* sorted_y, + intp_t* sorted_indices, + intp_t* ranks, + intp_t n +) noexcept nogil: + cdef intp_t i + for i in range(n): + sorted_indices[i] = i + sort(sorted_y, sorted_indices, n) + for i in range(n): + ranks[sorted_indices[i]] = i + + +def _py_precompute_absolute_errors( + const float64_t[:, ::1] ys, + const float64_t[:] sample_weight, + const intp_t[:] sample_indices, + const intp_t start, + const intp_t end, + const intp_t n, +): + """Used for testing precompute_absolute_errors.""" + cdef: + intp_t p, i + intp_t s = start + intp_t e = end + WeightedFenwickTree tree = WeightedFenwickTree(n) + float64_t[::1] sorted_y = np.empty(n, dtype=np.float64) + intp_t[::1] sorted_indices = np.empty(n, dtype=np.intp) + intp_t[::1] ranks = np.empty(n, dtype=np.intp) + float64_t[::1] abs_errors = np.zeros(n, dtype=np.float64) + float64_t[::1] medians = np.empty(n, dtype=np.float64) + + if start > end: + s = end + 1 + e = start + 1 + for p in range(s, e): + i = sample_indices[p] + sorted_y[p - s] = ys[i, 0] + compute_ranks(&sorted_y[0], &sorted_indices[0], &ranks[s], n) + + precompute_absolute_errors_fenwick( + sorted_y, ranks, sample_weight, sample_indices, tree, + start, end, abs_errors, medians + ) + return np.asarray(abs_errors)[s:e], np.asarray(medians)[s:e] + + +cdef class MAE(Criterion): + r"""Mean absolute error impurity criterion. + + MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true + value and f_i is the predicted value. + + It has almost nothing in common with other regression criterions + so it doesn't inherit from RegressionCriterion + """ + cdef float64_t[::1] node_medians + cdef float64_t[::1] left_abs_errors + cdef float64_t[::1] right_abs_errors + cdef float64_t[::1] left_medians + cdef float64_t[::1] right_medians + cdef float64_t[::1] sorted_y + cdef intp_t [::1] sorted_indices + cdef intp_t[::1] ranks + cdef WeightedFenwickTree tree + + def __cinit__(self, intp_t n_outputs, intp_t n_samples): + """Initialize parameters for this criterion. + + Parameters + ---------- + n_outputs : intp_t + The number of targets to be predicted + + n_samples : intp_t + The total number of samples to fit on + """ + # Default values + self.start = 0 + self.pos = 0 + self.end = 0 + + self.n_outputs = n_outputs + self.n_samples = n_samples + self.n_node_samples = 0 + self.weighted_n_node_samples = 0.0 + self.weighted_n_left = 0.0 + self.weighted_n_right = 0.0 + + self.node_medians = np.zeros(n_outputs, dtype=np.float64) + + # Note: this criterion has a n_samples x 64 bytes memory footprint, which is + # fine as it's instantiated only once to build an entire tree + self.left_abs_errors = np.empty(n_samples, dtype=np.float64) + self.right_abs_errors = np.empty(n_samples, dtype=np.float64) + self.left_medians = np.empty(n_samples, dtype=np.float64) + self.right_medians = np.empty(n_samples, dtype=np.float64) + self.tree = WeightedFenwickTree(n_samples) # 2 float64 arrays of size n_samples + 1 + + self.sorted_y = np.empty(n_samples, dtype=np.float64) + self.sorted_indices = np.empty(n_samples, dtype=np.intp) + self.ranks = np.empty(n_samples, dtype=np.intp) + + cdef int init( + self, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, + float64_t weighted_n_samples, + const intp_t[:] sample_indices, + intp_t start, + intp_t end, + ) except -1 nogil: + """Initialize the criterion. + + This initializes the criterion at node sample_indices[start:end] and children + sample_indices[start:start] and sample_indices[start:end]. + + WARNING: sample_indices will be modified in-place externally + after this method is called + """ + cdef: + intp_t i, p + intp_t n = end - start + float64_t w = 1.0 + + # Initialize fields + self.y = y + self.sample_weight = sample_weight + self.sample_indices = sample_indices + self.start = start + self.end = end + self.n_node_samples = n + self.weighted_n_samples = weighted_n_samples + self.weighted_n_node_samples = 0. + + for p in range(start, end): + i = sample_indices[p] + if sample_weight is not None: + w = sample_weight[i] + self.weighted_n_node_samples += w + + # Reset to pos=start + self.reset() + return 0 + + cdef void init_missing(self, intp_t n_missing) noexcept nogil: + """Raise error if n_missing != 0.""" + if n_missing == 0: + return + with gil: + raise ValueError("missing values is not supported for MAE.") + + cdef int reset(self) except -1 nogil: + """Reset the criterion at pos=start. + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + + Reset might be called after an external class has changed + inplace self.sample_indices[start:end], hence re-computing + the absolute errors is needed + """ + cdef intp_t k, p, i + + self.weighted_n_left = 0.0 + self.weighted_n_right = self.weighted_n_node_samples + self.pos = self.start + + n_bytes = self.n_node_samples * sizeof(float64_t) + memset(&self.left_abs_errors[self.start], 0, n_bytes) + memset(&self.right_abs_errors[self.start], 0, n_bytes) + + for k in range(self.n_outputs): + + for p in range(self.start, self.end): + i = self.sample_indices[p] + self.sorted_y[p - self.start] = self.y[i, k] + + compute_ranks( + &self.sorted_y[0], + &self.sorted_indices[0], + &self.ranks[self.start], + self.n_node_samples, + ) + + # Note that at each iteration of this loop, we overwrite `self.left_medians` + # and `self.right_medians`. They are used to check for monoticity constraints, + # which are allowed only with n_outputs=1. + precompute_absolute_errors_fenwick( + self.sorted_y, self.ranks, self.sample_weight, self.sample_indices, + self.tree, self.start, self.end, + # left_abs_errors is incremented, left_medians is overwritten + self.left_abs_errors, self.left_medians + ) + # For the right child, we consider samples from end-1 to start-1 + # i.e., reversed, and abs error & median are filled in reverse order to. + precompute_absolute_errors_fenwick( + self.sorted_y, self.ranks, self.sample_weight, self.sample_indices, + self.tree, self.end - 1, self.start - 1, + # right_abs_errors is incremented, right_medians is overwritten + self.right_abs_errors, self.right_medians + ) + # Store the median for the current node + self.node_medians[k] = self.right_medians[self.start] + + return 0 + + cdef int reverse_reset(self) except -1 nogil: + """For this class, this method is never called""" + raise NotImplementedError("This method is not implemented for this subclass") + + cdef int update(self, intp_t new_pos) except -1 nogil: + """Updated statistics by moving sample_indices[pos:new_pos] to the left. + new_pos is guaranteed to be greater than pos + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + + Time complexity: O(new_pos - pos) (which usually is O(1), at least for dense data) + """ + cdef intp_t pos = self.pos + cdef intp_t i, p + cdef float64_t w = 1.0 + + # Update statistics up to new_pos + for p in range(pos, new_pos): + i = self.sample_indices[p] + if self.sample_weight is not None: + w = self.sample_weight[i] + self.weighted_n_left += w + + self.weighted_n_right = (self.weighted_n_node_samples - + self.weighted_n_left) + self.pos = new_pos + return 0 + + cdef void node_value(self, float64_t* dest) noexcept nogil: + """Computes the node value of sample_indices[start:end] into dest.""" + cdef intp_t k + for k in range(self.n_outputs): + dest[k] = self.node_medians[k] + + cdef inline float64_t middle_value(self) noexcept nogil: + """Compute the middle value of a split for monotonicity constraints as the simple average + of the left and right children values. + + Monotonicity constraints are only supported for single-output trees we can safely assume + n_outputs == 1. + """ + return ( + self.left_medians[self.pos - 1] + + self.right_medians[self.pos] + ) / 2 + + cdef inline bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + float64_t lower_bound, + float64_t upper_bound, + ) noexcept nogil: + """Check monotonicity constraint is satisfied at the current regression split""" + return self._check_monotonicity( + monotonic_cst, lower_bound, upper_bound, + self.left_medians[self.pos - 1], self.right_medians[self.pos]) + + cdef float64_t node_impurity(self) noexcept nogil: + """Evaluate the impurity of the current node. + + Evaluate the MAE criterion as impurity of the current node, + i.e. the impurity of sample_indices[start:end]. The smaller the impurity the + better. + + Time complexity: O(1) (precomputed in `.reset()`) + """ + return ( + self.right_abs_errors[0] + / (self.weighted_n_node_samples * self.n_outputs) + ) + + cdef void children_impurity(self, float64_t* p_impurity_left, + float64_t* p_impurity_right) noexcept nogil: + """Evaluate the impurity in children nodes. + + i.e. the impurity of the left child (sample_indices[start:pos]) and the + impurity the right child (sample_indices[pos:end]). + + Time complexity: O(1) (precomputed in `.reset()`) + """ + cdef float64_t impurity_left = 0.0 + cdef float64_t impurity_right = 0.0 + + # if pos == start, left child is empty, hence impurity is 0 + if self.pos > self.start: + impurity_left += self.left_abs_errors[self.pos - 1] + p_impurity_left[0] = impurity_left / (self.weighted_n_left * + self.n_outputs) + + # if pos == end, right child is empty, hence impurity is 0 + if self.pos < self.end: + impurity_right += self.right_abs_errors[self.pos] + p_impurity_right[0] = impurity_right / (self.weighted_n_right * + self.n_outputs) + + # those 2 methods are copied from the RegressionCriterion abstract class: + def __reduce__(self): + return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) + + cdef inline void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil: + """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" + if dest[0] < lower_bound: + dest[0] = lower_bound + elif dest[0] > upper_bound: + dest[0] = upper_bound + + cdef class FriedmanMSE(MSE): """Mean squared error impurity criterion with improvement score by Friedman. diff --git a/sklearn/tree/_partitioner.pyx b/sklearn/tree/_partitioner.pyx index 5cec6073d74f1..f113ffaa6e6d2 100644 --- a/sklearn/tree/_partitioner.pyx +++ b/sklearn/tree/_partitioner.pyx @@ -18,6 +18,8 @@ from libc.string cimport memcpy import numpy as np from scipy.sparse import issparse +from sklearn.tree._sorting cimport sort + # Constant to switch between algorithm non zero value extract algorithm # in SparsePartitioner @@ -696,122 +698,3 @@ cdef inline void shift_missing_values_to_left_if_required( current_end = end - 1 - p samples[i], samples[current_end] = samples[current_end], samples[i] best.pos += best.n_missing - - -def _py_sort(float32_t[::1] feature_values, intp_t[::1] samples, intp_t n): - """Used for testing sort.""" - sort(&feature_values[0], &samples[0], n) - - -# Sort n-element arrays pointed to by feature_values and samples, simultaneously, -# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: - if n == 0: - return - cdef intp_t maxd = 2 * log2(n) - introsort(feature_values, samples, n, maxd) - - -cdef inline void swap(float32_t* feature_values, intp_t* samples, - intp_t i, intp_t j) noexcept nogil: - # Helper for sort - feature_values[i], feature_values[j] = feature_values[j], feature_values[i] - samples[i], samples[j] = samples[j], samples[i] - - -cdef inline float32_t median3(float32_t* feature_values, intp_t n) noexcept nogil: - # Median of three pivot selection, after Bentley and McIlroy (1993). - # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. - cdef float32_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] - if a < b: - if b < c: - return b - elif a < c: - return c - else: - return a - elif b < c: - if a < c: - return a - else: - return c - else: - return b - - -# Introsort with median of 3 pivot selection and 3-way partition function -# (robust to repeated elements, e.g. lots of zero features). -cdef void introsort(float32_t* feature_values, intp_t *samples, - intp_t n, intp_t maxd) noexcept nogil: - cdef float32_t pivot - cdef intp_t i, l, r - - while n > 1: - if maxd <= 0: # max depth limit exceeded ("gone quadratic") - heapsort(feature_values, samples, n) - return - maxd -= 1 - - pivot = median3(feature_values, n) - - # Three-way partition. - i = l = 0 - r = n - while i < r: - if feature_values[i] < pivot: - swap(feature_values, samples, i, l) - i += 1 - l += 1 - elif feature_values[i] > pivot: - r -= 1 - swap(feature_values, samples, i, r) - else: - i += 1 - - introsort(feature_values, samples, l, maxd) - feature_values += r - samples += r - n -= r - - -cdef inline void sift_down(float32_t* feature_values, intp_t* samples, - intp_t start, intp_t end) noexcept nogil: - # Restore heap order in feature_values[start:end] by moving the max element to start. - cdef intp_t child, maxind, root - - root = start - while True: - child = root * 2 + 1 - - # find max of root, left child, right child - maxind = root - if child < end and feature_values[maxind] < feature_values[child]: - maxind = child - if child + 1 < end and feature_values[maxind] < feature_values[child + 1]: - maxind = child + 1 - - if maxind == root: - break - else: - swap(feature_values, samples, root, maxind) - root = maxind - - -cdef void heapsort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: - cdef intp_t start, end - - # heapify - start = (n - 2) / 2 - end = n - while True: - sift_down(feature_values, samples, start, end) - if start == 0: - break - start -= 1 - - # sort by shrinking the heap, putting the max element immediately after it - end = n - 1 - while end > 0: - swap(feature_values, samples, 0, end) - sift_down(feature_values, samples, 0, end) - end = end - 1 diff --git a/sklearn/tree/_sorting.pxd b/sklearn/tree/_sorting.pxd new file mode 100644 index 0000000000000..969294885f01e --- /dev/null +++ b/sklearn/tree/_sorting.pxd @@ -0,0 +1,9 @@ +from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t + + +ctypedef fused floating_t: + float32_t + float64_t + + +cdef inline void sort(floating_t* feature_values, intp_t* samples, intp_t n) noexcept nogil diff --git a/sklearn/tree/_sorting.pyx b/sklearn/tree/_sorting.pyx new file mode 100644 index 0000000000000..5ae7de3233ade --- /dev/null +++ b/sklearn/tree/_sorting.pyx @@ -0,0 +1,120 @@ +from libc.math cimport log2 + + +def _py_sort(float32_t[::1] feature_values, intp_t[::1] samples, intp_t n): + """Used for testing sort.""" + sort(&feature_values[0], &samples[0], n) + + +# Sort n-element arrays pointed to by feature_values and samples, simultaneously, +# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). +cdef inline void sort(floating_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: + if n == 0: + return + cdef intp_t maxd = 2 * log2(n) + introsort(feature_values, samples, n, maxd) + + +cdef inline void swap(floating_t* feature_values, intp_t* samples, + intp_t i, intp_t j) noexcept nogil: + # Helper for sort + feature_values[i], feature_values[j] = feature_values[j], feature_values[i] + samples[i], samples[j] = samples[j], samples[i] + + +cdef inline floating_t median3(floating_t* feature_values, intp_t n) noexcept nogil: + # Median of three pivot selection, after Bentley and McIlroy (1993). + # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. + cdef floating_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] + if a < b: + if b < c: + return b + elif a < c: + return c + else: + return a + elif b < c: + if a < c: + return a + else: + return c + else: + return b + + +# Introsort with median of 3 pivot selection and 3-way partition function +# (robust to repeated elements, e.g. lots of zero features). +cdef void introsort(floating_t* feature_values, intp_t *samples, + intp_t n, intp_t maxd) noexcept nogil: + cdef floating_t pivot + cdef intp_t i, l, r + + while n > 1: + if maxd <= 0: # max depth limit exceeded ("gone quadratic") + heapsort(feature_values, samples, n) + return + maxd -= 1 + + pivot = median3(feature_values, n) + + # Three-way partition. + i = l = 0 + r = n + while i < r: + if feature_values[i] < pivot: + swap(feature_values, samples, i, l) + i += 1 + l += 1 + elif feature_values[i] > pivot: + r -= 1 + swap(feature_values, samples, i, r) + else: + i += 1 + + introsort(feature_values, samples, l, maxd) + feature_values += r + samples += r + n -= r + + +cdef inline void sift_down(floating_t* feature_values, intp_t* samples, + intp_t start, intp_t end) noexcept nogil: + # Restore heap order in feature_values[start:end] by moving the max element to start. + cdef intp_t child, maxind, root + + root = start + while True: + child = root * 2 + 1 + + # find max of root, left child, right child + maxind = root + if child < end and feature_values[maxind] < feature_values[child]: + maxind = child + if child + 1 < end and feature_values[maxind] < feature_values[child + 1]: + maxind = child + 1 + + if maxind == root: + break + else: + swap(feature_values, samples, root, maxind) + root = maxind + + +cdef void heapsort(floating_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: + cdef intp_t start, end + + # heapify + start = (n - 2) / 2 + end = n + while True: + sift_down(feature_values, samples, start, end) + if start == 0: + break + start -= 1 + + # sort by shrinking the heap, putting the max element immediately after it + end = n - 1 + while end > 0: + swap(feature_values, samples, 0, end) + sift_down(feature_values, samples, 0, end) + end = end - 1 diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index bf634e3b0e45f..88d25a44517c1 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -67,3 +67,22 @@ cdef class WeightedHeap: cdef void _heapify_down(self, intp_t) noexcept nogil cdef float64_t log(float64_t x) noexcept nogil + + +cdef class WeightedFenwickTree: + cdef intp_t size # number of leaves (ranks) + cdef float64_t* tree_w # BIT for weights + cdef float64_t* tree_wy # BIT for weighted targets + cdef intp_t max_pow2 # highest power of two <= n + cdef float64_t total_w # running total weight + cdef float64_t total_wy # running total weighted target + + cdef void reset(self, intp_t size) noexcept nogil + cdef void add(self, intp_t idx, float64_t y, float64_t w) noexcept nogil + cdef intp_t search( + self, + float64_t t, + float64_t* cw_out, + float64_t* cwy_out, + bint inclusive + ) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 253971110d01d..58362f4bcd116 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -5,6 +5,7 @@ from libc.stdlib cimport free from libc.stdlib cimport realloc from libc.math cimport log as ln from libc.math cimport isnan +from libc.string cimport memset import numpy as np cimport numpy as cnp @@ -249,3 +250,104 @@ cdef class PytestWeightedHeap(WeightedHeap): cdef double v, w self.pop(&v, &w) return v, w + + +cdef class WeightedFenwickTree: + """ + Fenwick tree (Binary Indexed Tree) for maintaining: + - prefix sums of weights, and + - prefix sums of weight*value (targets), + indexed by the rank of y in sorted order (1-based internally). + + Supports: + - add(rank, w, y): point update at 'rank' + - search(t): find the smallest rank with cumulative weight > t, + also returns prefix aggregates excluding that rank. + """ + + def __cinit__(self, intp_t capacity): + self.tree_w = NULL + self.tree_wy = NULL + # safe_realloc can raise MemoryError -> __cinit__ may propagate + safe_realloc(&self.tree_w, capacity + 1) + safe_realloc(&self.tree_wy, capacity + 1) + + cdef void reset(self, intp_t size) noexcept nogil: + cdef intp_t p + cdef intp_t n_bytes = (size + 1) * sizeof(float64_t) + # +1 because 1-based + + self.size = size + memset(self.tree_w, 0, n_bytes) + memset(self.tree_wy, 0, n_bytes) + self.total_w = 0.0 + self.total_wy = 0.0 + + # highest power of two <= size + p = 1 + while p <= size: + p <<= 1 + self.max_pow2 = p >> 1 + + def __dealloc__(self): + if self.tree_w != NULL: + free(self.tree_w) + if self.tree_wy != NULL: + free(self.tree_wy) + + cdef void add(self, intp_t idx, float64_t y, float64_t w) noexcept nogil: + cdef float64_t wy = w * y + idx += 1 # 1-based + + while idx <= self.size: + self.tree_w[idx] += w + self.tree_wy[idx] += wy + idx += idx & -idx + + self.total_w += w + self.total_wy += wy + + cdef intp_t search( + self, + float64_t t, + float64_t* cw_out, + float64_t* cwy_out, + bint inclusive + ) noexcept nogil: + """ + Find the leaf such that + prefix_weight <= t < prefix_weight if inclusive + + and return: + - cw: prefix weight up to (rank-1) + - cwv: prefix weighted sum up to (rank-1) + - q: the y-value at 'rank' (the weighted alpha-quantile) + - prev_idx: if t == 0 at the end, the last index where we made a move + + Notes: + * Assumes there is at least one active (positive-weight) item. + * If t >= total weight (can happen with alpha ~ 1), we clamp t slightly. + """ + cdef: + intp_t idx = 0 + float64_t cw = 0.0 + float64_t cwy = 0.0 + intp_t bit = self.max_pow2 + float64_t w + + # Standard Fenwick lower-bound with simultaneous prefix accumulation + while bit != 0: + next_idx = idx + bit + if next_idx <= self.size: + w = self.tree_w[next_idx] + if (t > w) or (inclusive and t >= w): + t -= w + idx = next_idx + cw += w + cwy += self.tree_wy[next_idx] + bit >>= 1 + + cw_out[0] = cw + cwy_out[0] = cwy + + return idx diff --git a/sklearn/tree/meson.build b/sklearn/tree/meson.build index 87345a1e344bf..d92a1e9703727 100644 --- a/sklearn/tree/meson.build +++ b/sklearn/tree/meson.build @@ -14,6 +14,9 @@ tree_extension_metadata = { '_utils': {'sources': [cython_gen.process('_utils.pyx')], 'override_options': ['optimization=3']}, + '_sorting': + {'sources': [cython_gen.process('_sorting.pyx')], + 'override_options': ['optimization=3']}, } foreach ext_name, ext_dict : tree_extension_metadata diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index ba91267561abc..080380795daa1 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -42,7 +42,7 @@ SPARSE_SPLITTERS, ) from sklearn.tree._criterion import _py_precompute_absolute_errors -from sklearn.tree._partitioner import _py_sort +from sklearn.tree._sorting import _py_sort from sklearn.tree._tree import ( NODE_DTYPE, TREE_LEAF, @@ -2949,7 +2949,7 @@ def compute_prefix_abs_errors_naive(y, w): def assert_same_results(y, w, indices, reverse=False): args = (n - 1, -1) if reverse else (0, n) - abs_errors, medians = _py_precompute_absolute_errors(y, w, indices, *args) + abs_errors, medians = _py_precompute_absolute_errors(y, w, indices, *args, n) y_sorted = y[indices] w_sorted = w[indices] if reverse: From b87ea449038d52b69d5713a06ce4a73da7bf78b1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 9 Oct 2025 09:30:37 +0200 Subject: [PATCH 48/62] fixes --- sklearn/tree/_criterion.pxd | 1 - sklearn/tree/_criterion.pyx | 406 ----------------------------- sklearn/tree/_utils.pxd | 18 -- sklearn/tree/_utils.pyx | 195 ++------------ sklearn/tree/tests/test_fenwick.py | 31 +++ sklearn/tree/tests/test_heap.py | 39 --- 6 files changed, 54 insertions(+), 636 deletions(-) create mode 100644 sklearn/tree/tests/test_fenwick.py delete mode 100644 sklearn/tree/tests/test_heap.py diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 24ea34892db7b..84d2e800d6a87 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -3,7 +3,6 @@ # See _criterion.pyx for implementation details. from ..utils._typedefs cimport float64_t, int8_t, intp_t -from ._utils cimport WeightedHeap cdef class Criterion: diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 6ea5110989e9d..da1ef4035b43c 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -12,7 +12,6 @@ cnp.import_array() from scipy.special.cython_special cimport xlogy from sklearn.tree._utils cimport log -from sklearn.tree._utils cimport WeightedHeap from sklearn.tree._utils cimport WeightedFenwickTree from sklearn.tree._sorting cimport sort @@ -1176,411 +1175,6 @@ cdef class MSE(RegressionCriterion): impurity_right[0] /= self.n_outputs -# Helper for MAE criterion: - -cdef void precompute_absolute_errors( - const float64_t[:, ::1] ys, - const float64_t[:] sample_weight, - const intp_t[:] sample_indices, - WeightedHeap above, - WeightedHeap below, - intp_t k, - intp_t start, - intp_t end, - float64_t[::1] abs_errors, - float64_t[::1] medians -) noexcept nogil: - """ - Fill `abs_errors` and `medians`. - - If start < end: - Computes the "prefix" AEs/medians, i.e the AEs for each set of indices - sample_indices[start:start + i] with i in {1, ..., n} - where n = end - start - Else: - Computes the "suffix" AEs/medians, i.e the AEs for each set of indices - sample_indices[i:] with i in {0, ..., n-1} - - Parameters - ---------- - ys : const float64_t[:, ::1] - Target values. Shape: (n_samples, n_outputs). - sample_weight : const float64_t[:] - Shape: (n_samples,) - sample_indices : const intp_t[:] - indices indicating which samples to use. Shape: (n_samples,) - above : WeightedHeap - below : WeightedHeap - k : intp_t - Dimension to consider in y. In [0, n_outputs - 1]. - start : intp_t - Start index in `sample_indices` - end : intp_t - End index (exclusive) in `sample_indices` - abs_errors : float64_t[::1] - array to store (increment) the computed absolute errors. Shape: (n,) - with n := end - start - medians : float64_t[::1] - array to store (overwrite) the computed medians. Shape: (n,) - - Complexity: O(n log n) - This algorithm is an adaptation of the two heaps solution of - the "find median from a data stream" problem - See for instance: https://www.geeksforgeeks.org/dsa/median-of-stream-of-integers-running-integers/ - - But here, it's the weighted median and we also need to compute the AE, so: - - instead of balancing the heaps based on their number of elements, - rebalance them based on the summed weights of their elements - - rewrite the AE computation by splitting the sum between elements - above and below the median, which allow to express it as a simple - O(1) computation. - See the maths in the PR desc: - https://github.com/scikit-learn/scikit-learn/pull/32100 - """ - cdef intp_t j, p, i, step, n - if start < end: - j = 0 - step = 1 - n = end - start - else: - n = start - end - step = -1 - j = n - 1 - - above.reset() - below.reset() - cdef float64_t y - cdef float64_t w = 1.0 - cdef float64_t top_val, top_weight - cdef float64_t median = 0.0 - cdef float64_t half_weight - - p = start - for _ in range(n): - i = sample_indices[p] - if sample_weight is not None: - w = sample_weight[i] - y = ys[i, k] - - # Insert into the appropriate heap - if below.is_empty(): - above.push(y, w) - elif y > below.top(): - above.push(y, w) - else: - below.push(y, w) - - half_weight = (above.total_weight + below.total_weight) / 2.0 - - # Rebalance heaps - while above.total_weight < half_weight and not below.is_empty(): - below.pop(&top_val, &top_weight) - above.push(top_val, top_weight) - while ( - not above.is_empty() - and (above.total_weight - above.top_weight()) >= half_weight - ): - above.pop(&top_val, &top_weight) - below.push(top_val, top_weight) - - # Current median - if above.total_weight == half_weight: - median = (above.top() + below.top()) / 2. - else: - median = above.top() - medians[j] = median - abs_errors[j] += ( - (below.total_weight - above.total_weight) * median - - below.weighted_sum - + above.weighted_sum - ) - p += step - j += step - - -def _py_precompute_absolute_errors_old( - const float64_t[:, ::1] ys, - const float64_t[:] sample_weight, - const intp_t[:] sample_indices, - const intp_t start, - const intp_t end, -): - """Used for testing precompute_absolute_errors.""" - cdef: - intp_t n = end - start if start < end else start - end - WeightedHeap above = WeightedHeap(n, True) - WeightedHeap below = WeightedHeap(n, False) - intp_t k = 0 - float64_t[::1] abs_errors = np.zeros(n, dtype=np.float64) - float64_t[::1] medians = np.zeros(n, dtype=np.float64) - - precompute_absolute_errors( - ys, sample_weight, sample_indices, above, below, - k, start, end, abs_errors, medians - ) - return np.asarray(abs_errors), np.asarray(medians) - - -cdef class MAE_old(Criterion): - r"""Mean absolute error impurity criterion. - - MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true - value and f_i is the predicted value. - - It has almost nothing in common with other regression criterions - so it doesn't inherit from RegressionCriterion - """ - cdef float64_t[::1] node_medians - cdef float64_t[::1] left_abs_errors - cdef float64_t[::1] right_abs_errors - cdef float64_t[::1] left_medians - cdef float64_t[::1] right_medians - cdef WeightedHeap above - cdef WeightedHeap below - - def __cinit__(self, intp_t n_outputs, intp_t n_samples): - """Initialize parameters for this criterion. - - Parameters - ---------- - n_outputs : intp_t - The number of targets to be predicted - - n_samples : intp_t - The total number of samples to fit on - """ - # Default values - self.start = 0 - self.pos = 0 - self.end = 0 - - self.n_outputs = n_outputs - self.n_samples = n_samples - self.n_node_samples = 0 - self.weighted_n_node_samples = 0.0 - self.weighted_n_left = 0.0 - self.weighted_n_right = 0.0 - - self.node_medians = np.zeros(n_outputs, dtype=np.float64) - - # Note: this criterion has a n_samples x 64 bytes memory footprint, which is - # fine as it's instantiated only once to build an entire tree - self.left_abs_errors = np.empty(n_samples, dtype=np.float64) - self.right_abs_errors = np.empty(n_samples, dtype=np.float64) - self.left_medians = np.empty(n_samples, dtype=np.float64) - self.right_medians = np.empty(n_samples, dtype=np.float64) - - self.above = WeightedHeap(n_samples, True) # min-heap - self.below = WeightedHeap(n_samples, False) # max-heap - - cdef int init( - self, - const float64_t[:, ::1] y, - const float64_t[:] sample_weight, - float64_t weighted_n_samples, - const intp_t[:] sample_indices, - intp_t start, - intp_t end, - ) except -1 nogil: - """Initialize the criterion. - - This initializes the criterion at node sample_indices[start:end] and children - sample_indices[start:start] and sample_indices[start:end]. - - WARNING: sample_indices will be modified in-place externally - after this method is called - """ - cdef intp_t i - cdef float64_t w = 1.0 - # Initialize fields - self.y = y - self.sample_weight = sample_weight - self.sample_indices = sample_indices - self.start = start - self.end = end - self.n_node_samples = end - start - self.weighted_n_samples = weighted_n_samples - self.weighted_n_node_samples = 0. - - for p in range(start, end): - i = sample_indices[p] - if sample_weight is not None: - w = sample_weight[i] - self.weighted_n_node_samples += w - - # Reset to pos=start - self.reset() - return 0 - - cdef void init_missing(self, intp_t n_missing) noexcept nogil: - """Raise error if n_missing != 0.""" - if n_missing == 0: - return - with gil: - raise ValueError("missing values is not supported for MAE.") - - cdef int reset(self) except -1 nogil: - """Reset the criterion at pos=start. - - Returns -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. - - Reset might be called after an external class has changed - inplace self.sample_indices[start:end], hence re-computing - the absolute errors is needed - """ - - self.weighted_n_left = 0.0 - self.weighted_n_right = self.weighted_n_node_samples - self.pos = self.start - - n_bytes = self.n_node_samples * sizeof(float64_t) - memset(&self.left_abs_errors[0], 0, n_bytes) - memset(&self.right_abs_errors[0], 0, n_bytes) - - # Precompute absolute errors (summed over each output) - # and medians (used only when n_outputs=1) - # of the right and left child of all possible splits - # for the current ordering of `sample_indices` - # Precomputation is needed here and can't be done step-by-step in the update method - # like for other criterions. Indeed, we don't have efficient ways to update right child - # statistics when removing samples from it. So we compute right child AEs/medians by - # traversing from right to left (and hence only adding samples). - for k in range(self.n_outputs): - # Note that at each iteration of this loop, we overwrite `self.left_medians` - # and `self.right_medians`. They are used to check for monoticity constraints, - # which are allowed only with n_outputs=1. - precompute_absolute_errors( - self.y, self.sample_weight, self.sample_indices, - self.above, self.below, k, self.start, self.end, - # left_abs_errors is incremented, left_medians is overwritten - self.left_abs_errors, self.left_medians - ) - # For the right child, we consider samples from end-1 to start-1 - # i.e., reversed, and abs error & median are filled in reverse order to. - precompute_absolute_errors( - self.y, self.sample_weight, self.sample_indices, - self.above, self.below, k, self.end - 1, self.start - 1, - # right_abs_errors is incremented, right_medians is overwritten - self.right_abs_errors, self.right_medians - ) - # Store the median for the current node - self.node_medians[k] = self.right_medians[0] - - return 0 - - cdef int reverse_reset(self) except -1 nogil: - """For this class, this method is never called""" - raise NotImplementedError("This method is not implemented for this subclass") - - cdef int update(self, intp_t new_pos) except -1 nogil: - """Updated statistics by moving sample_indices[pos:new_pos] to the left. - new_pos is guaranteed to be greater than pos - - Returns -1 in case of failure to allocate memory (and raise MemoryError) - or 0 otherwise. - - Time complexity: O(new_pos - pos) (which usually is O(1), at least for dense data) - """ - cdef intp_t pos = self.pos - cdef intp_t i, p - cdef float64_t w = 1.0 - - # Update statistics up to new_pos - for p in range(pos, new_pos): - i = self.sample_indices[p] - if self.sample_weight is not None: - w = self.sample_weight[i] - self.weighted_n_left += w - - self.weighted_n_right = (self.weighted_n_node_samples - - self.weighted_n_left) - self.pos = new_pos - return 0 - - cdef void node_value(self, float64_t* dest) noexcept nogil: - """Computes the node value of sample_indices[start:end] into dest.""" - cdef intp_t k - for k in range(self.n_outputs): - dest[k] = self.node_medians[k] - - cdef inline float64_t middle_value(self) noexcept nogil: - """Compute the middle value of a split for monotonicity constraints as the simple average - of the left and right children values. - - Monotonicity constraints are only supported for single-output trees we can safely assume - n_outputs == 1. - """ - cdef intp_t j = self.pos - self.start - return ( - self.left_medians[j - 1] - + self.right_medians[j] - ) / 2 - - cdef inline bint check_monotonicity( - self, - cnp.int8_t monotonic_cst, - float64_t lower_bound, - float64_t upper_bound, - ) noexcept nogil: - """Check monotonicity constraint is satisfied at the current regression split""" - cdef intp_t j = self.pos - self.start - - return self._check_monotonicity( - monotonic_cst, lower_bound, upper_bound, - self.left_medians[j - 1], self.right_medians[j]) - - cdef float64_t node_impurity(self) noexcept nogil: - """Evaluate the impurity of the current node. - - Evaluate the MAE criterion as impurity of the current node, - i.e. the impurity of sample_indices[start:end]. The smaller the impurity the - better. - - Time complexity: O(1) (precomputed in `.reset()`) - """ - return ( - self.right_abs_errors[0] - / (self.weighted_n_node_samples * self.n_outputs) - ) - - cdef void children_impurity(self, float64_t* p_impurity_left, - float64_t* p_impurity_right) noexcept nogil: - """Evaluate the impurity in children nodes. - - i.e. the impurity of the left child (sample_indices[start:pos]) and the - impurity the right child (sample_indices[pos:end]). - - Time complexity: O(1) (precomputed in `.reset()`) - """ - cdef intp_t j = self.pos - self.start - cdef float64_t impurity_left = 0.0 - cdef float64_t impurity_right = 0.0 - - # if pos == start, left child is empty, hence impurity is 0 - if self.pos > self.start: - impurity_left += self.left_abs_errors[j - 1] - p_impurity_left[0] = impurity_left / (self.weighted_n_left * - self.n_outputs) - - # if pos == end, right child is empty, hence impurity is 0 - if self.pos < self.end: - impurity_right += self.right_abs_errors[j] - p_impurity_right[0] = impurity_right / (self.weighted_n_right * - self.n_outputs) - - # those 2 methods are copied from the RegressionCriterion abstract class: - def __reduce__(self): - return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) - - cdef inline void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil: - """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" - if dest[0] < lower_bound: - dest[0] = lower_bound - elif dest[0] > upper_bound: - dest[0] = upper_bound - - # Helper for MAE criterion: cdef void precompute_absolute_errors_fenwick( diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 88d25a44517c1..ac38f72818829 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -47,24 +47,6 @@ cdef intp_t rand_int(intp_t low, intp_t high, cdef float64_t rand_uniform(float64_t low, float64_t high, uint32_t* random_state) noexcept nogil -cdef class WeightedHeap: - cdef intp_t capacity - cdef intp_t size - cdef float64_t* heap - cdef float64_t* weights - cdef float64_t total_weight - cdef float64_t weighted_sum - cdef bint min_heap - - cdef void reset(self) noexcept nogil - cdef bint is_empty(self) noexcept nogil - cdef void push(self, float64_t value, float64_t weight) noexcept nogil - cdef void pop(self, float64_t* value, float64_t* weight) noexcept nogil - cdef float64_t top_weight(self) noexcept nogil - cdef float64_t top(self) noexcept nogil - cdef void _swap(self, intp_t, intp_t) noexcept nogil - cdef void _heapify_up(self, intp_t) noexcept nogil - cdef void _heapify_down(self, intp_t) noexcept nogil cdef float64_t log(float64_t x) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 58362f4bcd116..2f331e7d976db 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -87,181 +87,18 @@ def _any_isnan_axis0(const float32_t[:, :] X): # ============================================================================= -# WeightedHeap data structure +# WeightedFenwickTree data structure # ============================================================================= -cdef class WeightedHeap: - """Binary heap with per-item weights, supporting min-heap and max-heap modes. - - Values are stored sign-adjusted internally so that the ordering logic - is always "min-heap" on the stored buffer: - - if min_heap: store v - - else (max-heap): store -v - - Attributes (all should be treated as readonly attributes) - ---------- - capacity : intp_t - Allocated capacity for the heap arrays. - - size : intp_t - Current number of elements in the heap. - - heap : float64_t* - Array of (possibly sign-adjusted) values that determines ordering. - - weights : float64_t* - Parallel array of weights. - - total_weight : float64_t - Sum of all weights currently in the heap. - - weighted_sum : float64_t - Sum over items of (original_value * weight), i.e. without sign-adjustment. - - min_heap : bint - If True, behaves as a min-heap; if False, behaves as a max-heap. - """ - - def __cinit__(self, intp_t capacity, bint min_heap=True): - if capacity <= 0: - capacity = 1 - self.capacity = capacity - self.size = 0 - self.min_heap = min_heap - self.total_weight = 0.0 - self.weighted_sum = 0.0 - self.heap = NULL - self.weights = NULL - # safe_realloc can raise MemoryError -> __cinit__ may propagate - safe_realloc(&self.heap, capacity) - safe_realloc(&self.weights, capacity) - - def __dealloc__(self): - if self.heap != NULL: - free(self.heap) - if self.weights != NULL: - free(self.weights) - - cdef void reset(self) noexcept nogil: - """Reset to construction state (keeps capacity).""" - self.size = 0 - self.total_weight = 0.0 - self.weighted_sum = 0.0 - - cdef bint is_empty(self) noexcept nogil: - return self.size == 0 - - cdef void push(self, float64_t value, float64_t weight) noexcept nogil: - """Insert a (value, weight).""" - cdef intp_t n = self.size - cdef float64_t stored = value if self.min_heap else -value - - assert n < self.capacity - # ^ should never raise as capacity is set to the max possible size - - self.heap[n] = stored - self.weights[n] = weight - self.size = n + 1 - - self.total_weight += weight - self.weighted_sum += value * weight - - self._heapify_up(n) - - cdef void pop(self, float64_t* value, float64_t* weight) noexcept nogil: - """Pop top element into pointers.""" - cdef intp_t n = self.size - assert n > 0 - - cdef float64_t stored = self.heap[0] - cdef float64_t v = stored if self.min_heap else -stored - cdef float64_t w = self.weights[0] - value[0] = v - weight[0] = w - - # Update aggregates - self.total_weight -= w - self.weighted_sum -= v * w - - # Move last to root and sift down - n -= 1 - self.size = n - if n > 0: - self.heap[0] = self.heap[n] - self.weights[0] = self.weights[n] - self._heapify_down(0) - - cdef float64_t top_weight(self) noexcept nogil: - assert self.size > 0 - return self.weights[0] - - cdef float64_t top(self) noexcept nogil: - assert self.size > 0 - cdef float64_t s = self.heap[0] - return s if self.min_heap else -s - - # Internal helpers (nogil): - - cdef inline void _swap(self, intp_t i, intp_t j) noexcept nogil: - cdef float64_t tmp = self.heap[i] - self.heap[i] = self.heap[j] - self.heap[j] = tmp - tmp = self.weights[i] - self.weights[i] = self.weights[j] - self.weights[j] = tmp - - cdef inline void _heapify_up(self, intp_t i) noexcept nogil: - """Move up the element at index i until heap invariant is restored.""" - cdef intp_t p - while i > 0: - p = (i - 1) >> 1 - if self.heap[i] < self.heap[p]: - self._swap(i, p) - i = p - else: - break - - cdef inline void _heapify_down(self, intp_t i) noexcept nogil: - """Move down the element at index i until heap invariant is restored.""" - cdef intp_t n = self.size - cdef intp_t left, right, mc - while True: - left = (i << 1) + 1 - right = left + 1 - if left >= n: - return - mc = left - if right < n and self.heap[right] < self.heap[left]: - mc = right - if self.heap[i] > self.heap[mc]: - self._swap(i, mc) - i = mc - else: - return - - -cdef class PytestWeightedHeap(WeightedHeap): - """Used for testing only""" - - def py_push(self, double value, double weight): - self.push(value, weight) - - def py_pop(self): - cdef double v, w - self.pop(&v, &w) - return v, w - - cdef class WeightedFenwickTree: """ Fenwick tree (Binary Indexed Tree) for maintaining: - - prefix sums of weights, and - - prefix sums of weight*value (targets), - indexed by the rank of y in sorted order (1-based internally). + - prefix sums of weights + - prefix sums of weight*value (targets) Supports: - add(rank, w, y): point update at 'rank' - - search(t): find the smallest rank with cumulative weight > t, + - search(t): find the smallest rank with cumulative weight > t (or >= t), also returns prefix aggregates excluding that rank. """ @@ -315,14 +152,13 @@ cdef class WeightedFenwickTree: bint inclusive ) noexcept nogil: """ - Find the leaf such that + Find the smallest index such that prefix_weight <= t < prefix_weight if inclusive and return: - - cw: prefix weight up to (rank-1) - - cwv: prefix weighted sum up to (rank-1) - - q: the y-value at 'rank' (the weighted alpha-quantile) - - prev_idx: if t == 0 at the end, the last index where we made a move + - idx: + - cw (write in cw_out): prefix weight up to idx exclusive + - cwy (write in cwy_out): prefix weighted sum to idx exclusive Notes: * Assumes there is at least one active (positive-weight) item. @@ -351,3 +187,18 @@ cdef class WeightedFenwickTree: cwy_out[0] = cwy return idx + + +cdef class PytestWeightedFenwickTree(WeightedFenwickTree): + """Used for testing only""" + + def py_reset(self, intp_t n): + self.reset(n) + + def py_add(self, intp_t idx, float64_t y, float64_t w): + self.add(idx, y, w) + + def py_search(self, float64_t t, inclusive=True): + cdef float64_t w, wy + idx = self.search(t, &w, &wy, inclusive) + return idx, w, wy diff --git a/sklearn/tree/tests/test_fenwick.py b/sklearn/tree/tests/test_fenwick.py new file mode 100644 index 0000000000000..80b55f592f417 --- /dev/null +++ b/sklearn/tree/tests/test_fenwick.py @@ -0,0 +1,31 @@ +import numpy as np + +from sklearn.tree._utils import PytestWeightedFenwickTree + + +# @pytest.mark.parametrize("min_heap", [True, False]) +def test_cython_weighted_fenwick_tree(): + """ + Test Cython's weighted Fenwick tree implementation + """ + rng = np.random.default_rng() + + n = 100 + indices = rng.permutation(n) + y = rng.normal(size=n) + w = rng.integers(1, 4, size=n) + y_sorted = np.zeros_like(y) + w_sorted = np.zeros_like(w) + + tree = PytestWeightedFenwickTree(n) + tree.py_reset(n) + + for idx in indices: + tree.py_add(idx, y[idx], w[idx]) + y_sorted[idx] = y[idx] + w_sorted[idx] = w[idx] + t = rng.uniform(0, w_sorted.sum()) + t_idx, cw, cwy = tree.py_search(t) + assert np.isclose(cw, w_sorted[:t_idx].sum()) + assert np.isclose(cwy, (w_sorted[:t_idx] * y_sorted[:t_idx]).sum()) + assert cw <= t diff --git a/sklearn/tree/tests/test_heap.py b/sklearn/tree/tests/test_heap.py deleted file mode 100644 index e54abb8b9fc90..0000000000000 --- a/sklearn/tree/tests/test_heap.py +++ /dev/null @@ -1,39 +0,0 @@ -import random -from heapq import heappop, heappush - -import pytest - -from sklearn.tree._utils import PytestWeightedHeap - - -@pytest.mark.parametrize("min_heap", [True, False]) -def test_cython_weighted_heap_vs_heapq(min_heap): - """ - Test Cython's weighted heap vs STL's heapq implementation. - - This unit-test first populates Cython Weighted Heap and STL's heap - with weighted samples, and then compares values that are popped. - """ - n = 200 - w_heap = PytestWeightedHeap(n, min_heap=min_heap) - py_heap = [] - - def pop_from_heaps_and_compare(): - top, top_w = w_heap.py_pop() - top_, top_w_ = heappop(py_heap) - if not min_heap: - top_ = -top_ - assert top == top_ - assert top_w == top_w_ - - for _ in range(n): - if len(py_heap) > 0 and random.random() < 1 / 3: - pop_from_heaps_and_compare() - else: - y = random.random() - w = random.random() - heappush(py_heap, (y if min_heap else -y, w)) - w_heap.py_push(y, w) - - for _ in range(len(py_heap)): - pop_from_heaps_and_compare() From 53ae038f954277bc0ebce024fb49b01b391ec5ad Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 9 Oct 2025 10:23:09 +0200 Subject: [PATCH 49/62] find idx and prev_idx in a single search --- sklearn/tree/_criterion.pyx | 8 +--- sklearn/tree/_utils.pxd | 2 +- sklearn/tree/_utils.pyx | 64 ++++++++++++++++++++++++------ sklearn/tree/tests/test_fenwick.py | 2 +- 4 files changed, 55 insertions(+), 21 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index da1ef4035b43c..8208c035c0273 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1247,12 +1247,8 @@ cdef void precompute_absolute_errors_fenwick( # Weighted alpha-quantile by cumulative weight half_weight = 0.5 * tree.total_w - median_idx = tree.search(half_weight, &w_left, &wy_left, inclusive=True) - if w_left == half_weight: - median_prev_idx = tree.search(half_weight, &w_right, &wy_right, inclusive=False) - median = (sorted_y[median_prev_idx] + sorted_y[median_idx]) / 2 - else: - median = sorted_y[median_idx] + median_idx = tree.search(half_weight, &w_left, &wy_left, &median_prev_idx) + median = (sorted_y[median_prev_idx] + sorted_y[median_idx]) / 2 # Right-side aggregates include the quantile position w_right = tree.total_w - w_left diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index ac38f72818829..97f8d60645b04 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -66,5 +66,5 @@ cdef class WeightedFenwickTree: float64_t t, float64_t* cw_out, float64_t* cwy_out, - bint inclusive + intp_t* prev_idx_out, ) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 2f331e7d976db..398e09e326e72 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -149,34 +149,71 @@ cdef class WeightedFenwickTree: float64_t t, float64_t* cw_out, float64_t* cwy_out, - bint inclusive + intp_t* prev_idx_out, ) noexcept nogil: """ - Find the smallest index such that - prefix_weight <= t < prefix_weight if inclusive + Find (prev_leaf, leaf) such that: + prefix_weight(prev_leaf) < t < prefix_weight(next_leaf(prev_leaf)) + prefix_weight(leaf) <= t < prefix_weight(next_leaf(leaf)) + Possibly, leaf == prev_leaf and return: - - idx: + - idx: leaf - cw (write in cw_out): prefix weight up to idx exclusive - cwy (write in cwy_out): prefix weighted sum to idx exclusive + - prev_idx (write in prev_idx_out): prev_leaf - Notes: - * Assumes there is at least one active (positive-weight) item. - * If t >= total weight (can happen with alpha ~ 1), we clamp t slightly. + Assumes: + * there is at least one active (positive-weight) item. + * 0 <= t <= total_weight """ cdef: intp_t idx = 0 + intp_t next_idx, prev_idx, bit_eq float64_t cw = 0.0 float64_t cwy = 0.0 intp_t bit = self.max_pow2 - float64_t w + float64_t w, t_eq # Standard Fenwick lower-bound with simultaneous prefix accumulation while bit != 0: next_idx = idx + bit if next_idx <= self.size: w = self.tree_w[next_idx] - if (t > w) or (inclusive and t >= w): + if t == w: + t_eq = t + bit_eq = bit + break + elif t > w: + t -= w + idx = next_idx + cw += w + cwy += self.tree_wy[next_idx] + bit >>= 1 + + if bit == 0: + cw_out[0] = cw + cwy_out[0] = cwy + prev_idx_out[0] = idx + return idx + + prev_idx = idx + while bit != 0: + next_idx = prev_idx + bit + if next_idx <= self.size: + w = self.tree_w[next_idx] + if t > w: + t -= w + prev_idx = next_idx + bit >>= 1 + + bit = bit_eq + t = t_eq + while bit != 0: + next_idx = idx + bit + if next_idx <= self.size: + w = self.tree_w[next_idx] + if t >= w: t -= w idx = next_idx cw += w @@ -185,7 +222,7 @@ cdef class WeightedFenwickTree: cw_out[0] = cw cwy_out[0] = cwy - + prev_idx_out[0] = prev_idx return idx @@ -198,7 +235,8 @@ cdef class PytestWeightedFenwickTree(WeightedFenwickTree): def py_add(self, intp_t idx, float64_t y, float64_t w): self.add(idx, y, w) - def py_search(self, float64_t t, inclusive=True): + def py_search(self, float64_t t): cdef float64_t w, wy - idx = self.search(t, &w, &wy, inclusive) - return idx, w, wy + cdef intp_t prev_idx + idx = self.search(t, &w, &wy, &prev_idx) + return prev_idx, idx, w, wy diff --git a/sklearn/tree/tests/test_fenwick.py b/sklearn/tree/tests/test_fenwick.py index 80b55f592f417..d205c8e7fdc3f 100644 --- a/sklearn/tree/tests/test_fenwick.py +++ b/sklearn/tree/tests/test_fenwick.py @@ -25,7 +25,7 @@ def test_cython_weighted_fenwick_tree(): y_sorted[idx] = y[idx] w_sorted[idx] = w[idx] t = rng.uniform(0, w_sorted.sum()) - t_idx, cw, cwy = tree.py_search(t) + t_idx_low, t_idx, cw, cwy = tree.py_search(t) assert np.isclose(cw, w_sorted[:t_idx].sum()) assert np.isclose(cwy, (w_sorted[:t_idx] * y_sorted[:t_idx]).sum()) assert cw <= t From acadfb7e5ec7d9a56a6c26e29dcc64dd636b6701 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 9 Oct 2025 11:40:50 +0200 Subject: [PATCH 50/62] attempt at fixing sorting --- sklearn/tree/_sorting.pxd | 6 ++---- sklearn/tree/_sorting.pyx | 16 ++++++++-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/_sorting.pxd b/sklearn/tree/_sorting.pxd index 969294885f01e..31faf657249d4 100644 --- a/sklearn/tree/_sorting.pxd +++ b/sklearn/tree/_sorting.pxd @@ -1,9 +1,7 @@ from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t -ctypedef fused floating_t: - float32_t - float64_t +from cython cimport floating -cdef inline void sort(floating_t* feature_values, intp_t* samples, intp_t n) noexcept nogil +cdef void sort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil diff --git a/sklearn/tree/_sorting.pyx b/sklearn/tree/_sorting.pyx index 5ae7de3233ade..c26e7cbc33d1a 100644 --- a/sklearn/tree/_sorting.pyx +++ b/sklearn/tree/_sorting.pyx @@ -8,24 +8,24 @@ def _py_sort(float32_t[::1] feature_values, intp_t[::1] samples, intp_t n): # Sort n-element arrays pointed to by feature_values and samples, simultaneously, # by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef inline void sort(floating_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: +cdef void sort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil: if n == 0: return cdef intp_t maxd = 2 * log2(n) introsort(feature_values, samples, n, maxd) -cdef inline void swap(floating_t* feature_values, intp_t* samples, +cdef inline void swap(floating* feature_values, intp_t* samples, intp_t i, intp_t j) noexcept nogil: # Helper for sort feature_values[i], feature_values[j] = feature_values[j], feature_values[i] samples[i], samples[j] = samples[j], samples[i] -cdef inline floating_t median3(floating_t* feature_values, intp_t n) noexcept nogil: +cdef inline floating median3(floating* feature_values, intp_t n) noexcept nogil: # Median of three pivot selection, after Bentley and McIlroy (1993). # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. - cdef floating_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] + cdef floating a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] if a < b: if b < c: return b @@ -44,9 +44,9 @@ cdef inline floating_t median3(floating_t* feature_values, intp_t n) noexcept no # Introsort with median of 3 pivot selection and 3-way partition function # (robust to repeated elements, e.g. lots of zero features). -cdef void introsort(floating_t* feature_values, intp_t *samples, +cdef void introsort(floating* feature_values, intp_t *samples, intp_t n, intp_t maxd) noexcept nogil: - cdef floating_t pivot + cdef floating pivot cdef intp_t i, l, r while n > 1: @@ -77,7 +77,7 @@ cdef void introsort(floating_t* feature_values, intp_t *samples, n -= r -cdef inline void sift_down(floating_t* feature_values, intp_t* samples, +cdef inline void sift_down(floating* feature_values, intp_t* samples, intp_t start, intp_t end) noexcept nogil: # Restore heap order in feature_values[start:end] by moving the max element to start. cdef intp_t child, maxind, root @@ -100,7 +100,7 @@ cdef inline void sift_down(floating_t* feature_values, intp_t* samples, root = maxind -cdef void heapsort(floating_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: +cdef void heapsort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil: cdef intp_t start, end # heapify From 39a15c17caa14cb6cfc03c14a509a2342d87010e Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 13 Oct 2025 18:26:22 +0200 Subject: [PATCH 51/62] cleanup --- sklearn/tree/_criterion.pyx | 26 ++-- sklearn/tree/_utils.pyx | 223 ++++++++++++++++++----------- sklearn/tree/tests/test_fenwick.py | 32 ++++- 3 files changed, 180 insertions(+), 101 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 8208c035c0273..6eb6ce6580269 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1177,7 +1177,7 @@ cdef class MSE(RegressionCriterion): # Helper for MAE criterion: -cdef void precompute_absolute_errors_fenwick( +cdef void precompute_absolute_errors( const float64_t[::1] sorted_y, const intp_t[::1] ranks, const float64_t[:] sample_weight, @@ -1248,9 +1248,13 @@ cdef void precompute_absolute_errors_fenwick( # Weighted alpha-quantile by cumulative weight half_weight = 0.5 * tree.total_w median_idx = tree.search(half_weight, &w_left, &wy_left, &median_prev_idx) - median = (sorted_y[median_prev_idx] + sorted_y[median_idx]) / 2 - # Right-side aggregates include the quantile position + if median_idx != median_prev_idx: + # Exact match for half_weight in the tree, take the middle point: + median = (sorted_y[median_prev_idx] + sorted_y[median_idx]) / 2 + else: + median = sorted_y[median_idx] + w_right = tree.total_w - w_left wy_right = tree.total_wy - wy_left @@ -1264,15 +1268,16 @@ cdef void precompute_absolute_errors_fenwick( cdef inline void compute_ranks( - float64_t* sorted_y, + float64_t* y, intp_t* sorted_indices, intp_t* ranks, intp_t n ) noexcept nogil: + """Sort `y` inplace and fill `ranks` accordingly""" cdef intp_t i for i in range(n): sorted_indices[i] = i - sort(sorted_y, sorted_indices, n) + sort(y, sorted_indices, n) for i in range(n): ranks[sorted_indices[i]] = i @@ -1305,7 +1310,7 @@ def _py_precompute_absolute_errors( sorted_y[p - s] = ys[i, 0] compute_ranks(&sorted_y[0], &sorted_indices[0], &ranks[s], n) - precompute_absolute_errors_fenwick( + precompute_absolute_errors( sorted_y, ranks, sample_weight, sample_indices, tree, start, end, abs_errors, medians ) @@ -1443,6 +1448,11 @@ cdef class MAE(Criterion): i = self.sample_indices[p] self.sorted_y[p - self.start] = self.y[i, k] + # Compute the ranks of the node-local values in sorted order. + # - self.sorted_y[0:n_node_samples] is sorted in-place (with indices). + # - self.sorted_indices is a buffer used internally by compute_ranks + # - self.ranks[p] receives the rank of self.y[self.samples_indices[p], k] + # in the sorted array, for p in [start, end) compute_ranks( &self.sorted_y[0], &self.sorted_indices[0], @@ -1453,7 +1463,7 @@ cdef class MAE(Criterion): # Note that at each iteration of this loop, we overwrite `self.left_medians` # and `self.right_medians`. They are used to check for monoticity constraints, # which are allowed only with n_outputs=1. - precompute_absolute_errors_fenwick( + precompute_absolute_errors( self.sorted_y, self.ranks, self.sample_weight, self.sample_indices, self.tree, self.start, self.end, # left_abs_errors is incremented, left_medians is overwritten @@ -1461,7 +1471,7 @@ cdef class MAE(Criterion): ) # For the right child, we consider samples from end-1 to start-1 # i.e., reversed, and abs error & median are filled in reverse order to. - precompute_absolute_errors_fenwick( + precompute_absolute_errors( self.sorted_y, self.ranks, self.sample_weight, self.sample_indices, self.tree, self.end - 1, self.start - 1, # right_abs_errors is incremented, right_medians is overwritten diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 398e09e326e72..b80fbba6fb197 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -92,28 +92,36 @@ def _any_isnan_axis0(const float32_t[:, :] X): cdef class WeightedFenwickTree: """ - Fenwick tree (Binary Indexed Tree) for maintaining: + Fenwick tree (Binary Indexed Tree) specialized for maintaining: - prefix sums of weights - - prefix sums of weight*value (targets) - - Supports: - - add(rank, w, y): point update at 'rank' - - search(t): find the smallest rank with cumulative weight > t (or >= t), - also returns prefix aggregates excluding that rank. + - prefix sums of weight * target (y) + + Notes: + - Implementation uses 1-based indexing internally for the Fenwick tree + arrays, hence the +1 sized buffers. + - Memory ownership: this class allocates and frees the underlying C buffers. + - Typical operations: + add(rank, y, w) -> O(log n) + search(t) -> O(log n), finds the smallest rank with + cumulative weight > t (see search for details). """ def __cinit__(self, intp_t capacity): self.tree_w = NULL self.tree_wy = NULL - # safe_realloc can raise MemoryError -> __cinit__ may propagate + + # Allocate arrays of length (capacity + 1) because indices are 1-based. safe_realloc(&self.tree_w, capacity + 1) safe_realloc(&self.tree_wy, capacity + 1) cdef void reset(self, intp_t size) noexcept nogil: + """ + Reset the tree to hold 'size' elements and clear all aggregates. + """ cdef intp_t p - cdef intp_t n_bytes = (size + 1) * sizeof(float64_t) - # +1 because 1-based + cdef intp_t n_bytes = (size + 1) * sizeof(float64_t) # +1 for 1-based storage + # Public size and zeroed aggregates. self.size = size memset(self.tree_w, 0, n_bytes) memset(self.tree_wy, 0, n_bytes) @@ -132,98 +140,141 @@ cdef class WeightedFenwickTree: if self.tree_wy != NULL: free(self.tree_wy) - cdef void add(self, intp_t idx, float64_t y, float64_t w) noexcept nogil: - cdef float64_t wy = w * y - idx += 1 # 1-based + cdef void add(self, intp_t idx, float64_t y_value, float64_t weight) noexcept nogil: + """ + Add a weighted observation to the Fenwick tree. + + Parameters + ---------- + idx : intp_t + The 0-based index where to add the observation + y_value : float64_t + The target value (y) of the observation + weight : float64_t + The sample weight + + Notes + ----- + Updates both weight sums and weighted target sums in O(log n) time. + """ + cdef float64_t weighted_y = weight * y_value + cdef intp_t fenwick_idx = idx + 1 # Convert to 1-based indexing - while idx <= self.size: - self.tree_w[idx] += w - self.tree_wy[idx] += wy - idx += idx & -idx + # Update Fenwick tree nodes by traversing up the tree + while fenwick_idx <= self.size: + self.tree_w[fenwick_idx] += weight + self.tree_wy[fenwick_idx] += weighted_y + # Move to next node using bit manipulation: add lowest set bit + fenwick_idx += fenwick_idx & -fenwick_idx - self.total_w += w - self.total_wy += wy + # Update global totals + self.total_w += weight + self.total_wy += weighted_y cdef intp_t search( self, - float64_t t, - float64_t* cw_out, - float64_t* cwy_out, + float64_t target_weight, + float64_t* cumul_weight_out, + float64_t* cumul_weighted_y_out, intp_t* prev_idx_out, ) noexcept nogil: """ - Find (prev_leaf, leaf) such that: - prefix_weight(prev_leaf) < t < prefix_weight(next_leaf(prev_leaf)) - prefix_weight(leaf) <= t < prefix_weight(next_leaf(leaf)) - Possibly, leaf == prev_leaf - - and return: - - idx: leaf - - cw (write in cw_out): prefix weight up to idx exclusive - - cwy (write in cwy_out): prefix weighted sum to idx exclusive - - prev_idx (write in prev_idx_out): prev_leaf - - Assumes: - * there is at least one active (positive-weight) item. - * 0 <= t <= total_weight + Binary search to find the position where cumulative weight reaches target. + + This method performs a binary search on the Fenwick tree to find indices + such that the cumulative weight at 'prev_idx' is < target_weight and + the cumulative weight at the returned index is >= target_weight. + + Parameters + ---------- + target_weight : float64_t + The target cumulative weight to search for + cumul_weight_out : float64_t* + Output pointer for cumulative weight up to returned index (exclusive) + cumul_weighted_y_out : float64_t* + Output pointer for cumulative weighted y-sum up to returned index (exclusive) + prev_idx_out : intp_t* + Output pointer for the previous index (largest index with cumul_weight < target) + + Returns + ------- + intp_t + The index where cumulative weight first reaches or exceeds target_weight + + Notes + ----- + - O(log n) complexity + - Ignores nodes with zero weights (corresponding to uninserted y-values) + - Assumes at least one active (positive-weight) item exists + - Assumes 0 <= target_weight <= total_weight """ cdef: - intp_t idx = 0 - intp_t next_idx, prev_idx, bit_eq - float64_t cw = 0.0 - float64_t cwy = 0.0 - intp_t bit = self.max_pow2 - float64_t w, t_eq - - # Standard Fenwick lower-bound with simultaneous prefix accumulation - while bit != 0: - next_idx = idx + bit + intp_t current_idx = 0 + intp_t next_idx, prev_idx, equal_bit + float64_t cumul_weight = 0.0 + float64_t cumul_weighted_y = 0.0 + intp_t search_bit = self.max_pow2 # Start from highest power of 2 + float64_t node_weight, equal_target + + # Phase 1: Standard Fenwick binary search with prefix accumulation + # Traverse down the tree, moving right when we can consume more weight + while search_bit != 0: + next_idx = current_idx + search_bit if next_idx <= self.size: - w = self.tree_w[next_idx] - if t == w: - t_eq = t - bit_eq = bit + node_weight = self.tree_w[next_idx] + if target_weight == node_weight: + # Exact match found - store state for later processing + equal_target = target_weight + equal_bit = search_bit break - elif t > w: - t -= w - idx = next_idx - cw += w - cwy += self.tree_wy[next_idx] - bit >>= 1 - - if bit == 0: - cw_out[0] = cw - cwy_out[0] = cwy - prev_idx_out[0] = idx - return idx - - prev_idx = idx - while bit != 0: - next_idx = prev_idx + bit + elif target_weight > node_weight: + # We can consume this node's weight - move right and accumulate + target_weight -= node_weight + current_idx = next_idx + cumul_weight += node_weight + cumul_weighted_y += self.tree_wy[next_idx] + search_bit >>= 1 + + # If no exact match, we're done with standard search + if search_bit == 0: + cumul_weight_out[0] = cumul_weight + cumul_weighted_y_out[0] = cumul_weighted_y + prev_idx_out[0] = current_idx + return current_idx + + # Phase 2: Handle exact match case - find prev_idx + # Search for the largest index with cumulative weight < original target + prev_idx = current_idx + while search_bit != 0: + next_idx = prev_idx + search_bit if next_idx <= self.size: - w = self.tree_w[next_idx] - if t > w: - t -= w + node_weight = self.tree_w[next_idx] + if target_weight > node_weight: + target_weight -= node_weight prev_idx = next_idx - bit >>= 1 - - bit = bit_eq - t = t_eq - while bit != 0: - next_idx = idx + bit + search_bit >>= 1 + + # Phase 3: Complete the exact match search + # Restore state and search for the largest index with + # cumulative weight <= original target (and this is case, we know we have ==) + search_bit = equal_bit + target_weight = equal_target + while search_bit != 0: + next_idx = current_idx + search_bit if next_idx <= self.size: - w = self.tree_w[next_idx] - if t >= w: - t -= w - idx = next_idx - cw += w - cwy += self.tree_wy[next_idx] - bit >>= 1 - - cw_out[0] = cw - cwy_out[0] = cwy + node_weight = self.tree_w[next_idx] + if target_weight >= node_weight: + target_weight -= node_weight + current_idx = next_idx + cumul_weight += node_weight + cumul_weighted_y += self.tree_wy[next_idx] + search_bit >>= 1 + + # Output results + cumul_weight_out[0] = cumul_weight + cumul_weighted_y_out[0] = cumul_weighted_y prev_idx_out[0] = prev_idx - return idx + return current_idx cdef class PytestWeightedFenwickTree(WeightedFenwickTree): diff --git a/sklearn/tree/tests/test_fenwick.py b/sklearn/tree/tests/test_fenwick.py index d205c8e7fdc3f..b71184d84cb6d 100644 --- a/sklearn/tree/tests/test_fenwick.py +++ b/sklearn/tree/tests/test_fenwick.py @@ -3,12 +3,11 @@ from sklearn.tree._utils import PytestWeightedFenwickTree -# @pytest.mark.parametrize("min_heap", [True, False]) -def test_cython_weighted_fenwick_tree(): +def test_cython_weighted_fenwick_tree(global_random_seed): """ Test Cython's weighted Fenwick tree implementation """ - rng = np.random.default_rng() + rng = np.random.default_rng(global_random_seed) n = 100 indices = rng.permutation(n) @@ -20,12 +19,31 @@ def test_cython_weighted_fenwick_tree(): tree = PytestWeightedFenwickTree(n) tree.py_reset(n) - for idx in indices: + for i in range(n): + idx = indices[i] tree.py_add(idx, y[idx], w[idx]) y_sorted[idx] = y[idx] w_sorted[idx] = w[idx] - t = rng.uniform(0, w_sorted.sum()) - t_idx_low, t_idx, cw, cwy = tree.py_search(t) + + target = rng.uniform(0, w_sorted.sum()) + t_idx_low, t_idx, cw, cwy = tree.py_search(target) + + # check the aggregates are consistent with the returned idx assert np.isclose(cw, w_sorted[:t_idx].sum()) assert np.isclose(cwy, (w_sorted[:t_idx] * y_sorted[:t_idx]).sum()) - assert cw <= t + + # check if the cumulative weight is less than or equal to the target + # depending on t_idx_low and t_idx + if t_idx_low == t_idx: + assert cw < target + else: + assert cw == target + + # check that if we add the next weight, we are above the target: + next_weights = w_sorted[t_idx:][w_sorted[t_idx:] > 0] + if next_weights.size > 0: + assert cw + next_weights[0] > target + # and not below the target for `t_idx_low`: + next_weights = w_sorted[t_idx_low:][w_sorted[t_idx_low:] > 0] + if next_weights.size > 0: + assert cw + next_weights[0] >= target From 75c27e11fd1d7e981da89c8f7b58f13e354bd762 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 20 Oct 2025 20:27:30 +0200 Subject: [PATCH 52/62] moved back sort function into partitioner --- sklearn/tree/_criterion.pyx | 2 +- sklearn/tree/_partitioner.pxd | 5 ++ sklearn/tree/_partitioner.pyx | 121 +++++++++++++++++++++++++++++++- sklearn/tree/_sorting.pxd | 7 -- sklearn/tree/_sorting.pyx | 120 ------------------------------- sklearn/tree/meson.build | 3 - sklearn/tree/tests/test_tree.py | 2 +- 7 files changed, 126 insertions(+), 134 deletions(-) delete mode 100644 sklearn/tree/_sorting.pxd delete mode 100644 sklearn/tree/_sorting.pyx diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 6eb6ce6580269..5c4ef8b2b0725 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -13,7 +13,7 @@ from scipy.special.cython_special cimport xlogy from sklearn.tree._utils cimport log from sklearn.tree._utils cimport WeightedFenwickTree -from sklearn.tree._sorting cimport sort +from sklearn.tree._partitioner cimport sort # EPSILON is used in the Poisson criterion cdef float64_t EPSILON = 10 * np.finfo('double').eps diff --git a/sklearn/tree/_partitioner.pxd b/sklearn/tree/_partitioner.pxd index 6aa92db088645..6590b8ed585f1 100644 --- a/sklearn/tree/_partitioner.pxd +++ b/sklearn/tree/_partitioner.pxd @@ -3,6 +3,8 @@ # See _partitioner.pyx for details. +from cython cimport floating + from sklearn.utils._typedefs cimport ( float32_t, float64_t, int8_t, int32_t, intp_t, uint8_t, uint32_t ) @@ -176,3 +178,6 @@ cdef void shift_missing_values_to_left_if_required( intp_t[::1] samples, intp_t end, ) noexcept nogil + + +cdef void sort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil diff --git a/sklearn/tree/_partitioner.pyx b/sklearn/tree/_partitioner.pyx index f113ffaa6e6d2..a970834a58a50 100644 --- a/sklearn/tree/_partitioner.pyx +++ b/sklearn/tree/_partitioner.pyx @@ -18,8 +18,6 @@ from libc.string cimport memcpy import numpy as np from scipy.sparse import issparse -from sklearn.tree._sorting cimport sort - # Constant to switch between algorithm non zero value extract algorithm # in SparsePartitioner @@ -698,3 +696,122 @@ cdef inline void shift_missing_values_to_left_if_required( current_end = end - 1 - p samples[i], samples[current_end] = samples[current_end], samples[i] best.pos += best.n_missing + + +def _py_sort(float32_t[::1] feature_values, intp_t[::1] samples, intp_t n): + """Used for testing sort.""" + sort(&feature_values[0], &samples[0], n) + + +# Sort n-element arrays pointed to by feature_values and samples, simultaneously, +# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). +cdef void sort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil: + if n == 0: + return + cdef intp_t maxd = 2 * log2(n) + introsort(feature_values, samples, n, maxd) + + +cdef inline void swap(floating* feature_values, intp_t* samples, + intp_t i, intp_t j) noexcept nogil: + # Helper for sort + feature_values[i], feature_values[j] = feature_values[j], feature_values[i] + samples[i], samples[j] = samples[j], samples[i] + + +cdef inline floating median3(floating* feature_values, intp_t n) noexcept nogil: + # Median of three pivot selection, after Bentley and McIlroy (1993). + # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. + cdef floating a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] + if a < b: + if b < c: + return b + elif a < c: + return c + else: + return a + elif b < c: + if a < c: + return a + else: + return c + else: + return b + + +# Introsort with median of 3 pivot selection and 3-way partition function +# (robust to repeated elements, e.g. lots of zero features). +cdef void introsort(floating* feature_values, intp_t *samples, + intp_t n, intp_t maxd) noexcept nogil: + cdef floating pivot + cdef intp_t i, l, r + + while n > 1: + if maxd <= 0: # max depth limit exceeded ("gone quadratic") + heapsort(feature_values, samples, n) + return + maxd -= 1 + + pivot = median3(feature_values, n) + + # Three-way partition. + i = l = 0 + r = n + while i < r: + if feature_values[i] < pivot: + swap(feature_values, samples, i, l) + i += 1 + l += 1 + elif feature_values[i] > pivot: + r -= 1 + swap(feature_values, samples, i, r) + else: + i += 1 + + introsort(feature_values, samples, l, maxd) + feature_values += r + samples += r + n -= r + + +cdef inline void sift_down(floating* feature_values, intp_t* samples, + intp_t start, intp_t end) noexcept nogil: + # Restore heap order in feature_values[start:end] by moving the max element to start. + cdef intp_t child, maxind, root + + root = start + while True: + child = root * 2 + 1 + + # find max of root, left child, right child + maxind = root + if child < end and feature_values[maxind] < feature_values[child]: + maxind = child + if child + 1 < end and feature_values[maxind] < feature_values[child + 1]: + maxind = child + 1 + + if maxind == root: + break + else: + swap(feature_values, samples, root, maxind) + root = maxind + + +cdef void heapsort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil: + cdef intp_t start, end + + # heapify + start = (n - 2) / 2 + end = n + while True: + sift_down(feature_values, samples, start, end) + if start == 0: + break + start -= 1 + + # sort by shrinking the heap, putting the max element immediately after it + end = n - 1 + while end > 0: + swap(feature_values, samples, 0, end) + sift_down(feature_values, samples, 0, end) + end = end - 1 diff --git a/sklearn/tree/_sorting.pxd b/sklearn/tree/_sorting.pxd deleted file mode 100644 index 31faf657249d4..0000000000000 --- a/sklearn/tree/_sorting.pxd +++ /dev/null @@ -1,7 +0,0 @@ -from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t - - -from cython cimport floating - - -cdef void sort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil diff --git a/sklearn/tree/_sorting.pyx b/sklearn/tree/_sorting.pyx deleted file mode 100644 index c26e7cbc33d1a..0000000000000 --- a/sklearn/tree/_sorting.pyx +++ /dev/null @@ -1,120 +0,0 @@ -from libc.math cimport log2 - - -def _py_sort(float32_t[::1] feature_values, intp_t[::1] samples, intp_t n): - """Used for testing sort.""" - sort(&feature_values[0], &samples[0], n) - - -# Sort n-element arrays pointed to by feature_values and samples, simultaneously, -# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef void sort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil: - if n == 0: - return - cdef intp_t maxd = 2 * log2(n) - introsort(feature_values, samples, n, maxd) - - -cdef inline void swap(floating* feature_values, intp_t* samples, - intp_t i, intp_t j) noexcept nogil: - # Helper for sort - feature_values[i], feature_values[j] = feature_values[j], feature_values[i] - samples[i], samples[j] = samples[j], samples[i] - - -cdef inline floating median3(floating* feature_values, intp_t n) noexcept nogil: - # Median of three pivot selection, after Bentley and McIlroy (1993). - # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. - cdef floating a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] - if a < b: - if b < c: - return b - elif a < c: - return c - else: - return a - elif b < c: - if a < c: - return a - else: - return c - else: - return b - - -# Introsort with median of 3 pivot selection and 3-way partition function -# (robust to repeated elements, e.g. lots of zero features). -cdef void introsort(floating* feature_values, intp_t *samples, - intp_t n, intp_t maxd) noexcept nogil: - cdef floating pivot - cdef intp_t i, l, r - - while n > 1: - if maxd <= 0: # max depth limit exceeded ("gone quadratic") - heapsort(feature_values, samples, n) - return - maxd -= 1 - - pivot = median3(feature_values, n) - - # Three-way partition. - i = l = 0 - r = n - while i < r: - if feature_values[i] < pivot: - swap(feature_values, samples, i, l) - i += 1 - l += 1 - elif feature_values[i] > pivot: - r -= 1 - swap(feature_values, samples, i, r) - else: - i += 1 - - introsort(feature_values, samples, l, maxd) - feature_values += r - samples += r - n -= r - - -cdef inline void sift_down(floating* feature_values, intp_t* samples, - intp_t start, intp_t end) noexcept nogil: - # Restore heap order in feature_values[start:end] by moving the max element to start. - cdef intp_t child, maxind, root - - root = start - while True: - child = root * 2 + 1 - - # find max of root, left child, right child - maxind = root - if child < end and feature_values[maxind] < feature_values[child]: - maxind = child - if child + 1 < end and feature_values[maxind] < feature_values[child + 1]: - maxind = child + 1 - - if maxind == root: - break - else: - swap(feature_values, samples, root, maxind) - root = maxind - - -cdef void heapsort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil: - cdef intp_t start, end - - # heapify - start = (n - 2) / 2 - end = n - while True: - sift_down(feature_values, samples, start, end) - if start == 0: - break - start -= 1 - - # sort by shrinking the heap, putting the max element immediately after it - end = n - 1 - while end > 0: - swap(feature_values, samples, 0, end) - sift_down(feature_values, samples, 0, end) - end = end - 1 diff --git a/sklearn/tree/meson.build b/sklearn/tree/meson.build index d92a1e9703727..87345a1e344bf 100644 --- a/sklearn/tree/meson.build +++ b/sklearn/tree/meson.build @@ -14,9 +14,6 @@ tree_extension_metadata = { '_utils': {'sources': [cython_gen.process('_utils.pyx')], 'override_options': ['optimization=3']}, - '_sorting': - {'sources': [cython_gen.process('_sorting.pyx')], - 'override_options': ['optimization=3']}, } foreach ext_name, ext_dict : tree_extension_metadata diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 080380795daa1..804e43a7e5968 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -42,7 +42,7 @@ SPARSE_SPLITTERS, ) from sklearn.tree._criterion import _py_precompute_absolute_errors -from sklearn.tree._sorting import _py_sort +from sklearn.tree._partitioner import _py_sort from sklearn.tree._tree import ( NODE_DTYPE, TREE_LEAF, From 2c973995c8055d028a720fc9b571658d46b9fa0c Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Mon, 20 Oct 2025 20:29:07 +0200 Subject: [PATCH 53/62] Apply suggestion from @ogrisel Co-authored-by: Olivier Grisel --- sklearn/tree/_criterion.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 5c4ef8b2b0725..0b96dbb63c6e3 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1388,7 +1388,7 @@ cdef class MAE(Criterion): sample_indices[start:start] and sample_indices[start:end]. WARNING: sample_indices will be modified in-place externally - after this method is called + after this method is called. """ cdef: intp_t i, p From 0285c97fe60b4b0d7d21113e0b6c3e97d414b22f Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Mon, 20 Oct 2025 20:29:27 +0200 Subject: [PATCH 54/62] Update sklearn/tree/_criterion.pyx Co-authored-by: Olivier Grisel --- sklearn/tree/_criterion.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 0b96dbb63c6e3..bef0c2474bfaf 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1324,7 +1324,7 @@ cdef class MAE(Criterion): value and f_i is the predicted value. It has almost nothing in common with other regression criterions - so it doesn't inherit from RegressionCriterion + so it doesn't inherit from RegressionCriterion. """ cdef float64_t[::1] node_medians cdef float64_t[::1] left_abs_errors From 7523930d97e42306bda44682a7a3efb8c91d71c6 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 20 Oct 2025 20:58:54 +0200 Subject: [PATCH 55/62] add link to report --- sklearn/tree/_criterion.pyx | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index bef0c2474bfaf..4bee4c9c91244 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1320,11 +1320,16 @@ def _py_precompute_absolute_errors( cdef class MAE(Criterion): r"""Mean absolute error impurity criterion. - MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true - value and f_i is the predicted value. - It has almost nothing in common with other regression criterions so it doesn't inherit from RegressionCriterion. + + MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true + value and f_i is the predicted value. + In decision trees, all samples within a node share the same f_i value, + which corresponds to the weighted median of the target values y_i. + For detailed explanations of the mathematics and algorithmic aspects of + this implementation, refer to + https://github.com/cakedev0/fast-mae-split/blob/main/report.ipynb """ cdef float64_t[::1] node_medians cdef float64_t[::1] left_abs_errors @@ -1430,7 +1435,7 @@ cdef class MAE(Criterion): Reset might be called after an external class has changed inplace self.sample_indices[start:end], hence re-computing - the absolute errors is needed + the absolute errors is needed. """ cdef intp_t k, p, i @@ -1483,17 +1488,17 @@ cdef class MAE(Criterion): return 0 cdef int reverse_reset(self) except -1 nogil: - """For this class, this method is never called""" + """For this class, this method is never called.""" raise NotImplementedError("This method is not implemented for this subclass") cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left. - new_pos is guaranteed to be greater than pos + new_pos is guaranteed to be greater than pos. Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. - Time complexity: O(new_pos - pos) (which usually is O(1), at least for dense data) + Time complexity: O(new_pos - pos) (which usually is O(1), at least for dense data). """ cdef intp_t pos = self.pos cdef intp_t i, p From 78d6cfc2793693b83903509576dcf03cc816b10d Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Tue, 21 Oct 2025 14:34:26 +0200 Subject: [PATCH 56/62] Update sklearn/tree/_utils.pyx Co-authored-by: Olivier Grisel --- sklearn/tree/_utils.pyx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index b80fbba6fb197..18fb6cb5497b4 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -98,7 +98,9 @@ cdef class WeightedFenwickTree: Notes: - Implementation uses 1-based indexing internally for the Fenwick tree - arrays, hence the +1 sized buffers. + arrays, hence the +1 sized buffers. 1-based indexing is customary for this + data structure and makes the some index handling slightly more efficient and + natural. - Memory ownership: this class allocates and frees the underlying C buffers. - Typical operations: add(rank, y, w) -> O(log n) From 0fdab95e216b6b38b9f18dcf69e461adad3469bb Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Tue, 21 Oct 2025 16:27:23 +0200 Subject: [PATCH 57/62] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/tree/_criterion.pyx | 3 +++ sklearn/tree/tests/test_tree.py | 1 + 2 files changed, 4 insertions(+) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 4bee4c9c91244..0af82997565b4 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1204,6 +1204,9 @@ cdef void precompute_absolute_errors( sorted_y : const float64_t[::1] Target values, sorted ranks : const intp_t[::1] + Ranks of the node-local values of y for points in sample_indices such that: + sorted_y[ranks[p]] == y[sample_indices[p]] for any p in [start, end) or + (end, start]. sample_weight : const float64_t[:] sample_indices : const intp_t[:] indices indicating which samples to use. Shape: (n_samples,) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 804e43a7e5968..469d1efb569f2 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2948,6 +2948,7 @@ def compute_prefix_abs_errors_naive(y, w): return np.array(errors), np.array(medians) def assert_same_results(y, w, indices, reverse=False): + n = y.shape[0] args = (n - 1, -1) if reverse else (0, n) abs_errors, medians = _py_precompute_absolute_errors(y, w, indices, *args, n) y_sorted = y[indices] From f2219f884f7571e35ae7bc9515c52159d0dc7d69 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 21 Oct 2025 16:46:29 +0200 Subject: [PATCH 58/62] rename y -> sorted_y in compute_ranks --- sklearn/tree/_criterion.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 0af82997565b4..6e1996a681668 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1271,16 +1271,16 @@ cdef void precompute_absolute_errors( cdef inline void compute_ranks( - float64_t* y, + float64_t* sorted_y, intp_t* sorted_indices, intp_t* ranks, intp_t n ) noexcept nogil: - """Sort `y` inplace and fill `ranks` accordingly""" + """Sort `sorted_y` inplace and fill `ranks` accordingly""" cdef intp_t i for i in range(n): sorted_indices[i] = i - sort(y, sorted_indices, n) + sort(sorted_y, sorted_indices, n) for i in range(n): ranks[sorted_indices[i]] = i From 64a3516a219e6ad174afcb537b599f5a18d92673 Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Mon, 27 Oct 2025 21:49:46 +0100 Subject: [PATCH 59/62] Update sklearn/tree/tests/test_fenwick.py Co-authored-by: Olivier Grisel --- sklearn/tree/tests/test_fenwick.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_fenwick.py b/sklearn/tree/tests/test_fenwick.py index b71184d84cb6d..c27c17f0ac945 100644 --- a/sklearn/tree/tests/test_fenwick.py +++ b/sklearn/tree/tests/test_fenwick.py @@ -39,7 +39,7 @@ def test_cython_weighted_fenwick_tree(global_random_seed): else: assert cw == target - # check that if we add the next weight, we are above the target: + # check that if we add the next non-null weight, we are above the target: next_weights = w_sorted[t_idx:][w_sorted[t_idx:] > 0] if next_weights.size > 0: assert cw + next_weights[0] > target From 70277d367fab6eccf49b43088db7717fbdf5fb7f Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 27 Oct 2025 21:50:21 +0100 Subject: [PATCH 60/62] adressed comments on test fenwick --- sklearn/tree/tests/test_fenwick.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/tests/test_fenwick.py b/sklearn/tree/tests/test_fenwick.py index b71184d84cb6d..efd9cbc1949f2 100644 --- a/sklearn/tree/tests/test_fenwick.py +++ b/sklearn/tree/tests/test_fenwick.py @@ -12,9 +12,9 @@ def test_cython_weighted_fenwick_tree(global_random_seed): n = 100 indices = rng.permutation(n) y = rng.normal(size=n) - w = rng.integers(1, 4, size=n) - y_sorted = np.zeros_like(y) - w_sorted = np.zeros_like(w) + w = rng.integers(0, 4, size=n) + y_included_so_far = np.zeros_like(y) + w_included_so_far = np.zeros_like(w) tree = PytestWeightedFenwickTree(n) tree.py_reset(n) @@ -22,15 +22,17 @@ def test_cython_weighted_fenwick_tree(global_random_seed): for i in range(n): idx = indices[i] tree.py_add(idx, y[idx], w[idx]) - y_sorted[idx] = y[idx] - w_sorted[idx] = w[idx] + y_included_so_far[idx] = y[idx] + w_included_so_far[idx] = w[idx] - target = rng.uniform(0, w_sorted.sum()) + target = rng.uniform(0, w_included_so_far.sum()) t_idx_low, t_idx, cw, cwy = tree.py_search(target) # check the aggregates are consistent with the returned idx - assert np.isclose(cw, w_sorted[:t_idx].sum()) - assert np.isclose(cwy, (w_sorted[:t_idx] * y_sorted[:t_idx]).sum()) + assert np.isclose(cw, np.sum(w_included_so_far[:t_idx])) + assert np.isclose( + cwy, np.sum(w_included_so_far[:t_idx] * y_included_so_far[:t_idx]) + ) # check if the cumulative weight is less than or equal to the target # depending on t_idx_low and t_idx @@ -40,10 +42,10 @@ def test_cython_weighted_fenwick_tree(global_random_seed): assert cw == target # check that if we add the next weight, we are above the target: - next_weights = w_sorted[t_idx:][w_sorted[t_idx:] > 0] + next_weights = w_included_so_far[t_idx:][w_included_so_far[t_idx:] > 0] if next_weights.size > 0: assert cw + next_weights[0] > target # and not below the target for `t_idx_low`: - next_weights = w_sorted[t_idx_low:][w_sorted[t_idx_low:] > 0] + next_weights = w_included_so_far[t_idx_low:][w_included_so_far[t_idx_low:] > 0] if next_weights.size > 0: assert cw + next_weights[0] >= target From 3e60afd8174ba13bda5498f8bbb85f886e599eeb Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 27 Oct 2025 23:36:49 +0100 Subject: [PATCH 61/62] a lot of doc/comments --- sklearn/tree/_criterion.pyx | 165 +++++++++++++++++++++++++++--------- 1 file changed, 126 insertions(+), 39 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 6e1996a681668..5e67bb8f1497f 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1192,12 +1192,13 @@ cdef void precompute_absolute_errors( Fill `abs_errors` and `medians`. If start < end: - Computes the "prefix" AEs/medians, i.e the AEs for each set of indices - sample_indices[start:start + i] with i in {1, ..., n} - where n = end - start + Forward pass: Computes the "prefix" AEs/medians + i.e the AEs for each set of indices sample_indices[start:start + i] + with i in {1, ..., n}, where n = end - start. Else: - Computes the "suffix" AEs/medians, i.e the AEs for each set of indices - sample_indices[i:] with i in {0, ..., n-1} + Backward pass: Computes the "suffix" AEs/medians + i.e the AEs for each set of indices sample_indices[start - i:start] + with i in {1, ..., n}, where n = start - end. Parameters ---------- @@ -1225,7 +1226,7 @@ cdef void precompute_absolute_errors( Complexity: O(n log n) """ cdef: - intp_t p, i, step, n, r, median_idx, median_prev_idx + intp_t p, i, step, n, rank, median_rank, median_prev_rank float64_t w = 1. float64_t half_weight, median float64_t w_right, w_left, wy_left, wy_right @@ -1240,29 +1241,44 @@ cdef void precompute_absolute_errors( tree.reset(n) p = start + # We iterate exactly `n` samples starting at absolute index `start` and + # move by `step` (+1 for the forward pass, -1 for the backward pass). for _ in range(n): i = sample_indices[p] if sample_weight is not None: w = sample_weight[i] - # Activate sample i at its y-rank - r = ranks[p] - tree.add(r, sorted_y[r], w) + # Activate sample i at its rank: + rank = ranks[p] + tree.add(rank, sorted_y[rank], w) - # Weighted alpha-quantile by cumulative weight + # Weighted median by cumulative weight: the median is where the + # cumulative weight crosses half of the total weight. half_weight = 0.5 * tree.total_w - median_idx = tree.search(half_weight, &w_left, &wy_left, &median_prev_idx) - - if median_idx != median_prev_idx: - # Exact match for half_weight in the tree, take the middle point: - median = (sorted_y[median_prev_idx] + sorted_y[median_idx]) / 2 + # find the smallest activated rank with cumulative weight > half_weight + # while returning the prefix sums (`w_left` and `wy_left`) + # up to (and excluding) that index: + median_rank = tree.search(half_weight, &w_left, &wy_left, &median_prev_rank) + + if median_rank != median_prev_rank: + # Exact match for half_weight fell between two consecutive ranks: + # cumulative weight up to `median_rank` excluded is exactly half_weight. + # In that case, `median_prev_rank` is the activated rank such that + # the cumulative weight up to it included is exactly half_weight. + # In this case we take the mid-point: + median = (sorted_y[median_prev_rank] + sorted_y[median_rank]) / 2 else: - median = sorted_y[median_idx] + # if there are no exact match for half_weight in the cumulative weights + # `median_rank == median_prev_rank` and the median is: + median = sorted_y[median_rank] + # Convert left prefix sums into right-hand complements. w_right = tree.total_w - w_left wy_right = tree.total_wy - wy_left - # O(1) pinball loss formula medians[p] = median + # Pinball-loss identity for absolute error at the current set: + # sum_{y_i >= m} w_i (y_i - m) = wy_right - m * w_right + # sum_{y_i < m} w_i (m - y_i) = m * w_left - wy_left abs_errors[p] += ( (wy_right - median * w_right) + (median * w_left - wy_left) @@ -1326,12 +1342,63 @@ cdef class MAE(Criterion): It has almost nothing in common with other regression criterions so it doesn't inherit from RegressionCriterion. - MAE = (1 / n)*(\sum_i |y_i - f_i|), where y_i is the true - value and f_i is the predicted value. - In decision trees, all samples within a node share the same f_i value, - which corresponds to the weighted median of the target values y_i. - For detailed explanations of the mathematics and algorithmic aspects of - this implementation, refer to + MAE = (1 / n)*(\sum_i |y_i - p_i|), where y_i is the true + value and p_i is the predicted value. + In a decision tree, that prediction is the (weighted) median + of the targets in the node. + + How this implementation works + ----------------------------- + This class precomputes in `reset`, for the current node, + the absolute-error values and corresponding medians for all + potential split positions: every p in [start, end). + + For that: + - We first compute the rank of each samples node-local sorted order of target values. + `self.ranks[p]` gives the rank of sample p. + - While iterating the segment of indices (p in [start, end)), we + * "activate" one sample at a time at its rank within a prefix sum tree, + the `WeightedFenwickTree`: `tree.add(rank, y, weight)` + The tree maintains cumulative sums of weights and of `weight * y` + * search for the half total weight in the tree: + `tree.search(current_total_weight / 2)`. + This alloww us to retrieve/compute: + * the current weighted median value + * the absolute-error contribution via the standard pinball-loss identity: + AE = (wy_right - median * w_right) + (median * w_left - wy_left) + - We perform two such passes: + * one forward from `start` to `end - 1` to fill `left_abs_errors[p]` and + `left_medians[p]` for left children. + * one backward from `end - 1` down to `start` to fill + `right_abs_errors[p]` and `right_medians[p]` for right children. + + Complexity: time complexity is O(n log n), indeed: + - computing ranks is based on sorting: O(n log n) + - add and search operations in the Fenwick tree are O(log n). + => the forward and backward passes are O(n log n). + + How the other methods use the precomputations + -------------------------------------------- + - `reset` performs the precomputation described above. + It also stores the node weighted median per output in + `node_medians` (prediction value of the node). + + - `update(new_pos)` only updates `weighted_n_left` and `weighted_n_right`; + no recomputation of errors is needed. + + - `children_impurity` reads the precomputed absolute errors at + `left_abs_errors[pos - 1]` and `right_abs_errors[pos]` and scales + them by the corresponding child weights and `n_outputs` to report the + impurity of each child. + + - `middle_value` and `check_monotonicity` use the precomputed + `left_medians[pos - 1]` and `right_medians[pos]` to derive the + mid-point value and to validate monotonic constraints when enabled. + + - Missing values are not supported for MAE: `init_missing` raises. + + For a complementary, in-depth discussion of the mathematics and design + choices, see the external report: https://github.com/cakedev0/fast-mae-split/blob/main/report.ipynb """ cdef float64_t[::1] node_medians @@ -1342,7 +1409,7 @@ cdef class MAE(Criterion): cdef float64_t[::1] sorted_y cdef intp_t [::1] sorted_indices cdef intp_t[::1] ranks - cdef WeightedFenwickTree tree + cdef WeightedFenwickTree prefix_sum_tree def __cinit__(self, intp_t n_outputs, intp_t n_samples): """Initialize parameters for this criterion. @@ -1375,11 +1442,22 @@ cdef class MAE(Criterion): self.right_abs_errors = np.empty(n_samples, dtype=np.float64) self.left_medians = np.empty(n_samples, dtype=np.float64) self.right_medians = np.empty(n_samples, dtype=np.float64) - self.tree = WeightedFenwickTree(n_samples) # 2 float64 arrays of size n_samples + 1 - + self.ranks = np.empty(n_samples, dtype=np.intp) + # Important: The arrays declared above are indexed with + # the absolute position `p` in `sample_indices` (not with a 0-based offset). + # The forward and backward passes in `reset` method ensure that + # for any current split position `pos` we can read: + # - left child precomputed values at `p = pos - 1`, and + # - right child precomputed values at `p = pos`. + + self.prefix_sum_tree = WeightedFenwickTree(n_samples) + # used memory: 2 float64 arrays of size n_samples + 1 + # we reuse a single `WeightedFenwickTree` instance to build prefix + # and suffix aggregates over the node samples. + + # Work buffer arrays, used with 0-based offset: self.sorted_y = np.empty(n_samples, dtype=np.float64) self.sorted_indices = np.empty(n_samples, dtype=np.intp) - self.ranks = np.empty(n_samples, dtype=np.intp) cdef int init( self, @@ -1450,17 +1528,24 @@ cdef class MAE(Criterion): memset(&self.left_abs_errors[self.start], 0, n_bytes) memset(&self.right_abs_errors[self.start], 0, n_bytes) + # Multi-output handling: + # absolute errors are accumulated across outputs by + # incrementing `left_abs_errors` and `right_abs_errors` on each pass. + # The per-output medians arrays are overwritten at each output iteration + # as they are only used for monotonicity checks when `n_outputs == 1`. + for k in range(self.n_outputs): + # 1) Node-local ordering: + # for each output k, the values `y[sample_indices[p], k]` for p + # in [start, end) are copied into self.sorted_y[0:n_node_samples]` + # and ranked with `compute_ranks`. + # The resulting `self.ranks[p]` gives the rank of sample p in the + # node-local sorted order. for p in range(self.start, self.end): i = self.sample_indices[p] self.sorted_y[p - self.start] = self.y[i, k] - # Compute the ranks of the node-local values in sorted order. - # - self.sorted_y[0:n_node_samples] is sorted in-place (with indices). - # - self.sorted_indices is a buffer used internally by compute_ranks - # - self.ranks[p] receives the rank of self.y[self.samples_indices[p], k] - # in the sorted array, for p in [start, end) compute_ranks( &self.sorted_y[0], &self.sorted_indices[0], @@ -1468,23 +1553,25 @@ cdef class MAE(Criterion): self.n_node_samples, ) - # Note that at each iteration of this loop, we overwrite `self.left_medians` - # and `self.right_medians`. They are used to check for monoticity constraints, - # which are allowed only with n_outputs=1. + # 2) Forward pass + # from `start` to `end - 1` to fill `left_abs_errors[p]` and + # `left_medians[p]` for left children. precompute_absolute_errors( self.sorted_y, self.ranks, self.sample_weight, self.sample_indices, - self.tree, self.start, self.end, + self.prefix_sum_tree, self.start, self.end, # left_abs_errors is incremented, left_medians is overwritten self.left_abs_errors, self.left_medians ) - # For the right child, we consider samples from end-1 to start-1 - # i.e., reversed, and abs error & median are filled in reverse order to. + # 3) Backward pass + # from `end - 1` down to `start` to fill `right_abs_errors[p]` + # and `right_medians[p]` for right children. precompute_absolute_errors( self.sorted_y, self.ranks, self.sample_weight, self.sample_indices, - self.tree, self.end - 1, self.start - 1, + self.prefix_sum_tree, self.end - 1, self.start - 1, # right_abs_errors is incremented, right_medians is overwritten self.right_abs_errors, self.right_medians ) + # Store the median for the current node self.node_medians[k] = self.right_medians[self.start] From 10c7dde7fe6c6c8861901ba9944ea639d9a74b3c Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Tue, 28 Oct 2025 10:59:40 +0100 Subject: [PATCH 62/62] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/tree/_criterion.pyx | 6 ++++-- sklearn/tree/_utils.pyx | 4 ---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 5e67bb8f1497f..4124ee2c4e374 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1362,7 +1362,7 @@ cdef class MAE(Criterion): The tree maintains cumulative sums of weights and of `weight * y` * search for the half total weight in the tree: `tree.search(current_total_weight / 2)`. - This alloww us to retrieve/compute: + This allows us to retrieve/compute: * the current weighted median value * the absolute-error contribution via the standard pinball-loss identity: AE = (wy_right - median * w_right) + (median * w_left - wy_left) @@ -1572,7 +1572,9 @@ cdef class MAE(Criterion): self.right_abs_errors, self.right_medians ) - # Store the median for the current node + # Store the median for the current node: when p == self.start all the + # node's data points are sent to the right child, so the current node + # median value and the right child median value would be equal. self.node_medians[k] = self.right_medians[self.start] return 0 diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 18fb6cb5497b4..695a86e9a8f68 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -86,10 +86,6 @@ def _any_isnan_axis0(const float32_t[:, :] X): return np.asarray(isnan_out) -# ============================================================================= -# WeightedFenwickTree data structure -# ============================================================================= - cdef class WeightedFenwickTree: """ Fenwick tree (Binary Indexed Tree) specialized for maintaining: