Skip to content
84 changes: 44 additions & 40 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import builtins
import math
from collections.abc import Callable
from copy import copy
from itertools import chain
from textwrap import dedent
from typing import Any, TypeAlias
Expand Down Expand Up @@ -779,9 +778,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType:
This caches objects to save allocation and run time.

"""
if dtype not in cache:
cache[dtype] = ScalarType(dtype=dtype)
return cache[dtype]
try:
return cache[dtype]
except KeyError:
cache[dtype] = res = ScalarType(dtype=dtype)
return res


# Register C code for ViewOp on Scalars.
Expand Down Expand Up @@ -987,25 +988,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:


def as_scalar(x: Any, name: str | None = None) -> ScalarVariable:
from pytensor.tensor.basic import scalar_from_tensor
from pytensor.tensor.type import TensorType
if isinstance(x, ScalarVariable):
return x

if isinstance(x, Variable):
from pytensor.tensor.basic import scalar_from_tensor
from pytensor.tensor.type import TensorType

if isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError(f"Cannot convert {x} to a scalar type")

if isinstance(x, Apply):
# FIXME: Why do we support calling this with Apply?
# Also, if we do, why can't we support multiple outputs?
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output"
" Op has to be fetched.",
x,
)
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x, ScalarVariable):
return x
elif isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError(f"Cannot convert {x} to a scalar type")
return as_scalar(x.outputs[0])

return constant(x)

Expand Down Expand Up @@ -1329,32 +1333,26 @@ def supports_c_code(self, inputs, outputs):
the given Elemwise inputs, outputs.

"""
try:
tmp_s_input = []
# To keep the same aliasing between inputs
mapping = dict()
for ii in inputs:
if ii in mapping:
tmp_s_input.append(mapping[ii])
else:
tmp = get_scalar_type(ii.dtype).make_variable()
tmp_s_input.append(tmp)
mapping[ii] = tmp_s_input[-1]

with config.change_flags(compute_test_value="ignore"):
s_op = self(*tmp_s_input, return_list=True)
tmp_s_input = []
# To keep the same aliasing between inputs
mapping = {}
for ii in inputs:
if ii in mapping:
tmp_s_input.append(mapping[ii])
else:
tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable()
tmp_s_input.append(tmp)

# if the scalar_op don't have a c implementation,
# we skip its fusion to allow the fusion of the
# other ops.
try:
self.c_code(
s_op[0].owner,
self.make_node(*tmp_s_input),
"test_presence_of_c_code",
# FIXME: Shouldn't this be a unique name per unique variable?
["x" for x in inputs],
["z" for z in outputs],
{"fail": "%(fail)s"},
)
except (MethodNotDefined, NotImplementedError):
except (NotImplementedError, MethodNotDefined):
return False
return True

Expand Down Expand Up @@ -4094,12 +4092,12 @@ def __init__(self, *args, **kwargs):
self.prepare_node_called = set()
super().__init__(*args, **kwargs)

def _cleanup_graph(self, inputs, outputs):
def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True):
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.

fgraph = FunctionGraph(copy(inputs), copy(outputs))
fgraph = FunctionGraph(inputs, outputs, clone=clone)

# Validate node types
for node in fgraph.apply_nodes:
Expand Down Expand Up @@ -4282,7 +4280,9 @@ class Composite(ScalarInnerGraphOp):

init_param: tuple[str, ...] = ("inputs", "outputs")

def __init__(self, inputs, outputs, name="Composite"):
def __init__(
self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True
):
self.name = name
self._name = None
# We need to clone the graph as sometimes its nodes already
Expand All @@ -4300,10 +4300,13 @@ def __init__(self, inputs, outputs, name="Composite"):
if len(outputs) > 1 or not any(
isinstance(var.owner.op, Composite) for var in outputs
):
# No inner Composite
inputs, outputs = clone(inputs, outputs)
if clone_graph:
inputs, outputs = clone(inputs, outputs)

else:
# Inner Composite that we need to flatten
# FIXME: There could be a composite in the middle of the graph, why is this here?
# If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway.
assert len(outputs) == 1
# 1. Create a new graph from inputs up to the
# Composite
Expand All @@ -4322,7 +4325,8 @@ def __init__(self, inputs, outputs, name="Composite"):
assert res[0] != inputs
inputs, outputs = res[0], res2[1]

self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
# We already cloned the graph, or the user told us there was no need for it
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
self.inputs_type = tuple(input.type for input in self.inputs)
self.outputs_type = tuple(output.type for output in self.outputs)
self.nin = len(inputs)
Expand Down
Loading