Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpsynth/discrete_mechanisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dpsynth.discrete_mechanisms.aim import AIMMechanism
from dpsynth.discrete_mechanisms.aim_gdp import AIMGDPMechanism
from dpsynth.discrete_mechanisms.common import DiscreteMechanismResult
from dpsynth.discrete_mechanisms.common import MechanismDiagnostics
from dpsynth.discrete_mechanisms.direct import DirectMechanism
from dpsynth.discrete_mechanisms.independent import IndependentMechanism
from dpsynth.discrete_mechanisms.mst import MSTMechanism
Expand Down
80 changes: 43 additions & 37 deletions dpsynth/discrete_mechanisms/aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from collections.abc import Iterable, Mapping
import dataclasses
import time
from typing import TypeAlias

from absl import logging
Expand Down Expand Up @@ -171,6 +170,7 @@ def __call__(
raise ValueError('Must call calibrate() before using the mechanism.')

logging.info('[AIM]: Starting Mechanism.')
phase_times = {}
constraints = initial_potentials is not None
marginal_oracle = common.default_oracle(self.marginal_oracle, constraints)

Expand Down Expand Up @@ -223,23 +223,27 @@ def __call__(
########################################################################
# Select a marginal query worst approximated by the current model. #
########################################################################
t0 = time.time()
rho_remaining -= rho_per_round
fraction = self.select_budget_fraction
sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round)
epsilon = accounting.zcdp_exponential_eps(fraction * rho_per_round)
size_limit = self.max_model_size * (zcdp_rho - rho_remaining) / zcdp_rho
small_candidates = _filter_candidates(candidates, model, size_limit)

estimates = mbi.marginal_oracles.bulk_variable_elimination(
model.potentials, list(small_candidates), total=model.total
)
marginal_query = _worst_approximated(
rng, small_candidates, answers, estimates, epsilon, sigma, data.domain
)
with common.timed(phase_times, 'selection'):
rho_remaining -= rho_per_round
fraction = self.select_budget_fraction
sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round)
epsilon = accounting.zcdp_exponential_eps(fraction * rho_per_round)
size_limit = self.max_model_size * (zcdp_rho - rho_remaining) / zcdp_rho
small_candidates = _filter_candidates(candidates, model, size_limit)

estimates = mbi.marginal_oracles.bulk_variable_elimination(
model.potentials, list(small_candidates), total=model.total
)
marginal_query = _worst_approximated(
rng,
small_candidates,
answers,
estimates,
epsilon,
sigma,
data.domain,
)

t1 = time.time()
logging.info('[AIM] Found worst-approximated candidate in %.2fs', t1 - t0)
logging.info(
'[AIM] Round %d, Budget used: %.4f, Measuring: %s, Candidates: %d',
t,
Expand All @@ -251,29 +255,28 @@ def __call__(
######################################################################
# Measure the marginal query privately using the Gaussian mechanism. #
######################################################################
measurement = common.measure_marginals_with_noise(
rng, data, [marginal_query], sigma
)[0]
measurements.append(measurement)
old_estimate = model.project(marginal_query).datavector()
with common.timed(phase_times, 'measurement'):
measurement = common.measure_marginals_with_noise(
rng, data, [marginal_query], sigma
)[0]
measurements.append(measurement)
old_estimate = model.project(marginal_query).datavector()

#####################################################
# Estimate the data distribution using Private-PGM. #
#####################################################
t2 = time.time()
callback_fn = mbi.callbacks.default(measurements)
measured_cliques = list(set(m.clique for m in measurements))
warm_start = model.potentials.expand(measured_cliques)
model = mbi.estimation.mirror_descent(
data.domain,
measurements,
potentials=warm_start,
iters=self.pgm_iters,
callback_fn=callback_fn,
marginal_oracle=marginal_oracle,
)
t3 = time.time()
logging.info('[AIM] Mirror descent took %.2fs', t3 - t2)
with common.timed(phase_times, 'estimation'):
callback_fn = mbi.callbacks.default(measurements)
measured_cliques = list(set(m.clique for m in measurements))
warm_start = model.potentials.expand(measured_cliques)
model = mbi.estimation.mirror_descent(
data.domain,
measurements,
potentials=warm_start,
iters=self.pgm_iters,
callback_fn=callback_fn,
marginal_oracle=marginal_oracle,
)

new_estimate = model.project(marginal_query).datavector()

Expand All @@ -288,6 +291,9 @@ def __call__(
sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round)
logging.info('[AIM] Reducing sigma: %.1f', sigma)

diagnostics = common.clique_stats(model)
diagnostics.phase_times = phase_times
diagnostics.num_rounds = t
return common.DiscreteMechanismResult(
model=model, measurements=measurements
model=model, measurements=measurements, diagnostics=diagnostics
)
84 changes: 42 additions & 42 deletions dpsynth/discrete_mechanisms/aim_gdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from collections.abc import Iterable, Mapping
import dataclasses
import time
from typing import TypeAlias

from absl import logging
Expand Down Expand Up @@ -217,6 +216,7 @@ def __call__(
raise ValueError('Must call calibrate() before using the mechanism.')

logging.info('[AIM] Starting Mechanism.')
phase_times = {}
constraints = initial_potentials is not None
marginal_oracle = common.default_oracle(self.marginal_oracle, constraints)

Expand Down Expand Up @@ -284,28 +284,26 @@ def __call__(
########################################################################
# Select a marginal query worst approximated by the current model. #
########################################################################
t0 = time.time()
budget_remaining -= budget_per_round
measure_budget = budget_per_round * (1 - self.select_budget_fraction)
select_budget = budget_per_round * self.select_budget_fraction
measure_sigma = accounting.gdp_gaussian_sigma(measure_budget)
percent_used = (gdp_budget - budget_remaining) / gdp_budget
size_limit = self.max_model_size * percent_used
small_candidates = _filter_candidates(candidates, model, size_limit)

marginal_query = _worst_approximated(
rng,
candidates=small_candidates,
errors=errors,
answers=answers,
model=model,
select_budget=select_budget,
measure_sigma=measure_sigma,
max_new_evals=self.max_candidates_per_round,
)
with common.timed(phase_times, 'selection'):
budget_remaining -= budget_per_round
measure_budget = budget_per_round * (1 - self.select_budget_fraction)
select_budget = budget_per_round * self.select_budget_fraction
measure_sigma = accounting.gdp_gaussian_sigma(measure_budget)
percent_used = (gdp_budget - budget_remaining) / gdp_budget
size_limit = self.max_model_size * percent_used
small_candidates = _filter_candidates(candidates, model, size_limit)

marginal_query = _worst_approximated(
rng,
candidates=small_candidates,
errors=errors,
answers=answers,
model=model,
select_budget=select_budget,
measure_sigma=measure_sigma,
max_new_evals=self.max_candidates_per_round,
)

t1 = time.time()
logging.info('[AIM] Found worst candidate in %.2fs', t1 - t0)
logging.info(
'[AIM] Round %d, Budget used: %.4f, Measuring: %s, Candidates: %d',
t,
Expand All @@ -317,29 +315,28 @@ def __call__(
######################################################################
# Measure the marginal query privately using the Gaussian mechanism. #
######################################################################
measurement = common.measure_marginals_with_noise(
rng, data, [marginal_query], measure_sigma
)[0]
measurements.append(measurement)
old_estimate = model.project(marginal_query).datavector()
with common.timed(phase_times, 'measurement'):
measurement = common.measure_marginals_with_noise(
rng, data, [marginal_query], measure_sigma
)[0]
measurements.append(measurement)
old_estimate = model.project(marginal_query).datavector()

#####################################################
# Estimate the data distribution using Private-PGM. #
#####################################################
t2 = time.time()
callback_fn = mbi.callbacks.default(measurements)
measured_cliques = list(set(m.clique for m in measurements))
warm_start = model.potentials.expand(measured_cliques)
model = mbi.estimation.mirror_descent(
domain,
measurements,
potentials=warm_start,
iters=self.pgm_iters,
callback_fn=callback_fn,
marginal_oracle=marginal_oracle,
)
t3 = time.time()
logging.info('[AIM] Mirror descent took %.2fs', t3 - t2)
with common.timed(phase_times, 'estimation'):
callback_fn = mbi.callbacks.default(measurements)
measured_cliques = list(set(m.clique for m in measurements))
warm_start = model.potentials.expand(measured_cliques)
model = mbi.estimation.mirror_descent(
domain,
measurements,
potentials=warm_start,
iters=self.pgm_iters,
callback_fn=callback_fn,
marginal_oracle=marginal_oracle,
)

new_estimate = model.project(marginal_query).datavector()

Expand All @@ -358,6 +355,9 @@ def __call__(
'[AIM] Increasing budget per round: %.5f', budget_per_round
)

diagnostics = common.clique_stats(model)
diagnostics.phase_times = phase_times
diagnostics.num_rounds = t
return common.DiscreteMechanismResult(
model=model, measurements=measurements
model=model, measurements=measurements, diagnostics=diagnostics
)
78 changes: 76 additions & 2 deletions dpsynth/discrete_mechanisms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,93 @@
"""Common utility functions for synthetic data mechanisms."""

from collections.abc import Iterable, Mapping
import contextlib
import dataclasses
import functools
import itertools
from typing import Any, TypeAlias
import time
from typing import TypeAlias

from absl import logging
from dpsynth import transformations
import mbi
import mbi.junction_tree
import more_itertools
import numpy as np
import scipy
import scipy.special


@dataclasses.dataclass
class MechanismDiagnostics:
"""Diagnostic info from a discrete mechanism run.

Attributes:
phase_times: Wall-clock time in seconds for each named phase.
num_rounds: Number of select-measure rounds (iterative mechanisms only).
num_cliques: Number of cliques in the fitted model.
max_clique_size: Size of the largest clique.
total_clique_size: Sum of all clique sizes.
max_jtree_node_size: Size of the largest junction tree node.
total_jtree_size: Sum of all junction tree node sizes.
"""

phase_times: dict[str, float] = dataclasses.field(default_factory=dict)
num_rounds: int = 0
num_cliques: int = 0
max_clique_size: int = 0
total_clique_size: int = 0
max_jtree_node_size: int = 0
total_jtree_size: int = 0


@contextlib.contextmanager
def timed(phase_times: dict[str, float], name: str):
"""Context manager that logs and records wall-clock time for a phase."""
start = time.monotonic()
yield
elapsed = time.monotonic() - start
phase_times[name] = phase_times.get(name, 0.0) + elapsed
logging.info('[%s] %.2fs', name, elapsed)


def clique_stats(model: mbi.MarkovRandomField) -> MechanismDiagnostics:
"""Compute structural diagnostics from a fitted model and log them.

Args:
model: The fitted graphical model.

Returns:
A MechanismDiagnostics populated with clique and junction tree stats.
"""
domain = model.potentials.domain
cliques = model.potentials.cliques
sizes = [domain.size(c) for c in cliques]
jtree, _ = mbi.junction_tree.make_junction_tree(domain, cliques)
jtree_nodes = list(jtree.nodes)
jtree_sizes = [domain.size(n) for n in jtree_nodes]
diagnostics = MechanismDiagnostics(
num_cliques=len(cliques),
max_clique_size=max(sizes, default=0),
total_clique_size=sum(sizes),
max_jtree_node_size=max(jtree_sizes, default=0),
total_jtree_size=sum(jtree_sizes),
)
logging.info(
'Cliques: %d, max_size: %d, total_size: %d',
diagnostics.num_cliques,
diagnostics.max_clique_size,
diagnostics.total_clique_size,
)
logging.info(
'Junction tree: %d nodes, max_size: %d, total_size: %d',
len(jtree_nodes),
diagnostics.max_jtree_node_size,
diagnostics.total_jtree_size,
)
return diagnostics


@dataclasses.dataclass
class DiscreteMechanismResult:
"""Result of running a discrete mechanism.
Expand All @@ -42,7 +116,7 @@ class DiscreteMechanismResult:
measurements: list[mbi.LinearMeasurement] = dataclasses.field(
default_factory=list
)
diagnostics: Any | None = None
diagnostics: MechanismDiagnostics | None = None


def exponential_mechanism(
Expand Down
35 changes: 20 additions & 15 deletions dpsynth/discrete_mechanisms/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,29 @@ def __call__(
constraints = initial_potentials is not None
marginal_oracle = common.default_oracle(self.marginal_oracle, constraints)

phase_times = {}
# measure_marginals_with_noise splits gdp_sigma across the queries
# internally via weight normalization.
new_measurements = common.measure_marginals_with_noise(
rng, data, self.prespecified_marginal_queries, self.gdp_sigma
)
if initial_measurements:
all_measurements = initial_measurements + new_measurements
else:
all_measurements = new_measurements
with common.timed(phase_times, 'measurement'):
new_measurements = common.measure_marginals_with_noise(
rng, data, self.prespecified_marginal_queries, self.gdp_sigma
)
if initial_measurements:
all_measurements = initial_measurements + new_measurements
else:
all_measurements = new_measurements

# fit a distribution to the noisy measurements
model = mbi.estimation.mirror_descent(
data.domain,
all_measurements,
iters=self.pgm_iters,
potentials=initial_potentials,
marginal_oracle=marginal_oracle,
)
with common.timed(phase_times, 'estimation'):
model = mbi.estimation.mirror_descent(
data.domain,
all_measurements,
iters=self.pgm_iters,
potentials=initial_potentials,
marginal_oracle=marginal_oracle,
)
diagnostics = common.clique_stats(model)
diagnostics.phase_times = phase_times
return common.DiscreteMechanismResult(
model=model, measurements=all_measurements
model=model, measurements=all_measurements, diagnostics=diagnostics
)
Loading
Loading