logging_optimizer.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """
  2. Optimization with logging via a hook.
  3. Author: Jing Wu and Frank Dellaert
  4. """
  5. # pylint: disable=invalid-name
  6. from gtsam import NonlinearOptimizer, NonlinearOptimizerParams
  7. import gtsam
  8. def optimize(optimizer, check_convergence, hook):
  9. """ Given an optimizer and a convergence check, iterate until convergence.
  10. After each iteration, hook(optimizer, error) is called.
  11. After the function, use values and errors to get the result.
  12. Arguments:
  13. optimizer (T): needs an iterate and an error function.
  14. check_convergence: T * float * float -> bool
  15. hook -- hook function to record the error
  16. """
  17. # the optimizer is created with default values which incur the error below
  18. current_error = optimizer.error()
  19. hook(optimizer, current_error)
  20. # Iterative loop
  21. while True:
  22. # Do next iteration
  23. optimizer.iterate()
  24. new_error = optimizer.error()
  25. hook(optimizer, new_error)
  26. if check_convergence(optimizer, current_error, new_error):
  27. return
  28. current_error = new_error
  29. def gtsam_optimize(optimizer,
  30. params,
  31. hook):
  32. """ Given an optimizer and params, iterate until convergence.
  33. After each iteration, hook(optimizer) is called.
  34. After the function, use values and errors to get the result.
  35. Arguments:
  36. optimizer {NonlinearOptimizer} -- Nonlinear optimizer
  37. params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters
  38. hook -- hook function to record the error
  39. """
  40. def check_convergence(optimizer, current_error, new_error):
  41. return (optimizer.iterations() >= params.getMaxIterations()) or (
  42. gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(),
  43. current_error, new_error)) or (
  44. isinstance(optimizer, gtsam.LevenbergMarquardtOptimizer) and optimizer.lambda_() > params.getlambdaUpperBound())
  45. optimize(optimizer, check_convergence, hook)
  46. return optimizer.values()