Skip to content

Commit

Permalink
Merge pull request #80 from Huangzizhou/slim
Browse files Browse the repository at this point in the history
smooth_step after line search for slim
  • Loading branch information
Huangzizhou authored Sep 12, 2024
2 parents 0cd21ba + 5e8eb3d commit 68aacce
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/polysolve/nonlinear/Problem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ namespace polysolve::nonlinear
/// @param new_x New solution.
virtual void solution_changed(const TVector &new_x) {}

virtual bool after_line_search_custom_operation(const TVector &x0, const TVector &x1) { return false; }

/// @brief Callback function used to determine if the solver should stop.
/// @param state Current state of the solver.
/// @param x Current solution.
Expand Down
19 changes: 14 additions & 5 deletions src/polysolve/nonlinear/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,13 @@ namespace polysolve::nonlinear
continue;
}

x += rate * delta_x;
{
TVector x1 = x + rate * delta_x;
if (objFunc.after_line_search_custom_operation(x, x1))
objFunc.solution_changed(x1);
x = x1;
}

old_energy = energy;

// Reset this for the next iterations
Expand Down Expand Up @@ -574,6 +580,7 @@ namespace polysolve::nonlinear
void Solver::verify_gradient(Problem &objFunc, const TVector &x, const TVector &grad)
{
bool match = false;
double J = objFunc(x);

switch (gradient_fd_strategy)
{
Expand All @@ -591,16 +598,18 @@ namespace polysolve::nonlinear
objFunc.solution_changed(x1);
double J1 = objFunc(x1);

double fd = (J2 - J1) / 2 / gradient_fd_eps;
double fd_centered = (J2 - J1) / 2 / gradient_fd_eps;
double fd_right = (J2 - J) / gradient_fd_eps;
double fd_left = (J - J1) / gradient_fd_eps;
double analytic = direc.dot(grad);

match = abs(fd - analytic) < 1e-8 || abs(fd - analytic) < 1e-4 * abs(analytic);
match = abs(fd_centered - analytic) < 1e-8 || abs(fd_centered - analytic) < 1e-4 * abs(analytic);

// Log error in either case to make it more visible in the logs.
if (match)
m_logger.debug("step size: {}, finite difference: {}, derivative: {}", gradient_fd_eps, fd, analytic);
m_logger.debug("step size: {}, finite difference: {} {} {}, derivative: {}", gradient_fd_eps, fd_centered, fd_left, fd_right, analytic);
else
m_logger.error("step size: {}, finite difference: {}, derivative: {}", gradient_fd_eps, fd, analytic);
m_logger.error("step size: {}, finite difference: {} {} {}, derivative: {}", gradient_fd_eps, fd_centered, fd_left, fd_right, analytic);
}
break;
case FiniteDiffStrategy::FULL_FINITE_DIFF:
Expand Down

0 comments on commit 68aacce

Please sign in to comment.