diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index 0ef291a5..9db79ccc 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -517,7 +517,7 @@ def get_optimizer(hps, model=None, batch_axis_name=None): opt_init, opt_update = ( sharpness_aware_minimization.sharpness_aware_minimization( rho=hps.opt_hparams['rho'], - grad_clip=hps.get('grad_clip', None), + grad_clip=hps.opt_hparams.get('grad_clip', None), base_opt_init_fn=opt_init, base_opt_update_fn=opt_update, )