@@ -59,9 +59,36 @@ struct gemm_emul_tinysq
59
59
60
60
61
61
62
+ struct gemm_emul_large_mp_helper
63
+ {
64
+ template <typename eT>
65
+ arma_hot
66
+ inline
67
+ static
68
+ void
69
+ copy_row (eT* out_mem, const Mat<eT>& in, const uword row)
70
+ {
71
+ const uword n_rows = in.n_rows ;
72
+ const uword n_cols = in.n_cols ;
73
+
74
+ const eT* in_mem_row = in.memptr () + row;
75
+
76
+ for (uword i=0 ; i < n_cols; ++i)
77
+ {
78
+ out_mem[i] = (*in_mem_row);
79
+
80
+ in_mem_row += n_rows;
81
+ }
82
+ }
83
+ };
84
+
85
+
86
+
87
+ #if defined(ARMA_USE_OPENMP)
62
88
// ! emulation of gemm(), for non-complex matrices only, as it assumes only simple transposes (ie. doesn't do hermitian transposes)
89
+ // ! parallelised version
63
90
template <const bool do_trans_A=false , const bool do_trans_B=false , const bool use_alpha=false , const bool use_beta=false >
64
- struct gemm_emul_large
91
+ struct gemm_emul_large_mp
65
92
{
66
93
template <typename eT, typename TA, typename TB>
67
94
arma_hot
@@ -78,13 +105,151 @@ struct gemm_emul_large
78
105
)
79
106
{
80
107
arma_debug_sigprint ();
108
+
109
+ const uword A_n_rows = A.n_rows ;
110
+ const uword A_n_cols = A.n_cols ;
111
+
112
+ const uword B_n_rows = B.n_rows ;
113
+ const uword B_n_cols = B.n_cols ;
114
+
115
+ if ( (do_trans_A == false ) && (do_trans_B == false ) )
116
+ {
117
+ const uword n_threads = uword (mp_thread_limit::get ());
118
+
119
+ podarray<eT> tmp (A_n_cols * n_threads, arma_nozeros_indicator ());
120
+
121
+ eT* tmp_mem = tmp.memptr ();
122
+
123
+ #pragma omp parallel for schedule(static) num_threads(int(n_threads))
124
+ for (uword row_A=0 ; row_A < A_n_rows; ++row_A)
125
+ {
126
+ const uword thread_id = uword (omp_get_thread_num ());
127
+
128
+ eT* A_rowdata = tmp_mem + (A_n_cols * thread_id);
129
+
130
+ gemm_emul_large_mp_helper::copy_row (A_rowdata, A, row_A);
131
+
132
+ for (uword col_B=0 ; col_B < B_n_cols; ++col_B)
133
+ {
134
+ const eT acc = op_dot::direct_dot (B_n_rows, A_rowdata, B.colptr (col_B));
135
+
136
+ if ( (use_alpha == false ) && (use_beta == false ) ) { C.at (row_A,col_B) = acc; }
137
+ else if ( (use_alpha == true ) && (use_beta == false ) ) { C.at (row_A,col_B) = alpha*acc; }
138
+ else if ( (use_alpha == false ) && (use_beta == true ) ) { C.at (row_A,col_B) = acc + beta*C.at (row_A,col_B); }
139
+ else if ( (use_alpha == true ) && (use_beta == true ) ) { C.at (row_A,col_B) = alpha*acc + beta*C.at (row_A,col_B); }
140
+ }
141
+ }
142
+ }
143
+ else
144
+ if ( (do_trans_A == true ) && (do_trans_B == false ) )
145
+ {
146
+ const int n_threads = mp_thread_limit::get ();
147
+
148
+ #pragma omp parallel for schedule(static) num_threads(n_threads)
149
+ for (uword col_A=0 ; col_A < A_n_cols; ++col_A)
150
+ {
151
+ // col_A is interpreted as row_A when storing the results in matrix C
152
+
153
+ const eT* A_coldata = A.colptr (col_A);
154
+
155
+ for (uword col_B=0 ; col_B < B_n_cols; ++col_B)
156
+ {
157
+ const eT acc = op_dot::direct_dot (B_n_rows, A_coldata, B.colptr (col_B));
158
+
159
+ if ( (use_alpha == false ) && (use_beta == false ) ) { C.at (col_A,col_B) = acc; }
160
+ else if ( (use_alpha == true ) && (use_beta == false ) ) { C.at (col_A,col_B) = alpha*acc; }
161
+ else if ( (use_alpha == false ) && (use_beta == true ) ) { C.at (col_A,col_B) = acc + beta*C.at (col_A,col_B); }
162
+ else if ( (use_alpha == true ) && (use_beta == true ) ) { C.at (col_A,col_B) = alpha*acc + beta*C.at (col_A,col_B); }
163
+ }
164
+ }
165
+ }
166
+ else
167
+ if ( (do_trans_A == false ) && (do_trans_B == true ) )
168
+ {
169
+ Mat<eT> BB;
170
+ op_strans::apply_mat_noalias (BB, B);
171
+
172
+ gemm_emul_large_mp<false , false , use_alpha, use_beta>::apply (C, A, BB, alpha, beta);
173
+ }
174
+ else
175
+ if ( (do_trans_A == true ) && (do_trans_B == true ) )
176
+ {
177
+ // using trans(A)*trans(B) = trans(B*A) equivalency; assuming no hermitian transpose
178
+
179
+ const uword n_threads = uword (mp_thread_limit::get ());
180
+
181
+ podarray<eT> tmp (B_n_cols * n_threads, arma_nozeros_indicator ());
182
+
183
+ eT* tmp_mem = tmp.memptr ();
184
+
185
+ #pragma omp parallel for schedule(static) num_threads(int(n_threads))
186
+ for (uword row_B=0 ; row_B < B_n_rows; ++row_B)
187
+ {
188
+ const uword thread_id = uword (omp_get_thread_num ());
189
+
190
+ eT* B_rowdata = tmp_mem + (B_n_cols * thread_id);
191
+
192
+ gemm_emul_large_mp_helper::copy_row (B_rowdata, B, row_B);
193
+
194
+ for (uword col_A=0 ; col_A < A_n_cols; ++col_A)
195
+ {
196
+ const eT acc = op_dot::direct_dot (A_n_rows, B_rowdata, A.colptr (col_A));
197
+
198
+ if ( (use_alpha == false ) && (use_beta == false ) ) { C.at (col_A,row_B) = acc; }
199
+ else if ( (use_alpha == true ) && (use_beta == false ) ) { C.at (col_A,row_B) = alpha*acc; }
200
+ else if ( (use_alpha == false ) && (use_beta == true ) ) { C.at (col_A,row_B) = acc + beta*C.at (col_A,row_B); }
201
+ else if ( (use_alpha == true ) && (use_beta == true ) ) { C.at (col_A,row_B) = alpha*acc + beta*C.at (col_A,row_B); }
202
+ }
203
+ }
204
+ }
205
+ }
206
+
207
+ };
208
+ #endif
209
+
210
+
81
211
212
+ // ! emulation of gemm(), for non-complex matrices only, as it assumes only simple transposes (ie. doesn't do hermitian transposes)
213
+ template <const bool do_trans_A=false , const bool do_trans_B=false , const bool use_alpha=false , const bool use_beta=false >
214
+ struct gemm_emul_large
215
+ {
216
+ template <typename eT, typename TA, typename TB>
217
+ arma_hot
218
+ inline
219
+ static
220
+ void
221
+ apply
222
+ (
223
+ Mat<eT>& C,
224
+ const TA& A,
225
+ const TB& B,
226
+ const eT alpha = eT(1 ),
227
+ const eT beta = eT(0 )
228
+ )
229
+ {
230
+ arma_debug_sigprint ();
231
+
82
232
const uword A_n_rows = A.n_rows ;
83
233
const uword A_n_cols = A.n_cols ;
84
234
85
235
const uword B_n_rows = B.n_rows ;
86
236
const uword B_n_cols = B.n_cols ;
87
237
238
+ #if defined(ARMA_USE_OPENMP)
239
+ {
240
+ // TODO: replace with more sophisticated threshold mechanism
241
+
242
+ constexpr uword threshold = uword (30 );
243
+
244
+ if ( (A_n_rows >= threshold) && (A_n_cols >= threshold) && (B_n_rows >= threshold) && (B_n_cols >= threshold) && (mp_thread_limit::in_parallel () == false ) )
245
+ {
246
+ gemm_emul_large_mp<do_trans_A, do_trans_B, use_alpha, use_beta>::apply (C,A,B,alpha,beta);
247
+
248
+ return ;
249
+ }
250
+ }
251
+ #endif
252
+
88
253
if ( (do_trans_A == false ) && (do_trans_B == false ) )
89
254
{
90
255
arma_aligned podarray<eT> tmp (A_n_cols);
0 commit comments