-
Notifications
You must be signed in to change notification settings - Fork 442
Refactor botorch/sampling/pathwise and add support for product kernels #2838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2838 +/- ##
===========================================
- Coverage 100.00% 99.93% -0.07%
===========================================
Files 216 219 +3
Lines 20211 20722 +511
===========================================
+ Hits 20211 20709 +498
- Misses 0 13 +13 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Thanks @seashoo for the PR - this is a big one! It'll take me a bit of time to review this in detail, I plan to do a first higher-level pass this week.
What exactly does this mean? |
Hi @Balandat, Thanks for the response! We've included a more detailed Project Overview section in the pull request description to clarify our validation approach. Specifically, we utilized the existing unit test files, which cover prior, updates, and posterior sampling, and ensured that all tests passed as part of this rebase. While these tests are comprehensive, we welcome any additional guidance you might have on further validating the code's robustness. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went over the main code in the PR in detail; overall this looks great, thanks for the effort and patching some gaps (e.g. _gaussian_update_ModelListGP
). I have not reviewed the testing code in detail, but can do that after the next pass.
The key things to address are:
- Some additions in the patch file were not included here - curious to understand why (and if this was an oversight let's add them in - I pointed out which ones).
- Currently the tests still have some coverage gaps based on the codecov report here. Please add some test cases to also cover the currently uncovered lines.
task_index = ( | ||
num_inputs + model._task_feature | ||
if model._task_feature < 0 | ||
else model._task_feature | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to do this so we can always assume that it's positive and we don't have to do this custom handling. But I think this is ok for now as is, that change is beyond the scope of this PR.
task_index = ( | |
num_inputs + model._task_feature | |
if model._task_feature < 0 | |
else model._task_feature | |
) | |
# TODO: Changed `MultiTaskGP` to normalize the task feature in its constructor. | |
task_index = ( | |
num_inputs + model._task_feature | |
if model._task_feature < 0 | |
else model._task_feature | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great suggestion- I would be willing to come back to this in a separate PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be resolved in #2908!
@Balandat, thank you for the comprehensive review and detailed feedback on Wilson's product_kernel.diff implementation! My apologies for the time it's taken to get back to you- some of the implementations took quite some time to fully realize and I've been balancing the work alongside my current internship. I'm excited to collaborate with you at a faster pace now that I've freed up time! If you have questions regarding any of my specific implementations, feel free to ask in a reply to any of the comments here- I'll be able to communicate much more swiftly now. I've went ahead and resolved some of the upstream merge conflicts that appeared while I was away, and I've also filled up the code coverage gaps as you've asked. Here's a quick summary of the major changes implemented to address your concerns: Mathematical Issues ResolvedProduct Kernel ImplementationCompletely redesigned the
Transform System OverhaulFixed scaling and coordination issues across the transform pipeline
Architectural ImprovementsFeature Map RedesignBuilt comprehensive feature map architecture
Code OrganizationRestructured from monolithic
Code Quality EnhancementsModern Python Standards
Type SafetyEnhanced type annotations and return type specificity for better IDE support Technical details regarding the issues + approaches taken in response to your suggestions are further addressed in my replies! |
Thanks for all the updates, @seashoo. Sorry for the delay on the review; I've been traveling for a conference and taking some time off. I should be able to provide a detailed review next week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @seashoo for the PR! I added some comments and questions too.
num_random_features: int, | ||
num_inputs: int | None = None, | ||
random_feature_scale: float | None = None, | ||
cosine_only: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not aware of the ProductKernel dynamic with cosines here. @seashoo do you have a reference to share on it? It'd be much appreciated!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @hvarfner, I believe @seashoo followed the direct implementation from the original patch provided by Balandat, likely implemented by Wilson to make computation tractable for product of infinite dimensional kernels. By enforcing cosine_only == True
, we ensure that the feature map takes the element-wise product of the features. Otherwise, we would need to take the tensor product of each pair of sine and cosine features.
I don't think we have a specific reference that comments on this edge case, but as you already know, cosine_only
specifies whether to use cosine features with a random phase, instead of paired sine and cosine features, as described in Rahimi & Recht (2007) on Random Fourier Features (RFF) and Sutherland (2015) on RFF error bounds. I'll let you know if I come across any further references that clarify this implementation.
self.input_transform = input_transform | ||
self.output_transform = output_transform | ||
|
||
def forward(self, x: Tensor, **kwargs: Any) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This forward looks really impressive, but I honestly cannot tell what is going on. Can you add some comments? Specifically L133 onward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback- I’ve added in-line comments to DirectSumFeatureMap.forward (from the tiling logic through the final concatenation) to clarify:
- Why I collect/scale individual feature blocks.
- How the tiling & broadcasting works for lower-dimensional feature maps.
- The rationale behind the rescaling.
- How the multi_index trick avoids extra allocations.
Let me know if any sections still feel unclear or if you’d like similar commentary elsewhere.
|
||
@property | ||
def raw_output_shape(self) -> Size: | ||
if not self: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When does this occur?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not self
only happens when the DirectSumFeatureMap
is empty. That can be because you:
- Purposely start with an empty container and plan to append feature-maps later, or
- Deleted the last entry and the list is now length-zero.
We hit this in unit tests that build DirectSumFeatureMap([])
to make sure the class doesn’t crash in that edge case. Returning Size([])
just keeps the object in a sane, queryable state (so output_shape
, batch_shape
, etc. still work) until real feature maps are added. I’ve added a clarifying comment!
return torch.concat(blocks, dim=-1) | ||
|
||
@property | ||
def raw_output_shape(self) -> Size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if I'm reading this correctly, raw_output_shape
returns the largest broadcastable shape across all sub-kernels? How does it differ from torch.broadcast_shapes
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raw_output_shape
does two things in one go:
- For every sub-map it first figures out the shape you would get after tiling / broadcasting inside
forward
(e.g. when a 1-D feature map has to be expanded to align with a 2-D one). - It then concatenates those shapes along the feature dimension, so the final size of the last axis is the sum of the sub-maps’ feature counts.
torch.broadcast_shapes
by itself only answers “what common shape can all these tensors be viewed as without copying?”. It never alters the size of any dimension- especially not the last one.
In our case we need to grow the last dimension because we’re gluing feature vectors together; that’s why raw_output_shape
can’t be expressed as a single torch.broadcast_shapes
call.
block = block.to_dense() if isinstance(block, LinearOperator) else block | ||
block = block if block.is_sparse else block.to_sparse() | ||
else: | ||
multi_index = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment or two would be really helpful here too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comments to the sparse branch!
# containing data_covar_module * task_covar_module | ||
from gpytorch.kernels import ProductKernel | ||
|
||
if isinstance(model.covar_module, ProductKernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MTGPs (including the fully Bayesian variant) now use a ProductKernel by definition, so this check and the conditional logic should be obsolete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And if there are some edge cases where we don't adhere to the standard ProductKernel(IndexKernel, SomeOtherKernel)
where both have the expected active_dims
, I think it's best to raise an error!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the nudge- both spots (prior_samplers.py
and update_strategies.py
) now follow the MTGP contract precisely. If a user hand-builds a ProductKernel
and forgets active_dims
on the data part, we add them so downstream helpers don’t error.
task_index = ( | ||
num_inputs + model._task_feature | ||
if model._task_feature < 0 | ||
else model._task_feature | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be resolved in #2908!
# containing data_covar_module * task_covar_module | ||
from gpytorch.kernels import ProductKernel | ||
|
||
if isinstance(model.covar_module, ProductKernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once again, this conditional logic should not be needed anymore!
INF_DIM_KERNELS: Tuple[Type[Kernel], ...] = ( | ||
kernels.MaternKernel, | ||
kernels.RBFKernel, | ||
kernels.MultitaskKernel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is MultiTaskKernel
necessarily in here? Shouldn't IndexKernel
be here too in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
INF_DIM_KERNELS
is the “shortcut” list for kernels that are always treated as infinite-dimensional when deciding whether we need random Fourier features etc. RBFKernel
and MaternKernel
are on the list because their exact feature map is infinite. IndexKernel
itself is finite-dimensional (exactly num_tasks
features), so we deliberately keep it out of the list.
Motivation
Hi! I'm Sahran Ashoor, an undergraduate research assistant working for the Uncertainty Quantification Lab at the University of Houston. I work under Dr. Ruda Zhang and Taiwo Adebiyi, both of whom having already spoken with Max Balandat regarding incorporating a rebase of botorch/sampling/pathwise (Largely written by James. T. Wilson). The changes included in this pull request are my best attempt at faithfully completing the change logs I was provided (product_kernel_diff.txt).
Have you read the Contributing Guidelines on pull requests?
Yes!
Project Overview
The primary goal was to make the original codebase by Wilson compatible with the latest BoTorch version. To achieve this, we used the original source codes and test suites, which initially revealed several incompatibility issues. Our main contribution involved carefully rebasing Wilson's code while preserving the logic for pathwise sampling. Importantly, all changes were confined to the botorch/sampling/pathwise directory to ensure a seamless integration, passing both local pathwise test suites and BoTorch's global test suites via GitHub workflows.
In terms of code logic, we relied on Wilson's unit tests for prior, updates, and posterior sampling, which we believe are sufficient to validate the correctness of the implementation. However, we welcome your feedback on this approach, and would appreciate any suggestions for additional tests or example scripts to further confirm the robustness of the changes. We are open to collaborating further on this effort.
Test Plan
(Write your test plan here. If you changed any code, please provide us with clear instructions on how you verified your changes work. Bonus points for screenshots and videos!)
The entirety of the testing suite was ran through pytest. Through additional verification we've found that the logic may be offset, but we're hoping to work with you all and further validate these changes under the discretion of Dr. Zhang. Expect further communications directly from my lab that will provide more insight into the rebase.
Related PRs
(If this PR adds or changes functionality, please take some time to update the docs at https://github.com/pytorch/botorch, and link to your PR here.)
N/A