-
Notifications
You must be signed in to change notification settings - Fork 140
Description
Description
vectorize_node implicitly assumes that whenever we want to vectorize a node, we will return a new node that has a 1-to-1 mapping with the original outputs, but this is too restrictive. It could be the case we want to vectorize a single node with two variables coming from different nodes, or a single output from a multi-valued node. There's no reason why we need a one node -> one node mapping.
pytensor/pytensor/graph/replace.py
Lines 208 to 211 in 79ff97a
@singledispatch | |
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply: | |
# Default implementation is provided in pytensor.tensor.blockwise | |
raise NotImplementedError |
For backwards compatibility we should check if the returned object is an Apply and issue a warning that this form is deprecated (but still use it) and instead a list of outputs (like in the rewrites) should be returned. All our implementations in PyTensor should switch to returning a list of variables.
The catch/warning could be done here:
pytensor/pytensor/graph/replace.py
Lines 214 to 217 in 79ff97a
def vectorize_node(node: Apply, *batched_inputs) -> Apply: | |
"""Returns vectorized version of node with new batched inputs.""" | |
op = node.op | |
return _vectorize_node(op, node, *batched_inputs) |
Then everything that calls vectorize_node
should now expect a list as output. Like here:
pytensor/pytensor/graph/replace.py
Lines 301 to 308 in 79ff97a
vect_node = vectorize_node(node, *vect_inputs) | |
for output, vect_output in zip(node.outputs, vect_node.outputs): | |
if output in vect_vars: | |
# This can happen when some outputs of a multi-output node are given a replacement, | |
# while some of the remaining outputs are still needed in the graph. | |
# We make sure we don't overwrite the provided replacement with the newly vectorized output | |
continue | |
vect_vars[output] = vect_output |