Skip to content

Commit 224797b

Browse files
authored
Update bn_folding_test.py
1 parent 2a787ff commit 224797b

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

tests/bn_folding_test.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -462,17 +462,13 @@ def test_same_training_and_prediction(model_name):
462462
if model_name == "conv2d":
463463
x_shape = (2, 2, 1)
464464
kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]])
465-
gamma = np.array([2., 1.])
466-
beta = np.array([0., 1.])
467-
moving_mean = np.array([1., 1.])
468-
moving_variance = np.array([1., 2.])
469465
elif model_name == "dense":
470466
x_shape = (4,)
471467
kernel = np.array([[1., 1.], [1., 0.], [1., 1.], [0., 1.]])
472-
gamma = np.array([2., 1.])
473-
beta = np.array([0., 1.])
474-
moving_mean = np.array([1., 1.])
475-
moving_variance = np.array([1., 2.])
468+
gamma = np.array([2., 1.])
469+
beta = np.array([0., 1.])
470+
moving_mean = np.array([1., 1.])
471+
moving_variance = np.array([1., 2.])
476472
iteration = np.array(-1)
477473

478474
train_ds = generate_dataset(train_size=10, batch_size=10, input_shape=x_shape,

0 commit comments

Comments
 (0)