-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Fix: improve speed of trees with MAE criterion from O(n^2) to O(n log n) #32100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…dded print everywhere to debug; fixed some bugs
…al PR but not all
| # MAE split precomputations algorithm | ||
| # ============================================================================= | ||
|
|
||
| def _any_isnan_axis0(const float32_t[:, :] X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this one up, in the helpers section.
|
@adam2392 could you please have a look here? |
adam2392
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First of all. Thanks @cakedev0 for taking a look at this challenging, but impactful issue, and proposing a fix.
I took an initial glance. This overall looks like the right direction to me, so I want to make sure others take a look before we dive into the nitty stuff of making the PR mergable, and maintainable.
I have an open q: For decision trees, we can imagine imposing a quantile-criterion split (e.g. the pinball loss). Naively, I think we can make the WeightedHeaps work to maintain any sort of quantile right?
Perhaps @thomasjpfan wants to take a look as well before we dive deeper into the code.
Co-authored-by: Olivier Grisel <[email protected]>
…to mae-split-optim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (besides nitpicks below)! Thanks for the great PR.
I let @adam2392 do the merge if he is still +1 for merge after the latest changes.
Possible follow-ups:
- generalize to regression for an arbitrary quantile;
- add support for missing values (if not overly complex).
Co-authored-by: Olivier Grisel <[email protected]>
Actually, this is very simple (even simplifies the current code base), and has nothing to do with criteria. Criteria don't interact with feature values, just with the target values and their ordering via |
|
I missed that PR despite the notification... |
|
I resolved the conflict and looked at the change. I'm not super familiar with the old code, so I focussed on looking at how the new code looks. It seems fine to me. I learnt about Fenwick trees :D While the change is quite large (both in diff size and speed up!) it does not change existing tests. It adds to an existing test and adds new tests. So I think we can be quite sure that we won't break existing users and only deliver speed ups. |
|
I will attempt to look at it this week! It's in my queue that keeps queueing 😅 |
|
Let me merge to get this in 1.8 given the fact that we have already 2 recent positive reviews. @adam2392 feel free to open a follow-up PR with incremental improvements or fixes if needed. |
|
Thanks very much @cakedev0! This is great work! |
|
We also need to review #32119 BTW ;) |
|
Thank you very much, @cakedev0! 👏 |
|
Awesome. Thanks @cakedev0 ! |
|
🎉 🎉 Thanks to everyone who spent time reviewing this!! |
|
Congratulations, @cakedev0, this is very cool!! |
|
@cakedev0 Thanks for clarifying. The whatsnew entry contains both the efficiency enhancement and the fix. That‘s perfect! |
This PR re-implements the way
DecisionTreeRegressor(criterion='absolute_error')works underneath for optimization purposes. The current algorithm for calculating the AE of a split incures a O(n^2) overall complexity for building a tree which quickly becomes impractical. My implementation makes it O(n log n) making it tremendously faster.For instance with d=2, n=100_000 and max_depth=1 (just one split), the execution time went from ~30s to ~100ms on my machine.
Referenced Issues
Fixes #9626 by reducing the complexity from O(n^2) to O(n log n).
Also fixes #32099 & #10725 (which are probably duplicates). But that's more of a side effect of re-implementing completely the criterion logic for MAE.
Supersedes #11649 (which was opened to fix #10725 7 years ago but never merged).
Explanation of my changes
The changes focus solely on the class
MAE(RegressionCriterion).Previous implementation had O(n^2) overall complexity emerging from several methods in this class
update: O(n) cost due to updating a data structure that maintains data sorted (WeightedMedianCalculator/WeightedPQueue). Called O(n) times to find the best split => O(n^2) overallchildren_impurity: O(n) due to looping over all the data points. Called O(n) times to find the best split => O(n^2) overallThose can't really be fixed by small local changes, as overall, the algorithm is O(n^2) independently of how you implement it. Hence a complete rewrite was needed. As discussed in this technical report I made, there are several efficient algorithms to solve the problem (computing the absolute errors for all the possible splits along one feature).
The one I chose initially was an intuitive adaptation of the well-known two-heap solution of the "find median from a data stream" problem. But even if it had a O(n log n) expected complexity, it can be O(n^2 log n) in some pathological cases. So after some discussions, it was chosen to implement an other solution: the "Fenwick tree option". This solution is based on a Fenwick tree, a data-structure specialized in efficient prefix sums computations and updates.
See the technical report for detailed explanation of the algorithm, but in short, the main steps are:
the value of those 4 prefix/suffix-sums can be found while searching for the median in the tree, and once you have those, the computation becomes O(1).
Iterate on the data from left to right to compute the AE for every possible left child. And iterate from right to left to compute the AE for every possible right child.
This logic is implemented in
tree/_criterion.pyx::precompute_absolute_errorsas I wanted to be able to unit test it.After some research I found a paper about the same problem. Their approach uses the two heaps idea and generalizes to arbitrary quantiles (as done in my follow-up PR), but it does not handle weighted samples. Also, the paper uses a more elaborate formula for the absolute error/loss computation than mine, TBH it looks unnecessarily complex.