Skip to content

Commit cd64b9d

Browse files
connernilsenfacebook-github-bot
authored andcommitted
upgrade pyre version in fbcode/torchrec - batch 1 (#2516)
Summary: Pull Request resolved: #2516 Reviewed By: PaulZhang12 Differential Revision: D64846836 fbshipit-source-id: 96f4563d1c3e1cddc92071db886c593fdd96f03e
1 parent 0b938db commit cd64b9d

18 files changed

+65
-23
lines changed

.pyre_configuration

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
}
1414
],
1515
"strict": true,
16-
"version": "0.0.101703592829"
16+
"version": "0.0.101729681899"
1717
}

torchrec/distributed/batched_embedding_kernel.py

+22
Original file line numberDiff line numberDiff line change
@@ -674,13 +674,24 @@ def __init__(
674674
self.table_name_to_count: Dict[str, int] = {}
675675
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
676676

677+
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
678+
# `ShardedEmbeddingTable`.
677679
for idx, config in enumerate(self._config.embedding_tables):
680+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
678681
self._local_rows.append(config.local_rows)
682+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
683+
# `get_weight_init_min`.
679684
self._weight_init_mins.append(config.get_weight_init_min())
685+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
686+
# `get_weight_init_max`.
680687
self._weight_init_maxs.append(config.get_weight_init_max())
688+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
689+
# `num_embeddings`.
681690
self._num_embeddings.append(config.num_embeddings)
691+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
682692
self._local_cols.append(config.local_cols)
683693
self._feature_table_map.extend([idx] * config.num_features())
694+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
684695
if config.name not in self.table_name_to_count:
685696
self.table_name_to_count[config.name] = 0
686697
self.table_name_to_count[config.name] += 1
@@ -1080,13 +1091,24 @@ def __init__(
10801091
self.table_name_to_count: Dict[str, int] = {}
10811092
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
10821093

1094+
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
1095+
# `ShardedEmbeddingTable`.
10831096
for idx, config in enumerate(self._config.embedding_tables):
1097+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
10841098
self._local_rows.append(config.local_rows)
1099+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1100+
# `get_weight_init_min`.
10851101
self._weight_init_mins.append(config.get_weight_init_min())
1102+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1103+
# `get_weight_init_max`.
10861104
self._weight_init_maxs.append(config.get_weight_init_max())
1105+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1106+
# `num_embeddings`.
10871107
self._num_embeddings.append(config.num_embeddings)
1108+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
10881109
self._local_cols.append(config.local_cols)
10891110
self._feature_table_map.extend([idx] * config.num_features())
1111+
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
10901112
if config.name not in self.table_name_to_count:
10911113
self.table_name_to_count[config.name] = 0
10921114
self.table_name_to_count[config.name] += 1

torchrec/distributed/benchmark/benchmark_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9+
# pyre-ignore-all-errors[16]
910

1011
#!/usr/bin/env python3
1112

@@ -431,6 +432,7 @@ def transform_module(
431432
compile_mode: CompileMode,
432433
world_size: int,
433434
batch_size: int,
435+
# pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter.
434436
ctx: ContextManager,
435437
benchmark_unsharded_module: bool = False,
436438
) -> torch.nn.Module:
@@ -1051,7 +1053,6 @@ def benchmark_module(
10511053
for compile_mode in compile_modes:
10521054
if not benchmark_unsharded:
10531055
# Test sharders should have a singular sharding_type
1054-
# pyre-ignore [16]
10551056
sharder._sharding_type = sharding_type.value
10561057
# pyre-ignore [6]
10571058
benchmark_type = benchmark_type_name(compile_mode, sharding_type)

torchrec/distributed/keyed_jagged_tensor_pool.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,9 @@ def _update_local(
667667
) -> None:
668668
raise NotImplementedError("Inference does not support update")
669669

670+
# pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value of
671+
# `None`.
670672
def _update_preproc(self, values: KeyedJaggedTensor) -> KeyedJaggedTensor:
671-
# pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value
672-
# of `None`.
673673
pass
674674

675675

torchrec/distributed/object_pool.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,16 @@ def input_dist(
131131
*input,
132132
# pyre-ignore[2]
133133
**kwargs,
134+
# pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit return
135+
# value of `None`.
134136
) -> Awaitable[Awaitable[torch.Tensor]]:
135-
# pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit
136-
# return value of `None`.
137137
pass
138138

139+
# pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`.
139140
def compute(self, ctx: ShrdCtx, dist_input: torch.Tensor) -> DistOut:
140-
# pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`.
141141
pass
142142

143+
# pyre-fixme[7]: Expected `LazyAwaitable[Out]` but got implicit return value of
144+
# `None`.
143145
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
144-
# pyre-fixme[7]: Expected `LazyAwaitable[Variable[Out]]` but got implicit
145-
# return value of `None`.
146146
pass

torchrec/distributed/planner/tests/test_partitioners.py

+4
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,8 @@ def test_different_sharding_plan(self) -> None:
773773
for shard in sharding_option.shards:
774774
if shard.storage and shard.rank is not None:
775775
greedy_perf_hbm_uses[
776+
# pyre-fixme[6]: For 1st argument expected `SupportsIndex`
777+
# but got `Optional[int]`.
776778
shard.rank
777779
] += shard.storage.hbm # pyre-ignore[16]
778780

@@ -796,6 +798,8 @@ def test_different_sharding_plan(self) -> None:
796798
for sharding_option in sharding_options:
797799
for shard in sharding_option.shards:
798800
if shard.storage and shard.rank:
801+
# pyre-fixme[6]: For 1st argument expected `SupportsIndex` but
802+
# got `Optional[int]`.
799803
memory_balanced_hbm_uses[shard.rank] += shard.storage.hbm
800804

801805
self.assertTrue(max(memory_balanced_hbm_uses) < max(greedy_perf_hbm_uses))

torchrec/distributed/shards_wrapper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
aten = torch.ops.aten # pyre-ignore[5]
2828

2929

30-
# pyre-fixme[13]: Attribute `_local_shards` is never initialized.
31-
# pyre-fixme[13]: Attribute `_storage_meta` is never initialized.
3230
class LocalShardsWrapper(torch.Tensor):
3331
"""
3432
A wrapper class to hold local shards of a DTensor.
@@ -37,7 +35,9 @@ class LocalShardsWrapper(torch.Tensor):
3735
"""
3836

3937
__slots__ = ["_local_shards", "_storage_meta"]
38+
# pyre-fixme[13]: Attribute `_local_shards` is never initialized.
4039
_local_shards: List[torch.Tensor]
40+
# pyre-fixme[13]: Attribute `_storage_meta` is never initialized.
4141
_storage_meta: TensorStorageMetadata
4242

4343
@staticmethod

torchrec/distributed/tensor_pool.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ def _update_local(
459459
deduped_ids, dedup_permutation = deterministic_dedup(ids)
460460
shard.update(deduped_ids, values[dedup_permutation])
461461

462+
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
462463
def _update_preproc(self, values: torch.Tensor) -> torch.Tensor:
463-
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
464464
pass
465465

466466
def update(self, ids: torch.Tensor, values: torch.Tensor) -> None:

torchrec/distributed/tests/test_awaitable.py

+4
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@ def _wait_impl(self) -> torch.Tensor:
2424
class AwaitableTests(unittest.TestCase):
2525
def test_callback(self) -> None:
2626
awaitable = AwaitableInstance()
27+
# pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
28+
# `(ret: Any) -> int`.
2729
awaitable.callbacks.append(lambda ret: 2 * ret)
2830
self.assertTrue(
2931
torch.allclose(awaitable.wait(), torch.FloatTensor([2.0, 4.0, 6.0]))
3032
)
3133

3234
def test_callback_chained(self) -> None:
3335
awaitable = AwaitableInstance()
36+
# pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
37+
# `(ret: Any) -> int`.
3438
awaitable.callbacks.append(lambda ret: 2 * ret)
3539
awaitable.callbacks.append(lambda ret: ret**2)
3640
self.assertTrue(

torchrec/distributed/tests/test_embedding_types.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def __init__(self) -> None:
2929
torch.nn.Module(),
3030
]
3131

32+
# pyre-fixme[7]: Expected `EmbeddingBagCollectionContext` but got implicit
33+
# return value of `None`.
3234
def create_context(self) -> ShrdCtx:
33-
# pyre-fixme[7]: Expected `EmbeddingBagCollectionContext` but got implicit
34-
# return value of `None`.
3535
pass
3636

3737
def input_dist(
@@ -41,19 +41,18 @@ def input_dist(
4141
*input,
4242
# pyre-ignore[2]
4343
**kwargs,
44-
) -> Awaitable[Awaitable[CompIn]]:
4544
# pyre-fixme[7]: Expected `Awaitable[Awaitable[KJTList]]` but got implicit
4645
# return value of `None`.
46+
) -> Awaitable[Awaitable[CompIn]]:
4747
pass
4848

49+
# pyre-fixme[7]: Expected `List[Tensor]` but got implicit return value of `None`.
4950
def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut:
50-
# pyre-fixme[7]: Expected `List[Tensor]` but got implicit return value of
51-
# `None`.
5251
pass
5352

53+
# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got implicit
54+
# return value of `None`.
5455
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
55-
# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got
56-
# implicit return value of `None`.
5756
pass
5857

5958

torchrec/distributed/tests/test_lazy_awaitable.py

-2
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
244244

245245
tempFile = None
246246
with tempfile.NamedTemporaryFile(delete=False) as f:
247-
# pyre-fixme[6]: For 2nd argument expected `SupportsWrite[bytes]` but
248-
# got `_TemporaryFileWrapper[bytes]`.
249247
pickle.dump(gm, f)
250248
tempFile = f
251249

torchrec/distributed/train_pipeline/train_pipelines.py

+1
Original file line numberDiff line numberDiff line change
@@ -1613,6 +1613,7 @@ def __init__(
16131613

16141614
def get_compiled_autograd_ctx(
16151615
self,
1616+
# pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter.
16161617
) -> ContextManager:
16171618
# this allows for pipelining
16181619
# to avoid doing a sum on None

torchrec/distributed/types.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
# other metaclasses (i.e. AwaitableMeta) for customized
4545
# behaviors, as Generic is non-trival metaclass in
4646
# python 3.6 and below
47-
# pyre-fixme[21]: Could not find name `GenericMeta` in `typing` (stubbed).
4847
from typing import GenericMeta
4948
except ImportError:
5049
# In python 3.7+, GenericMeta doesn't exist as it's no
@@ -975,6 +974,9 @@ def __init__(
975974
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
976975
self._qcomm_codecs_registry = qcomm_codecs_registry
977976

977+
# pyre-fixme[56]: Pyre doesn't yet support decorators with ParamSpec applied to
978+
# generic functions. Consider using a context manager instead of a decorator, if
979+
# possible.
978980
@abc.abstractclassmethod
979981
# pyre-ignore [3]
980982
def shard(

torchrec/distributed/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,11 @@ def maybe_reset_parameters(m: nn.Module) -> None:
476476

477477

478478
def maybe_annotate_embedding_event(
479-
event: EmbeddingEvent, module_fqn: Optional[str], sharding_type: Optional[str]
479+
event: EmbeddingEvent,
480+
module_fqn: Optional[str],
481+
sharding_type: Optional[str],
482+
# pyre-fixme[24]: Generic type `AbstractContextManager` expects 2 type parameters,
483+
# received 1.
480484
) -> AbstractContextManager[None]:
481485
if module_fqn and sharding_type:
482486
annotation = f"[{event.value}]_[{module_fqn}]_[{sharding_type}]"

torchrec/inference/inference_legacy/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
# pyre-ignore-all-errors[0, 21]
9+
810
"""Torchrec Inference
911
1012
Torchrec inference provides a Torch.Deploy based library for GPU inference.

torchrec/linter/module_linter.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def print_error_message(
3737
"""
3838
lint_item = {
3939
"path": python_path,
40+
# pyre-fixme[16]: `AST` has no attribute `lineno`.
4041
"line": node.lineno,
42+
# pyre-fixme[16]: `AST` has no attribute `col_offset`.
4143
"char": node.col_offset + 1,
4244
"severity": severity,
4345
"name": name,

torchrec/metrics/tests/test_metric_module.py

+1
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ def _test_adjust_compute_interval(
528528
)
529529
mock_time.time = MagicMock(return_value=0.0)
530530

531+
# pyre-fixme[53]: Captured variable `batch` is not annotated.
531532
def _train(metric_module: RecMetricModule) -> float:
532533
for _ in range(metric_module.compute_interval_steps):
533534
metric_module.update(batch)

torchrec/modules/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def convert_list_of_modules_to_modulelist(
133133
# `Iterable[torch.nn.Module]`.
134134
len(modules)
135135
== sizes[0]
136+
# pyre-fixme[6]: For 1st argument expected `pyre_extensions.PyreReadOnly[Sized]`
137+
# but got `Iterable[Module]`.
136138
), f"the counts of modules ({len(modules)}) do not match with the required counts {sizes}"
137139
if len(sizes) == 1:
138140
return torch.nn.ModuleList(modules)

0 commit comments

Comments
 (0)