diff --git a/model2vec/inference/model.py b/model2vec/inference/model.py index 74023b6..0677e92 100644 --- a/model2vec/inference/model.py +++ b/model2vec/inference/model.py @@ -2,6 +2,7 @@ import re from collections.abc import Sequence +from enum import Enum from pathlib import Path from tempfile import TemporaryDirectory from typing import TypeVar, cast @@ -23,25 +24,26 @@ LabelType = TypeVar("LabelType", list[str], list[list[str]]) +class HeadType(str, Enum): + CLASSIFIER = "classifier" + PROJECTOR = "projector" + MULTILABEL = "multilabel" + + class StaticModelPipeline: def __init__(self, model: StaticModel, head: Pipeline) -> None: """Create a pipeline with a StaticModel encoder.""" self.model = model self.head = head - classifier = self.head[-1] - # Check if the classifier is a multilabel classifier. - # NOTE: this doesn't look robust, but it is. - # Different classifiers, such as OVR wrappers, support multilabel output natively, so we - # can just use predict. - self.multilabel = False - if isinstance(classifier, MLPClassifier): - if classifier.out_activation_ == "logistic": - self.multilabel = True - - @property - def classes_(self) -> np.ndarray: - """The classes of the classifier.""" - return self.head.classes_ + + last_head = self.head[-1] + self.classes_: None | np.ndarray = None + if isinstance(last_head, MLPClassifier): + activation = last_head.out_activation_ + self.classifier_type = HeadType.MULTILABEL if activation == "logistic" else HeadType.CLASSIFIER + self.classes_ = self.head.classes_ + else: + self.classifier_type = HeadType.PROJECTOR @classmethod def from_pretrained( @@ -138,7 +140,8 @@ def predict( multiprocessing_threshold=multiprocessing_threshold, ) - if self.multilabel: + if self.classifier_type == HeadType.MULTILABEL: + assert self.classes_ is not None out_labels = [] proba = self.head.predict_proba(encoded) for vector in proba: @@ -166,7 +169,10 @@ def predict_proba( :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True. :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000. :return: The predicted labels or probabilities. + :raises ValueError: If the classifier type is projector. """ + if self.classifier_type == HeadType.PROJECTOR: + raise ValueError("You are using evaluate on a projector model. This is not supported.") encoded = self._encode_and_coerce_to_2d( X, show_progress_bar=show_progress_bar, @@ -190,7 +196,10 @@ def evaluate( :param threshold: The threshold for multilabel classification. :param output_dict: Whether to output the classification report as a dictionary. :return: A classification report. + :raises ValueError: If the classifier type is projector. """ + if self.classifier_type == HeadType.PROJECTOR: + raise ValueError("You are using evaluate on a projector model. This is not supported.") predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold) report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict) diff --git a/model2vec/train/__init__.py b/model2vec/train/__init__.py index c70f803..8766516 100644 --- a/model2vec/train/__init__.py +++ b/model2vec/train/__init__.py @@ -6,5 +6,6 @@ importable(extra_dependency, _REQUIRED_EXTRA) from model2vec.train.classifier import StaticModelForClassification +from model2vec.train.similarity import StaticModelForSimilarity -__all__ = ["StaticModelForClassification"] +__all__ = ["StaticModelForClassification", "StaticModelForSimilarity"] diff --git a/model2vec/train/base.py b/model2vec/train/base.py index 3966c8f..b2d0dd8 100644 --- a/model2vec/train/base.py +++ b/model2vec/train/base.py @@ -1,48 +1,69 @@ from __future__ import annotations import logging +from collections.abc import Sequence +from tempfile import TemporaryDirectory from typing import Any, TypeVar +import lightning.pytorch as pl import numpy as np import torch +from lightning.pytorch import LightningModule +from lightning.pytorch.callbacks import Callback, EarlyStopping from tokenizers import Encoding, Tokenizer from torch import nn from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader, Dataset +from tqdm import trange from model2vec import StaticModel -from model2vec.train.utils import get_probable_pad_token_id +from model2vec.inference import StaticModelPipeline +from model2vec.train.dataset import TextDataset +from model2vec.train.utils import get_probable_pad_token_id, to_pipeline, train_test_split logger = logging.getLogger(__name__) -class FinetunableStaticModel(nn.Module): +class _BaseFinetuneable(nn.Module): + val_metric = "val_loss" + early_stopping_direction = "min" + def __init__( self, *, vectors: torch.Tensor, tokenizer: Tokenizer, + hidden_dim: int = 256, + n_layers: int = 0, out_dim: int = 2, pad_id: int = 0, token_mapping: list[int] | None = None, weights: torch.Tensor | None = None, freeze: bool = False, + normalize: bool = True, ) -> None: """ Initialize a trainable StaticModel from a StaticModel. :param vectors: The embeddings of the staticmodel. :param tokenizer: The tokenizer. + :param hidden_dim: The hidden dimension of the head. + :param n_layers: The number of layers in the head. :param out_dim: The output dimension of the head. :param pad_id: The padding id. This is set to 0 in almost all model2vec models :param token_mapping: The token mapping. If None, the token mapping is set to the range of the number of vectors. :param weights: The weights of the model. If None, the weights are initialized to zeros. :param freeze: Whether to freeze the embeddings. This should be set to False in most cases. + :param normalize: Whether to normalize the embeddings. + :raises ValueError: If the vectors are not a 2D tensor. + :raises ValueError: If the token_mapping is not None and the length does not match the number of vectors. """ super().__init__() self.pad_id = pad_id self.out_dim = out_dim self.embed_dim = vectors.shape[1] + self.hidden_dim = hidden_dim + self.n_layers = n_layers + self.normalize = normalize self.vectors = vectors if self.vectors.dtype != torch.float32: @@ -53,6 +74,8 @@ def __init__( self.vectors = vectors.float() if token_mapping is not None: + if len(token_mapping) != len(vectors): + raise ValueError("token_mapping must have the same length as vectors") self.token_mapping = torch.tensor(token_mapping, dtype=torch.int64) else: self.token_mapping = torch.arange(len(vectors), dtype=torch.int64) @@ -70,20 +93,56 @@ def construct_weights(self) -> nn.Parameter: return nn.Parameter(weights, requires_grad=not self.freeze) def construct_head(self) -> nn.Sequential: - """Method should be overridden for various other classes.""" - return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim)) + """Constructs a simple classifier head.""" + modules: list[nn.Module] = [] + if self.n_layers == 0: + modules.append(nn.Linear(self.embed_dim, self.out_dim)) + else: + # If we have a hidden layer, we should first project to hidden_dim + modules = [ + nn.Linear(self.embed_dim, self.hidden_dim), + nn.ReLU(), + ] + for _ in range(self.n_layers - 1): + modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) + # We always have a layer mapping from hidden to out. + modules.append(nn.Linear(self.hidden_dim, self.out_dim)) + + linear_modules = [module for module in modules if isinstance(module, nn.Linear)] + if linear_modules: + *initial, last = linear_modules + for module in initial: + nn.init.kaiming_uniform_(module.weight, nonlinearity="relu") + nn.init.zeros_(module.bias) + # Final layer does not kaiming + nn.init.xavier_uniform_(last.weight) + nn.init.zeros_(last.bias) + + return nn.Sequential(*modules) + + def _initialize(self) -> None: + """Initialize the classifier for training.""" + self.head = self.construct_head() + self.embeddings = nn.Embedding.from_pretrained( + self.vectors.clone(), freeze=self.freeze, padding_idx=self.pad_id + ) + self.w = self.construct_weights() + self.train() @classmethod def from_pretrained( - cls: type[ModelType], *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any + cls: type[ModelType], path: str = "minishlab/potion-base-32m", *, token: str | None = None, **kwargs: Any ) -> ModelType: """Load the model from a pretrained model2vec model.""" - model = StaticModel.from_pretrained(model_name) - return cls.from_static_model(model=model, out_dim=out_dim, **kwargs) + if model_name := kwargs.pop("model_name"): + logger.warning("The 'model_name' argument is deprecated. Use 'path' instead.") + path = model_name + model = StaticModel.from_pretrained(path, token=token) + return cls.from_static_model(model=model, **kwargs) @classmethod def from_static_model( - cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, pad_token: str | None = None, **kwargs: Any + cls: type[ModelType], *, model: StaticModel, pad_token: str | None = None, **kwargs: Any ) -> ModelType: """Load the model from a static model.""" model.embedding = np.nan_to_num(model.embedding) @@ -100,7 +159,6 @@ def from_static_model( return cls( vectors=embeddings_converted, pad_id=pad_id, - out_dim=out_dim, tokenizer=model.tokenizer, token_mapping=token_mapping, weights=weights, @@ -130,7 +188,23 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor: # Mean pooling by dividing by the length embedded = embedded / length[:, None] - return nn.functional.normalize(embedded) + if self.normalize: + return nn.functional.normalize(embedded) + return embedded + + @torch.no_grad() + def _encode_single_batch(self, X: list[str]) -> torch.Tensor: + input_ids = self.tokenize(X) + return self.head(self._encode(input_ids)) + + def encode(self, X: list[str], batch_size: int = 1024, show_progress_bar: bool = False) -> np.ndarray: + """Encode a single batch of input ids.""" + pred = [] + for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): + logits = self._encode_single_batch(X[batch : batch + batch_size]) + pred.append(logits.cpu().numpy()) + + return np.concatenate(pred, axis=0) def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the mean, and a classifier layer after.""" @@ -163,45 +237,156 @@ def to_static_model(self) -> StaticModel: token_mapping = self.token_mapping.numpy() return StaticModel( - vectors=emb, weights=w, tokenizer=self.tokenizer, normalize=True, token_mapping=token_mapping + vectors=emb, weights=w, tokenizer=self.tokenizer, normalize=self.normalize, token_mapping=token_mapping ) + def to_pipeline(self) -> StaticModelPipeline: + """Convert the model to an sklearn pipeline.""" + return to_pipeline(self) + + def _determine_batch_size(self, batch_size: int | None, train_length: int) -> int: + if batch_size is None: + # Set to a multiple of 32 + base_number = int(min(max(1, (train_length / 30) // 32), 16)) + batch_size = int(base_number * 32) + logger.info("Batch size automatically set to %d.", batch_size) + + return batch_size + + def _check_val_split( + self, X: list[str], y: list, X_val: list[str] | None, y_val: list | None, test_size: float + ) -> tuple[list[str], list[str], Sequence, Sequence]: + if (X_val is not None) != (y_val is not None): + raise ValueError("Both X_val and y_val must be provided together, or neither.") + + if X_val is not None and y_val is not None: + # Additional check to ensure y_val is of the same type as y + if type(y_val[0]) != type(y[0]): + raise ValueError("X_val and y_val must be of the same type as X and y.") + + train_texts = X + train_labels = y + validation_texts = X_val + validation_labels = y_val + else: + train_texts, validation_texts, train_labels, validation_labels = train_test_split(X, y, test_size=test_size) -class TextDataset(Dataset): - def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None: - """ - A dataset of texts. + return train_texts, validation_texts, train_labels, validation_labels - :param tokenized_texts: The tokenized texts. Each text is a list of token ids. - :param targets: The targets. - :raises ValueError: If the number of labels does not match the number of texts. - """ - if len(targets) != len(tokenized_texts): - raise ValueError("Number of labels does not match number of texts.") - self.tokenized_texts = tokenized_texts - self.targets = targets + def _train( + self, + module: LightningModule, + train_dataset: TextDataset, + val_dataset: TextDataset, + batch_size: int, + early_stopping_patience: int | None, + min_epochs: int | None, + max_epochs: int | None, + device: str, + validation_steps: int | None, + ) -> None: + callbacks: list[Callback] = [] + if early_stopping_patience is not None: + callback = EarlyStopping( + monitor=self.val_metric, + mode=self.early_stopping_direction, + patience=early_stopping_patience, + min_delta=0.001, + ) + callbacks.append(callback) + + val_check_interval: int | None = None + check_val_every_epoch: int | None = 1 + if validation_steps is None: + n_train_batches = len(train_dataset) // batch_size + target_checks_per_epoch = 4 + min_train_steps_between_val = 250 + + # If we have more than 250 batches, smoothly interpolate + if n_train_batches > min_train_steps_between_val: + val_check_interval = max( + min_train_steps_between_val, + n_train_batches // target_checks_per_epoch, + ) + check_val_every_epoch = None + else: + val_check_interval = validation_steps + check_val_every_epoch = None + + with TemporaryDirectory() as tempdir: + trainer = pl.Trainer( + min_epochs=min_epochs, + max_epochs=max_epochs, + callbacks=callbacks, + val_check_interval=val_check_interval, + check_val_every_n_epoch=check_val_every_epoch, + accelerator=device, + default_root_dir=tempdir, + ) + + trainer.fit( + module, + train_dataloaders=train_dataset.to_dataloader(shuffle=True, batch_size=batch_size), + val_dataloaders=val_dataset.to_dataloader(shuffle=False, batch_size=batch_size), + ) + best_model_path = trainer.checkpoint_callback.best_model_path # type: ignore + best_model_weights = torch.load(best_model_path, weights_only=True) - def __len__(self) -> int: - """Return the length of the dataset.""" - return len(self.tokenized_texts) + state_dict = {} + for weight_name, weight in best_model_weights["state_dict"].items(): + if "loss_function" in weight_name: + # Skip the loss function class weight as its not needed for predictions + continue + state_dict[weight_name.removeprefix("model.")] = weight - def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]: - """Gets an item.""" - return self.tokenized_texts[index], self.targets[index] + self.load_state_dict(state_dict) + self.eval() - @staticmethod - def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]: - """Collate function.""" - texts, targets = zip(*batch) + def _prepare_dataset(self, X: list[str], y: torch.Tensor, max_length: int = 512) -> TextDataset: + """ + Prepare a dataset. - tensors: list[torch.Tensor] = [torch.LongTensor(x) for x in texts] - padded = pad_sequence(tensors, batch_first=True, padding_value=0) + :param X: The texts. + :param y: The labels. + :param max_length: The maximum length of the input. + :return: A TextDataset. + """ + # This is a speed optimization. + # assumes a mean token length of 10, which is really high, so safe. + truncate_length = max_length * 10 + batch_size = 1024 + tokenized: list[list[int]] = [] + for batch_idx in trange(0, len(X), 1024, desc="Tokenizing data"): + batch = [x[:truncate_length] for x in X[batch_idx : batch_idx + batch_size]] + encoded = self.tokenizer.encode_batch_fast(batch, add_special_tokens=False) + tokenized.extend([encoding.ids[:max_length] for encoding in encoded]) + + return TextDataset(tokenized, y) + + def _labels_to_tensor(self, labels: Any) -> torch.Tensor: + """Turn the labels into a tensor.""" + return labels + + def _create_datasets( + self, + X: list[str], + y: Any, + X_val: list[str] | None, + y_val: Any | None, + test_size: float, + ) -> tuple[TextDataset, TextDataset]: + train_texts, validation_texts, train_labels, validation_labels = self._check_val_split( + X, y, X_val, y_val, test_size + ) + y_tensor = self._labels_to_tensor(train_labels) + y_val_tensor = self._labels_to_tensor(validation_labels) - return padded, torch.stack(targets) + logger.info("Preparing train dataset.") + train_dataset = self._prepare_dataset(train_texts, y_tensor) + logger.info("Preparing validation dataset.") + val_dataset = self._prepare_dataset(validation_texts, y_val_tensor) - def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader: - """Convert the dataset to a DataLoader.""" - return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size) + return train_dataset, val_dataset -ModelType = TypeVar("ModelType", bound=FinetunableStaticModel) +ModelType = TypeVar("ModelType", bound=_BaseFinetuneable) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index c675d17..c715c07 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -3,32 +3,28 @@ import logging from collections import Counter from itertools import chain -from tempfile import TemporaryDirectory -from typing import TypeVar, cast +from typing import Any, Literal, cast import lightning as pl import numpy as np import torch -from lightning.pytorch.callbacks import Callback, EarlyStopping -from lightning.pytorch.utilities.types import OptimizerLRScheduler -from sklearn.metrics import jaccard_score -from sklearn.model_selection import train_test_split -from sklearn.neural_network import MLPClassifier -from sklearn.pipeline import make_pipeline from tokenizers import Tokenizer -from torch import nn from tqdm import trange -from model2vec.inference import StaticModelPipeline, evaluate_single_or_multi_label -from model2vec.train.base import FinetunableStaticModel, TextDataset +from model2vec.inference import evaluate_single_or_multi_label +from model2vec.train.base import _BaseFinetuneable +from model2vec.train.lightning_modules import ClassifierLightningModule, MultiLabelClassifierLightningModule logger = logging.getLogger(__name__) + _DEFAULT_RANDOM_SEED = 42 +LabelType = list[str] | list[list[str]] -LabelType = TypeVar("LabelType", list[str], list[list[str]]) +class StaticModelForClassification(_BaseFinetuneable): + val_metric = "val_accuracy" + early_stopping_direction = "max" -class StaticModelForClassification(FinetunableStaticModel): def __init__( self, *, @@ -41,10 +37,9 @@ def __init__( token_mapping: list[int] | None = None, weights: torch.Tensor | None = None, freeze: bool = False, + normalize: bool = True, ) -> None: """Initialize a standard classifier model.""" - self.n_layers = n_layers - self.hidden_dim = hidden_dim # Alias: Follows scikit-learn. Set to dummy classes self.classes_: list[str] = [str(x) for x in range(out_dim)] # multilabel flag will be set based on the type of `y` passed to fit. @@ -57,6 +52,9 @@ def __init__( token_mapping=token_mapping, weights=weights, freeze=freeze, + hidden_dim=hidden_dim, + n_layers=n_layers, + normalize=normalize, ) @property @@ -64,34 +62,6 @@ def classes(self) -> np.ndarray: """Return all clasess in the correct order.""" return np.array(self.classes_) - def construct_head(self) -> nn.Sequential: - """Constructs a simple classifier head.""" - modules: list[nn.Module] = [] - if self.n_layers == 0: - modules.append(nn.Linear(self.embed_dim, self.out_dim)) - else: - # If we have a hidden layer, we should first project to hidden_dim - modules = [ - nn.Linear(self.embed_dim, self.hidden_dim), - nn.ReLU(), - ] - for _ in range(self.n_layers - 1): - modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) - # We always have a layer mapping from hidden to out. - modules.append(nn.Linear(self.hidden_dim, self.out_dim)) - - linear_modules = [module for module in modules if isinstance(module, nn.Linear)] - if linear_modules: - *initial, last = linear_modules - for module in initial: - nn.init.kaiming_uniform_(module.weight, nonlinearity="relu") - nn.init.zeros_(module.bias) - # Final layer does not kaiming - nn.init.xavier_uniform_(last.weight) - nn.init.zeros_(last.bias) - - return nn.Sequential(*modules) - def predict( self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024, threshold: float = 0.5 ) -> np.ndarray: @@ -109,7 +79,7 @@ def predict( """ pred = [] for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): - logits = self._predict_single_batch(X[batch : batch + batch_size]) + logits = self._encode_single_batch(X[batch : batch + batch_size]) if self.multilabel: probs = torch.sigmoid(logits) mask = (probs > threshold).cpu().numpy() @@ -122,12 +92,6 @@ def predict( else: return np.array(pred) - @torch.no_grad() - def _predict_single_batch(self, X: list[str]) -> torch.Tensor: - input_ids = self.tokenize(X) - vectors, _ = self.forward(input_ids) - return vectors - def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: """ Predict probabilities for each class. @@ -137,14 +101,14 @@ def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_siz """ pred = [] for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): - logits = self._predict_single_batch(X[batch : batch + batch_size]) + logits = self._encode_single_batch(X[batch : batch + batch_size]) if self.multilabel: pred.append(torch.sigmoid(logits).cpu().numpy()) else: pred.append(torch.softmax(logits, dim=1).cpu().numpy()) return np.concatenate(pred, axis=0) - def fit( # noqa: C901 # Complexity is bad. + def fit( self, X: list[str], y: LabelType, @@ -157,7 +121,8 @@ def fit( # noqa: C901 # Complexity is bad. device: str = "auto", X_val: list[str] | None = None, y_val: LabelType | None = None, - class_weight: torch.Tensor | None = None, + class_weight: Literal["balanced"] | dict[str, float] | torch.Tensor | None = None, + validation_steps: int | None = None, random_seed: int = _DEFAULT_RANDOM_SEED, ) -> StaticModelForClassification: """ @@ -188,6 +153,7 @@ def fit( # noqa: C901 # Complexity is bad. :param y_val: The labels to be used for validation. :param class_weight: The weight of the classes. If None, all classes are weighted equally. Must have the same length as the number of classes. + :param validation_steps: The number of steps to run validation for. If None, validation steps are estimated from the data. :param random_seed: The random seed to use. Defaults to 42. :return: The fitted model. :raises ValueError: If either X_val or y_val are provided, but not both. @@ -196,89 +162,63 @@ def fit( # noqa: C901 # Complexity is bad. logger.info("Re-initializing model.") # Determine whether the task is multilabel based on the type of y. - self._initialize(y) - - if (X_val is not None) != (y_val is not None): - raise ValueError("Both X_val and y_val must be provided together, or neither.") - - if X_val is not None and y_val is not None: - # Additional check to ensure y_val is of the same type as y - if type(y_val[0]) != type(y[0]): - raise ValueError("X_val and y_val must be of the same type as X and y.") - - train_texts = X - train_labels = y - validation_texts = X_val - validation_labels = y_val - else: - train_texts, validation_texts, train_labels, validation_labels = self._train_test_split( - X, - y, - test_size=test_size, - ) - - if batch_size is None: - # Set to a multiple of 32 - base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16)) - batch_size = int(base_number * 32) - logger.info("Batch size automatically set to %d.", batch_size) + self._initialize_on_labels(y) + self._initialize() if class_weight is not None: - if len(class_weight) != len(self.classes_): - raise ValueError("class_weight must have the same length as the number of classes.") - - logger.info("Preparing train dataset.") - train_dataset = self._prepare_dataset(train_texts, train_labels) - logger.info("Preparing validation dataset.") - val_dataset = self._prepare_dataset(validation_texts, validation_labels) - - c = _ClassifierLightningModule(self, learning_rate=learning_rate, class_weight=class_weight) - - n_train_batches = len(train_dataset) // batch_size - callbacks: list[Callback] = [] - if early_stopping_patience is not None: - callback = EarlyStopping(monitor="val_accuracy", mode="max", patience=early_stopping_patience) - callbacks.append(callback) - - # If the dataset is small, we check the validation set every epoch. - # If the dataset is large, we check the validation set every 250 batches. - if n_train_batches < 250: - val_check_interval = None - check_val_every_epoch = 1 + if isinstance(class_weight, torch.Tensor): + logger.warning("You are passing a tensor as class weight. This will be removed in an upcoming version.") + if len(class_weight) != len(self.classes_): + raise ValueError("class_weight must have the same length as the number of classes.") + class_weight = {self.classes_[idx]: w for idx, w in enumerate(class_weight.tolist())} + resolved_class_weight = self._determine_class_weight(class_weight, y) else: - val_check_interval = max(250, 2 * len(val_dataset) // batch_size) - check_val_every_epoch = None - - with TemporaryDirectory() as tempdir: - trainer = pl.Trainer( - min_epochs=min_epochs, - max_epochs=max_epochs, - callbacks=callbacks, - val_check_interval=val_check_interval, - check_val_every_n_epoch=check_val_every_epoch, - accelerator=device, - default_root_dir=tempdir, - ) + resolved_class_weight = None - trainer.fit( - c, - train_dataloaders=train_dataset.to_dataloader(shuffle=True, batch_size=batch_size), - val_dataloaders=val_dataset.to_dataloader(shuffle=False, batch_size=batch_size), - ) - best_model_path = trainer.checkpoint_callback.best_model_path # type: ignore - best_model_weights = torch.load(best_model_path, weights_only=True) + train_dataset, val_dataset = self._create_datasets(X, y, X_val, y_val, test_size) + batch_size = self._determine_batch_size(batch_size, len(train_dataset)) - state_dict = {} - for weight_name, weight in best_model_weights["state_dict"].items(): - if "loss_function" in weight_name: - # Skip the loss function class weight as its not needed for predictions - continue - state_dict[weight_name.removeprefix("model.")] = weight + c: pl.LightningModule + if self.multilabel: + c = MultiLabelClassifierLightningModule( + self, learning_rate=learning_rate, class_weight=resolved_class_weight + ) + else: + c = ClassifierLightningModule(self, learning_rate=learning_rate, class_weight=resolved_class_weight) + + self._train( + module=c, + train_dataset=train_dataset, + val_dataset=val_dataset, + batch_size=batch_size, + early_stopping_patience=early_stopping_patience, + min_epochs=min_epochs, + max_epochs=max_epochs, + device=device, + validation_steps=validation_steps, + ) - self.load_state_dict(state_dict) - self.eval() return self + def _determine_class_weight( + self, class_weight: dict[str, float] | Literal["balanced"], y: LabelType + ) -> torch.Tensor: + """Determine the class weight for the classifier.""" + if class_weight == "balanced": + if self.multilabel: + y = cast(list[list[str]], y) + counts = Counter(chain.from_iterable(y)) + else: + y = cast(list[str], y) + counts = Counter(y) + total = sum(counts.values()) + n_classes = len(counts) + # Reciprocal weight: upweight rare classes, downweight frequent ones + weights = [total / (n_classes * counts[c]) for c in self.classes_] + else: + weights = [class_weight[c] for c in self.classes_] + return torch.tensor(weights, dtype=torch.float32) + def evaluate( self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False ) -> str | dict[str, dict[str, float]]: @@ -298,7 +238,7 @@ def evaluate( return report - def _initialize(self, y: LabelType) -> None: + def _initialize_on_labels(self, y: LabelType) -> None: """ Sets the output dimensionality, the classes, and initializes the head. @@ -306,12 +246,14 @@ def _initialize(self, y: LabelType) -> None: :raises ValueError: If the labels are inconsistent. """ if isinstance(y[0], (str, int)): + y = cast(list[str], y) # Check if all labels are strings or integers. if not all(isinstance(label, (str, int)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be strings or integers.") self.multilabel = False classes = sorted(set(y)) else: + y = cast(list[list[str]], y) # Check if all labels are lists or tuples. if not all(isinstance(label, (list, tuple)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.") @@ -319,145 +261,21 @@ def _initialize(self, y: LabelType) -> None: classes = sorted(set(chain.from_iterable(y))) self.classes_ = classes - self.out_dim = len(self.classes_) # Update output dimension - self.head = self.construct_head() - self.embeddings = nn.Embedding.from_pretrained( - self.vectors.clone(), freeze=self.freeze, padding_idx=self.pad_id - ) - self.w = self.construct_weights() - self.train() + self.out_dim = len(self.classes_) - def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> TextDataset: - """ - Prepare a dataset. For multilabel classification, each target is converted into a multi-hot vector. - - :param X: The texts. - :param y: The labels. - :param max_length: The maximum length of the input. - :return: A TextDataset. - """ - # This is a speed optimization. - # assumes a mean token length of 10, which is really high, so safe. - truncate_length = max_length * 10 - X = [x[:truncate_length] for x in X] - tokenized: list[list[int]] = [ - encoding.ids[:max_length] for encoding in self.tokenizer.encode_batch_fast(X, add_special_tokens=False) - ] + def _labels_to_tensor(self, labels: Any) -> torch.Tensor: + """Convert a list or list of list of labels to a tensor.""" if self.multilabel: # Convert labels to multi-hot vectors num_classes = len(self.classes_) - labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float) + labels_tensor = torch.zeros(len(labels), num_classes, dtype=torch.float) mapping = {label: idx for idx, label in enumerate(self.classes_)} - for i, sample_labels in enumerate(y): + for i, sample_labels in enumerate(labels): indices = [mapping[label] for label in sample_labels] labels_tensor[i, indices] = 1.0 else: - labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long) - return TextDataset(tokenized, labels_tensor) - - def _train_test_split( - self, - X: list[str], - y: list[str] | list[list[str]], - test_size: float, - ) -> tuple[list[str], list[str], LabelType, LabelType]: - """ - Split the data. - - For single-label classification, stratification is attempted (if possible). - For multilabel classification, a random split is performed. - """ - if not self.multilabel: - label_counts = Counter(y) - if min(label_counts.values()) < 2: - logger.info("Some classes have less than 2 samples. Stratification is disabled.") - return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True) - return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=y) - else: - # Multilabel classification does not support stratification. - return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True) - - def to_pipeline(self) -> StaticModelPipeline: - """Convert the model to an sklearn pipeline.""" - static_model = self.to_static_model() - - random_state = np.random.RandomState(_DEFAULT_RANDOM_SEED) - n_items = len(self.classes) - X = random_state.randn(n_items, static_model.dim) - y = self.classes - - mlp_head = MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers) - mlp_head.fit(X, y) - - for index, layer in enumerate([module for module in self.head if isinstance(module, nn.Linear)]): - mlp_head.coefs_[index] = layer.weight.detach().cpu().numpy().T - mlp_head.intercepts_[index] = layer.bias.detach().cpu().numpy() - # Below is necessary to ensure that the converted model works correctly. - # In scikit-learn, a binary classifier only has a single vector of output coefficients - # and a single intercept. We use two output vectors. - # To convert correctly, we need to set the outputs correctly, and fix the activation function. - # Make sure n_outputs is set to > 1. - mlp_head.n_outputs_ = self.out_dim - # Set to softmax or sigmoid - mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax" - - pipeline = make_pipeline(mlp_head) - return StaticModelPipeline(static_model, pipeline) - - -class _ClassifierLightningModule(pl.LightningModule): - def __init__( - self, model: StaticModelForClassification, learning_rate: float, class_weight: torch.Tensor | None = None - ) -> None: - """Initialize the LightningModule.""" - super().__init__() - self.model = model - self.learning_rate = learning_rate - self.loss_function = ( - nn.CrossEntropyLoss(weight=class_weight) - if not model.multilabel - else nn.BCEWithLogitsLoss(pos_weight=class_weight) - ) + labels_tensor = torch.tensor( + [self.classes_.index(label) for label in cast(list[str], labels)], dtype=torch.long + ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Simple forward pass.""" - return self.model(x) - - def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Training step using cross-entropy loss for single-label and binary cross-entropy for multilabel training.""" - x, y = batch - head_out, _ = self.model(x) - loss = self.loss_function(head_out, y) - self.log("train_loss", loss) - return loss - - def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Validation step computing loss and accuracy.""" - x, y = batch - head_out, _ = self.model(x) - loss = self.loss_function(head_out, y) - accuracy: float - if self.model.multilabel: - preds = (torch.sigmoid(head_out) > 0.5).float() - # Multilabel accuracy is defined as the Jaccard score averaged over samples. - accuracy = cast(float, jaccard_score(y.cpu(), preds.cpu(), average="samples")) - else: - accuracy = (head_out.argmax(dim=1) == y).float().mean() - self.log("val_loss", loss) - self.log("val_accuracy", accuracy, prog_bar=True) - - return loss - - def configure_optimizers(self) -> OptimizerLRScheduler: - """Configure optimizer and learning rate scheduler.""" - optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - mode="min", - factor=0.5, - patience=3, - min_lr=1e-6, - threshold=0.03, - threshold_mode="rel", - ) - return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} + return labels_tensor diff --git a/model2vec/train/dataset.py b/model2vec/train/dataset.py new file mode 100644 index 0000000..fc608c5 --- /dev/null +++ b/model2vec/train/dataset.py @@ -0,0 +1,40 @@ +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset + + +class TextDataset(Dataset): + def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None: + """ + A dataset of texts. + + :param tokenized_texts: The tokenized texts. Each text is a list of token ids. + :param targets: The targets. + :raises ValueError: If the number of targets does not match the number of texts. + """ + if len(targets) != len(tokenized_texts): + raise ValueError("Number of targets does not match number of texts.") + self.tokenized_texts = tokenized_texts + self.targets = targets + + def __len__(self) -> int: + """Return the length of the dataset.""" + return len(self.tokenized_texts) + + def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]: + """Gets an item.""" + return self.tokenized_texts[index], self.targets[index] + + @staticmethod + def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]: + """Collate function.""" + texts, targets = zip(*batch) + + tensors: list[torch.Tensor] = [torch.LongTensor(x) for x in texts] + padded = pad_sequence(tensors, batch_first=True, padding_value=0) + + return padded, torch.stack(targets) + + def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader: + """Convert the dataset to a DataLoader.""" + return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size) diff --git a/model2vec/train/lightning_modules.py b/model2vec/train/lightning_modules.py new file mode 100644 index 0000000..93a7c82 --- /dev/null +++ b/model2vec/train/lightning_modules.py @@ -0,0 +1,99 @@ +from typing import cast + +import lightning.pytorch as pl +import torch +from lightning.pytorch.utilities.types import OptimizerLRScheduler +from sklearn.metrics import jaccard_score +from torch import nn + + +class StaticLightningModule(pl.LightningModule): + def __init__(self, model: nn.Module, learning_rate: float) -> None: + """Initialize the LightningModule.""" + super().__init__() + self.model = model + self.learning_rate = learning_rate + self.loss_function = self.cosine_distance + + def cosine_distance(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Returns the cosine distance loss function.""" + x = torch.nn.functional.normalize(x, dim=1) + y = torch.nn.functional.normalize(y, dim=1) + return (1 - torch.sum(x * y, dim=1)).mean() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Simple forward pass.""" + return self.model(x) + + def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Training step using cross-entropy loss for single-label and binary cross-entropy for multilabel training.""" + x, y = batch + head_out, _ = self.model(x) + loss = self.loss_function(head_out, y) + self.log("train_loss", loss, prog_bar=True) + return loss + + def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Validation step computing loss and accuracy.""" + x, y = batch + head_out, _ = self.model(x) + loss = self.loss_function(head_out, y) + self.log("val_loss", loss, prog_bar=True) + + return loss + + def configure_optimizers(self) -> OptimizerLRScheduler: + """Configure optimizer and learning rate scheduler.""" + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min", + factor=0.5, + patience=3, + min_lr=1e-6, + threshold=0.03, + threshold_mode="rel", + ) + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} + + +class ClassifierLightningModule(StaticLightningModule): + def __init__(self, model: nn.Module, learning_rate: float, class_weight: torch.Tensor | None = None) -> None: + """Initialize the LightningModule.""" + super().__init__(model, learning_rate) + self.model = model + self.learning_rate = learning_rate + self.loss_function = nn.CrossEntropyLoss(weight=class_weight) + + def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Validation step computing loss and accuracy.""" + x, y = batch + head_out, _ = self.model(x) + loss = self.loss_function(head_out, y) + accuracy = (head_out.argmax(dim=1) == y).float().mean() + self.log("val_loss", loss) + self.log("val_accuracy", accuracy, prog_bar=True) + + return loss + + +class MultiLabelClassifierLightningModule(StaticLightningModule): + def __init__(self, model: nn.Module, learning_rate: float, class_weight: torch.Tensor | None = None) -> None: + """Initialize the LightningModule.""" + super().__init__(model, learning_rate) + self.model = model + self.learning_rate = learning_rate + self.loss_function = nn.BCEWithLogitsLoss(weight=class_weight) + + def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Validation step computing loss and accuracy.""" + x, y = batch + head_out, _ = self.model(x) + loss = self.loss_function(head_out, y) + preds = (torch.sigmoid(head_out) > 0.5).float() + # Multilabel accuracy is defined as the Jaccard score averaged over samples. + accuracy = cast(float, jaccard_score(y.cpu(), preds.cpu(), average="samples")) + self.log("val_loss", loss) + self.log("val_accuracy", accuracy, prog_bar=True) + + return loss diff --git a/model2vec/train/similarity.py b/model2vec/train/similarity.py new file mode 100644 index 0000000..6aea754 --- /dev/null +++ b/model2vec/train/similarity.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import logging + +import lightning as pl +import torch +from tokenizers import Tokenizer + +from model2vec.train.base import _BaseFinetuneable +from model2vec.train.lightning_modules import StaticLightningModule + +logger = logging.getLogger(__name__) + +_DEFAULT_RANDOM_SEED = 42 + + +class StaticModelForSimilarity(_BaseFinetuneable): + def __init__( + self, + *, + vectors: torch.Tensor, + tokenizer: Tokenizer, + n_layers: int = 1, + hidden_dim: int = 512, + out_dim: int = 2, + pad_id: int = 0, + token_mapping: list[int] | None = None, + weights: torch.Tensor | None = None, + freeze: bool = False, + normalize: bool = True, + ) -> None: + """Initialize a standard similarity model.""" + super().__init__( + vectors=vectors, + out_dim=out_dim, + pad_id=pad_id, + tokenizer=tokenizer, + token_mapping=token_mapping, + weights=weights, + freeze=freeze, + hidden_dim=hidden_dim, + n_layers=n_layers, + normalize=normalize, + ) + + def fit( + self, + X: list[str], + y: torch.Tensor, + learning_rate: float = 1e-3, + batch_size: int | None = None, + min_epochs: int | None = None, + max_epochs: int | None = -1, + early_stopping_patience: int | None = 5, + test_size: float = 0.1, + device: str = "auto", + X_val: list[str] | None = None, + y_val: torch.Tensor | None = None, + validation_steps: int | None = None, + random_seed: int = _DEFAULT_RANDOM_SEED, + ) -> StaticModelForSimilarity: + """ + Fit a model. + + This function creates a Lightning Trainer object and fits the model to the data. + We use early stopping. After training, the weights of the best model are loaded back into the model. + + This function seeds everything with a seed of 42, so the results are reproducible. + It also splits the data into a train and validation set, again with a random seed. + + If `X_val` and `y_val` are not provided, the function will automatically + split the training data into a train and validation set using `test_size`. + + :param X: The texts to train on. + :param y: The vectors to train on. + :param learning_rate: The learning rate. + :param batch_size: The batch size. If None, a good batch size is chosen automatically. + :param min_epochs: The minimum number of epochs to train for. + :param max_epochs: The maximum number of epochs to train for. + If this is -1, the model trains until early stopping is triggered. + :param early_stopping_patience: The patience for early stopping. + If this is None, early stopping is disabled. + :param test_size: The test size for the train-test split. + :param device: The device to train on. If this is "auto", the device is chosen automatically. + :param X_val: The texts to be used for validation. + :param y_val: The vectors to be used for validation. + :param validation_steps: The number of steps to run validation for. If None, validation steps are estimated from the data. + :param random_seed: The random seed to use. Defaults to 42. + :return: The fitted model. + """ + pl.seed_everything(random_seed) + logger.info("Re-initializing model.") + + train_dataset, val_dataset = self._create_datasets(X, y, X_val, y_val, test_size) + batch_size = self._determine_batch_size(batch_size, len(train_dataset)) + + self.out_dim = train_dataset.targets.shape[1] + self._initialize() + + c = StaticLightningModule(self, learning_rate=learning_rate) + + self._train( + module=c, + train_dataset=train_dataset, + val_dataset=val_dataset, + batch_size=batch_size, + early_stopping_patience=early_stopping_patience, + min_epochs=min_epochs, + max_epochs=max_epochs, + device=device, + validation_steps=validation_steps, + ) + + return self diff --git a/model2vec/train/utils.py b/model2vec/train/utils.py index 4d6b95b..0429c9e 100644 --- a/model2vec/train/utils.py +++ b/model2vec/train/utils.py @@ -1,6 +1,22 @@ +from __future__ import annotations + import logging +from collections import Counter +from typing import TYPE_CHECKING +import numpy as np +from sklearn.model_selection import train_test_split as sklearn_split +from sklearn.neural_network import MLPClassifier, MLPRegressor +from sklearn.pipeline import make_pipeline from tokenizers import Tokenizer +from torch import nn + +from model2vec.inference import StaticModelPipeline + +if TYPE_CHECKING: + from model2vec.train.base import _BaseFinetuneable + from model2vec.train.classifier import StaticModelForClassification + logger = logging.getLogger(__name__) @@ -19,3 +35,56 @@ def get_probable_pad_token_id(tokenizer: Tokenizer) -> int: logger.warning("No known pad token found, using 0 as default") return 0 + + +def to_pipeline(model: "_BaseFinetuneable | StaticModelForClassification") -> StaticModelPipeline: + """Convert the model to an sklearn pipeline.""" + from model2vec.train.classifier import StaticModelForClassification + + static_model = model.to_static_model() + + random_state = np.random.RandomState(42) + n_items = model.out_dim + X = random_state.randn(n_items, static_model.dim) + y: np.ndarray | list[str] + if isinstance(model, StaticModelForClassification): + y = model.classes_ + mlp_head = MLPClassifier(hidden_layer_sizes=(model.hidden_dim,) * model.n_layers) + activation = "logistic" if model.multilabel else "softmax" + else: + y = random_state.randn(n_items, n_items) + mlp_head = MLPRegressor(hidden_layer_sizes=(model.hidden_dim,) * model.n_layers) + activation = "identity" + mlp_head.fit(X, y) + + for index, layer in enumerate([module for module in model.head if isinstance(module, nn.Linear)]): + mlp_head.coefs_[index] = layer.weight.detach().cpu().numpy().T + mlp_head.intercepts_[index] = layer.bias.detach().cpu().numpy() + + mlp_head.n_outputs_ = model.out_dim + mlp_head.out_activation_ = activation + pipeline = make_pipeline(mlp_head) + + return StaticModelPipeline(static_model, pipeline) + + +def train_test_split( + X: list[str], + y: list, + test_size: float, +) -> tuple[list[str], list[str], list, list]: + """ + Split the data. + + For single-label classification, stratification is attempted (if possible). + For multilabel classification, a random split is performed. + """ + stratify_data = None + if isinstance(y, list) and isinstance(y[0], (str, int)): + label_counts = Counter(y) + if min(label_counts.values()) < 2: + logger.info("Some classes have fewer than 2 samples. Stratification is disabled.") + stratify_data = None + else: + stratify_data = y + return sklearn_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=stratify_data) # type: ignore diff --git a/tests/conftest.py b/tests/conftest.py index 210d56b..de60201 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ from model2vec.inference import StaticModelPipeline from model2vec.model import StaticModel -from model2vec.train import StaticModelForClassification +from model2vec.train import StaticModelForClassification, StaticModelForSimilarity _TOKENIZER_TYPES = ["wordpiece", "bpe", "unigram"] @@ -182,3 +182,26 @@ def mock_trained_pipeline(request: pytest.FixtureRequest) -> StaticModelForClass model.fit(X, y) # type: ignore return model + + +@pytest.fixture(scope="session") +def mock_trained_similarity_pipeline() -> StaticModelForSimilarity: + """Mock StaticModelForSimilarity.""" + tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer + torch.random.manual_seed(42) + vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) + model = StaticModelForSimilarity(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") + + X = ["dog", "cat"] + y = torch.randn(2, 32) + model.fit(X, y) + + return model + + +@pytest.fixture(scope="session") +def mock_inference_pipeline_projector( + mock_trained_similarity_pipeline: StaticModelForSimilarity, +) -> StaticModelPipeline: + """Mock pipeline.""" + return mock_trained_similarity_pipeline.to_pipeline() diff --git a/tests/test_inference.py b/tests/test_inference.py index bae4732..fea74ad 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -5,18 +5,20 @@ import pytest -from model2vec.inference import StaticModelPipeline +from model2vec.inference.model import HeadType, StaticModelPipeline def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None: """Test successful init and predict with StaticModelPipeline.""" target: list[str] | list[list[str]] - if mock_inference_pipeline.multilabel: + if mock_inference_pipeline.classifier_type == HeadType.MULTILABEL: + assert mock_inference_pipeline.classes_ is not None if isinstance(mock_inference_pipeline.classes_[0], str): target = [["a", "b"]] else: target = [[0, 1]] # type: ignore else: + assert mock_inference_pipeline.classes_ is not None if isinstance(mock_inference_pipeline.classes_[0], str): target = ["b"] else: @@ -25,6 +27,19 @@ def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None: assert mock_inference_pipeline.predict(["dog"]).tolist() == target +def test_init_predict_projector(mock_inference_pipeline_projector: StaticModelPipeline) -> None: + """Test successful init and predict with StaticModelPipeline.""" + assert mock_inference_pipeline_projector.classifier_type == HeadType.PROJECTOR + assert mock_inference_pipeline_projector.classes_ is None + with pytest.raises(ValueError): + mock_inference_pipeline_projector.predict_proba(["dog"]) + with pytest.raises(ValueError): + mock_inference_pipeline_projector.evaluate(["dog"], ["a"]) + + prediction = mock_inference_pipeline_projector.predict(["dog"]) + assert prediction.shape == (1, 32) + + def test_init_predict_proba(mock_inference_pipeline: StaticModelPipeline) -> None: """Test successful init and predict_proba with StaticModelPipeline.""" assert mock_inference_pipeline.predict_proba("dog").argmax() == 1 @@ -34,12 +49,14 @@ def test_init_predict_proba(mock_inference_pipeline: StaticModelPipeline) -> Non def test_init_evaluate(mock_inference_pipeline: StaticModelPipeline) -> None: """Test successful init and evaluate with StaticModelPipeline.""" target: list[str] | list[list[str]] - if mock_inference_pipeline.multilabel: + if mock_inference_pipeline.classifier_type == HeadType.MULTILABEL: + assert mock_inference_pipeline.classes_ is not None if isinstance(mock_inference_pipeline.classes_[0], str): target = [["a", "b"]] else: target = [[0, 1]] # type: ignore else: + assert mock_inference_pipeline.classes_ is not None if isinstance(mock_inference_pipeline.classes_[0], str): target = ["b"] else: @@ -53,12 +70,14 @@ def test_roundtrip_save(mock_inference_pipeline: StaticModelPipeline) -> None: mock_inference_pipeline.save_pretrained(temp_dir) loaded = StaticModelPipeline.from_pretrained(temp_dir) target: list[str] | list[list[str]] - if mock_inference_pipeline.multilabel: + if mock_inference_pipeline.classifier_type == HeadType.MULTILABEL: + assert mock_inference_pipeline.classes_ is not None if isinstance(mock_inference_pipeline.classes_[0], str): target = [["a", "b"]] else: target = [[0, 1]] # type: ignore else: + assert mock_inference_pipeline.classes_ is not None if isinstance(mock_inference_pipeline.classes_[0], str): target = ["b"] else: diff --git a/tests/test_trainable.py b/tests/test_trainable.py index 0b0179d..1002e36 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -10,8 +10,10 @@ from model2vec.model import StaticModel from model2vec.train import StaticModelForClassification -from model2vec.train.base import FinetunableStaticModel, TextDataset -from model2vec.train.utils import get_probable_pad_token_id +from model2vec.train.base import _BaseFinetuneable +from model2vec.train.dataset import TextDataset +from model2vec.train.similarity import StaticModelForSimilarity +from model2vec.train.utils import get_probable_pad_token_id, train_test_split @pytest.mark.parametrize("n_layers", [0, 1, 2, 3]) @@ -34,7 +36,9 @@ def test_init_predict(n_layers: int, mock_vectors: np.ndarray, mock_tokenizer: T def test_init_base_class(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: """Test successful initialization of the base class.""" vectors_torched = torch.from_numpy(mock_vectors) - s = FinetunableStaticModel(vectors=vectors_torched, tokenizer=mock_tokenizer) + s = _BaseFinetuneable( + vectors=vectors_torched, tokenizer=mock_tokenizer, hidden_dim=256, out_dim=2, n_layers=0, pad_id=0 + ) assert s.vectors.shape == mock_vectors.shape assert s.w.shape[0] == mock_vectors.shape[0] @@ -42,16 +46,43 @@ def test_init_base_class(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> assert head[0].in_features == mock_vectors.shape[1] +def test_init_base_class_weights(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: + """Test successful initialization of the base class.""" + vectors_torched = torch.from_numpy(mock_vectors) + s = _BaseFinetuneable( + vectors=vectors_torched, + tokenizer=mock_tokenizer, + hidden_dim=256, + out_dim=2, + n_layers=0, + pad_id=0, + token_mapping=torch.randint(0, mock_vectors.shape[0], (mock_vectors.shape[0],)).tolist(), + ) + assert s.vectors.shape == mock_vectors.shape + assert s.w.shape[0] == mock_vectors.shape[0] + + with pytest.raises(ValueError): + _BaseFinetuneable( + vectors=vectors_torched, + tokenizer=mock_tokenizer, + hidden_dim=256, + out_dim=2, + n_layers=0, + pad_id=0, + token_mapping=torch.randint(0, mock_vectors.shape[0], (10,)).tolist(), + ) + + def test_init_base_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: """Test initializion from a static model.""" model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer) - s = FinetunableStaticModel.from_static_model(model=model) + s = _BaseFinetuneable.from_static_model(model=model) assert s.vectors.shape == mock_vectors.shape assert s.w.shape[0] == mock_vectors.shape[0] with TemporaryDirectory() as temp_dir: model.save_pretrained(temp_dir) - s = FinetunableStaticModel.from_pretrained(model_name=temp_dir) + s = _BaseFinetuneable.from_pretrained(model_name=temp_dir) assert s.vectors.shape == mock_vectors.shape assert s.w.shape[0] == mock_vectors.shape[0] @@ -164,9 +195,25 @@ def test_convert_to_pipeline(mock_trained_pipeline: StaticModelForClassification assert np.allclose(p1, p2) -def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -> None: +def test_convert_to_pipeline_similarity(mock_trained_similarity_pipeline: StaticModelForSimilarity) -> None: + """Convert a model to a pipeline.""" + mock_trained_similarity_pipeline.eval() + pipeline = mock_trained_similarity_pipeline.to_pipeline() + encoded_pipeline = pipeline.model.encode(["dog cat", "dog"]) + encoded_model = ( + mock_trained_similarity_pipeline(mock_trained_similarity_pipeline.tokenize(["dog cat", "dog"]))[1] + .detach() + .numpy() + ) + assert np.allclose(encoded_pipeline, encoded_model) + a = pipeline.predict(["dog cat", "dog"]).tolist() + b = mock_trained_similarity_pipeline.encode(["dog cat", "dog"]).tolist() + assert np.allclose(a, b) + + +def test_train_test_split() -> None: """Test the train test split function.""" - a, b, c, d = mock_trained_pipeline._train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5) + a, b, c, d = train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5) assert len(a) == 2 assert len(b) == 2 assert len(c) == len(a)