Skip to content

Commit 1e949b4

Browse files
committed
1 parent 9bc49f5 commit 1e949b4

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

keras/src/layers/core/composite_layer_test.py

+25
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,31 @@ def replace_fn(layer, *args, **kwargs):
11981198

11991199
self.assertAllClose(model(data), model2(data))
12001200

1201+
def test_clone_recursive(self):
1202+
def layer_fn1(inputs):
1203+
return layers.Dense(32, name="dense_2")(inputs)
1204+
1205+
layer1 = CompositeLayer(layer_fn1, name="sub")
1206+
1207+
def layer_fn2(inputs):
1208+
return layer1(inputs)
1209+
1210+
slayer = CompositeLayer(layer_fn2, name="subfunc")
1211+
1212+
inputs = layers.Input(shape=(16, 32))
1213+
outputs = slayer(inputs)
1214+
model = Model(inputs, outputs)
1215+
1216+
def call_fn(layer, *args, **kwargs):
1217+
if isinstance(layer, layers.Dense):
1218+
new_layer = layers.Dense(layer.units, name="dense_modified")
1219+
return new_layer(*args, **kwargs)
1220+
return layer(*args, **kwargs)
1221+
1222+
new_model = clone_model(model, call_function=call_fn, recursive=True)
1223+
sub = new_model.get_layer("subfunc").get_layer("sub")
1224+
self.assertEqual(sub.layers[1].name, "dense_modified")
1225+
12011226
def test_build_twice(self):
12021227
def layer_fn(inputs):
12031228
return layers.Dense(5)(inputs)

keras/src/models/cloning.py

+1
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def wrapped_clone_function(layer):
276276
clone_function=clone_function,
277277
call_function=call_function,
278278
cache=cache,
279+
recursive=True
279280
)
280281
cache[id(layer)] = clone
281282
return clone

keras/src/models/cloning_test.py

+23
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,29 @@ def clone_function(layer):
279279
self.assertFalse(hasattr(l1, "flag"))
280280
self.assertTrue(hasattr(l2, "flag"))
281281

282+
def test_recursive_level_2(self):
283+
inputs = layers.Input(shape=(16, 32))
284+
outputs = layers.Dense(32, name="dense_2")(inputs)
285+
layer1 = models.Model(inputs, outputs, name="sub")
286+
287+
inputs = layers.Input(shape=(16, 32))
288+
outputs = layer1(inputs)
289+
slayer = models.Model(inputs, outputs, name="subfunc")
290+
291+
inputs = layers.Input(shape=(16, 32))
292+
outputs = slayer(inputs)
293+
model = models.Model(inputs, outputs)
294+
295+
def call_fn(layer, *args, **kwargs):
296+
if isinstance(layer, layers.Dense):
297+
new_layer = layers.Dense(layer.units, name="dense_modified")
298+
return new_layer(*args, **kwargs)
299+
return layer(*args, **kwargs)
300+
301+
new_model = clone_model(model, call_function=call_fn, recursive=True)
302+
sub = new_model.get_layer("subfunc").get_layer("sub")
303+
self.assertEqual(sub.layers[1].name, "dense_modified")
304+
282305
def test_compiled_model_cloning(self):
283306
model = models.Sequential()
284307
model.add(layers.Input((3,)))

0 commit comments

Comments
 (0)