@@ -137,8 +137,21 @@ typename sb_handle_t::event_t _trsv(
137
137
sb_handle_t & sb_handle, index_t _N, container_t0 _mA, index_t _lda,
138
138
container_t1 _vx, increment_t _incx,
139
139
typename sb_handle_t ::event_t _dependencies) {
140
- return blas::internal::_trsv_impl<4 , 2 , uplo, trn, diag>(
141
- sb_handle, _N, _mA, _lda, _vx, _incx, _dependencies);
140
+ const auto device = sb_handle.get_queue ().get_device ();
141
+ if (device.is_gpu ()) {
142
+ const std::string vendor =
143
+ device.template get_info <cl::sycl::info::device::vendor>();
144
+ if (vendor.find (" Intel" ) == vendor.npos ) {
145
+ return blas::internal::_trsv_impl<32 , 4 , uplo, trn, diag>(
146
+ sb_handle, _N, _mA, _lda, _vx, _incx, _dependencies);
147
+ } else {
148
+ throw std::runtime_error (
149
+ " Trsv operator currently not supported on Intel GPUs" );
150
+ }
151
+ } else {
152
+ return blas::internal::_trsv_impl<4 , 2 , uplo, trn, diag>(
153
+ sb_handle, _N, _mA, _lda, _vx, _incx, _dependencies);
154
+ }
142
155
}
143
156
} // namespace backend
144
157
} // namespace trsv
@@ -152,8 +165,21 @@ typename sb_handle_t::event_t _tbsv(
152
165
sb_handle_t & sb_handle, index_t _N, index_t _K, container_t0 _mA,
153
166
index_t _lda, container_t1 _vx, increment_t _incx,
154
167
const typename sb_handle_t ::event_t & _dependencies) {
155
- return blas::internal::_tbsv_impl<4 , 2 , uplo, trn, diag>(
156
- sb_handle, _N, _K, _mA, _lda, _vx, _incx, _dependencies);
168
+ const auto device = sb_handle.get_queue ().get_device ();
169
+ if (device.is_gpu ()) {
170
+ const std::string vendor =
171
+ device.template get_info <cl::sycl::info::device::vendor>();
172
+ if (vendor.find (" Intel" ) == vendor.npos ) {
173
+ return blas::internal::_tbsv_impl<32 , 4 , uplo, trn, diag>(
174
+ sb_handle, _N, _K, _mA, _lda, _vx, _incx, _dependencies);
175
+ } else {
176
+ throw std::runtime_error (
177
+ " Tbsv operator currently not supported on Intel GPUs" );
178
+ }
179
+ } else {
180
+ return blas::internal::_tbsv_impl<4 , 2 , uplo, trn, diag>(
181
+ sb_handle, _N, _K, _mA, _lda, _vx, _incx, _dependencies);
182
+ }
157
183
}
158
184
} // namespace backend
159
185
} // namespace tbsv
@@ -163,12 +189,24 @@ namespace backend {
163
189
template <uplo_type uplo, transpose_type trn, diag_type diag,
164
190
typename sb_handle_t , typename index_t , typename container_t0,
165
191
typename container_t1, typename increment_t >
166
- typename sb_handle_t ::event_t _tpsv (sb_handle_t & sb_handle, index_t _N,
167
- container_t0 _mA, container_t1 _vx,
168
- increment_t _incx,
169
- const typename sb_handle_t ::event_t & _dependencies) {
170
- return blas::internal::_tpsv_impl<4 , 2 , uplo, trn, diag>(sb_handle, _N, _mA,
171
- _vx, _incx, _dependencies);
192
+ typename sb_handle_t ::event_t _tpsv (
193
+ sb_handle_t & sb_handle, index_t _N, container_t0 _mA, container_t1 _vx,
194
+ increment_t _incx, const typename sb_handle_t ::event_t & _dependencies) {
195
+ const auto device = sb_handle.get_queue ().get_device ();
196
+ if (device.is_gpu ()) {
197
+ const std::string vendor =
198
+ device.template get_info <cl::sycl::info::device::vendor>();
199
+ if (vendor.find (" Intel" ) == vendor.npos ) {
200
+ return blas::internal::_tpsv_impl<32 , 4 , uplo, trn, diag>(
201
+ sb_handle, _N, _mA, _vx, _incx, _dependencies);
202
+ } else {
203
+ throw std::runtime_error (
204
+ " Tpsv operator currently not supported on Intel GPUs" );
205
+ }
206
+ } else {
207
+ return blas::internal::_tpsv_impl<4 , 2 , uplo, trn, diag>(
208
+ sb_handle, _N, _mA, _vx, _incx, _dependencies);
209
+ }
172
210
}
173
211
} // namespace backend
174
212
} // namespace tpsv
0 commit comments