Skip to content

Add DP supervised fine-tuning module for Gemma models#47

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
cl/933779334
Open

Add DP supervised fine-tuning module for Gemma models#47
copybara-service[bot] wants to merge 1 commit into
mainfrom
cl/933779334

Conversation

@copybara-service

Copy link
Copy Markdown

Add DP supervised fine-tuning module for Gemma models

Introduces dpsynth.text.model and dpsynth.text.dp_sft, a composable API for
differentially private supervised fine-tuning of Gemma language models with
LoRA adapters. Split into two files to separate concerns:

model.py (no DP logic, standard ML review):

  • SupportedModel: Enum of validated Gemma variants.
  • LoraConfig: LoRA adapter configuration.
  • load_gemma(): Load pretrained model with LoRA adapters applied.
  • sft_loss_fn(): Cross-entropy loss (forked from tunix peft_trainer).

dp_sft.py (DP-critical, needs privacy review):

  • DPSft: DPMechanism subclass wrapping DP-SGD via JAX Privacy.
  • calibrate(): Noise calibration from a zCDP budget.

dp_sft_test.py:

  • Unit tests for loss function, calibration, config, and enum wiring.
    All tests use a tiny mock model (no real checkpoint loading).

Composes Tunix (model), qwix (LoRA), and JAX Privacy (DP training).

Introduces dpsynth.text.model and dpsynth.text.dp_sft, a composable API for
differentially private supervised fine-tuning of Gemma language models with
LoRA adapters. Split into two files to separate concerns:

  model.py (no DP logic, standard ML review):
  - SupportedModel: Enum of validated Gemma variants.
  - LoraConfig: LoRA adapter configuration.
  - load_gemma(): Load pretrained model with LoRA adapters applied.
  - sft_loss_fn(): Cross-entropy loss (forked from tunix peft_trainer).

  dp_sft.py (DP-critical, needs privacy review):
  - DPSft: DPMechanism subclass wrapping DP-SGD via JAX Privacy.
  - calibrate(): Noise calibration from a zCDP budget.

  dp_sft_test.py:
  - Unit tests for loss function, calibration, config, and enum wiring.
    All tests use a tiny mock model (no real checkpoint loading).

Composes Tunix (model), qwix (LoRA), and JAX Privacy (DP training).

PiperOrigin-RevId: 933779334
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants