Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,22 @@ python3 -m tokenlearn.featurize \
--dataset-split "train"
```

The output is a standard HuggingFace dataset saved to `--output-dir`. You can optionally push it to the Hub:

```bash
python3 -m tokenlearn.featurize --model-name "baai/bge-base-en-v1.5" --output-dir "data/c4_features" --push-to-hub "username/my-featurized-dataset"
```

To train a model on the featurized data, the `tokenlearn-train` CLI can be used:
```bash
python3 -m tokenlearn.train --model-name "baai/bge-base-en-v1.5" --data-path "data/c4_features" --save-path "<path-to-save-model>"
```

`--data-path` also accepts a HuggingFace Hub repo ID if you have a featurized dataset there:
```bash
python3 -m tokenlearn.train --model-name "baai/bge-base-en-v1.5" --data-path "username/my-featurized-dataset" --save-path "<path-to-save-model>"
```

Training will create two models:
- The base trained model.
- The base model with weighting applied. This is the model that should be used for downstream tasks.
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ ignore_missing_imports = true
[tool.setuptools]
packages = ["tokenlearn"]

[tool.setuptools.package-data]
tokenlearn = ["datacards/*.md"]

[tool.setuptools_scm]
# can be empty if no extra settings are needed, presence enables setuptools_scm

Expand Down
74 changes: 74 additions & 0 deletions tokenlearn/datacards/dataset_card_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
---
{{ card_data }}
---

# {{ repo_id or dataset_name }} Dataset Card

This dataset was created with [Tokenlearn](https://github.com/MinishLab/tokenlearn) for training [Model2Vec](https://github.com/MinishLab/model2vec) models. It contains mean token embeddings produced by a sentence transformer, used as training targets for static embedding distillation.

## Dataset Details

| Field | Value |
|---|---|
| **Source dataset** | [{{ source_dataset }}](https://huggingface.co/datasets/{{ source_dataset }}) |
| **Source split** | `{{ source_split }}` |
| **Embedding model** | [{{ model_name }}](https://huggingface.co/{{ model_name }}) |
| **Embedding dimension** | {{ embedding_dim }} |
| **Rows** | {{ num_rows }} |

## Dataset Structure

| Column | Type | Description |
|---|---|---|
| `text` | `string` | Truncated input text |
| `embedding` | `list[float32]` | Mean token embedding from `{{ model_name }}`, excluding BOS/EOS tokens |

## Usage

Load with the `datasets` library:

```python
from datasets import load_dataset

dataset = load_dataset("{{ repo_id or dataset_name }}")
```

Train a Model2Vec model on this dataset using Tokenlearn:

```bash
python -m tokenlearn.train \
--model-name "{{ model_name }}" \
--data-path "{{ repo_id or dataset_name }}" \
--save-path "<path-to-save-model>"
```

## Creation

This dataset was created using the `tokenlearn-featurize` CLI:

```bash
python -m tokenlearn.featurize \
--model-name "{{ model_name }}" \
--dataset-path "{{ source_dataset }}" \
--dataset-name "{{ source_name }}" \
--dataset-split "{{ source_split }}" \
--output-dir "<output-dir>"
```

## Library Authors

Tokenlearn was developed by the [Minish](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled).

## Citation

```
@software{minishlab2024model2vec,
author = {Stephan Tulkens and {van Dongen}, Thomas},
title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
year = {2024},
publisher = {Zenodo},
doi = {10.5281/zenodo.17270888},
url = {https://github.com/MinishLab/model2vec},
license = {MIT}
}
```
157 changes: 136 additions & 21 deletions tokenlearn/featurize.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,110 @@
import argparse
import json
import logging
import shutil
from pathlib import Path
from typing import Iterator

import numpy as np
from datasets import load_dataset
from datasets import Dataset, Features, Sequence, Value, concatenate_datasets, load_dataset, load_from_disk
from huggingface_hub import DatasetCard, DatasetCardData
from more_itertools import batched
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers.tokenization_utils import PreTrainedTokenizer

_SAVE_EVERY = 32
_DATASET_CARD_TEMPLATE = Path(__file__).parent / "datacards" / "dataset_card_template.md"

_SAVE_EVERY = 1024 # rows (≈ 32 batches at the default batch size of 32)

_FEATURES = Features({"text": Value("string"), "embedding": Sequence(Value("float32"))})

logger = logging.getLogger(__name__)


def _save_checkpoint(checkpoints_dir: Path, texts: list[str], embeddings: list[np.ndarray], part_idx: int) -> None:
"""Save a checkpoint part as a Parquet file."""
part = Dataset.from_dict(
{"text": texts, "embedding": [e.tolist() for e in embeddings]},
features=_FEATURES,
)
part.to_parquet(str(checkpoints_dir / f"shard_{part_idx:08d}.parquet"))


def _compact_checkpoints(checkpoints_dir: Path, output_dir: Path, keep_checkpoints: bool) -> None:
"""Compact checkpoint shards into a single standard HuggingFace dataset."""
shard_files = sorted(checkpoints_dir.glob("shard_*.parquet"))
if not shard_files:
return

logger.info("Compacting checkpoints into final dataset...")
# Build the compacted dataset in a sibling temp dir, then replace output_dir.
tmp_dir = output_dir.parent / f"{output_dir.name}.tmp"
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
# Load all shards and concatenate them into a single dataset, then save to the temp dir.
dataset = concatenate_datasets([Dataset.from_parquet(str(f)) for f in shard_files])
dataset.save_to_disk(str(tmp_dir))
if output_dir.exists():
# Remove the old output dir before renaming the temp dir to avoid leaving stale Arrow files from previous runs.
shutil.rmtree(output_dir)
tmp_dir.rename(output_dir)
if not keep_checkpoints:
shutil.rmtree(checkpoints_dir)
logger.info(f"Dataset saved to {output_dir}")


def _create_dataset_card(
output_dir: Path,
model_name: str,
source_dataset: str,
source_name: str,
source_split: str,
num_rows: int,
embedding_dim: int,
repo_id: str | None = None,
) -> DatasetCard:
"""Create a dataset card, save it to the output directory, and return it."""
card_data = DatasetCardData(
language="en",
tags=["tokenlearn", "embeddings", "model2vec"],
)
card = DatasetCard.from_template(
card_data,
template_path=str(_DATASET_CARD_TEMPLATE),
repo_id=repo_id,
dataset_name=output_dir.name,
model_name=model_name,
source_dataset=source_dataset,
source_name=source_name,
source_split=source_split,
num_rows=num_rows,
embedding_dim=embedding_dim,
)
card.save(output_dir / "README.md")
return card


def featurize( # noqa C901
dataset: Iterator[dict[str, str]],
model: SentenceTransformer,
output_dir: str,
max_means: int,
max_rows: int,
batch_size: int,
text_key: str,
max_length: int | None = None,
keep_checkpoints: bool = False,
) -> None:
"""Make a directory and dump all kinds of data in it."""
output_dir_path = Path(output_dir)
output_dir_path.mkdir(parents=True, exist_ok=True)
checkpoints_dir = Path(str(output_dir_path) + ".checkpoints")
checkpoints_dir.mkdir(exist_ok=True)

# Ugly hack
largest_batch = max([int(x.stem.split("_")[1]) for x in list(output_dir_path.glob("*.json"))], default=0)
if largest_batch:
logger.info(f"Resuming from batch {largest_batch}, skipping previous batches.")
shard_files = sorted(checkpoints_dir.glob("shard_*.parquet"))
part_idx = len(shard_files)
rows_done = sum(len(Dataset.from_parquet(str(f))) for f in shard_files)
if rows_done:
logger.info(f"Resuming from {rows_done} previously written rows ({part_idx} checkpoint shards).")

texts = []
embeddings = []
Expand All @@ -47,14 +118,17 @@ def featurize( # noqa C901
tokenizer.model_max_length = max_length
model.max_seq_length = max_length
logger.info(f"Set tokenizer maximum length to {max_length}.")
# Binding i in case the dataset is empty.
i = 0
for i, batch in tqdm(enumerate(batched(dataset, n=batch_size))):
if i * batch_size >= max_means:
logger.info(f"Reached maximum number of means: {max_means}")
batch = list(batch)
rows_processed = i * batch_size
if rows_processed >= max_rows:
logger.info(f"Reached maximum number of rows: {max_rows}")
break
if largest_batch and i <= largest_batch:
# Skip batches fully covered by a previous run; trim those that straddle the boundary.
if rows_processed + len(batch) <= rows_done:
continue
if rows_processed < rows_done:
batch = batch[rows_done - rows_processed :]
batch = [x[text_key] for x in batch]

if not all(isinstance(x, str) for x in batch):
Expand All @@ -64,14 +138,15 @@ def featurize( # noqa C901
for text, embedding in zip(batch, batch_embeddings):
texts.append(_truncate_text(tokenizer, text))
embeddings.append(embedding[1:-1].float().mean(axis=0).cpu().numpy())
if i and i % _SAVE_EVERY == 0:
json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
np.save(output_dir_path / f"feature_{i}.npy", embeddings)
if len(texts) >= _SAVE_EVERY:
_save_checkpoint(checkpoints_dir, texts, embeddings, part_idx)
part_idx += 1
texts = []
embeddings = []
if texts:
json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
np.save(output_dir_path / f"feature_{i}.npy", embeddings)
_save_checkpoint(checkpoints_dir, texts, embeddings, part_idx)

_compact_checkpoints(checkpoints_dir, output_dir_path, keep_checkpoints)


def _truncate_text(tokenizer: PreTrainedTokenizer, text: str) -> str:
Expand Down Expand Up @@ -123,10 +198,10 @@ def main() -> None:
help="Disable streaming mode when loading the dataset.",
)
parser.add_argument(
"--max-means",
"--max-rows",
type=int,
default=1000000,
help="The maximum number of mean embeddings to generate.",
help="The maximum number of rows to featurize.",
)
parser.add_argument(
"--key",
Expand All @@ -141,6 +216,17 @@ def main() -> None:
help="Batch size to use for encoding the texts.",
)
parser.add_argument("--max-length", type=int, default=None, help="Maximum token length for the tokenizer.")
parser.add_argument(
"--keep-checkpoints",
action="store_true",
help="Keep checkpoint parts after compaction (default: delete them).",
)
parser.add_argument(
"--push-to-hub",
type=str,
default=None,
help="HuggingFace Hub repo ID to push the dataset to after featurizing (e.g., 'username/my-dataset').",
)

args = parser.parse_args()

Expand All @@ -159,7 +245,36 @@ def main() -> None:
streaming=args.no_streaming,
)

featurize(iter(dataset), model, output_dir, args.max_means, args.batch_size, args.key, max_length=args.max_length)
featurize(
iter(dataset),
model,
output_dir,
args.max_rows,
args.batch_size,
args.key,
max_length=args.max_length,
keep_checkpoints=args.keep_checkpoints,
)

output_dir_path = Path(output_dir)
if (output_dir_path / "dataset_info.json").exists():
ds = load_from_disk(output_dir)
card = _create_dataset_card(
output_dir=output_dir_path,
model_name=args.model_name,
source_dataset=args.dataset_path,
source_name=args.dataset_name,
source_split=args.dataset_split,
num_rows=len(ds),
embedding_dim=len(ds[0]["embedding"]),
repo_id=args.push_to_hub,
)
if args.push_to_hub:
logger.info(f"Pushing dataset to Hub: {args.push_to_hub}")
ds.push_to_hub(args.push_to_hub)
card.push_to_hub(args.push_to_hub)
else:
logger.warning("No data was written — skipping dataset card and Hub push.")


if __name__ == "__main__":
Expand Down
20 changes: 16 additions & 4 deletions tokenlearn/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import logging
from pathlib import Path

import numpy as np
import torch
Expand Down Expand Up @@ -41,7 +40,19 @@ def main() -> None:
"--data-path",
type=str,
default="data/fineweb_bgebase",
help="Path to the directory containing the dataset.",
help="Path to a local HuggingFace dataset directory or a Hub repo ID.",
)
parser.add_argument(
"--data-split",
type=str,
default="train",
help="Dataset split to use when loading from the Hub (e.g., 'train', 'validation').",
)
parser.add_argument(
"--data-name",
type=str,
default=None,
help="Dataset configuration name when loading from the Hub (e.g., 'en' for C4).",
)
parser.add_argument(
"--save-path",
Expand Down Expand Up @@ -89,8 +100,9 @@ def main() -> None:
args = parser.parse_args()

# Collect paths for training data
paths = sorted(Path(args.data_path).glob("*.json"))
train_txt, train_vec = collect_means_and_texts(paths, args.limit_samples)
train_txt, train_vec = collect_means_and_texts(
args.data_path, args.limit_samples, split=args.data_split, name=args.data_name
)

pca_dims = args.pca_dims

Expand Down
Loading