From d3187975b2d5b03516113525cda4affbcfc74f86 Mon Sep 17 00:00:00 2001 From: maskedsyntax Date: Thu, 9 Apr 2026 22:01:38 +0530 Subject: [PATCH 1/4] clean up lint and bump __version__ to 0.2.0 - ruff check is now clean across the repo (0 errors) - Added a per-file ignore for E741 in tests so short loop vars like 'l' stay allowed there - Fixed a handful of long lines, an unused local in the CLI, and a type comparison in a PTQ dynamic test - comprexx.__version__ was still pinned at 0.1.0 after the v0.2 release; bumped to match pyproject.toml --- comprexx/__init__.py | 5 ++--- comprexx/analysis/flops.py | 9 +++++++-- comprexx/analysis/profiler.py | 7 +++---- comprexx/analysis/sensitivity.py | 21 +++++++++++---------- comprexx/cli/main.py | 4 +--- comprexx/core/pipeline.py | 2 +- comprexx/core/report.py | 2 +- comprexx/export/manifest.py | 2 +- comprexx/stages/decomposition/low_rank.py | 3 ++- comprexx/stages/pruning/structured.py | 3 ++- comprexx/stages/pruning/unstructured.py | 1 - pyproject.toml | 3 +++ tests/conftest.py | 2 +- tests/unit/test_low_rank.py | 2 +- tests/unit/test_onnx_export.py | 1 - tests/unit/test_ptq_dynamic.py | 2 +- tests/unit/test_recipe.py | 1 - tests/unit/test_weight_only_quant.py | 1 - 18 files changed, 37 insertions(+), 34 deletions(-) diff --git a/comprexx/__init__.py b/comprexx/__init__.py index 26c59ef..fa15f19 100644 --- a/comprexx/__init__.py +++ b/comprexx/__init__.py @@ -1,7 +1,8 @@ """Comprexx — ML Model Compression Toolkit.""" -__version__ = "0.1.0" +__version__ = "0.2.0" +from comprexx import stages from comprexx.analysis.profiler import ModelProfile, analyze from comprexx.analysis.sensitivity import ( LayerSensitivity, @@ -23,8 +24,6 @@ from comprexx.export.onnx import ONNXExporter from comprexx.recipe.loader import load_recipe -from comprexx import stages - __all__ = [ "AccuracyGuard", "AccuracyGuardTriggered", diff --git a/comprexx/analysis/flops.py b/comprexx/analysis/flops.py index a0b2d10..81cec47 100644 --- a/comprexx/analysis/flops.py +++ b/comprexx/analysis/flops.py @@ -11,7 +11,9 @@ def _conv_flops(module: nn.Conv2d, input: torch.Tensor, output: torch.Tensor) -> int: batch_size = output.shape[0] out_h, out_w = output.shape[2], output.shape[3] - kernel_ops = module.kernel_size[0] * module.kernel_size[1] * (module.in_channels // module.groups) + kernel_ops = ( + module.kernel_size[0] * module.kernel_size[1] * (module.in_channels // module.groups) + ) # 2 ops per multiply-add flops = 2 * batch_size * module.out_channels * out_h * out_w * kernel_ops if module.bias is not None: @@ -35,7 +37,10 @@ def _bn_flops(module: nn.BatchNorm2d, input: torch.Tensor, output: torch.Tensor) _FLOP_HANDLERS: dict[type, Any] = { nn.Conv2d: _conv_flops, - nn.Conv1d: lambda m, i, o: 2 * o.shape[0] * m.out_channels * o.shape[2] * m.kernel_size[0] * (m.in_channels // m.groups), + nn.Conv1d: lambda m, i, o: ( + 2 * o.shape[0] * m.out_channels * o.shape[2] + * m.kernel_size[0] * (m.in_channels // m.groups) + ), nn.Linear: _linear_flops, nn.BatchNorm2d: _bn_flops, nn.BatchNorm1d: lambda m, i, o: 4 * i.numel(), diff --git a/comprexx/analysis/profiler.py b/comprexx/analysis/profiler.py index abc3304..56115fa 100644 --- a/comprexx/analysis/profiler.py +++ b/comprexx/analysis/profiler.py @@ -3,11 +3,10 @@ from __future__ import annotations import json -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Optional -import torch import torch.nn as nn from comprexx.analysis.flops import count_flops @@ -52,7 +51,7 @@ def size_mb(self) -> float: return self.size_bytes / (1024 * 1024) def compressible_layers(self) -> list[LayerInfo]: - return [l for l in self.layers if l.is_compressible] + return [layer for layer in self.layers if layer.is_compressible] def summary(self) -> str: lines = [ @@ -76,7 +75,7 @@ def to_dict(self) -> dict: "size_bytes": self.size_bytes, "size_mb": self.size_mb, "architecture_category": self.architecture_category, - "layers": [l.to_dict() for l in self.layers], + "layers": [layer.to_dict() for layer in self.layers], } def to_json(self) -> str: diff --git a/comprexx/analysis/sensitivity.py b/comprexx/analysis/sensitivity.py index 1a49e41..47a3157 100644 --- a/comprexx/analysis/sensitivity.py +++ b/comprexx/analysis/sensitivity.py @@ -13,7 +13,6 @@ from __future__ import annotations -import copy from dataclasses import asdict, dataclass, field from typing import Callable, Literal @@ -48,16 +47,16 @@ class SensitivityReport: def most_sensitive(self, n: int = 5) -> list[LayerSensitivity]: """Layers with the largest metric drops (in descending order).""" - return sorted(self.layers, key=lambda l: l.metric_drop, reverse=True)[:n] + return sorted(self.layers, key=lambda x: x.metric_drop, reverse=True)[:n] def most_tolerant(self, n: int = 5) -> list[LayerSensitivity]: """Layers with the smallest metric drops.""" - return sorted(self.layers, key=lambda l: l.metric_drop)[:n] + return sorted(self.layers, key=lambda x: x.metric_drop)[:n] def recommend_exclusions(self, threshold: float) -> list[str]: - """Names of layers whose drop exceeds `threshold` — candidates for + """Names of layers whose drop exceeds `threshold`: candidates for `exclude_layers` in a pruning/quantization stage.""" - return [l.name for l in self.layers if l.metric_drop > threshold] + return [x.name for x in self.layers if x.metric_drop > threshold] def to_dict(self) -> dict: return { @@ -65,7 +64,7 @@ def to_dict(self) -> dict: "perturbation": self.perturbation, "intensity": self.intensity, "baseline_metric": self.baseline_metric, - "layers": [l.to_dict() for l in self.layers], + "layers": [layer.to_dict() for layer in self.layers], } def summary(self) -> str: @@ -75,14 +74,16 @@ def summary(self) -> str: f" {len(self.layers)} layer(s) analyzed", " most sensitive:", ] - for l in self.most_sensitive(5): + for layer in self.most_sensitive(5): lines.append( - f" {l.name:40s} drop={l.metric_drop:+.4f} params={l.num_params:,}" + f" {layer.name:40s} drop={layer.metric_drop:+.4f} " + f"params={layer.num_params:,}" ) lines.append(" most tolerant:") - for l in self.most_tolerant(5): + for layer in self.most_tolerant(5): lines.append( - f" {l.name:40s} drop={l.metric_drop:+.4f} params={l.num_params:,}" + f" {layer.name:40s} drop={layer.metric_drop:+.4f} " + f"params={layer.num_params:,}" ) return "\n".join(lines) diff --git a/comprexx/cli/main.py b/comprexx/cli/main.py index a662248..335ecb4 100644 --- a/comprexx/cli/main.py +++ b/comprexx/cli/main.py @@ -3,8 +3,6 @@ from __future__ import annotations import importlib -import json -import sys from pathlib import Path from typing import Optional @@ -171,7 +169,7 @@ def export_cmd( out_path = Path(output_dir) / "model.onnx" exporter = ONNXExporter() with console.status("Exporting to ONNX..."): - manifest = exporter.export(model, input_shape=shape, output_path=str(out_path)) + exporter.export(model, input_shape=shape, output_path=str(out_path)) console.print(f"[green]Exported to {out_path}[/green]") console.print(f"Manifest: {Path(output_dir) / 'comprexx_manifest.json'}") else: diff --git a/comprexx/core/pipeline.py b/comprexx/core/pipeline.py index 1d10905..68beb36 100644 --- a/comprexx/core/pipeline.py +++ b/comprexx/core/pipeline.py @@ -3,7 +3,7 @@ from __future__ import annotations import time -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Callable, Optional diff --git a/comprexx/core/report.py b/comprexx/core/report.py index e287ccf..d1fa205 100644 --- a/comprexx/core/report.py +++ b/comprexx/core/report.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Optional diff --git a/comprexx/export/manifest.py b/comprexx/export/manifest.py index 478868d..cdcf62b 100644 --- a/comprexx/export/manifest.py +++ b/comprexx/export/manifest.py @@ -4,7 +4,7 @@ import hashlib import json -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Optional diff --git a/comprexx/stages/decomposition/low_rank.py b/comprexx/stages/decomposition/low_rank.py index 224555a..c1f672f 100644 --- a/comprexx/stages/decomposition/low_rank.py +++ b/comprexx/stages/decomposition/low_rank.py @@ -129,7 +129,8 @@ def apply( skipped_no_gain += 1 continue - first, second = _svd_factorize(w, module.bias.data if module.bias is not None else None, rank) + bias = module.bias.data if module.bias is not None else None + first, second = _svd_factorize(w, bias, rank) replacement = nn.Sequential(first, second) _replace_module(model, name, replacement) decomposed += 1 diff --git a/comprexx/stages/pruning/structured.py b/comprexx/stages/pruning/structured.py index 9d03eec..fc135f7 100644 --- a/comprexx/stages/pruning/structured.py +++ b/comprexx/stages/pruning/structured.py @@ -140,7 +140,8 @@ def _prune_global( mask[idx] = 0.0 # Apply structured mask along dim 0 (output channels / filters) - prune.custom_from_mask(module, "weight", mask.view(-1, 1, 1, 1).expand_as(module.weight)) + full_mask = mask.view(-1, 1, 1, 1).expand_as(module.weight) + prune.custom_from_mask(module, "weight", full_mask) notes.append(f"Global pruning: zeroed {n_prune} filters across {len(to_prune)} layers.") diff --git a/comprexx/stages/pruning/unstructured.py b/comprexx/stages/pruning/unstructured.py index 246d776..cd699a6 100644 --- a/comprexx/stages/pruning/unstructured.py +++ b/comprexx/stages/pruning/unstructured.py @@ -6,7 +6,6 @@ import time from typing import Literal -import torch import torch.nn as nn import torch.nn.utils.prune as prune from pydantic import BaseModel, Field diff --git a/pyproject.toml b/pyproject.toml index 92964f7..3ed1def 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ target-version = "py310" [tool.ruff.lint] select = ["E", "F", "I", "W"] +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["E741"] + [tool.pytest.ini_options] testpaths = ["tests"] markers = [ diff --git a/tests/conftest.py b/tests/conftest.py index 5f677bb..f57f4ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import torch from torch.utils.data import DataLoader, TensorDataset -from tests.fixtures.models import tiny_cnn, tiny_transformer, tiny_rnn +from tests.fixtures.models import tiny_cnn, tiny_rnn, tiny_transformer @pytest.fixture diff --git a/tests/unit/test_low_rank.py b/tests/unit/test_low_rank.py index 26883a8..2d5a5b5 100644 --- a/tests/unit/test_low_rank.py +++ b/tests/unit/test_low_rank.py @@ -6,9 +6,9 @@ from comprexx.stages.base import StageContext from comprexx.stages.decomposition.low_rank import ( LowRankDecomposition, + LowRankDecompositionConfig, _choose_rank, _svd_factorize, - LowRankDecompositionConfig, ) diff --git a/tests/unit/test_onnx_export.py b/tests/unit/test_onnx_export.py index c80181b..733a010 100644 --- a/tests/unit/test_onnx_export.py +++ b/tests/unit/test_onnx_export.py @@ -1,7 +1,6 @@ """Tests for ONNX export.""" import json -from pathlib import Path import pytest import torch diff --git a/tests/unit/test_ptq_dynamic.py b/tests/unit/test_ptq_dynamic.py index 706cd15..2be2be1 100644 --- a/tests/unit/test_ptq_dynamic.py +++ b/tests/unit/test_ptq_dynamic.py @@ -49,4 +49,4 @@ def test_original_model_unchanged(self): stage.apply(model, ctx) # Original model layers should not be replaced - assert type(list(model.modules())[1]) == original_type + assert type(list(model.modules())[1]) is original_type diff --git a/tests/unit/test_recipe.py b/tests/unit/test_recipe.py index 4c171e9..3215c9c 100644 --- a/tests/unit/test_recipe.py +++ b/tests/unit/test_recipe.py @@ -6,7 +6,6 @@ from comprexx.recipe.loader import load_recipe, recipe_to_pipeline from comprexx.recipe.schema import RecipeV1 - VALID_RECIPE_YAML = """\ name: test-recipe version: "1.0" diff --git a/tests/unit/test_weight_only_quant.py b/tests/unit/test_weight_only_quant.py index a3f49ca..62e5a8f 100644 --- a/tests/unit/test_weight_only_quant.py +++ b/tests/unit/test_weight_only_quant.py @@ -1,7 +1,6 @@ """Tests for weight-only quantization stage.""" import torch -import torch.nn as nn from comprexx.stages.base import StageContext from comprexx.stages.quantization.weight_only import ( From 16ab97adaa2207fe9bb9812c9098a69f84a080b7 Mon Sep 17 00:00:00 2001 From: maskedsyntax Date: Thu, 9 Apr 2026 22:01:38 +0530 Subject: [PATCH 2/4] silence torch.ao.quantization deprecation warning in PTQ stages The torch.ao.quantization API prints a DeprecationWarning on every call, announcing its removal in torch 2.10. Our PTQ dynamic and static stages still rely on it, so the warning is unavoidable for now and just adds noise to user output. Wrap the prepare/convert/quantize_dynamic calls in warnings.catch_warnings and filter the specific message. Leaves a TODO(v0.3) marking the migration to torchao.quantization so it's visible from the source. Test-suite warning count drops from 25 to 5. --- comprexx/stages/quantization/ptq_dynamic.py | 23 ++++++++--- comprexx/stages/quantization/ptq_static.py | 46 ++++++++++++--------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/comprexx/stages/quantization/ptq_dynamic.py b/comprexx/stages/quantization/ptq_dynamic.py index 6eb30d7..2902df6 100644 --- a/comprexx/stages/quantization/ptq_dynamic.py +++ b/comprexx/stages/quantization/ptq_dynamic.py @@ -4,6 +4,7 @@ import copy import time +import warnings from typing import Literal import torch @@ -14,6 +15,10 @@ from comprexx.core.report import StageReport from comprexx.stages.base import CompressionStage, StageContext +# TODO(v0.3): migrate to torchao.quantization. torch.ao.quantization is +# scheduled for removal in torch 2.10. Tracked at pytorch/ao#2259. +_TORCH_AO_WARN = r"torch\.ao\.quantization is deprecated" + class PTQDynamicConfig(BaseModel): """Configuration for dynamic quantization.""" @@ -47,13 +52,19 @@ def apply( # Profile before profile_before = analyze(model, context.input_shape, context.device) - # Apply dynamic quantization + # Apply dynamic quantization. Silence the torch.ao.quantization + # deprecation warning — it's noise for our users, and we track the + # migration in the TODO above. dtype = torch.qint8 - quantized_model = torch.quantization.quantize_dynamic( - model, - qconfig_spec={nn.Linear, nn.LSTM}, - dtype=dtype, - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=_TORCH_AO_WARN, category=DeprecationWarning + ) + quantized_model = torch.quantization.quantize_dynamic( + model, + qconfig_spec={nn.Linear, nn.LSTM}, + dtype=dtype, + ) # Profile after — size calculation for quantized models size_after = _quantized_model_size(quantized_model) diff --git a/comprexx/stages/quantization/ptq_static.py b/comprexx/stages/quantization/ptq_static.py index cbb9c82..738be4a 100644 --- a/comprexx/stages/quantization/ptq_static.py +++ b/comprexx/stages/quantization/ptq_static.py @@ -4,6 +4,7 @@ import copy import time +import warnings from typing import Literal import torch @@ -15,6 +16,10 @@ from comprexx.core.report import StageReport from comprexx.stages.base import CompressionStage, StageContext +# TODO(v0.3): migrate to torchao.quantization. torch.ao.quantization is +# scheduled for removal in torch 2.10. Tracked at pytorch/ao#2259. +_TORCH_AO_WARN = r"torch\.ao\.quantization is deprecated" + class PTQStaticConfig(BaseModel): """Configuration for static quantization.""" @@ -68,25 +73,28 @@ def apply( # Fuse common patterns if possible model = _try_fuse(model) - # Prepare: insert observers - prepared = torch.quantization.prepare(model, inplace=False) - - # Calibration: run forward passes to collect statistics - samples_seen = 0 - with torch.no_grad(): - for batch in context.calibration_data: - if samples_seen >= self.config.calibration_samples: - break - if isinstance(batch, (list, tuple)): - x = batch[0] - else: - x = batch - x = x.to(context.device) - prepared(x) - samples_seen += x.shape[0] - - # Convert: replace observers with quantized ops - quantized = torch.quantization.convert(prepared, inplace=False) + # Prepare → calibrate → convert. Silence the torch.ao.quantization + # deprecation warning — it's noise for users, tracked as a v0.3 TODO. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=_TORCH_AO_WARN, category=DeprecationWarning + ) + prepared = torch.quantization.prepare(model, inplace=False) + + samples_seen = 0 + with torch.no_grad(): + for batch in context.calibration_data: + if samples_seen >= self.config.calibration_samples: + break + if isinstance(batch, (list, tuple)): + x = batch[0] + else: + x = batch + x = x.to(context.device) + prepared(x) + samples_seen += x.shape[0] + + quantized = torch.quantization.convert(prepared, inplace=False) # Estimate quantized size size_after = _estimate_quantized_size(quantized, profile_before.size_bytes) From 9d75327856f984aa44286d38863dedf477273871 Mon Sep 17 00:00:00 2001 From: maskedsyntax Date: Thu, 9 Apr 2026 22:01:38 +0530 Subject: [PATCH 3/4] add github actions ci workflow Runs pytest on Python 3.10, 3.11, and 3.12 via a matrix, plus a separate ruff lint job. Triggers on push and pull_request against master. The project had no CI before this. --- .github/workflows/ci.yml | 51 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..2ea7362 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,51 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + test: + name: Test (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install package + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,onnx]" + + - name: Run tests + run: pytest -q + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + + - name: Install ruff + run: pip install ruff + + - name: Ruff check + run: ruff check . From 37b5c4b58e298f8522d6d0ab2bc244cfb07b3932 Mon Sep 17 00:00:00 2001 From: maskedsyntax Date: Thu, 9 Apr 2026 22:01:38 +0530 Subject: [PATCH 4/4] add changelog for v0.1.0 and v0.2.0 Keep-a-Changelog format with entries for the initial release, the v0.2 feature drop (new compression techniques and sensitivity analysis), and an Unreleased section tracking this stabilize-0.2 work. --- CHANGELOG.md | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..b29625d --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,63 @@ +# Changelog + +All notable changes to Comprexx are documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- GitHub Actions CI workflow running `pytest` on Python 3.10, 3.11, 3.12 plus a + `ruff check` lint job. +- `CHANGELOG.md` with history for v0.1.0 and v0.2.0. + +### Changed +- Silenced the noisy `torch.ao.quantization is deprecated` warning inside the + PTQ dynamic and static stages. The underlying API is still used, with a + `TODO(v0.3)` marking the upcoming migration to `torchao.quantization`. +- Fixed the package `__version__` to report `0.2.0` instead of the stale + `0.1.0` that shipped on PyPI. +- Tightened the codebase against `ruff check` and added a per-file ignore + for `E741` in tests. + +## [0.2.0] - 2026-04-07 + +### Added +- **Unstructured pruning** stage: magnitude or random element-wise pruning + with global/local scope and optional gradual cubic schedule. +- **N:M sparsity** stage: structured N-of-M sparsity (default 2:4) for + NVIDIA Ampere sparse tensor cores. +- **Weight-only quantization** stage: group-wise INT4/INT8 with symmetric + or asymmetric scaling for Linear and Conv2d layers. +- **Low-rank decomposition** stage: truncated SVD factorization of Linear + layers, with fixed rank-ratio or energy-threshold selection modes. +- **Operator fusion** stage: Conv2d + BatchNorm2d folding via `torch.fx` + with graceful fallback on non-traceable models. +- **Weight clustering** stage: per-layer k-means codebook clustering. +- **`cx.analyze_sensitivity()`**: per-layer sensitivity probing via prune + or noise perturbation. Returns a `SensitivityReport` that ranks layers + by metric drop and can suggest `exclude_layers` above a threshold. +- New techniques are wired through the recipe schema and loader, and + exported from `comprexx.stages`. + +### Tests +- 163 passing (up from 91). + +## [0.1.0] - 2026-04-06 + +Initial release. + +### Added +- Model analysis and profiling via `cx.analyze()`. +- Structured pruning with L1/L2/random criteria and global/local scope. +- Post-training dynamic and static INT8 quantization. +- ONNX export with manifest and optional `onnxruntime` validation. +- Recipe-driven pipelines (YAML) validated via Pydantic. +- CLI commands: `comprexx analyze`, `compress`, `export`. +- Accuracy guards with halt/warn actions. +- Per-stage compression reports persisted under `comprexx_runs/`. + +[Unreleased]: https://github.com/cachevector/comprexx/compare/v0.2.0...HEAD +[0.2.0]: https://github.com/cachevector/comprexx/compare/v0.1.0...v0.2.0 +[0.1.0]: https://github.com/cachevector/comprexx/releases/tag/v0.1.0