Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: CI

on:
push:
branches: [master]
pull_request:
branches: [master]

jobs:
test:
name: Test (Python ${{ matrix.python-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip

- name: Install package
run: |
python -m pip install --upgrade pip
pip install -e ".[dev,onnx]"

- name: Run tests
run: pytest -q

lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: pip

- name: Install ruff
run: pip install ruff

- name: Ruff check
run: ruff check .
63 changes: 63 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Changelog

All notable changes to Comprexx are documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added
- GitHub Actions CI workflow running `pytest` on Python 3.10, 3.11, 3.12 plus a
`ruff check` lint job.
- `CHANGELOG.md` with history for v0.1.0 and v0.2.0.

### Changed
- Silenced the noisy `torch.ao.quantization is deprecated` warning inside the
PTQ dynamic and static stages. The underlying API is still used, with a
`TODO(v0.3)` marking the upcoming migration to `torchao.quantization`.
- Fixed the package `__version__` to report `0.2.0` instead of the stale
`0.1.0` that shipped on PyPI.
- Tightened the codebase against `ruff check` and added a per-file ignore
for `E741` in tests.

## [0.2.0] - 2026-04-07

### Added
- **Unstructured pruning** stage: magnitude or random element-wise pruning
with global/local scope and optional gradual cubic schedule.
- **N:M sparsity** stage: structured N-of-M sparsity (default 2:4) for
NVIDIA Ampere sparse tensor cores.
- **Weight-only quantization** stage: group-wise INT4/INT8 with symmetric
or asymmetric scaling for Linear and Conv2d layers.
- **Low-rank decomposition** stage: truncated SVD factorization of Linear
layers, with fixed rank-ratio or energy-threshold selection modes.
- **Operator fusion** stage: Conv2d + BatchNorm2d folding via `torch.fx`
with graceful fallback on non-traceable models.
- **Weight clustering** stage: per-layer k-means codebook clustering.
- **`cx.analyze_sensitivity()`**: per-layer sensitivity probing via prune
or noise perturbation. Returns a `SensitivityReport` that ranks layers
by metric drop and can suggest `exclude_layers` above a threshold.
- New techniques are wired through the recipe schema and loader, and
exported from `comprexx.stages`.

### Tests
- 163 passing (up from 91).

## [0.1.0] - 2026-04-06

Initial release.

### Added
- Model analysis and profiling via `cx.analyze()`.
- Structured pruning with L1/L2/random criteria and global/local scope.
- Post-training dynamic and static INT8 quantization.
- ONNX export with manifest and optional `onnxruntime` validation.
- Recipe-driven pipelines (YAML) validated via Pydantic.
- CLI commands: `comprexx analyze`, `compress`, `export`.
- Accuracy guards with halt/warn actions.
- Per-stage compression reports persisted under `comprexx_runs/`.

[Unreleased]: https://github.com/cachevector/comprexx/compare/v0.2.0...HEAD
[0.2.0]: https://github.com/cachevector/comprexx/compare/v0.1.0...v0.2.0
[0.1.0]: https://github.com/cachevector/comprexx/releases/tag/v0.1.0
5 changes: 2 additions & 3 deletions comprexx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Comprexx — ML Model Compression Toolkit."""

__version__ = "0.1.0"
__version__ = "0.2.0"

from comprexx import stages
from comprexx.analysis.profiler import ModelProfile, analyze
from comprexx.analysis.sensitivity import (
LayerSensitivity,
Expand All @@ -23,8 +24,6 @@
from comprexx.export.onnx import ONNXExporter
from comprexx.recipe.loader import load_recipe

from comprexx import stages

__all__ = [
"AccuracyGuard",
"AccuracyGuardTriggered",
Expand Down
9 changes: 7 additions & 2 deletions comprexx/analysis/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
def _conv_flops(module: nn.Conv2d, input: torch.Tensor, output: torch.Tensor) -> int:
batch_size = output.shape[0]
out_h, out_w = output.shape[2], output.shape[3]
kernel_ops = module.kernel_size[0] * module.kernel_size[1] * (module.in_channels // module.groups)
kernel_ops = (
module.kernel_size[0] * module.kernel_size[1] * (module.in_channels // module.groups)
)
# 2 ops per multiply-add
flops = 2 * batch_size * module.out_channels * out_h * out_w * kernel_ops
if module.bias is not None:
Expand All @@ -35,7 +37,10 @@ def _bn_flops(module: nn.BatchNorm2d, input: torch.Tensor, output: torch.Tensor)

_FLOP_HANDLERS: dict[type, Any] = {
nn.Conv2d: _conv_flops,
nn.Conv1d: lambda m, i, o: 2 * o.shape[0] * m.out_channels * o.shape[2] * m.kernel_size[0] * (m.in_channels // m.groups),
nn.Conv1d: lambda m, i, o: (
2 * o.shape[0] * m.out_channels * o.shape[2]
* m.kernel_size[0] * (m.in_channels // m.groups)
),
nn.Linear: _linear_flops,
nn.BatchNorm2d: _bn_flops,
nn.BatchNorm1d: lambda m, i, o: 4 * i.numel(),
Expand Down
7 changes: 3 additions & 4 deletions comprexx/analysis/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from __future__ import annotations

import json
from dataclasses import dataclass, field, asdict
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Optional

import torch
import torch.nn as nn

from comprexx.analysis.flops import count_flops
Expand Down Expand Up @@ -52,7 +51,7 @@ def size_mb(self) -> float:
return self.size_bytes / (1024 * 1024)

def compressible_layers(self) -> list[LayerInfo]:
return [l for l in self.layers if l.is_compressible]
return [layer for layer in self.layers if layer.is_compressible]

def summary(self) -> str:
lines = [
Expand All @@ -76,7 +75,7 @@ def to_dict(self) -> dict:
"size_bytes": self.size_bytes,
"size_mb": self.size_mb,
"architecture_category": self.architecture_category,
"layers": [l.to_dict() for l in self.layers],
"layers": [layer.to_dict() for layer in self.layers],
}

def to_json(self) -> str:
Expand Down
21 changes: 11 additions & 10 deletions comprexx/analysis/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from __future__ import annotations

import copy
from dataclasses import asdict, dataclass, field
from typing import Callable, Literal

Expand Down Expand Up @@ -48,24 +47,24 @@ class SensitivityReport:

def most_sensitive(self, n: int = 5) -> list[LayerSensitivity]:
"""Layers with the largest metric drops (in descending order)."""
return sorted(self.layers, key=lambda l: l.metric_drop, reverse=True)[:n]
return sorted(self.layers, key=lambda x: x.metric_drop, reverse=True)[:n]

def most_tolerant(self, n: int = 5) -> list[LayerSensitivity]:
"""Layers with the smallest metric drops."""
return sorted(self.layers, key=lambda l: l.metric_drop)[:n]
return sorted(self.layers, key=lambda x: x.metric_drop)[:n]

def recommend_exclusions(self, threshold: float) -> list[str]:
"""Names of layers whose drop exceeds `threshold` candidates for
"""Names of layers whose drop exceeds `threshold`: candidates for
`exclude_layers` in a pruning/quantization stage."""
return [l.name for l in self.layers if l.metric_drop > threshold]
return [x.name for x in self.layers if x.metric_drop > threshold]

def to_dict(self) -> dict:
return {
"metric_name": self.metric_name,
"perturbation": self.perturbation,
"intensity": self.intensity,
"baseline_metric": self.baseline_metric,
"layers": [l.to_dict() for l in self.layers],
"layers": [layer.to_dict() for layer in self.layers],
}

def summary(self) -> str:
Expand All @@ -75,14 +74,16 @@ def summary(self) -> str:
f" {len(self.layers)} layer(s) analyzed",
" most sensitive:",
]
for l in self.most_sensitive(5):
for layer in self.most_sensitive(5):
lines.append(
f" {l.name:40s} drop={l.metric_drop:+.4f} params={l.num_params:,}"
f" {layer.name:40s} drop={layer.metric_drop:+.4f} "
f"params={layer.num_params:,}"
)
lines.append(" most tolerant:")
for l in self.most_tolerant(5):
for layer in self.most_tolerant(5):
lines.append(
f" {l.name:40s} drop={l.metric_drop:+.4f} params={l.num_params:,}"
f" {layer.name:40s} drop={layer.metric_drop:+.4f} "
f"params={layer.num_params:,}"
)
return "\n".join(lines)

Expand Down
4 changes: 1 addition & 3 deletions comprexx/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from __future__ import annotations

import importlib
import json
import sys
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -171,7 +169,7 @@ def export_cmd(
out_path = Path(output_dir) / "model.onnx"
exporter = ONNXExporter()
with console.status("Exporting to ONNX..."):
manifest = exporter.export(model, input_shape=shape, output_path=str(out_path))
exporter.export(model, input_shape=shape, output_path=str(out_path))
console.print(f"[green]Exported to {out_path}[/green]")
console.print(f"Manifest: {Path(output_dir) / 'comprexx_manifest.json'}")
else:
Expand Down
2 changes: 1 addition & 1 deletion comprexx/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import time
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Callable, Optional
Expand Down
2 changes: 1 addition & 1 deletion comprexx/core/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import json
from dataclasses import dataclass, field, asdict
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Optional

Expand Down
2 changes: 1 addition & 1 deletion comprexx/export/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import hashlib
import json
from dataclasses import dataclass, field, asdict
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
Expand Down
3 changes: 2 additions & 1 deletion comprexx/stages/decomposition/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def apply(
skipped_no_gain += 1
continue

first, second = _svd_factorize(w, module.bias.data if module.bias is not None else None, rank)
bias = module.bias.data if module.bias is not None else None
first, second = _svd_factorize(w, bias, rank)
replacement = nn.Sequential(first, second)
_replace_module(model, name, replacement)
decomposed += 1
Expand Down
3 changes: 2 additions & 1 deletion comprexx/stages/pruning/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def _prune_global(
mask[idx] = 0.0

# Apply structured mask along dim 0 (output channels / filters)
prune.custom_from_mask(module, "weight", mask.view(-1, 1, 1, 1).expand_as(module.weight))
full_mask = mask.view(-1, 1, 1, 1).expand_as(module.weight)
prune.custom_from_mask(module, "weight", full_mask)

notes.append(f"Global pruning: zeroed {n_prune} filters across {len(to_prune)} layers.")

Expand Down
1 change: 0 additions & 1 deletion comprexx/stages/pruning/unstructured.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
from typing import Literal

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from pydantic import BaseModel, Field
Expand Down
23 changes: 17 additions & 6 deletions comprexx/stages/quantization/ptq_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import copy
import time
import warnings
from typing import Literal

import torch
Expand All @@ -14,6 +15,10 @@
from comprexx.core.report import StageReport
from comprexx.stages.base import CompressionStage, StageContext

# TODO(v0.3): migrate to torchao.quantization. torch.ao.quantization is
# scheduled for removal in torch 2.10. Tracked at pytorch/ao#2259.
_TORCH_AO_WARN = r"torch\.ao\.quantization is deprecated"


class PTQDynamicConfig(BaseModel):
"""Configuration for dynamic quantization."""
Expand Down Expand Up @@ -47,13 +52,19 @@ def apply(
# Profile before
profile_before = analyze(model, context.input_shape, context.device)

# Apply dynamic quantization
# Apply dynamic quantization. Silence the torch.ao.quantization
# deprecation warning — it's noise for our users, and we track the
# migration in the TODO above.
dtype = torch.qint8
quantized_model = torch.quantization.quantize_dynamic(
model,
qconfig_spec={nn.Linear, nn.LSTM},
dtype=dtype,
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=_TORCH_AO_WARN, category=DeprecationWarning
)
quantized_model = torch.quantization.quantize_dynamic(
model,
qconfig_spec={nn.Linear, nn.LSTM},
dtype=dtype,
)

# Profile after — size calculation for quantized models
size_after = _quantized_model_size(quantized_model)
Expand Down
Loading
Loading