Skip to content

Commit 3535601

Browse files
zhuzilinfacebook-github-bot
authored andcommitted
Migrate naive id transformer (pytorch#661)
Summary: This PR is trying to migrate `NaiveIDTransformer` to the top folder. The idea behind the naive transformer is quite simple, we use a hash map (`ska::flat_hash_map`) to store the global id to cache id map, and use a bitmap to record the empty slot in the cache. cc reyoung Pull Request resolved: pytorch#661 Reviewed By: zyan0 Differential Revision: D39857301 Pulled By: colin2328 fbshipit-source-id: a7681826827d84d51dcf5d91b3a0eca7583f391e
1 parent f206180 commit 3535601

File tree

14 files changed

+461
-5
lines changed

14 files changed

+461
-5
lines changed

.github/workflows/unittest_ci_cpu.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ jobs:
121121
working-directory: ./cpp-build
122122
shell: bash
123123
run: |
124-
conda run -n build_binary cmake -DBUILD_TEST=ON -DCMAKE_CXX_COMPILER=$(which x86_64-conda-linux-gnu-g++) ..
124+
conda run -n build_binary cmake \
125+
-DBUILD_TEST=ON \
126+
-DCMAKE_PREFIX_PATH=/home/ec2-user/miniconda/envs/build_binary/lib/python3.7/site-packages/torch/share/cmake ..
125127
conda run -n build_binary make -j
126128
- name: Test
127129
working-directory: ./cpp-build

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@ cmake_minimum_required(VERSION 3.11.0 FATAL_ERROR)
99
project(TorchRec
1010
LANGUAGES CXX C)
1111

12+
find_package(Torch REQUIRED)
13+
1214
set(CMAKE_CXX_STANDARD 20)
1315

1416
include(FetchContent)
1517

1618
option(BUILD_TEST "Build C++ test binaries (need gtest and gbenchmark)" OFF)
1719

20+
add_definitions("-D_GLIBCXX_USE_CXX11_ABI=0")
21+
1822
add_subdirectory(torchrec/csrc)
1923

2024
if (BUILD_TEST)
@@ -34,5 +38,6 @@ if (BUILD_TEST)
3438
FetchContent_MakeAvailable(googletest google_benchmark)
3539

3640
enable_testing()
41+
add_subdirectory(benchmarks/cpp)
3742
add_subdirectory(test/cpp)
3843
endif()

benchmarks/cpp/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
add_subdirectory(dynamic_embedding)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
function(add_tde_benchmark NAME)
8+
add_executable(${NAME} ${ARGN})
9+
target_link_libraries(${NAME} tde_cpp_objs benchmark::benchmark_main benchmark::benchmark)
10+
endfunction()
11+
12+
add_tde_benchmark(naive_id_transformer_benchmark naive_id_transformer_benchmark.cpp)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <benchmark/benchmark.h>
10+
#include <torch/torch.h>
11+
#include <torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h>
12+
13+
namespace torchrec {
14+
15+
static void BM_NaiveIDTransformer(benchmark::State& state) {
16+
using Tag = int32_t;
17+
NaiveIDTransformer<Tag> transformer(2e8);
18+
torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong);
19+
torch::Tensor cache_ids = torch::empty_like(global_ids);
20+
for (auto _ : state) {
21+
state.PauseTiming();
22+
global_ids.random_(state.range(0), state.range(1));
23+
state.ResumeTiming();
24+
transformer.transform(
25+
std::span{
26+
global_ids.template data_ptr<int64_t>(),
27+
static_cast<size_t>(global_ids.numel())},
28+
std::span{
29+
cache_ids.template data_ptr<int64_t>(),
30+
static_cast<size_t>(cache_ids.numel())});
31+
}
32+
}
33+
34+
BENCHMARK(BM_NaiveIDTransformer)
35+
->Iterations(100)
36+
->Unit(benchmark::kMillisecond)
37+
->ArgNames({"rand_from", "rand_to"})
38+
->Args({static_cast<long long>(1e10), static_cast<long long>(2e10)})
39+
->Args({static_cast<long long>(1e6), static_cast<long long>(2e6)});
40+
41+
} // namespace torchrec

test/cpp/dynamic_embedding/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ function(add_tde_test NAME)
1111
endfunction()
1212

1313
add_tde_test(bits_op_test bits_op_test.cpp)
14+
add_tde_test(naive_id_transformer_test naive_id_transformer_test.cpp)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
#include <torchrec/csrc/dynamic_embedding/details/naive_id_transformer.h>
11+
12+
namespace torchrec {
13+
14+
TEST(tde, NaiveThreadedIDTransformer_NoFilter) {
15+
using Tag = int32_t;
16+
NaiveIDTransformer<Tag, Bitmap<uint8_t>> transformer(16);
17+
const int64_t global_ids[5] = {100, 101, 100, 102, 101};
18+
int64_t cache_ids[5];
19+
int64_t expected_cache_ids[5] = {0, 1, 0, 2, 1};
20+
ASSERT_TRUE(transformer.transform(global_ids, cache_ids));
21+
for (size_t i = 0; i < 5; i++) {
22+
ASSERT_EQ(expected_cache_ids[i], cache_ids[i]);
23+
}
24+
}
25+
26+
TEST(tde, NaiveThreadedIDTransformer_Full) {
27+
using Tag = int32_t;
28+
NaiveIDTransformer<Tag, Bitmap<uint8_t>> transformer(4);
29+
const int64_t global_ids[5] = {100, 101, 102, 103, 104};
30+
int64_t cache_ids[5];
31+
int64_t expected_cache_ids[5] = {0, 1, 2, 3, -1};
32+
33+
ASSERT_FALSE(transformer.transform(global_ids, cache_ids));
34+
for (size_t i = 0; i < 4; i++) {
35+
EXPECT_EQ(expected_cache_ids[i], cache_ids[i]);
36+
}
37+
}
38+
39+
TEST(tde, NaiveThreadedIDTransformer_Evict) {
40+
using Tag = int32_t;
41+
NaiveIDTransformer<Tag, Bitmap<uint8_t>> transformer(4);
42+
const int64_t global_ids[5] = {100, 101, 102, 103, 104};
43+
int64_t cache_ids[5];
44+
45+
ASSERT_FALSE(transformer.transform(global_ids, cache_ids));
46+
47+
const int64_t evict_global_ids[2] = {100, 102};
48+
transformer.evict(evict_global_ids);
49+
50+
const int64_t new_global_ids[4] = {101, 102, 103, 104};
51+
int64_t new_cache_ids[4];
52+
53+
ASSERT_TRUE(transformer.transform(new_global_ids, new_cache_ids));
54+
55+
int64_t expected_cache_ids[4] = {1, 0, 3, 2};
56+
57+
for (size_t i = 0; i < 4; i++) {
58+
EXPECT_EQ(expected_cache_ids[i], new_cache_ids[i]);
59+
}
60+
}
61+
62+
TEST(tde, NaiveThreadedIDTransformer_Iterator) {
63+
using Tag = int32_t;
64+
NaiveIDTransformer<Tag, Bitmap<uint8_t>> transformer(16);
65+
const int64_t global_ids[5] = {100, 101, 100, 102, 101};
66+
int64_t cache_ids[5];
67+
int64_t expected_cache_ids[5] = {3, 4, 3, 5, 4};
68+
ASSERT_TRUE(transformer.transform(global_ids, cache_ids));
69+
70+
auto iterator = transformer.iterator();
71+
for (size_t i = 0; i < 3; i++) {
72+
EXPECT_TRUE(iterator().has_value());
73+
}
74+
EXPECT_TRUE(!iterator().has_value());
75+
}
76+
77+
} // namespace torchrec

torchrec/csrc/dynamic_embedding/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ add_library(tde_cpp_objs
1010
details/ctz_impl.cpp)
1111

1212
target_include_directories(tde_cpp_objs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../)
13+
target_include_directories(tde_cpp_objs PUBLIC ${TORCH_INCLUDE_DIRS})
14+
target_link_libraries(tde_cpp_objs PUBLIC ${TORCH_LIBRARIES})
15+
target_compile_options(tde_cpp_objs PUBLIC -fPIC)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <stdint.h>
11+
12+
namespace torchrec {
13+
14+
/**
15+
* Bitmap
16+
*
17+
* A bitmap for recording whether num_bits of slots are
18+
* occupied or free.
19+
*/
20+
template <typename T = uint32_t>
21+
struct Bitmap {
22+
explicit Bitmap(int64_t num_bits);
23+
Bitmap(const Bitmap&) = delete;
24+
Bitmap(Bitmap&&) noexcept = default;
25+
26+
/**
27+
* Returns the position of the next free slot.
28+
* If the bitmap is full, return `num_total_bits_`.
29+
*/
30+
int64_t next_free_bit();
31+
32+
/**
33+
* Set the slot of position `offset` to free.
34+
*/
35+
void free_bit(int64_t offset);
36+
37+
/**
38+
* Returns if all slots in the bitmap is occupied.
39+
*/
40+
bool full() const;
41+
42+
static constexpr int64_t num_bits_per_value = sizeof(T) * 8;
43+
44+
const int64_t num_total_bits_;
45+
const int64_t num_values_;
46+
std::unique_ptr<T[]> values_;
47+
48+
int64_t next_free_bit_;
49+
};
50+
51+
} // namespace torchrec
52+
53+
#include <torchrec/csrc/dynamic_embedding/details/bitmap_impl.h>
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <stdint.h>
11+
#include <torchrec/csrc/dynamic_embedding/details/bits_op.h>
12+
13+
namespace torchrec {
14+
15+
template <typename T>
16+
inline Bitmap<T>::Bitmap(int64_t num_bits)
17+
: num_total_bits_(num_bits),
18+
num_values_((num_bits + num_bits_per_value - 1) / num_bits_per_value),
19+
values_(new T[num_values_]),
20+
next_free_bit_(0) {
21+
std::fill(values_.get(), values_.get() + num_values_, -1);
22+
}
23+
24+
template <typename T>
25+
inline int64_t Bitmap<T>::next_free_bit() {
26+
int64_t result = next_free_bit_;
27+
int64_t offset = result / num_bits_per_value;
28+
T value = values_[offset];
29+
// set the last 1 bit to zero
30+
values_[offset] = value & (value - 1);
31+
while (values_[offset] == 0 && offset < num_values_) {
32+
offset++;
33+
}
34+
value = values_[offset];
35+
if (C10_LIKELY(value)) {
36+
next_free_bit_ = offset * num_bits_per_value + ctz(value);
37+
} else {
38+
next_free_bit_ = num_total_bits_;
39+
}
40+
41+
return result;
42+
}
43+
44+
template <typename T>
45+
inline void Bitmap<T>::free_bit(int64_t offset) {
46+
int64_t mask_offset = offset / num_bits_per_value;
47+
int64_t bit_offset = offset % num_bits_per_value;
48+
values_[mask_offset] |= 1 << bit_offset;
49+
next_free_bit_ = std::min(offset, next_free_bit_);
50+
}
51+
template <typename T>
52+
inline bool Bitmap<T>::full() const {
53+
return next_free_bit_ >= num_total_bits_;
54+
}
55+
56+
} // namespace torchrec

torchrec/csrc/dynamic_embedding/details/clz_impl.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ template <typename T>
2020
struct ClzImpl {
2121
/**
2222
* Naive implementation for no __builtin_clz
23-
* @param v
24-
* @return
2523
*/
2624
int operator()(T v) const {
2725
int result = 0;

torchrec/csrc/dynamic_embedding/details/ctz_impl.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ template <typename T>
1414
struct CtzImpl {
1515
/**
1616
* Naive implementation for no __builtin_ctz
17-
* @param v
18-
* @return
1917
*/
2018
int operator()(T v) const {
2119
if (v == 0)

0 commit comments

Comments
 (0)