Skip to content

Commit 8c58412

Browse files
Console messaging optimization and addition of SNMFOptimizer.objective_log (#195)
* Added convergence flag to fit() * [pre-commit.ci] auto fixes from pre-commit hooks * Added news * Update snmf_class.py Was giving me errors due to use of python 3.11 * Update snmf_class.py Optimized print messages Added verbose option Added objective_log attribute to track objective function updates, with associated time stamps and specifying what matrix has been updated * Update snmf_class.py * Update snmf_class.py Fixed wrong setting of objective_log * Update snmf_class.py remove commented print messages * Update snmf_class.py Added "iteration" key to objective_log Removed redundant _objective_history attribute * Create optimize-messages.rst * Update snmf_class.py * Update snmf_class.py * Update snmf_class.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8dfb08d commit 8c58412

2 files changed

Lines changed: 130 additions & 57 deletions

File tree

news/optimize-messages.rst

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
**Added:**
2+
3+
* 'SNMFOptimizer.objective_log' attr: dictionary list to track the optimization
4+
process, recording the step, iteration, objective, and timestamp at each update.
5+
Uses the 'step', 'iteration', 'objective' and 'timestamp' keys.
6+
* 'SNMFOptimizer(verbose : Optional[bool])' option and SNMFOptimizer.verbose
7+
attribute to allow users to toggle diagnostic console output.
8+
9+
**Changed:**
10+
11+
* Modified all print messages for improved readability and tied them to the new
12+
verbose flag.
13+
* Refactored convergence checks and step-size calculations to pull objective
14+
values directly from objective_log instead of relying on a separate history
15+
array.
16+
17+
**Deprecated:**
18+
19+
* <news item>
20+
21+
**Removed:**
22+
23+
* Removed the 'SNMFOptimizer._objective_history' list, which was made redundant
24+
by the comprehensive 'SNMFOptimizer.objective_log' tracking system.
25+
26+
**Fixed:**
27+
28+
* <news item>
29+
30+
**Security:**
31+
32+
* <news item>

src/diffpy/stretched_nmf/snmf_class.py

Lines changed: 98 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import time
2+
13
import cvxpy as cp
24
import numpy as np
35
from scipy.optimize import minimize
@@ -82,6 +84,7 @@ def __init__(
8284
eta=0,
8385
random_state=None,
8486
show_plots=False,
87+
verbose=False,
8588
):
8689
"""Initialize an instance of sNMF with estimator
8790
hyperparameters.
@@ -126,6 +129,7 @@ def __init__(
126129
self.eta = eta
127130
self.random_state = random_state
128131
self.show_plots = show_plots
132+
self.verbose = verbose
129133

130134
self._rng = np.random.default_rng(self.random_state)
131135
self._plotter = SNMFPlotter() if self.show_plots else None
@@ -214,6 +218,7 @@ def _initialize_factors(
214218
[1, -2, 1],
215219
offsets=[0, 1, 2],
216220
shape=(self.n_signals_ - 2, self.n_signals_),
221+
dtype=float,
217222
)
218223

219224
def fit(
@@ -303,7 +308,14 @@ def fit(
303308
self.stretch_.copy(),
304309
]
305310
self.objective_difference_ = None
306-
self._objective_history = [self.objective_function_]
311+
self.objective_log = [
312+
{
313+
"step": "start",
314+
"iteration": 0,
315+
"objective": self.objective_function_,
316+
"timestamp": time.time(),
317+
}
318+
]
307319

308320
# Set up tracking variables for _update_components()
309321
self._prev_components = None
@@ -322,13 +334,15 @@ def fit(
322334
sparsity_term = self.eta * np.sum(
323335
np.sqrt(self.components_)
324336
) # Square root penalty
325-
objective_without_penalty = (
337+
base_obj = (
326338
self.objective_function_ - regularization_term - sparsity_term
327339
)
328-
print(
329-
f"Start, Objective function: {self.objective_function_:.5e}"
330-
f", Obj - reg/sparse: {objective_without_penalty:.5e}"
331-
)
340+
if self.verbose:
341+
print(
342+
f"\n--- Start ---"
343+
f"\nTotal Objective : {self.objective_function_:.5e}"
344+
f"\nBase Obj (No Reg) : {base_obj:.5e}"
345+
)
332346

333347
# Main optimization loop
334348
for outiter in range(self.max_iter):
@@ -347,35 +361,23 @@ def fit(
347361
sparsity_term = self.eta * np.sum(
348362
np.sqrt(self.components_)
349363
) # Square root penalty
350-
objective_without_penalty = (
364+
base_obj = (
351365
self.objective_function_ - regularization_term - sparsity_term
352366
)
353-
print(
354-
f"Obj fun: {self.objective_function_:.5e}, "
355-
f"Obj - reg/sparse: {objective_without_penalty:.5e}, "
356-
f"Iter: {self._outer_iter}"
357-
)
358-
obj_diff = (
359-
self.objective_function - regularization_term - sparsity_term
360-
)
361-
print(
362-
f"Obj fun: {self.objective_function:.5e}, "
363-
f", Obj - reg/sparse: {obj_diff:.5e}"
364-
f"Iter: {self.outiter}"
365-
)
366-
367+
convergence_threshold = self.objective_function_ * self.tol
367368
# Convergence check: Stop if diffun is small
368369
# and at least min_iter iterations have passed
369-
print(
370-
"Checking if ",
371-
self.objective_difference_,
372-
" < ",
373-
self.objective_function_ * self.tol,
374-
)
370+
if self.verbose:
371+
print(
372+
f"\n--- Iteration {self._outer_iter} ---"
373+
f"\nTotal Objective : {self.objective_function_:.5e}"
374+
f"\nBase Obj (No Reg) : {base_obj:.5e}"
375+
"\nConvergence Check : Δ "
376+
f"({self.objective_difference_:.2e})"
377+
f" < Threshold ({convergence_threshold:.2e})\n"
378+
)
375379
if (
376-
self.objective_difference_ is not None
377-
and self.objective_difference_
378-
< self.objective_function_ * self.tol
380+
self.objective_difference_ < convergence_threshold
379381
and outiter >= self.min_iter
380382
):
381383
self.converged_ = True
@@ -387,6 +389,8 @@ def fit(
387389
return self
388390

389391
def _normalize_results(self):
392+
if self.verbose:
393+
print("\nNormalizing results after convergence...")
390394
# Select our best results for normalization
391395
self.components_ = self.best_matrices_[0]
392396
self.weights_ = self.best_matrices_[1]
@@ -420,13 +424,17 @@ def _normalize_results(self):
420424
self._update_components()
421425
self.residuals_ = self._get_residual_matrix()
422426
self.objective_function_ = self._get_objective_function()
423-
print(
424-
f"Objective function after normalize_components: "
425-
f"{self.objective_function_:.5e}"
427+
self.objective_log.append(
428+
{
429+
"step": "c_norm",
430+
"iteration": outiter,
431+
"objective": self.objective_function_,
432+
"timestamp": time.time(),
433+
}
426434
)
427-
self._objective_history.append(self.objective_function_)
428435
self.objective_difference_ = (
429-
self._objective_history[-2] - self._objective_history[-1]
436+
self.objective_log[-2]["objective"]
437+
- self.objective_log[-1]["objective"]
430438
)
431439
if self._plotter is not None:
432440
self._plotter.update(
@@ -435,27 +443,41 @@ def _normalize_results(self):
435443
stretch=self.stretch_,
436444
update_tag="normalize components",
437445
)
446+
convergence_threshold = self.objective_function_ * self.tol
447+
if self.verbose:
448+
print(
449+
f"\n--- Iteration {outiter} after normalization---"
450+
f"\nTotal Objective : {self.objective_function_:.5e}"
451+
"\nConvergence Check : Δ "
452+
f"({self.objective_difference_:.2e})"
453+
f" < Threshold ({convergence_threshold:.2e})\n"
454+
)
438455
if (
439-
self.objective_difference_
440-
< self.objective_function_ * self.tol
456+
self.objective_difference_ < convergence_threshold
441457
and outiter >= 7
442458
):
443459
break
444460

445461
def _outer_loop(self):
462+
if self.verbose:
463+
print("Updating components and weights...")
446464
for inner_iter in range(4):
447465
self._inner_iter = inner_iter
448466
self._prev_grad_components = self._grad_components.copy()
449467
self._update_components()
450468
self.residuals_ = self._get_residual_matrix()
451469
self.objective_function_ = self._get_objective_function()
452-
print(
453-
f"Objective function after _update_components: "
454-
f"{self.objective_function_:.5e}"
470+
self.objective_log.append(
471+
{
472+
"step": "c",
473+
"iteration": self._outer_iter,
474+
"objective": self.objective_function_,
475+
"timestamp": time.time(),
476+
}
455477
)
456-
self._objective_history.append(self.objective_function_)
457478
self.objective_difference_ = (
458-
self._objective_history[-2] - self._objective_history[-1]
479+
self.objective_log[-2]["objective"]
480+
- self.objective_log[-1]["objective"]
459481
)
460482
if self.objective_function_ < self.best_objective_:
461483
self.best_objective_ = self.objective_function_
@@ -475,13 +497,18 @@ def _outer_loop(self):
475497
self._update_weights()
476498
self.residuals_ = self._get_residual_matrix()
477499
self.objective_function_ = self._get_objective_function()
478-
print(
479-
f"Objective function after _update_weights: "
480-
f"{self.objective_function_:.5e}"
500+
self.objective_log.append(
501+
{
502+
"step": "w",
503+
"iteration": self._outer_iter,
504+
"objective": self.objective_function_,
505+
"timestamp": time.time(),
506+
}
481507
)
482-
self._objective_history.append(self.objective_function_)
508+
483509
self.objective_difference_ = (
484-
self._objective_history[-2] - self._objective_history[-1]
510+
self.objective_log[-2]["objective"]
511+
- self.objective_log[-1]["objective"]
485512
)
486513
if self.objective_function_ < self.best_objective_:
487514
self.best_objective_ = self.objective_function_
@@ -499,10 +526,11 @@ def _outer_loop(self):
499526
)
500527

501528
self.objective_difference_ = (
502-
self._objective_history[-2] - self._objective_history[-1]
529+
self.objective_log[-2]["objective"]
530+
- self.objective_log[-1]["objective"]
503531
)
504532
if (
505-
self._objective_history[-3] - self.objective_function_
533+
self.objective_log[-3]["objective"] - self.objective_function_
506534
< self.objective_difference_ * 1e-3
507535
):
508536
break
@@ -512,13 +540,17 @@ def _outer_loop(self):
512540
self._update_stretch()
513541
self.residuals_ = self._get_residual_matrix()
514542
self.objective_function_ = self._get_objective_function()
515-
print(
516-
f"Objective function after _update_stretch: "
517-
f"{self.objective_function_:.5e}"
543+
self.objective_log.append(
544+
{
545+
"step": "s",
546+
"iteration": self._outer_iter,
547+
"objective": self.objective_function_,
548+
"timestamp": time.time(),
549+
}
518550
)
519-
self._objective_history.append(self.objective_function_)
520551
self.objective_difference_ = (
521-
self._objective_history[-2] - self._objective_history[-1]
552+
self.objective_log[-2]["objective"]
553+
- self.objective_log[-1]["objective"]
522554
)
523555
if self.objective_function_ < self.best_objective_:
524556
self.best_objective_ = self.objective_function_
@@ -804,7 +836,12 @@ def _solve_quadratic_program(self, t, m):
804836

805837
# Solve using a QP solver
806838
prob = cp.Problem(objective, constraints)
807-
prob.solve(solver=cp.OSQP, verbose=False)
839+
prob.solve(
840+
solver=cp.OSQP,
841+
verbose=False,
842+
polish=False, # TODO keep? removes polish message
843+
# solver_verbose=False
844+
)
808845

809846
# Get the solution
810847
return np.maximum(
@@ -814,6 +851,7 @@ def _solve_quadratic_program(self, t, m):
814851
def _update_components(self):
815852
"""Updates `components` using gradient-based optimization with
816853
adaptive step size."""
854+
817855
# Compute stretched components using the interpolation function
818856
stretched_components, _, _ = (
819857
self._compute_stretched_components()
@@ -878,8 +916,8 @@ def _update_components(self):
878916
)
879917
self.components_ = mask * self.components_
880918

881-
objective_improvement = self._objective_history[
882-
-1
919+
objective_improvement = self.objective_log[-1][
920+
"objective"
883921
] - self._get_objective_function(
884922
residuals=self._get_residual_matrix()
885923
)
@@ -962,6 +1000,9 @@ def _update_stretch(self):
9621000
"""Updates stretching matrix using constrained optimization
9631001
(equivalent to fmincon in MATLAB)."""
9641002

1003+
if self.verbose:
1004+
print("Updating stretch factors...")
1005+
9651006
# Flatten stretch for compatibility with the optimizer
9661007
# (since SciPy expects 1D input)
9671008
stretch_flat_initial = self.stretch_.flatten()

0 commit comments

Comments
 (0)