Skip to content

Commit

Permalink
sqrtm: Use more efficient calculation for Hermitian matrices (bug #60…
Browse files Browse the repository at this point in the history
…797).

* libinterp/corefcn/sqrtm.cc (sqrtm_utri_inplace): Check if Schur matrix is
(nearly) diagonal and use more efficient calculation in that case.
  • Loading branch information
mmuetzel committed Aug 1, 2021
1 parent d14c89c commit 3c9c7f5
Showing 1 changed file with 80 additions and 37 deletions.
117 changes: 80 additions & 37 deletions libinterp/corefcn/sqrtm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,51 +40,94 @@

OCTAVE_BEGIN_NAMESPACE(octave)

template <typename Matrix>
template <typename T>
static void
sqrtm_utri_inplace (Matrix& T)
sqrtm_utri_inplace (T& m)
{
typedef typename Matrix::element_type element_type;
typedef typename T::element_type element_type;
typedef typename T::real_matrix_type real_matrix_type;
typedef typename T::real_elt_type real_elt_type;

const element_type zero = element_type ();

bool singular = false;
bool diagonal = true;

// The following code is equivalent to this triple loop:
//
// n = rows (T);
// for j = 1:n
// T(j,j) = sqrt (T(j,j));
// for i = j-1:-1:1
// if T(i,j) != 0
// T(i,j) /= (T(i,i) + T(j,j));
// endif
// k = 1:i-1;
// T(k,j) -= T(k,i) * T(i,j);
// endfor
// endfor
//
// this is an in-place, cache-aligned variant of the code
// given in Higham's paper.

const octave_idx_type n = T.rows ();
element_type *Tp = T.rwdata ();
for (octave_idx_type j = 0; j < n; j++)
// The Schur matrix of Hermitian matrices is diagonal.
// check for off-diagonal elements above tolerance
const octave_idx_type n = m.rows ();
real_matrix_type abs_m = m.abs ();

real_elt_type max_abs_diag = 0;
for (octave_idx_type i = 0; i < n; i++)
max_abs_diag = std::max (max_abs_diag, abs_m(i,i));

const real_elt_type tol = n * max_abs_diag
* std::numeric_limits<real_elt_type>::epsilon ();

for (octave_idx_type j = 0; j < n; j++)
{
for (octave_idx_type i = j-1; i >= 0; i--)
{
if (abs_m(i,j) > tol)
{
diagonal = false;
break;
}
}
if (! diagonal)
break;
}

element_type *mp = m.fortran_vec ();
if (diagonal)
{
element_type *colj = Tp + n*j;
if (colj[j] != zero)
colj[j] = sqrt (colj[j]);
else
singular = true;
// shortcut for diagonal Schur matrices
for (octave_idx_type i = 0; i < n; i++)
{
octave_idx_type idx_diag = i*(n+1);
if (mp[idx_diag] != zero)
mp[idx_diag] = sqrt (mp[idx_diag]);
else
singular = true;
}
}
else
{
// The following code is equivalent to this triple loop:
//
// n = rows (m);
// for j = 1:n
// m(j,j) = sqrt (m(j,j));
// for i = j-1:-1:1
// if m(i,j) != 0
// m(i,j) /= (m(i,i) + m(j,j));
// endif
// k = 1:i-1;
// m(k,j) -= m(k,i) * m(i,j);
// endfor
// endfor
//
// this is an in-place, cache-aligned variant of the code
// given in Higham's paper.

for (octave_idx_type i = j-1; i >= 0; i--)
for (octave_idx_type j = 0; j < n; j++)
{
const element_type *coli = Tp + n*i;
if (colj[i] != zero)
colj[i] /= (coli[i] + colj[j]);
const element_type colji = colj[i];
for (octave_idx_type k = 0; k < i; k++)
colj[k] -= coli[k] * colji;
element_type *colj = mp + n*j;
if (colj[j] != zero)
colj[j] = sqrt (colj[j]);
else
singular = true;

for (octave_idx_type i = j-1; i >= 0; i--)
{
const element_type *coli = mp + n*i;
if (colj[i] != zero)
colj[i] /= (coli[i] + colj[j]);
const element_type colji = colj[i];
for (octave_idx_type k = 0; k < i; k++)
colj[k] -= coli[k] * colji;
}
}
}

Expand Down Expand Up @@ -186,11 +229,11 @@ do_sqrtm (const octave_value& arg)
x = schur_fact.schur_matrix ();
u = schur_fact.unitary_schur_matrix ();
}
while (0); // schur no longer needed.
while (0); // schur_fact no longer needed.

sqrtm_utri_inplace (x);

x = u * x; // original x no longer needed.
x = u * x; // original x no longer needed.
ComplexMatrix res = xgemm (x, u, blas_no_trans, blas_conj_trans);

if (cutoff > 0 && xnorm (imag (res), one) <= cutoff)
Expand Down

0 comments on commit 3c9c7f5

Please sign in to comment.