Skip to content

Prasukj7-arch/PTQ_QAT_Model_Training

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

67 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CIFAR-10 ResNet18 Model Optimization: A 2-Month Project

This repository documents a comprehensive 2-Month Project focused on optimizing a ResNet18 model for the CIFAR-10 dataset. The project explores model compression techniques, including quantization (Post-Training Quantization and Quantization-Aware Training), to reduce model size and improve inference speed while maintaining high accuracy.

Project Overview & Goals

The primary objectives of this study were to:

  • Establish a robust full-precision (FP32) baseline for image classification on CIFAR-10.
  • Implement and evaluate model compression techniques:
    • Quantization: Reducing precision of weights and activations (e.g., from 32-bit floating-point to 8-bit integers).
      • Post Training Quantization
      • Quantization Aware Training
  • Compare optimized models against the baseline across key metrics: accuracy, model size, and inference time.
  • Provide clear visualizations to illustrate the impact and trade-offs of each optimization strategy.

Features Implemented & Work Done

Quantization Techniques

  1. Full Precision (FP32) Model Training:

    • Trained a ResNet18 model from scratch on CIFAR-10.
    • Applied advanced data augmentation: RandomHorizontalFlip, RandomCrop (32x32 with 4-pixel padding), ColorJitter (brightness, contrast, saturation, hue), and RandomRotation (15 degrees).
    • Used mixed-precision training (AMP) on CUDA-enabled devices for faster training.
    • Saved the best-performing model as best_float_model.pth.
    • Outputs: Training/validation loss and accuracy curves, confusion matrix, and sample prediction visualizations.
  2. Post-Training Quantization (PTQ):

    • Applied static 8-bit integer quantization to best_float_model.pth.
    • Fused modules (Conv-BN-ReLU) for optimized quantized operations.
    • Calibrated activation ranges using a subset of the training dataset (8192 samples).
    • Saved the quantized model as final_ptq_quantized_model.pth.
    • Outputs: Confusion matrix and sample prediction visualizations.
  3. Quantization-Aware Training (QAT):

    • Fine-tuned best_float_model.pth with simulated quantization to adapt weights to quantization noise.
    • Saved the best QAT model as best_qat_model.pth.
    • Outputs: Confusion matrix and sample prediction visualizations.

Evaluation & Visualizations

  • Metrics: Compared Test Accuracy (%), Model Size (MB), and Inference Time (ms/sample) for FP32, PTQ, and QAT models.
  • Visualizations:
    • Accuracy Comparison: Bar chart comparing FP32, PTQ, and QAT model accuracies.
    • Model Size Comparison: Bar chart showing disk size of model variants.
    • Inference Time Comparison: Bar chart displaying average inference time per sample.

Setup and Prerequisites

  1. Clone the Repository:

    git clone https://github.com/Prasukj7-arch/IIT_Hyerabad_PTQ_QAT_Model_Training.git
    cd IIT_Hyerabad_PTQ_QAT_Model_Training
  2. Create a Virtual Environment (recommended):

    python3 -m venv venv
    source venv/bin/activate  # On Windows: .\venv\Scripts\activate
  3. Install Dependencies:

    pip install torch torchvision matplotlib seaborn scikit-learn numpy

    Ensure the correct PyTorch version for your system (especially for CUDA-enabled GPUs). Check the PyTorch website for installation details.

Usage

Follow these steps to train models, apply quantization, and generate comparisons.

Step 1: Train the Full Precision (FP32) Model

Train the ResNet18 model and save the best checkpoint:

python train_fp32_model.py

Outputs:

  • best_float_model.pth
  • fp32_training_curves.png
  • fp32_confusion_matrix.png
  • fp32_sample_predictions.png

Step 2: Apply Post-Training Quantization (PTQ)

Load best_float_model.pth, apply PTQ, and save the quantized model:

python ptq_model.py

Outputs:

  • final_ptq_quantized_model.pth
  • ptq_confusion_matrix.png
  • ptq_sample_predictions.png

Step 3: Perform Quantization-Aware Training (QAT)

Load best_float_model.pth, fine-tune with QAT, and save the result:

python qat_model.py

Outputs:

  • best_qat_model.pth
  • qat_confusion_matrix.png
  • qat_sample_predictions.png

Overall Evaluation Comparision:

My Image My Image

Key Results and Analysis

Quantization Results

Model Type Accuracy (%) Model Size (MB) Inference Time (ms/sample)
FP32 94.02 42.70 27.67
PTQ 94.04 10.79 9.53
QAT 93.98 10.78 8.80

Observations

  • Model Size Reduction: PTQ and QAT achieved a ~75% reduction in model size (42.70 MB to 10.79 MB).
  • Inference Speedup: Quantized models reduced inference time by ~65% (27.67 ms to ~9 ms).
  • Accuracy Retention (PTQ): PTQ maintained accuracy (94.04% vs 94.02%), indicating robust calibration.
  • Accuracy Retention (QAT): QAT achieved 93.98% accuracy, with a minor ~0.3% drop, confirming effective adaptation to quantization.

These results highlight the efficacy of quantization for resource-constrained environments.

Acknowledgements

  • Built using the PyTorch deep learning framework.
  • Utilizes the ResNet architecture and CIFAR-10 dataset, standard benchmarks in computer vision.

About

ResNet18 model optimization for CIFAR-10 using Post-Training and Quantization-Aware Training (PTQ/QAT) to reduce size and improve inference.

Topics

Resources

Stars

Watchers

Forks

Contributors