From 9e9c9252e8618c1342707f0f94203b7e7131bec2 Mon Sep 17 00:00:00 2001 From: Daisuke Taniwaki Date: Wed, 22 Apr 2026 22:15:49 +0900 Subject: [PATCH 1/4] feat: add --threads option to parallelize report data fetching Add --threads CLI option (default 1) to `edr report` and `edr send-report`. When set to >1, independent dbt run-operations are executed concurrently using ThreadPoolExecutor with SubprocessDbtRunner. dbt's Python API (APIDbtRunner) is not thread-safe due to global mutable state (GLOBAL_FLAGS, adapter FACTORY, etc.), so parallel execution uses SubprocessDbtRunner which spawns independent dbt processes per call. The fetching is split into phases: - Phase 1: 14 independent operations run in parallel - Phase 2: exposures + test_results (depend on Phase 1) - Phase 3: lineage (depends on Phase 2) - Phase 4: pure computation (no dbt calls) With --threads=14, edr report time is expected to drop from ~3m40s to ~30-40s on adapters with high query latency (e.g. Athena). --- elementary/monitor/api/report/report.py | 388 +++++++++++++++++- elementary/monitor/cli.py | 18 + .../report/data_monitoring_report.py | 6 + 3 files changed, 400 insertions(+), 12 deletions(-) diff --git a/elementary/monitor/api/report/report.py b/elementary/monitor/api/report/report.py index 77850333e..5fd9e2a26 100644 --- a/elementary/monitor/api/report/report.py +++ b/elementary/monitor/api/report/report.py @@ -1,7 +1,9 @@ from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from typing import Dict, Iterable, List, Optional, Tuple, Union from elementary.clients.api.api_client import APIClient +from elementary.clients.dbt.subprocess_dbt_runner import SubprocessDbtRunner from elementary.monitor.api.filters.filters import FiltersAPI from elementary.monitor.api.groups.groups import GroupsAPI from elementary.monitor.api.groups.schema import GroupsSchema @@ -38,8 +40,11 @@ from elementary.monitor.api.totals_schema import TotalsSchema from elementary.monitor.data_monitoring.schema import SelectorFilterSchema from elementary.monitor.fetchers.tests.schema import NormalizedTestSchema +from elementary.utils.log import get_logger from elementary.utils.time import get_now_utc_iso_format +logger = get_logger(__name__) + class ReportAPI(APIClient): def _get_groups( @@ -68,6 +73,27 @@ def _get_exposures( ) -> Dict[str, NormalizedExposureSchema]: return models_api.get_exposures(upstream_node_ids=upstream_node_ids) + def _create_subprocess_runner(self) -> SubprocessDbtRunner: + """Create a SubprocessDbtRunner for thread-safe parallel execution. + + dbt's Python API (APIDbtRunner) is not thread-safe due to global + mutable state (GLOBAL_FLAGS, adapter FACTORY, etc.). + SubprocessDbtRunner spawns an independent dbt process per call, + making it safe to use from multiple threads. + """ + runner = self.dbt_runner + return SubprocessDbtRunner( + project_dir=runner.project_dir, + profiles_dir=runner.profiles_dir, + target=runner.target, + raise_on_failure=runner.raise_on_failure, + env_vars=getattr(runner, "env_vars", None), + vars=runner.vars, + secret_vars=runner.secret_vars, + allow_macros_without_package_prefix=runner.allow_macros_without_package_prefix, + run_deps_if_needed=False, + ) + def get_report_data( self, days_back: int = 7, @@ -79,6 +105,44 @@ def get_report_data( filter: SelectorFilterSchema = SelectorFilterSchema(), env: Optional[str] = None, warehouse_type: Optional[str] = None, + threads: int = 1, + ) -> Tuple[ReportDataSchema, Optional[Exception]]: + if threads > 1: + return self._get_report_data_parallel( + days_back=days_back, + test_runs_amount=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, + exclude_elementary_models=exclude_elementary_models, + project_name=project_name, + disable_samples=disable_samples, + filter=filter, + env=env, + warehouse_type=warehouse_type, + threads=threads, + ) + return self._get_report_data_sequential( + days_back=days_back, + test_runs_amount=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, + exclude_elementary_models=exclude_elementary_models, + project_name=project_name, + disable_samples=disable_samples, + filter=filter, + env=env, + warehouse_type=warehouse_type, + ) + + def _get_report_data_sequential( + self, + days_back: int = 7, + test_runs_amount: int = 720, + disable_passed_test_metrics: bool = False, + exclude_elementary_models: bool = False, + project_name: Optional[str] = None, + disable_samples: bool = False, + filter: SelectorFilterSchema = SelectorFilterSchema(), + env: Optional[str] = None, + warehouse_type: Optional[str] = None, ) -> Tuple[ReportDataSchema, Optional[Exception]]: try: tests_api = TestsAPI( @@ -170,6 +234,220 @@ def get_report_data( snapshots, ) + return self._build_report_data( + days_back=days_back, + project_name=project_name, + env=env, + warehouse_type=warehouse_type, + exclude_elementary_models=exclude_elementary_models, + seeds=seeds, + snapshots=snapshots, + models=models, + sources=sources, + exposures=exposures, + singular_tests=singular_tests, + groups=groups, + models_runs=models_runs, + coverages=coverages, + tests=tests, + test_invocation=test_invocation, + test_results=test_results, + source_freshness_results=source_freshness_results, + test_runs=test_runs, + source_freshness_runs=source_freshness_runs, + lineage=lineage, + filters=filters, + invocations_api=invocations_api, + ) + except Exception as error: + return ReportDataSchema(), error + + def _get_report_data_parallel( + self, + days_back: int = 7, + test_runs_amount: int = 720, + disable_passed_test_metrics: bool = False, + exclude_elementary_models: bool = False, + project_name: Optional[str] = None, + disable_samples: bool = False, + filter: SelectorFilterSchema = SelectorFilterSchema(), + env: Optional[str] = None, + warehouse_type: Optional[str] = None, + threads: int = 4, + ) -> Tuple[ReportDataSchema, Optional[Exception]]: + try: + parallel_runner = self._create_subprocess_runner() + logger.info( + "Fetching report data in parallel with %d threads", threads + ) + + # Phase 1: fetch all independent data in parallel + with ThreadPoolExecutor(max_workers=threads) as pool: + f_seeds = pool.submit( + ModelsAPI(dbt_runner=parallel_runner).get_seeds + ) + f_snapshots = pool.submit( + ModelsAPI(dbt_runner=parallel_runner).get_snapshots + ) + f_models = pool.submit( + ModelsAPI(dbt_runner=parallel_runner).get_models, + exclude_elementary_models, + ) + f_sources = pool.submit( + ModelsAPI(dbt_runner=parallel_runner).get_sources + ) + f_singular_tests = pool.submit( + TestsAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, + ).get_singular_tests + ) + f_models_runs = pool.submit( + ModelsAPI(dbt_runner=parallel_runner).get_models_runs, + days_back, + exclude_elementary_models, + ) + f_coverages = pool.submit( + ModelsAPI(dbt_runner=parallel_runner).get_test_coverages + ) + f_tests = pool.submit( + TestsAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, + ).get_tests + ) + f_test_invocation = pool.submit( + InvocationsAPI( + dbt_runner=parallel_runner + ).get_test_invocation_from_filter, + filter, + ) + f_freshness_results = pool.submit( + SourceFreshnessesAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, + ).get_source_freshness_results + ) + f_test_runs = pool.submit( + TestsAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, + ).get_test_runs + ) + f_freshness_runs = pool.submit( + SourceFreshnessesAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, + ).get_source_freshness_runs + ) + f_latest_invocation = pool.submit( + InvocationsAPI( + dbt_runner=parallel_runner + ).get_models_latest_invocation + ) + f_invocations_data = pool.submit( + InvocationsAPI( + dbt_runner=parallel_runner + ).get_models_latest_invocations_data + ) + + seeds = f_seeds.result() + snapshots = f_snapshots.result() + models = f_models.result() + sources = f_sources.result() + singular_tests = f_singular_tests.result() + models_runs = f_models_runs.result() + coverages = f_coverages.result() + tests = f_tests.result() + test_invocation = f_test_invocation.result() + source_freshness_results = f_freshness_results.result() + test_runs = f_test_runs.result() + source_freshness_runs = f_freshness_runs.result() + models_latest_invocation = f_latest_invocation.result() + invocations_data = f_invocations_data.result() + + # Phase 2: fetch data that depends on Phase 1 results + lineage_node_ids: List[str] = list( + seeds.keys() + ) + list(snapshots.keys()) + list(models.keys()) + list(sources.keys()) + + with ThreadPoolExecutor(max_workers=threads) as pool: + f_exposures = pool.submit( + ModelsAPI(dbt_runner=parallel_runner).get_exposures, + upstream_node_ids=lineage_node_ids, + ) + f_test_results = pool.submit( + TestsAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, + ).get_test_results, + test_invocation.invocation_id, + disable_samples, + ) + + exposures = f_exposures.result() + test_results = f_test_results.result() + + lineage_node_ids.extend(exposures.keys()) + + # Phase 3: lineage depends on all node IDs + lineage = LineageAPI(dbt_runner=parallel_runner).get_lineage( + lineage_node_ids, exclude_elementary_models + ) + + # Phase 4: pure computation (no dbt calls) + groups = self._get_groups( + models.values(), + sources.values(), + exposures.values(), + seeds.values(), + snapshots.values(), + singular_tests, + ) + + union_test_results = { + x: test_results.get(x, []) + source_freshness_results.get(x, []) + for x in set(test_results).union(source_freshness_results) + } + test_results_totals = get_total_test_results(union_test_results) + + union_test_runs = dict() + for key in set(test_runs).union(source_freshness_runs): + test_run = test_runs.get(key, []) + source_freshness_run = ( + source_freshness_runs.get(key, []) if key is not None else [] + ) + union_test_runs[key] = test_run + source_freshness_run + test_runs_totals = get_total_test_runs(union_test_runs) + + filters = FiltersAPI(dbt_runner=parallel_runner).get_filters( + test_results_totals, + test_runs_totals, + models, + sources, + models_runs.runs, + seeds, + snapshots, + ) + + invocations_job_identification = defaultdict(list) + for invocation in invocations_data: + invocation_key = invocation.job_name or invocation.job_id + if invocation_key is not None: + invocations_job_identification[invocation_key].append( + invocation.invocation_id + ) + serializable_groups = groups.dict() serializable_models = self._serialize_models( models, sources, exposures, seeds, snapshots @@ -190,17 +468,6 @@ def get_report_data( serializable_filters = filters.dict() serializable_lineage = lineage.dict() - models_latest_invocation = invocations_api.get_models_latest_invocation() - invocations = invocations_api.get_models_latest_invocations_data() - - invocations_job_identification = defaultdict(list) - for invocation in invocations: - invocation_key = invocation.job_name or invocation.job_id - if invocation_key is not None: - invocations_job_identification[invocation_key].append( - invocation.invocation_id - ) - report_data = ReportDataSchema( creation_time=get_now_utc_iso_format(), days_back=days_back, @@ -217,7 +484,7 @@ def get_report_data( model_runs_totals=serializable_model_runs_totals, filters=serializable_filters, lineage=serializable_lineage, - invocations=invocations, + invocations=invocations_data, resources_latest_invocation=models_latest_invocation, invocations_job_identification=invocations_job_identification, env=ReportDataEnvSchema( @@ -228,6 +495,103 @@ def get_report_data( except Exception as error: return ReportDataSchema(), error + def _build_report_data( + self, + days_back, + project_name, + env, + warehouse_type, + exclude_elementary_models, + seeds, + snapshots, + models, + sources, + exposures, + singular_tests, + groups, + models_runs, + coverages, + tests, + test_invocation, + test_results, + source_freshness_results, + test_runs, + source_freshness_runs, + lineage, + filters, + invocations_api, + ) -> Tuple[ReportDataSchema, Optional[Exception]]: + union_test_results = { + x: test_results.get(x, []) + source_freshness_results.get(x, []) + for x in set(test_results).union(source_freshness_results) + } + test_results_totals = get_total_test_results(union_test_results) + + union_test_runs = dict() + for key in set(test_runs).union(source_freshness_runs): + test_run = test_runs.get(key, []) + source_freshness_run = ( + source_freshness_runs.get(key, []) if key is not None else [] + ) + union_test_runs[key] = test_run + source_freshness_run + test_runs_totals = get_total_test_runs(union_test_runs) + + serializable_groups = groups.dict() + serializable_models = self._serialize_models( + models, sources, exposures, seeds, snapshots + ) + serializable_model_runs = self._serialize_models_runs(models_runs.runs) + serializable_model_runs_totals = models_runs.dict(include={"totals"})[ + "totals" + ] + serializable_models_coverages = self._serialize_coverages(coverages) + serializable_tests = self._serialize_tests(tests) + serializable_test_results = self._serialize_test_results(union_test_results) + serializable_test_results_totals = self._serialize_totals( + test_results_totals + ) + serializable_test_runs = self._serialize_test_runs(union_test_runs) + serializable_test_runs_totals = self._serialize_totals(test_runs_totals) + serializable_invocation = test_invocation.dict() + serializable_filters = filters.dict() + serializable_lineage = lineage.dict() + + models_latest_invocation = invocations_api.get_models_latest_invocation() + invocations = invocations_api.get_models_latest_invocations_data() + + invocations_job_identification = defaultdict(list) + for invocation in invocations: + invocation_key = invocation.job_name or invocation.job_id + if invocation_key is not None: + invocations_job_identification[invocation_key].append( + invocation.invocation_id + ) + + report_data = ReportDataSchema( + creation_time=get_now_utc_iso_format(), + days_back=days_back, + models=serializable_models, + groups=serializable_groups, + tests=serializable_tests, + invocation=serializable_invocation, + test_results=serializable_test_results, + test_results_totals=serializable_test_results_totals, + test_runs=serializable_test_runs, + test_runs_totals=serializable_test_runs_totals, + coverages=serializable_models_coverages, + model_runs=serializable_model_runs, + model_runs_totals=serializable_model_runs_totals, + filters=serializable_filters, + lineage=serializable_lineage, + invocations=invocations, + resources_latest_invocation=models_latest_invocation, + invocations_job_identification=invocations_job_identification, + env=ReportDataEnvSchema( + project_name=project_name, env=env, warehouse_type=warehouse_type + ), + ) + return report_data, None + def _serialize_models( self, models: Dict[str, NormalizedModelSchema], diff --git a/elementary/monitor/cli.py b/elementary/monitor/cli.py index c86b4b992..825c25b73 100644 --- a/elementary/monitor/cli.py +++ b/elementary/monitor/cli.py @@ -449,6 +449,13 @@ def monitor( default=True, help="Whether to open the report in the browser.", ) +@click.option( + "--threads", + type=int, + default=1, + help="Number of threads for fetching report data in parallel. " + "When set to >1, independent dbt operations run concurrently using subprocess-based runners.", +) @click.pass_context def report( ctx, @@ -464,6 +471,7 @@ def report( file_path, disable_passed_test_metrics, open_browser, + threads, exclude_elementary_models, disable_samples, project_name, @@ -511,6 +519,7 @@ def report( exclude_elementary_models=exclude_elementary_models, should_open_browser=open_browser, project_name=project_name, + threads=threads, ) anonymous_tracking.track_cli_end( Command.REPORT, data_monitoring.properties(), ctx.command.name @@ -660,6 +669,13 @@ def report( default=None, help="Include additional information at the test results summary message.\nCurrently only --include descriptions is supported.", ) +@click.option( + "--threads", + type=int, + default=1, + help="Number of threads for fetching report data in parallel. " + "When set to >1, independent dbt operations run concurrently using subprocess-based runners.", +) @click.pass_context def send_report( ctx, @@ -701,6 +717,7 @@ def send_report( select, disable, include, + threads, target_path, quiet_logs, ssl_ca_bundle, @@ -784,6 +801,7 @@ def send_report( remote_file_path=bucket_file_path, disable_html_attachment=(disable == "html_attachment"), include_description=(include == "description"), + threads=threads, ) anonymous_tracking.track_cli_end( diff --git a/elementary/monitor/data_monitoring/report/data_monitoring_report.py b/elementary/monitor/data_monitoring/report/data_monitoring_report.py index 7493b96e7..005b41e92 100644 --- a/elementary/monitor/data_monitoring/report/data_monitoring_report.py +++ b/elementary/monitor/data_monitoring/report/data_monitoring_report.py @@ -61,6 +61,7 @@ def generate_report( should_open_browser: bool = True, exclude_elementary_models: bool = False, project_name: Optional[str] = None, + threads: int = 1, ) -> Tuple[bool, str]: html_path = self._get_report_file_path(file_path) output_data = self.get_report_data( @@ -69,6 +70,7 @@ def generate_report( disable_passed_test_metrics=disable_passed_test_metrics, exclude_elementary_models=exclude_elementary_models, project_name=project_name, + threads=threads, ) template_html_path = os.path.join(os.path.dirname(__file__), "index.html") @@ -110,6 +112,7 @@ def get_report_data( disable_passed_test_metrics: bool = False, exclude_elementary_models: bool = False, project_name: Optional[str] = None, + threads: int = 1, ): report_api = ReportAPI(self.internal_dbt_runner) report_data, error = report_api.get_report_data( @@ -122,6 +125,7 @@ def get_report_data( filter=self.selector_filter.to_selector_filter_schema(), env=self.config.env, warehouse_type=self.warehouse_info.type if self.warehouse_info else None, + threads=threads, ) self._add_report_tracking(report_data, error) if error: @@ -182,6 +186,7 @@ def send_report( remote_file_path: Optional[str] = None, disable_html_attachment: bool = False, include_description: bool = False, + threads: int = 1, ): # Generate the report generated_report_successfully, local_html_path = self.generate_report( @@ -192,6 +197,7 @@ def send_report( should_open_browser=should_open_browser, exclude_elementary_models=exclude_elementary_models, project_name=project_name, + threads=threads, ) if not generated_report_successfully: From 1221e71fe17a1787b7820da818600ee12c58193b Mon Sep 17 00:00:00 2001 From: Daisuke Taniwaki Date: Wed, 22 Apr 2026 22:32:02 +0900 Subject: [PATCH 2/4] refactor: extract _assemble_report_data and add unit tests - Extract shared serialization logic into _assemble_report_data, used by both sequential and parallel paths - Reduce API object construction duplication in parallel path with factory helpers - Add unit tests for subprocess runner creation, routing logic, thread pool usage, and error propagation --- elementary/monitor/api/report/report.py | 343 ++++++------------ tests/unit/monitor/api/report/__init__.py | 0 .../api/report/test_report_parallel.py | 121 ++++++ 3 files changed, 228 insertions(+), 236 deletions(-) create mode 100644 tests/unit/monitor/api/report/__init__.py create mode 100644 tests/unit/monitor/api/report/test_report_parallel.py diff --git a/elementary/monitor/api/report/report.py b/elementary/monitor/api/report/report.py index 5fd9e2a26..802679fd5 100644 --- a/elementary/monitor/api/report/report.py +++ b/elementary/monitor/api/report/report.py @@ -176,15 +176,6 @@ def _get_report_data_sequential( lineage_node_ids.extend(exposures.keys()) singular_tests = tests_api.get_singular_tests() - groups = self._get_groups( - models.values(), - sources.values(), - exposures.values(), - seeds.values(), - snapshots.values(), - singular_tests, - ) - models_runs = models_api.get_models_runs( days_back=days_back, exclude_elementary_models=exclude_elementary_models ) @@ -200,53 +191,24 @@ def _get_report_data_sequential( source_freshness_results = ( source_freshnesses_api.get_source_freshness_results() ) - - union_test_results = { - x: test_results.get(x, []) + source_freshness_results.get(x, []) - for x in set(test_results).union(source_freshness_results) - } - - test_results_totals = get_total_test_results(union_test_results) - test_runs = tests_api.get_test_runs() source_freshness_runs = source_freshnesses_api.get_source_freshness_runs() - union_test_runs = dict() - for key in set(test_runs).union(source_freshness_runs): - test_run = test_runs.get(key, []) - source_freshness_run = ( - source_freshness_runs.get(key, []) if key is not None else [] - ) - union_test_runs[key] = test_run + source_freshness_run - - test_runs_totals = get_total_test_runs(union_test_runs) - lineage = lineage_api.get_lineage( lineage_node_ids, exclude_elementary_models ) - filters = filters_api.get_filters( - test_results_totals, - test_runs_totals, - models, - sources, - models_runs.runs, - seeds, - snapshots, - ) - return self._build_report_data( + return self._assemble_report_data( days_back=days_back, project_name=project_name, env=env, warehouse_type=warehouse_type, - exclude_elementary_models=exclude_elementary_models, seeds=seeds, snapshots=snapshots, models=models, sources=sources, exposures=exposures, singular_tests=singular_tests, - groups=groups, models_runs=models_runs, coverages=coverages, tests=tests, @@ -256,8 +218,9 @@ def _get_report_data_sequential( test_runs=test_runs, source_freshness_runs=source_freshness_runs, lineage=lineage, - filters=filters, - invocations_api=invocations_api, + filters_api=filters_api, + models_latest_invocation=invocations_api.get_models_latest_invocation(), + invocations_data=invocations_api.get_models_latest_invocations_data(), ) except Exception as error: return ReportDataSchema(), error @@ -281,82 +244,58 @@ def _get_report_data_parallel( "Fetching report data in parallel with %d threads", threads ) - # Phase 1: fetch all independent data in parallel - with ThreadPoolExecutor(max_workers=threads) as pool: - f_seeds = pool.submit( - ModelsAPI(dbt_runner=parallel_runner).get_seeds + def _new_models_api() -> ModelsAPI: + return ModelsAPI(dbt_runner=parallel_runner) + + def _new_tests_api() -> TestsAPI: + return TestsAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, + disable_passed_test_metrics=disable_passed_test_metrics, ) - f_snapshots = pool.submit( - ModelsAPI(dbt_runner=parallel_runner).get_snapshots + + def _new_freshness_api() -> SourceFreshnessesAPI: + return SourceFreshnessesAPI( + dbt_runner=parallel_runner, + days_back=days_back, + invocations_per_test=test_runs_amount, ) + + def _new_invocations_api() -> InvocationsAPI: + return InvocationsAPI(dbt_runner=parallel_runner) + + # Phase 1: fetch all independent data in parallel + with ThreadPoolExecutor(max_workers=threads) as pool: + f_seeds = pool.submit(_new_models_api().get_seeds) + f_snapshots = pool.submit(_new_models_api().get_snapshots) f_models = pool.submit( - ModelsAPI(dbt_runner=parallel_runner).get_models, - exclude_elementary_models, - ) - f_sources = pool.submit( - ModelsAPI(dbt_runner=parallel_runner).get_sources - ) - f_singular_tests = pool.submit( - TestsAPI( - dbt_runner=parallel_runner, - days_back=days_back, - invocations_per_test=test_runs_amount, - disable_passed_test_metrics=disable_passed_test_metrics, - ).get_singular_tests + _new_models_api().get_models, exclude_elementary_models ) + f_sources = pool.submit(_new_models_api().get_sources) + f_singular_tests = pool.submit(_new_tests_api().get_singular_tests) f_models_runs = pool.submit( - ModelsAPI(dbt_runner=parallel_runner).get_models_runs, + _new_models_api().get_models_runs, days_back, exclude_elementary_models, ) - f_coverages = pool.submit( - ModelsAPI(dbt_runner=parallel_runner).get_test_coverages - ) - f_tests = pool.submit( - TestsAPI( - dbt_runner=parallel_runner, - days_back=days_back, - invocations_per_test=test_runs_amount, - disable_passed_test_metrics=disable_passed_test_metrics, - ).get_tests - ) + f_coverages = pool.submit(_new_models_api().get_test_coverages) + f_tests = pool.submit(_new_tests_api().get_tests) f_test_invocation = pool.submit( - InvocationsAPI( - dbt_runner=parallel_runner - ).get_test_invocation_from_filter, - filter, + _new_invocations_api().get_test_invocation_from_filter, filter ) f_freshness_results = pool.submit( - SourceFreshnessesAPI( - dbt_runner=parallel_runner, - days_back=days_back, - invocations_per_test=test_runs_amount, - ).get_source_freshness_results - ) - f_test_runs = pool.submit( - TestsAPI( - dbt_runner=parallel_runner, - days_back=days_back, - invocations_per_test=test_runs_amount, - disable_passed_test_metrics=disable_passed_test_metrics, - ).get_test_runs + _new_freshness_api().get_source_freshness_results ) + f_test_runs = pool.submit(_new_tests_api().get_test_runs) f_freshness_runs = pool.submit( - SourceFreshnessesAPI( - dbt_runner=parallel_runner, - days_back=days_back, - invocations_per_test=test_runs_amount, - ).get_source_freshness_runs + _new_freshness_api().get_source_freshness_runs ) f_latest_invocation = pool.submit( - InvocationsAPI( - dbt_runner=parallel_runner - ).get_models_latest_invocation + _new_invocations_api().get_models_latest_invocation ) f_invocations_data = pool.submit( - InvocationsAPI( - dbt_runner=parallel_runner - ).get_models_latest_invocations_data + _new_invocations_api().get_models_latest_invocations_data ) seeds = f_seeds.result() @@ -375,29 +314,26 @@ def _get_report_data_parallel( invocations_data = f_invocations_data.result() # Phase 2: fetch data that depends on Phase 1 results - lineage_node_ids: List[str] = list( - seeds.keys() - ) + list(snapshots.keys()) + list(models.keys()) + list(sources.keys()) + lineage_node_ids: List[str] = ( + list(seeds.keys()) + + list(snapshots.keys()) + + list(models.keys()) + + list(sources.keys()) + ) with ThreadPoolExecutor(max_workers=threads) as pool: f_exposures = pool.submit( - ModelsAPI(dbt_runner=parallel_runner).get_exposures, + _new_models_api().get_exposures, upstream_node_ids=lineage_node_ids, ) f_test_results = pool.submit( - TestsAPI( - dbt_runner=parallel_runner, - days_back=days_back, - invocations_per_test=test_runs_amount, - disable_passed_test_metrics=disable_passed_test_metrics, - ).get_test_results, + _new_tests_api().get_test_results, test_invocation.invocation_id, disable_samples, ) exposures = f_exposures.result() test_results = f_test_results.result() - lineage_node_ids.extend(exposures.keys()) # Phase 3: lineage depends on all node IDs @@ -406,109 +342,45 @@ def _get_report_data_parallel( ) # Phase 4: pure computation (no dbt calls) - groups = self._get_groups( - models.values(), - sources.values(), - exposures.values(), - seeds.values(), - snapshots.values(), - singular_tests, - ) - - union_test_results = { - x: test_results.get(x, []) + source_freshness_results.get(x, []) - for x in set(test_results).union(source_freshness_results) - } - test_results_totals = get_total_test_results(union_test_results) - - union_test_runs = dict() - for key in set(test_runs).union(source_freshness_runs): - test_run = test_runs.get(key, []) - source_freshness_run = ( - source_freshness_runs.get(key, []) if key is not None else [] - ) - union_test_runs[key] = test_run + source_freshness_run - test_runs_totals = get_total_test_runs(union_test_runs) - - filters = FiltersAPI(dbt_runner=parallel_runner).get_filters( - test_results_totals, - test_runs_totals, - models, - sources, - models_runs.runs, - seeds, - snapshots, - ) - - invocations_job_identification = defaultdict(list) - for invocation in invocations_data: - invocation_key = invocation.job_name or invocation.job_id - if invocation_key is not None: - invocations_job_identification[invocation_key].append( - invocation.invocation_id - ) - - serializable_groups = groups.dict() - serializable_models = self._serialize_models( - models, sources, exposures, seeds, snapshots - ) - serializable_model_runs = self._serialize_models_runs(models_runs.runs) - serializable_model_runs_totals = models_runs.dict(include={"totals"})[ - "totals" - ] - serializable_models_coverages = self._serialize_coverages(coverages) - serializable_tests = self._serialize_tests(tests) - serializable_test_results = self._serialize_test_results(union_test_results) - serializable_test_results_totals = self._serialize_totals( - test_results_totals - ) - serializable_test_runs = self._serialize_test_runs(union_test_runs) - serializable_test_runs_totals = self._serialize_totals(test_runs_totals) - serializable_invocation = test_invocation.dict() - serializable_filters = filters.dict() - serializable_lineage = lineage.dict() - - report_data = ReportDataSchema( - creation_time=get_now_utc_iso_format(), + return self._assemble_report_data( days_back=days_back, - models=serializable_models, - groups=serializable_groups, - tests=serializable_tests, - invocation=serializable_invocation, - test_results=serializable_test_results, - test_results_totals=serializable_test_results_totals, - test_runs=serializable_test_runs, - test_runs_totals=serializable_test_runs_totals, - coverages=serializable_models_coverages, - model_runs=serializable_model_runs, - model_runs_totals=serializable_model_runs_totals, - filters=serializable_filters, - lineage=serializable_lineage, - invocations=invocations_data, - resources_latest_invocation=models_latest_invocation, - invocations_job_identification=invocations_job_identification, - env=ReportDataEnvSchema( - project_name=project_name, env=env, warehouse_type=warehouse_type - ), + project_name=project_name, + env=env, + warehouse_type=warehouse_type, + seeds=seeds, + snapshots=snapshots, + models=models, + sources=sources, + exposures=exposures, + singular_tests=singular_tests, + models_runs=models_runs, + coverages=coverages, + tests=tests, + test_invocation=test_invocation, + test_results=test_results, + source_freshness_results=source_freshness_results, + test_runs=test_runs, + source_freshness_runs=source_freshness_runs, + lineage=lineage, + filters_api=FiltersAPI(dbt_runner=parallel_runner), + models_latest_invocation=models_latest_invocation, + invocations_data=invocations_data, ) - return report_data, None except Exception as error: return ReportDataSchema(), error - def _build_report_data( + def _assemble_report_data( self, days_back, project_name, env, warehouse_type, - exclude_elementary_models, seeds, snapshots, models, sources, exposures, singular_tests, - groups, models_runs, coverages, tests, @@ -518,9 +390,19 @@ def _build_report_data( test_runs, source_freshness_runs, lineage, - filters, - invocations_api, + filters_api, + models_latest_invocation, + invocations_data, ) -> Tuple[ReportDataSchema, Optional[Exception]]: + groups = self._get_groups( + models.values(), + sources.values(), + exposures.values(), + seeds.values(), + snapshots.values(), + singular_tests, + ) + union_test_results = { x: test_results.get(x, []) + source_freshness_results.get(x, []) for x in set(test_results).union(source_freshness_results) @@ -536,31 +418,18 @@ def _build_report_data( union_test_runs[key] = test_run + source_freshness_run test_runs_totals = get_total_test_runs(union_test_runs) - serializable_groups = groups.dict() - serializable_models = self._serialize_models( - models, sources, exposures, seeds, snapshots - ) - serializable_model_runs = self._serialize_models_runs(models_runs.runs) - serializable_model_runs_totals = models_runs.dict(include={"totals"})[ - "totals" - ] - serializable_models_coverages = self._serialize_coverages(coverages) - serializable_tests = self._serialize_tests(tests) - serializable_test_results = self._serialize_test_results(union_test_results) - serializable_test_results_totals = self._serialize_totals( - test_results_totals + filters = filters_api.get_filters( + test_results_totals, + test_runs_totals, + models, + sources, + models_runs.runs, + seeds, + snapshots, ) - serializable_test_runs = self._serialize_test_runs(union_test_runs) - serializable_test_runs_totals = self._serialize_totals(test_runs_totals) - serializable_invocation = test_invocation.dict() - serializable_filters = filters.dict() - serializable_lineage = lineage.dict() - - models_latest_invocation = invocations_api.get_models_latest_invocation() - invocations = invocations_api.get_models_latest_invocations_data() invocations_job_identification = defaultdict(list) - for invocation in invocations: + for invocation in invocations_data: invocation_key = invocation.job_name or invocation.job_id if invocation_key is not None: invocations_job_identification[invocation_key].append( @@ -570,20 +439,22 @@ def _build_report_data( report_data = ReportDataSchema( creation_time=get_now_utc_iso_format(), days_back=days_back, - models=serializable_models, - groups=serializable_groups, - tests=serializable_tests, - invocation=serializable_invocation, - test_results=serializable_test_results, - test_results_totals=serializable_test_results_totals, - test_runs=serializable_test_runs, - test_runs_totals=serializable_test_runs_totals, - coverages=serializable_models_coverages, - model_runs=serializable_model_runs, - model_runs_totals=serializable_model_runs_totals, - filters=serializable_filters, - lineage=serializable_lineage, - invocations=invocations, + models=self._serialize_models( + models, sources, exposures, seeds, snapshots + ), + groups=groups.dict(), + tests=self._serialize_tests(tests), + invocation=test_invocation.dict(), + test_results=self._serialize_test_results(union_test_results), + test_results_totals=self._serialize_totals(test_results_totals), + test_runs=self._serialize_test_runs(union_test_runs), + test_runs_totals=self._serialize_totals(test_runs_totals), + coverages=self._serialize_coverages(coverages), + model_runs=self._serialize_models_runs(models_runs.runs), + model_runs_totals=models_runs.dict(include={"totals"})["totals"], + filters=filters.dict(), + lineage=lineage.dict(), + invocations=invocations_data, resources_latest_invocation=models_latest_invocation, invocations_job_identification=invocations_job_identification, env=ReportDataEnvSchema( diff --git a/tests/unit/monitor/api/report/__init__.py b/tests/unit/monitor/api/report/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/monitor/api/report/test_report_parallel.py b/tests/unit/monitor/api/report/test_report_parallel.py new file mode 100644 index 000000000..80fa862fa --- /dev/null +++ b/tests/unit/monitor/api/report/test_report_parallel.py @@ -0,0 +1,121 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from elementary.monitor.api.report.report import ReportAPI + + +@pytest.fixture +def mock_dbt_runner(): + runner = MagicMock() + runner.project_dir = "/tmp/project" + runner.profiles_dir = "/tmp/profiles" + runner.target = "dev" + runner.raise_on_failure = True + runner.env_vars = {"KEY": "value"} + runner.vars = {} + runner.secret_vars = {} + runner.allow_macros_without_package_prefix = False + return runner + + +class TestCreateSubprocessRunner: + def test_creates_runner_with_correct_config(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch( + "elementary.monitor.api.report.report.SubprocessDbtRunner" + ) as mock_cls: + api._create_subprocess_runner() + mock_cls.assert_called_once_with( + project_dir="/tmp/project", + profiles_dir="/tmp/profiles", + target="dev", + raise_on_failure=True, + env_vars={"KEY": "value"}, + vars={}, + secret_vars={}, + allow_macros_without_package_prefix=False, + run_deps_if_needed=False, + ) + + def test_deps_not_run(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch( + "elementary.monitor.api.report.report.SubprocessDbtRunner" + ) as mock_cls: + api._create_subprocess_runner() + call_kwargs = mock_cls.call_args[1] + assert call_kwargs["run_deps_if_needed"] is False + + +class TestGetReportDataRouting: + def test_threads_1_uses_sequential(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch.object(api, "_get_report_data_sequential") as mock_seq: + mock_seq.return_value = (MagicMock(), None) + api.get_report_data(threads=1) + mock_seq.assert_called_once() + + def test_threads_gt1_uses_parallel(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch.object(api, "_get_report_data_parallel") as mock_par: + mock_par.return_value = (MagicMock(), None) + api.get_report_data(threads=4) + mock_par.assert_called_once() + + def test_threads_passed_to_parallel(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with patch.object(api, "_get_report_data_parallel") as mock_par: + mock_par.return_value = (MagicMock(), None) + api.get_report_data(threads=8) + call_kwargs = mock_par.call_args[1] + assert call_kwargs["threads"] == 8 + + +class TestGetReportDataParallel: + def test_uses_thread_pool_executor(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + with ( + patch.object(api, "_create_subprocess_runner") as mock_create, + patch( + "elementary.monitor.api.report.report.ThreadPoolExecutor" + ) as mock_pool_cls, + patch( + "elementary.monitor.api.report.report.ModelsAPI" + ), + patch( + "elementary.monitor.api.report.report.TestsAPI" + ), + patch( + "elementary.monitor.api.report.report.SourceFreshnessesAPI" + ), + patch( + "elementary.monitor.api.report.report.InvocationsAPI" + ), + patch( + "elementary.monitor.api.report.report.LineageAPI" + ), + patch( + "elementary.monitor.api.report.report.FiltersAPI" + ), + patch.object(api, "_assemble_report_data") as mock_assemble, + ): + mock_create.return_value = MagicMock() + mock_pool = MagicMock() + mock_pool_cls.return_value.__enter__ = MagicMock(return_value=mock_pool) + mock_pool_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_pool.submit.return_value.result.return_value = {} + mock_assemble.return_value = (MagicMock(), None) + + api._get_report_data_parallel(threads=4) + + mock_pool_cls.assert_called_with(max_workers=4) + + def test_error_propagation(self, mock_dbt_runner): + api = ReportAPI(mock_dbt_runner) + error = RuntimeError("test error") + with patch.object( + api, "_create_subprocess_runner", side_effect=error + ): + result, err = api._get_report_data_parallel(threads=4) + assert err is error From 3ed15b9d05c2d1a7b4d1deb7c63acabab9b4a77b Mon Sep 17 00:00:00 2001 From: Daisuke Taniwaki Date: Wed, 22 Apr 2026 22:33:22 +0900 Subject: [PATCH 3/4] fix: use click.IntRange for --threads to reject invalid values --- elementary/monitor/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elementary/monitor/cli.py b/elementary/monitor/cli.py index 825c25b73..dbbdd4416 100644 --- a/elementary/monitor/cli.py +++ b/elementary/monitor/cli.py @@ -451,7 +451,7 @@ def monitor( ) @click.option( "--threads", - type=int, + type=click.IntRange(min=1), default=1, help="Number of threads for fetching report data in parallel. " "When set to >1, independent dbt operations run concurrently using subprocess-based runners.", @@ -671,7 +671,7 @@ def report( ) @click.option( "--threads", - type=int, + type=click.IntRange(min=1), default=1, help="Number of threads for fetching report data in parallel. " "When set to >1, independent dbt operations run concurrently using subprocess-based runners.", From 6048e939217610c898c27d24849085b383f2dcae Mon Sep 17 00:00:00 2001 From: Daisuke Taniwaki Date: Wed, 22 Apr 2026 22:41:44 +0900 Subject: [PATCH 4/4] fix: strengthen parallel test assertions and fix unused variable --- tests/unit/monitor/api/report/test_report_parallel.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/unit/monitor/api/report/test_report_parallel.py b/tests/unit/monitor/api/report/test_report_parallel.py index 80fa862fa..9f5ecec8b 100644 --- a/tests/unit/monitor/api/report/test_report_parallel.py +++ b/tests/unit/monitor/api/report/test_report_parallel.py @@ -104,12 +104,16 @@ def test_uses_thread_pool_executor(self, mock_dbt_runner): mock_pool = MagicMock() mock_pool_cls.return_value.__enter__ = MagicMock(return_value=mock_pool) mock_pool_cls.return_value.__exit__ = MagicMock(return_value=False) - mock_pool.submit.return_value.result.return_value = {} + mock_pool.submit.return_value.result.return_value = MagicMock( + invocation_id="inv-1" + ) mock_assemble.return_value = (MagicMock(), None) - api._get_report_data_parallel(threads=4) + _, err = api._get_report_data_parallel(threads=4) mock_pool_cls.assert_called_with(max_workers=4) + assert err is None + mock_assemble.assert_called_once() def test_error_propagation(self, mock_dbt_runner): api = ReportAPI(mock_dbt_runner) @@ -117,5 +121,5 @@ def test_error_propagation(self, mock_dbt_runner): with patch.object( api, "_create_subprocess_runner", side_effect=error ): - result, err = api._get_report_data_parallel(threads=4) + _, err = api._get_report_data_parallel(threads=4) assert err is error