Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions model2vec/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
from collections.abc import Sequence
from enum import Enum
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TypeVar, cast
Expand All @@ -23,25 +24,26 @@
LabelType = TypeVar("LabelType", list[str], list[list[str]])


class HeadType(str, Enum):
CLASSIFIER = "classifier"
PROJECTOR = "projector"
MULTILABEL = "multilabel"


class StaticModelPipeline:
def __init__(self, model: StaticModel, head: Pipeline) -> None:
"""Create a pipeline with a StaticModel encoder."""
self.model = model
self.head = head
classifier = self.head[-1]
# Check if the classifier is a multilabel classifier.
# NOTE: this doesn't look robust, but it is.
# Different classifiers, such as OVR wrappers, support multilabel output natively, so we
# can just use predict.
self.multilabel = False
if isinstance(classifier, MLPClassifier):
if classifier.out_activation_ == "logistic":
self.multilabel = True

@property
def classes_(self) -> np.ndarray:
"""The classes of the classifier."""
return self.head.classes_

last_head = self.head[-1]
self.classes_: None | np.ndarray = None
if isinstance(last_head, MLPClassifier):
activation = last_head.out_activation_
self.classifier_type = HeadType.MULTILABEL if activation == "logistic" else HeadType.CLASSIFIER
self.classes_ = self.head.classes_
else:
self.classifier_type = HeadType.PROJECTOR

@classmethod
def from_pretrained(
Expand Down Expand Up @@ -138,7 +140,8 @@ def predict(
multiprocessing_threshold=multiprocessing_threshold,
)

if self.multilabel:
if self.classifier_type == HeadType.MULTILABEL:
assert self.classes_ is not None
out_labels = []
proba = self.head.predict_proba(encoded)
for vector in proba:
Expand Down Expand Up @@ -166,7 +169,10 @@ def predict_proba(
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
:return: The predicted labels or probabilities.
:raises ValueError: If the classifier type is projector.
"""
if self.classifier_type == HeadType.PROJECTOR:
raise ValueError("You are using evaluate on a projector model. This is not supported.")
encoded = self._encode_and_coerce_to_2d(
X,
show_progress_bar=show_progress_bar,
Expand All @@ -190,7 +196,10 @@ def evaluate(
:param threshold: The threshold for multilabel classification.
:param output_dict: Whether to output the classification report as a dictionary.
:return: A classification report.
:raises ValueError: If the classifier type is projector.
"""
if self.classifier_type == HeadType.PROJECTOR:
raise ValueError("You are using evaluate on a projector model. This is not supported.")
predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)

Expand Down
3 changes: 2 additions & 1 deletion model2vec/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
importable(extra_dependency, _REQUIRED_EXTRA)

from model2vec.train.classifier import StaticModelForClassification
from model2vec.train.similarity import StaticModelForSimilarity

__all__ = ["StaticModelForClassification"]
__all__ = ["StaticModelForClassification", "StaticModelForSimilarity"]
Loading
Loading