Skip to content

Commit 3d20616

Browse files
Fix JAX GPU CI and make formatter happy (keras-team#20749)
* Fix JAX GPU CI * Makes formatter happy * Makes formatter happy - 2
1 parent 97c1c00 commit 3d20616

36 files changed

+64
-82
lines changed

examples/demo_custom_torch_workflow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def train(model, train_loader, num_epochs, optimizer, loss_fn):
7474
# Print loss statistics
7575
if (batch_idx + 1) % 10 == 0:
7676
print(
77-
f"Epoch [{epoch+1}/{num_epochs}], "
78-
f"Batch [{batch_idx+1}/{len(train_loader)}], "
77+
f"Epoch [{epoch + 1}/{num_epochs}], "
78+
f"Batch [{batch_idx + 1}/{len(train_loader)}], "
7979
f"Loss: {running_loss / 10}"
8080
)
8181
running_loss = 0.0

examples/demo_torch_multi_gpu.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def train(model, train_loader, num_epochs, optimizer, loss_fn):
104104
# Print loss statistics
105105
if (batch_idx + 1) % 10 == 0:
106106
print(
107-
f"Epoch [{epoch+1}/{num_epochs}], "
108-
f"Batch [{batch_idx+1}/{len(train_loader)}], "
107+
f"Epoch [{epoch + 1}/{num_epochs}], "
108+
f"Batch [{batch_idx + 1}/{len(train_loader)}], "
109109
f"Loss: {running_loss / 10}"
110110
)
111111
running_loss = 0.0

keras/src/backend/openvino/core.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@ def __getitem__(self, indices):
242242
def __len__(self):
243243
ov_output = self.output
244244
ov_shape = ov_output.get_partial_shape()
245-
assert (
246-
ov_shape.rank.is_static and ov_shape.rank.get_length() > 0
247-
), "rank must be static and greater than zero"
245+
assert ov_shape.rank.is_static and ov_shape.rank.get_length() > 0, (
246+
"rank must be static and greater than zero"
247+
)
248248
assert ov_shape[0].is_static, "the first dimension must be static"
249249
return ov_shape[0].get_length()
250250

@@ -425,10 +425,10 @@ def convert_to_numpy(x):
425425
x = x.value
426426
else:
427427
return x.value.data
428-
assert isinstance(
429-
x, OpenVINOKerasTensor
430-
), "unsupported type {} for `convert_to_numpy` in openvino backend".format(
431-
type(x)
428+
assert isinstance(x, OpenVINOKerasTensor), (
429+
"unsupported type {} for `convert_to_numpy` in openvino backend".format(
430+
type(x)
431+
)
432432
)
433433
try:
434434
ov_result = x.output

keras/src/backend/openvino/nn.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,9 @@ def depthwise_conv(
325325
data_format = backend.standardize_data_format(data_format)
326326
num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2
327327

328-
assert (
329-
data_format == "channels_last"
330-
), "`depthwise_conv` is supported only for channels_last data_format"
328+
assert data_format == "channels_last", (
329+
"`depthwise_conv` is supported only for channels_last data_format"
330+
)
331331

332332
strides = _adjust_strides_dilation(strides, num_spatial_dims)
333333
dilation_rate = _adjust_strides_dilation(dilation_rate, num_spatial_dims)

keras/src/backend/openvino/numpy.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ def mean(x, axis=None, keepdims=False):
8181

8282

8383
def max(x, axis=None, keepdims=False, initial=None):
84-
assert (
85-
initial is None
86-
), "`max` with not None initial is not supported by openvino backend"
84+
assert initial is None, (
85+
"`max` with not None initial is not supported by openvino backend"
86+
)
8787
x = get_ov_output(x)
8888
reduce_axis = ov_opset.constant(axis, Type.i32).output(0)
8989
return OpenVINOKerasTensor(
@@ -260,9 +260,9 @@ def bincount(x, weights=None, minlength=0, sparse=False):
260260

261261

262262
def broadcast_to(x, shape):
263-
assert isinstance(
264-
shape, (tuple, list)
265-
), "`broadcast_to` is supported only for tuple and list `shape`"
263+
assert isinstance(shape, (tuple, list)), (
264+
"`broadcast_to` is supported only for tuple and list `shape`"
265+
)
266266
target_shape = ov_opset.constant(list(shape), Type.i32).output(0)
267267
x = get_ov_output(x)
268268
return OpenVINOKerasTensor(ov_opset.broadcast(x, target_shape).output(0))

keras/src/backend/tensorflow/linalg.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ def lstsq(a, b, rcond=None):
203203
b = b[:, None]
204204
if a.ndim != 2:
205205
raise TypeError(
206-
f"{a.ndim}-dimensional array given. "
207-
"Array must be two-dimensional"
206+
f"{a.ndim}-dimensional array given. Array must be two-dimensional"
208207
)
209208
if b.ndim != 2:
210209
raise TypeError(

keras/src/backend/tensorflow/numpy.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2469,8 +2469,7 @@ def squeeze(x, axis=None):
24692469
for a in axis:
24702470
if static_shape[a] != 1:
24712471
raise ValueError(
2472-
f"Cannot squeeze axis={a}, because the "
2473-
"dimension is not 1."
2472+
f"Cannot squeeze axis={a}, because the dimension is not 1."
24742473
)
24752474
axis = sorted([canonicalize_axis(a, len(static_shape)) for a in axis])
24762475
if isinstance(x, tf.SparseTensor):

keras/src/callbacks/csv_logger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def handle_value(k):
5959
isinstance(k, collections.abc.Iterable)
6060
and not is_zero_dim_ndarray
6161
):
62-
return f"\"[{', '.join(map(str, k))}]\""
62+
return f'"[{", ".join(map(str, k))}]"'
6363
else:
6464
return k
6565

keras/src/callbacks/swap_ema_weights_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_swap_ema_weights_with_invalid_optimizer(self):
5353
model = self._get_compiled_model(use_ema=False)
5454
with self.assertRaisesRegex(
5555
ValueError,
56-
("SwapEMAWeights must be used when " "`use_ema=True` is set"),
56+
("SwapEMAWeights must be used when `use_ema=True` is set"),
5757
):
5858
model.fit(
5959
self.x_train,

keras/src/export/export_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def make_input_spec(x):
6565
if isinstance(x, layers.InputSpec):
6666
if x.shape is None or x.dtype is None:
6767
raise ValueError(
68-
"The `shape` and `dtype` must be provided. " f"Received: x={x}"
68+
f"The `shape` and `dtype` must be provided. Received: x={x}"
6969
)
7070
input_spec = x
7171
elif isinstance(x, backend.KerasTensor):

keras/src/export/onnx.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ def _check_jax_kwargs(kwargs):
116116
}
117117
if kwargs["is_static"] is not True:
118118
raise ValueError(
119-
"`is_static` must be `True` in `kwargs` when using the jax "
120-
"backend."
119+
"`is_static` must be `True` in `kwargs` when using the jax backend."
121120
)
122121
if kwargs["jax2tf_kwargs"]["enable_xla"] is not False:
123122
raise ValueError(

keras/src/layers/activations/leaky_relu.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def __init__(self, negative_slope=0.3, **kwargs):
3939
if "alpha" in kwargs:
4040
negative_slope = kwargs.pop("alpha")
4141
warnings.warn(
42-
"Argument `alpha` is deprecated. "
43-
"Use `negative_slope` instead."
42+
"Argument `alpha` is deprecated. Use `negative_slope` instead."
4443
)
4544
super().__init__(**kwargs)
4645
if negative_slope is None or negative_slope < 0:

keras/src/layers/attention/grouped_query_attention.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ def __init__(
112112
self.num_key_value_heads = num_key_value_heads
113113
if num_query_heads % num_key_value_heads != 0:
114114
raise ValueError(
115-
"`num_query_heads` must be divisible"
116-
" by `num_key_value_heads`."
115+
"`num_query_heads` must be divisible by `num_key_value_heads`."
117116
)
118117
self.num_repeats = num_query_heads // num_key_value_heads
119118
self.dropout = dropout

keras/src/layers/attention/multi_head_attention_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def test_multi_head_attention_output_shape_as_int(self):
644644
output = mha(query=query, value=value)
645645

646646
assert output.shape == (2, 4, 8), (
647-
f"Expected shape (2, 4, 8)," f" got {output.shape}"
647+
f"Expected shape (2, 4, 8), got {output.shape}"
648648
)
649649

650650
def test_multi_head_attention_output_shape_as_tuple(self):
@@ -657,7 +657,7 @@ def test_multi_head_attention_output_shape_as_tuple(self):
657657
output = mha(query=query, value=value)
658658

659659
assert output.shape == (2, 4, 8, 8), (
660-
f"Expected shape (2, 4, 8, 8)," f" got {output.shape}"
660+
f"Expected shape (2, 4, 8, 8), got {output.shape}"
661661
)
662662

663663
def test_multi_head_attention_output_shape_error(self):

keras/src/layers/convolutional/base_conv.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,7 @@ def enable_lora(
282282
)
283283
if self.lora_enabled:
284284
raise ValueError(
285-
"lora is already enabled. "
286-
"This can only be done once per layer."
285+
"lora is already enabled. This can only be done once per layer."
287286
)
288287
self._tracker.unlock()
289288
self.lora_kernel_a = self.add_weight(

keras/src/layers/core/dense.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,7 @@ def enable_lora(
168168
)
169169
if self.lora_enabled:
170170
raise ValueError(
171-
"lora is already enabled. "
172-
"This can only be done once per layer."
171+
"lora is already enabled. This can only be done once per layer."
173172
)
174173
self._tracker.unlock()
175174
self.lora_kernel_a = self.add_weight(

keras/src/layers/core/einsum_dense.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ def enable_lora(
224224
)
225225
if self.lora_enabled:
226226
raise ValueError(
227-
"lora is already enabled. "
228-
"This can only be done once per layer."
227+
"lora is already enabled. This can only be done once per layer."
229228
)
230229
self._tracker.unlock()
231230
self.lora_kernel_a = self.add_weight(

keras/src/layers/core/embedding.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ def enable_lora(
163163
)
164164
if self.lora_enabled:
165165
raise ValueError(
166-
"lora is already enabled. "
167-
"This can only be done once per layer."
166+
"lora is already enabled. This can only be done once per layer."
168167
)
169168
self._tracker.unlock()
170169
self.lora_embeddings_a = self.add_weight(

keras/src/layers/layer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1430,8 +1430,7 @@ def _build_by_run_for_kwargs(self, shapes_dict):
14301430

14311431
def __repr__(self):
14321432
return (
1433-
f"<{self.__class__.__name__} "
1434-
f"name={self.name}, built={self.built}>"
1433+
f"<{self.__class__.__name__} name={self.name}, built={self.built}>"
14351434
)
14361435

14371436
def __str__(self):

keras/src/layers/normalization/layer_normalization_test.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,7 @@ def test_ln_basics(self):
8787
def test_invalid_axis(self):
8888
with self.assertRaisesRegex(
8989
TypeError,
90-
(
91-
"Expected an int or a list/tuple of ints for the argument "
92-
"'axis'"
93-
),
90+
("Expected an int or a list/tuple of ints for the argument 'axis'"),
9491
):
9592
layers.LayerNormalization(axis={"axis": -1})
9693

keras/src/layers/preprocessing/hashed_crossing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
super().__init__(name=name, dtype=dtype)
9191
if sparse and backend.backend() != "tensorflow":
9292
raise ValueError(
93-
"`sparse=True` can only be used with the " "TensorFlow backend."
93+
"`sparse=True` can only be used with the TensorFlow backend."
9494
)
9595

9696
argument_validation.validate_string_arg(

keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def __init__(self, factor=0.5, data_format=None, seed=None, **kwargs):
4949
super().__init__(**kwargs)
5050
if factor < 0 or factor > 1:
5151
raise ValueError(
52-
"`factor` should be between 0 and 1. "
53-
f"Received: factor={factor}"
52+
f"`factor` should be between 0 and 1. Received: factor={factor}"
5453
)
5554
self.factor = factor
5655
self.data_format = backend.standardize_data_format(data_format)

keras/src/layers/preprocessing/integer_lookup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def __init__(
328328
)
329329
if sparse and backend.backend() != "tensorflow":
330330
raise ValueError(
331-
"`sparse=True` can only be used with the " "TensorFlow backend."
331+
"`sparse=True` can only be used with the TensorFlow backend."
332332
)
333333
if vocabulary_dtype != "int64":
334334
raise ValueError(

keras/src/layers/preprocessing/string_lookup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def __init__(
314314
)
315315
if sparse and backend.backend() != "tensorflow":
316316
raise ValueError(
317-
"`sparse=True` can only be used with the " "TensorFlow backend."
317+
"`sparse=True` can only be used with the TensorFlow backend."
318318
)
319319
self.encoding = encoding
320320
super().__init__(

keras/src/layers/preprocessing/text_vectorization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,11 @@ def __init__(
226226
)
227227
if sparse and backend.backend() != "tensorflow":
228228
raise ValueError(
229-
"`sparse=True` can only be used with the " "TensorFlow backend."
229+
"`sparse=True` can only be used with the TensorFlow backend."
230230
)
231231
if ragged and backend.backend() != "tensorflow":
232232
raise ValueError(
233-
"`ragged=True` can only be used with the " "TensorFlow backend."
233+
"`ragged=True` can only be used with the TensorFlow backend."
234234
)
235235

236236
# 'standardize' must be one of

keras/src/metrics/metric.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _check_super_called(self):
247247
)
248248

249249
def __repr__(self):
250-
return f"<{self.__class__.__name__} " f"name={self.name}>"
250+
return f"<{self.__class__.__name__} name={self.name}>"
251251

252252
def __str__(self):
253253
return self.__repr__()

keras/src/ops/image.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -978,8 +978,7 @@ def _pad_images(
978978
)
979979
if left_padding < 0:
980980
raise ValueError(
981-
"left_padding must be >= 0. "
982-
f"Received: left_padding={left_padding}"
981+
f"left_padding must be >= 0. Received: left_padding={left_padding}"
983982
)
984983
if right_padding < 0:
985984
raise ValueError(
@@ -1198,8 +1197,7 @@ def _crop_images(
11981197

11991198
if top_cropping < 0:
12001199
raise ValueError(
1201-
"top_cropping must be >= 0. "
1202-
f"Received: top_cropping={top_cropping}"
1200+
f"top_cropping must be >= 0. Received: top_cropping={top_cropping}"
12031201
)
12041202
if target_height < 0:
12051203
raise ValueError(
@@ -1213,8 +1211,7 @@ def _crop_images(
12131211
)
12141212
if target_width < 0:
12151213
raise ValueError(
1216-
"target_width must be >= 0. "
1217-
f"Received: target_width={target_width}"
1214+
f"target_width must be >= 0. Received: target_width={target_width}"
12181215
)
12191216

12201217
# Compute start_indices and shape

keras/src/ops/linalg.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -596,12 +596,11 @@ def call(self, a, b):
596596
def compute_output_spec(self, a, b):
597597
if len(a.shape) != 2:
598598
raise ValueError(
599-
"Expected a to have rank 2. " f"Received: a.shape={a.shape}"
599+
f"Expected a to have rank 2. Received: a.shape={a.shape}"
600600
)
601601
if len(b.shape) not in (1, 2):
602602
raise ValueError(
603-
"Expected b to have rank 1 or 2. "
604-
f"Received: b.shape={b.shape}"
603+
f"Expected b to have rank 1 or 2. Received: b.shape={b.shape}"
605604
)
606605
m, n = a.shape
607606
if b.shape[0] != m:
@@ -666,8 +665,7 @@ def _assert_1d(*arrays):
666665
for a in arrays:
667666
if a.ndim < 1:
668667
raise ValueError(
669-
"Expected input to have rank >= 1. "
670-
"Received scalar input {a}."
668+
"Expected input to have rank >= 1. Received scalar input {a}."
671669
)
672670

673671

keras/src/ops/numpy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6754,7 +6754,7 @@ class Argpartition(Operation):
67546754
def __init__(self, kth, axis=-1):
67556755
super().__init__()
67566756
if not isinstance(kth, int):
6757-
raise ValueError("kth must be an integer. Received:" f"kth = {kth}")
6757+
raise ValueError(f"kth must be an integer. Received:kth = {kth}")
67586758
self.kth = kth
67596759
self.axis = axis
67606760

keras/src/random/seed_generator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, seed=None, name=None, **kwargs):
7676

7777
if not isinstance(seed, int):
7878
raise ValueError(
79-
"Argument `seed` must be an integer. " f"Received: seed={seed}"
79+
f"Argument `seed` must be an integer. Received: seed={seed}"
8080
)
8181

8282
def seed_initializer(*args, **kwargs):

keras/src/saving/file_editor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def _generate_metadata_info(self, rich_style=False):
500500
if rich_style:
501501
version = f"{summary_utils.highlight_symbol(version)}"
502502
date = f"{summary_utils.highlight_symbol(date)}"
503-
return f"Saved with Keras {version} " f"- date: {date}"
503+
return f"Saved with Keras {version} - date: {date}"
504504

505505
def _print_weights_structure(
506506
self, weights_dict, indent=0, is_first=True, prefix="", inner_path=""

0 commit comments

Comments
 (0)