Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions init2winit/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from init2winit.dataset_lib import datasets
from init2winit.init_lib import initializers
from init2winit.model_lib import models
from init2winit.trainer_lib import training_algorithm
from init2winit.trainer_lib import training_algorithms
from ml_collections.config_dict import config_dict


Expand Down Expand Up @@ -95,18 +97,24 @@ def expand_dot_keys(d):
return expanded_dict


def build_hparams(model_name,
initializer_name,
dataset_name,
hparam_overrides,
input_pipeline_hps=None,
allowed_unrecognized_hparams=None):
def build_hparams(
model_name,
initializer_name,
dataset_name,
training_algorithm_name,
hparam_overrides,
input_pipeline_hps=None,
allowed_unrecognized_hparams=None,
):
"""Build experiment hyperparameters.

Args:
model_name: the string model name.
initializer_name: the string initializer name.
dataset_name: the string dataset name.
training_algorithm_name: the string name of the training algorithm. Used to
look up algorithm-specific default training hyperparameters (optimizer,
opt_hparams, lr_hparams).
hparam_overrides: a dict of hyperparameter override names/values, or a JSON
string encoding of this hyperparameter override dict. Note that this is
applied after the hyperparameter file overrides.
Expand All @@ -125,22 +133,49 @@ def build_hparams(model_name,
initializer_hps = initializers.get_initializer_hparams(initializer_name)
dataset_hps = datasets.get_dataset_hparams(dataset_name)
input_pipeline_hps = input_pipeline_hps or config_dict.ConfigDict()
overrides_dict = hparam_overrides or {}
if isinstance(overrides_dict, str):
overrides_dict = json.loads(overrides_dict)

# Training hparams come from the training algorithm.
algo_cls = training_algorithms.get_training_algorithm(training_algorithm_name)

# For OptaxTrainingAlgorithm, pass optimizer_name (if overridden) and
# model_name so it can resolve defaults using the 3-tier hierarchy.
if issubclass(algo_cls, training_algorithm.OptaxTrainingAlgorithm):
optimizer_name = (
None if not overrides_dict else overrides_dict.get('optimizer')
)
training_hps = algo_cls.get_default_training_hparams(
optimizer_name=optimizer_name,
model_name=model_name,
)
else:
training_hps = algo_cls.get_default_training_hparams()

merged_dict = {}

hps_dicts = [
hps.to_dict()
for hps in [model_hps, initializer_hps, dataset_hps, input_pipeline_hps]
for hps in [
training_hps,
model_hps,
initializer_hps,
dataset_hps,
input_pipeline_hps,
]
]

total_hps = 0
# Check that all provided hps have no overlap.
seen_keys = set()
for hps_dict in hps_dicts:
merged_dict.update(hps_dict)
total_hps += len(hps_dict.keys())
overlap = seen_keys.intersection(hps_dict.keys())
if overlap:
raise ValueError(f'There is overlap in the provided hparams: {overlap}')
seen_keys.update(hps_dict.keys())

# Check that all provided have no overlap.
if total_hps != len(merged_dict.keys()):
raise ValueError('There is overlap in the provided hparams.')
for hps_dict in hps_dicts:
merged_dict.update(hps_dict)

# Convert to the Shallue and Lee label smoothing style.
if merged_dict.get('use_shallue_label_smoothing', False):
Expand All @@ -159,34 +194,33 @@ def build_hparams(model_name,
for key in ['opt_hparams', 'lr_hparams']:
merged[key].unlock()

if hparam_overrides:
if isinstance(hparam_overrides, str):
hparam_overrides = json.loads(hparam_overrides)

if overrides_dict:
# If the user is changing the learning rate schedule or optimizer. We must
# wipe all of the keys from the old dictionary.
merged_schedule = None
if merged.get('lr_hparams'):
merged_schedule = merged['lr_hparams'].get('schedule')
overrides_schedule = None
if hparam_overrides.get('lr_hparams'):
overrides_schedule = hparam_overrides['lr_hparams'].get('schedule')
if overrides_dict.get('lr_hparams'):
overrides_schedule = overrides_dict['lr_hparams'].get('schedule')
if overrides_schedule and merged_schedule != overrides_schedule:
merged['lr_hparams'] = {}
if ('optimizer' in hparam_overrides and
merged['optimizer'] != hparam_overrides['optimizer']):
if (
'optimizer' in overrides_dict
and merged['optimizer'] != overrides_dict['optimizer']
):
merged['opt_hparams'] = {}
hparam_overrides = expand_dot_keys(hparam_overrides)
overrides_dict = expand_dot_keys(overrides_dict)
if allowed_unrecognized_hparams:
new_keys = [k for k in hparam_overrides if k not in merged]
new_keys = [k for k in overrides_dict if k not in merged]
if new_keys:
logging.warning('Unrecognized top-level hparams: %s', new_keys)
if any(k not in allowed_unrecognized_hparams for k in new_keys):
raise ValueError(
f'Unrecognized top-level hparams not in allowlist: {new_keys}')
with merged.unlocked():
merged.update(hparam_overrides)
merged.update(overrides_dict)
else:
merged.update(hparam_overrides)
merged.update(overrides_dict)

return merged
4 changes: 3 additions & 1 deletion init2winit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ def _run(
model_name=model_name,
initializer_name=initializer_name,
dataset_name=dataset_name,
training_algorithm_name=training_algorithm_name,
hparam_overrides=hparam_overrides,
input_pipeline_hps=input_pipeline_hps,
allowed_unrecognized_hparams=allowed_unrecognized_hparams)
allowed_unrecognized_hparams=allowed_unrecognized_hparams,
)

# Note that one should never tune an RNG seed!!! The seed is only included in
# the hparams for convenience of running hparam trials with multiple seeds per
Expand Down
19 changes: 2 additions & 17 deletions init2winit/model_lib/adabelief_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,12 @@
# results in a large Dense matrix in the readout layer and unstable
# training.
use_kernel_size_as_stride_in_pooling=True,
layer_rescale_factors={},
lr_hparams={
'schedule': 'constant',
'base_lr': 0.2,
},
normalizer='batch_norm',
optimizer='momentum',
opt_hparams={
'momentum': 0.9,
},
batch_size=128,
l2_decay_factor=0.0001,
l2_decay_rank_threshold=2,
label_smoothing=None,
rng_seed=-1,
use_shallue_label_smoothing=False,
model_dtype='float32',
grad_clip=None,
normalize_classifier_input='none',
classification_scale_factor=1.0,
))
)
)


class BottleneckBlock(nn.Module):
Expand Down
17 changes: 0 additions & 17 deletions init2winit/model_lib/adabelief_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,12 @@
dict(
num_filters=16,
num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200]
layer_rescale_factors={},
lr_hparams={
'schedule': 'constant',
'base_lr': 0.2,
},
optimizer='momentum',
opt_hparams={
'momentum': 0.9,
},
batch_size=128,
l2_decay_factor=0.0001,
l2_decay_rank_threshold=2,
label_smoothing=None,
rng_seed=-1,
use_shallue_label_smoothing=False,
batch_norm_momentum=0.9,
batch_norm_epsilon=1e-5,
# Make this a string to avoid having to import jnp into the configs.
model_dtype='float32',
virtual_batch_size=None,
total_accumulated_batch_size=None,
data_format='NHWC',
grad_clip=None,
))


Expand Down
19 changes: 2 additions & 17 deletions init2winit/model_lib/adabelief_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,10 @@
DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
num_layers=11, # Must be one of [11, 13, 16, 19]
layer_rescale_factors={},
lr_hparams={
'schedule': 'constant',
'base_lr': 0.2,
},
normalizer='none',
optimizer='momentum',
opt_hparams={
'momentum': 0.9,
},
batch_size=128,
l2_decay_factor=0.0001,
l2_decay_rank_threshold=2,
label_smoothing=None,
rng_seed=-1,
use_shallue_label_smoothing=False,
model_dtype='float32',
grad_clip=None,
))
)
)


def classifier(x, num_outputs, dropout_rate, deterministic):
Expand Down
21 changes: 0 additions & 21 deletions init2winit/model_lib/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,7 @@
hid_sizes=[128, 64, 32, 64, 128],
activation_function=['relu', 'relu', 'relu', 'relu', 'relu'],
kernel_scales=[1.0] * 6,
lr_hparams={
'base_lr': 0.1,
'schedule': 'constant'
},
layer_rescale_factors={},
optimizer='hessian_free',
opt_hparams={
'cg_max_iter': 250,
'cg_iter_tracking_method': 'back_tracking',
'use_line_search': True,
'init_damping': 50.0,
'damping_ub': 10 ** 2,
'damping_lb': 10 ** -6,
},
batch_size=128,
label_smoothing=None,
rng_seed=-1,
use_shallue_label_smoothing=False,
model_dtype='float32',
l2_decay_factor=2e-5,
l2_decay_rank_threshold=1,
))


Expand All @@ -82,4 +62,3 @@ def get_fake_inputs(self, hps):
jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype)
]
return dummy_inputs

36 changes: 0 additions & 36 deletions init2winit/model_lib/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,8 @@
MLCOMMONS_DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
activation_function='swish',
optimizer='adam',
opt_hparams={
'beta1': .9,
'beta2': .98,
'epsilon': 1e-9,
'weight_decay': 0.0
},
lr_hparams={
'base_lr': 0.1,
'schedule': 'constant'
},
batch_size=256,
eval_batch_size=128,
l2_decay_factor=1e-6,
l2_decay_rank_threshold=0,
use_shallue_label_smoothing=False,
rng_seed=-1,
model_dtype='float32',
grad_clip=5.0,
encoder_dim=512,
num_attention_heads=8,
num_encoder_layers=16,
Expand All @@ -84,33 +67,15 @@
enable_decoder_pre_layer_norm=True,
enable_conformer_post_layer_norm=True,
use_lingvo_attention=False,
total_accumulated_batch_size=None,
attn_temperature=1.0,
))


DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
activation_function='swish',
optimizer='adam',
opt_hparams={
'beta1': .9,
'beta2': .98,
'epsilon': 1e-9,
'weight_decay': 0.0
},
lr_hparams={
'base_lr': 0.1,
'schedule': 'constant'
},
batch_size=256,
eval_batch_size=128,
l2_decay_factor=1e-6,
l2_decay_rank_threshold=0,
use_shallue_label_smoothing=False,
rng_seed=-1,
model_dtype='float32',
grad_clip=5.0,
encoder_dim=512,
num_attention_heads=8,
num_encoder_layers=16,
Expand All @@ -128,7 +93,6 @@
enable_decoder_pre_layer_norm=True,
enable_conformer_post_layer_norm=True,
use_lingvo_attention=False,
total_accumulated_batch_size=None,
attn_temperature=1.0))


Expand Down
17 changes: 0 additions & 17 deletions init2winit/model_lib/convolutional_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,8 @@
'paddings': ['SAME', ((1, 0), (1, 0)), 'SAME', 'SAME'],
'activations': ['relu', 'relu', 'relu', 'id'],
},

activation_function='relu',
lr_hparams={
'base_lr': 0.02,
'schedule': 'constant'
},
layer_rescale_factors={},
optimizer='momentum',
opt_hparams={
'momentum': 0,
},
batch_size=128,
l2_decay_factor=None,
l2_decay_rank_threshold=0,
label_smoothing=None,
rng_seed=-1,
use_shallue_label_smoothing=False,
model_dtype='float32',
grad_clip=None,
))


Expand Down
Loading
Loading