Skip to content

Commit

Permalink
maint: Merge default to bytecode-interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
arungiridhar committed Jan 17, 2024
2 parents 5b1703b + eebbfe1 commit aa7b15e
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 100 deletions.
117 changes: 80 additions & 37 deletions libinterp/corefcn/sqrtm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,51 +40,94 @@

OCTAVE_BEGIN_NAMESPACE(octave)

template <typename Matrix>
template <typename T>
static void
sqrtm_utri_inplace (Matrix& T)
sqrtm_utri_inplace (T& m)
{
typedef typename Matrix::element_type element_type;
typedef typename T::element_type element_type;
typedef typename T::real_matrix_type real_matrix_type;
typedef typename T::real_elt_type real_elt_type;

const element_type zero = element_type ();

bool singular = false;
bool diagonal = true;

// The following code is equivalent to this triple loop:
//
// n = rows (T);
// for j = 1:n
// T(j,j) = sqrt (T(j,j));
// for i = j-1:-1:1
// if T(i,j) != 0
// T(i,j) /= (T(i,i) + T(j,j));
// endif
// k = 1:i-1;
// T(k,j) -= T(k,i) * T(i,j);
// endfor
// endfor
//
// this is an in-place, cache-aligned variant of the code
// given in Higham's paper.

const octave_idx_type n = T.rows ();
element_type *Tp = T.rwdata ();
for (octave_idx_type j = 0; j < n; j++)
// The Schur matrix of Hermitian matrices is diagonal.
// check for off-diagonal elements above tolerance
const octave_idx_type n = m.rows ();
real_matrix_type abs_m = m.abs ();

real_elt_type max_abs_diag = 0;
for (octave_idx_type i = 0; i < n; i++)
max_abs_diag = std::max (max_abs_diag, abs_m(i,i));

const real_elt_type tol = n * max_abs_diag
* std::numeric_limits<real_elt_type>::epsilon ();

for (octave_idx_type j = 0; j < n; j++)
{
for (octave_idx_type i = j-1; i >= 0; i--)
{
if (abs_m(i,j) > tol)
{
diagonal = false;
break;
}
}
if (! diagonal)
break;
}

element_type *mp = m.fortran_vec ();
if (diagonal)
{
element_type *colj = Tp + n*j;
if (colj[j] != zero)
colj[j] = sqrt (colj[j]);
else
singular = true;
// shortcut for diagonal Schur matrices
for (octave_idx_type i = 0; i < n; i++)
{
octave_idx_type idx_diag = i*(n+1);
if (mp[idx_diag] != zero)
mp[idx_diag] = sqrt (mp[idx_diag]);
else
singular = true;
}
}
else
{
// The following code is equivalent to this triple loop:
//
// n = rows (m);
// for j = 1:n
// m(j,j) = sqrt (m(j,j));
// for i = j-1:-1:1
// if m(i,j) != 0
// m(i,j) /= (m(i,i) + m(j,j));
// endif
// k = 1:i-1;
// m(k,j) -= m(k,i) * m(i,j);
// endfor
// endfor
//
// this is an in-place, cache-aligned variant of the code
// given in Higham's paper.

for (octave_idx_type i = j-1; i >= 0; i--)
for (octave_idx_type j = 0; j < n; j++)
{
const element_type *coli = Tp + n*i;
if (colj[i] != zero)
colj[i] /= (coli[i] + colj[j]);
const element_type colji = colj[i];
for (octave_idx_type k = 0; k < i; k++)
colj[k] -= coli[k] * colji;
element_type *colj = mp + n*j;
if (colj[j] != zero)
colj[j] = sqrt (colj[j]);
else
singular = true;

for (octave_idx_type i = j-1; i >= 0; i--)
{
const element_type *coli = mp + n*i;
if (colj[i] != zero)
colj[i] /= (coli[i] + colj[j]);
const element_type colji = colj[i];
for (octave_idx_type k = 0; k < i; k++)
colj[k] -= coli[k] * colji;
}
}
}

Expand Down Expand Up @@ -186,11 +229,11 @@ do_sqrtm (const octave_value& arg)
x = schur_fact.schur_matrix ();
u = schur_fact.unitary_schur_matrix ();
}
while (0); // schur no longer needed.
while (0); // schur_fact no longer needed.

sqrtm_utri_inplace (x);

x = u * x; // original x no longer needed.
x = u * x; // original x no longer needed.
ComplexMatrix res = xgemm (x, u, blas_no_trans, blas_conj_trans);

if (cutoff > 0 && xnorm (imag (res), one) <= cutoff)
Expand Down
93 changes: 30 additions & 63 deletions libinterp/parse-tree/oct-parse.yy
Original file line number Diff line number Diff line change
Expand Up @@ -5498,34 +5498,6 @@ OCTAVE_BEGIN_NAMESPACE(octave)
m_lexer.m_allow_command_syntax = false;
}

// FIXME: this function partially duplicates do_dbtype in debug.cc.
static std::string
get_file_line (const std::string& name, int line)
{
// NAME should be an absolute file name and the file should exist.

std::ifstream fs = sys::ifstream (name.c_str (), std::ios::in);

std::string text;

if (fs)
{
int i = 1;

do
{
if (! std::getline (fs, text))
{
text = "";
break;
}
}
while (i++ < line);
}

return text;
}

void
base_parser::bison_error (const std::string& str)
{
Expand All @@ -5543,56 +5515,51 @@ OCTAVE_BEGIN_NAMESPACE(octave)
{
std::ostringstream output_buf;

if (m_lexer.m_reading_fcn_file || m_lexer.m_reading_script_file
|| m_lexer.m_reading_classdef_file)
output_buf << "parse error near line " << err_line
<< " of file " << m_lexer.m_fcn_file_full_name;
else
output_buf << "parse error:";

if (str != "parse error")
output_buf << "\n\n " << str;

output_buf << "\n\n";

std::string curr_line;

if (m_lexer.m_reading_fcn_file || m_lexer.m_reading_script_file
|| m_lexer.m_reading_classdef_file)
curr_line = get_file_line (m_lexer.m_fcn_file_full_name, err_line);
else
curr_line = m_lexer.m_current_input_line;
bool in_file = (m_lexer.m_reading_fcn_file || m_lexer.m_reading_script_file
|| m_lexer.m_reading_classdef_file);

// Adjust the error column for display because it is 1-based in the
// lexer for easier reporting.
err_col--;

if (! curr_line.empty ())
if (in_file)
{
// FIXME: we could do better if we just cached lines from the
// input file in a list. See also functions for managing input
// buffers in lex.ll.
output_buf << str
<< " near line " << err_line << ", column " << err_col << "\n"
<< "error: called from\n"
<< " " << m_lexer.m_fcn_file_name
<< " at line " << err_line << " column " << err_col << "\n";
}
else
{
// On command line, point directly to error
output_buf << str << "\n\n";
std::string curr_line = m_lexer.m_current_input_line;

std::size_t len = curr_line.length ();
if (! curr_line.empty ())
{
// FIXME: we could do better if we just cached lines from the
// input file in a list. See also functions for managing input
// buffers in lex.ll.
std::size_t len = curr_line.length ();

if (curr_line[len-1] == '\n')
curr_line.resize (len-1);
if (curr_line[len-1] == '\n')
curr_line.resize (len-1);

// Print the line, maybe with a pointer near the error token.
// Print the line, maybe with a pointer near the error token.
output_buf << ">>> " << curr_line << "\n";

output_buf << ">>> " << curr_line << "\n";
if (err_col == 0)
err_col = len;

if (err_col == 0)
err_col = len;
for (int i = 0; i < err_col + 3; i++)
output_buf << " ";

for (int i = 0; i < err_col + 3; i++)
output_buf << " ";
output_buf << "^" << "\n";
}

output_buf << "^";
}

output_buf << "\n";

m_parse_error_msg = output_buf.str ();
}

Expand Down

0 comments on commit aa7b15e

Please sign in to comment.