diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index b145599d..0dffe467 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -21,6 +21,8 @@ from init2winit.dataset_lib import datasets from init2winit.init_lib import initializers from init2winit.model_lib import models +from init2winit.trainer_lib import training_algorithm +from init2winit.trainer_lib import training_algorithms from ml_collections.config_dict import config_dict @@ -95,18 +97,24 @@ def expand_dot_keys(d): return expanded_dict -def build_hparams(model_name, - initializer_name, - dataset_name, - hparam_overrides, - input_pipeline_hps=None, - allowed_unrecognized_hparams=None): +def build_hparams( + model_name, + initializer_name, + dataset_name, + training_algorithm_name, + hparam_overrides, + input_pipeline_hps=None, + allowed_unrecognized_hparams=None, +): """Build experiment hyperparameters. Args: model_name: the string model name. initializer_name: the string initializer name. dataset_name: the string dataset name. + training_algorithm_name: the string name of the training algorithm. Used to + look up algorithm-specific default training hyperparameters (optimizer, + opt_hparams, lr_hparams). hparam_overrides: a dict of hyperparameter override names/values, or a JSON string encoding of this hyperparameter override dict. Note that this is applied after the hyperparameter file overrides. @@ -125,22 +133,49 @@ def build_hparams(model_name, initializer_hps = initializers.get_initializer_hparams(initializer_name) dataset_hps = datasets.get_dataset_hparams(dataset_name) input_pipeline_hps = input_pipeline_hps or config_dict.ConfigDict() + overrides_dict = hparam_overrides or {} + if isinstance(overrides_dict, str): + overrides_dict = json.loads(overrides_dict) + + # Training hparams come from the training algorithm. + algo_cls = training_algorithms.get_training_algorithm(training_algorithm_name) + + # For OptaxTrainingAlgorithm, pass optimizer_name (if overridden) and + # model_name so it can resolve defaults using the 3-tier hierarchy. + if issubclass(algo_cls, training_algorithm.OptaxTrainingAlgorithm): + optimizer_name = ( + None if not overrides_dict else overrides_dict.get('optimizer') + ) + training_hps = algo_cls.get_default_training_hparams( + optimizer_name=optimizer_name, + model_name=model_name, + ) + else: + training_hps = algo_cls.get_default_training_hparams() merged_dict = {} hps_dicts = [ hps.to_dict() - for hps in [model_hps, initializer_hps, dataset_hps, input_pipeline_hps] + for hps in [ + training_hps, + model_hps, + initializer_hps, + dataset_hps, + input_pipeline_hps, + ] ] - total_hps = 0 + # Check that all provided hps have no overlap. + seen_keys = set() for hps_dict in hps_dicts: - merged_dict.update(hps_dict) - total_hps += len(hps_dict.keys()) + overlap = seen_keys.intersection(hps_dict.keys()) + if overlap: + raise ValueError(f'There is overlap in the provided hparams: {overlap}') + seen_keys.update(hps_dict.keys()) - # Check that all provided have no overlap. - if total_hps != len(merged_dict.keys()): - raise ValueError('There is overlap in the provided hparams.') + for hps_dict in hps_dicts: + merged_dict.update(hps_dict) # Convert to the Shallue and Lee label smoothing style. if merged_dict.get('use_shallue_label_smoothing', False): @@ -159,34 +194,33 @@ def build_hparams(model_name, for key in ['opt_hparams', 'lr_hparams']: merged[key].unlock() - if hparam_overrides: - if isinstance(hparam_overrides, str): - hparam_overrides = json.loads(hparam_overrides) - + if overrides_dict: # If the user is changing the learning rate schedule or optimizer. We must # wipe all of the keys from the old dictionary. merged_schedule = None if merged.get('lr_hparams'): merged_schedule = merged['lr_hparams'].get('schedule') overrides_schedule = None - if hparam_overrides.get('lr_hparams'): - overrides_schedule = hparam_overrides['lr_hparams'].get('schedule') + if overrides_dict.get('lr_hparams'): + overrides_schedule = overrides_dict['lr_hparams'].get('schedule') if overrides_schedule and merged_schedule != overrides_schedule: merged['lr_hparams'] = {} - if ('optimizer' in hparam_overrides and - merged['optimizer'] != hparam_overrides['optimizer']): + if ( + 'optimizer' in overrides_dict + and merged['optimizer'] != overrides_dict['optimizer'] + ): merged['opt_hparams'] = {} - hparam_overrides = expand_dot_keys(hparam_overrides) + overrides_dict = expand_dot_keys(overrides_dict) if allowed_unrecognized_hparams: - new_keys = [k for k in hparam_overrides if k not in merged] + new_keys = [k for k in overrides_dict if k not in merged] if new_keys: logging.warning('Unrecognized top-level hparams: %s', new_keys) if any(k not in allowed_unrecognized_hparams for k in new_keys): raise ValueError( f'Unrecognized top-level hparams not in allowlist: {new_keys}') with merged.unlocked(): - merged.update(hparam_overrides) + merged.update(overrides_dict) else: - merged.update(hparam_overrides) + merged.update(overrides_dict) return merged diff --git a/init2winit/main.py b/init2winit/main.py index 946a708f..e13451d7 100644 --- a/init2winit/main.py +++ b/init2winit/main.py @@ -165,9 +165,11 @@ def _run( model_name=model_name, initializer_name=initializer_name, dataset_name=dataset_name, + training_algorithm_name=training_algorithm_name, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps, - allowed_unrecognized_hparams=allowed_unrecognized_hparams) + allowed_unrecognized_hparams=allowed_unrecognized_hparams, + ) # Note that one should never tune an RNG seed!!! The seed is only included in # the hparams for convenience of running hparam trials with multiple seeds per diff --git a/init2winit/model_lib/adabelief_densenet.py b/init2winit/model_lib/adabelief_densenet.py index af64b028..9b06ac5e 100644 --- a/init2winit/model_lib/adabelief_densenet.py +++ b/init2winit/model_lib/adabelief_densenet.py @@ -48,27 +48,12 @@ # results in a large Dense matrix in the readout layer and unstable # training. use_kernel_size_as_stride_in_pooling=True, - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, normalizer='batch_norm', - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, normalize_classifier_input='none', classification_scale_factor=1.0, - )) + ) +) class BottleneckBlock(nn.Module): diff --git a/init2winit/model_lib/adabelief_resnet.py b/init2winit/model_lib/adabelief_resnet.py index 72d6282b..6954e771 100644 --- a/init2winit/model_lib/adabelief_resnet.py +++ b/init2winit/model_lib/adabelief_resnet.py @@ -49,29 +49,12 @@ dict( num_filters=16, num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, # Make this a string to avoid having to import jnp into the configs. model_dtype='float32', virtual_batch_size=None, - total_accumulated_batch_size=None, data_format='NHWC', - grad_clip=None, )) diff --git a/init2winit/model_lib/adabelief_vgg.py b/init2winit/model_lib/adabelief_vgg.py index d7e1b143..95422ea1 100644 --- a/init2winit/model_lib/adabelief_vgg.py +++ b/init2winit/model_lib/adabelief_vgg.py @@ -39,25 +39,10 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( num_layers=11, # Must be one of [11, 13, 16, 19] - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, normalizer='none', - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, - )) + ) +) def classifier(x, num_outputs, dropout_rate, deterministic): diff --git a/init2winit/model_lib/autoencoder.py b/init2winit/model_lib/autoencoder.py index 36c803d7..9e6c7bb0 100644 --- a/init2winit/model_lib/autoencoder.py +++ b/init2winit/model_lib/autoencoder.py @@ -37,27 +37,7 @@ hid_sizes=[128, 64, 32, 64, 128], activation_function=['relu', 'relu', 'relu', 'relu', 'relu'], kernel_scales=[1.0] * 6, - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='hessian_free', - opt_hparams={ - 'cg_max_iter': 250, - 'cg_iter_tracking_method': 'back_tracking', - 'use_line_search': True, - 'init_damping': 50.0, - 'damping_ub': 10 ** 2, - 'damping_lb': 10 ** -6, - }, - batch_size=128, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - l2_decay_factor=2e-5, - l2_decay_rank_threshold=1, )) @@ -82,4 +62,3 @@ def get_fake_inputs(self, hps): jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) ] return dummy_inputs - \ No newline at end of file diff --git a/init2winit/model_lib/conformer.py b/init2winit/model_lib/conformer.py index c732d372..146ef857 100644 --- a/init2winit/model_lib/conformer.py +++ b/init2winit/model_lib/conformer.py @@ -46,25 +46,8 @@ MLCOMMONS_DEFAULT_HPARAMS = config_dict.ConfigDict( dict( activation_function='swish', - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - batch_size=256, eval_batch_size=128, - l2_decay_factor=1e-6, - l2_decay_rank_threshold=0, - use_shallue_label_smoothing=False, - rng_seed=-1, model_dtype='float32', - grad_clip=5.0, encoder_dim=512, num_attention_heads=8, num_encoder_layers=16, @@ -84,7 +67,6 @@ enable_decoder_pre_layer_norm=True, enable_conformer_post_layer_norm=True, use_lingvo_attention=False, - total_accumulated_batch_size=None, attn_temperature=1.0, )) @@ -92,25 +74,8 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( activation_function='swish', - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - batch_size=256, eval_batch_size=128, - l2_decay_factor=1e-6, - l2_decay_rank_threshold=0, - use_shallue_label_smoothing=False, - rng_seed=-1, model_dtype='float32', - grad_clip=5.0, encoder_dim=512, num_attention_heads=8, num_encoder_layers=16, @@ -128,7 +93,6 @@ enable_decoder_pre_layer_norm=True, enable_conformer_post_layer_norm=True, use_lingvo_attention=False, - total_accumulated_batch_size=None, attn_temperature=1.0)) diff --git a/init2winit/model_lib/convolutional_autoencoder.py b/init2winit/model_lib/convolutional_autoencoder.py index f04fc459..607b9bf5 100644 --- a/init2winit/model_lib/convolutional_autoencoder.py +++ b/init2winit/model_lib/convolutional_autoencoder.py @@ -49,25 +49,8 @@ 'paddings': ['SAME', ((1, 0), (1, 0)), 'SAME', 'SAME'], 'activations': ['relu', 'relu', 'relu', 'id'], }, - activation_function='relu', - lr_hparams={ - 'base_lr': 0.02, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='momentum', - opt_hparams={ - 'momentum': 0, - }, - batch_size=128, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, )) diff --git a/init2winit/model_lib/deepspeech.py b/init2winit/model_lib/deepspeech.py index 3c4fba49..89b33b8b 100644 --- a/init2winit/model_lib/deepspeech.py +++ b/init2winit/model_lib/deepspeech.py @@ -47,21 +47,8 @@ MLCOMMONS_DEFAULT_HPARAMS = config_dict.ConfigDict( dict( activation='relu', - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - batch_size=256, eval_batch_size=128, - l2_decay_factor=1e-6, - l2_decay_rank_threshold=0, - use_shallue_label_smoothing=False, - rng_seed=-1, model_dtype='float32', - grad_clip=10.0, num_lstm_layers=4, num_ffn_layers=3, encoder_dim=512, @@ -79,7 +66,6 @@ enable_residual_connections=False, enable_decoder_layer_norm=False, bidirectional=True, - total_accumulated_batch_size=None, enable_subsampling_batchnorm=False, enable_synced_batchnorm=False, layernorm_everywhere=False)) @@ -88,21 +74,8 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( activation='relu', - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - batch_size=256, eval_batch_size=128, - l2_decay_factor=1e-6, - l2_decay_rank_threshold=0, - use_shallue_label_smoothing=False, - rng_seed=-1, model_dtype='float32', - grad_clip=10.0, num_lstm_layers=4, num_ffn_layers=3, encoder_dim=512, @@ -119,7 +92,6 @@ enable_residual_connections=False, enable_decoder_layer_norm=False, bidirectional=True, - total_accumulated_batch_size=None, enable_subsampling_batchnorm=False, enable_synced_batchnorm=False, layernorm_everywhere=False)) diff --git a/init2winit/model_lib/dlrm.py b/init2winit/model_lib/dlrm.py index b340ab56..786b37f5 100644 --- a/init2winit/model_lib/dlrm.py +++ b/init2winit/model_lib/dlrm.py @@ -33,7 +33,6 @@ dict( activation_function='relu', embedding_init_multiplier=None, - rng_seed=-1, model_dtype='float32', vocab_size=32 * 128 * 1024, mlp_bottom_dims=[128, 128], @@ -41,22 +40,7 @@ output_shape=(1,), embed_dim=64, keep_diags=True, - optimizer='adam', - batch_size=128, num_dense_features=13, - lr_hparams={ - 'base_lr': 0.01, - 'schedule': 'constant' - }, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - }, - l2_decay_factor=1e-5, - l2_decay_rank_threshold=2, - total_accumulated_batch_size=None, - grad_clip=None, dropout_rate=0.0, normalizer='none', # dropout will exist only if there are at least two top mlp layers diff --git a/init2winit/model_lib/fully_connected.py b/init2winit/model_lib/fully_connected.py index 641e057e..4bd60ef9 100644 --- a/init2winit/model_lib/fully_connected.py +++ b/init2winit/model_lib/fully_connected.py @@ -30,26 +30,10 @@ dict( hid_sizes=[20, 10], kernel_scales=[1.0, 1.0, 1.0], - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - total_accumulated_batch_size=None, activation_function='relu', - l2_decay_factor=.0005, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, - )) + ) +) class FullyConnected(nn.Module): diff --git a/init2winit/model_lib/gnn.py b/init2winit/model_lib/gnn.py index 33ff062e..9f7da0b7 100644 --- a/init2winit/model_lib/gnn.py +++ b/init2winit/model_lib/gnn.py @@ -37,32 +37,13 @@ # small hparams used for unit tests DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - rng_seed=-1, model_dtype='float32', latent_dim=256, - optimizer='adam', hidden_dims=(256,), - batch_size=256, - lr_hparams={ - 'base_lr': 0.01, - 'schedule': 'constant' - }, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0.0, - }, activation_function='relu', - l2_decay_factor=.0005, - l2_decay_rank_threshold=2, num_message_passing_steps=5, normalizer='layer_norm', dropout_rate=0.1, - total_accumulated_batch_size=None, - grad_clip=None, - label_smoothing=0.0, - use_shallue_label_smoothing=False, )) diff --git a/init2winit/model_lib/local_attention_transformer.py b/init2winit/model_lib/local_attention_transformer.py index 9709da74..b489cd01 100644 --- a/init2winit/model_lib/local_attention_transformer.py +++ b/init2winit/model_lib/local_attention_transformer.py @@ -79,31 +79,8 @@ feedforward_dropout=0.0, feedforward_depths=[4096, 1032], model_dtype='float32', - batch_size=8, - grad_clip=None, - lr_hparams={ - 'base_lr': 0.01, - 'defer_steps': 10000, - 'schedule': 't2t_rsqrt_normalized_decay', - }, - optimizer='adafactor', - opt_hparams={ - 'adafactor_decay_rate': 0.8, - 'clipping_threshold': 1.0, - 'factored': True, - 'min_dim_size_to_factor': 128, - # The 2 hyperparameters cause errors with optax.inject_hyperparams - # In this case it is not relevant since the default - # adafactors values are needed - # 'adafactor_momentum': 0.0, - # 'multiply_by_parameter_scale': True, - }, - # Below hyperparameters needed only to make the model - # compatible with init2winit library - rng_seed=-1, - label_smoothing=None, - weight_decay=None, - l2_decay_factor=None,)) + ) +) Tensor = Union[np.array, jnp.ndarray] diff --git a/init2winit/model_lib/lstm_lm.py b/init2winit/model_lib/lstm_lm.py index f4276131..5bc2b554 100644 --- a/init2winit/model_lib/lstm_lm.py +++ b/init2winit/model_lib/lstm_lm.py @@ -35,9 +35,6 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - # training params - batch_size=256, - rng_seed=-1, # model architecture params model_dtype='float32', bidirectional=False, @@ -49,20 +46,6 @@ recurrent_dropout_rate=0.1, tie_embeddings=False, projection_layer=False, - # optimizer params - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 1e-3, - }, - l2_decay_factor=None, - grad_clip=None, - optimizer='adam', - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0, - }, ) ) diff --git a/init2winit/model_lib/max_pooling_cnn.py b/init2winit/model_lib/max_pooling_cnn.py index 1520c7a9..d73d0a64 100644 --- a/init2winit/model_lib/max_pooling_cnn.py +++ b/init2winit/model_lib/max_pooling_cnn.py @@ -38,26 +38,9 @@ window_paddings=['SAME', 'SAME', 'SAME'], strides=[2, 2, 2], num_dense_units=[512, 256], - lr_hparams={ - 'base_lr': 0.001, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, activation_fn='relu', normalizer='none', - l2_decay_factor=.0005, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, - total_accumulated_batch_size=None, )) diff --git a/init2winit/model_lib/mdlm_rope_nanodo.py b/init2winit/model_lib/mdlm_rope_nanodo.py index 97815764..74c3d80b 100644 --- a/init2winit/model_lib/mdlm_rope_nanodo.py +++ b/init2winit/model_lib/mdlm_rope_nanodo.py @@ -37,23 +37,8 @@ num_heads=8, num_layers=12, mlp_dim=2048, - rng_seed=-1, computation_dtype='bfloat16', model_dtype='float32', - optimizer='adam', - batch_size=256, - lr_hparams={'base_lr': 0.01, 'schedule': 'constant'}, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0.0, - }, - l2_decay_factor=0.0005, - l2_decay_rank_threshold=2, - grad_clip=None, - label_smoothing=0.0, - use_shallue_label_smoothing=False, normalization='rmsnorm', mlp_activation='glu', qk_norm=True, diff --git a/init2winit/model_lib/mlperf_resnet.py b/init2winit/model_lib/mlperf_resnet.py index 59c8d562..5bbf85f3 100644 --- a/init2winit/model_lib/mlperf_resnet.py +++ b/init2winit/model_lib/mlperf_resnet.py @@ -30,33 +30,10 @@ FAKE_MODEL_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( num_filters=16, num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - layer_rescale_factors={}, - lr_hparams={ - 'batch_size': 128, - 'base_lr': 10.0, - 'decay_end': -1, - 'end_lr': 1e-4, - 'power': 2.0, - 'schedule': 'mlperf_polynomial', - 'start_lr': 0.0, - 'steps_per_epoch': 10009.250000000002, - 'warmup_steps': 18, - }, - optimizer='mlperf_lars_resnet', - opt_hparams={ - 'weight_decay': 2e-4, - 'beta': 0.9 - }, - batch_size=128, - l2_decay_factor=None, - l2_decay_rank_threshold=2, - label_smoothing=.1, - use_shallue_label_smoothing=False, model_dtype='float32', virtual_batch_size=64, data_format='NHWC', activation_function='relu', - grad_clip=None, dropout_rate=0.0, )) @@ -66,30 +43,13 @@ num_filters=16, # We set default to 18 for faster unit tests. num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, bn_output_scale=0.0, - l2_decay_factor=None, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, model_dtype='float32', virtual_batch_size=64, - total_accumulated_batch_size=None, data_format='NHWC', activation_function='relu', - grad_clip=None, dropout_rate=0.0, )) diff --git a/init2winit/model_lib/nanodo.py b/init2winit/model_lib/nanodo.py index 99fb0a0e..dac53a55 100644 --- a/init2winit/model_lib/nanodo.py +++ b/init2winit/model_lib/nanodo.py @@ -41,23 +41,8 @@ num_heads=8, # num attention heads num_layers=6, # number of transformer block layers mlp_dim=2048, # FF inner dimension - rng_seed=-1, computation_dtype='bfloat16', model_dtype='float32', - optimizer='adam', - batch_size=256, - lr_hparams={'base_lr': 0.01, 'schedule': 'constant'}, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0.0, - }, - l2_decay_factor=0.0005, - l2_decay_rank_threshold=2, - grad_clip=None, - label_smoothing=0.0, - use_shallue_label_smoothing=False, ) ) diff --git a/init2winit/model_lib/nqm.py b/init2winit/model_lib/nqm.py index 818bddf7..baa8a350 100644 --- a/init2winit/model_lib/nqm.py +++ b/init2winit/model_lib/nqm.py @@ -27,22 +27,11 @@ # small hparams used for unit tests DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - optimizer='momentum', - opt_hparams={ - 'momentum': 0.0, - }, - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - batch_size=128, - rng_seed=-1, # Note the dimension is set by input_shape. hessian_decay_power=1, noise_decay_power=1, nqm_mode='diagH_diagC', model_dtype='float32', - l2_decay_factor=None, )) diff --git a/init2winit/model_lib/resnet.py b/init2winit/model_lib/resnet.py index ae312232..4198259f 100644 --- a/init2winit/model_lib/resnet.py +++ b/init2winit/model_lib/resnet.py @@ -29,33 +29,16 @@ DEFAULT_HPARAMS = config_dict.ConfigDict(dict( num_filters=16, num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, # Make this a string to avoid having to import jnp into the configs. model_dtype='float32', virtual_batch_size=64, - total_accumulated_batch_size=None, data_format='NHWC', block_type='post_activation', # either pre_activation or post_activation bn_relu_conv=True, # only used for block_type='pre_activation' use_bn=True, dropout_rate=0.0, - grad_clip=None, activation_function='relu', extra_norm_on_residual=False, )) diff --git a/init2winit/model_lib/rope_nanodo.py b/init2winit/model_lib/rope_nanodo.py index 0b43ac8d..803cfff9 100644 --- a/init2winit/model_lib/rope_nanodo.py +++ b/init2winit/model_lib/rope_nanodo.py @@ -42,23 +42,8 @@ num_heads=8, # num attention heads num_layers=12, # number of transformer block layers mlp_dim=2048, # FF inner dimension - rng_seed=-1, computation_dtype='bfloat16', model_dtype='float32', - optimizer='adam', - batch_size=256, - lr_hparams={'base_lr': 0.01, 'schedule': 'constant'}, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0.0, - }, - l2_decay_factor=0.0005, - l2_decay_rank_threshold=2, - grad_clip=None, - label_smoothing=0.0, - use_shallue_label_smoothing=False, normalization='rmsnorm', mlp_activation='glu', qk_norm=True, diff --git a/init2winit/model_lib/simple_cnn.py b/init2winit/model_lib/simple_cnn.py index 9cbdb5e6..e1d35b53 100644 --- a/init2winit/model_lib/simple_cnn.py +++ b/init2winit/model_lib/simple_cnn.py @@ -29,22 +29,7 @@ DEFAULT_HPARAMS = config_dict.ConfigDict(dict( num_filters=[20, 10], kernel_sizes=[3, 3], - lr_hparams={ - 'base_lr': 0.001, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, activation_function='relu', - l2_decay_factor=.0005, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', )) diff --git a/init2winit/model_lib/test_models.py b/init2winit/model_lib/test_models.py index 5f7155c0..0a19e153 100644 --- a/init2winit/model_lib/test_models.py +++ b/init2winit/model_lib/test_models.py @@ -28,6 +28,7 @@ from init2winit.init_lib import initializers from init2winit.model_lib import model_utils from init2winit.model_lib import models +from init2winit.trainer_lib import training_algorithm import jax from jax.experimental import mesh_utils from jax.flatten_util import ravel_pytree @@ -425,7 +426,10 @@ def _get_fake_inputs_for_initialization(model, hps): def _initialize_model(model_str, model_dtype): """Initialize a model given a registry name and dtype.""" model_cls = models.get_model(model_str) - hps = models.get_model_hparams(model_str) + hps = copy.deepcopy( + training_algorithm.OptaxTrainingAlgorithm.get_default_training_hparams() + ) + hps.update(models.get_model_hparams(model_str)) hps.update(DATA_HPS[model_str]) if 'input_edge_shape' in hps and 'input_node_shape' in hps: hps.input_shape = (hps.input_node_shape, hps.input_edge_shape) @@ -462,7 +466,10 @@ def test_classification_models(self, model_str): model_hps = models.get_model_hparams(model_str) loss = 'cross_entropy' metrics = 'classification_metrics' - hps = copy.copy(model_hps) + hps = copy.deepcopy( + training_algorithm.OptaxTrainingAlgorithm.get_default_training_hparams() + ) + hps.update(model_hps) hps.update({'output_shape': OUTPUT_SHAPE['classification']}) rng = jax.random.PRNGKey(0) dropout_rng, params_rng = jax.random.split(rng) diff --git a/init2winit/model_lib/transformer_lm.py b/init2winit/model_lib/transformer_lm.py index 0e310ff8..f2fb4b3d 100644 --- a/init2winit/model_lib/transformer_lm.py +++ b/init2winit/model_lib/transformer_lm.py @@ -37,7 +37,6 @@ # These reproduce the flax example. DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=512, emb_dim=128, num_heads=8, num_layers=6, @@ -45,28 +44,8 @@ mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.1 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.0016, - 'warmup_steps': 1000, - 'squash_steps': 1000, - 'schedule': 'rsqrt_normalized_decay_warmup' - }, - label_smoothing=None, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, decode=False, normalize_attention=False, )) diff --git a/init2winit/model_lib/transformer_stu_lm.py b/init2winit/model_lib/transformer_stu_lm.py index 07e41acf..2a5d2fcd 100644 --- a/init2winit/model_lib/transformer_stu_lm.py +++ b/init2winit/model_lib/transformer_stu_lm.py @@ -40,7 +40,6 @@ # These reproduce the flax example. DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=512, emb_dim=128, num_heads=8, num_layers=6, @@ -48,28 +47,8 @@ mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.1 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.0016, - 'warmup_steps': 1000, - 'squash_steps': 1000, - 'schedule': 'rsqrt_normalized_decay_warmup' - }, - label_smoothing=None, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, decode=False, normalize_attention=False, input_len=128, diff --git a/init2winit/model_lib/transformer_stu_tensordot_lm.py b/init2winit/model_lib/transformer_stu_tensordot_lm.py index 42a1ae52..351cde4d 100644 --- a/init2winit/model_lib/transformer_stu_tensordot_lm.py +++ b/init2winit/model_lib/transformer_stu_tensordot_lm.py @@ -40,7 +40,6 @@ # These reproduce the flax example. DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=512, emb_dim=128, num_heads=8, num_layers=6, @@ -48,28 +47,8 @@ mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.1 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.0016, - 'warmup_steps': 1000, - 'squash_steps': 1000, - 'schedule': 'rsqrt_normalized_decay_warmup' - }, - label_smoothing=None, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, decode=False, normalize_attention=False, input_len=128, diff --git a/init2winit/model_lib/unet.py b/init2winit/model_lib/unet.py index 3a4e3741..b12b4982 100644 --- a/init2winit/model_lib/unet.py +++ b/init2winit/model_lib/unet.py @@ -33,43 +33,6 @@ from ml_collections import config_dict -# NOTE(dsuo): We use the Kitchen Sink optimizer to match the RMSProp -# implementation found in the reference FastMRI U-Net code. Specifically, -# epsilon in optax's scale_by_rms places its epsilon inside the square root, -# whereas the reference code epsilon outside. -opt_hparams = { - 'weight_decay': 0.0, - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, -} - -# NOTE(dsuo): This lives here because decay_events / decay_factors is too large -# to pass via the config file. -_FASTMRI_TRAIN_SIZE = 34742 -_FASTMRI_VALID_SIZE = 7135 - -batch_size = 8 -num_epochs = 50 -steps_per_epoch = int(_FASTMRI_TRAIN_SIZE / batch_size) -num_train_steps = num_epochs * steps_per_epoch -lr_gamma = 0.1 -lr_step_size = 40 * steps_per_epoch -decay_events = list(range(lr_step_size, num_train_steps, lr_step_size)) -decay_factors = [lr_gamma] * len(decay_events) -decay_factors = [ - decay_factor**i - for decay_factor, i in zip(decay_factors, range(1, - len(decay_events) + 1)) -] - -lr_hparams = { - 'schedule': 'piecewise_constant', - 'base_lr': 1e-3, - 'decay_events': decay_events, - 'decay_factors': decay_factors -} - DEFAULT_HPARAMS = config_dict.ConfigDict( dict( out_chans=1, @@ -77,15 +40,7 @@ num_pool_layers=4, dropout_rate=0.0, activation='leaky_relu', - optimizer='adam', - opt_hparams=opt_hparams, - lr_hparams=lr_hparams, - l2_decay_factor=None, - batch_size=batch_size, - rng_seed=-1, model_dtype='float32', - grad_clip=None, - total_accumulated_batch_size=None, normalizer='unet_instance_norm', )) diff --git a/init2winit/model_lib/vit.py b/init2winit/model_lib/vit.py index 6bcad8f9..275f9dcc 100644 --- a/init2winit/model_lib/vit.py +++ b/init2winit/model_lib/vit.py @@ -43,27 +43,8 @@ pool_type='gap', posemb='sincos2d', head_zeroinit=True, - lr_hparams={ - 'base_lr': 1e-3, - 'schedule': 'cosine_warmup', - }, - optimizer='adam', - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 1e-1, - }, - l2_decay_factor=None, - l2_decay_rank_threshold=2, - batch_size=1024, - rng_seed=-1, model_dtype='float32', - grad_clip=None, - total_accumulated_batch_size=None, dropout_rate=0.0, - label_smoothing=0.0, - use_shallue_label_smoothing=False, normalizer='pre_layer_norm', activation='gelu', resnet_style_residual=False, diff --git a/init2winit/model_lib/wide_resnet.py b/init2winit/model_lib/wide_resnet.py index 8df134a9..f9379b1f 100644 --- a/init2winit/model_lib/wide_resnet.py +++ b/init2winit/model_lib/wide_resnet.py @@ -30,31 +30,14 @@ dict( blocks_per_group=3, channel_multiplier=2, - lr_hparams={ - 'base_lr': 0.001, - 'schedule': 'cosine' - }, normalizer='batch_norm', - layer_rescale_factors={}, conv_kernel_scale=1.0, dense_kernel_scale=1.0, dropout_rate=0.0, conv_kernel_init='lecun_normal', dense_kernel_init='lecun_normal', - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, virtual_batch_size=None, - total_accumulated_batch_size=None, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, activation_function='relu', group_strides=[(1, 1), (2, 2), (2, 2)]) ) diff --git a/init2winit/model_lib/xformer_translate.py b/init2winit/model_lib/xformer_translate.py index 348117b8..ff7e5984 100644 --- a/init2winit/model_lib/xformer_translate.py +++ b/init2winit/model_lib/xformer_translate.py @@ -42,7 +42,6 @@ MLCOMMONS_DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=64, share_embeddings=False, logits_via_embedding=False, emb_dim=512, @@ -56,40 +55,18 @@ dropout_rate=0.1, aux_dropout_rate=0.1, tie_dropouts=False, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.05, - 'warmup_steps': 8000, - 'factors': 'constant * linear_warmup * rsqrt_decay', - 'schedule': 'compound' - }, - label_smoothing=0.1, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, enc_self_attn_kernel_init='xavier_uniform', dec_self_attn_kernel_init='xavier_uniform', dec_cross_attn_kernel_init='xavier_uniform', decode=False, - total_accumulated_batch_size=None, normalize_attention=False, )) DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=64, share_embeddings=False, logits_via_embedding=False, emb_dim=512, @@ -102,33 +79,12 @@ mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.05, - 'warmup_steps': 8000, - 'factors': 'constant * linear_warmup * rsqrt_decay', - 'schedule': 'compound' - }, - label_smoothing=0.1, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, enc_self_attn_kernel_init='xavier_uniform', dec_self_attn_kernel_init='xavier_uniform', dec_cross_attn_kernel_init='xavier_uniform', decode=False, - total_accumulated_batch_size=None, normalize_attention=False, )) diff --git a/init2winit/model_lib/xformer_translate_binary.py b/init2winit/model_lib/xformer_translate_binary.py index bd5b4fbd..ace8d5e4 100644 --- a/init2winit/model_lib/xformer_translate_binary.py +++ b/init2winit/model_lib/xformer_translate_binary.py @@ -43,7 +43,6 @@ def _default_binarize_hparams(): DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=64, share_embeddings=False, logits_via_embedding=False, emb_dim=512, @@ -54,33 +53,12 @@ def _default_binarize_hparams(): mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.98, - 'epsilon': 1e-9, - 'weight_decay': 0.0, - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.05, - 'warmup_steps': 8000, - 'factors': 'constant * linear_warmup * rsqrt_decay', - 'schedule': 'compound', - }, - label_smoothing=0.1, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, enc_self_attn_kernel_init='xavier_uniform', dec_self_attn_kernel_init='xavier_uniform', dec_cross_attn_kernel_init='xavier_uniform', decode=False, - total_accumulated_batch_size=None, binarize_hparams=_default_binarize_hparams(), quant_steps={ # training step at which model is partially binarized 'ff_weights': 90e3, diff --git a/init2winit/model_lib/xformer_translate_mlc_variant.py b/init2winit/model_lib/xformer_translate_mlc_variant.py index 9e825328..6cea66c1 100644 --- a/init2winit/model_lib/xformer_translate_mlc_variant.py +++ b/init2winit/model_lib/xformer_translate_mlc_variant.py @@ -42,7 +42,6 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=64, share_embeddings=False, logits_via_embedding=False, emb_dim=512, @@ -56,34 +55,13 @@ dropout_rate=0.1, aux_dropout_rate=0.1, tie_dropouts=False, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.05, - 'warmup_steps': 8000, - 'factors': 'constant * linear_warmup * rsqrt_decay', - 'schedule': 'compound' - }, - label_smoothing=0.1, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, enc_self_attn_kernel_init='xavier_uniform', dec_self_attn_kernel_init='xavier_uniform', dec_cross_attn_kernel_init='xavier_uniform', attn_kernel_scale=1.0, decode=False, - total_accumulated_batch_size=None, normalize_attention=False, glu=False, ffn_activation='relu', diff --git a/init2winit/optimizer_lib/test_optimizers.py b/init2winit/optimizer_lib/test_optimizers.py index 1d8bb6a8..e65f9720 100644 --- a/init2winit/optimizer_lib/test_optimizers.py +++ b/init2winit/optimizer_lib/test_optimizers.py @@ -72,6 +72,7 @@ def test_generic_multi_optimizer_init(self): 'loss': 'cross_entropy', 'metrics': 'classification_metrics', 'initializer': 'noop', + 'training_algorithm': 'optax_training_algorithm', 'hparam_overrides': config_dict.ConfigDict({ 'optimizer': 'generic_multi_optimizer', 'l2_decay_factor': None, @@ -96,7 +97,9 @@ def test_generic_multi_optimizer_init(self): experiment_config.model, experiment_config.initializer, experiment_config.dataset, - hparam_overrides=experiment_config.hparam_overrides) + experiment_config.training_algorithm, + hparam_overrides=experiment_config.hparam_overrides, + ) model = model_cls( merged_hps, diff --git a/init2winit/test_hyperparameters.py b/init2winit/test_hyperparameters.py index 0b85286e..645a5750 100644 --- a/init2winit/test_hyperparameters.py +++ b/init2winit/test_hyperparameters.py @@ -38,6 +38,7 @@ def test_override(self): model_name='transformer', initializer_name='noop', dataset_name='lm1b_v2', + training_algorithm_name='optax_training_algorithm', hparam_overrides=hps_overrides, ) @@ -60,6 +61,7 @@ def test_unrecognized_override(self): model_name='transformer', initializer_name='noop', dataset_name='lm1b_v2', + training_algorithm_name='optax_training_algorithm', hparam_overrides=hps_overrides, allowed_unrecognized_hparams=[], ) @@ -67,6 +69,7 @@ def test_unrecognized_override(self): model_name='transformer', initializer_name='noop', dataset_name='lm1b_v2', + training_algorithm_name='optax_training_algorithm', hparam_overrides=hps_overrides, allowed_unrecognized_hparams=['lr_hparamsTYPO'], ) @@ -81,6 +84,7 @@ def test_dot_override(self): model_name='transformer', initializer_name='noop', dataset_name='lm1b_v2', + training_algorithm_name='optax_training_algorithm', hparam_overrides=hps_overrides, ) @@ -95,7 +99,7 @@ def test_dot_override(self): } self.assertEqual( set(merged_hps.lr_hparams.keys()), - set(['schedule', 'warmup_steps', 'base_lr', 'squash_steps']), + set(['schedule', 'base_lr', 'warmup_steps', 'squash_steps']), ) self.assertEqual(merged_hps.lr_hparams.to_dict(), expected_lr_hparams) @@ -119,6 +123,7 @@ def test_optimizer_override(self): model_name='transformer', initializer_name='noop', dataset_name='lm1b_v2', + training_algorithm_name='optax_training_algorithm', hparam_overrides=hps_overrides, ) diff --git a/init2winit/tools/inspect_dataset.py b/init2winit/tools/inspect_dataset.py index 5dc15279..d9697503 100644 --- a/init2winit/tools/inspect_dataset.py +++ b/init2winit/tools/inspect_dataset.py @@ -62,6 +62,7 @@ def main(unused_argv): num_batches = FLAGS.num_batches dataset_name = FLAGS.dataset model_name = FLAGS.model + training_algorithm_name = 'optax_training_algorithm' initializer_name = 'noop' hparam_overrides = { @@ -72,7 +73,9 @@ def main(unused_argv): model_name=model_name, initializer_name=initializer_name, dataset_name=dataset_name, - hparam_overrides=hparam_overrides) + hparam_overrides=hparam_overrides, + training_algorithm_name=training_algorithm_name, + ) rng = jax.random.PRNGKey(0) rng, data_rng = jax.random.split(rng) diff --git a/init2winit/trainer_lib/test_trainer.py b/init2winit/trainer_lib/test_trainer.py index ed126a15..6ca00d1b 100644 --- a/init2winit/trainer_lib/test_trainer.py +++ b/init2winit/trainer_lib/test_trainer.py @@ -36,6 +36,7 @@ from init2winit.model_lib import models from init2winit.trainer_lib import trainer from init2winit.trainer_lib import trainer_utils +from init2winit.trainer_lib import training_algorithm import jax from jax.experimental import mesh_utils import jax.numpy as jnp @@ -389,7 +390,10 @@ def test_graph_model_trainer(self): rng = jax.random.PRNGKey(1337) model_str = 'gnn' model_cls = models.get_model(model_str) - hps = models.get_model_hparams(model_str) + hps = ( + training_algorithm.OptaxTrainingAlgorithm.get_default_training_hparams() + ) + hps.update(models.get_model_hparams(model_str)) hps.update({ 'batch_size': 2, 'input_edge_shape': (7,), @@ -455,11 +459,14 @@ def test_dlrm_model_trainer(self): model_str = 'dlrm' dataset_str = 'criteo1tb' model_cls = models.get_model(model_str) - model_hps = models.get_model_hparams(model_str) + model_hps = ( + training_algorithm.OptaxTrainingAlgorithm.get_default_training_hparams() + ) + model_hps.update(models.get_model_hparams(model_str)) model_hps.vocab_size = 1024 dataset_hps = datasets.get_dataset_hparams(dataset_str) dataset_hps.update({ - 'batch_size': model_hps.batch_size, + 'batch_size': 128, 'num_dense_features': model_hps.num_dense_features, 'vocab_size': model_hps.vocab_size, }) @@ -479,6 +486,7 @@ def test_dlrm_model_trainer(self): 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, 'num_device_prefetches': 0, + 'batch_size': dataset_hps.batch_size, }) model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) initializer = initializers.get_initializer('noop') @@ -678,8 +686,10 @@ def test_trainer(self): model_name, initializer_name, dataset_name, + training_algorithm_name='optax_training_algorithm', hparam_overrides=hparam_overrides, - input_pipeline_hps=input_pipeline_hps) + input_pipeline_hps=input_pipeline_hps, + ) eval_batch_size = 16 num_examples = 256 @@ -961,9 +971,10 @@ def test_early_stopping(self, min_steps): initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { - 'lr_hparams': { - 'base_lr': 0.1, - 'schedule': 'cosine' + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'cosine'}, + 'optimizer': 'momentum', + 'opt_hparams': { + 'momentum': 0.9, }, 'batch_size': 8, 'train_size': 160, @@ -979,8 +990,10 @@ def test_early_stopping(self, min_steps): model_name, initializer_name, dataset_name, + training_algorithm_name='optax_training_algorithm', hparam_overrides=hparam_overrides, - input_pipeline_hps=input_pipeline_hps) + input_pipeline_hps=input_pipeline_hps, + ) eval_batch_size = 16 num_examples = 256 diff --git a/init2winit/trainer_lib/training_algorithm.py b/init2winit/trainer_lib/training_algorithm.py index 948e80fb..f509b1f0 100644 --- a/init2winit/trainer_lib/training_algorithm.py +++ b/init2winit/trainer_lib/training_algorithm.py @@ -18,12 +18,14 @@ import abc import collections +from absl import logging from init2winit import schedules from init2winit.model_lib import model_utils from init2winit.optimizer_lib import gradient_accumulator from init2winit.optimizer_lib import optimizers import jax import jax.numpy as jnp +from ml_collections.config_dict import config_dict import optax @@ -111,6 +113,27 @@ def __init__(self, hps, model, num_train_steps): self.hps = hps self.eval_report_metrics = collections.defaultdict() + @classmethod + def get_default_training_hparams(cls): + """Returns default training hyperparameters. + + The base class provides infrastructure-level defaults. Subclasses should + call super() and merge in their own optimizer/lr defaults. + + Returns: + A ConfigDict of default training hyperparameters. + """ + return config_dict.ConfigDict({ + 'batch_size': None, + 'total_accumulated_batch_size': None, + 'l2_decay_factor': None, + 'l2_decay_rank_threshold': 2, + 'label_smoothing': None, + 'rng_seed': -1, + 'use_shallue_label_smoothing': False, + 'layer_rescale_factors': {}, + }) + @abc.abstractmethod def update_params( self, @@ -175,9 +198,649 @@ def init_optimizer_state( """ +# Per-optimizer default opt_hparams for OptaxTrainingAlgorithm. +# These consolidate all the inline defaults from get_optimizer() in +# optimizer_lib/optimizers.py so that configs don't need to redundantly +# specify values that match the defaults. +_OPTAX_OPTIMIZER_DEFAULTS = { + 'adam': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'sgd': { + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'momentum': { + 'momentum': 0.9, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'nesterov': { + 'momentum': 0.9, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'nadam': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'epsilon_root': 0.0, + 'debias': True, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'generalized_adam': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'nesterov': True, + 'power': 2.0, + 'disable_preconditioning': False, + 'epsilon_root': 0.0, + 'debias': True, + 'weight_decay': 0.0, + 'disable_multiply_wd_by_base_lr': False, + 'grad_clip': None, + }, + 'nadamp': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'epsilon_root': 0.0, + 'debias': True, + 'power': 2.0, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'adaprop': { + 'beta1': 0.9, + 'beta2': 0.999, + 'beta3': 1.0, + 'beta4': 0.999, + 'epsilon': 1e-8, + 'power': 2.0, + 'nesterov': True, + 'quantized_dtype': 'float32', + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'adafactor': { + 'min_dim_size_to_factor': 128, + 'adafactor_decay_rate': 0.8, + 'decay_offset': 0, + 'multiply_by_parameter_scale': True, + 'clipping_threshold': 1.0, + 'momentum': None, + 'epsilon': 1e-30, + 'factored': True, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'lamb': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-6, + 'epsilon_root': 0.0, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'lars': { + 'trust_coefficient': 0.001, + 'epsilon': 1e-6, + 'momentum': 0.9, + 'nesterov': False, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'adabelief': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'epsilon_root': 0.0, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'radam': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'epsilon_root': 0.0, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'adam_relative_update': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'epsilon_root': 0.0, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'herolion': { + 'beta1': 0.9, + 'beta2': 0.99, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'bubbles': { + 'weight_decay': 0.01, + 'beta1': 0.9, + 'beta2': 0.999, + 'nesterov': True, + 'min_steps': 100, + 'grad_rms_threshold': 10.0, + 'precond_grad_clip': None, + 'bias_correction': True, + 'grad_clip': None, + }, + 'lora_bubbles': { + 'weight_decay': 0.01, + 'beta1': 0.9, + 'beta2': 0.999, + 'nesterov': True, + 'eps': 1e-7, + 'lora_min_steps': 100, + 'lora_update_steps': 20, + 'lora_rank': 64, + 'grad_rms_threshold': 10.0, + 'precond_grad_clip': None, + 'bias_correction': True, + 'grad_clip': None, + }, + 'diag_bubbles': { + 'beta1': None, + 'beta2': 0.999, + 'eps': 1e-8, + 'precond_grad_clip': None, + 'nesterov': False, + 'bias_correction': True, + 'weight_decay': 1e-4, + 'grad_clip': None, + }, + 'decoupled_adam': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'epsilon_root': 0.0, + 'weight_decay': 0.0, + 'grad_clip': None, + }, + 'adan': { + 'beta1': 0.98, + 'beta2': 0.92, + 'beta3': 0.99, + 'epsilon': 1e-8, + 'epsilon_root': 0.0, + 'weight_decay': 0.0, + 'use_adamw_wd': True, + 'tie_b1_b2': False, + 'grad_clip': None, + }, + 'sm3': { + 'beta1': 0.9, + 'beta2': 0.999, + 'diagonal_epsilon': 1e-10, + 'weight_decay': 0.0, + 'normalize_grads': False, + 'grad_clip': None, + }, +} + +# UNet piecewise_constant schedule constants (originally in unet.py). +_FASTMRI_TRAIN_SIZE = 34742 +_UNET_BATCH_SIZE = 8 +_UNET_NUM_EPOCHS = 50 +_UNET_STEPS_PER_EPOCH = int(_FASTMRI_TRAIN_SIZE / _UNET_BATCH_SIZE) +_UNET_NUM_TRAIN_STEPS = _UNET_NUM_EPOCHS * _UNET_STEPS_PER_EPOCH +_UNET_LR_GAMMA = 0.1 +_UNET_LR_STEP_SIZE = 40 * _UNET_STEPS_PER_EPOCH +_UNET_DECAY_EVENTS = list( + range(_UNET_LR_STEP_SIZE, _UNET_NUM_TRAIN_STEPS, _UNET_LR_STEP_SIZE) +) +_UNET_DECAY_FACTORS = [ + _UNET_LR_GAMMA**i for i in range(1, len(_UNET_DECAY_EVENTS) + 1) +] + +# Per-model training defaults, preserving the historical optimizer configuration +# that each model was originally tuned with. These are used as Tier 2 fallback +# when no explicit optimizer is specified in hparam_overrides. +_MODEL_TRAINING_DEFAULTS = { + # Vision models with momentum + 'adabelief_densenet': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.9}, + 'lr_hparams': {'base_lr': 0.2, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 0.0001, + }, + 'adabelief_resnet': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.9}, + 'lr_hparams': {'base_lr': 0.2, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 0.0001, + }, + 'adabelief_vgg': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.9}, + 'lr_hparams': {'base_lr': 0.2, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 0.0001, + }, + 'resnet': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.9}, + 'lr_hparams': {'base_lr': 0.2, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 0.0001, + }, + 'fully_connected': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.9}, + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 0.0005, + }, + 'simple_cnn': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.9}, + 'lr_hparams': {'base_lr': 0.001, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 0.0005, + }, + 'max_pooling_cnn': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.9}, + 'lr_hparams': {'base_lr': 0.001, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 0.0005, + }, + 'wide_resnet': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.9}, + 'lr_hparams': {'base_lr': 0.001, 'schedule': 'cosine'}, + 'batch_size': 128, + 'l2_decay_factor': 0.0001, + }, + 'convolutional_autoencoder': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.0}, + 'lr_hparams': {'base_lr': 0.02, 'schedule': 'constant'}, + 'batch_size': 128, + }, + 'nqm': { + 'optimizer': 'momentum', + 'opt_hparams': {'momentum': 0.0}, + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'constant'}, + 'batch_size': 128, + }, + # Vision models with adam + 'vit': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.1}, + 'lr_hparams': {'base_lr': 1e-3, 'schedule': 'cosine_warmup'}, + 'batch_size': 1024, + }, + # Speech models + 'conformer': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0, 'grad_clip': 5.0}, + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'constant'}, + 'batch_size': 256, + 'l2_decay_factor': 1e-6, + 'l2_decay_rank_threshold': 0, + }, + 'mlcommons_conformer': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0, 'grad_clip': 5.0}, + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'constant'}, + 'batch_size': 256, + 'l2_decay_factor': 1e-6, + 'l2_decay_rank_threshold': 0, + }, + 'deepspeech': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0, 'grad_clip': 10.0}, + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'constant'}, + 'batch_size': 256, + 'l2_decay_factor': 1e-6, + 'l2_decay_rank_threshold': 0, + }, + 'mlcommons_deepspeech': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0, 'grad_clip': 10.0}, + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'constant'}, + 'batch_size': 256, + 'l2_decay_factor': 1e-6, + 'l2_decay_rank_threshold': 0, + }, + 'transformer': { + 'optimizer': 'adam', + 'opt_hparams': { + 'beta1': 0.9, + 'beta2': 0.98, + 'epsilon': 1e-9, + 'weight_decay': 0.1, + }, + 'lr_hparams': { + 'base_lr': 0.0016, + 'warmup_steps': 1000, + 'squash_steps': 1000, + 'schedule': 'rsqrt_normalized_decay_warmup', + }, + 'batch_size': 512, + 'l2_decay_rank_threshold': 0, + }, + 'performer': { + 'optimizer': 'adam', + 'opt_hparams': { + 'beta1': 0.9, + 'beta2': 0.98, + 'epsilon': 1e-9, + 'weight_decay': 0.1, + }, + 'lr_hparams': { + 'base_lr': 0.0016, + 'warmup_steps': 1000, + 'squash_steps': 1000, + 'schedule': 'rsqrt_normalized_decay_warmup', + }, + 'batch_size': 512, + 'l2_decay_rank_threshold': 0, + }, + 'transformer_stu': { + 'optimizer': 'adam', + 'opt_hparams': { + 'beta1': 0.9, + 'beta2': 0.98, + 'epsilon': 1e-9, + 'weight_decay': 0.1, + }, + 'lr_hparams': { + 'base_lr': 0.0016, + 'warmup_steps': 1000, + 'squash_steps': 1000, + 'schedule': 'rsqrt_normalized_decay_warmup', + }, + 'batch_size': 512, + 'l2_decay_rank_threshold': 0, + }, + 'transformer_stu_tensordot': { + 'optimizer': 'adam', + 'opt_hparams': { + 'beta1': 0.9, + 'beta2': 0.98, + 'epsilon': 1e-9, + 'weight_decay': 0.1, + }, + 'lr_hparams': { + 'base_lr': 0.0016, + 'warmup_steps': 1000, + 'squash_steps': 1000, + 'schedule': 'rsqrt_normalized_decay_warmup', + }, + 'batch_size': 512, + 'l2_decay_rank_threshold': 0, + }, + 'xformer_translate': { + 'optimizer': 'adam', + 'opt_hparams': { + 'beta1': 0.9, + 'beta2': 0.98, + 'epsilon': 1e-9, + 'weight_decay': 0.0, + }, + 'lr_hparams': { + 'base_lr': 0.05, + 'warmup_steps': 8000, + 'factors': 'constant * linear_warmup * rsqrt_decay', + 'schedule': 'compound', + }, + 'batch_size': 64, + 'l2_decay_rank_threshold': 0, + }, + 'mlcommons_xformer_translate': { + 'optimizer': 'adam', + 'opt_hparams': { + 'beta1': 0.9, + 'beta2': 0.98, + 'epsilon': 1e-9, + 'weight_decay': 0.0, + }, + 'lr_hparams': { + 'base_lr': 0.05, + 'warmup_steps': 8000, + 'factors': 'constant * linear_warmup * rsqrt_decay', + 'schedule': 'compound', + }, + 'batch_size': 64, + 'l2_decay_rank_threshold': 0, + }, + 'xformer_translate_binary': { + 'optimizer': 'adam', + 'opt_hparams': { + 'beta1': 0.9, + 'beta2': 0.98, + 'epsilon': 1e-9, + 'weight_decay': 0.0, + }, + 'lr_hparams': { + 'base_lr': 0.05, + 'warmup_steps': 8000, + 'factors': 'constant * linear_warmup * rsqrt_decay', + 'schedule': 'compound', + }, + 'batch_size': 64, + 'l2_decay_rank_threshold': 0, + }, + 'xformer_translate_mlc_variant': { + 'optimizer': 'adam', + 'opt_hparams': { + 'beta1': 0.9, + 'beta2': 0.98, + 'epsilon': 1e-9, + 'weight_decay': 0.0, + }, + 'lr_hparams': { + 'base_lr': 0.05, + 'warmup_steps': 8000, + 'factors': 'constant * linear_warmup * rsqrt_decay', + 'schedule': 'compound', + }, + 'batch_size': 64, + 'l2_decay_rank_threshold': 0, + }, + 'lstm': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0}, + 'lr_hparams': {'base_lr': 1e-3, 'schedule': 'constant'}, + 'batch_size': 256, + }, + 'local_attention_transformer': { + 'optimizer': 'adafactor', + 'opt_hparams': {}, + 'lr_hparams': { + 'base_lr': 0.01, + 'defer_steps': 10000, + 'schedule': 't2t_rsqrt_normalized_decay', + }, + 'batch_size': 8, + }, + # GNN / tabular + 'gnn': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0}, + 'lr_hparams': {'base_lr': 0.01, 'schedule': 'constant'}, + 'batch_size': 256, + 'l2_decay_factor': 0.0005, + }, + 'dlrm': { + 'optimizer': 'adam', + 'opt_hparams': {}, + 'lr_hparams': {'base_lr': 0.01, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 1e-5, + }, + 'dlrm_resnet': { + 'optimizer': 'adam', + 'opt_hparams': {}, + 'lr_hparams': {'base_lr': 0.01, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 1e-5, + }, + # Nanodo family + 'nanodo': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0}, + 'lr_hparams': {'base_lr': 0.01, 'schedule': 'constant'}, + 'batch_size': 256, + 'l2_decay_factor': 0.0005, + }, + 'rope_nanodo': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0}, + 'lr_hparams': {'base_lr': 0.01, 'schedule': 'constant'}, + 'batch_size': 256, + 'l2_decay_factor': 0.0005, + }, + 'mdlm_rope_nanodo': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0}, + 'lr_hparams': {'base_lr': 0.01, 'schedule': 'constant'}, + 'batch_size': 256, + 'l2_decay_factor': 0.0005, + }, + # Special + 'autoencoder': { + 'optimizer': 'hessian_free', + 'opt_hparams': {'damping_lb': 1e-6}, + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'constant'}, + 'batch_size': 128, + 'l2_decay_factor': 2e-5, + 'l2_decay_rank_threshold': 1, + }, + 'mlperf_resnet': { + 'optimizer': 'lars', + 'opt_hparams': {'weight_decay': 2e-4, 'momentum': 0.9}, + 'lr_hparams': { + 'base_lr': 10.0, + 'schedule': 'mlperf_polynomial', + 'start_lr': 0.0, + 'end_lr': 1e-4, + 'power': 2.0, + 'decay_end': -1, + 'warmup_steps': 18, + 'warmup_power': 1, + }, + 'batch_size': 128, + }, + 'fake_resnet': { + 'optimizer': 'lars', + 'opt_hparams': {'weight_decay': 2e-4, 'momentum': 0.9}, + 'lr_hparams': { + 'base_lr': 10.0, + 'schedule': 'mlperf_polynomial', + 'start_lr': 0.0, + 'end_lr': 1e-4, + 'power': 2.0, + 'decay_end': -1, + 'warmup_steps': 18, + 'warmup_power': 1, + }, + 'batch_size': 128, + }, + 'unet': { + 'optimizer': 'adam', + 'opt_hparams': {'weight_decay': 0.0}, + 'lr_hparams': { + 'schedule': 'piecewise_constant', + 'base_lr': 1e-3, + 'decay_events': _UNET_DECAY_EVENTS, + 'decay_factors': _UNET_DECAY_FACTORS, + }, + 'batch_size': 8, + }, +} + + class OptaxTrainingAlgorithm(TrainingAlgorithm): """Class for training algorithms implemented with optax and defined in optimizer_lib.optimizers.py.""" + @classmethod + def get_default_training_hparams(cls, optimizer_name=None, model_name=None): + """Returns default training hparams for optax-based training. + + Resolution hierarchy: + 1. If optimizer_name is provided, use optimizer-specific defaults. + 2. Else if model_name is in _MODEL_TRAINING_DEFAULTS, use model-specific + defaults (which include the model's historical optimizer choice). + 3. Else fall back to 'adam' defaults. + + Args: + optimizer_name: Optional name of the optimizer to get defaults for. When + provided, takes precedence over model_name. + model_name: Optional name of the model. Used to look up historical + per-model training defaults when no optimizer_name is specified. + + Returns: + A ConfigDict of default training hyperparameters including + optimizer-specific opt_hparams looked up from the per-optimizer defaults + table. + """ + training_hparams = super().get_default_training_hparams() + + if optimizer_name is not None: + # Tier 1: explicit optimizer override. + logging.info( + 'Using optimizer-specific defaults for optimizer=%s', + optimizer_name, + ) + opt_defaults = dict(_OPTAX_OPTIMIZER_DEFAULTS.get(optimizer_name, {})) + training_hparams.update({ + 'optimizer': optimizer_name, + 'opt_hparams': opt_defaults, + 'lr_hparams': { + 'base_lr': 0.001, + 'schedule': 'cosine', + }, + }) + elif model_name is not None and model_name in _MODEL_TRAINING_DEFAULTS: + # Tier 2: model-specific historical defaults. + logging.info( + 'Using model-specific training defaults for model=%s', model_name + ) + model_defaults = dict(_MODEL_TRAINING_DEFAULTS[model_name]) + model_optimizer = model_defaults.pop('optimizer', 'adam') + # Start with the optimizer's own defaults, then overlay model-specific. + opt_defaults = dict(_OPTAX_OPTIMIZER_DEFAULTS.get(model_optimizer, {})) + opt_defaults.update(model_defaults.pop('opt_hparams', {})) + training_hparams.update({ + 'optimizer': model_optimizer, + 'opt_hparams': opt_defaults, + 'lr_hparams': model_defaults.pop( + 'lr_hparams', {'base_lr': 0.001, 'schedule': 'cosine'} + ), + }) + # Apply remaining model-level overrides (batch_size, l2_decay, etc.) + training_hparams.update(model_defaults) + else: + # Tier 3: generic adam fallback. + opt_defaults = dict(_OPTAX_OPTIMIZER_DEFAULTS.get('adam', {})) + training_hparams.update({ + 'optimizer': 'adam', + 'opt_hparams': opt_defaults, + 'lr_hparams': { + 'base_lr': 0.001, + 'schedule': 'cosine', + }, + }) + + return training_hparams + def __init__(self, hps, model, num_train_steps): super().__init__(hps, model, num_train_steps) self._optimizer_state = None