Skip to content

Commit

Permalink
fix: Expression::get_row / get_col no copy #4314 (#4315)
Browse files Browse the repository at this point in the history
* fix: Expression::get_row / get_col no copy #4314

* fix: const keyword

* fix: suppress warnings
  • Loading branch information
martinjrobins authored Aug 8, 2024
1 parent 4cb4edd commit b1fc595
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ class Expression {
/**
* @brief Returns row indices in COO format (where the output data represents sparse matrix elements)
*/
virtual std::vector<expr_int> get_row() = 0;
virtual const std::vector<expr_int>& get_row() = 0;

/**
* @brief Returns column indices in COO format (where the output data represents sparse matrix elements)
*/
virtual std::vector<expr_int> get_col() = 0;
virtual const std::vector<expr_int>& get_col() = 0;

public: // data members
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ CasadiFunction::CasadiFunction(const BaseFunctionType &f) : Expression(), m_func
m_res.resize(sz_res, nullptr);
m_iw.resize(sz_iw, 0);
m_w.resize(sz_w, 0);

if (m_func.n_out() > 0) {
casadi::Sparsity casadi_sparsity = m_func.sparsity_out(0);
m_rows = casadi_sparsity.get_row();
m_cols = casadi_sparsity.get_col();
}
}

// only call this once m_arg and m_res have been set appropriately
Expand All @@ -45,24 +51,14 @@ expr_int CasadiFunction::nnz_out() {
return static_cast<expr_int>(m_func.nnz_out());
}

std::vector<expr_int> CasadiFunction::get_row() {
return get_row(0);
}

std::vector<expr_int> CasadiFunction::get_row(expr_int ind) {
const std::vector<expr_int>& CasadiFunction::get_row() {
DEBUG("CasadiFunction get_row(): " << m_func.name());
casadi::Sparsity casadi_sparsity = m_func.sparsity_out(ind);
return casadi_sparsity.get_row();
}

std::vector<expr_int> CasadiFunction::get_col() {
return get_col(0);
return m_rows;
}

std::vector<expr_int> CasadiFunction::get_col(expr_int ind) {
const std::vector<expr_int>& CasadiFunction::get_col() {
DEBUG("CasadiFunction get_col(): " << m_func.name());
casadi::Sparsity casadi_sparsity = m_func.sparsity_out(ind);
return casadi_sparsity.get_col();
return m_cols;
}

void CasadiFunction::operator()(const std::vector<realtype*>& inputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ class CasadiFunction : public Expression
expr_int out_shape(int k) override;
expr_int nnz() override;
expr_int nnz_out() override;
std::vector<expr_int> get_row() override;
std::vector<expr_int> get_row(expr_int ind);
std::vector<expr_int> get_col() override;
std::vector<expr_int> get_col(expr_int ind);
const std::vector<expr_int>& get_row() override;
const std::vector<expr_int>& get_col() override;

public:
/*
Expand All @@ -43,6 +41,8 @@ class CasadiFunction : public Expression
private:
std::vector<expr_int> m_iw; // cppcheck-suppress unusedStructMember
std::vector<double> m_w; // cppcheck-suppress unusedStructMember
std::vector<expr_int> m_rows; // cppcheck-suppress unusedStructMember
std::vector<expr_int> m_cols; // cppcheck-suppress unusedStructMember
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class IREEFunction : public Expression
expr_int out_shape(int k) override;
expr_int nnz() override;
expr_int nnz_out() override;
std::vector<expr_int> get_col() override;
std::vector<expr_int> get_row() override;
const std::vector<expr_int>& get_col() override;
const std::vector<expr_int>& get_row() override;

/*
* @brief Evaluate the MLIR function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ expr_int IREEFunction::nnz_out() {
return m_func.nnz;
}

std::vector<expr_int> IREEFunction::get_row() {
const std::vector<expr_int>& IREEFunction::get_row() {
DEBUG("IreeFunction get_row" << m_func.row.size());
return m_func.row;
}

std::vector<expr_int> IREEFunction::get_col() {
const std::vector<expr_int>& IREEFunction::get_col() {
DEBUG("IreeFunction get_col" << m_func.col.size());
return m_func.col;
}
Expand Down
8 changes: 4 additions & 4 deletions src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ void IDAKLUSolverOpenMP<ExprSet>::CalcVarsSensitivities(
DEBUG("IDAKLUSolver::CalcVarsSensitivities");
// Calculate sensitivities
std::vector<realtype> dens_dvar_dp = std::vector<realtype>(number_of_parameters, 0);
for (size_t dvar_k=0; dvar_k<functions->dvar_dy_fcns.size(); dvar_k++) {
for (size_t dvar_k = 0; dvar_k < functions->dvar_dy_fcns.size(); dvar_k++) {
// Isolate functions
Expression* dvar_dy = functions->dvar_dy_fcns[dvar_k];
Expression* dvar_dp = functions->dvar_dp_fcns[dvar_k];
Expand All @@ -306,15 +306,15 @@ void IDAKLUSolverOpenMP<ExprSet>::CalcVarsSensitivities(
// Calculate dvar/dp and convert to dense array for indexing
(*dvar_dp)({tret, yval, functions->inputs.data()}, {&res_dvar_dp[0]});
for (int k=0; k<number_of_parameters; k++) {
dens_dvar_dp[k]=0;
dens_dvar_dp[k] = 0;
}
for (int k=0; k<dvar_dp->nnz_out(); k++) {
dens_dvar_dp[dvar_dp->get_row()[k]] = res_dvar_dp[k];
}
// Calculate sensitivities
for (int paramk=0; paramk<number_of_parameters; paramk++) {
for (int paramk = 0; paramk < number_of_parameters; paramk++) {
yS_return[*ySk] = dens_dvar_dp[paramk];
for (int spk=0; spk<dvar_dy->nnz_out(); spk++) {
for (int spk = 0; spk < dvar_dy->nnz_out(); spk++) {
yS_return[*ySk] += res_dvar_dy[spk] * ySval[paramk][dvar_dy->get_col()[spk]];
}
(*ySk)++;
Expand Down

0 comments on commit b1fc595

Please sign in to comment.