From 1b35108586270cf3f13bbb6ebebfbe94646cc348 Mon Sep 17 00:00:00 2001 From: Ahmed Khaled Date: Sun, 26 Apr 2026 14:20:16 -0700 Subject: [PATCH] internal PiperOrigin-RevId: 906019116 --- init2winit/trainer_lib/base_trainer.py | 61 ++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 2a253069..32d3d7a3 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -62,7 +62,7 @@ def __init__( early_stopping_mode=None, early_stopping_min_steps=0, eval_steps=None, - log_frequency=None, + log_frequency=100, log_steps=None, metrics_logger=None, init_logger=None, @@ -221,6 +221,8 @@ def __init__( # Initialized in train() when training starts. self._time_at_prev_eval_end = None self._prev_eval_step = None + self._time_at_prev_log_end = None + self._prev_log_step = None assert hps.batch_size % (jax.device_count()) == 0 assert eval_batch_size % (jax.device_count()) == 0 @@ -444,14 +446,55 @@ def _check_early_stopping(self, report): self._early_stopping_target_value) return early_stopping_condition - def _log_training_metrics(self): - """Log training_algorithm metrics.""" + def _build_training_report(self): + """Build a report dict with training metrics. + + Computes mean training cost and training throughput since the last log. + Does not perform any logging or update internal timing state. + + Returns: + A dict containing global_step, train_cost, train_steps_per_sec, and + training algorithm eval report metrics. + """ + time_since_last_log = time.time() - self._time_at_prev_log_end + steps_since_last_log = self._global_step - self._prev_log_step + + mean_train_cost = self._sum_train_cost / max( + 1, self._global_step - self._prev_eval_step + ) + train_steps_per_sec = steps_since_last_log / time_since_last_log + report = dict(global_step=self._global_step) report.update(self.training_algorithm.eval_report_metrics) + report.update( + train_cost=mean_train_cost, + train_steps_per_sec=train_steps_per_sec, + ) + return report + + def _log_training_metrics(self): + """Log training_algorithm metrics.""" + report = self._build_training_report() if jax.process_index() == 0: logging.info('Step %d, training metrics: %r', self._global_step, report) if self._metrics_logger: self._metrics_logger.append_scalar_metrics(report) + # Progress message with ETA estimate. + steps_remaining = self._num_train_steps - self._global_step + pct = self._global_step / self._num_train_steps * 100 + avg_step_time = 1.0 / report['train_steps_per_sec'] + eta = steps_remaining * avg_step_time + trainer_utils.log_message( + f'Step {self._global_step}/{self._num_train_steps}' + f' [{pct:.0f}%].' + f' Average step time is {avg_step_time:.2f}s.' + ' Estimated time left is' + f' {trainer_utils.format_time(eta)}.', + self._logging_pool, + self._xm_work_unit, + ) + self._time_at_prev_log_end = time.time() + self._prev_log_step = self._global_step return report def _eval(self, start_step, start_time, eval_rng, save=True): @@ -505,18 +548,15 @@ def _eval(self, start_step, start_time, eval_rng, save=True): run_time = time.time() - self._time_at_prev_eval_end steps_per_sec = steps_since_last_eval / run_time - mean_train_cost = self._sum_train_cost / max( - 1, self._global_step - self._prev_eval_step - ) + report.update(self._build_training_report()) + self._sum_train_cost = 0.0 epoch = self._global_step * self._hps.batch_size // self._hps.train_size overall_steps_per_sec = self._get_step_frequency( self._global_step, start_step, start_time) report.update( - global_step=self._global_step, epoch=epoch, preemption_count=self._preemption_count, - train_cost=mean_train_cost, overall_steps_per_sec=overall_steps_per_sec, steps_per_sec_no_eval=steps_per_sec_no_eval, steps_per_sec=steps_per_sec, @@ -524,7 +564,6 @@ def _eval(self, start_step, start_time, eval_rng, save=True): run_time_no_eval=time_since_last_eval, run_time=run_time, ) - report.update(self.training_algorithm.eval_report_metrics) if jax.process_index() == 0: trainer_utils.log_eta( self._logging_pool, @@ -540,6 +579,8 @@ def _eval(self, start_step, start_time, eval_rng, save=True): self._time_at_prev_eval_end = time.time() self._prev_eval_step = self._global_step + self._time_at_prev_log_end = self._time_at_prev_eval_end + self._prev_log_step = self._global_step return report def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): @@ -691,6 +732,8 @@ def train(self): # that we can increment as in the case of train_time. self._time_at_prev_eval_end = start_time self._prev_eval_step = self._global_step + self._time_at_prev_log_end = start_time + self._prev_log_step = self._global_step if self._global_step in self._checkpoint_steps: self._save(checkpoint_manager=self._orbax_checkpoint_manager_extra)