Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit e80e698

Browse files
Add Kronecker benchmark
This commit adds a Kronecker decomposition benchmark and saves the best options found by the tuner for future reproducibility. Since Kronecker Recurrent Units are meant to replace large FC layers we compare against that baseline.
1 parent 0dddfce commit e80e698

File tree

7 files changed

+1213
-13
lines changed

7 files changed

+1213
-13
lines changed

.jenkins/build.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,8 @@ WITH_CAFFE2=ON CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda CLANG_PREFIX=$(${CONDA_PREF
6969
python setup.py install
7070
./test_python/run_test.sh
7171

72-
FILTER_OUT=MLP_model ./test.sh
72+
FILTER_OUT="MLP_model kronecker" ./test.sh
73+
# 2LUT can OOM on smaller Maxwells on our CI machines
7374
./build/tc/benchmarks/MLP_model --gtest_filter=-*2LUT*
75+
# Kronecker xxxAsMatMul can OOM
76+
./build/tc/benchmarks/kronecker --gtest_filter=-*AsMatMul*

tc/benchmarks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ find_library(CUDA_CUDNN_LIBRARIES cudnn
1818
set(EXAMPLES_FILES
1919
batchmatmul
2020
group_convolution
21+
kronecker
2122
tmm
2223
MLP_model
2324
)

tc/benchmarks/benchmark_fixture.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ struct Benchmark : public ::testing::Test {
9696
}
9797
}
9898

99-
template <typename CheckFunction>
100-
void Check(
99+
using CheckFunction = std::function<bool(
100+
const std::vector<at::Tensor>& inputs,
101+
const std::vector<at::Tensor>& outputs)>;
102+
std::vector<at::Tensor> Check(
101103
const std::string& tc,
102104
const std::string& name,
103105
const tc::CudaMappingOptions& mappingOptions,
@@ -184,6 +186,8 @@ struct Benchmark : public ::testing::Test {
184186
std::cout << "\n\n";
185187

186188
#undef GET_US
189+
190+
return outputs;
187191
}
188192

189193
template <typename InitFunction, typename InplaceFunction>
@@ -230,7 +234,6 @@ struct Benchmark : public ::testing::Test {
230234
#undef GET_US
231235
}
232236

233-
template <typename CheckFunction>
234237
void validateProto(
235238
std::string cacheFilename,
236239
const std::string& tc,
@@ -342,8 +345,7 @@ struct Benchmark : public ::testing::Test {
342345
#undef GET_US
343346
}
344347

345-
template <typename CheckFunction>
346-
void autotune(
348+
std::vector<tc::CudaMappingOptions> autotune(
347349
std::string cacheFilename,
348350
std::string resultsFilename,
349351
std::string tc,
@@ -442,6 +444,10 @@ struct Benchmark : public ::testing::Test {
442444
<< "\n---------------------------------------------------------";
443445
std::cout << "\n\n";
444446
#undef GET_US
447+
448+
return {bestOptions};
445449
}
450+
451+
return {};
446452
}
447453
};

0 commit comments

Comments
 (0)