Skip to content

Commit fd4496d

Browse files
Chenyu Zhangfacebook-github-bot
authored andcommitted
kvzch use new operator in model publish (#3108)
Summary: Pull Request resolved: #3108 Publish change to enable KVEmbeddingInference when use_virtual_table is set to true Reviewed By: emlin Differential Revision: D75321284 fbshipit-source-id: 07f128b9ed8fc024a267b661f18766ed6609e374
1 parent d7c7098 commit fd4496d

File tree

3 files changed

+82
-40
lines changed

3 files changed

+82
-40
lines changed

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PoolingMode,
2121
rounded_row_size_in_bytes,
2222
)
23+
from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference
2324
from torchrec.distributed.batched_embedding_kernel import (
2425
BaseBatchedEmbedding,
2526
BaseBatchedEmbeddingBag,
@@ -237,13 +238,16 @@ def __init__(
237238
super().__init__(config, pg, device)
238239

239240
managed: List[EmbeddingLocation] = []
241+
is_virtual_table: bool = False
240242
for table in config.embedding_tables:
241243
if device is not None and device.type == "cuda":
242244
managed.append(
243245
compute_kernel_to_embedding_location(table.compute_kernel)
244246
)
245247
else:
246248
managed.append(EmbeddingLocation.HOST)
249+
if table.use_virtual_table:
250+
is_virtual_table = True
247251
self._config: GroupedEmbeddingConfig = config
248252
self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params)
249253
self._is_weighted: Optional[bool] = config.is_weighted
@@ -284,6 +288,8 @@ def __init__(
284288

285289
if self.lengths_to_tbe:
286290
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
291+
elif is_virtual_table:
292+
tbe_clazz = KVEmbeddingInference
287293
else:
288294
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen
289295

@@ -448,13 +454,16 @@ def __init__(
448454
super().__init__(config, pg, device)
449455

450456
managed: List[EmbeddingLocation] = []
457+
is_virtual_table = False
451458
for table in config.embedding_tables:
452459
if device is not None and device.type == "cuda":
453460
managed.append(
454461
compute_kernel_to_embedding_location(table.compute_kernel)
455462
)
456463
else:
457464
managed.append(EmbeddingLocation.HOST)
465+
if table.use_virtual_table:
466+
is_virtual_table = True
458467
self._config: GroupedEmbeddingConfig = config
459468
self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params)
460469
self._quant_state_dict_split_scale_bias: bool = (
@@ -465,37 +474,40 @@ def __init__(
465474
)
466475
# 16 for CUDA, 1 for others like CPU and MTIA.
467476
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
468-
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
469-
IntNBitTableBatchedEmbeddingBagsCodegen(
470-
embedding_specs=[
477+
embedding_clazz = (
478+
KVEmbeddingInference
479+
if is_virtual_table
480+
else IntNBitTableBatchedEmbeddingBagsCodegen
481+
)
482+
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz(
483+
embedding_specs=[
484+
(
485+
table.name,
486+
local_rows,
471487
(
472-
table.name,
473-
local_rows,
474-
(
475-
local_cols
476-
if self._quant_state_dict_split_scale_bias
477-
else table.embedding_dim
478-
),
479-
data_type_to_sparse_type(table.data_type),
480-
location,
481-
)
482-
for local_rows, local_cols, table, location in zip(
483-
self._local_rows,
484-
self._local_cols,
485-
config.embedding_tables,
486-
managed,
487-
)
488-
],
489-
device=device,
490-
pooling_mode=PoolingMode.NONE,
491-
feature_table_map=self._feature_table_map,
492-
row_alignment=self._tbe_row_alignment,
493-
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
494-
feature_names_per_table=[
495-
table.feature_names for table in config.embedding_tables
496-
],
497-
**(tbe_fused_params(fused_params) or {}),
498-
)
488+
local_cols
489+
if self._quant_state_dict_split_scale_bias
490+
else table.embedding_dim
491+
),
492+
data_type_to_sparse_type(table.data_type),
493+
location,
494+
)
495+
for local_rows, local_cols, table, location in zip(
496+
self._local_rows,
497+
self._local_cols,
498+
config.embedding_tables,
499+
managed,
500+
)
501+
],
502+
device=device,
503+
pooling_mode=PoolingMode.NONE,
504+
feature_table_map=self._feature_table_map,
505+
row_alignment=self._tbe_row_alignment,
506+
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
507+
feature_names_per_table=[
508+
table.feature_names for table in config.embedding_tables
509+
],
510+
**(tbe_fused_params(fused_params) or {}),
499511
)
500512
if device is not None:
501513
self._emb_module.initialize_weights()

torchrec/quant/embedding_modules.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
IntNBitTableBatchedEmbeddingBagsCodegen,
3131
PoolingMode,
3232
)
33+
from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference
3334
from torch import Tensor
3435
from torchrec.distributed.utils import none_throws
3536
from torchrec.modules.embedding_configs import (
@@ -357,7 +358,7 @@ def __init__(
357358
self._is_weighted = is_weighted
358359
self._embedding_bag_configs: List[EmbeddingBagConfig] = tables
359360
self._key_to_tables: Dict[
360-
Tuple[PoolingType, DataType, bool], List[EmbeddingBagConfig]
361+
Tuple[PoolingType, bool], List[EmbeddingBagConfig]
361362
] = defaultdict(list)
362363
self._feature_names: List[str] = []
363364
self._feature_splits: List[int] = []
@@ -383,15 +384,13 @@ def __init__(
383384
key = (table.pooling, table.use_virtual_table)
384385
else:
385386
key = (table.pooling, False)
386-
# pyre-ignore
387387
self._key_to_tables[key].append(table)
388388

389389
location = (
390390
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
391391
)
392392

393-
for key, emb_configs in self._key_to_tables.items():
394-
pooling = key[0]
393+
for (pooling, use_virtual_table), emb_configs in self._key_to_tables.items():
395394
embedding_specs = []
396395
weight_lists: Optional[
397396
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -420,7 +419,12 @@ def __init__(
420419
)
421420
feature_table_map.extend([idx] * table.num_features())
422421

423-
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
422+
embedding_clazz = (
423+
KVEmbeddingInference
424+
if use_virtual_table
425+
else IntNBitTableBatchedEmbeddingBagsCodegen
426+
)
427+
emb_module = embedding_clazz(
424428
embedding_specs=embedding_specs,
425429
pooling_mode=pooling_type_to_pooling_mode(pooling),
426430
weight_lists=weight_lists,
@@ -790,8 +794,7 @@ def __init__( # noqa C901
790794
key = (table.data_type, False)
791795
self._key_to_tables[key].append(table)
792796
self._feature_splits: List[int] = []
793-
for key, emb_configs in self._key_to_tables.items():
794-
data_type = key[0]
797+
for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items():
795798
embedding_specs = []
796799
weight_lists: Optional[
797800
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -816,10 +819,13 @@ def __init__( # noqa C901
816819
table_name_to_quantized_weights[table.name]
817820
)
818821
feature_table_map.extend([idx] * table.num_features())
819-
# move to here to make sure feature_names order is consistent with the embedding groups
820822
self._feature_names.extend(table.feature_names)
821-
822-
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
823+
embedding_clazz = (
824+
KVEmbeddingInference
825+
if use_virtual_table
826+
else IntNBitTableBatchedEmbeddingBagsCodegen
827+
)
828+
emb_module = embedding_clazz(
823829
embedding_specs=embedding_specs,
824830
pooling_mode=PoolingMode.NONE,
825831
weight_lists=weight_lists,

torchrec/quant/tests/test_embedding_modules.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import logging
1011
import unittest
1112
from dataclasses import replace
1213
from typing import Dict, List, Optional, Type
@@ -44,6 +45,19 @@
4445
KeyedTensor,
4546
)
4647

48+
logger: logging.Logger = logging.getLogger(__name__)
49+
50+
51+
def load_required_dram_kv_embedding_libraries() -> bool:
52+
try:
53+
torch.ops.load_library(
54+
"//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference"
55+
)
56+
return True
57+
except Exception as e:
58+
logger.error(f"Failed to load dram_kv_embedding libraries, skipping test: {e}")
59+
return False
60+
4761

4862
class EmbeddingBagCollectionTest(unittest.TestCase):
4963
def _asserting_same_embeddings(
@@ -260,6 +274,11 @@ def test_multiple_features(self) -> None:
260274
)
261275
self._test_ebc([eb1_config, eb2_config], features)
262276

277+
# pyre-ignore: Invalid decoration [56]
278+
@unittest.skipIf(
279+
not load_required_dram_kv_embedding_libraries(),
280+
"Skip when required libraries are not available",
281+
)
263282
def test_multiple_kernels_per_ebc_table(self) -> None:
264283
class TestModule(torch.nn.Module):
265284
def __init__(self, m: torch.nn.Module) -> None:
@@ -780,6 +799,11 @@ def __init__(self, m: torch.nn.Module) -> None:
780799
self.assertEqual(config.name, "t2")
781800
self.assertEqual(config.data_type, DataType.INT8)
782801

802+
# pyre-ignore: Invalid decoration [56]
803+
@unittest.skipIf(
804+
not load_required_dram_kv_embedding_libraries(),
805+
"Skip when required libraries are not available",
806+
)
783807
def test_multiple_kernels_per_ec_table(self) -> None:
784808
class TestModule(torch.nn.Module):
785809
def __init__(self, m: torch.nn.Module) -> None:

0 commit comments

Comments
 (0)