diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..2ea7362 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,51 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + test: + name: Test (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install package + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,onnx]" + + - name: Run tests + run: pytest -q + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + + - name: Install ruff + run: pip install ruff + + - name: Ruff check + run: ruff check . diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..b29625d --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,63 @@ +# Changelog + +All notable changes to Comprexx are documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- GitHub Actions CI workflow running `pytest` on Python 3.10, 3.11, 3.12 plus a + `ruff check` lint job. +- `CHANGELOG.md` with history for v0.1.0 and v0.2.0. + +### Changed +- Silenced the noisy `torch.ao.quantization is deprecated` warning inside the + PTQ dynamic and static stages. The underlying API is still used, with a + `TODO(v0.3)` marking the upcoming migration to `torchao.quantization`. +- Fixed the package `__version__` to report `0.2.0` instead of the stale + `0.1.0` that shipped on PyPI. +- Tightened the codebase against `ruff check` and added a per-file ignore + for `E741` in tests. + +## [0.2.0] - 2026-04-07 + +### Added +- **Unstructured pruning** stage: magnitude or random element-wise pruning + with global/local scope and optional gradual cubic schedule. +- **N:M sparsity** stage: structured N-of-M sparsity (default 2:4) for + NVIDIA Ampere sparse tensor cores. +- **Weight-only quantization** stage: group-wise INT4/INT8 with symmetric + or asymmetric scaling for Linear and Conv2d layers. +- **Low-rank decomposition** stage: truncated SVD factorization of Linear + layers, with fixed rank-ratio or energy-threshold selection modes. +- **Operator fusion** stage: Conv2d + BatchNorm2d folding via `torch.fx` + with graceful fallback on non-traceable models. +- **Weight clustering** stage: per-layer k-means codebook clustering. +- **`cx.analyze_sensitivity()`**: per-layer sensitivity probing via prune + or noise perturbation. Returns a `SensitivityReport` that ranks layers + by metric drop and can suggest `exclude_layers` above a threshold. +- New techniques are wired through the recipe schema and loader, and + exported from `comprexx.stages`. + +### Tests +- 163 passing (up from 91). + +## [0.1.0] - 2026-04-06 + +Initial release. + +### Added +- Model analysis and profiling via `cx.analyze()`. +- Structured pruning with L1/L2/random criteria and global/local scope. +- Post-training dynamic and static INT8 quantization. +- ONNX export with manifest and optional `onnxruntime` validation. +- Recipe-driven pipelines (YAML) validated via Pydantic. +- CLI commands: `comprexx analyze`, `compress`, `export`. +- Accuracy guards with halt/warn actions. +- Per-stage compression reports persisted under `comprexx_runs/`. + +[Unreleased]: https://github.com/cachevector/comprexx/compare/v0.2.0...HEAD +[0.2.0]: https://github.com/cachevector/comprexx/compare/v0.1.0...v0.2.0 +[0.1.0]: https://github.com/cachevector/comprexx/releases/tag/v0.1.0 diff --git a/comprexx/__init__.py b/comprexx/__init__.py index 26c59ef..fa15f19 100644 --- a/comprexx/__init__.py +++ b/comprexx/__init__.py @@ -1,7 +1,8 @@ """Comprexx — ML Model Compression Toolkit.""" -__version__ = "0.1.0" +__version__ = "0.2.0" +from comprexx import stages from comprexx.analysis.profiler import ModelProfile, analyze from comprexx.analysis.sensitivity import ( LayerSensitivity, @@ -23,8 +24,6 @@ from comprexx.export.onnx import ONNXExporter from comprexx.recipe.loader import load_recipe -from comprexx import stages - __all__ = [ "AccuracyGuard", "AccuracyGuardTriggered", diff --git a/comprexx/analysis/flops.py b/comprexx/analysis/flops.py index a0b2d10..81cec47 100644 --- a/comprexx/analysis/flops.py +++ b/comprexx/analysis/flops.py @@ -11,7 +11,9 @@ def _conv_flops(module: nn.Conv2d, input: torch.Tensor, output: torch.Tensor) -> int: batch_size = output.shape[0] out_h, out_w = output.shape[2], output.shape[3] - kernel_ops = module.kernel_size[0] * module.kernel_size[1] * (module.in_channels // module.groups) + kernel_ops = ( + module.kernel_size[0] * module.kernel_size[1] * (module.in_channels // module.groups) + ) # 2 ops per multiply-add flops = 2 * batch_size * module.out_channels * out_h * out_w * kernel_ops if module.bias is not None: @@ -35,7 +37,10 @@ def _bn_flops(module: nn.BatchNorm2d, input: torch.Tensor, output: torch.Tensor) _FLOP_HANDLERS: dict[type, Any] = { nn.Conv2d: _conv_flops, - nn.Conv1d: lambda m, i, o: 2 * o.shape[0] * m.out_channels * o.shape[2] * m.kernel_size[0] * (m.in_channels // m.groups), + nn.Conv1d: lambda m, i, o: ( + 2 * o.shape[0] * m.out_channels * o.shape[2] + * m.kernel_size[0] * (m.in_channels // m.groups) + ), nn.Linear: _linear_flops, nn.BatchNorm2d: _bn_flops, nn.BatchNorm1d: lambda m, i, o: 4 * i.numel(), diff --git a/comprexx/analysis/profiler.py b/comprexx/analysis/profiler.py index abc3304..56115fa 100644 --- a/comprexx/analysis/profiler.py +++ b/comprexx/analysis/profiler.py @@ -3,11 +3,10 @@ from __future__ import annotations import json -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Optional -import torch import torch.nn as nn from comprexx.analysis.flops import count_flops @@ -52,7 +51,7 @@ def size_mb(self) -> float: return self.size_bytes / (1024 * 1024) def compressible_layers(self) -> list[LayerInfo]: - return [l for l in self.layers if l.is_compressible] + return [layer for layer in self.layers if layer.is_compressible] def summary(self) -> str: lines = [ @@ -76,7 +75,7 @@ def to_dict(self) -> dict: "size_bytes": self.size_bytes, "size_mb": self.size_mb, "architecture_category": self.architecture_category, - "layers": [l.to_dict() for l in self.layers], + "layers": [layer.to_dict() for layer in self.layers], } def to_json(self) -> str: diff --git a/comprexx/analysis/sensitivity.py b/comprexx/analysis/sensitivity.py index 1a49e41..47a3157 100644 --- a/comprexx/analysis/sensitivity.py +++ b/comprexx/analysis/sensitivity.py @@ -13,7 +13,6 @@ from __future__ import annotations -import copy from dataclasses import asdict, dataclass, field from typing import Callable, Literal @@ -48,16 +47,16 @@ class SensitivityReport: def most_sensitive(self, n: int = 5) -> list[LayerSensitivity]: """Layers with the largest metric drops (in descending order).""" - return sorted(self.layers, key=lambda l: l.metric_drop, reverse=True)[:n] + return sorted(self.layers, key=lambda x: x.metric_drop, reverse=True)[:n] def most_tolerant(self, n: int = 5) -> list[LayerSensitivity]: """Layers with the smallest metric drops.""" - return sorted(self.layers, key=lambda l: l.metric_drop)[:n] + return sorted(self.layers, key=lambda x: x.metric_drop)[:n] def recommend_exclusions(self, threshold: float) -> list[str]: - """Names of layers whose drop exceeds `threshold` — candidates for + """Names of layers whose drop exceeds `threshold`: candidates for `exclude_layers` in a pruning/quantization stage.""" - return [l.name for l in self.layers if l.metric_drop > threshold] + return [x.name for x in self.layers if x.metric_drop > threshold] def to_dict(self) -> dict: return { @@ -65,7 +64,7 @@ def to_dict(self) -> dict: "perturbation": self.perturbation, "intensity": self.intensity, "baseline_metric": self.baseline_metric, - "layers": [l.to_dict() for l in self.layers], + "layers": [layer.to_dict() for layer in self.layers], } def summary(self) -> str: @@ -75,14 +74,16 @@ def summary(self) -> str: f" {len(self.layers)} layer(s) analyzed", " most sensitive:", ] - for l in self.most_sensitive(5): + for layer in self.most_sensitive(5): lines.append( - f" {l.name:40s} drop={l.metric_drop:+.4f} params={l.num_params:,}" + f" {layer.name:40s} drop={layer.metric_drop:+.4f} " + f"params={layer.num_params:,}" ) lines.append(" most tolerant:") - for l in self.most_tolerant(5): + for layer in self.most_tolerant(5): lines.append( - f" {l.name:40s} drop={l.metric_drop:+.4f} params={l.num_params:,}" + f" {layer.name:40s} drop={layer.metric_drop:+.4f} " + f"params={layer.num_params:,}" ) return "\n".join(lines) diff --git a/comprexx/cli/main.py b/comprexx/cli/main.py index a662248..335ecb4 100644 --- a/comprexx/cli/main.py +++ b/comprexx/cli/main.py @@ -3,8 +3,6 @@ from __future__ import annotations import importlib -import json -import sys from pathlib import Path from typing import Optional @@ -171,7 +169,7 @@ def export_cmd( out_path = Path(output_dir) / "model.onnx" exporter = ONNXExporter() with console.status("Exporting to ONNX..."): - manifest = exporter.export(model, input_shape=shape, output_path=str(out_path)) + exporter.export(model, input_shape=shape, output_path=str(out_path)) console.print(f"[green]Exported to {out_path}[/green]") console.print(f"Manifest: {Path(output_dir) / 'comprexx_manifest.json'}") else: diff --git a/comprexx/core/pipeline.py b/comprexx/core/pipeline.py index 1d10905..68beb36 100644 --- a/comprexx/core/pipeline.py +++ b/comprexx/core/pipeline.py @@ -3,7 +3,7 @@ from __future__ import annotations import time -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Callable, Optional diff --git a/comprexx/core/report.py b/comprexx/core/report.py index e287ccf..d1fa205 100644 --- a/comprexx/core/report.py +++ b/comprexx/core/report.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Optional diff --git a/comprexx/export/manifest.py b/comprexx/export/manifest.py index 478868d..cdcf62b 100644 --- a/comprexx/export/manifest.py +++ b/comprexx/export/manifest.py @@ -4,7 +4,7 @@ import hashlib import json -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Optional diff --git a/comprexx/stages/decomposition/low_rank.py b/comprexx/stages/decomposition/low_rank.py index 224555a..c1f672f 100644 --- a/comprexx/stages/decomposition/low_rank.py +++ b/comprexx/stages/decomposition/low_rank.py @@ -129,7 +129,8 @@ def apply( skipped_no_gain += 1 continue - first, second = _svd_factorize(w, module.bias.data if module.bias is not None else None, rank) + bias = module.bias.data if module.bias is not None else None + first, second = _svd_factorize(w, bias, rank) replacement = nn.Sequential(first, second) _replace_module(model, name, replacement) decomposed += 1 diff --git a/comprexx/stages/pruning/structured.py b/comprexx/stages/pruning/structured.py index 9d03eec..fc135f7 100644 --- a/comprexx/stages/pruning/structured.py +++ b/comprexx/stages/pruning/structured.py @@ -140,7 +140,8 @@ def _prune_global( mask[idx] = 0.0 # Apply structured mask along dim 0 (output channels / filters) - prune.custom_from_mask(module, "weight", mask.view(-1, 1, 1, 1).expand_as(module.weight)) + full_mask = mask.view(-1, 1, 1, 1).expand_as(module.weight) + prune.custom_from_mask(module, "weight", full_mask) notes.append(f"Global pruning: zeroed {n_prune} filters across {len(to_prune)} layers.") diff --git a/comprexx/stages/pruning/unstructured.py b/comprexx/stages/pruning/unstructured.py index 246d776..cd699a6 100644 --- a/comprexx/stages/pruning/unstructured.py +++ b/comprexx/stages/pruning/unstructured.py @@ -6,7 +6,6 @@ import time from typing import Literal -import torch import torch.nn as nn import torch.nn.utils.prune as prune from pydantic import BaseModel, Field diff --git a/comprexx/stages/quantization/ptq_dynamic.py b/comprexx/stages/quantization/ptq_dynamic.py index 6eb30d7..2902df6 100644 --- a/comprexx/stages/quantization/ptq_dynamic.py +++ b/comprexx/stages/quantization/ptq_dynamic.py @@ -4,6 +4,7 @@ import copy import time +import warnings from typing import Literal import torch @@ -14,6 +15,10 @@ from comprexx.core.report import StageReport from comprexx.stages.base import CompressionStage, StageContext +# TODO(v0.3): migrate to torchao.quantization. torch.ao.quantization is +# scheduled for removal in torch 2.10. Tracked at pytorch/ao#2259. +_TORCH_AO_WARN = r"torch\.ao\.quantization is deprecated" + class PTQDynamicConfig(BaseModel): """Configuration for dynamic quantization.""" @@ -47,13 +52,19 @@ def apply( # Profile before profile_before = analyze(model, context.input_shape, context.device) - # Apply dynamic quantization + # Apply dynamic quantization. Silence the torch.ao.quantization + # deprecation warning — it's noise for our users, and we track the + # migration in the TODO above. dtype = torch.qint8 - quantized_model = torch.quantization.quantize_dynamic( - model, - qconfig_spec={nn.Linear, nn.LSTM}, - dtype=dtype, - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=_TORCH_AO_WARN, category=DeprecationWarning + ) + quantized_model = torch.quantization.quantize_dynamic( + model, + qconfig_spec={nn.Linear, nn.LSTM}, + dtype=dtype, + ) # Profile after — size calculation for quantized models size_after = _quantized_model_size(quantized_model) diff --git a/comprexx/stages/quantization/ptq_static.py b/comprexx/stages/quantization/ptq_static.py index cbb9c82..738be4a 100644 --- a/comprexx/stages/quantization/ptq_static.py +++ b/comprexx/stages/quantization/ptq_static.py @@ -4,6 +4,7 @@ import copy import time +import warnings from typing import Literal import torch @@ -15,6 +16,10 @@ from comprexx.core.report import StageReport from comprexx.stages.base import CompressionStage, StageContext +# TODO(v0.3): migrate to torchao.quantization. torch.ao.quantization is +# scheduled for removal in torch 2.10. Tracked at pytorch/ao#2259. +_TORCH_AO_WARN = r"torch\.ao\.quantization is deprecated" + class PTQStaticConfig(BaseModel): """Configuration for static quantization.""" @@ -68,25 +73,28 @@ def apply( # Fuse common patterns if possible model = _try_fuse(model) - # Prepare: insert observers - prepared = torch.quantization.prepare(model, inplace=False) - - # Calibration: run forward passes to collect statistics - samples_seen = 0 - with torch.no_grad(): - for batch in context.calibration_data: - if samples_seen >= self.config.calibration_samples: - break - if isinstance(batch, (list, tuple)): - x = batch[0] - else: - x = batch - x = x.to(context.device) - prepared(x) - samples_seen += x.shape[0] - - # Convert: replace observers with quantized ops - quantized = torch.quantization.convert(prepared, inplace=False) + # Prepare → calibrate → convert. Silence the torch.ao.quantization + # deprecation warning — it's noise for users, tracked as a v0.3 TODO. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=_TORCH_AO_WARN, category=DeprecationWarning + ) + prepared = torch.quantization.prepare(model, inplace=False) + + samples_seen = 0 + with torch.no_grad(): + for batch in context.calibration_data: + if samples_seen >= self.config.calibration_samples: + break + if isinstance(batch, (list, tuple)): + x = batch[0] + else: + x = batch + x = x.to(context.device) + prepared(x) + samples_seen += x.shape[0] + + quantized = torch.quantization.convert(prepared, inplace=False) # Estimate quantized size size_after = _estimate_quantized_size(quantized, profile_before.size_bytes) diff --git a/pyproject.toml b/pyproject.toml index 92964f7..3ed1def 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ target-version = "py310" [tool.ruff.lint] select = ["E", "F", "I", "W"] +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["E741"] + [tool.pytest.ini_options] testpaths = ["tests"] markers = [ diff --git a/tests/conftest.py b/tests/conftest.py index 5f677bb..f57f4ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import torch from torch.utils.data import DataLoader, TensorDataset -from tests.fixtures.models import tiny_cnn, tiny_transformer, tiny_rnn +from tests.fixtures.models import tiny_cnn, tiny_rnn, tiny_transformer @pytest.fixture diff --git a/tests/unit/test_low_rank.py b/tests/unit/test_low_rank.py index 26883a8..2d5a5b5 100644 --- a/tests/unit/test_low_rank.py +++ b/tests/unit/test_low_rank.py @@ -6,9 +6,9 @@ from comprexx.stages.base import StageContext from comprexx.stages.decomposition.low_rank import ( LowRankDecomposition, + LowRankDecompositionConfig, _choose_rank, _svd_factorize, - LowRankDecompositionConfig, ) diff --git a/tests/unit/test_onnx_export.py b/tests/unit/test_onnx_export.py index c80181b..733a010 100644 --- a/tests/unit/test_onnx_export.py +++ b/tests/unit/test_onnx_export.py @@ -1,7 +1,6 @@ """Tests for ONNX export.""" import json -from pathlib import Path import pytest import torch diff --git a/tests/unit/test_ptq_dynamic.py b/tests/unit/test_ptq_dynamic.py index 706cd15..2be2be1 100644 --- a/tests/unit/test_ptq_dynamic.py +++ b/tests/unit/test_ptq_dynamic.py @@ -49,4 +49,4 @@ def test_original_model_unchanged(self): stage.apply(model, ctx) # Original model layers should not be replaced - assert type(list(model.modules())[1]) == original_type + assert type(list(model.modules())[1]) is original_type diff --git a/tests/unit/test_recipe.py b/tests/unit/test_recipe.py index 4c171e9..3215c9c 100644 --- a/tests/unit/test_recipe.py +++ b/tests/unit/test_recipe.py @@ -6,7 +6,6 @@ from comprexx.recipe.loader import load_recipe, recipe_to_pipeline from comprexx.recipe.schema import RecipeV1 - VALID_RECIPE_YAML = """\ name: test-recipe version: "1.0" diff --git a/tests/unit/test_weight_only_quant.py b/tests/unit/test_weight_only_quant.py index a3f49ca..62e5a8f 100644 --- a/tests/unit/test_weight_only_quant.py +++ b/tests/unit/test_weight_only_quant.py @@ -1,7 +1,6 @@ """Tests for weight-only quantization stage.""" import torch -import torch.nn as nn from comprexx.stages.base import StageContext from comprexx.stages.quantization.weight_only import (