Skip to content

Commit 84252b5

Browse files
authored
Fix C-cache bug related to input order of nominal variables (#1673)
1 parent 49f76da commit 84252b5

File tree

2 files changed

+139
-8
lines changed

2 files changed

+139
-8
lines changed

pytensor/link/c/basic.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,7 +1392,8 @@ def in_sig(i, topological_pos, i_idx):
13921392

13931393
# It is important that a variable (i)
13941394
# yield a 'position' that reflects its role in code_gen()
1395-
if isinstance(i, AtomicVariable): # orphans
1395+
inp_sig = isig = fgraph_inputs_dict.get(i, False) # inputs
1396+
if isinstance(i, AtomicVariable): # orphans or constant inputs
13961397
if id(i) not in constant_ids:
13971398
isig = (i.signature(), topological_pos, i_idx)
13981399
# If the PyTensor constant provides a strong hash
@@ -1412,11 +1413,7 @@ def in_sig(i, topological_pos, i_idx):
14121413
constant_ids[id(i)] = isig
14131414
else:
14141415
isig = constant_ids[id(i)]
1415-
# print 'SIGNATURE', i.signature()
1416-
# return i.signature()
1417-
elif i in fgraph_inputs_dict: # inputs
1418-
isig = fgraph_inputs_dict[i]
1419-
else:
1416+
elif inp_sig is None:
14201417
if i.owner is None:
14211418
assert all(all(out is not None for out in o.outputs) for o in order)
14221419
assert all(input.owner is None for input in fgraph.inputs)
@@ -1432,7 +1429,7 @@ def in_sig(i, topological_pos, i_idx):
14321429
)
14331430
else:
14341431
isig = (op_pos[i.owner], i.owner.outputs.index(i)) # temps
1435-
return (isig, i in no_recycling)
1432+
return (inp_sig, isig, i in no_recycling)
14361433

14371434
version = []
14381435
for node_pos, node in enumerate(order):

tests/link/c/test_basic.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import numpy as np
22
import pytest
33

4+
from pytensor import Out
45
from pytensor.compile import shared
56
from pytensor.compile.function import function
67
from pytensor.compile.mode import Mode
78
from pytensor.configdefaults import config
8-
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.graph.basic import Apply, Constant, NominalVariable, Variable
910
from pytensor.graph.fg import FunctionGraph
1011
from pytensor.link.basic import PerformLinker
1112
from pytensor.link.c.basic import CLinker, DualLinker, OpWiseCLinker
1213
from pytensor.link.c.op import COp
1314
from pytensor.link.c.type import CType
15+
from pytensor.link.vm import VMLinker
1416
from pytensor.tensor.type import iscalar, matrix, vector
1517
from tests.link.test_link import make_function
1618

@@ -135,6 +137,19 @@ def impl(self, x, y):
135137
add = Add()
136138

137139

140+
class Sub(Binary):
141+
def c_code(self, node, name, inp, out, sub):
142+
x, y = inp
143+
(z,) = out
144+
return f"{z} = {x} - {y};"
145+
146+
def impl(self, x, y):
147+
return x - y
148+
149+
150+
sub = Sub()
151+
152+
138153
class BadSub(Binary):
139154
def c_code(self, node, name, inp, out, sub):
140155
x, y = inp
@@ -260,6 +275,125 @@ def test_clinker_single_node():
260275
assert fn(2.0, 7.0) == 9
261276

262277

278+
@pytest.mark.skipif(
279+
not config.cxx, reason="G++ not available, so we need to skip this test."
280+
)
281+
@pytest.mark.parametrize(
282+
"linker", [CLinker(), VMLinker(use_cloop=True)], ids=["C", "CVM"]
283+
)
284+
@pytest.mark.parametrize("atomic_type", ["constant", "nominal"])
285+
def test_clinker_atomic_inputs(linker, atomic_type):
286+
"""Test that compiling variants of the same graph with different order of atomic inputs works correctly
287+
288+
Indirect regression test for https://github.com/pymc-devs/pytensor/issues/1670
289+
"""
290+
291+
def call(thunk_out, args):
292+
thunk, input_storage, output_storage = thunk_out
293+
assert len(input_storage) == len(args)
294+
for i, arg in zip(input_storage, args):
295+
i.data = arg
296+
thunk()
297+
assert len(output_storage) == 1, "Helper function assumes one output"
298+
return output_storage[0].data
299+
300+
if atomic_type == "constant":
301+
# Put large value to make sure we don't forget to specify it
302+
x = Constant(tdouble, 999, name="x")
303+
one = Constant(tdouble, 1.0)
304+
two = Constant(tdouble, 2.0)
305+
else:
306+
x = NominalVariable(0, tdouble, name="x")
307+
one = NominalVariable(1, tdouble, name="one")
308+
two = NominalVariable(1, tdouble, name="two")
309+
310+
sub_one = sub(x, one)
311+
sub_two = sub(x, two)
312+
313+
# It may seem strange to have a constant as an input,
314+
# but that's exactly how C_Ops define a single node FunctionGraph
315+
# to be compiled by the CLinker.
316+
# FunctionGraph(node.inputs, node.outputs)
317+
fg1 = FunctionGraph(inputs=[x, one], outputs=[sub_one])
318+
thunk1 = linker.accept(fg1).make_thunk()
319+
assert call(thunk1, [10, 1]) == 9
320+
# Technically, passing a wrong constant is undefined behavior,
321+
# Just checking the current behavior, NOT ENFORCING IT
322+
assert call(thunk1, [10, 0]) == 10
323+
324+
# The old code didn't use to handle a swap of atomic inputs correctly
325+
# Because it didn't expect Atomic variables to be in the inputs list
326+
# This reordering doesn't usually happen, because C_Ops pass the inputs in the order of the node.
327+
# What can happen is that we compile the same FunctionGraph with CLinker and CVMLinker,
328+
# The CLinker takes the whole FunctionGraph as is, with the required inputs specified by the user
329+
# While the CVMLinker will call the CLinker on its one Op with all inputs (required and constants)
330+
# This difference in input signature used to be ignored by the cache key,
331+
# but the generated code cared about the number of explicit inputs.
332+
# Changing the order of inputs is a smoke test to make sure we pay attention to the input signature.
333+
# The fg4 below tests the actual number of inputs changing.
334+
fg2 = FunctionGraph(inputs=[one, x], outputs=[sub_one])
335+
thunk2 = linker.accept(fg2).make_thunk()
336+
assert call(thunk2, [1, 10]) == 9
337+
# Again, technically undefined behavior
338+
assert call(thunk2, [0, 10]) == 10
339+
340+
fg3 = FunctionGraph(inputs=[x, two], outputs=[sub_two])
341+
thunk3 = linker.accept(fg3).make_thunk()
342+
assert call(thunk3, [10, 2]) == 8
343+
344+
# For completeness, confirm the CLinker cmodule_key are all different
345+
key1 = CLinker().accept(fg1).cmodule_key()
346+
key2 = CLinker().accept(fg2).cmodule_key()
347+
key3 = CLinker().accept(fg3).cmodule_key()
348+
349+
if atomic_type == "constant":
350+
# Case that only make sense for constant atomic inputs
351+
352+
# This used to complain that an extra imaginary argument didn't have the right dtype
353+
# Because it used to reuse the codegen from the previous examples incorrectly
354+
fg4 = FunctionGraph(inputs=[x], outputs=[sub_one])
355+
thunk4 = linker.accept(fg4).make_thunk()
356+
assert call(thunk4, [10]) == 9
357+
358+
# Note that fg1 and fg3 are structurally identical, but have distinct constants
359+
# Therefore they have distinct module keys.
360+
# This behavior could change in the future, to enable more caching reuse:
361+
# https://github.com/pymc-devs/pytensor/issues/1672
362+
key4 = CLinker().accept(fg4).cmodule_key()
363+
assert len({key1, key2, key3, key4}) == 4
364+
else:
365+
# With nominal inputs, fg1 and fg3 are identical
366+
assert key1 != key2
367+
assert key1 == key3
368+
369+
370+
@pytest.mark.skipif(
371+
not config.cxx, reason="G++ not available, so we need to skip this test."
372+
)
373+
def test_clinker_cvm_same_function():
374+
# Direct regression test for
375+
# https://github.com/pymc-devs/pytensor/issues/1670
376+
x1 = NominalVariable(0, vector("x", shape=(10,), dtype="float64").type)
377+
y1 = NominalVariable(1, vector("y", shape=(10,), dtype="float64").type)
378+
const1 = np.arange(10)
379+
out = x1 + const1 * y1
380+
381+
# Without borrow the C / CVM code is different
382+
fn = function(
383+
[x1, y1], [Out(out, borrow=True)], mode=Mode(linker="c", optimizer="fast_run")
384+
)
385+
fn(np.zeros(10), np.zeros(10))
386+
387+
fn = function(
388+
[x1, y1],
389+
[Out(out, borrow=True)],
390+
mode=Mode(linker="cvm", optimizer="fast_run"),
391+
)
392+
fn(
393+
np.zeros(10), np.zeros(10)
394+
) # Used to raise ValueError: expected an ndarray, not None
395+
396+
263397
@pytest.mark.skipif(
264398
not config.cxx, reason="G++ not available, so we need to skip this test."
265399
)

0 commit comments

Comments
 (0)