From 39d9029a670a2c34526ff1e9f6c3197307272f9e Mon Sep 17 00:00:00 2001 From: Ryan McKenna Date: Tue, 23 Jun 2026 06:04:02 -0700 Subject: [PATCH] Add mechanism diagnostics infrastructure Add MechanismDiagnostics dataclass, timed() context manager, and clique_stats() helper to common.py. Wire diagnostics into all six discrete mechanisms (AIM, AIM-GDP, Direct, Independent, MST, SWIFT). Each mechanism now reports: - Per-phase wall-clock timing (selection, measurement, estimation) - Structural stats: clique count/sizes, junction tree node sizes - Round count for iterative mechanisms (AIM, AIM-GDP) clique_stats() takes the fitted MarkovRandomField directly, extracting cliques and domain from the model's potentials. In AIM and AIM-GDP, replaces manual t0/t1/t2/t3 timing with the timed() context manager. In MST, replaces the inline hypothetical_model_size call with the richer clique_stats helper. PiperOrigin-RevId: 936617333 --- dpsynth/discrete_mechanisms/__init__.py | 1 + dpsynth/discrete_mechanisms/aim.py | 80 +++++++++++---------- dpsynth/discrete_mechanisms/aim_gdp.py | 84 +++++++++++----------- dpsynth/discrete_mechanisms/common.py | 78 +++++++++++++++++++- dpsynth/discrete_mechanisms/direct.py | 35 +++++---- dpsynth/discrete_mechanisms/independent.py | 45 ++++++------ dpsynth/discrete_mechanisms/mst.py | 83 ++++++++++----------- dpsynth/discrete_mechanisms/swift.py | 63 ++++++++-------- 8 files changed, 284 insertions(+), 185 deletions(-) diff --git a/dpsynth/discrete_mechanisms/__init__.py b/dpsynth/discrete_mechanisms/__init__.py index 0d127ee..ca4531e 100644 --- a/dpsynth/discrete_mechanisms/__init__.py +++ b/dpsynth/discrete_mechanisms/__init__.py @@ -19,6 +19,7 @@ from dpsynth.discrete_mechanisms.aim import AIMMechanism from dpsynth.discrete_mechanisms.aim_gdp import AIMGDPMechanism from dpsynth.discrete_mechanisms.common import DiscreteMechanismResult +from dpsynth.discrete_mechanisms.common import MechanismDiagnostics from dpsynth.discrete_mechanisms.direct import DirectMechanism from dpsynth.discrete_mechanisms.independent import IndependentMechanism from dpsynth.discrete_mechanisms.mst import MSTMechanism diff --git a/dpsynth/discrete_mechanisms/aim.py b/dpsynth/discrete_mechanisms/aim.py index 7414254..e3504e6 100644 --- a/dpsynth/discrete_mechanisms/aim.py +++ b/dpsynth/discrete_mechanisms/aim.py @@ -16,7 +16,6 @@ from collections.abc import Iterable, Mapping import dataclasses -import time from typing import TypeAlias from absl import logging @@ -171,6 +170,7 @@ def __call__( raise ValueError('Must call calibrate() before using the mechanism.') logging.info('[AIM]: Starting Mechanism.') + phase_times = {} constraints = initial_potentials is not None marginal_oracle = common.default_oracle(self.marginal_oracle, constraints) @@ -223,23 +223,27 @@ def __call__( ######################################################################## # Select a marginal query worst approximated by the current model. # ######################################################################## - t0 = time.time() - rho_remaining -= rho_per_round - fraction = self.select_budget_fraction - sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round) - epsilon = accounting.zcdp_exponential_eps(fraction * rho_per_round) - size_limit = self.max_model_size * (zcdp_rho - rho_remaining) / zcdp_rho - small_candidates = _filter_candidates(candidates, model, size_limit) - - estimates = mbi.marginal_oracles.bulk_variable_elimination( - model.potentials, list(small_candidates), total=model.total - ) - marginal_query = _worst_approximated( - rng, small_candidates, answers, estimates, epsilon, sigma, data.domain - ) + with common.timed(phase_times, 'selection'): + rho_remaining -= rho_per_round + fraction = self.select_budget_fraction + sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round) + epsilon = accounting.zcdp_exponential_eps(fraction * rho_per_round) + size_limit = self.max_model_size * (zcdp_rho - rho_remaining) / zcdp_rho + small_candidates = _filter_candidates(candidates, model, size_limit) + + estimates = mbi.marginal_oracles.bulk_variable_elimination( + model.potentials, list(small_candidates), total=model.total + ) + marginal_query = _worst_approximated( + rng, + small_candidates, + answers, + estimates, + epsilon, + sigma, + data.domain, + ) - t1 = time.time() - logging.info('[AIM] Found worst-approximated candidate in %.2fs', t1 - t0) logging.info( '[AIM] Round %d, Budget used: %.4f, Measuring: %s, Candidates: %d', t, @@ -251,29 +255,28 @@ def __call__( ###################################################################### # Measure the marginal query privately using the Gaussian mechanism. # ###################################################################### - measurement = common.measure_marginals_with_noise( - rng, data, [marginal_query], sigma - )[0] - measurements.append(measurement) - old_estimate = model.project(marginal_query).datavector() + with common.timed(phase_times, 'measurement'): + measurement = common.measure_marginals_with_noise( + rng, data, [marginal_query], sigma + )[0] + measurements.append(measurement) + old_estimate = model.project(marginal_query).datavector() ##################################################### # Estimate the data distribution using Private-PGM. # ##################################################### - t2 = time.time() - callback_fn = mbi.callbacks.default(measurements) - measured_cliques = list(set(m.clique for m in measurements)) - warm_start = model.potentials.expand(measured_cliques) - model = mbi.estimation.mirror_descent( - data.domain, - measurements, - potentials=warm_start, - iters=self.pgm_iters, - callback_fn=callback_fn, - marginal_oracle=marginal_oracle, - ) - t3 = time.time() - logging.info('[AIM] Mirror descent took %.2fs', t3 - t2) + with common.timed(phase_times, 'estimation'): + callback_fn = mbi.callbacks.default(measurements) + measured_cliques = list(set(m.clique for m in measurements)) + warm_start = model.potentials.expand(measured_cliques) + model = mbi.estimation.mirror_descent( + data.domain, + measurements, + potentials=warm_start, + iters=self.pgm_iters, + callback_fn=callback_fn, + marginal_oracle=marginal_oracle, + ) new_estimate = model.project(marginal_query).datavector() @@ -288,6 +291,9 @@ def __call__( sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round) logging.info('[AIM] Reducing sigma: %.1f', sigma) + diagnostics = common.clique_stats(model) + diagnostics.phase_times = phase_times + diagnostics.num_rounds = t return common.DiscreteMechanismResult( - model=model, measurements=measurements + model=model, measurements=measurements, diagnostics=diagnostics ) diff --git a/dpsynth/discrete_mechanisms/aim_gdp.py b/dpsynth/discrete_mechanisms/aim_gdp.py index 86ca15a..6cf471f 100644 --- a/dpsynth/discrete_mechanisms/aim_gdp.py +++ b/dpsynth/discrete_mechanisms/aim_gdp.py @@ -16,7 +16,6 @@ from collections.abc import Iterable, Mapping import dataclasses -import time from typing import TypeAlias from absl import logging @@ -217,6 +216,7 @@ def __call__( raise ValueError('Must call calibrate() before using the mechanism.') logging.info('[AIM] Starting Mechanism.') + phase_times = {} constraints = initial_potentials is not None marginal_oracle = common.default_oracle(self.marginal_oracle, constraints) @@ -284,28 +284,26 @@ def __call__( ######################################################################## # Select a marginal query worst approximated by the current model. # ######################################################################## - t0 = time.time() - budget_remaining -= budget_per_round - measure_budget = budget_per_round * (1 - self.select_budget_fraction) - select_budget = budget_per_round * self.select_budget_fraction - measure_sigma = accounting.gdp_gaussian_sigma(measure_budget) - percent_used = (gdp_budget - budget_remaining) / gdp_budget - size_limit = self.max_model_size * percent_used - small_candidates = _filter_candidates(candidates, model, size_limit) - - marginal_query = _worst_approximated( - rng, - candidates=small_candidates, - errors=errors, - answers=answers, - model=model, - select_budget=select_budget, - measure_sigma=measure_sigma, - max_new_evals=self.max_candidates_per_round, - ) + with common.timed(phase_times, 'selection'): + budget_remaining -= budget_per_round + measure_budget = budget_per_round * (1 - self.select_budget_fraction) + select_budget = budget_per_round * self.select_budget_fraction + measure_sigma = accounting.gdp_gaussian_sigma(measure_budget) + percent_used = (gdp_budget - budget_remaining) / gdp_budget + size_limit = self.max_model_size * percent_used + small_candidates = _filter_candidates(candidates, model, size_limit) + + marginal_query = _worst_approximated( + rng, + candidates=small_candidates, + errors=errors, + answers=answers, + model=model, + select_budget=select_budget, + measure_sigma=measure_sigma, + max_new_evals=self.max_candidates_per_round, + ) - t1 = time.time() - logging.info('[AIM] Found worst candidate in %.2fs', t1 - t0) logging.info( '[AIM] Round %d, Budget used: %.4f, Measuring: %s, Candidates: %d', t, @@ -317,29 +315,28 @@ def __call__( ###################################################################### # Measure the marginal query privately using the Gaussian mechanism. # ###################################################################### - measurement = common.measure_marginals_with_noise( - rng, data, [marginal_query], measure_sigma - )[0] - measurements.append(measurement) - old_estimate = model.project(marginal_query).datavector() + with common.timed(phase_times, 'measurement'): + measurement = common.measure_marginals_with_noise( + rng, data, [marginal_query], measure_sigma + )[0] + measurements.append(measurement) + old_estimate = model.project(marginal_query).datavector() ##################################################### # Estimate the data distribution using Private-PGM. # ##################################################### - t2 = time.time() - callback_fn = mbi.callbacks.default(measurements) - measured_cliques = list(set(m.clique for m in measurements)) - warm_start = model.potentials.expand(measured_cliques) - model = mbi.estimation.mirror_descent( - domain, - measurements, - potentials=warm_start, - iters=self.pgm_iters, - callback_fn=callback_fn, - marginal_oracle=marginal_oracle, - ) - t3 = time.time() - logging.info('[AIM] Mirror descent took %.2fs', t3 - t2) + with common.timed(phase_times, 'estimation'): + callback_fn = mbi.callbacks.default(measurements) + measured_cliques = list(set(m.clique for m in measurements)) + warm_start = model.potentials.expand(measured_cliques) + model = mbi.estimation.mirror_descent( + domain, + measurements, + potentials=warm_start, + iters=self.pgm_iters, + callback_fn=callback_fn, + marginal_oracle=marginal_oracle, + ) new_estimate = model.project(marginal_query).datavector() @@ -358,6 +355,9 @@ def __call__( '[AIM] Increasing budget per round: %.5f', budget_per_round ) + diagnostics = common.clique_stats(model) + diagnostics.phase_times = phase_times + diagnostics.num_rounds = t return common.DiscreteMechanismResult( - model=model, measurements=measurements + model=model, measurements=measurements, diagnostics=diagnostics ) diff --git a/dpsynth/discrete_mechanisms/common.py b/dpsynth/discrete_mechanisms/common.py index 5c7145d..4e6f3b5 100644 --- a/dpsynth/discrete_mechanisms/common.py +++ b/dpsynth/discrete_mechanisms/common.py @@ -15,19 +15,93 @@ """Common utility functions for synthetic data mechanisms.""" from collections.abc import Iterable, Mapping +import contextlib import dataclasses import functools import itertools -from typing import Any, TypeAlias +import time +from typing import TypeAlias +from absl import logging from dpsynth import transformations import mbi +import mbi.junction_tree import more_itertools import numpy as np import scipy import scipy.special +@dataclasses.dataclass +class MechanismDiagnostics: + """Diagnostic info from a discrete mechanism run. + + Attributes: + phase_times: Wall-clock time in seconds for each named phase. + num_rounds: Number of select-measure rounds (iterative mechanisms only). + num_cliques: Number of cliques in the fitted model. + max_clique_size: Size of the largest clique. + total_clique_size: Sum of all clique sizes. + max_jtree_node_size: Size of the largest junction tree node. + total_jtree_size: Sum of all junction tree node sizes. + """ + + phase_times: dict[str, float] = dataclasses.field(default_factory=dict) + num_rounds: int = 0 + num_cliques: int = 0 + max_clique_size: int = 0 + total_clique_size: int = 0 + max_jtree_node_size: int = 0 + total_jtree_size: int = 0 + + +@contextlib.contextmanager +def timed(phase_times: dict[str, float], name: str): + """Context manager that logs and records wall-clock time for a phase.""" + start = time.monotonic() + yield + elapsed = time.monotonic() - start + phase_times[name] = phase_times.get(name, 0.0) + elapsed + logging.info('[%s] %.2fs', name, elapsed) + + +def clique_stats(model: mbi.MarkovRandomField) -> MechanismDiagnostics: + """Compute structural diagnostics from a fitted model and log them. + + Args: + model: The fitted graphical model. + + Returns: + A MechanismDiagnostics populated with clique and junction tree stats. + """ + domain = model.potentials.domain + cliques = model.potentials.cliques + sizes = [domain.size(c) for c in cliques] + jtree, _ = mbi.junction_tree.make_junction_tree(domain, cliques) + jtree_nodes = list(jtree.nodes) + jtree_sizes = [domain.size(n) for n in jtree_nodes] + diagnostics = MechanismDiagnostics( + num_cliques=len(cliques), + max_clique_size=max(sizes, default=0), + total_clique_size=sum(sizes), + max_jtree_node_size=max(jtree_sizes, default=0), + total_jtree_size=sum(jtree_sizes), + ) + logging.info( + 'Cliques: %d, max_size: %d, total_size: %d', + diagnostics.num_cliques, + diagnostics.max_clique_size, + diagnostics.total_clique_size, + ) + logging.info( + 'Junction tree: %d nodes, max_size: %d, total_size: %d', + len(jtree_nodes), + diagnostics.max_jtree_node_size, + diagnostics.total_jtree_size, + ) + return diagnostics + + @dataclasses.dataclass class DiscreteMechanismResult: """Result of running a discrete mechanism. @@ -42,7 +116,7 @@ class DiscreteMechanismResult: measurements: list[mbi.LinearMeasurement] = dataclasses.field( default_factory=list ) - diagnostics: Any | None = None + diagnostics: MechanismDiagnostics | None = None def exponential_mechanism( diff --git a/dpsynth/discrete_mechanisms/direct.py b/dpsynth/discrete_mechanisms/direct.py index f9e19a4..e971319 100644 --- a/dpsynth/discrete_mechanisms/direct.py +++ b/dpsynth/discrete_mechanisms/direct.py @@ -70,24 +70,29 @@ def __call__( constraints = initial_potentials is not None marginal_oracle = common.default_oracle(self.marginal_oracle, constraints) + phase_times = {} # measure_marginals_with_noise splits gdp_sigma across the queries # internally via weight normalization. - new_measurements = common.measure_marginals_with_noise( - rng, data, self.prespecified_marginal_queries, self.gdp_sigma - ) - if initial_measurements: - all_measurements = initial_measurements + new_measurements - else: - all_measurements = new_measurements + with common.timed(phase_times, 'measurement'): + new_measurements = common.measure_marginals_with_noise( + rng, data, self.prespecified_marginal_queries, self.gdp_sigma + ) + if initial_measurements: + all_measurements = initial_measurements + new_measurements + else: + all_measurements = new_measurements # fit a distribution to the noisy measurements - model = mbi.estimation.mirror_descent( - data.domain, - all_measurements, - iters=self.pgm_iters, - potentials=initial_potentials, - marginal_oracle=marginal_oracle, - ) + with common.timed(phase_times, 'estimation'): + model = mbi.estimation.mirror_descent( + data.domain, + all_measurements, + iters=self.pgm_iters, + potentials=initial_potentials, + marginal_oracle=marginal_oracle, + ) + diagnostics = common.clique_stats(model) + diagnostics.phase_times = phase_times return common.DiscreteMechanismResult( - model=model, measurements=all_measurements + model=model, measurements=all_measurements, diagnostics=diagnostics ) diff --git a/dpsynth/discrete_mechanisms/independent.py b/dpsynth/discrete_mechanisms/independent.py index 4eceb6c..b98fd21 100644 --- a/dpsynth/discrete_mechanisms/independent.py +++ b/dpsynth/discrete_mechanisms/independent.py @@ -71,29 +71,34 @@ def __call__( # per-query sigma = gdp_sigma * sqrt(d). attributes = len(data.domain) per_query_sigma = self.gdp_sigma * attributes**0.5 + phase_times = {} measurements = initial_measurements or [] existing_cliques = {m.clique for m in measurements} - for attr in data.domain: - clique = (attr,) - if clique in existing_cliques: - continue - marginal = data.project(clique).datavector() - noisy_marginal = ( - marginal + rng.normal(size=marginal.shape) * per_query_sigma - ) - measurements.append(mbi.LinearMeasurement(noisy_marginal, clique)) + with common.timed(phase_times, 'measurement'): + for attr in data.domain: + clique = (attr,) + if clique in existing_cliques: + continue + marginal = data.project(clique).datavector() + noisy_marginal = ( + marginal + rng.normal(size=marginal.shape) * per_query_sigma + ) + measurements.append(mbi.LinearMeasurement(noisy_marginal, clique)) - potentials = initial_potentials - if potentials is not None: - potentials = potentials.expand([m.clique for m in measurements]) + with common.timed(phase_times, 'estimation'): + potentials = initial_potentials + if potentials is not None: + potentials = potentials.expand([m.clique for m in measurements]) - model = mbi.estimation.mirror_descent( - data.domain, - measurements, - iters=self.pgm_iters, - potentials=potentials, - marginal_oracle=marginal_oracle, - ) + model = mbi.estimation.mirror_descent( + data.domain, + measurements, + iters=self.pgm_iters, + potentials=potentials, + marginal_oracle=marginal_oracle, + ) + diagnostics = common.clique_stats(model) + diagnostics.phase_times = phase_times return common.DiscreteMechanismResult( - model=model, measurements=measurements + model=model, measurements=measurements, diagnostics=diagnostics ) diff --git a/dpsynth/discrete_mechanisms/mst.py b/dpsynth/discrete_mechanisms/mst.py index 61920b0..1663ead 100644 --- a/dpsynth/discrete_mechanisms/mst.py +++ b/dpsynth/discrete_mechanisms/mst.py @@ -220,59 +220,62 @@ def __call__( if self.zcdp_rho is None: raise ValueError('Must call calibrate() before using the mechanism.') logging.info('[MST]: Starting MST mechanism.') + phase_times = {} constraints = initial_potentials is not None marginal_oracle = common.default_oracle(self.marginal_oracle, constraints) budget_remaining = self.zcdp_rho - if initial_measurements is None: - budget_remaining -= self.one_way_budget_fraction * self.zcdp_rho - one_way_rho = self.zcdp_rho * self.one_way_budget_fraction - one_way_sigma = accounting.zcdp_gaussian_sigma(one_way_rho) - one_way_measurements = common.measure_marginals_with_noise( - rng, - data, - marginal_queries=[(a,) for a in data.domain], - gdp_sigma=one_way_sigma, - ) - else: - one_way_measurements = initial_measurements + with common.timed(phase_times, 'measurement'): + if initial_measurements is None: + budget_remaining -= self.one_way_budget_fraction * self.zcdp_rho + one_way_rho = self.zcdp_rho * self.one_way_budget_fraction + one_way_sigma = accounting.zcdp_gaussian_sigma(one_way_rho) + one_way_measurements = common.measure_marginals_with_noise( + rng, + data, + marginal_queries=[(a,) for a in data.domain], + gdp_sigma=one_way_sigma, + ) + else: + one_way_measurements = initial_measurements exponential_rho = self.select_budget_fraction * self.zcdp_rho budget_remaining -= exponential_rho # Select and measure 2-way marginals using rho/3 budget for each step. - two_way_marginal_queries = _select_two_way_marginal_queries( - rng, - data, - exponential_rho, - one_way_measurements, - maximum_marginal_size=self.maximum_marginal_size, - ) - logging.info('[MST]: Selected two-way marginal queries.') - gaussian_rho = budget_remaining - sigma = accounting.zcdp_gaussian_sigma(gaussian_rho) - two_way_measurements = common.measure_marginals_with_noise( - rng, data, two_way_marginal_queries, sigma - ) - logging.info('[MST]: Measured two-way marginals.') + with common.timed(phase_times, 'selection'): + two_way_marginal_queries = _select_two_way_marginal_queries( + rng, + data, + exponential_rho, + one_way_measurements, + maximum_marginal_size=self.maximum_marginal_size, + ) + logging.info('[MST]: Selected two-way marginal queries.') + with common.timed(phase_times, 'measurement'): + gaussian_rho = budget_remaining + sigma = accounting.zcdp_gaussian_sigma(gaussian_rho) + two_way_measurements = common.measure_marginals_with_noise( + rng, data, two_way_marginal_queries, sigma + ) + logging.info('[MST]: Measured two-way marginals.') all_measurements = one_way_measurements + two_way_measurements # Fit a distribution to the noisy measurements using Private-PGM. potentials = initial_potentials if potentials is not None: potentials = potentials.expand([m.clique for m in all_measurements]) - model_size = mbi.junction_tree.hypothetical_model_size( - data.domain, [m.clique for m in all_measurements] - ) - logging.info('[MST]: Model size: %d MB', model_size) - model = mbi.estimation.mirror_descent( - data.domain, - all_measurements, - iters=self.pgm_iters, - potentials=potentials, - callback_fn=mbi.callbacks.default(all_measurements), - marginal_oracle=marginal_oracle, - ) - logging.info('[MST]: Fit distribution to the noisy measurements.') + with common.timed(phase_times, 'estimation'): + model = mbi.estimation.mirror_descent( + data.domain, + all_measurements, + iters=self.pgm_iters, + potentials=potentials, + callback_fn=mbi.callbacks.default(all_measurements), + marginal_oracle=marginal_oracle, + ) + logging.info('[MST]: Fit distribution to the noisy measurements.') + diagnostics = common.clique_stats(model) + diagnostics.phase_times = phase_times return common.DiscreteMechanismResult( - model=model, measurements=all_measurements + model=model, measurements=all_measurements, diagnostics=diagnostics ) diff --git a/dpsynth/discrete_mechanisms/swift.py b/dpsynth/discrete_mechanisms/swift.py index c276404..f385ef7 100644 --- a/dpsynth/discrete_mechanisms/swift.py +++ b/dpsynth/discrete_mechanisms/swift.py @@ -114,6 +114,7 @@ def __call__( raise ValueError('Must call calibrate() before using the mechanism.') logging.info('[SWIFT] Starting Mechanism.') + phase_times = {} constraints = initial_potentials is not None marginal_oracle = common.default_oracle(self.marginal_oracle, constraints) @@ -157,48 +158,52 @@ def __call__( ########################################### # Select subset of candidates to measure. # ########################################### - assert 0 < self.select_budget_frac < 1 - l1_error_budget = self.select_budget_frac * gdp_budget - budget_remaining -= l1_error_budget + with common.timed(phase_times, 'selection'): + assert 0 < self.select_budget_frac < 1 + l1_error_budget = self.select_budget_frac * gdp_budget + budget_remaining -= l1_error_budget - errors = _compute_initial_errors( - rng, answers, model, list(candidates), l1_error_budget - ) - logging.info('[SWIFT] Computed initial errors.') + errors = _compute_initial_errors( + rng, answers, model, list(candidates), l1_error_budget + ) + logging.info('[SWIFT] Computed initial errors.') - selected, jtree = select_queries( - errors, candidates, domain, self.max_clique_size, budget_remaining - ) + selected, jtree = select_queries( + errors, candidates, domain, self.max_clique_size, budget_remaining + ) ########################################## # Measure the selected marginal queries. # ########################################## - new_measurements, _ = _measure_selected_marginals( - rng, answers, selected, budget_remaining - ) - measurements.extend(new_measurements) + with common.timed(phase_times, 'measurement'): + new_measurements, _ = _measure_selected_marginals( + rng, answers, selected, budget_remaining + ) + measurements.extend(new_measurements) ######################################################## # Estimate the model using all measurements # ######################################################## + with common.timed(phase_times, 'estimation'): + closed_oracle = functools.partial( + mbi.marginal_oracles.message_passing_stable, jtree=jtree + ) - closed_oracle = functools.partial( - mbi.marginal_oracles.message_passing_stable, jtree=jtree - ) - - callback_fn = mbi.callbacks.default(measurements) - model = mbi.estimation.mirror_descent( - domain, - measurements, - iters=self.pgm_iters, - potentials=potentials, - marginal_oracle=closed_oracle, - callback_fn=callback_fn, - ) - logging.info('[SWIFT] Estimated final model.') + callback_fn = mbi.callbacks.default(measurements) + model = mbi.estimation.mirror_descent( + domain, + measurements, + iters=self.pgm_iters, + potentials=potentials, + marginal_oracle=closed_oracle, + callback_fn=callback_fn, + ) + logging.info('[SWIFT] Estimated final model.') + diagnostics = common.clique_stats(model) + diagnostics.phase_times = phase_times return common.DiscreteMechanismResult( - model=model, measurements=measurements + model=model, measurements=measurements, diagnostics=diagnostics )