Skip to content

Commit e57828e

Browse files
authored
Various simplifications in cppdlr/utils.hpp (#14)
1 parent e59f710 commit e57828e

File tree

1 file changed

+48
-87
lines changed

1 file changed

+48
-87
lines changed

c++/cppdlr/utils.hpp

Lines changed: 48 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
namespace cppdlr {
2424
using dcomplex = std::complex<double>;
2525

26+
/**
27+
* Calculate the squared norm of a vector
28+
*
29+
* @param v The input vector
30+
* @return x The squared norm of the vector
31+
*/
32+
double normsq(nda::MemoryVector auto const &v) { return nda::real(nda::blas::dotc(v, v)); }
33+
2634
/**
2735
* Class constructor for barycheb: barycentric Lagrange interpolation at
2836
* Chebyshev nodes.
@@ -116,10 +124,10 @@ namespace cppdlr {
116124
// Compute norms of rows of input matrix, and rescale eps tolerance
117125
auto norms = nda::vector<double>(m);
118126
double epssq = eps * eps;
119-
for (int j = 0; j < m; ++j) { norms(j) = nda::real(nda::blas::dotc(aa(j, _), aa(j, _))); }
127+
for (int j = 0; j < m; ++j) { norms(j) = normsq(aa(j, _)); }
120128

121129
// Begin pivoted double Gram-Schmidt procedure
122-
int jpiv = 0, jj = 0;
130+
int jpiv = 0;
123131
double nrm = 0;
124132
auto piv = nda::arange(m);
125133
auto tmp = nda::vector<S>(n);
@@ -137,38 +145,29 @@ namespace cppdlr {
137145
}
138146

139147
// Swap current row with chosen pivot row
140-
tmp = aa(j, _);
141-
aa(j, _) = aa(jpiv, _);
142-
aa(jpiv, _) = tmp;
143-
144-
nrm = norms(j);
145-
norms(j) = norms(jpiv);
146-
norms(jpiv) = nrm;
147-
148-
jj = piv(j);
149-
piv(j) = piv(jpiv);
150-
piv(jpiv) = jj;
148+
deep_swap(aa(j, _), aa(jpiv, _));
149+
std::swap(norms(j), norms(jpiv));
150+
std::swap(piv(j), piv(jpiv));
151151

152152
// Orthogonalize current rows (now the chosen pivot row) against all
153153
// previously chosen rows
154154
for (int k = 0; k < j; ++k) { aa(j, _) = aa(j, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j, _)); }
155155

156156
// Get norm of current row
157-
nrm = nda::real(nda::blas::dotc(aa(j, _), aa(j, _)));
158-
//nrm = nda::norm(aa(j, _));
157+
nrm = normsq(aa(j, _));
159158

160159
// Terminate if sufficiently small, and return previously selected rows
161160
// (not including current row)
162161
if (nrm <= epssq) { return {aa(nda::range(0, j), _), norms(nda::range(0, j)), piv(nda::range(0, j))}; };
163162

164163
// Normalize current row
165-
aa(j, _) = aa(j, _) * (1 / sqrt(nrm));
164+
aa(j, _) /= sqrt(nrm);
166165

167166
// Orthogonalize remaining rows against current row
168167
for (int k = j + 1; k < m; ++k) {
169168
if (norms(k) <= epssq) { continue; } // Can skip rows with norm less than tolerance
170169
aa(k, _) = aa(k, _) - aa(j, _) * nda::blas::dotc(aa(j, _), aa(k, _));
171-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
170+
norms(k) = normsq(aa(k, _));
172171
}
173172
}
174173

@@ -211,22 +210,21 @@ namespace cppdlr {
211210
if (m % 2 != 0) { throw std::runtime_error("Input matrix must have even number of rows."); }
212211

213212
// Copy input data, re-ordering rows to make symmetric rows adjacent.
214-
auto aa = typename T::regular_type(m, n);
213+
auto aa = typename T::regular_type(m, n);
215214
aa(nda::range(0, m, 2), _) = a(nda::range(0, m / 2), _);
216215
aa(nda::range(1, m, 2), _) = a(nda::range(m - 1, m / 2 - 1, -1), _);
217216

218217
// Compute norms of rows of input matrix, and rescale eps tolerance
219218
auto norms = nda::vector<double>(m);
220219
double epssq = eps * eps;
221-
for (int j = 0; j < m; ++j) { norms(j) = nda::real(nda::blas::dotc(aa(j, _), aa(j, _))); }
220+
for (int j = 0; j < m; ++j) { norms(j) = normsq(aa(j, _)); }
222221

223222
// Begin pivoted double Gram-Schmidt procedure
224-
int jpiv = 0, jj = 0;
225-
double nrm = 0;
226-
auto piv = nda::arange(0, m);
223+
int jpiv = 0;
224+
double nrm = 0;
225+
auto piv = nda::arange(0, m);
227226
piv(nda::range(0, m, 2)) = nda::arange(0, m / 2); // Re-order pivots to match re-ordered input matrix
228227
piv(nda::range(1, m, 2)) = nda::arange(m - 1, m / 2 - 1, -1);
229-
auto tmp = nda::vector<S>(n);
230228

231229
if (maxrnk % 2 != 0) { // If n < m and n is odd, decrease maxrnk to maintain symmetry
232230
maxrnk -= 1;
@@ -245,61 +243,46 @@ namespace cppdlr {
245243
}
246244

247245
// Swap current row pair with chosen pivot row pair
248-
tmp = aa(j, _);
249-
aa(j, _) = aa(jpiv, _);
250-
aa(jpiv, _) = tmp;
251-
tmp = aa(j + 1, _);
252-
aa(j + 1, _) = aa(jpiv + 1, _);
253-
aa(jpiv + 1, _) = tmp;
254-
255-
nrm = norms(j);
256-
norms(j) = norms(jpiv);
257-
norms(jpiv) = nrm;
258-
nrm = norms(j + 1);
259-
norms(j + 1) = norms(jpiv + 1);
260-
norms(jpiv + 1) = nrm;
261-
262-
jj = piv(j);
263-
piv(j) = piv(jpiv);
264-
piv(jpiv) = jj;
265-
jj = piv(j + 1);
266-
piv(j + 1) = piv(jpiv + 1);
267-
piv(jpiv + 1) = jj;
246+
deep_swap(aa(j, _), aa(jpiv, _));
247+
deep_swap(aa(j + 1, _), aa(jpiv + 1, _));
248+
std::swap(norms(j), norms(jpiv));
249+
std::swap(norms(j + 1), norms(jpiv + 1));
250+
std::swap(piv(j), piv(jpiv));
251+
std::swap(piv(j + 1), piv(jpiv + 1));
268252

269253
// Orthogonalize current row (now the first chosen pivot row) against all
270254
// previously chosen rows
271255
for (int k = 0; k < j; ++k) { aa(j, _) = aa(j, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j, _)); }
272256

273257
// Get norm of current row
274-
nrm = nda::real(nda::blas::dotc(aa(j, _), aa(j, _)));
258+
nrm = normsq(aa(j, _));
275259

276260
// Terminate if sufficiently small, and return previously selected rows
277261
// (not including current row)
278262
if (nrm <= epssq) { return {aa(nda::range(0, j), _), norms(nda::range(0, j)), piv(nda::range(0, j))}; };
279263

280264
// Normalize current row
281-
aa(j, _) = aa(j, _) * (1 / sqrt(nrm));
265+
aa(j, _) /= sqrt(nrm);
282266

283267
// Orthogonalize remaining rows against current row
284268
for (int k = j + 1; k < m; ++k) {
285269
if (norms(k) <= epssq) { continue; } // Can skip rows with norm less than tolerance
286270
aa(k, _) = aa(k, _) - aa(j, _) * nda::blas::dotc(aa(j, _), aa(k, _));
287-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
271+
norms(k) = normsq(aa(k, _));
288272
}
289273

290274
// Orthogonalize current row (now the second chosen pivot row) against all
291275
// previously chosen rows
292276
for (int k = 0; k < j + 1; ++k) { aa(j + 1, _) = aa(j + 1, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j + 1, _)); }
293277

294278
// Normalize current row
295-
nrm = nda::real(nda::blas::dotc(aa(j + 1, _), aa(j + 1, _)));
296-
aa(j + 1, _) = aa(j + 1, _) * (1 / sqrt(nrm));
279+
aa(j + 1, _) /= sqrt(normsq(aa(j + 1, _)));
297280

298281
// Orthogonalize remaining rows against current row
299282
for (int k = j + 2; k < m; ++k) {
300283
if (norms(k) <= epssq) { continue; } // Can skip rows with norm less than tolerance
301284
aa(k, _) = aa(k, _) - aa(j + 1, _) * nda::blas::dotc(aa(j + 1, _), aa(k, _));
302-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
285+
norms(k) = normsq(aa(k, _));
303286
}
304287
}
305288

@@ -352,18 +335,18 @@ namespace cppdlr {
352335
aa(nda::range(0, m, 2), _) = a(nda::range(0, m / 2), _);
353336
aa(nda::range(1, m, 2), _) = a(nda::range(m - 1, m / 2 - 1, -1), _);
354337
} else {
355-
aa(0, _) = a((m - 1) / 2, _);
338+
aa(0, _) = a((m - 1) / 2, _);
356339
aa(nda::range(1, m, 2), _) = a(nda::range(0, (m - 1) / 2), _);
357340
aa(nda::range(2, m, 2), _) = a(nda::range(m - 1, (m - 1) / 2, -1), _);
358341
//aa(m - 1, _) = a((m - 1) / 2, _);
359342
}
360343

361344
// Compute norms of rows of input matrix
362345
auto norms = nda::vector<double>(m);
363-
for (int j = 0; j < m; ++j) { norms(j) = nda::real(nda::blas::dotc(aa(j, _), aa(j, _))); }
346+
for (int j = 0; j < m; ++j) { norms(j) = normsq(aa(j, _)); }
364347

365348
// Begin pivoted double Gram-Schmidt procedure
366-
int jpiv = 0, jj = 0;
349+
int jpiv = 0;
367350
double nrm = 0;
368351
auto piv = nda::arange(0, m);
369352
if (m % 2 == 0) {
@@ -375,23 +358,17 @@ namespace cppdlr {
375358
piv(nda::range(2, m, 2)) = nda::arange(m - 1, (m - 1) / 2, -1);
376359
//piv(m - 1) = (m - 1) / 2;
377360
}
378-
auto tmp = nda::vector<S>(n);
379361

380362
// If m odd, first choose middle row (now last row) as first pivot
381363

382364
if (m % 2 == 1) {
383-
//int j = 0; // Index of current row
384-
//jpiv = 0; // Index of pivot row
385-
386365
// Normalize
387-
nrm = nda::real(nda::blas::dotc(aa(0, _), aa(0, _)));
388-
aa(0, _) = aa(0, _) * (1 / sqrt(nrm));
389-
//aa(0, _) /= sqrt(nda::real(nda::blas::dotc(aa(0, _), aa(0, _))));
366+
aa(0, _) /= sqrt(normsq(aa(0, _)));
390367

391368
// Orthogonalize remaining rows against current row
392369
for (int k = 1; k < m; ++k) {
393370
aa(k, _) = aa(k, _) - aa(0, _) * nda::blas::dotc(aa(0, _), aa(k, _));
394-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
371+
norms(k) = normsq(aa(k, _));
395372
}
396373
}
397374

@@ -410,53 +387,37 @@ namespace cppdlr {
410387
}
411388

412389
// Swap current row pair with chosen pivot row pair
413-
tmp = aa(j, _);
414-
aa(j, _) = aa(jpiv, _);
415-
aa(jpiv, _) = tmp;
416-
tmp = aa(j + 1, _);
417-
aa(j + 1, _) = aa(jpiv + 1, _);
418-
aa(jpiv + 1, _) = tmp;
419-
420-
nrm = norms(j);
421-
norms(j) = norms(jpiv);
422-
norms(jpiv) = nrm;
423-
nrm = norms(j + 1);
424-
norms(j + 1) = norms(jpiv + 1);
425-
norms(jpiv + 1) = nrm;
426-
427-
jj = piv(j);
428-
piv(j) = piv(jpiv);
429-
piv(jpiv) = jj;
430-
jj = piv(j + 1);
431-
piv(j + 1) = piv(jpiv + 1);
432-
piv(jpiv + 1) = jj;
390+
deep_swap(aa(j, _), aa(jpiv, _));
391+
deep_swap(aa(j + 1, _), aa(jpiv + 1, _));
392+
std::swap(norms(j), norms(jpiv));
393+
std::swap(norms(j + 1), norms(jpiv + 1));
394+
std::swap(piv(j), piv(jpiv));
395+
std::swap(piv(j + 1), piv(jpiv + 1));
433396

434397
// Orthogonalize current row (now the first chosen pivot row) against all
435398
// previously chosen rows
436399
for (int k = 0; k < j; ++k) { aa(j, _) = aa(j, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j, _)); }
437400

438401
// Normalize current row
439-
nrm = nda::real(nda::blas::dotc(aa(j, _), aa(j, _)));
440-
aa(j, _) = aa(j, _) * (1 / sqrt(nrm));
402+
aa(j, _) /= sqrt(normsq(aa(j, _)));
441403

442404
// Orthogonalize remaining rows against current row
443405
for (int k = j + 1; k < m; ++k) {
444406
aa(k, _) = aa(k, _) - aa(j, _) * nda::blas::dotc(aa(j, _), aa(k, _));
445-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
407+
norms(k) = normsq(aa(k, _));
446408
}
447409

448410
// Orthogonalize current row (now the second chosen pivot row) against all
449411
// previously chosen rows
450412
for (int k = 0; k < j + 1; ++k) { aa(j + 1, _) = aa(j + 1, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j + 1, _)); }
451413

452414
// Normalize current row
453-
nrm = nda::real(nda::blas::dotc(aa(j + 1, _), aa(j + 1, _)));
454-
aa(j + 1, _) = aa(j + 1, _) * (1 / sqrt(nrm));
415+
aa(j + 1, _) /= sqrt(normsq(aa(j + 1, _)));
455416

456417
// Orthogonalize remaining rows against current row
457418
for (int k = j + 2; k < m; ++k) {
458419
aa(k, _) = aa(k, _) - aa(j + 1, _) * nda::blas::dotc(aa(j + 1, _), aa(k, _));
459-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
420+
norms(k) = normsq(aa(k, _));
460421
}
461422
}
462423

@@ -551,7 +512,7 @@ namespace cppdlr {
551512
* @return Contraction of the inner dimensions of \p a and \p b
552513
*/
553514
template <nda::MemoryArray Ta, nda::MemoryArray Tb, nda::Scalar Sa = nda::get_value_t<Ta>, nda::Scalar Sb = nda::get_value_t<Tb>,
554-
nda::Scalar S = typename std::common_type<Sa, Sb>::type>
515+
nda::Scalar S = std::common_type_t<Sa, Sb>>
555516
nda::array<S, Ta::rank + Tb::rank - 2> arraymult(Ta const &a, Tb const &b) {
556517

557518
// Get ranks of input arrays

0 commit comments

Comments
 (0)