Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions init2winit/trainer_lib/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
early_stopping_mode=None,
early_stopping_min_steps=0,
eval_steps=None,
log_frequency=None,
log_frequency=100,
log_steps=None,
metrics_logger=None,
init_logger=None,
Expand Down Expand Up @@ -221,6 +221,8 @@ def __init__(
# Initialized in train() when training starts.
self._time_at_prev_eval_end = None
self._prev_eval_step = None
self._time_at_prev_log_end = None
self._prev_log_step = None

assert hps.batch_size % (jax.device_count()) == 0
assert eval_batch_size % (jax.device_count()) == 0
Expand Down Expand Up @@ -444,14 +446,55 @@ def _check_early_stopping(self, report):
self._early_stopping_target_value)
return early_stopping_condition

def _log_training_metrics(self):
"""Log training_algorithm metrics."""
def _build_training_report(self):
"""Build a report dict with training metrics.

Computes mean training cost and training throughput since the last log.
Does not perform any logging or update internal timing state.

Returns:
A dict containing global_step, train_cost, train_steps_per_sec, and
training algorithm eval report metrics.
"""
time_since_last_log = time.time() - self._time_at_prev_log_end
steps_since_last_log = self._global_step - self._prev_log_step

mean_train_cost = self._sum_train_cost / max(
1, self._global_step - self._prev_eval_step
)
train_steps_per_sec = steps_since_last_log / time_since_last_log

report = dict(global_step=self._global_step)
report.update(self.training_algorithm.eval_report_metrics)
report.update(
train_cost=mean_train_cost,
train_steps_per_sec=train_steps_per_sec,
)
return report

def _log_training_metrics(self):
"""Log training_algorithm metrics."""
report = self._build_training_report()
if jax.process_index() == 0:
logging.info('Step %d, training metrics: %r', self._global_step, report)
if self._metrics_logger:
self._metrics_logger.append_scalar_metrics(report)
# Progress message with ETA estimate.
steps_remaining = self._num_train_steps - self._global_step
pct = self._global_step / self._num_train_steps * 100
avg_step_time = 1.0 / report['train_steps_per_sec']
eta = steps_remaining * avg_step_time
trainer_utils.log_message(
f'Step {self._global_step}/{self._num_train_steps}'
f' [{pct:.0f}%].'
f' Average step time is {avg_step_time:.2f}s.'
' Estimated time left is'
f' {trainer_utils.format_time(eta)}.',
self._logging_pool,
self._xm_work_unit,
)
self._time_at_prev_log_end = time.time()
self._prev_log_step = self._global_step
return report

def _eval(self, start_step, start_time, eval_rng, save=True):
Expand Down Expand Up @@ -505,26 +548,22 @@ def _eval(self, start_step, start_time, eval_rng, save=True):
run_time = time.time() - self._time_at_prev_eval_end
steps_per_sec = steps_since_last_eval / run_time

mean_train_cost = self._sum_train_cost / max(
1, self._global_step - self._prev_eval_step
)
report.update(self._build_training_report())

self._sum_train_cost = 0.0
epoch = self._global_step * self._hps.batch_size // self._hps.train_size
overall_steps_per_sec = self._get_step_frequency(
self._global_step, start_step, start_time)
report.update(
global_step=self._global_step,
epoch=epoch,
preemption_count=self._preemption_count,
train_cost=mean_train_cost,
overall_steps_per_sec=overall_steps_per_sec,
steps_per_sec_no_eval=steps_per_sec_no_eval,
steps_per_sec=steps_per_sec,
eval_time=eval_time,
run_time_no_eval=time_since_last_eval,
run_time=run_time,
)
report.update(self.training_algorithm.eval_report_metrics)
if jax.process_index() == 0:
trainer_utils.log_eta(
self._logging_pool,
Expand All @@ -540,6 +579,8 @@ def _eval(self, start_step, start_time, eval_rng, save=True):

self._time_at_prev_eval_end = time.time()
self._prev_eval_step = self._global_step
self._time_at_prev_log_end = self._time_at_prev_eval_end
self._prev_log_step = self._global_step
return report

def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng):
Expand Down Expand Up @@ -691,6 +732,8 @@ def train(self):
# that we can increment as in the case of train_time.
self._time_at_prev_eval_end = start_time
self._prev_eval_step = self._global_step
self._time_at_prev_log_end = start_time
self._prev_log_step = self._global_step

if self._global_step in self._checkpoint_steps:
self._save(checkpoint_manager=self._orbax_checkpoint_manager_extra)
Expand Down
Loading