Skip to content

feat: refactor training#316

Open
stephantul wants to merge 6 commits intomainfrom
revamp-training
Open

feat: refactor training#316
stephantul wants to merge 6 commits intomainfrom
revamp-training

Conversation

@stephantul
Copy link
Copy Markdown
Contributor

This PR refactors a large part of the training logic. The base class is now much more feature complete, and takes care of everything that does not have labeled data. The classifier has become much smaller, and only takes care of data that has labels.

Here's a concise summary of the changes:

  • We can now finetune static models on vectors. This is called StaticModelForSimilarity. This currently only supports cosine similarity training.
  • The lightning modules for the classifier and multilabel classifier have been split up, and I added a new lightning module for similarity. The classifier lightning modules derive from the similarity ones
  • Much of the logic has been consolidated in the base class.
  • I added a normalize argument to the initializer. This controls whether the model encoding is normalized.
  • I changed the named arguments in from_pretrained to mimic those in StaticModel and added a deprecation warning to the old argument.
  • I added a new argument called "balanced" to class_weight, added support for dicts, and deprecated support for tensors.
  • I added a new argument called "validation_steps" to the fit method, which controls how many steps are between validation splits.
  • StaticModelPipeline now supports projection, backed by a scikit learn MLPRegressor. This allows us to create static models with additional dense layers and put them in sklear pipelines.

And maybe some other changes I forgot. Not everything is fully covered by tests, but most of it is.

@stephantul stephantul requested a review from Pringled March 27, 2026 15:37
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 27, 2026

Codecov Report

❌ Patch coverage is 95.84570% with 14 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
model2vec/train/classifier.py 82.97% 8 Missing ⚠️
model2vec/train/base.py 96.00% 5 Missing ⚠️
model2vec/train/lightning_modules.py 98.41% 1 Missing ⚠️
Files with missing lines Coverage Δ
model2vec/inference/model.py 93.27% <100.00%> (+0.55%) ⬆️
model2vec/train/__init__.py 100.00% <100.00%> (ø)
model2vec/train/dataset.py 100.00% <100.00%> (ø)
model2vec/train/similarity.py 100.00% <100.00%> (ø)
model2vec/train/utils.py 100.00% <100.00%> (ø)
model2vec/train/lightning_modules.py 98.41% <98.41%> (ø)
model2vec/train/base.py 96.96% <96.00%> (-1.02%) ⬇️
model2vec/train/classifier.py 90.65% <82.97%> (-6.97%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant