Skip to content

Commit 03fe057

Browse files
authored
Ignore shard_map attr error in mypy. (#97)
1 parent 91422d5 commit 03fe057

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
3232
FeatureConfig = config.FeatureConfig
33-
shard_map = jax.experimental.shard_map.shard_map
33+
shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
3434

3535

3636
def _get_partition_spec(

keras_rs/src/layers/embedding/jax/embedding_lookup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from keras_rs.src.layers.embedding.jax import embedding_utils
1717

1818
ShardedCooMatrix = embedding_utils.ShardedCooMatrix
19-
shard_map = jax.experimental.shard_map.shard_map
19+
shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
2020

2121
T = TypeVar("T")
2222
Nested = Union[T, Sequence[T], Mapping[str, T]]

0 commit comments

Comments
 (0)