From 4e64c14d353e9067f4e2c03c27cacf081c50978c Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Fri, 23 May 2025 15:03:36 -0700 Subject: [PATCH 1/4] feat (modularize): First steps to modularizing Pan3D components Addressing issue https://github.com/Kitware/pan3d/issues/186 -- modulerized Pan3D color picker and scalar bar --- src/pan3d/explorers/analytics.py | 62 ++----- src/pan3d/explorers/contour.py | 44 +++-- src/pan3d/explorers/globe.py | 76 ++------ src/pan3d/explorers/slicer.py | 59 ++----- src/pan3d/ui/contour.py | 43 ++--- src/pan3d/ui/globe.py | 97 +++++------ src/pan3d/ui/preview.py | 43 ++--- src/pan3d/ui/slicer.py | 50 +++--- src/pan3d/ui/vtk_view.py | 47 ----- src/pan3d/utils/common.py | 143 +-------------- src/pan3d/viewers/preview.py | 60 ++----- src/pan3d/widgets/__init__.py | 0 src/pan3d/widgets/color.py | 291 +++++++++++++++++++++++++++++++ src/trame/widgets/pan3d.py | 3 + 14 files changed, 497 insertions(+), 521 deletions(-) create mode 100644 src/pan3d/widgets/__init__.py create mode 100644 src/pan3d/widgets/color.py create mode 100644 src/trame/widgets/pan3d.py diff --git a/src/pan3d/explorers/analytics.py b/src/pan3d/explorers/analytics.py index 2f2f8489..a880e7bc 100644 --- a/src/pan3d/explorers/analytics.py +++ b/src/pan3d/explorers/analytics.py @@ -1,5 +1,4 @@ import vtkmodules.vtkRenderingOpenGL2 # noqa: F401 -from vtkmodules.vtkCommonCore import vtkLookupTable from vtkmodules.vtkFiltersGeometry import vtkDataSetSurfaceFilter # VTK factory initialization @@ -16,10 +15,10 @@ from pan3d.ui.analytics import Plotting from pan3d.ui.preview import RenderingSettings -from pan3d.ui.vtk_view import Pan3DScalarBar, Pan3DView +from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar -from pan3d.utils.convert import to_float, to_image -from pan3d.utils.presets import set_preset +from pan3d.utils.convert import to_float +from pan3d.widgets.color import ScalarBar from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import html @@ -64,15 +63,12 @@ def _setup_vtk(self): self.interactor.SetRenderWindow(self.render_window) self.interactor.GetInteractorStyle().SetCurrentStyleToTrackballCamera() - self.lut = vtkLookupTable() - # Need explicit geometry extraction when used with WASM self.geometry = vtkDataSetSurfaceFilter( input_connection=self.source.output_port ) self.mapper = vtkPolyDataMapper( input_connection=self.geometry.output_port, - lookup_table=self.lut, ) self.actor = vtkActor(mapper=self.mapper, visibility=0) @@ -128,6 +124,7 @@ def _build_ui(self, **kwargs): label="File path to save", v_model=("save_dataset_path", ""), hide_details=True, + change=self.property_changed, ) with v3.VCardActions(): v3.VSpacer() @@ -171,7 +168,7 @@ def _build_ui(self, **kwargs): ) # Scalar bar - Pan3DScalarBar( + ScalarBar( v_show="!control_expended", v_if="color_by", img_src="preset_img", @@ -194,7 +191,8 @@ def _build_ui(self, **kwargs): panel_label="Analytics Explorer", ).ui_content: self.ctrl.source_update_rendering_panel = RenderingSettings( - self.source, + self.retrieve_source, + self.retrieve_mapper, self.update_rendering, ).update_from_source @@ -210,50 +208,22 @@ def _build_ui(self, **kwargs): source=self.source, toggle="chart_expanded" ) + def retrieve_mapper(self): + """Used as a callback to retrieve the mapper.""" + return self.mapper + + def retrieve_source(self): + """Used as a callback to retrieve the source.""" + return self.source + # ----------------------------------------------------- # State change callbacks # ----------------------------------------------------- @change("color_by") - def _on_color_by(self, color_by, **__): - if self.source.input is None: - return - - ds = self.source() - if color_by in ds.point_data.keys(): # vtk is missing in iter - array = ds.point_data[color_by] - min_value, max_value = array.GetRange() - - self.state.color_min = min_value - self.state.color_max = max_value - - self.mapper.SelectColorArray(color_by) - self.mapper.SetScalarModeToUsePointFieldData() - self.mapper.InterpolateScalarsBeforeMappingOn() - self.mapper.SetScalarVisibility(1) - else: - self.mapper.SetScalarVisibility(0) - self.state.color_min = 0 - self.state.color_max = 1 - + def _on_color_by(self, **__): self.plotting.update_plot() - @change("color_preset", "color_min", "color_max", "nan_color") - def _on_color_preset( - self, nan_color, nan_colors, color_preset, color_min, color_max, **_ - ): - color_min = float(color_min) - color_max = float(color_max) - self.mapper.SetScalarRange(color_min, color_max) - - color = nan_colors[nan_color] - self.lut.SetNanColor(color) - - set_preset(self.lut, color_preset) - self.state.preset_img = to_image(self.lut, 255) - - self.ctrl.view_update() - @change("scale_x", "scale_y", "scale_z") def _on_scale_change(self, scale_x, scale_y, scale_z, **_): self.actor.SetScale( diff --git a/src/pan3d/explorers/contour.py b/src/pan3d/explorers/contour.py index 2d53162a..51702201 100644 --- a/src/pan3d/explorers/contour.py +++ b/src/pan3d/explorers/contour.py @@ -1,5 +1,4 @@ import vtkmodules.vtkRenderingOpenGL2 # noqa: F401 -from vtkmodules.vtkCommonCore import vtkLookupTable from vtkmodules.vtkCommonDataModel import vtkDataObject, vtkDataSetAttributes from vtkmodules.vtkFiltersCore import ( vtkAssignAttribute, @@ -25,10 +24,11 @@ ) from pan3d.ui.contour import ContourRenderingSettings -from pan3d.ui.vtk_view import Pan3DScalarBar, Pan3DView +from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar -from pan3d.utils.convert import to_float, to_image -from pan3d.utils.presets import set_preset +from pan3d.utils.convert import to_float +from pan3d.widgets.color import ScalarBar +from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import vuetify3 as v3 @@ -37,6 +37,12 @@ class ContourExplorer(Explorer): def __init__(self, xarray=None, source=None, server=None, local_rendering=None): super().__init__(xarray, source, server, local_rendering) + + if self.source is None: + self.source = vtkXArrayRectilinearSource( + input=self.xarray + ) # To initialize the pipeline + # setup self.last_field = None self.last_preset = None @@ -51,8 +57,6 @@ def __init__(self, xarray=None, source=None, server=None, local_rendering=None): def _setup_vtk(self): ds = self.source() - self.lut = vtkLookupTable() - self.renderer = vtkRenderer(background=(0.8, 0.8, 0.8)) self.interactor = vtkRenderWindowInteractor() self.render_window = vtkRenderWindow(off_screen_rendering=1) @@ -86,7 +90,6 @@ def _setup_vtk(self): input_connection=self.bands.output_port, scalar_visibility=1, interpolate_scalars_before_mapping=1, - lookup_table=self.lut, ) self.mapper.SetScalarModeToUsePointFieldData() self.actor = vtkActor(mapper=self.mapper) @@ -141,7 +144,7 @@ def _build_ui(self, **_): ) # Scalar bar - Pan3DScalarBar( + ScalarBar( v_show="!control_expended", v_if="color_by", img_src="preset_img", @@ -205,7 +208,8 @@ def _build_ui(self, **_): panel_label="Contour Explorer", ).ui_content: self.ctrl.source_update_rendering_panel = ContourRenderingSettings( - self.source, + self.retrieve_source, + self.retrieve_mapper, self.update_rendering, ).update_from_source @@ -217,14 +221,13 @@ def update_rendering(self, reset_camera=False): self.ctrl.view_reset_camera() - def reset_color_range(self): - if self.state.color_by is None: - return + def retrieve_mapper(self): + """Used as a callback to retrieve the mapper.""" + return self.mapper - field_array = self.source.input[self.state.color_by].values - with self.state: - self.state.color_min = float(field_array.min()) - self.state.color_max = float(field_array.max()) + def retrieve_source(self): + """Used as a callback to retrieve the source.""" + return self.source # ----------------------------------------------------- # State change callbacks @@ -266,20 +269,13 @@ def _on_update_data(self, color_by, time_idx, **_): # update range if self.last_field != color_by: self.last_field = color_by - self.reset_color_range() self.ctrl.view_update() - @change("color_min", "color_max", "color_preset", "nan_color", "nb_contours") + @change("nb_contours") def _on_update_color_range( self, nb_contours, color_min, color_max, color_preset, **_ ): - if self.last_preset != color_preset: - self.last_preset = color_preset - set_preset(self.lut, color_preset) - self.state.preset_img = to_image(self.lut, 255) - - self.mapper.SetScalarRange(color_min, color_max) self.bands.GenerateValues(nb_contours, [color_min, color_max]) self.ctrl.view_update() diff --git a/src/pan3d/explorers/globe.py b/src/pan3d/explorers/globe.py index d79a0076..ed593cc1 100644 --- a/src/pan3d/explorers/globe.py +++ b/src/pan3d/explorers/globe.py @@ -3,7 +3,7 @@ from pathlib import Path import vtkmodules.vtkRenderingOpenGL2 # noqa: F401 -from vtkmodules.vtkCommonCore import vtkLookupTable, vtkObject +from vtkmodules.vtkCommonCore import vtkObject from vtkmodules.vtkFiltersGeometry import vtkDataSetSurfaceFilter # VTK factory initialization @@ -23,11 +23,12 @@ from pan3d.filters.globe import ProjectToSphere from pan3d.ui.globe import GlobeRenderingSettings -from pan3d.ui.vtk_view import Pan3DScalarBar, Pan3DView +from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar -from pan3d.utils.convert import to_image, update_camera +from pan3d.utils.convert import update_camera from pan3d.utils.globe import get_continent_outlines, get_globe, get_globe_textures -from pan3d.utils.presets import set_preset +from pan3d.widgets.color import ScalarBar +from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from trame.app import asynchronous from trame.decorators import change from trame.ui.vuetify3 import VAppLayout @@ -48,7 +49,8 @@ class GlobeExplorer(Explorer): def __init__(self, xarray=None, source=None, server=None, local_rendering=None): super().__init__(xarray, source, server, local_rendering) - + if self.source is None: + self.source = vtkXArrayRectilinearSource() # To initialize the pipeline self.textures = get_globe_textures() self.state.textures = list(self.textures.keys()) @@ -60,8 +62,6 @@ def __init__(self, xarray=None, source=None, server=None, local_rendering=None): # ------------------------------------------------------------------------- def _setup_vtk(self): - self.lut = vtkLookupTable() - self.renderer = vtkRenderer(background=(0.8, 0.8, 0.8)) self.interactor = vtkRenderWindowInteractor() self.render_window = vtkRenderWindow(off_screen_rendering=1) @@ -87,9 +87,7 @@ def _setup_vtk(self): input_connection=self.dglobe.output_port ) - self.mapper = vtkPolyDataMapper( - input_connection=self.geometry.output_port, lookup_table=self.lut - ) + self.mapper = vtkPolyDataMapper(input_connection=self.geometry.output_port) self.actor = vtkActor(mapper=self.mapper, visibility=0) # Camera @@ -140,7 +138,7 @@ def _build_ui(self, **kwargs): ) # Scalar bar - Pan3DScalarBar( + ScalarBar( v_show="!control_expended", v_if="color_by", img_src="preset_img", @@ -203,59 +201,23 @@ def _build_ui(self, **kwargs): panel_label="Globe Explorer", ).ui_content: self.ctrl.source_update_rendering_panel = GlobeRenderingSettings( - self.source, + self.retrieve_source, + self.retrieve_mapper, self.update_rendering, ).update_from_source + def retrieve_mapper(self): + """Used as a callback to retrieve the mapper.""" + return self.mapper + + def retrieve_source(self): + """Used as a callback to retrieve the source.""" + return self.source + # ----------------------------------------------------- # State change callbacks # ----------------------------------------------------- - @change("color_by") - def _on_color_by(self, color_by, **__): - if self.source.input is None: - return - - ds = self.source() - if color_by in ds.point_data.keys(): # vtk is missing in iter - array = ds.point_data[color_by] - min_value, max_value = array.GetRange() - - self.state.color_min = min_value - self.state.color_max = max_value - - self.mapper.SelectColorArray(color_by) - self.mapper.SetScalarModeToUsePointFieldData() - self.mapper.InterpolateScalarsBeforeMappingOn() - self.mapper.SetScalarVisibility(1) - else: - self.mapper.SetScalarVisibility(0) - self.state.color_min = 0 - self.state.color_max = 1 - - @change("color_preset", "color_min", "color_max", "nan_color") - def _on_color_preset( - self, - nan_color, - nan_colors, - color_preset, - color_min, - color_max, - opacity, - **_, - ): - color_min = float(color_min) - color_max = float(color_max) - self.mapper.SetScalarRange(color_min, color_max) - - set_preset(self.lut, color_preset) - self.state.preset_img = to_image(self.lut, 255) - - color = nan_colors[nan_color] - self.lut.SetNanColor(color) - - self.ctrl.view_update() - @change("opacity", "representation", "cell_size", "render_shadow") def _on_change_opacity( self, representation, opacity, cell_size, render_shadow, **_ diff --git a/src/pan3d/explorers/slicer.py b/src/pan3d/explorers/slicer.py index f6fe75e4..1c9ca406 100644 --- a/src/pan3d/explorers/slicer.py +++ b/src/pan3d/explorers/slicer.py @@ -1,5 +1,4 @@ import vtkmodules.vtkRenderingOpenGL2 # noqa: F401 -from vtkmodules.vtkCommonCore import vtkLookupTable from vtkmodules.vtkCommonDataModel import ( vtkPlane, ) @@ -26,10 +25,10 @@ ) from pan3d.ui.slicer import SliceRenderingSettings -from pan3d.ui.vtk_view import Pan3DScalarBar, Pan3DView +from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar -from pan3d.utils.convert import to_image -from pan3d.utils.presets import set_preset +from pan3d.widgets.color import ScalarBar +from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import html @@ -145,6 +144,8 @@ class SliceExplorer(Explorer): def __init__(self, xarray=None, source=None, server=None, local_rendering=None): super().__init__(xarray, source, server, local_rendering) + if self.source is None: + self.source = vtkXArrayRectilinearSource() # To initialize the pipeline self._setup_vtk() self._build_ui() @@ -159,9 +160,6 @@ def _setup_vtk(self): 0.5 * (bounds[4] + bounds[5]), ] - # Create lookup table - self.lut = vtkLookupTable() - # Build rendering pipeline self.renderer = vtkRenderer() self.interactor = vtkRenderWindowInteractor() @@ -174,7 +172,7 @@ def _setup_vtk(self): cutter.SetCutFunction(plane) cutter.input_connection = self.source.output_port slice_actor = vtkActor() - slice_mapper = vtkDataSetMapper(lookup_table=self.lut) + slice_mapper = vtkDataSetMapper() slice_mapper.SetInputConnection(cutter.GetOutputPort()) slice_mapper.SetScalarModeToUsePointFieldData() slice_mapper.InterpolateScalarsBeforeMappingOn() @@ -186,7 +184,7 @@ def _setup_vtk(self): outline = vtkOutlineFilter() outline_actor = vtkActor() - outline_mapper = vtkPolyDataMapper(lookup_table=self.lut) + outline_mapper = vtkPolyDataMapper() outline.input_connection = self.source.output_port outline_mapper.SetInputConnection(outline.GetOutputPort()) outline_actor.SetMapper(outline_mapper) @@ -196,7 +194,7 @@ def _setup_vtk(self): self.outline_mapper = outline_mapper data_actor = vtkActor() - data_mapper = vtkDataSetMapper(lookup_table=self.lut) + data_mapper = vtkDataSetMapper() data_mapper.input_connection = self.source.output_port data_actor.SetMapper(data_mapper) data_actor.GetProperty().SetOpacity(0.1) @@ -254,7 +252,7 @@ def _build_ui(self, *args, **kwargs): ) # Scalar bar - Pan3DScalarBar( + ScalarBar( v_show="!control_expended", v_if="color_by", img_src="preset_img", @@ -325,10 +323,19 @@ def _build_ui(self, *args, **kwargs): panel_label="Slice Explorer", ).ui_content: self.ctrl.source_update_rendering_panel = SliceRenderingSettings( - self.source, + self.retrieve_source, + self.retrieve_mapper, self.update_rendering, ).update_from_source + def retrieve_mapper(self): + """Used as a callback to retrieve the mapper.""" + return self.slice_mapper + + def retrieve_source(self): + """Used as a callback to retrieve the source.""" + return self.source + def update_rendering(self, reset_camera=False): self.renderer.ResetCamera() @@ -439,34 +446,6 @@ def _on_scale_change(self, scale_x, scale_y, scale_z, **_): self.on_view_mode_change(self.state.view_mode) - @change("color_by") - def _on_color_by_change(self, color_by, **_): - if color_by is None: - return - - color_min, color_max = self.source().point_data[color_by].GetRange() - - self.slice_mapper.SetScalarRange(color_min, color_max) - self.slice_mapper.SelectColorArray(color_by) - - self.state.color_min = color_min - self.state.color_max = color_max - - @change("color_min", "color_max", "color_preset", "nan_color") - def _on_update_color_range( - self, color_min, color_max, color_preset, nan_color, nan_colors, **_ - ): - set_preset(self.lut, color_preset) - self.state.preset_img = to_image(self.lut, 255) - - color = nan_colors[nan_color] - self.lut.SetNanColor(color) - - color_min = float(color_min) - color_max = float(color_max) - self.slice_mapper.SetScalarRange(color_min, color_max) - self.ctrl.view_update() - def _set_view_2D(self, axis): camera = self.renderer.GetActiveCamera() view_up = [0, 0, 1] if axis == 1 else [0, 1, 0] diff --git a/src/pan3d/ui/contour.py b/src/pan3d/ui/contour.py index 026d9c2c..e8fd5e3c 100644 --- a/src/pan3d/ui/contour.py +++ b/src/pan3d/ui/contour.py @@ -7,15 +7,10 @@ class ContourRenderingSettings(RenderingSettingsBasic): - def __init__(self, source, update_rendering): - super().__init__(source, update_rendering) + def __init__(self, retrieve_source, retrieve_mapper, update_rendering): + super().__init__(retrieve_source, retrieve_mapper, update_rendering) - self.source = source - fields = list(self.source.available_arrays) - active_field = fields[0] if len(fields) > 0 else None - nb_times = ( - self.source.input[active_field].shape[0] if active_field is not None else 0 - ) + self._retrieve_source = retrieve_source with self.content: # Actor scaling @@ -112,7 +107,7 @@ def __init__(self, source, update_rendering): prepend_icon="mdi-clock-outline", v_model=("time_idx", 0), min=0, - max=("slice_t_max", nb_times - 1), + max=("slice_t_max", 0), step=1, hide_details=True, density="compact", @@ -133,25 +128,25 @@ def __init__(self, source, update_rendering): ) def update_from_source(self, source=None): + state = self.state + source = source or self._retrieve_source() if source is None: - source = self.source + return with self.state: - self.state.data_arrays_available = source.available_arrays - self.state.data_arrays = source.arrays - self.state.color_by = None - self.state.axis_names = [source.x, source.y, source.z] - self.state.slice_extents = source.slice_extents + state.data_arrays_available = source.available_arrays + state.data_arrays = source.arrays + state.color_by = None + state.axis_names = [source.x, source.y, source.z] + state.slice_extents = source.slice_extents # Update time - self.state.slice_t = source.t_index - self.state.slice_t_max = source.t_size - 1 - self.state.t_labels = source.t_labels - self.state.max_time_width = math.ceil( - 0.58 * max_str_length(self.state.t_labels) - ) + state.slice_t = source.t_index + state.slice_t_max = source.t_size - 1 + state.t_labels = source.t_labels + state.max_time_width = math.ceil(0.58 * max_str_length(state.t_labels)) - if self.state.slice_t_max > 0: - self.state.max_time_index_width = math.ceil( - 0.6 + (math.log10(self.state.slice_t_max + 1) + 1) * 2 * 0.58 + if state.slice_t_max > 0: + state.max_time_index_width = math.ceil( + 0.6 + (math.log10(state.slice_t_max + 1) + 1) * 2 * 0.58 ) diff --git a/src/pan3d/ui/globe.py b/src/pan3d/ui/globe.py index fa629469..15df5b70 100644 --- a/src/pan3d/ui/globe.py +++ b/src/pan3d/ui/globe.py @@ -10,9 +10,10 @@ @TrameApp() class GlobeRenderingSettings(RenderingSettingsBasic): - def __init__(self, source, update_rendering): - super().__init__("Rendering", "show_rendering") - self.source = source + def __init__(self, retrieve_source, retrieve_mapper, update_rendering): + super().__init__(retrieve_source, retrieve_mapper, update_rendering) + self._retrieve_source = retrieve_source + with self.content: v3.VDivider() v3.VSelect( @@ -341,90 +342,81 @@ def __init__(self, source, update_rendering): ) def update_from_source(self, source=None): + state = self.state + source = source or self._retrieve_source() if source is None: - source = self.source + return with self.state: - self.state.data_arrays_available = source.available_arrays - self.state.data_arrays = source.arrays - self.state.color_by = None - self.state.axis_names = [source.x, source.y, source.z] - self.state.slice_extents = source.slice_extents + state.data_arrays_available = source.available_arrays + state.data_arrays = source.arrays + state.color_by = None + state.axis_names = [source.x, source.y, source.z] + state.slice_extents = source.slice_extents slices = source.slices for axis in XYZ: # default axis_extent = self.state.slice_extents.get(getattr(source, axis)) - self.state[f"slice_{axis}_range"] = axis_extent - self.state[f"slice_{axis}_cut"] = 0 - self.state[f"slice_{axis}_step"] = 1 - self.state[f"slice_{axis}_type"] = "range" + state[f"slice_{axis}_range"] = axis_extent + state[f"slice_{axis}_cut"] = 0 + state[f"slice_{axis}_step"] = 1 + state[f"slice_{axis}_type"] = "range" # use slice info if available axis_slice = slices.get(getattr(source, axis)) if axis_slice is not None: if isinstance(axis_slice, int): # cut - self.state[f"slice_{axis}_cut"] = axis_slice - self.state[f"slice_{axis}_type"] = "cut" + state[f"slice_{axis}_cut"] = axis_slice + state[f"slice_{axis}_type"] = "cut" else: # range - self.state[f"slice_{axis}_range"] = [ + state[f"slice_{axis}_range"] = [ axis_slice[0], axis_slice[1] - 1, ] # end is inclusive - self.state[f"slice_{axis}_step"] = axis_slice[2] + state[f"slice_{axis}_step"] = axis_slice[2] # Update time - self.state.slice_t = source.t_index - self.state.slice_t_max = source.t_size - 1 - self.state.t_labels = source.t_labels - self.state.max_time_width = math.ceil( - 0.58 * max_str_length(self.state.t_labels) - ) - if self.state.slice_t_max > 0: - self.state.max_time_index_width = math.ceil( - 0.6 + (math.log10(self.state.slice_t_max + 1) + 1) * 2 * 0.58 + state.slice_t = source.t_index + state.slice_t_max = source.t_size - 1 + state.t_labels = source.t_labels + state.max_time_width = math.ceil(0.58 * max_str_length(state.t_labels)) + if state.slice_t_max > 0: + state.max_time_index_width = math.ceil( + 0.6 + (math.log10(state.slice_t_max + 1) + 1) * 2 * 0.58 ) - def reset_color_range(self): - color_by = self.state.color_by - ds = self.source() - if color_by in ds.point_data.keys(): # vtk is missing in iter - array = ds.point_data[color_by] - min_value, max_value = array.GetRange() - - self.state.color_min = min_value - self.state.color_max = max_value - else: - self.state.color_min = 0 - self.state.color_max = 1 - @change("slice_t", *[var.format(axis) for axis in XYZ for var in SLICE_VARS]) def on_change(self, slice_t, **_): if self.state.import_pending: return + source = self._retrieve_source() + if source is None: + return - slices = {self.source.t: slice_t} + state = self.state + slices = {source.t: slice_t} for axis in XYZ: - axis_name = getattr(self.source, axis) + axis_name = getattr(source, axis) if axis_name is None: continue - if self.state[f"slice_{axis}_type"] == "range": - if self.state[f"slice_{axis}_range"] is None: + if state[f"slice_{axis}_type"] == "range": + if state[f"slice_{axis}_range"] is None: continue slices[axis_name] = [ - *self.state[f"slice_{axis}_range"], - int(self.state[f"slice_{axis}_step"]), + *state[f"slice_{axis}_range"], + int(state[f"slice_{axis}_step"]), ] slices[axis_name][1] += 1 # end is exclusive else: - slices[axis_name] = self.state[f"slice_{axis}_cut"] + slices[axis_name] = state[f"slice_{axis}_cut"] - self.source.slices = slices - ds = self.source() - self.state.dataset_bounds = ds.bounds + source.slices = slices + ds = source() + state.dataset_bounds = ds.bounds self.ctrl.view_reset_clipping_range() self.ctrl.view_update() @@ -433,6 +425,7 @@ def on_change(self, slice_t, **_): def _on_slice_t(self, slice_t, **_): if self.state.import_pending: return - - self.source.t_index = slice_t - self.ctrl.view_update() + source = self._retrieve_source() + if source is not None: + source.t_index = slice_t + self.ctrl.view_update() diff --git a/src/pan3d/ui/preview.py b/src/pan3d/ui/preview.py index 3ee377e6..053dae59 100644 --- a/src/pan3d/ui/preview.py +++ b/src/pan3d/ui/preview.py @@ -11,10 +11,10 @@ class RenderingSettings(RenderingSettingsBasic): - def __init__(self, source, update_rendering, **kwargs): - super().__init__(source, update_rendering, **kwargs) + def __init__(self, retrieve_source, retrieve_mapper, update_rendering, **kwargs): + super().__init__(retrieve_source, retrieve_mapper, update_rendering, **kwargs) - self.source = source + self._retrieve_source = retrieve_source self.state.setdefault("slice_extents", {}) self.state.setdefault("axis_names", []) self.state.setdefault("t_labels", []) @@ -311,7 +311,7 @@ def __init__(self, source, update_rendering, **kwargs): def update_from_source(self, source=None): if source is None: - source = self.source + source = self._retrieve_source() with self.state: self.state.data_arrays_available = source.available_arrays @@ -355,19 +355,6 @@ def update_from_source(self, source=None): 0.6 + (math.log10(self.state.slice_t_max + 1) + 1) * 2 * 0.58 ) - def reset_color_range(self): - color_by = self.state.color_by - ds = self.source() - if color_by in ds.point_data.keys(): # vtk is missing in iter - array = ds.point_data[color_by] - min_value, max_value = array.GetRange() - - self.state.color_min = min_value - self.state.color_max = max_value - else: - self.state.color_min = 0 - self.state.color_max = 1 - @change("data_origin_source") def _on_data_origin_source(self, data_origin_source, **kwargs): if self.state.import_pending: @@ -395,12 +382,16 @@ def _on_data_origin_id(self, data_origin_id, data_origin_source, **kwargs): @change("slice_t", *[var.format(axis) for axis in XYZ for var in SLICE_VARS]) def on_change(self, slice_t, **_): + source = self._retrieve_source() + if source is None: + return + if self.state.import_pending: return - slices = {self.source.t: slice_t} + slices = {source.t: slice_t} for axis in XYZ: - axis_name = getattr(self.source, axis) + axis_name = getattr(source, axis) if axis_name is None: continue @@ -415,8 +406,8 @@ def on_change(self, slice_t, **_): else: slices[axis_name] = self.state[f"slice_{axis}_cut"] - self.source.slices = slices - ds = self.source() + source.slices = slices + ds = source() self.state.dataset_bounds = ds.bounds self.ctrl.view_reset_clipping_range() @@ -424,10 +415,13 @@ def on_change(self, slice_t, **_): @change("slice_t") def _on_slice_t(self, slice_t, **_): + source = self._retrieve_source() + if source is None: + return if self.state.import_pending: return - self.source.t_index = slice_t + source.t_index = slice_t self.ctrl.view_update() @change("data_arrays") @@ -440,5 +434,6 @@ def _on_array_selection(self, data_arrays, **_): self.state.color_by = data_arrays[0] elif len(data_arrays) == 0: self.state.color_by = None - - self.source.arrays = data_arrays + source = self._retrieve_source() + if source is not None: + source.arrays = data_arrays diff --git a/src/pan3d/ui/slicer.py b/src/pan3d/ui/slicer.py index 24ed7a9b..9d1f1a17 100644 --- a/src/pan3d/ui/slicer.py +++ b/src/pan3d/ui/slicer.py @@ -7,9 +7,9 @@ class SliceRenderingSettings(RenderingSettingsBasic): - def __init__(self, source, update_rendering): - super().__init__(source, update_rendering) - self.source = source + def __init__(self, retrieve_source, retrieve_mapper, update_rendering): + super().__init__(retrieve_source, retrieve_mapper, update_rendering) + self._retrieve_source = retrieve_source style = {"density": "compact", "hide_details": True} with self.content: @@ -168,8 +168,10 @@ def __init__(self, source, update_rendering): ) def update_from_source(self, source=None): + state = self.state + source = source or self._retrieve_source() if source is None: - source = self.source + return ds = source() bounds = ds.bounds @@ -178,33 +180,31 @@ def update_from_source(self, source=None): 0.5 * (bounds[2] + bounds[3]), 0.5 * (bounds[4] + bounds[5]), ] - with self.state: - self.state.data_arrays_available = source.available_arrays - self.state.data_arrays = source.arrays + with state: + state.data_arrays_available = source.available_arrays + state.data_arrays = source.arrays - self.state.color_by = None - self.state.axis_names = [ + state.color_by = None + state.axis_names = [ x for x in [source.x, source.y, source.z] if x is not None ] - self.state.slice_extents = source.slice_extents + state.slice_extents = source.slice_extents # Update time - self.state.slice_t = source.t_index - self.state.slice_t_max = source.t_size - 1 - self.state.t_labels = source.t_labels - self.state.max_time_width = math.ceil( - 0.58 * max_str_length(self.state.t_labels) - ) + state.slice_t = source.t_index + state.slice_t_max = source.t_size - 1 + state.t_labels = source.t_labels + state.max_time_width = math.ceil(0.58 * max_str_length(state.t_labels)) - if self.state.slice_t_max > 0: - self.state.max_time_index_width = math.ceil( - 0.6 + (math.log10(self.state.slice_t_max + 1) + 1) * 2 * 0.58 + if state.slice_t_max > 0: + state.max_time_index_width = math.ceil( + 0.6 + (math.log10(state.slice_t_max + 1) + 1) * 2 * 0.58 ) # Update state from dataset - self.state.bounds = ds.bounds - self.state.cut_x = origin[0] - self.state.cut_y = origin[1] - self.state.cut_z = origin[2] - self.state.slice_axis = source.z if source.z is not None else source.y - self.state.slice_axes = self.state.axis_names + state.bounds = ds.bounds + state.cut_x = origin[0] + state.cut_y = origin[1] + state.cut_z = origin[2] + state.slice_axis = source.z if source.z is not None else source.y + state.slice_axes = state.axis_names diff --git a/src/pan3d/ui/vtk_view.py b/src/pan3d/ui/vtk_view.py index 31c51627..4fc12a0e 100644 --- a/src/pan3d/ui/vtk_view.py +++ b/src/pan3d/ui/vtk_view.py @@ -234,50 +234,3 @@ def _on_view_type_change(self, view_3d, **_): if not self.state[self._import_pending]: self.ctrl.view_reset_camera() - - -class Pan3DScalarBar(v3.VTooltip): - def __init__(self, img_src, color_min="color_min", color_max="color_max", **kwargs): - super().__init__(location="top") - - # Activate CSS - self.server.enable_module(base) - self.server.enable_module(vtk_view) - - self.state.setdefault("scalarbar_probe", []) - self.state.client_only("scalarbar_probe", "scalarbar_probe_available") - - with self: - # Content - with html.Template(v_slot_activator="{ props }"): - with html.Div( - classes="scalarbar", - rounded="pill", - v_bind="props", - **kwargs, - ): - html.Div( - f"{{{{ {color_min}.toFixed(6) }}}}", classes="scalarbar-left" - ) - html.Img( - src=(img_src, None), - style="height: 100%; width: 100%;", - classes="rounded-lg border-thin", - mousemove="scalarbar_probe = [$event.x, $event.target.getBoundingClientRect()]", - mouseenter="scalarbar_probe_available = 1", - mouseleave="scalarbar_probe_available = 0", - __events=["mousemove", "mouseenter", "mouseleave"], - ) - html.Div( - v_show=("scalarbar_probe_available", False), - classes="scalar-cursor", - style=( - "`left: ${scalarbar_probe?.[0] - scalarbar_probe?.[1]?.left}px`", - ), - ) - html.Div( - f"{{{{ {color_max}.toFixed(6) }}}}", classes="scalarbar-right" - ) - html.Span( - f"{{{{ (({color_max} - {color_min}) * (scalarbar_probe?.[0] - scalarbar_probe?.[1]?.left) / scalarbar_probe?.[1]?.width + {color_min}).toFixed(6) }}}}" - ) diff --git a/src/pan3d/utils/common.py b/src/pan3d/utils/common.py index 21711935..66c0d25c 100644 --- a/src/pan3d/utils/common.py +++ b/src/pan3d/utils/common.py @@ -6,7 +6,7 @@ from pan3d.ui.collapsible import CollapsableSection from pan3d.ui.css import base, preview from pan3d.utils.convert import update_camera -from pan3d.utils.presets import PRESETS +from pan3d.widgets.color import ColorBy from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from trame.app import asynchronous, get_server from trame.decorators import TrameApp, change @@ -453,19 +453,6 @@ def __init__(self, load_dataset): flat=True, variant="solo", ) - - # v3.VDivider() - # v3.VSwitch( - # label=("`Order ${data_origin_order}`",), - # v_model=("data_origin_order", "C"), - # true_value="C", - # false_value="F", - # hide_details=True, - # density="compact", - # flat=True, - # variant="solo", - # classes="mx-6", - # ) v3.VDivider() v3.VBtn( "{{ load_button_text }}", @@ -711,10 +698,10 @@ def __init__( @TrameApp() class RenderingSettingsBasic(CollapsableSection): - def __init__(self, source, update_rendering): + def __init__(self, retrieve_source, retrieve_mapper, update_rendering): super().__init__("Rendering", "show_rendering") + self._retrieve_source = retrieve_source - self.source = source with self.content: v3.VSelect( placeholder="Data arrays", @@ -731,132 +718,16 @@ def __init__(self, source, update_rendering): variant="solo", ) v3.VDivider() - v3.VSelect( - placeholder="Color By", - prepend_inner_icon="mdi-format-color-fill", - v_model=("color_by", None), - items=("data_arrays", []), - clearable=True, - hide_details=True, - density="compact", - flat=True, - variant="solo", - ) - v3.VDivider() - with v3.VRow(no_gutters=True, classes="align-center mr-0"): - with v3.VCol(): - v3.VTextField( - prepend_inner_icon="mdi-water-minus", - v_model_number=("color_min", 0.45), - type="number", - hide_details=True, - density="compact", - flat=True, - variant="solo", - reverse=True, - ) - with v3.VCol(): - v3.VTextField( - prepend_inner_icon="mdi-water-plus", - v_model_number=("color_max", 5.45), - type="number", - hide_details=True, - density="compact", - flat=True, - variant="solo", - reverse=True, - ) - with html.Div(classes="flex-0"): - v3.VBtn( - icon="mdi-arrow-split-vertical", - size="sm", - density="compact", - flat=True, - variant="outlined", - classes="mx-2", - click=self.reset_color_range, - ) - # v3.VDivider() - with html.Div(classes="mx-2"): - html.Img( - src=("preset_img", None), - style="height: 0.75rem; width: 100%;", - classes="rounded-lg border-thin", - ) - v3.VSelect( - placeholder="Color Preset", - prepend_inner_icon="mdi-palette", - v_model=("color_preset", "Fast"), - items=("color_presets", list(PRESETS.keys())), - hide_details=True, - density="compact", - flat=True, - variant="solo", - ) - - with v3.VTooltip( - text=("`NaN Color (${nan_colors[nan_color]})`",), - ): - with html.Template(v_slot_activator="{ props }"): - with v3.VItemGroup( - v_model="nan_color", - v_bind="props", - classes="d-inline-flex ga-1 pa-2", - mandatory="force", - ): - v3.VIcon( - "mdi-eyedropper-variant", - classes="my-auto mx-1 text-medium-emphasis", - ) - with v3.VItem( - v_for="(color, i) in nan_colors", key="i", value=("i",) - ): - with v3.Template( - raw_attrs=['#default="{ isSelected, toggle }"'] - ): - with v3.VAvatar( - density="compact", - color=("isSelected ? 'primary': 'transparent'",), - ): - v3.VBtn( - "{{ color[3] < 0.1 ? 't' : '' }}", - density="compact", - border="md surface opacity-100", - color=( - "color[3] ? `rgb(${color[0] * 255}, ${color[1] * 255}, ${color[2] * 255})` : undefined", - ), - flat=True, - icon=True, - ripple=False, - size="small", - click="toggle", - ) + ColorBy(retrieve_source=retrieve_source, retrieve_mapper=retrieve_mapper) @change("data_arrays") def _on_array_selection(self, data_arrays, **_): if self.state.import_pending: return - self.state.dirty_data = True - if len(data_arrays) == 0: - self.state.color_by = None - elif self.state.color_by is None or self.state.color_by not in data_arrays: - self.state.color_by = data_arrays[0] - - self.source.arrays = data_arrays - - def reset_color_range(self): - color_by = self.state.color_by - ds = self.source() - if color_by in ds.point_data.keys(): # vtk is missing in iter - array = ds.point_data[color_by] - min_value, max_value = array.GetRange() - - self.state.color_min = min_value - self.state.color_max = max_value - else: - self.state.color_min = 0 - self.state.color_max = 1 + source = self._retrieve_source() + if source is not None: + source.arrays = data_arrays def update_from_source(self, source=None): raise NotImplementedError( diff --git a/src/pan3d/viewers/preview.py b/src/pan3d/viewers/preview.py index 4cbb1b8c..4803b4bf 100644 --- a/src/pan3d/viewers/preview.py +++ b/src/pan3d/viewers/preview.py @@ -1,5 +1,4 @@ import vtkmodules.vtkRenderingOpenGL2 # noqa: F401 -from vtkmodules.vtkCommonCore import vtkLookupTable from vtkmodules.vtkFiltersGeometry import vtkDataSetSurfaceFilter # VTK factory initialization @@ -15,10 +14,10 @@ ) from pan3d.ui.preview import RenderingSettings -from pan3d.ui.vtk_view import Pan3DScalarBar, Pan3DView +from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar -from pan3d.utils.convert import to_float, to_image -from pan3d.utils.presets import set_preset +from pan3d.utils.convert import to_float +from pan3d.widgets.color import ScalarBar from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from trame.decorators import change from trame.ui.vuetify3 import VAppLayout @@ -49,7 +48,6 @@ def _setup_vtk(self): self.interactor.SetRenderWindow(self.render_window) self.interactor.GetInteractorStyle().SetCurrentStyleToTrackballCamera() - self.lut = vtkLookupTable() self.source = vtkXArrayRectilinearSource(input=self.xarray) # Need explicit geometry extraction when used with WASM @@ -58,7 +56,6 @@ def _setup_vtk(self): ) self.mapper = vtkPolyDataMapper( input_connection=self.geometry.output_port, - lookup_table=self.lut, ) self.actor = vtkActor(mapper=self.mapper, visibility=0) @@ -76,6 +73,14 @@ def _setup_vtk(self): # UI # ------------------------------------------------------------------------- + def retrieve_mapper(self): + """Used as a callback to retrieve the mapper.""" + return self.mapper + + def retrieve_source(self): + """Used as a callback to retrieve the source.""" + return self.source + def _build_ui(self, **kwargs): self.state.trame__title = "XArray Viewer" @@ -90,7 +95,7 @@ def _build_ui(self, **kwargs): ) # Scalar bar - Pan3DScalarBar( + ScalarBar( v_show="!control_expended", v_if="color_by", img_src="preset_img", @@ -152,7 +157,8 @@ def _build_ui(self, **kwargs): xr_update_info="xr_update_info", ).ui_content: self.ctrl.source_update_rendering_panel = RenderingSettings( - self.source, + self.retrieve_source, + self.retrieve_mapper, self.update_rendering, ).update_from_source @@ -160,44 +166,6 @@ def _build_ui(self, **kwargs): # State change callbacks # ----------------------------------------------------- - @change("color_by") - def _on_color_by(self, color_by, **__): - if self.source.input is None: - return - - ds = self.source() - if color_by in ds.point_data.keys(): - array = ds.point_data[color_by] - min_value, max_value = array.GetRange() - - self.state.color_min = min_value - self.state.color_max = max_value - - self.mapper.SelectColorArray(color_by) - self.mapper.SetScalarModeToUsePointFieldData() - self.mapper.InterpolateScalarsBeforeMappingOn() - self.mapper.SetScalarVisibility(1) - else: - self.mapper.SetScalarVisibility(0) - self.state.color_min = 0 - self.state.color_max = 1 - - @change("color_preset", "color_min", "color_max", "nan_color") - def _on_color_preset( - self, nan_color, nan_colors, color_preset, color_min, color_max, **_ - ): - color_min = float(color_min) - color_max = float(color_max) - self.mapper.SetScalarRange(color_min, color_max) - - color = nan_colors[nan_color] - self.lut.SetNanColor(color) - - set_preset(self.lut, color_preset) - self.state.preset_img = to_image(self.lut, 255) - - self.ctrl.view_update() - @change("scale_x", "scale_y", "scale_z") def _on_scale_change(self, scale_x, scale_y, scale_z, **_): self.actor.SetScale( diff --git a/src/pan3d/widgets/__init__.py b/src/pan3d/widgets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pan3d/widgets/color.py b/src/pan3d/widgets/color.py new file mode 100644 index 00000000..126e4eac --- /dev/null +++ b/src/pan3d/widgets/color.py @@ -0,0 +1,291 @@ +from vtkmodules.vtkCommonCore import vtkLookupTable +from vtkmodules.vtkRenderingCore import vtkMapper + +from pan3d.ui.css import base, vtk_view +from pan3d.utils.convert import to_image +from pan3d.utils.presets import PRESETS, set_preset +from trame.widgets import html +from trame.widgets import vuetify3 as v3 + + +class ColorBy(html.Div): + """Color settings for the XArray Explorers. + Arguments: + source: The source of the data to be colored. + color_by: The name of the data array to color by. + data_arrays: The list of available data arrays. + color_min: The minimum value for the color range. + color_max: The maximum value for the color range. + nan_color: The color to use for NaN values. + color_preset: The name of the color preset to use. + color_presets: The list of available color presets. + """ + + def __init__( + self, + retrieve_source=None, + retrieve_mapper=None, + color_by="color_by", + data_arrays="data_arrays", + color_min="color_min", + color_max="color_max", + nan_color="nan_color", + color_preset="color_preset", + color_presets="color_presets", + preset_img="preset_img", + **kwargs, + ): + super().__init__(**kwargs) + + self.lut = vtkLookupTable() + + # initialize component specific variables + self._retrieve_source = retrieve_source + self._retrieve_mapper = retrieve_mapper + self._color_by = color_by + self._data_arrays = data_arrays + self._color_min = color_min + self._color_max = color_max + self._nan_color = nan_color + self._color_preset = color_preset + self._color_presets = color_presets + self._preset_img = preset_img + + # Track state changes + self.state.change(data_arrays)(self._on_change_data_arrays) + self.state.change(color_by)(self._on_change_color_by) + self.state.change(color_min, color_max, color_preset, nan_color)( + self._on_change_properties + ) + + with self: + v3.VSelect( + placeholder="Color By", + prepend_inner_icon="mdi-format-color-fill", + v_model=(color_by, None), + items=(data_arrays, []), + clearable=True, + hide_details=True, + density="compact", + flat=True, + variant="solo", + ) + v3.VDivider() + with v3.VRow(no_gutters=True, classes="align-center mr-0"): + with v3.VCol(): + v3.VTextField( + prepend_inner_icon="mdi-water-minus", + v_model_number=(color_min, 0.45), + type="number", + hide_details=True, + density="compact", + flat=True, + variant="solo", + reverse=True, + ) + with v3.VCol(): + v3.VTextField( + prepend_inner_icon="mdi-water-plus", + v_model_number=(color_max, 5.45), + type="number", + hide_details=True, + density="compact", + flat=True, + variant="solo", + reverse=True, + ) + with html.Div(classes="flex-0"): + v3.VBtn( + icon="mdi-arrow-split-vertical", + size="sm", + density="compact", + flat=True, + variant="outlined", + classes="mx-2", + click=self.reset_color_range, + ) + # v3.VDivider() + with html.Div(classes="mx-2"): + html.Img( + src=("preset_img", None), + style="height: 0.75rem; width: 100%;", + classes="rounded-lg border-thin", + ) + v3.VSelect( + placeholder="Color Preset", + prepend_inner_icon="mdi-palette", + v_model=(color_preset, "Fast"), + items=(color_presets, list(PRESETS.keys())), + hide_details=True, + density="compact", + flat=True, + variant="solo", + ) + + with v3.VTooltip( + text=("`NaN Color (${nan_colors[nan_color]})`",), + ): + with html.Template(v_slot_activator="{ props }"): + with v3.VItemGroup( + v_model=nan_color, + v_bind="props", + classes="d-inline-flex ga-1 pa-2", + mandatory="force", + ): + v3.VIcon( + "mdi-eyedropper-variant", + classes="my-auto mx-1 text-medium-emphasis", + ) + with v3.VItem( + v_for="(color, i) in nan_colors", key="i", value=("i",) + ): + with v3.Template( + raw_attrs=['#default="{ isSelected, toggle }"'] + ): + with v3.VAvatar( + density="compact", + color=("isSelected ? 'primary': 'transparent'",), + ): + v3.VBtn( + "{{ color[3] < 0.1 ? 't' : '' }}", + density="compact", + border="md surface opacity-100", + color=( + "color[3] ? `rgb(${color[0] * 255}, ${color[1] * 255}, ${color[2] * 255})` : undefined", + ), + flat=True, + icon=True, + ripple=False, + size="small", + click="toggle", + ) + + def _on_change_data_arrays(self, **__): + state = self.state + data_arrays = self.state[self._data_arrays] + color_by = state[self._color_by] + if len(data_arrays) == 0: + state[self._color_by] = None + elif color_by is None or color_by not in data_arrays: + state[self._color_by] = data_arrays[0] + + def _on_change_color_by(self, **__): + state = self.state + source = self._retrieve_source() + mapper: vtkMapper = self._retrieve_mapper() + + if source is None: + return + + color_by = state[self._color_by] + + ds = source() + if color_by in ds.point_data.keys(): # vtk is missing in iter + array = ds.point_data[color_by] + min_value, max_value = array.GetRange() + + state[self._color_min] = min_value + state[self._color_max] = max_value + + if mapper is not None: + mapper.SetLookupTable(self.lut) + mapper.SelectColorArray(color_by) + mapper.SetScalarModeToUsePointFieldData() + mapper.InterpolateScalarsBeforeMappingOn() + mapper.SetScalarVisibility(1) + else: + if mapper is not None: + mapper.SetScalarVisibility(0) + state[self._color_min] = 0 + state[self._color_max] = 1 + + def _on_change_properties(self, **__): + """Change the color properties based on the selected data array,preset and range""" + state = self.state + mapper: vtkMapper = self._retrieve_mapper() + + color_min = state[self._color_min] + color_max = state[self._color_max] + color_min = float(color_min) + color_max = float(color_max) + if mapper is not None: + mapper.SetLookupTable(self.lut) + mapper.SetScalarRange(color_min, color_max) + + nan_colors = state.nan_colors + nan_color = state[self._nan_color] + color = nan_colors[nan_color] + self.lut.SetNanColor(color) + + preset = state[self._color_preset] + set_preset(self.lut, preset) + state.preset_img = to_image(self.lut, 255) + + self.ctrl.view_update() + + def reset_color_range(self): + """Reset the color range to the min and max values of the selected data array.""" + state = self.state + color_by = state[self._color_by] + source = self._retrieve_source() + ds = source() + + if color_by in ds.point_data.keys(): # vtk is missing in iter + array = ds.point_data[color_by] + min_value, max_value = array.GetRange() + + state[self._color_min] = min_value + state[self._color_max] = max_value + else: + state[self._color_min] = 0 + state[self._color_max] = 1 + + self.ctrl.view_update() + + +class ScalarBar(v3.VTooltip): + def __init__(self, img_src, color_min="color_min", color_max="color_max", **kwargs): + """Scalar bar for the XArray Explorers.""" + super().__init__(location="top") + + # Activate CSS + self.server.enable_module(base) + self.server.enable_module(vtk_view) + + self.state.setdefault("scalarbar_probe", []) + self.state.client_only("scalarbar_probe", "scalarbar_probe_available") + + with self: + # Content + with html.Template(v_slot_activator="{ props }"): + with html.Div( + classes="scalarbar", + rounded="pill", + v_bind="props", + **kwargs, + ): + html.Div( + f"{{{{ {color_min}.toFixed(6) }}}}", classes="scalarbar-left" + ) + html.Img( + src=(img_src, None), + style="height: 100%; width: 100%;", + classes="rounded-lg border-thin", + mousemove="scalarbar_probe = [$event.x, $event.target.getBoundingClientRect()]", + mouseenter="scalarbar_probe_available = 1", + mouseleave="scalarbar_probe_available = 0", + __events=["mousemove", "mouseenter", "mouseleave"], + ) + html.Div( + v_show=("scalarbar_probe_available", False), + classes="scalar-cursor", + style=( + "`left: ${scalarbar_probe?.[0] - scalarbar_probe?.[1]?.left}px`", + ), + ) + html.Div( + f"{{{{ {color_max}.toFixed(6) }}}}", classes="scalarbar-right" + ) + html.Span( + f"{{{{ (({color_max} - {color_min}) * (scalarbar_probe?.[0] - scalarbar_probe?.[1]?.left) / scalarbar_probe?.[1]?.width + {color_min}).toFixed(6) }}}}" + ) diff --git a/src/trame/widgets/pan3d.py b/src/trame/widgets/pan3d.py new file mode 100644 index 00000000..15132712 --- /dev/null +++ b/src/trame/widgets/pan3d.py @@ -0,0 +1,3 @@ +from pan3d.widgets.color import ColorBy, ScalarBar + +__all__ = ["ColorBy", "ScalarBar"] From 1145c6dcb7df7623cc213871a999c3910ecf7ef1 Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Tue, 3 Jun 2025 16:09:01 -0700 Subject: [PATCH 2/4] feat (modularize): Adding scalar bar component --- src/pan3d/explorers/analytics.py | 2 +- src/pan3d/explorers/contour.py | 2 +- src/pan3d/explorers/globe.py | 2 +- src/pan3d/explorers/slicer.py | 2 +- src/pan3d/utils/common.py | 2 +- src/pan3d/viewers/preview.py | 13 ++- src/pan3d/widgets/{color.py => color_by.py} | 49 --------- src/pan3d/widgets/scalar_bar.py | 105 ++++++++++++++++++++ src/trame/widgets/pan3d.py | 2 +- 9 files changed, 121 insertions(+), 58 deletions(-) rename src/pan3d/widgets/{color.py => color_by.py} (81%) create mode 100644 src/pan3d/widgets/scalar_bar.py diff --git a/src/pan3d/explorers/analytics.py b/src/pan3d/explorers/analytics.py index a880e7bc..66ba2af6 100644 --- a/src/pan3d/explorers/analytics.py +++ b/src/pan3d/explorers/analytics.py @@ -18,7 +18,7 @@ from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar from pan3d.utils.convert import to_float -from pan3d.widgets.color import ScalarBar +from src.pan3d.widgets.color_by import ScalarBar from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import html diff --git a/src/pan3d/explorers/contour.py b/src/pan3d/explorers/contour.py index 51702201..f63c7f6b 100644 --- a/src/pan3d/explorers/contour.py +++ b/src/pan3d/explorers/contour.py @@ -27,8 +27,8 @@ from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar from pan3d.utils.convert import to_float -from pan3d.widgets.color import ScalarBar from pan3d.xarray.algorithm import vtkXArrayRectilinearSource +from src.pan3d.widgets.color_by import ScalarBar from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import vuetify3 as v3 diff --git a/src/pan3d/explorers/globe.py b/src/pan3d/explorers/globe.py index ed593cc1..66616c37 100644 --- a/src/pan3d/explorers/globe.py +++ b/src/pan3d/explorers/globe.py @@ -27,8 +27,8 @@ from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar from pan3d.utils.convert import update_camera from pan3d.utils.globe import get_continent_outlines, get_globe, get_globe_textures -from pan3d.widgets.color import ScalarBar from pan3d.xarray.algorithm import vtkXArrayRectilinearSource +from src.pan3d.widgets.color_by import ScalarBar from trame.app import asynchronous from trame.decorators import change from trame.ui.vuetify3 import VAppLayout diff --git a/src/pan3d/explorers/slicer.py b/src/pan3d/explorers/slicer.py index 1c9ca406..424d039b 100644 --- a/src/pan3d/explorers/slicer.py +++ b/src/pan3d/explorers/slicer.py @@ -27,8 +27,8 @@ from pan3d.ui.slicer import SliceRenderingSettings from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar -from pan3d.widgets.color import ScalarBar from pan3d.xarray.algorithm import vtkXArrayRectilinearSource +from src.pan3d.widgets.color_by import ScalarBar from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import html diff --git a/src/pan3d/utils/common.py b/src/pan3d/utils/common.py index 66c0d25c..cc620c25 100644 --- a/src/pan3d/utils/common.py +++ b/src/pan3d/utils/common.py @@ -6,7 +6,7 @@ from pan3d.ui.collapsible import CollapsableSection from pan3d.ui.css import base, preview from pan3d.utils.convert import update_camera -from pan3d.widgets.color import ColorBy +from pan3d.widgets.color_by import ColorBy from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from trame.app import asynchronous, get_server from trame.decorators import TrameApp, change diff --git a/src/pan3d/viewers/preview.py b/src/pan3d/viewers/preview.py index 4803b4bf..fe746cd7 100644 --- a/src/pan3d/viewers/preview.py +++ b/src/pan3d/viewers/preview.py @@ -17,7 +17,7 @@ from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar from pan3d.utils.convert import to_float -from pan3d.widgets.color import ScalarBar +from pan3d.widgets.scalar_bar import ScalarBar from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from trame.decorators import change from trame.ui.vuetify3 import VAppLayout @@ -95,10 +95,9 @@ def _build_ui(self, **kwargs): ) # Scalar bar - ScalarBar( + self.scalar_bar = ScalarBar( v_show="!control_expended", v_if="color_by", - img_src="preset_img", ) # Save dialog @@ -185,6 +184,14 @@ def _on_scale_change(self, scale_x, scale_y, scale_z, **_): self.ctrl.view_reset_camera() + @change("preset") + def _on_preset_change(self, preset, **_): + self.scalar_bar.set_preset(preset) + + @change("color_min", "color_max") + def _on_color_range_change(self, color_min, color_max, **_): + self.scalar_bar.set_color_range(color_min, color_max) + @change("data_origin_order") def _on_order_change(self, **_): if self.state.import_pending: diff --git a/src/pan3d/widgets/color.py b/src/pan3d/widgets/color_by.py similarity index 81% rename from src/pan3d/widgets/color.py rename to src/pan3d/widgets/color_by.py index 126e4eac..5d1f9561 100644 --- a/src/pan3d/widgets/color.py +++ b/src/pan3d/widgets/color_by.py @@ -1,7 +1,6 @@ from vtkmodules.vtkCommonCore import vtkLookupTable from vtkmodules.vtkRenderingCore import vtkMapper -from pan3d.ui.css import base, vtk_view from pan3d.utils.convert import to_image from pan3d.utils.presets import PRESETS, set_preset from trame.widgets import html @@ -241,51 +240,3 @@ def reset_color_range(self): state[self._color_max] = 1 self.ctrl.view_update() - - -class ScalarBar(v3.VTooltip): - def __init__(self, img_src, color_min="color_min", color_max="color_max", **kwargs): - """Scalar bar for the XArray Explorers.""" - super().__init__(location="top") - - # Activate CSS - self.server.enable_module(base) - self.server.enable_module(vtk_view) - - self.state.setdefault("scalarbar_probe", []) - self.state.client_only("scalarbar_probe", "scalarbar_probe_available") - - with self: - # Content - with html.Template(v_slot_activator="{ props }"): - with html.Div( - classes="scalarbar", - rounded="pill", - v_bind="props", - **kwargs, - ): - html.Div( - f"{{{{ {color_min}.toFixed(6) }}}}", classes="scalarbar-left" - ) - html.Img( - src=(img_src, None), - style="height: 100%; width: 100%;", - classes="rounded-lg border-thin", - mousemove="scalarbar_probe = [$event.x, $event.target.getBoundingClientRect()]", - mouseenter="scalarbar_probe_available = 1", - mouseleave="scalarbar_probe_available = 0", - __events=["mousemove", "mouseenter", "mouseleave"], - ) - html.Div( - v_show=("scalarbar_probe_available", False), - classes="scalar-cursor", - style=( - "`left: ${scalarbar_probe?.[0] - scalarbar_probe?.[1]?.left}px`", - ), - ) - html.Div( - f"{{{{ {color_max}.toFixed(6) }}}}", classes="scalarbar-right" - ) - html.Span( - f"{{{{ (({color_max} - {color_min}) * (scalarbar_probe?.[0] - scalarbar_probe?.[1]?.left) / scalarbar_probe?.[1]?.width + {color_min}).toFixed(6) }}}}" - ) diff --git a/src/pan3d/widgets/scalar_bar.py b/src/pan3d/widgets/scalar_bar.py new file mode 100644 index 00000000..6082031b --- /dev/null +++ b/src/pan3d/widgets/scalar_bar.py @@ -0,0 +1,105 @@ +from vtkmodules.vtkCommonCore import vtkLookupTable + +from pan3d.ui.css import base, vtk_view +from pan3d.utils.convert import to_image +from pan3d.utils.presets import PRESETS, set_preset +from trame.widgets import html +from trame.widgets import vuetify3 as v3 + + +class ScalarBar(v3.VTooltip): + """ + Scalar bar for the XArray Explorers. + """ + + _next_id = 0 + + def set_preset(self, preset_name): + """Set the color preset for the scalar bar.""" + if preset_name not in PRESETS: + err_msg = f"Preset '{preset_name}' not found." + raise ValueError(err_msg) + set_preset(self._lut, preset_name) + self.state[self.__preset_key] = to_image(self._lut) + + def set_color_range(self, color_min, color_max): + """Set the color range for the scalar bar.""" + self.state[self.__color_min_key] = color_min + self.state[self.__color_max_key] = color_max + # Update the lookup table range + self._lut.SetRange(color_min, color_max) + self.state[self.__preset_key] = to_image(self._lut) + + def __init__(self, preset=None, color_min=0.0, color_max=1.0, **kwargs): + """Scalar bar for the XArray Explorers.""" + super().__init__(location="top") + + print("Updating the scalar bar") + + # Activate CSS + self.server.enable_module(base) + self.server.enable_module(vtk_view) + + __ns = kwargs.get("namespace", "view") + if __ns == "view": + ScalarBar._next_id += 1 + if ScalarBar._next_id > 1: + __ns = f"view{ScalarBar._next_id}" + + self.__preset_key = f"{__ns}_preset" + self.__color_min_key = f"{__ns}_color_min" + self.__color_max_key = f"{__ns}_color_max" + self.__scalarbar_probe_key = f"{__ns}_scalarbar_probe" + self.__scalarbar_probe_key_available = f"{__ns}_scalarbar_probe_available" + + if preset is None: + preset = next(iter(PRESETS.keys())) + self._lut = vtkLookupTable() + set_preset(self._lut, preset) + + # Initialize state + self.state[self.__preset_key] = to_image(self._lut) + self.state[self.__color_min_key] = color_min + self.state[self.__color_max_key] = color_max + + self.state.setdefault(self.__scalarbar_probe_key, []) + self.state.client_only( + self.__scalarbar_probe_key, self.__scalarbar_probe_key_available + ) + + with self: + # Content + with html.Template(v_slot_activator="{ props }"): + with html.Div( + classes="scalarbar", + rounded="pill", + v_bind="props", + **kwargs, + ): + html.Div( + f"{{{{{self.__color_min_key}.toFixed(6) }}}}", + classes="scalarbar-left", + ) + html.Img( + src=(self.__preset_key, None), + style="height: 100%; width: 100%;", + classes="rounded-lg border-thin", + mousemove=f"{self.__scalarbar_probe_key} = [$event.x, $event.target.getBoundingClientRect()]", + mouseenter=f"{self.__scalarbar_probe_key_available} = 1", + mouseleave=f"{self.__scalarbar_probe_key_available} = 0", + __events=["mousemove", "mouseenter", "mouseleave"], + ) + html.Div( + v_show=(self.__scalarbar_probe_key_available, False), + classes="scalar-cursor", + style=( + f"`left: ${{{self.__scalarbar_probe_key}?.[0] - {self.__scalarbar_probe_key}?.[1]?.left}}px`", + ), + ) + html.Div( + f"{{{{ {self.__color_max_key}.toFixed(6) }}}}", + classes="scalarbar-right", + ) + html.Span( + f"{{{{ (({self.__color_max_key} - {self.__color_min_key}) * ({self.__scalarbar_probe_key}?.[0] - {self.__scalarbar_probe_key}?.[1]?.left) / {self.__scalarbar_probe_key}?.[1]?.width + {self.__color_min_key}).toFixed(6) }}}}" + ) diff --git a/src/trame/widgets/pan3d.py b/src/trame/widgets/pan3d.py index 15132712..1e05b377 100644 --- a/src/trame/widgets/pan3d.py +++ b/src/trame/widgets/pan3d.py @@ -1,3 +1,3 @@ -from pan3d.widgets.color import ColorBy, ScalarBar +from src.pan3d.widgets.color_by import ColorBy, ScalarBar __all__ = ["ColorBy", "ScalarBar"] From 1082e33818b06245bb71a9177207b4f4d8021596 Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Wed, 4 Jun 2025 11:51:28 -0700 Subject: [PATCH 3/4] fix (scalar bar): Fixing scalar bar abstraction --- pyproject.toml | 2 + src/pan3d/explorers/analytics.py | 8 ++ src/pan3d/explorers/contour.py | 8 ++ src/pan3d/explorers/globe.py | 200 +------------------------------ src/pan3d/explorers/slicer.py | 8 ++ src/pan3d/viewers/preview.py | 6 +- src/pan3d/widgets/scalar_bar.py | 124 +++++++++++-------- 7 files changed, 111 insertions(+), 245 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2ac2dde9..a920c1fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ classifiers = [ [project.optional-dependencies] viewer = [ "trame>=3.9", + "trame-client>=3.9.1", "trame-vtk>=2.8.15", "trame-vtklocal>=0.12.2", "trame-vuetify>=3.0.1", @@ -52,6 +53,7 @@ pangeo = [ all = [ # viewers/explorers "trame>=3.9", + "trame-client>=3.9.1", "trame-vtk>=2.8.15", "trame-vtklocal>=0.12.2", "trame-vuetify>=3.0.1", diff --git a/src/pan3d/explorers/analytics.py b/src/pan3d/explorers/analytics.py index 66ba2af6..49dca70c 100644 --- a/src/pan3d/explorers/analytics.py +++ b/src/pan3d/explorers/analytics.py @@ -224,6 +224,14 @@ def retrieve_source(self): def _on_color_by(self, **__): self.plotting.update_plot() + @change("color_preset") + def _on_preset_change(self, color_preset, **_): + self.scalar_bar.preset = color_preset + + @change("color_min", "color_max") + def _on_color_range_change(self, color_min, color_max, **_): + self.scalar_bar.set_color_range(color_min, color_max) + @change("scale_x", "scale_y", "scale_z") def _on_scale_change(self, scale_x, scale_y, scale_z, **_): self.actor.SetScale( diff --git a/src/pan3d/explorers/contour.py b/src/pan3d/explorers/contour.py index f63c7f6b..114e7515 100644 --- a/src/pan3d/explorers/contour.py +++ b/src/pan3d/explorers/contour.py @@ -279,6 +279,14 @@ def _on_update_color_range( self.bands.GenerateValues(nb_contours, [color_min, color_max]) self.ctrl.view_update() + @change("color_preset") + def _on_preset_change(self, color_preset, **_): + self.scalar_bar.preset = color_preset + + @change("color_min", "color_max") + def _on_color_range_change(self, color_min, color_max, **_): + self.scalar_bar.set_color_range(color_min, color_max) + def main(): app = ContourExplorer() diff --git a/src/pan3d/explorers/globe.py b/src/pan3d/explorers/globe.py index 66616c37..3ed1d7f3 100644 --- a/src/pan3d/explorers/globe.py +++ b/src/pan3d/explorers/globe.py @@ -1,7 +1,3 @@ -import json -import traceback -from pathlib import Path - import vtkmodules.vtkRenderingOpenGL2 # noqa: F401 from vtkmodules.vtkCommonCore import vtkObject from vtkmodules.vtkFiltersGeometry import vtkDataSetSurfaceFilter @@ -25,11 +21,9 @@ from pan3d.ui.globe import GlobeRenderingSettings from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar -from pan3d.utils.convert import update_camera from pan3d.utils.globe import get_continent_outlines, get_globe, get_globe_textures from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from src.pan3d.widgets.color_by import ScalarBar -from trame.app import asynchronous from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import vuetify3 as v3 @@ -242,110 +236,6 @@ def _on_texture_preset(self, texture, **_): self.gactor.SetTexture(self.textures[texture]) self.ctrl.view_update() - @change("data_origin_order") - def _on_order_change(self, **_): - if self.state.import_pending: - return - - self.state.load_button_text = "Load" - self.state.can_load = True - - # ----------------------------------------------------- - # Triggers - # ----------------------------------------------------- - def _import_file_upload(self, files): - self.import_state(json.loads(files[0].get("content"))) - - def _process_cli(self, **_): - args, _ = self.server.cli.parse_known_args() - - # Skip if xarray provided - if self.source.input: - if not self.source.arrays: - self.source.arrays = self.source.available_arrays - self.ctrl.xr_update_info(self.source.input, self.source.available_arrays) - self.ctrl.source_update_rendering_panel(self.source) - self._update_rendering(reset_camera=True) - self.state.show_rendering = True - return - - # import state - if args.import_state: - self._import_file_from_path(args.import_state) - - # load xarray (file) - elif args.xarray_file: - self.state.import_pending = True - with self.state: - self._load_dataset("file", args.xarray_file) - self.state.data_origin_id = str(Path(args.xarray_file).resolve()) - self.state.import_pending = False - - # load xarray (url) - elif args.xarray_url: - self.state.import_pending = True - with self.state: - self._load_dataset("url", args.xarray_url) - self.state.data_origin_id = args.xarray_url - self.state.import_pending = False - - def _import_file_from_path(self, file_path): - if file_path is None: - return - - file_path = Path(file_path) - if file_path.exists(): - self.import_state(json.loads(file_path.read_text("utf-8"))) - - def _load_dataset(self, source, id, order="C", config=None): - self.state.data_origin_source = source - self.state.data_origin_id = id - self.state.load_button_text = "Loaded" - self.state.can_load = False - self.state.show_data_information = True - - if config is None: - config = { - "arrays": [], - "slices": {}, - } - - try: - self.source.load( - { - "data_origin": { - "source": source, - "id": id, - "order": order, - }, - "dataset_config": config, - } - ) - if self.actor.visibility: - self.renderer.RemoveActor(self.actor) - self.actor.visibility = 0 - - # Make sure arrays are available - if not self.source.arrays: - self.source.arrays = self.source.available_arrays - - # Extract UI - self.ctrl.xr_update_info(self.source.input, self.source.available_arrays) - self.ctrl.source_update_rendering_panel(self.source) - - # no error - self.state.data_origin_error = False - except Exception as e: - self.state.data_origin_error = ( - f"Error occurred while trying to load data. {e}" - ) - self.state.data_origin_id_error = True - self.state.load_button_text = "Load" - self.state.can_load = True - self.state.show_data_information = False - - print(traceback.format_exc()) - def update_rendering(self, reset_camera=False): self.state.dirty_data = False @@ -370,91 +260,13 @@ def update_rendering(self, reset_camera=False): else: self.ctrl.view_update() - # ----------------------------------------------------- - # Public API - # ----------------------------------------------------- + @change("color_preset") + def _on_preset_change(self, color_preset, **_): + self.scalar_bar.preset = color_preset - def export_state(self): - """Return a json dump of the reader and viewer state""" - camera = self.renderer.active_camera - state_to_export = { - **self.source.state, - "xr-globe": { - "view_3d": self.state.view_3d, - "color_by": self.state.color_by, - "color_preset": self.state.color_preset, - "color_min": self.state.color_min, - "color_max": self.state.color_max, - "scale_x": self.state.scale_x, - "scale_y": self.state.scale_y, - "scale_z": self.state.scale_z, - }, - "camera": { - "position": camera.position, - "view_up": camera.view_up, - "focal_point": camera.focal_point, - "parallel_projection": camera.parallel_projection, - "parallel_scale": camera.parallel_scale, - }, - } - return json.dumps(state_to_export, indent=2) - - def import_state(self, data_state): - """ - Read the current state to load the data and visualization setup if any. - - Parameters: - - data_state (dict): reader (+viewer) state to reset to - """ - self.state.import_pending = True - try: - data_origin = data_state.get("data_origin") - source = data_origin.get("source") - id = data_origin.get("id") - order = data_origin.get("order", "C") - config = data_state.get("dataset_config") - globe_state = data_state.get("xr-globe", {}) - camera_state = data_state.get("camera", {}) - - # load data and initial rendering setup - with self.state: - self._load_dataset(source, id, order, config) - self.state.update(globe_state) - - # override computed color range using state values - with self.state: - self.state.update(globe_state) - - # update camera and render - update_camera(self.renderer.active_camera, camera_state) - self._update_rendering() - finally: - self.state.import_pending = False - - async def _save_dataset(self, file_path): - output_path = Path(file_path).resolve() - self.source.input.to_netcdf(output_path) - - def save_dataset(self, file_path): - """ - Write XArray data into a file using a background task. - So when used programmatically, make sure you await the returned task. - - Parameters: - - file_path (str): path to use for writing the file - - Returns: - writing task - """ - self.state.show_save_dialog = False - return asynchronous.create_task(self._save_dataset(file_path)) - - async def _async_display(self): - await self.ui.ready - self.ui._ipython_display_() - - def _ipython_display_(self): - asynchronous.create_task(self._async_display()) + @change("color_min", "color_max") + def _on_color_range_change(self, color_min, color_max, **_): + self.scalar_bar.set_color_range(color_min, color_max) # ----------------------------------------------------------------------------- diff --git a/src/pan3d/explorers/slicer.py b/src/pan3d/explorers/slicer.py index 424d039b..5333e969 100644 --- a/src/pan3d/explorers/slicer.py +++ b/src/pan3d/explorers/slicer.py @@ -344,6 +344,14 @@ def update_rendering(self, reset_camera=False): self.ctrl.view_reset_camera() + @change("color_preset") + def _on_preset_change(self, color_preset, **_): + self.scalar_bar.preset = color_preset + + @change("color_min", "color_max") + def _on_color_range_change(self, color_min, color_max, **_): + self.scalar_bar.set_color_range(color_min, color_max) + # ------------------------------------------------------------------------- # Property API # ------------------------------------------------------------------------- diff --git a/src/pan3d/viewers/preview.py b/src/pan3d/viewers/preview.py index fe746cd7..5fd99270 100644 --- a/src/pan3d/viewers/preview.py +++ b/src/pan3d/viewers/preview.py @@ -184,9 +184,9 @@ def _on_scale_change(self, scale_x, scale_y, scale_z, **_): self.ctrl.view_reset_camera() - @change("preset") - def _on_preset_change(self, preset, **_): - self.scalar_bar.set_preset(preset) + @change("color_preset") + def _on_preset_change(self, color_preset, **_): + self.scalar_bar.preset = color_preset @change("color_min", "color_max") def _on_color_range_change(self, color_min, color_max, **_): diff --git a/src/pan3d/widgets/scalar_bar.py b/src/pan3d/widgets/scalar_bar.py index 6082031b..676432ae 100644 --- a/src/pan3d/widgets/scalar_bar.py +++ b/src/pan3d/widgets/scalar_bar.py @@ -14,57 +14,85 @@ class ScalarBar(v3.VTooltip): _next_id = 0 - def set_preset(self, preset_name): - """Set the color preset for the scalar bar.""" - if preset_name not in PRESETS: - err_msg = f"Preset '{preset_name}' not found." - raise ValueError(err_msg) - set_preset(self._lut, preset_name) - self.state[self.__preset_key] = to_image(self._lut) - def set_color_range(self, color_min, color_max): """Set the color range for the scalar bar.""" - self.state[self.__color_min_key] = color_min - self.state[self.__color_max_key] = color_max - # Update the lookup table range - self._lut.SetRange(color_min, color_max) - self.state[self.__preset_key] = to_image(self._lut) + self.state[self.__color_min] = color_min + self.state[self.__color_max] = color_max - def __init__(self, preset=None, color_min=0.0, color_max=1.0, **kwargs): - """Scalar bar for the XArray Explorers.""" - super().__init__(location="top") + @property + def preset(self): + return self._preset + + @preset.setter + def preset(self, value): + if value not in PRESETS: + err_msg = f"Preset '{value}' not found." + raise ValueError(err_msg) + self._preset = value + set_preset(self._lut, value) + with self.state: + self.state[self.__preset_image] = to_image(self._lut) - print("Updating the scalar bar") + @property + def preset_image_name(self): + return self.__preset_image - # Activate CSS - self.server.enable_module(base) - self.server.enable_module(vtk_view) + @property + def color_min(self): + return self.state[self.__color_min] + + @color_min.setter + def color_min(self, value): + with self.state: + self.state[self.__color_min] = value - __ns = kwargs.get("namespace", "view") - if __ns == "view": - ScalarBar._next_id += 1 - if ScalarBar._next_id > 1: - __ns = f"view{ScalarBar._next_id}" + @property + def color_min_name(self): + return self.__color_min - self.__preset_key = f"{__ns}_preset" - self.__color_min_key = f"{__ns}_color_min" - self.__color_max_key = f"{__ns}_color_max" - self.__scalarbar_probe_key = f"{__ns}_scalarbar_probe" - self.__scalarbar_probe_key_available = f"{__ns}_scalarbar_probe_available" + @property + def color_max(self): + return self.state[self.__color_max] - if preset is None: - preset = next(iter(PRESETS.keys())) + @color_max.setter + def color_max(self, value): + with self.state: + self.state[self.__color_max] = value + + @property + def color_max_name(self): + return self.__color_max + + @classmethod + def next_id(cls): + """Get the next unique ID for the scalar bar.""" + cls._next_id += 1 + return f"pan3d_scalarbar{cls._next_id}" + + def __init__(self, preset="Fast", color_min=0.0, color_max=1.0, **kwargs): + """Scalar bar for the XArray Explorers.""" self._lut = vtkLookupTable() - set_preset(self._lut, preset) + super().__init__(location="top") + # Activate CSS + self.server.enable_module(base) + self.server.enable_module(vtk_view) - # Initialize state - self.state[self.__preset_key] = to_image(self._lut) - self.state[self.__color_min_key] = color_min - self.state[self.__color_max_key] = color_max + ns = self.next_id() + self.__preset_image = f"{ns}_preset" + self.__color_min = f"{ns}_color_min" + self.__color_max = f"{ns}_color_max" + # Probe enables mouse events for scalar bar + self.__probe_location = f"{ns}_probe_location" + self.__probe_enabled = f"{ns}_probe_enabled" - self.state.setdefault(self.__scalarbar_probe_key, []) + # Initialize state + self.preset = preset + self.set_color_range(color_min, color_max) + self.state[self.__probe_location] = [] + self.state[self.__probe_enabled] = 0 self.state.client_only( - self.__scalarbar_probe_key, self.__scalarbar_probe_key_available + self.__probe_location, + self.__probe_enabled, ) with self: @@ -77,29 +105,29 @@ def __init__(self, preset=None, color_min=0.0, color_max=1.0, **kwargs): **kwargs, ): html.Div( - f"{{{{{self.__color_min_key}.toFixed(6) }}}}", + f"{{{{{self.__color_min}.toFixed(6) }}}}", classes="scalarbar-left", ) html.Img( - src=(self.__preset_key, None), + src=(self.__preset_image, None), style="height: 100%; width: 100%;", classes="rounded-lg border-thin", - mousemove=f"{self.__scalarbar_probe_key} = [$event.x, $event.target.getBoundingClientRect()]", - mouseenter=f"{self.__scalarbar_probe_key_available} = 1", - mouseleave=f"{self.__scalarbar_probe_key_available} = 0", + mousemove=f"{self.__probe_location} = [$event.x, $event.target.getBoundingClientRect()]", + mouseenter=f"{self.__probe_enabled} = 1", + mouseleave=f"{self.__probe_enabled} = 0", __events=["mousemove", "mouseenter", "mouseleave"], ) html.Div( - v_show=(self.__scalarbar_probe_key_available, False), + v_show=(self.__probe_enabled, False), classes="scalar-cursor", style=( - f"`left: ${{{self.__scalarbar_probe_key}?.[0] - {self.__scalarbar_probe_key}?.[1]?.left}}px`", + f"`left: ${{{self.__probe_location}?.[0] - {self.__probe_location}?.[1]?.left}}px`", ), ) html.Div( - f"{{{{ {self.__color_max_key}.toFixed(6) }}}}", + f"{{{{ {self.__color_max}.toFixed(6) }}}}", classes="scalarbar-right", ) html.Span( - f"{{{{ (({self.__color_max_key} - {self.__color_min_key}) * ({self.__scalarbar_probe_key}?.[0] - {self.__scalarbar_probe_key}?.[1]?.left) / {self.__scalarbar_probe_key}?.[1]?.width + {self.__color_min_key}).toFixed(6) }}}}" + f"{{{{ (({self.__color_max} - {self.__color_min}) * ({self.__probe_location}?.[0] - {self.__probe_location}?.[1]?.left) / {self.__probe_location}?.[1]?.width + {self.__color_min}).toFixed(6) }}}}" ) From 617ca335da49e9030a90c3f7f6848fbd9fefdd30 Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Fri, 6 Jun 2025 14:05:15 -0700 Subject: [PATCH 4/4] fix (modularize): Redo color_by component --- src/pan3d/explorers/analytics.py | 42 ++-- src/pan3d/explorers/contour.py | 55 +---- src/pan3d/explorers/globe.py | 30 +-- src/pan3d/explorers/slicer.py | 47 +--- src/pan3d/ui/analytics.py | 4 + src/pan3d/ui/collapsible.py | 4 +- src/pan3d/ui/contour.py | 12 +- src/pan3d/ui/globe.py | 60 +---- src/pan3d/ui/preview.py | 143 +++--------- src/pan3d/ui/slicer.py | 11 +- src/pan3d/utils/common.py | 135 +++++++++-- src/pan3d/viewers/preview.py | 37 +-- src/pan3d/widgets/color_by.py | 388 +++++++++++++++++++++---------- src/pan3d/widgets/scalar_bar.py | 104 +++++---- src/trame/widgets/pan3d.py | 3 +- 15 files changed, 538 insertions(+), 537 deletions(-) diff --git a/src/pan3d/explorers/analytics.py b/src/pan3d/explorers/analytics.py index 49dca70c..3a3f7372 100644 --- a/src/pan3d/explorers/analytics.py +++ b/src/pan3d/explorers/analytics.py @@ -18,7 +18,8 @@ from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar from pan3d.utils.convert import to_float -from src.pan3d.widgets.color_by import ScalarBar +from pan3d.widgets.scalar_bar import ScalarBar +from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import html @@ -46,6 +47,11 @@ def __init__(self, xarray=None, source=None, server=None, local_rendering=None): """Create an instance of the AnalyticsExplorer class.""" super().__init__(xarray, source, server, local_rendering) + if self.source is None: + self.source = vtkXArrayRectilinearSource( + input=self.xarray + ) # To initialize the pipeline + self.ui = None self._setup_vtk() self._build_ui() @@ -70,6 +76,7 @@ def _setup_vtk(self): self.mapper = vtkPolyDataMapper( input_connection=self.geometry.output_port, ) + self.actor = vtkActor(mapper=self.mapper, visibility=0) self.interactor.Initialize() @@ -124,7 +131,6 @@ def _build_ui(self, **kwargs): label="File path to save", v_model=("save_dataset_path", ""), hide_details=True, - change=self.property_changed, ) with v3.VCardActions(): v3.VSpacer() @@ -169,9 +175,9 @@ def _build_ui(self, **kwargs): # Scalar bar ScalarBar( + ctx_name="scalar_bar", v_show="!control_expended", v_if="color_by", - img_src="preset_img", ) # # Summary toolbar @@ -190,11 +196,11 @@ def _build_ui(self, **kwargs): xr_update_info="xr_update_info", panel_label="Analytics Explorer", ).ui_content: - self.ctrl.source_update_rendering_panel = RenderingSettings( - self.retrieve_source, - self.retrieve_mapper, - self.update_rendering, - ).update_from_source + RenderingSettings( + ctx_name="rendering", + source=self.source, + update_rendering=self.update_rendering, + ) with v3.VNavigationDrawer( disable_resize_watcher=True, @@ -208,30 +214,15 @@ def _build_ui(self, **kwargs): source=self.source, toggle="chart_expanded" ) - def retrieve_mapper(self): - """Used as a callback to retrieve the mapper.""" - return self.mapper - - def retrieve_source(self): - """Used as a callback to retrieve the source.""" - return self.source - # ----------------------------------------------------- # State change callbacks # ----------------------------------------------------- @change("color_by") - def _on_color_by(self, **__): + def _on_color_by_change_on(self, **kwargs): + super()._on_color_properties_change(**kwargs) self.plotting.update_plot() - @change("color_preset") - def _on_preset_change(self, color_preset, **_): - self.scalar_bar.preset = color_preset - - @change("color_min", "color_max") - def _on_color_range_change(self, color_min, color_max, **_): - self.scalar_bar.set_color_range(color_min, color_max) - @change("scale_x", "scale_y", "scale_z") def _on_scale_change(self, scale_x, scale_y, scale_z, **_): self.actor.SetScale( @@ -273,7 +264,6 @@ def update_rendering(self, reset_camera=False): def main(): - print("Launching analytics Explorer") app = AnalyticsExplorer() app.start() diff --git a/src/pan3d/explorers/contour.py b/src/pan3d/explorers/contour.py index 114e7515..98fb6367 100644 --- a/src/pan3d/explorers/contour.py +++ b/src/pan3d/explorers/contour.py @@ -27,8 +27,8 @@ from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar from pan3d.utils.convert import to_float +from pan3d.widgets.scalar_bar import ScalarBar from pan3d.xarray.algorithm import vtkXArrayRectilinearSource -from src.pan3d.widgets.color_by import ScalarBar from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import vuetify3 as v3 @@ -145,9 +145,9 @@ def _build_ui(self, **_): # Scalar bar ScalarBar( + ctx_name="scalar_bar", v_show="!control_expended", v_if="color_by", - img_src="preset_img", ) # Save dialog @@ -207,11 +207,11 @@ def _build_ui(self, **_): xr_update_info="xr_update_info", panel_label="Contour Explorer", ).ui_content: - self.ctrl.source_update_rendering_panel = ContourRenderingSettings( - self.retrieve_source, - self.retrieve_mapper, - self.update_rendering, - ).update_from_source + ContourRenderingSettings( + ctx_name="rendering", + source=self.source, + update_rendering=self.update_rendering, + ) def update_rendering(self, reset_camera=False): self.renderer.ResetCamera() @@ -221,14 +221,6 @@ def update_rendering(self, reset_camera=False): self.ctrl.view_reset_camera() - def retrieve_mapper(self): - """Used as a callback to retrieve the mapper.""" - return self.mapper - - def retrieve_source(self): - """Used as a callback to retrieve the source.""" - return self.source - # ----------------------------------------------------- # State change callbacks # ----------------------------------------------------- @@ -252,40 +244,17 @@ def _on_scale_change(self, scale_x, scale_y, scale_z, **_): self.ctrl.view_reset_camera() - @change("color_by", "time_idx") - def _on_update_data(self, color_by, time_idx, **_): - if self.source.input is None: - return - - self.source.t_index = time_idx - self.source.arrays = [color_by] + @change("color_by", "nb_contours", "color_min", "color_max") + def _on_color_by_change( + self, color_by, nb_contours, color_min, color_max, **kwargs + ): self.assign.Assign( color_by, vtkDataSetAttributes.SCALARS, vtkDataObject.FIELD_ASSOCIATION_POINTS, ) - self.mapper.SelectColorArray(color_by) - self.mapper.Update() - # update range - if self.last_field != color_by: - self.last_field = color_by - - self.ctrl.view_update() - - @change("nb_contours") - def _on_update_color_range( - self, nb_contours, color_min, color_max, color_preset, **_ - ): self.bands.GenerateValues(nb_contours, [color_min, color_max]) - self.ctrl.view_update() - - @change("color_preset") - def _on_preset_change(self, color_preset, **_): - self.scalar_bar.preset = color_preset - - @change("color_min", "color_max") - def _on_color_range_change(self, color_min, color_max, **_): - self.scalar_bar.set_color_range(color_min, color_max) + super()._on_color_properties_change(**kwargs) def main(): diff --git a/src/pan3d/explorers/globe.py b/src/pan3d/explorers/globe.py index 3ed1d7f3..b020113f 100644 --- a/src/pan3d/explorers/globe.py +++ b/src/pan3d/explorers/globe.py @@ -22,8 +22,8 @@ from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar from pan3d.utils.globe import get_continent_outlines, get_globe, get_globe_textures +from pan3d.widgets.scalar_bar import ScalarBar from pan3d.xarray.algorithm import vtkXArrayRectilinearSource -from src.pan3d.widgets.color_by import ScalarBar from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import vuetify3 as v3 @@ -133,9 +133,9 @@ def _build_ui(self, **kwargs): # Scalar bar ScalarBar( + ctx_name="scalar_bar", v_show="!control_expended", v_if="color_by", - img_src="preset_img", ) # Save dialog @@ -194,19 +194,11 @@ def _build_ui(self, **kwargs): xr_update_info="xr_update_info", panel_label="Globe Explorer", ).ui_content: - self.ctrl.source_update_rendering_panel = GlobeRenderingSettings( - self.retrieve_source, - self.retrieve_mapper, - self.update_rendering, - ).update_from_source - - def retrieve_mapper(self): - """Used as a callback to retrieve the mapper.""" - return self.mapper - - def retrieve_source(self): - """Used as a callback to retrieve the source.""" - return self.source + GlobeRenderingSettings( + ctx_name="rendering", + source=self.source, + update_rendering=self.update_rendering, + ) # ----------------------------------------------------- # State change callbacks @@ -260,14 +252,6 @@ def update_rendering(self, reset_camera=False): else: self.ctrl.view_update() - @change("color_preset") - def _on_preset_change(self, color_preset, **_): - self.scalar_bar.preset = color_preset - - @change("color_min", "color_max") - def _on_color_range_change(self, color_min, color_max, **_): - self.scalar_bar.set_color_range(color_min, color_max) - # ----------------------------------------------------------------------------- # Main executable diff --git a/src/pan3d/explorers/slicer.py b/src/pan3d/explorers/slicer.py index 5333e969..876557f4 100644 --- a/src/pan3d/explorers/slicer.py +++ b/src/pan3d/explorers/slicer.py @@ -27,8 +27,8 @@ from pan3d.ui.slicer import SliceRenderingSettings from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar +from pan3d.widgets.scalar_bar import ScalarBar from pan3d.xarray.algorithm import vtkXArrayRectilinearSource -from src.pan3d.widgets.color_by import ScalarBar from trame.decorators import change from trame.ui.vuetify3 import VAppLayout from trame.widgets import html @@ -180,7 +180,7 @@ def _setup_vtk(self): self.plane = plane self.cutter = cutter self.slice_actor = slice_actor - self.slice_mapper = slice_mapper + self.mapper = slice_mapper outline = vtkOutlineFilter() outline_actor = vtkActor() @@ -253,9 +253,9 @@ def _build_ui(self, *args, **kwargs): # Scalar bar ScalarBar( + ctx_name="scalar_bar", v_show="!control_expended", v_if="color_by", - img_src="preset_img", ) # Save dialog @@ -322,19 +322,11 @@ def _build_ui(self, *args, **kwargs): xr_update_info="xr_update_info", panel_label="Slice Explorer", ).ui_content: - self.ctrl.source_update_rendering_panel = SliceRenderingSettings( - self.retrieve_source, - self.retrieve_mapper, - self.update_rendering, - ).update_from_source - - def retrieve_mapper(self): - """Used as a callback to retrieve the mapper.""" - return self.slice_mapper - - def retrieve_source(self): - """Used as a callback to retrieve the source.""" - return self.source + SliceRenderingSettings( + ctx_name="rendering", + source=self.source, + update_rendering=self.update_rendering, + ) def update_rendering(self, reset_camera=False): self.renderer.ResetCamera() @@ -344,14 +336,6 @@ def update_rendering(self, reset_camera=False): self.ctrl.view_reset_camera() - @change("color_preset") - def _on_preset_change(self, color_preset, **_): - self.scalar_bar.preset = color_preset - - @change("color_min", "color_max") - def _on_color_range_change(self, color_min, color_max, **_): - self.scalar_bar.set_color_range(color_min, color_max) - # ------------------------------------------------------------------------- # Property API # ------------------------------------------------------------------------- @@ -424,21 +408,6 @@ def scale_axis(self, sfac): self.outline_actor.SetScale(*sfac) self.on_view_mode_change(s.view_mode) - @property - def color_map(self): - """ - Returns the color map currently used for visualization - """ - return self.state.cmap - - @color_map.setter - def color_map(self, cmap): - """ - Sets the color map used for visualization - """ - with self.state: - self.state.cmap = cmap - # ------------------------------------------------------------------------- # UI triggers # ------------------------------------------------------------------------- diff --git a/src/pan3d/ui/analytics.py b/src/pan3d/ui/analytics.py index 73fe2769..bc61800b 100644 --- a/src/pan3d/ui/analytics.py +++ b/src/pan3d/ui/analytics.py @@ -6,7 +6,11 @@ from trame.widgets import vuetify3 as v3 try: + import logging + import xcdat # noqa: F401 + + logging.getLogger().setLevel(logging.CRITICAL + 1) except ModuleNotFoundError as e: print( f""" diff --git a/src/pan3d/ui/collapsible.py b/src/pan3d/ui/collapsible.py index 166e8e2b..a0b8ae9a 100644 --- a/src/pan3d/ui/collapsible.py +++ b/src/pan3d/ui/collapsible.py @@ -6,8 +6,8 @@ class CollapsableSection(AbstractElement): id_count = 0 - def __init__(self, title, var_name=None, expended=False): - super().__init__(None) + def __init__(self, title, var_name=None, expended=False, **kwargs): + super().__init__(None, **kwargs) CollapsableSection.id_count += 1 show = var_name or f"show_section_{CollapsableSection.id_count}" with v3.VCardSubtitle( diff --git a/src/pan3d/ui/contour.py b/src/pan3d/ui/contour.py index e8fd5e3c..f6ccd773 100644 --- a/src/pan3d/ui/contour.py +++ b/src/pan3d/ui/contour.py @@ -7,10 +7,10 @@ class ContourRenderingSettings(RenderingSettingsBasic): - def __init__(self, retrieve_source, retrieve_mapper, update_rendering): - super().__init__(retrieve_source, retrieve_mapper, update_rendering) + def __init__(self, source, update_rendering, **kwargs): + super().__init__(source, update_rendering, **kwargs) - self._retrieve_source = retrieve_source + self.source = source with self.content: # Actor scaling @@ -105,7 +105,7 @@ def __init__(self, retrieve_source, retrieve_mapper, update_rendering): ): v3.VSlider( prepend_icon="mdi-clock-outline", - v_model=("time_idx", 0), + v_model=("slice_t", 0), min=0, max=("slice_t_max", 0), step=1, @@ -128,12 +128,10 @@ def __init__(self, retrieve_source, retrieve_mapper, update_rendering): ) def update_from_source(self, source=None): - state = self.state - source = source or self._retrieve_source() if source is None: return - with self.state: + with self.state as state: state.data_arrays_available = source.available_arrays state.data_arrays = source.arrays state.color_by = None diff --git a/src/pan3d/ui/globe.py b/src/pan3d/ui/globe.py index 15df5b70..7ffbedef 100644 --- a/src/pan3d/ui/globe.py +++ b/src/pan3d/ui/globe.py @@ -1,18 +1,17 @@ import math from pan3d.utils.common import RenderingSettingsBasic -from pan3d.utils.constants import SLICE_VARS, XYZ +from pan3d.utils.constants import XYZ from pan3d.utils.convert import max_str_length -from trame.decorators import TrameApp, change from trame.widgets import html from trame.widgets import vuetify3 as v3 -@TrameApp() class GlobeRenderingSettings(RenderingSettingsBasic): - def __init__(self, retrieve_source, retrieve_mapper, update_rendering): - super().__init__(retrieve_source, retrieve_mapper, update_rendering) - self._retrieve_source = retrieve_source + def __init__(self, source, update_rendering, **kwargs): + super().__init__(source, update_rendering, **kwargs) + + self.source = source with self.content: v3.VDivider() @@ -342,12 +341,10 @@ def __init__(self, retrieve_source, retrieve_mapper, update_rendering): ) def update_from_source(self, source=None): - state = self.state - source = source or self._retrieve_source() if source is None: return - with self.state: + with self.state as state: state.data_arrays_available = source.available_arrays state.data_arrays = source.arrays state.color_by = None @@ -356,7 +353,7 @@ def update_from_source(self, source=None): slices = source.slices for axis in XYZ: # default - axis_extent = self.state.slice_extents.get(getattr(source, axis)) + axis_extent = state.slice_extents.get(getattr(source, axis)) state[f"slice_{axis}_range"] = axis_extent state[f"slice_{axis}_cut"] = 0 state[f"slice_{axis}_step"] = 1 @@ -386,46 +383,3 @@ def update_from_source(self, source=None): state.max_time_index_width = math.ceil( 0.6 + (math.log10(state.slice_t_max + 1) + 1) * 2 * 0.58 ) - - @change("slice_t", *[var.format(axis) for axis in XYZ for var in SLICE_VARS]) - def on_change(self, slice_t, **_): - if self.state.import_pending: - return - source = self._retrieve_source() - if source is None: - return - - state = self.state - slices = {source.t: slice_t} - for axis in XYZ: - axis_name = getattr(source, axis) - if axis_name is None: - continue - - if state[f"slice_{axis}_type"] == "range": - if state[f"slice_{axis}_range"] is None: - continue - - slices[axis_name] = [ - *state[f"slice_{axis}_range"], - int(state[f"slice_{axis}_step"]), - ] - slices[axis_name][1] += 1 # end is exclusive - else: - slices[axis_name] = state[f"slice_{axis}_cut"] - - source.slices = slices - ds = source() - state.dataset_bounds = ds.bounds - - self.ctrl.view_reset_clipping_range() - self.ctrl.view_update() - - @change("slice_t") - def _on_slice_t(self, slice_t, **_): - if self.state.import_pending: - return - source = self._retrieve_source() - if source is not None: - source.t_index = slice_t - self.ctrl.view_update() diff --git a/src/pan3d/ui/preview.py b/src/pan3d/ui/preview.py index 053dae59..651a05e1 100644 --- a/src/pan3d/ui/preview.py +++ b/src/pan3d/ui/preview.py @@ -1,20 +1,17 @@ import math -from pathlib import Path -from pan3d import catalogs as pan3d_catalogs from pan3d.utils.common import RenderingSettingsBasic -from pan3d.utils.constants import SLICE_VARS, XYZ +from pan3d.utils.constants import XYZ from pan3d.utils.convert import max_str_length -from trame.decorators import change from trame.widgets import html from trame.widgets import vuetify3 as v3 class RenderingSettings(RenderingSettingsBasic): - def __init__(self, retrieve_source, retrieve_mapper, update_rendering, **kwargs): - super().__init__(retrieve_source, retrieve_mapper, update_rendering, **kwargs) + def __init__(self, source, update_rendering, **kwargs): + super().__init__(source, update_rendering, **kwargs) - self._retrieve_source = retrieve_source + self.source = source self.state.setdefault("slice_extents", {}) self.state.setdefault("axis_names", []) self.state.setdefault("t_labels", []) @@ -310,130 +307,44 @@ def __init__(self, retrieve_source, retrieve_mapper, update_rendering, **kwargs) ) def update_from_source(self, source=None): - if source is None: - source = self._retrieve_source() + self.source = source or self.source - with self.state: - self.state.data_arrays_available = source.available_arrays - self.state.data_arrays = source.arrays - self.state.color_by = None - self.state.axis_names = [source.x, source.y, source.z] - self.state.slice_extents = source.slice_extents + with self.state as state: + state.data_arrays_available = source.available_arrays + state.data_arrays = source.arrays + # state.color_by = None + state.axis_names = [source.x, source.y, source.z] + state.slice_extents = source.slice_extents slices = source.slices for axis in XYZ: # default - axis_extent = self.state.slice_extents.get(getattr(source, axis)) - self.state[f"slice_{axis}_range"] = axis_extent - self.state[f"slice_{axis}_cut"] = 0 - self.state[f"slice_{axis}_step"] = 1 - self.state[f"slice_{axis}_type"] = "range" + axis_extent = state.slice_extents.get(getattr(source, axis)) + state[f"slice_{axis}_range"] = axis_extent + state[f"slice_{axis}_cut"] = 0 + state[f"slice_{axis}_step"] = 1 + state[f"slice_{axis}_type"] = "range" # use slice info if available axis_slice = slices.get(getattr(source, axis)) if axis_slice is not None: if isinstance(axis_slice, int): # cut - self.state[f"slice_{axis}_cut"] = axis_slice - self.state[f"slice_{axis}_type"] = "cut" + state[f"slice_{axis}_cut"] = axis_slice + state[f"slice_{axis}_type"] = "cut" else: # range - self.state[f"slice_{axis}_range"] = [ + state[f"slice_{axis}_range"] = [ axis_slice[0], axis_slice[1] - 1, ] # end is inclusive - self.state[f"slice_{axis}_step"] = axis_slice[2] + state[f"slice_{axis}_step"] = axis_slice[2] # Update time - self.state.slice_t = source.t_index - self.state.slice_t_max = source.t_size - 1 - self.state.t_labels = source.t_labels - self.state.max_time_width = math.ceil( - 0.58 * max_str_length(self.state.t_labels) - ) - if self.state.slice_t_max > 0: - self.state.max_time_index_width = math.ceil( - 0.6 + (math.log10(self.state.slice_t_max + 1) + 1) * 2 * 0.58 + state.slice_t = source.t_index + state.slice_t_max = source.t_size - 1 + state.t_labels = source.t_labels + state.max_time_width = math.ceil(0.58 * max_str_length(state.t_labels)) + if state.slice_t_max > 0: + state.max_time_index_width = math.ceil( + 0.6 + (math.log10(state.slice_t_max + 1) + 1) * 2 * 0.58 ) - - @change("data_origin_source") - def _on_data_origin_source(self, data_origin_source, **kwargs): - if self.state.import_pending: - return - - self.state.data_origin_id = "" - results, *_ = pan3d_catalogs.search(data_origin_source) - self.state.data_origin_ids = [v["name"] for v in results] - self.state.data_origin_id_to_desc = { - v["name"]: v["description"] for v in results - } - - @change("data_origin_id") - def _on_data_origin_id(self, data_origin_id, data_origin_source, **kwargs): - if self.state.import_pending: - return - - self.state.load_button_text = "Load" - self.state.can_load = True - - if data_origin_source == "file": - self.state.data_origin_id_error = not Path(data_origin_id).exists() - elif self.state.data_origin_id_error: - self.state.data_origin_id_error = False - - @change("slice_t", *[var.format(axis) for axis in XYZ for var in SLICE_VARS]) - def on_change(self, slice_t, **_): - source = self._retrieve_source() - if source is None: - return - - if self.state.import_pending: - return - - slices = {source.t: slice_t} - for axis in XYZ: - axis_name = getattr(source, axis) - if axis_name is None: - continue - - if self.state[f"slice_{axis}_type"] == "range": - if self.state[f"slice_{axis}_range"] is None: - continue - slices[axis_name] = [ - *self.state[f"slice_{axis}_range"], - int(self.state[f"slice_{axis}_step"]), - ] - slices[axis_name][1] += 1 # end is exclusive - else: - slices[axis_name] = self.state[f"slice_{axis}_cut"] - - source.slices = slices - ds = source() - self.state.dataset_bounds = ds.bounds - - self.ctrl.view_reset_clipping_range() - self.ctrl.view_update() - - @change("slice_t") - def _on_slice_t(self, slice_t, **_): - source = self._retrieve_source() - if source is None: - return - if self.state.import_pending: - return - - source.t_index = slice_t - self.ctrl.view_update() - - @change("data_arrays") - def _on_array_selection(self, data_arrays, **_): - if self.state.import_pending: - return - - self.state.dirty_data = True - if len(data_arrays) == 1: - self.state.color_by = data_arrays[0] - elif len(data_arrays) == 0: - self.state.color_by = None - source = self._retrieve_source() - if source is not None: - source.arrays = data_arrays diff --git a/src/pan3d/ui/slicer.py b/src/pan3d/ui/slicer.py index 9d1f1a17..ac0dfe93 100644 --- a/src/pan3d/ui/slicer.py +++ b/src/pan3d/ui/slicer.py @@ -7,9 +7,10 @@ class SliceRenderingSettings(RenderingSettingsBasic): - def __init__(self, retrieve_source, retrieve_mapper, update_rendering): - super().__init__(retrieve_source, retrieve_mapper, update_rendering) - self._retrieve_source = retrieve_source + def __init__(self, source, update_rendering, **kwargs): + super().__init__(source, update_rendering, **kwargs) + + self.source = source style = {"density": "compact", "hide_details": True} with self.content: @@ -168,8 +169,6 @@ def __init__(self, retrieve_source, retrieve_mapper, update_rendering): ) def update_from_source(self, source=None): - state = self.state - source = source or self._retrieve_source() if source is None: return @@ -180,7 +179,7 @@ def update_from_source(self, source=None): 0.5 * (bounds[2] + bounds[3]), 0.5 * (bounds[4] + bounds[5]), ] - with state: + with self.state as state: state.data_arrays_available = source.available_arrays state.data_arrays = source.arrays diff --git a/src/pan3d/utils/common.py b/src/pan3d/utils/common.py index cc620c25..9ec3470b 100644 --- a/src/pan3d/utils/common.py +++ b/src/pan3d/utils/common.py @@ -2,20 +2,22 @@ import traceback from pathlib import Path +import numpy as np + from pan3d import catalogs as pan3d_catalogs from pan3d.ui.collapsible import CollapsableSection from pan3d.ui.css import base, preview +from pan3d.utils.constants import SLICE_VARS, XYZ from pan3d.utils.convert import update_camera from pan3d.widgets.color_by import ColorBy from pan3d.xarray.algorithm import vtkXArrayRectilinearSource -from trame.app import asynchronous, get_server -from trame.decorators import TrameApp, change +from trame.app import TrameApp, asynchronous +from trame.decorators import change from trame.widgets import html from trame.widgets import vuetify3 as v3 -@TrameApp() -class Explorer: +class Explorer(TrameApp): def __init__(self, xarray=None, source=None, server=None, local_rendering=None): """ Parameters: @@ -32,7 +34,7 @@ def __init__(self, xarray=None, source=None, server=None, local_rendering=None): - `--wasm`: Use WASM for local rendering - `--vtkjs`: Use vtk.js for local rendering """ - self.server = get_server(server, client_type="vue3") + super().__init__(server, client_type="vue3") parser = self.server.cli explorer = parser.add_argument_group("Explorer Properties") @@ -129,7 +131,7 @@ def _process_cli(self, **_): elif self.xarray is not None: self.state.show_data_information = True self.ctrl.xr_update_info(self.source.input, self.source.available_arrays) - self.ctrl.source_update_rendering_panel(self.source) + self.ctx.rendering.update_from_source(self.source) def start(self, **kwargs): """Initialize the UI and start the server for XArray Viewer.""" @@ -183,6 +185,60 @@ def _on_data_origin_id(self, data_origin_id, data_origin_source, **kwargs): elif self.state.data_origin_id_error: self.state.data_origin_id_error = False + # ----------------------------------------------------- + # UI Components + # ----------------------------------------------------- + @change("color_by", "color_preset", "color_min", "color_max", "nan_color") + def _on_color_properties_change(self, **_): + if self.mapper: + self.ctx.rendering.color_by.configure_mapper(self.mapper) + self.ctx.scalar_bar.preset = self.state.color_preset + self.ctx.scalar_bar.set_color_range( + self.state.color_min, self.state.color_max + ) + self.ctrl.view_update() + + @change("slice_t", *[var.format(axis) for axis in XYZ for var in SLICE_VARS]) + def on_change(self, slice_t, **_): + source = self.source + if source is None: + return + + if self.state.import_pending: + return + + slices = {source.t: slice_t} + for axis in XYZ: + axis_name = getattr(source, axis) + if axis_name is None: + continue + + if self.state[f"slice_{axis}_type"] == "range": + if self.state[f"slice_{axis}_range"] is None: + continue + slices[axis_name] = [ + *self.state[f"slice_{axis}_range"], + int(self.state[f"slice_{axis}_step"]), + ] + slices[axis_name][1] += 1 # end is exclusive + else: + slices[axis_name] = self.state[f"slice_{axis}_cut"] + + source.slices = slices + ds = source() + self.state.dataset_bounds = ds.bounds + + self.ctrl.view_reset_clipping_range() + self.ctrl.view_update() + + @change("slice_t") + def _on_slice_t(self, slice_t, **_): + if self.state.import_pending: + return + + self.source.t_index = slice_t + self.ctrl.view_update() + # ----------------------------------------------------- # Triggers # ----------------------------------------------------- @@ -225,7 +281,7 @@ def load_dataset(self, source, id, order="C", config=None): # Extract UI self.ctrl.xr_update_info(self.source.input, self.source.available_arrays) - self.ctrl.source_update_rendering_panel(self.source) + self.ctx.rendering.update_from_source(self.source) # no error self.state.data_origin_error = False @@ -696,11 +752,10 @@ def __init__( self.ctrl[xr_update_info] = DataInformation().update_information -@TrameApp() class RenderingSettingsBasic(CollapsableSection): - def __init__(self, retrieve_source, retrieve_mapper, update_rendering): - super().__init__("Rendering", "show_rendering") - self._retrieve_source = retrieve_source + def __init__(self, source=None, update_rendering=None, **kwargs): + super().__init__(self, "Rendering", "show_rendering", **kwargs) + self.source = source with self.content: v3.VSelect( @@ -718,16 +773,62 @@ def __init__(self, retrieve_source, retrieve_mapper, update_rendering): variant="solo", ) v3.VDivider() - ColorBy(retrieve_source=retrieve_source, retrieve_mapper=retrieve_mapper) + self.color_by = ColorBy( + color_by_name="color_by", + preset_name="color_preset", + color_min_name="color_min", + color_max_name="color_max", + nan_color_name="nan_color", + reset_color_range=self.reset_color_range, + ) + + def reset_color_range(self): + """Reset the color range to the min and max values of the selected data array.""" + color_by = self.color_by.color_by + ds = self.source() + array = ( + ds.point_data[color_by] + if color_by in ds.point_data.keys() + else ds.cell_data[color_by] + if color_by in ds.cell_data.keys() + else None + ) + if array is not None: + self.color_by.color_min = float(np.min(array)) + self.color_by.color_max = float(np.max(array)) + else: + self.color_by.color_min = 0.0 + self.color_by.color_max = 1.0 + + self.ctrl.view_update() + + def _get_array_info(self): + if self.source is None or self.source.input is None: + return [] + ds = self.source() + array_info = [] + for association in ["point_data", "cell_data", "field_data"]: + arrays = getattr(ds, association, None) + if arrays is not None: + for array in arrays: + array_info.append( + { + "name": array.GetName(), + "min": np.min(array), + "max": np.max(array), + "assoc": association, + } + ) + return array_info @change("data_arrays") def _on_array_selection(self, data_arrays, **_): - if self.state.import_pending: - return + # if self.state.import_pending: + # return self.state.dirty_data = True - source = self._retrieve_source() - if source is not None: - source.arrays = data_arrays + if self.source is not None: + self.source.arrays = data_arrays + self.color_by.data_arrays = self._get_array_info() def update_from_source(self, source=None): raise NotImplementedError( diff --git a/src/pan3d/viewers/preview.py b/src/pan3d/viewers/preview.py index 5fd99270..495c000b 100644 --- a/src/pan3d/viewers/preview.py +++ b/src/pan3d/viewers/preview.py @@ -73,14 +73,6 @@ def _setup_vtk(self): # UI # ------------------------------------------------------------------------- - def retrieve_mapper(self): - """Used as a callback to retrieve the mapper.""" - return self.mapper - - def retrieve_source(self): - """Used as a callback to retrieve the source.""" - return self.source - def _build_ui(self, **kwargs): self.state.trame__title = "XArray Viewer" @@ -95,7 +87,8 @@ def _build_ui(self, **kwargs): ) # Scalar bar - self.scalar_bar = ScalarBar( + ScalarBar( + ctx_name="scalar_bar", v_show="!control_expended", v_if="color_by", ) @@ -155,11 +148,11 @@ def _build_ui(self, **kwargs): export_file_download=self.export_state, xr_update_info="xr_update_info", ).ui_content: - self.ctrl.source_update_rendering_panel = RenderingSettings( - self.retrieve_source, - self.retrieve_mapper, - self.update_rendering, - ).update_from_source + RenderingSettings( + ctx_name="rendering", + source=self.source, + update_rendering=self.update_rendering, + ) # ----------------------------------------------------- # State change callbacks @@ -184,22 +177,6 @@ def _on_scale_change(self, scale_x, scale_y, scale_z, **_): self.ctrl.view_reset_camera() - @change("color_preset") - def _on_preset_change(self, color_preset, **_): - self.scalar_bar.preset = color_preset - - @change("color_min", "color_max") - def _on_color_range_change(self, color_min, color_max, **_): - self.scalar_bar.set_color_range(color_min, color_max) - - @change("data_origin_order") - def _on_order_change(self, **_): - if self.state.import_pending: - return - - self.state.load_button_text = "Load" - self.state.can_load = True - # ----------------------------------------------------- # Triggers # ----------------------------------------------------- diff --git a/src/pan3d/widgets/color_by.py b/src/pan3d/widgets/color_by.py index 5d1f9561..2f9e5c5f 100644 --- a/src/pan3d/widgets/color_by.py +++ b/src/pan3d/widgets/color_by.py @@ -1,68 +1,129 @@ +from typing import Optional + from vtkmodules.vtkCommonCore import vtkLookupTable -from vtkmodules.vtkRenderingCore import vtkMapper from pan3d.utils.convert import to_image from pan3d.utils.presets import PRESETS, set_preset from trame.widgets import html from trame.widgets import vuetify3 as v3 +POINT_DATA = "point_data" +CELL_DATA = "cell_data" +FIELD_DATA = "field_data" + class ColorBy(html.Div): - """Color settings for the XArray Explorers. - Arguments: - source: The source of the data to be colored. - color_by: The name of the data array to color by. - data_arrays: The list of available data arrays. - color_min: The minimum value for the color range. - color_max: The maximum value for the color range. - nan_color: The color to use for NaN values. - color_preset: The name of the color preset to use. - color_presets: The list of available color presets. + """ + Color settings for the XArray Explorers. """ + _next_id = 0 + + @classmethod + def next_id(cls): + """Get the next unique ID for the scalar bar.""" + cls._next_id += 1 + return f"pan3d_scalarbar{cls._next_id}" + def __init__( self, - retrieve_source=None, - retrieve_mapper=None, - color_by="color_by", - data_arrays="data_arrays", - color_min="color_min", - color_max="color_max", - nan_color="nan_color", - color_preset="color_preset", - color_presets="color_presets", - preset_img="preset_img", + data_arrays: Optional[list[dict]] = None, + color_by=None, + preset="Fast", + color_min=0.0, + color_max=1.0, + nan_color=0, + color_by_name=None, + preset_name=None, + color_min_name=None, + color_max_name=None, + nan_color_name=None, + reset_color_range=None, **kwargs, ): + """ + Initialize the ColorBy UI component. + + Parameters + ---------- + data_arrays : List[Dict], optional + A list of dictionaries representing available arrays to color by. + Each dictionary should contain keys such for name, association, and array min and max. + The association can be 'point_data', 'cell_data', or 'field_data' + e.g. [{'name' : 'temperature', 'assoc' : 'point_data', 'min' : 0.0, 'max' : 100.0}] + + color_by : str or None, optional + The name of the currently selected array to color by. + + preset : str, optional + Name of the colormap preset to use. Defaults to "Fast". + + color_min : float, optional + Minimum value of the color mapping range. Defaults to 0.0. + + color_max : float, optional + Maximum value of the color mapping range. Defaults to 1.0. + + nan_color : int or tuple, optional + Color to use for NaN values (index into colormap). Defaults to 0. + + color_by_name : str or None, optional + Name of the UI variable to bind the `color_by` selection to. + + preset_name : str or None, optional + Name of the UI variable to bind the `preset` selection to. + + color_min_name : str or None, optional + Name of the UI variable to bind the `color_min` value to. + + color_max_name : str or None, optional + Name of the UI variable to bind the `color_max` value to. + + nan_color_name : str or None, optional + Name of the UI variable to bind the `nan_color` value to. + + reset_color_range : callable or None, optional + Optional callback function to reset the color range to the array's min/max. + + **kwargs : dict + Additional keyword arguments passed to the parent `html.Div` component. + """ + + self._lut = vtkLookupTable() super().__init__(**kwargs) - self.lut = vtkLookupTable() - - # initialize component specific variables - self._retrieve_source = retrieve_source - self._retrieve_mapper = retrieve_mapper - self._color_by = color_by - self._data_arrays = data_arrays - self._color_min = color_min - self._color_max = color_max - self._nan_color = nan_color - self._color_preset = color_preset - self._color_presets = color_presets - self._preset_img = preset_img - - # Track state changes - self.state.change(data_arrays)(self._on_change_data_arrays) - self.state.change(color_by)(self._on_change_color_by) - self.state.change(color_min, color_max, color_preset, nan_color)( - self._on_change_properties - ) + ns = self.next_id() + # Variables that serve and input/output (interactive) can be user specified + self.__color_by = color_by_name or f"{ns}_color_by" + self.__color_preset = preset_name or f"{ns}_preset" + self.__color_min = color_min_name or f"{ns}_color_min" + self.__color_max = color_max_name or f"{ns}_color_max" + self.__nan_color = nan_color_name or f"{ns}_nan_color" + + # Variables that are only input or only output do not need user specification + self.__data_arrays = f"{ns}_data_arrays" + self.__preset_image = f"{ns}_preset_img" + self.__color_presets = f"{ns}_color_presets" + + # Register changes based on state update within the widget + self.state.change(self.__color_preset)(self.set_preset) + self.state.change(self.__color_by)(self.set_color_by) + self.state.change(self.__nan_color)(self.set_nan_color) + + self.__array_infos = None + self.data_arrays = data_arrays + self.color_by = color_by + self.preset = preset + self.color_min = color_min + self.color_max = color_max + self.nan_color = nan_color with self: v3.VSelect( placeholder="Color By", prepend_inner_icon="mdi-format-color-fill", - v_model=(color_by, None), - items=(data_arrays, []), + v_model=(self.__color_by, None), + items=(self.__data_arrays, []), clearable=True, hide_details=True, density="compact", @@ -74,7 +135,7 @@ def __init__( with v3.VCol(): v3.VTextField( prepend_inner_icon="mdi-water-minus", - v_model_number=(color_min, 0.45), + v_model_number=(self.__color_min, 0.45), type="number", hide_details=True, density="compact", @@ -85,7 +146,7 @@ def __init__( with v3.VCol(): v3.VTextField( prepend_inner_icon="mdi-water-plus", - v_model_number=(color_max, 5.45), + v_model_number=(self.__color_max, 5.45), type="number", hide_details=True, density="compact", @@ -101,20 +162,20 @@ def __init__( flat=True, variant="outlined", classes="mx-2", - click=self.reset_color_range, + click=reset_color_range, ) # v3.VDivider() with html.Div(classes="mx-2"): html.Img( - src=("preset_img", None), + src=(self.__preset_image, None), style="height: 0.75rem; width: 100%;", classes="rounded-lg border-thin", ) v3.VSelect( placeholder="Color Preset", prepend_inner_icon="mdi-palette", - v_model=(color_preset, "Fast"), - items=(color_presets, list(PRESETS.keys())), + v_model=(self.__color_preset, "Fast"), + items=(self.__color_presets, list(PRESETS.keys())), hide_details=True, density="compact", flat=True, @@ -126,7 +187,7 @@ def __init__( ): with html.Template(v_slot_activator="{ props }"): with v3.VItemGroup( - v_model=nan_color, + v_model=self.__nan_color, v_bind="props", classes="d-inline-flex ga-1 pa-2", mandatory="force", @@ -159,84 +220,165 @@ def __init__( click="toggle", ) - def _on_change_data_arrays(self, **__): - state = self.state - data_arrays = self.state[self._data_arrays] - color_by = state[self._color_by] - if len(data_arrays) == 0: - state[self._color_by] = None + @property + def data_arrays(self): + """ + Returns the arrays available to color the data by based on the list of dictionaries representing array metadata + Each dictionary contains keys such for name, association, and array min and max. + The association can be either of 'point_data', 'cell_data', or 'field_data' + e.g. [{'name' : 'temperature', 'assoc' : 'point_data', 'min' : 0.0, 'max' : 100.0}] + """ + return self.__array_infos + + @data_arrays.setter + def data_arrays(self, array_info: list[dict]): + """ + Controls the arrays available to color the data by based on the list of dictionaries representing array metadata + Each dictionary should contain keys such for name, association, and array min and max. + The association can be either of 'point_data', 'cell_data', or 'field_data' + e.g. [{'name' : 'temperature', 'assoc' : 'point_data', 'min' : 0.0, 'max' : 100.0}] + """ + if array_info is None: + return + self.__array_infos = array_info + data_arrays = [info["name"] for info in array_info] + self.state[self.__data_arrays] = data_arrays + + color_by = self.color_by + # If the data arrays are empty, set color_by to None + if array_info is None or len(data_arrays) == 0: + self.color_by = None + # If the color_by is not in the new data arrays, reset it to the first available elif color_by is None or color_by not in data_arrays: - state[self._color_by] = data_arrays[0] + self.color_by = data_arrays[0] - def _on_change_color_by(self, **__): - state = self.state - source = self._retrieve_source() - mapper: vtkMapper = self._retrieve_mapper() + @property + def color_by(self): + return self.state[self.__color_by] - if source is None: - return + def set_color_by(self, **kwargs): + self.color_by = self.state[self.__color_by] + + @color_by.setter + def color_by(self, value): + """Set the array to color by.""" + with self.state: + self.state[self.__color_by] = value + if self.__array_infos: + info = next( + (info for info in self.__array_infos if info["name"] == value), None + ) + if info is not None: + self.color_min = float(info.get("min", self.color_min)) + self.color_max = float(info.get("max", self.color_max)) + else: + self.color_min = 0.0 + self.color_max = 1.0 + + @property + def color_by_name(self): + return self.__color_by + + def set_preset(self, **_): + """Set the color preset for the scalar bar.""" + self.preset = self.state[self.__color_preset] + + @property + def preset(self): + return self._preset + + @preset.setter + def preset(self, value): + """Set the color preset to color the data by.""" + if value not in PRESETS: + err_msg = f"Preset '{value}' not found." + raise ValueError(err_msg) + self._preset = value + set_preset(self._lut, value) + with self.state: + self.state[self.__preset_image] = to_image(self._lut) + + @property + def preset_image_name(self): + return self.__preset_image - color_by = state[self._color_by] + def set_color_range(self, color_min, color_max): + """Set the color range for the scalar bar.""" + self.state[self.__color_min] = color_min + self.state[self.__color_max] = color_max - ds = source() - if color_by in ds.point_data.keys(): # vtk is missing in iter - array = ds.point_data[color_by] - min_value, max_value = array.GetRange() + @property + def color_min(self): + return self.state[self.__color_min] - state[self._color_min] = min_value - state[self._color_max] = max_value + @color_min.setter + def color_min(self, value): + """Set the minimum value of the color mapping range.""" + with self.state: + self.state[self.__color_min] = value - if mapper is not None: - mapper.SetLookupTable(self.lut) - mapper.SelectColorArray(color_by) - mapper.SetScalarModeToUsePointFieldData() - mapper.InterpolateScalarsBeforeMappingOn() - mapper.SetScalarVisibility(1) - else: - if mapper is not None: + @property + def color_min_name(self): + return self.__color_min + + @property + def color_max(self): + return self.state[self.__color_max] + + @color_max.setter + def color_max(self, value): + """Set the maximum value of the color mapping range.""" + with self.state: + self.state[self.__color_max] = value + + @property + def color_max_name(self): + return self.__color_max + + @property + def nan_color(self): + nan_colors = self.state.nan_colors + nan_color = self.state[self.__nan_color] + return nan_colors.get(nan_color, [0, 0, 0, 0]) + + @nan_color.setter + def nan_color(self, value: int): + """Set the color for NaN values.""" + with self.state: + self.state[self.__nan_color] = value + nan_color = self.state.nan_colors[value] + self._lut.SetNanColor(nan_color) + + def set_nan_color(self, **_): + nan_colors = self.state.nan_colors + nan_color = nan_colors[self.state[self.__nan_color]] + self._lut.SetNanColor(nan_color) + + def configure_mapper(self, mapper): + """Configure the color mapper with the current settings for any data association.""" + # Find the association type for the selected array + assoc = None + if self.__array_infos: + info = next( + (info for info in self.__array_infos if info["name"] == self.color_by), + None, + ) + if info is not None: + assoc = info.get("assoc", POINT_DATA).lower() + else: mapper.SetScalarVisibility(0) - state[self._color_min] = 0 - state[self._color_max] = 1 - - def _on_change_properties(self, **__): - """Change the color properties based on the selected data array,preset and range""" - state = self.state - mapper: vtkMapper = self._retrieve_mapper() - - color_min = state[self._color_min] - color_max = state[self._color_max] - color_min = float(color_min) - color_max = float(color_max) - if mapper is not None: - mapper.SetLookupTable(self.lut) - mapper.SetScalarRange(color_min, color_max) - - nan_colors = state.nan_colors - nan_color = state[self._nan_color] - color = nan_colors[nan_color] - self.lut.SetNanColor(color) - - preset = state[self._color_preset] - set_preset(self.lut, preset) - state.preset_img = to_image(self.lut, 255) - - self.ctrl.view_update() - - def reset_color_range(self): - """Reset the color range to the min and max values of the selected data array.""" - state = self.state - color_by = state[self._color_by] - source = self._retrieve_source() - ds = source() - - if color_by in ds.point_data.keys(): # vtk is missing in iter - array = ds.point_data[color_by] - min_value, max_value = array.GetRange() - - state[self._color_min] = min_value - state[self._color_max] = max_value - else: - state[self._color_min] = 0 - state[self._color_max] = 1 - - self.ctrl.view_update() + return + + # Set the color array and scalar mode based on association + if assoc == POINT_DATA: + mapper.SetScalarModeToUsePointFieldData() + elif assoc == CELL_DATA: + mapper.SetScalarModeToUseCellFieldData() + elif assoc == FIELD_DATA: + mapper.SetScalarModeToUseFieldData() + + set_preset(self._lut, self.preset) + mapper.SetLookupTable(self._lut) + mapper.SelectColorArray(self.color_by) + mapper.SetScalarRange(self.color_min, self.color_max) + mapper.SetScalarVisibility(1) diff --git a/src/pan3d/widgets/scalar_bar.py b/src/pan3d/widgets/scalar_bar.py index 676432ae..adc30499 100644 --- a/src/pan3d/widgets/scalar_bar.py +++ b/src/pan3d/widgets/scalar_bar.py @@ -14,65 +14,18 @@ class ScalarBar(v3.VTooltip): _next_id = 0 - def set_color_range(self, color_min, color_max): - """Set the color range for the scalar bar.""" - self.state[self.__color_min] = color_min - self.state[self.__color_max] = color_max - - @property - def preset(self): - return self._preset - - @preset.setter - def preset(self, value): - if value not in PRESETS: - err_msg = f"Preset '{value}' not found." - raise ValueError(err_msg) - self._preset = value - set_preset(self._lut, value) - with self.state: - self.state[self.__preset_image] = to_image(self._lut) - - @property - def preset_image_name(self): - return self.__preset_image - - @property - def color_min(self): - return self.state[self.__color_min] - - @color_min.setter - def color_min(self, value): - with self.state: - self.state[self.__color_min] = value - - @property - def color_min_name(self): - return self.__color_min - - @property - def color_max(self): - return self.state[self.__color_max] - - @color_max.setter - def color_max(self, value): - with self.state: - self.state[self.__color_max] = value - - @property - def color_max_name(self): - return self.__color_max - @classmethod def next_id(cls): """Get the next unique ID for the scalar bar.""" cls._next_id += 1 return f"pan3d_scalarbar{cls._next_id}" - def __init__(self, preset="Fast", color_min=0.0, color_max=1.0, **kwargs): + def __init__( + self, preset="Fast", color_min=0.0, color_max=1.0, ctx_name=None, **kwargs + ): """Scalar bar for the XArray Explorers.""" self._lut = vtkLookupTable() - super().__init__(location="top") + super().__init__(location="top", ctx_name=ctx_name) # Activate CSS self.server.enable_module(base) self.server.enable_module(vtk_view) @@ -131,3 +84,52 @@ def __init__(self, preset="Fast", color_min=0.0, color_max=1.0, **kwargs): html.Span( f"{{{{ (({self.__color_max} - {self.__color_min}) * ({self.__probe_location}?.[0] - {self.__probe_location}?.[1]?.left) / {self.__probe_location}?.[1]?.width + {self.__color_min}).toFixed(6) }}}}" ) + + @property + def preset(self): + return self._preset + + @preset.setter + def preset(self, value): + if value not in PRESETS: + err_msg = f"Preset '{value}' not found." + raise ValueError(err_msg) + self._preset = value + set_preset(self._lut, value) + with self.state: + self.state[self.__preset_image] = to_image(self._lut) + + @property + def preset_image_name(self): + return self.__preset_image + + @property + def color_min(self): + return self.state[self.__color_min] + + @color_min.setter + def color_min(self, value): + with self.state: + self.state[self.__color_min] = value + + @property + def color_min_name(self): + return self.__color_min + + @property + def color_max(self): + return self.state[self.__color_max] + + @color_max.setter + def color_max(self, value): + with self.state: + self.state[self.__color_max] = value + + @property + def color_max_name(self): + return self.__color_max + + def set_color_range(self, color_min, color_max): + """Set the color range for the scalar bar.""" + self.state[self.__color_min] = color_min + self.state[self.__color_max] = color_max diff --git a/src/trame/widgets/pan3d.py b/src/trame/widgets/pan3d.py index 1e05b377..09deb223 100644 --- a/src/trame/widgets/pan3d.py +++ b/src/trame/widgets/pan3d.py @@ -1,3 +1,4 @@ -from src.pan3d.widgets.color_by import ColorBy, ScalarBar +from pan3d.widgets.color_by import ColorBy +from pan3d.widgets.scalar_bar import ScalarBar __all__ = ["ColorBy", "ScalarBar"]