Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Commit a6d044a

Browse files
authored
Add Axpy_batch implementation (#479)
Added axpy_batch extension BLAS operator with benchmarks.
1 parent ece6336 commit a6d044a

23 files changed

+1099
-4
lines changed

benchmark/portblas/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ set(sources
6969
extension/omatcopy_batched.cpp
7070
extension/omatadd.cpp
7171
extension/omatadd_batched.cpp
72+
extension/axpy_batch.cpp
7273
)
7374

7475
if(${BLAS_ENABLE_EXTENSIONS})
+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/***************************************************************************
2+
*
3+
* @license
4+
* Copyright (C) Codeplay Software Limited
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* For your convenience, a copy of the License has been included in this
12+
* repository.
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*
20+
* portBLAS: BLAS implementation using SYCL
21+
*
22+
* @filename axpy_batch.cpp
23+
*
24+
**************************************************************************/
25+
26+
#include "../utils.hpp"
27+
28+
constexpr blas_benchmark::utils::ExtensionOp benchmark_op =
29+
blas_benchmark::utils::ExtensionOp::axpy_batch;
30+
31+
template <typename scalar_t, blas::helper::AllocType mem_alloc>
32+
void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, index_t size,
33+
scalar_t alpha, index_t inc_x, index_t inc_y, index_t stride_x_mul,
34+
index_t stride_y_mul, index_t batch_size, bool* success) {
35+
// initialize the state label
36+
blas_benchmark::utils::set_benchmark_label<scalar_t>(
37+
state, sb_handle_ptr->get_queue());
38+
39+
// Google-benchmark counters are double.
40+
blas_benchmark::utils::init_extension_counters<benchmark_op, scalar_t>(
41+
state, size, batch_size);
42+
43+
blas::SB_Handle& sb_handle = *sb_handle_ptr;
44+
auto q = sb_handle.get_queue();
45+
46+
const auto stride_x{size * std::abs(inc_x) * stride_x_mul};
47+
const auto stride_y{size * std::abs(inc_y) * stride_y_mul};
48+
49+
const index_t size_x{stride_x * batch_size};
50+
const index_t size_y{stride_y * batch_size};
51+
// Create data
52+
std::vector<scalar_t> vx =
53+
blas_benchmark::utils::random_data<scalar_t>(size_x);
54+
std::vector<scalar_t> vy =
55+
blas_benchmark::utils::random_data<scalar_t>(size_y);
56+
57+
auto inx = blas::helper::allocate<mem_alloc, scalar_t>(size_x, q);
58+
auto iny = blas::helper::allocate<mem_alloc, scalar_t>(size_y, q);
59+
60+
auto copy_x =
61+
blas::helper::copy_to_device<scalar_t>(q, vx.data(), inx, size_x);
62+
auto copy_y =
63+
blas::helper::copy_to_device<scalar_t>(q, vy.data(), iny, size_y);
64+
65+
sb_handle.wait({copy_x, copy_y});
66+
67+
#ifdef BLAS_VERIFY_BENCHMARK
68+
// Run a first time with a verification of the results
69+
std::vector<scalar_t> y_ref = vy;
70+
for (auto i = 0; i < batch_size; ++i) {
71+
reference_blas::axpy(size, static_cast<scalar_t>(alpha),
72+
vx.data() + i * stride_x, inc_x,
73+
y_ref.data() + i * stride_y, inc_y);
74+
}
75+
std::vector<scalar_t> y_temp = vy;
76+
{
77+
auto y_temp_gpu = blas::helper::allocate<mem_alloc, scalar_t>(size_y, q);
78+
auto copy_temp = blas::helper::copy_to_device<scalar_t>(q, y_temp.data(),
79+
y_temp_gpu, size_y);
80+
sb_handle.wait(copy_temp);
81+
auto axpy_batch_event =
82+
_axpy_batch(sb_handle, size, alpha, inx, inc_x, stride_x, y_temp_gpu,
83+
inc_y, stride_y, batch_size);
84+
sb_handle.wait(axpy_batch_event);
85+
auto copy_output =
86+
blas::helper::copy_to_host(q, y_temp_gpu, y_temp.data(), size_y);
87+
sb_handle.wait(copy_output);
88+
89+
blas::helper::deallocate<mem_alloc>(y_temp_gpu, q);
90+
}
91+
92+
std::ostringstream err_stream;
93+
if (!utils::compare_vectors(y_temp, y_ref, err_stream, "")) {
94+
const std::string& err_str = err_stream.str();
95+
state.SkipWithError(err_str.c_str());
96+
*success = false;
97+
};
98+
#endif
99+
100+
auto blas_method_def = [&]() -> std::vector<cl::sycl::event> {
101+
auto event = _axpy_batch(sb_handle, size, alpha, inx, inc_x, stride_x, iny,
102+
inc_y, stride_y, batch_size);
103+
sb_handle.wait(event);
104+
return event;
105+
};
106+
107+
// Warmup
108+
blas_benchmark::utils::warmup(blas_method_def);
109+
sb_handle.wait();
110+
111+
blas_benchmark::utils::init_counters(state);
112+
113+
// Measure
114+
for (auto _ : state) {
115+
// Run
116+
std::tuple<double, double> times =
117+
blas_benchmark::utils::timef(blas_method_def);
118+
119+
// Report
120+
blas_benchmark::utils::update_counters(state, times);
121+
}
122+
123+
state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]);
124+
state.SetBytesProcessed(state.iterations() *
125+
state.counters["bytes_processed"]);
126+
127+
blas_benchmark::utils::calc_avg_counters(state);
128+
129+
blas::helper::deallocate<mem_alloc>(inx, q);
130+
blas::helper::deallocate<mem_alloc>(iny, q);
131+
}
132+
133+
template <typename scalar_t, blas::helper::AllocType mem_alloc>
134+
void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success,
135+
std::string mem_type,
136+
std::vector<axpy_batch_param_t<scalar_t>> params) {
137+
for (auto p : params) {
138+
index_t n, inc_x, inc_y, stride_x_mul, stride_y_mul, batch_size;
139+
scalar_t alpha;
140+
std::tie(n, alpha, inc_x, inc_y, stride_x_mul, stride_y_mul, batch_size) =
141+
p;
142+
auto BM_lambda = [&](benchmark::State& st, blas::SB_Handle* sb_handle_ptr,
143+
index_t size, scalar_t alpha, index_t inc_x,
144+
index_t inc_y, index_t stride_x_mul,
145+
index_t stride_y_mul, index_t batch_size,
146+
bool* success) {
147+
run<scalar_t, mem_alloc>(st, sb_handle_ptr, size, alpha, inc_x, inc_y,
148+
stride_x_mul, stride_y_mul, batch_size, success);
149+
};
150+
benchmark::RegisterBenchmark(
151+
blas_benchmark::utils::get_name<benchmark_op, scalar_t, index_t>(
152+
n, alpha, inc_x, inc_y, stride_x_mul, stride_y_mul, batch_size,
153+
mem_type)
154+
.c_str(),
155+
BM_lambda, sb_handle_ptr, n, alpha, inc_x, inc_y, stride_x_mul,
156+
stride_y_mul, batch_size, success)
157+
->UseRealTime();
158+
}
159+
}
160+
161+
template <typename scalar_t>
162+
void register_benchmark(blas_benchmark::Args& args,
163+
blas::SB_Handle* sb_handle_ptr, bool* success) {
164+
auto axpy_batch_params =
165+
blas_benchmark::utils::get_axpy_batch_params<scalar_t>(args);
166+
167+
register_benchmark<scalar_t, blas::helper::AllocType::buffer>(
168+
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER,
169+
axpy_batch_params);
170+
#ifdef SB_ENABLE_USM
171+
register_benchmark<scalar_t, blas::helper::AllocType::usm>(
172+
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM,
173+
axpy_batch_params);
174+
#endif
175+
}
176+
177+
namespace blas_benchmark {
178+
void create_benchmark(blas_benchmark::Args& args,
179+
blas::SB_Handle* sb_handle_ptr, bool* success) {
180+
BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success);
181+
}
182+
} // namespace blas_benchmark

benchmark/rocblas/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ set(sources
7474
# Extension blas
7575
extension/omatcopy.cpp
7676
extension/omatadd.cpp
77-
77+
extension/axpy_batch.cpp
7878
)
7979

8080
# Operators supporting COMPLEX types benchmarking
+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/***************************************************************************
2+
*
3+
* @license
4+
* Copyright (C) Codeplay Software Limited
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* For your convenience, a copy of the License has been included in this
12+
* repository.
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*
20+
* portBLAS: BLAS implementation using SYCL
21+
*
22+
* @filename axpy_batch.cpp
23+
*
24+
**************************************************************************/
25+
26+
#include "../utils.hpp"
27+
#include "common/common_utils.hpp"
28+
29+
constexpr blas_benchmark::utils::ExtensionOp benchmark_op =
30+
blas_benchmark::utils::ExtensionOp::axpy_batch;
31+
32+
template <typename scalar_t, typename... args_t>
33+
static inline void rocblas_axpy_strided_batched_f(args_t&&... args) {
34+
if constexpr (std::is_same_v<scalar_t, float>) {
35+
CHECK_ROCBLAS_STATUS(
36+
rocblas_saxpy_strided_batched(std::forward<args_t>(args)...));
37+
} else if constexpr (std::is_same_v<scalar_t, double>) {
38+
CHECK_ROCBLAS_STATUS(
39+
rocblas_daxpy_strided_batched(std::forward<args_t>(args)...));
40+
}
41+
return;
42+
}
43+
44+
template <typename scalar_t>
45+
void run(benchmark::State& state, rocblas_handle& rb_handle, index_t size,
46+
scalar_t alpha, index_t inc_x, index_t inc_y, index_t stride_x_mul,
47+
index_t stride_y_mul, index_t batch_size, bool* success) {
48+
// initialize the state label
49+
blas_benchmark::utils::set_benchmark_label<scalar_t>(state);
50+
51+
// Google-benchmark counters are double.
52+
blas_benchmark::utils::init_extension_counters<benchmark_op, scalar_t>(
53+
state, size, batch_size);
54+
55+
const auto stride_x{size * std::abs(inc_x) * stride_x_mul};
56+
const auto stride_y{size * std::abs(inc_y) * stride_y_mul};
57+
58+
const index_t size_x{stride_x * batch_size};
59+
const index_t size_y{stride_y * batch_size};
60+
// Create data
61+
std::vector<scalar_t> vx =
62+
blas_benchmark::utils::random_data<scalar_t>(size_x);
63+
std::vector<scalar_t> vy =
64+
blas_benchmark::utils::random_data<scalar_t>(size_y);
65+
66+
blas_benchmark::utils::HIPVector<scalar_t> inx(size_x, vx.data());
67+
blas_benchmark::utils::HIPVector<scalar_t> iny(size_y, vy.data());
68+
69+
#ifdef BLAS_VERIFY_BENCHMARK
70+
// Run a first time with a verification of the results
71+
std::vector<scalar_t> y_ref = vy;
72+
for (auto i = 0; i < batch_size; ++i) {
73+
reference_blas::axpy(size, static_cast<scalar_t>(alpha),
74+
vx.data() + i * stride_x, inc_x,
75+
y_ref.data() + i * stride_y, inc_y);
76+
}
77+
std::vector<scalar_t> y_temp = vy;
78+
{
79+
blas_benchmark::utils::HIPVector<scalar_t, true> y_temp_gpu(size_y,
80+
y_temp.data());
81+
rocblas_axpy_strided_batched_f<scalar_t>(rb_handle, size, &alpha, inx,
82+
inc_x, stride_x, y_temp_gpu, inc_y,
83+
stride_y, batch_size);
84+
}
85+
86+
std::ostringstream err_stream;
87+
if (!utils::compare_vectors(y_temp, y_ref, err_stream, "")) {
88+
const std::string& err_str = err_stream.str();
89+
state.SkipWithError(err_str.c_str());
90+
*success = false;
91+
};
92+
#endif
93+
94+
auto blas_warmup = [&]() -> void {
95+
rocblas_axpy_strided_batched_f<scalar_t>(rb_handle, size, &alpha, inx,
96+
inc_x, stride_x, iny, inc_y,
97+
stride_y, batch_size);
98+
return;
99+
};
100+
101+
hipEvent_t start, stop;
102+
CHECK_HIP_ERROR(hipEventCreate(&start));
103+
CHECK_HIP_ERROR(hipEventCreate(&stop));
104+
105+
auto blas_method_def = [&]() -> std::vector<hipEvent_t> {
106+
CHECK_HIP_ERROR(hipEventRecord(start, NULL));
107+
rocblas_axpy_strided_batched_f<scalar_t>(rb_handle, size, &alpha, inx,
108+
inc_x, stride_x, iny, inc_y,
109+
stride_y, batch_size);
110+
CHECK_HIP_ERROR(hipEventRecord(stop, NULL));
111+
CHECK_HIP_ERROR(hipEventSynchronize(stop));
112+
return std::vector{start, stop};
113+
};
114+
115+
// Warmup
116+
blas_benchmark::utils::warmup(blas_method_def);
117+
CHECK_HIP_ERROR(hipStreamSynchronize(NULL));
118+
119+
blas_benchmark::utils::init_counters(state);
120+
121+
// Measure
122+
for (auto _ : state) {
123+
// Run
124+
std::tuple<double, double> times =
125+
blas_benchmark::utils::timef_hip(blas_method_def);
126+
127+
// Report
128+
blas_benchmark::utils::update_counters(state, times);
129+
}
130+
131+
state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]);
132+
state.SetBytesProcessed(state.iterations() *
133+
state.counters["bytes_processed"]);
134+
135+
blas_benchmark::utils::calc_avg_counters(state);
136+
137+
CHECK_HIP_ERROR(hipEventDestroy(start));
138+
CHECK_HIP_ERROR(hipEventDestroy(stop));
139+
}
140+
141+
template <typename scalar_t>
142+
void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
143+
bool* success) {
144+
auto axpy_batch_params =
145+
blas_benchmark::utils::get_axpy_batch_params<scalar_t>(args);
146+
147+
for (auto p : axpy_batch_params) {
148+
index_t n, inc_x, inc_y, stride_x_mul, stride_y_mul, batch_size;
149+
scalar_t alpha;
150+
std::tie(n, alpha, inc_x, inc_y, stride_x_mul, stride_y_mul, batch_size) =
151+
p;
152+
auto BM_lambda =
153+
[&](benchmark::State& st, rocblas_handle rb_handle, index_t size,
154+
scalar_t alpha, index_t inc_x, index_t inc_y, index_t stride_x_mul,
155+
index_t stride_y_mul, index_t batch_size, bool* success) {
156+
run<scalar_t>(st, rb_handle, size, alpha, inc_x, inc_y, stride_x_mul,
157+
stride_y_mul, batch_size, success);
158+
};
159+
benchmark::RegisterBenchmark(
160+
blas_benchmark::utils::get_name<benchmark_op, scalar_t, index_t>(
161+
n, alpha, inc_x, inc_y, stride_x_mul, stride_y_mul, batch_size,
162+
blas_benchmark::utils::MEM_TYPE_USM)
163+
.c_str(),
164+
BM_lambda, rb_handle, n, alpha, inc_x, inc_y, stride_x_mul,
165+
stride_y_mul, batch_size, success)
166+
->UseRealTime();
167+
}
168+
}
169+
170+
namespace blas_benchmark {
171+
void create_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
172+
bool* success) {
173+
BLAS_REGISTER_BENCHMARK(args, rb_handle, success);
174+
}
175+
} // namespace blas_benchmark

cmake/CmakeFunctionHelper.cmake

+2-1
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,8 @@ function (build_library LIB_NAME ENABLE_EXTENSIONS)
701701
$<TARGET_OBJECTS:matcopy_batch>
702702
$<TARGET_OBJECTS:transpose>
703703
$<TARGET_OBJECTS:omatadd>
704-
$<TARGET_OBJECTS:omatadd_batch>)
704+
$<TARGET_OBJECTS:omatadd_batch>
705+
$<TARGET_OBJECTS:axpy_batch>)
705706

706707
if (${ENABLE_EXTENSIONS})
707708
list(APPEND LIB_SRCS $<TARGET_OBJECTS:reduction>)

0 commit comments

Comments
 (0)