Skip to content

Commit 62c1536

Browse files
committed
pref: pad to max length for TPU
1 parent aa57ad6 commit 62c1536

4 files changed

Lines changed: 29 additions & 8 deletions

File tree

bsmetadata/experiments/with_metadata.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import functools
22
import logging
33

4+
from accelerate import DistributedType
45
from datasets import load_dataset
56
from torch.utils.data import DataLoader
6-
from transformers import default_data_collator
7+
from transformers import default_data_collator, DataCollatorWithPadding
78

89
from bsmetadata.metadata_utils import add_metadata_and_chunk_examples
910

@@ -124,15 +125,22 @@ def create_labels_column(examples):
124125
val_dataset = lm_datasets["validation"]
125126

126127
# DataLoaders creation:
128+
data_collator = default_data_collator
129+
if args.distributed_type == DistributedType.TPU:
130+
data_collator = DataCollatorWithPadding(
131+
tokenizer,
132+
padding="max_length",
133+
max_length=args.max_seq_len
134+
)
127135
train_dataloader = DataLoader(
128136
train_dataset,
129137
shuffle=True,
130-
collate_fn=default_data_collator,
138+
collate_fn=data_collator,
131139
batch_size=args.per_device_train_batch_size,
132140
)
133141
val_dataloader1 = DataLoader(
134142
val_dataset,
135-
collate_fn=default_data_collator,
143+
collate_fn=data_collator,
136144
batch_size=args.per_device_eval_batch_size,
137145
)
138146
return train_dataloader, {"val1": val_dataloader1}

bsmetadata/experiments/without_metadata.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
22

3+
from accelerate import DistributedType
34
from datasets import load_dataset
45
from torch.utils.data import DataLoader
5-
from transformers import default_data_collator
6+
from transformers import default_data_collator, DataCollatorWithPadding
67

78

89
logger = logging.getLogger(__name__)
@@ -157,15 +158,22 @@ def group_texts(examples):
157158
val_dataset = lm_datasets["validation"]
158159

159160
# DataLoaders creation:
161+
data_collator = default_data_collator
162+
if args.distributed_type == DistributedType.TPU:
163+
data_collator = DataCollatorWithPadding(
164+
tokenizer,
165+
padding="max_length",
166+
max_length=args.max_seq_len
167+
)
160168
train_dataloader = DataLoader(
161169
train_dataset,
162170
shuffle=True,
163-
collate_fn=default_data_collator,
171+
collate_fn=data_collator,
164172
batch_size=args.per_device_train_batch_size,
165173
)
166174
val_dataloader1 = DataLoader(
167175
val_dataset,
168-
collate_fn=default_data_collator,
176+
collate_fn=data_collator,
169177
batch_size=args.per_device_eval_batch_size,
170178
)
171179
return train_dataloader, {"val1": val_dataloader1}

bsmetadata/input_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass, field
22
from typing import List, Optional
33

4+
from accelerate import DistributedType
5+
46

57
@dataclass
68
class DataConfig:
@@ -62,6 +64,7 @@ class DataConfig:
6264
block_size: Optional[int] = field(
6365
default=None, metadata={"help": "Optional input sequence length after tokenization."}
6466
)
67+
distributed_type: DistributedType = field(default=DistributedType.NO)
6568

6669

6770
def get_dataloaders(tokenizer, cfg: DataConfig):

bsmetadata/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,14 @@ def loss_fn(batch, outputs, metadata_mask=None):
104104
return loss
105105

106106

107-
@hydra.main(config_name="config")
107+
@hydra.main(config_path=None, config_name="config")
108108
def main(args: CFG) -> None:
109+
accelerator = Accelerator()
110+
args.data_config.distributed_type = accelerator.distributed_type
111+
109112
print(OmegaConf.to_yaml(args))
110113

111114
set_seed(args.seed)
112-
accelerator = Accelerator()
113115
is_local_main_process = accelerator.is_local_main_process
114116
tqdm = partial(original_tqdm, disable=not is_local_main_process)
115117

0 commit comments

Comments
 (0)