From c74bfb64cd95cc0c9d5b53ccb1ef445941e9fcd5 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 3 Feb 2026 16:49:25 +0000 Subject: [PATCH 1/4] MG: prolong and inject UFL expressions --- firedrake/mg/interface.py | 23 ++++++++++++++++------- firedrake/mg/kernels.py | 13 ++++++++----- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/firedrake/mg/interface.py b/firedrake/mg/interface.py index a700b7eacc..9801d953c8 100644 --- a/firedrake/mg/interface.py +++ b/firedrake/mg/interface.py @@ -5,6 +5,7 @@ from firedrake.cofunction import Cofunction from firedrake.petsc import PETSc from ufl.duals import is_dual +from ufl.algorithms.analysis import extract_coefficients from . import utils from . import kernels @@ -13,6 +14,10 @@ def check_arguments(coarse, fine, needs_dual=False): + if coarse.ufl_shape != fine.ufl_shape: + raise ValueError("Mismatching function space shapes") + coarse, = extract_coefficients(coarse) + fine, = extract_coefficients(fine) if is_dual(coarse) != needs_dual: expected_type = Cofunction if needs_dual else Function raise TypeError("Coarse argument is a %s, not a %s" % (type(coarse).__name__, expected_type.__name__)) @@ -29,13 +34,13 @@ def check_arguments(coarse, fine, needs_dual=False): raise ValueError("Coarse argument must be from coarser space") if hierarchy is not fhierarchy: raise ValueError("Can't transfer between functions from different hierarchies") - if coarse.ufl_shape != fine.ufl_shape: - raise ValueError("Mismatching function space shapes") @PETSc.Log.EventDecorator() def prolong(coarse, fine): check_arguments(coarse, fine) + coarse_expr = coarse + coarse, = extract_coefficients(coarse_expr) Vc = coarse.function_space() Vf = fine.function_space() if len(Vc) > 1: @@ -78,7 +83,7 @@ def prolong(coarse, fine): coarse_coords = get_coordinates(Vc) fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc) fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space()) - kernel = kernels.prolong_kernel(coarse, Vf) + kernel = kernels.prolong_kernel(coarse_expr, Vf) # XXX: Should be able to figure out locations by pushing forward # reference cell node locations to physical space. @@ -100,6 +105,7 @@ def prolong(coarse, fine): new_fine = finest if j == repeat-1 else Function(Vfinest.reconstruct(mesh=meshes[next_level])) fine = new_fine.interpolate(fine) coarse = fine + coarse_expr = coarse return fine @@ -174,6 +180,8 @@ def restrict(fine_dual, coarse_dual): @PETSc.Log.EventDecorator() def inject(fine, coarse): check_arguments(coarse, fine) + fine_expr = fine + fine, = extract_coefficients(fine) Vf = fine.function_space() Vc = coarse.function_space() if len(Vc) > 1: @@ -212,14 +220,14 @@ def inject(fine, coarse): # Introduce an intermediate quadrature target space Vc = Vc.quadrature_space() - kernel, dg = kernels.inject_kernel(Vf, Vc) - if dg and not hierarchy.nested: - raise NotImplementedError("Sorry, we can't do supermesh projections yet!") - coarsest = coarse.zero() Vcoarsest = coarsest.function_space() meshes = hierarchy._meshes for j in range(repeat): + kernel, dg = kernels.inject_kernel(fine_expr, Vc) + if dg and not hierarchy.nested: + raise NotImplementedError("Sorry, we can't do supermesh projections yet!") + next_level -= 1 if j == repeat - 1 and not needs_quadrature: coarse = coarsest @@ -264,6 +272,7 @@ def inject(fine, coarse): new_coarse = coarsest if j == repeat - 1 else Function(Vcoarsest.reconstruct(mesh=meshes[next_level])) coarse = new_coarse.interpolate(coarse) fine = coarse + fine_expr = fine return coarse diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index d1ee220129..ef1fd9a572 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -8,6 +8,7 @@ from firedrake.mg import utils from ufl.algorithms import estimate_total_polynomial_degree +from ufl.algorithms.analysis import extract_coefficients from ufl.domain import extract_unique_domain import loopy as lp @@ -163,7 +164,8 @@ def _make_element_key(element): def prolong_kernel(expression, Vf): - Vc = expression.ufl_function_space() + coarse, = extract_coefficients(expression) + Vc = coarse.ufl_function_space() hierarchy, levelf = utils.get_level(Vf.mesh()) hierarchy, levelc = utils.get_level(Vc.mesh()) if Vc.mesh().extruded: @@ -191,7 +193,7 @@ def prolong_kernel(expression, Vf): evaluate_code = compile_element(expression, ufl.TestFunction(Vf.dual())) to_reference_kernel = to_reference_coordinates(coordinates.ufl_element()) coords_element = create_element(coordinates.ufl_element()) - element = create_element(expression.ufl_element()) + element = create_element(coarse.ufl_element()) my_kernel = """#include %(to_reference)s @@ -338,8 +340,10 @@ def restrict_kernel(Vf, Vc): return cache.setdefault(key, op2.Kernel(my_kernel, name="pyop2_kernel_restrict")) -def inject_kernel(Vf, Vc): +def inject_kernel(expression, Vc): if Vc.finat_element.is_dg(): + fine, = extract_coefficients(expression) + Vf = fine.ufl_function_space() hierarchy, level = utils.get_level(Vc.mesh()) if Vf.extruded: assert Vc.extruded @@ -359,9 +363,8 @@ def inject_kernel(Vf, Vc): return cache[key] except KeyError: ncandidate = hierarchy.coarse_to_fine_cells[level].shape[1] * level_ratio - return cache.setdefault(key, (dg_injection_kernel(Vf, Vc, ncandidate), True)) + return cache.setdefault(key, (dg_injection_kernel(expression, Vc, ncandidate), True)) else: - expression = ufl.Coefficient(Vf) return (prolong_kernel(expression, Vc), False) From 1c3a613a71be88c9303a6265d54127a37e70fb10 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 4 Feb 2026 11:14:15 +0000 Subject: [PATCH 2/4] fixes --- firedrake/mg/interface.py | 16 +++++-- firedrake/mg/kernels.py | 90 ++++++++++++++++++++++++--------------- firedrake/supermeshing.py | 9 ++-- 3 files changed, 73 insertions(+), 42 deletions(-) diff --git a/firedrake/mg/interface.py b/firedrake/mg/interface.py index 9801d953c8..f3850d182c 100644 --- a/firedrake/mg/interface.py +++ b/firedrake/mg/interface.py @@ -1,3 +1,4 @@ +import ufl from pyop2 import op2 from firedrake import ufl_expr, dmhooks @@ -83,8 +84,6 @@ def prolong(coarse, fine): coarse_coords = get_coordinates(Vc) fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc) fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space()) - kernel = kernels.prolong_kernel(coarse_expr, Vf) - # XXX: Should be able to figure out locations by pushing forward # reference cell node locations to physical space. # x = \sum_i c_i \phi_i(x_hat) @@ -94,6 +93,10 @@ def prolong(coarse, fine): for d in [coarse, coarse_coords]: d.dat.global_to_local_begin(op2.READ) d.dat.global_to_local_end(op2.READ) + + fine_dual = ufl.TestFunction(Vf.dual()) + ufl_interpolate = ufl.Interpolate(coarse_expr, fine_dual) + kernel = kernels.prolong_kernel(ufl_interpolate) op2.par_loop(kernel, fine.node_set, fine.dat(op2.WRITE), coarse.dat(op2.READ, fine_to_coarse), @@ -167,7 +170,10 @@ def restrict(fine_dual, coarse_dual): for d in [coarse_coords]: d.dat.global_to_local_begin(op2.READ) d.dat.global_to_local_end(op2.READ) - kernel = kernels.restrict_kernel(Vf, Vc) + + coarse_expr = ufl.TestFunction(Vc.dual()) + ufl_interpolate = ufl.Interpolate(coarse_expr, fine_dual) + kernel = kernels.restrict_kernel(ufl_interpolate) op2.par_loop(kernel, fine_dual.node_set, coarse_dual.dat(op2.INC, fine_to_coarse), fine_dual.dat(op2.READ), @@ -224,7 +230,9 @@ def inject(fine, coarse): Vcoarsest = coarsest.function_space() meshes = hierarchy._meshes for j in range(repeat): - kernel, dg = kernels.inject_kernel(fine_expr, Vc) + + ufl_interpolate = ufl.Interpolate(fine_expr, ufl.TestFunction(Vc.dual())) + kernel, dg = kernels.inject_kernel(ufl_interpolate) if dg and not hierarchy.nested: raise NotImplementedError("Sorry, we can't do supermesh projections yet!") diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index ef1fd9a572..c67d941adf 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -93,22 +93,26 @@ def to_reference_coordinates(ufl_coordinate_element, parameters=None): return evaluate_template_c % code -def compile_element(operand, dual_arg, parameters=None, - name="evaluate"): +def compile_element(ufl_interpolate, parameters=None, name="evaluate"): """Generate code for point evaluations. Parameters ---------- - operand: ufl.Expr - A primal expression - dual_arg: ufl.Coargument | ufl.Cofunction - A dual argument or coefficient + ufl_interpolate: ufl.Interpolate + A symbolic Interpolate expression + + parameters: dict + The form compiler parameters + + name: str + The name of the kernel Returns ------- str The generated code """ + dual_arg, operand = ufl_interpolate.argument_slots() domain = extract_unique_domain(operand) cell = domain.ufl_cell() dim = cell.topological_dimension @@ -129,7 +133,7 @@ def compile_element(operand, dual_arg, parameters=None, dual_arg = ufl.Cofunction(target_space.dual()) else: dual_arg = ufl.Coargument(target_space.dual(), number=dual_arg.number()) - expression = ufl.Interpolate(operand, dual_arg) + expression = ufl_interpolate._ufl_expr_reconstruct_(operand, dual_arg) # Create a runtime Quadrature element to_element = create_element(ufl_element) @@ -163,9 +167,11 @@ def _make_element_key(element): return entity_dofs_key(element.complex.get_topology()) + entity_dofs_key(element.entity_dofs()) -def prolong_kernel(expression, Vf): - coarse, = extract_coefficients(expression) +def prolong_kernel(ufl_interpolate): + dual_arg, operand = ufl_interpolate.argument_slots() + coarse, = extract_coefficients(ufl_interpolate) Vc = coarse.ufl_function_space() + Vf = dual_arg.ufl_function_space().dual() hierarchy, levelf = utils.get_level(Vf.mesh()) hierarchy, levelc = utils.get_level(Vc.mesh()) if Vc.mesh().extruded: @@ -190,7 +196,7 @@ def prolong_kernel(expression, Vf): try: return cache[key] except KeyError: - evaluate_code = compile_element(expression, ufl.TestFunction(Vf.dual())) + evaluate_code = compile_element(ufl_interpolate) to_reference_kernel = to_reference_coordinates(coordinates.ufl_element()) coords_element = create_element(coordinates.ufl_element()) element = create_element(coarse.ufl_element()) @@ -257,7 +263,11 @@ def prolong_kernel(expression, Vf): return cache.setdefault(key, op2.Kernel(my_kernel, name="pyop2_kernel_prolong")) -def restrict_kernel(Vf, Vc): +def restrict_kernel(ufl_interpolate): + dual_arg, operand = ufl_interpolate.argument_slots() + primal_arg, = ufl_interpolate.arguments() + Vc = primal_arg.ufl_function_space().dual() + Vf = dual_arg.ufl_function_space() hierarchy, levelf = utils.get_level(Vf.mesh()) if Vf.mesh().extruded: assert Vc.mesh().extruded @@ -273,7 +283,7 @@ def restrict_kernel(Vf, Vc): return cache[key] except KeyError: assert isinstance(Vc, FiredrakeDualSpace) and isinstance(Vf, FiredrakeDualSpace) - evaluate_code = compile_element(ufl.TestFunction(Vc.dual()), ufl.Cofunction(Vf)) + evaluate_code = compile_element(ufl_interpolate) to_reference_kernel = to_reference_coordinates(coordinates.ufl_element()) coords_element = create_element(coordinates.ufl_element()) element = create_element(Vc.ufl_element()) @@ -340,9 +350,11 @@ def restrict_kernel(Vf, Vc): return cache.setdefault(key, op2.Kernel(my_kernel, name="pyop2_kernel_restrict")) -def inject_kernel(expression, Vc): +def inject_kernel(ufl_interpolate): + dual_arg, operand = ufl_interpolate.argument_slots() + Vc = dual_arg.ufl_function_space().dual() if Vc.finat_element.is_dg(): - fine, = extract_coefficients(expression) + fine, = extract_coefficients(ufl_interpolate) Vf = fine.ufl_function_space() hierarchy, level = utils.get_level(Vc.mesh()) if Vf.extruded: @@ -363,9 +375,9 @@ def inject_kernel(expression, Vc): return cache[key] except KeyError: ncandidate = hierarchy.coarse_to_fine_cells[level].shape[1] * level_ratio - return cache.setdefault(key, (dg_injection_kernel(expression, Vc, ncandidate), True)) + return cache.setdefault(key, (dg_injection_kernel(ufl_interpolate, ncandidate), True)) else: - return (prolong_kernel(expression, Vc), False) + return (prolong_kernel(ufl_interpolate), False) class MacroKernelBuilder(firedrake_interface.KernelBuilderBase): @@ -410,18 +422,23 @@ def _coefficient(self, coefficient, name): return funarg -def dg_injection_kernel(Vf, Vc, ncell): +def dg_injection_kernel(ufl_interpolate, ncell): from firedrake import Tensor, AssembledVector, TestFunction, TrialFunction from firedrake.slate.slac import compile_expression if complex_mode: raise NotImplementedError("In complex mode we are waiting for Slate") + + dual_arg, operand = ufl_interpolate.argument_slots() + coefficients = extract_coefficients(ufl_interpolate) + source_mesh = extract_unique_domain(operand) macro_builder = MacroKernelBuilder(ScalarType, ncell) - macro_builder._domain_integral_type_map = {Vf.mesh(): "cell"} - macro_builder._entity_ids = {Vf.mesh(): (0,)} - f = ufl.Coefficient(Vf) - macro_builder.set_coefficients([f]) - macro_builder.set_coordinates(Vf.mesh()) + macro_builder._domain_integral_type_map = {source_mesh: "cell"} + macro_builder._entity_ids = {source_mesh: (0,)} + macro_builder.set_coefficients(coefficients) + macro_builder.set_coordinates(source_mesh) + fine, = coefficients + Vf = fine.ufl_function_space() Vfe = create_element(Vf.ufl_element()) ref_complex = Vfe.complex variant = Vf.ufl_element().variant() or "default" @@ -429,40 +446,43 @@ def dg_injection_kernel(Vf, Vc, ncell): from FIAT import macro ref_complex = macro.PowellSabinSplit(Vfe.cell) - macro_quadrature_rule = make_quadrature(ref_complex, estimate_total_polynomial_degree(ufl.inner(f, f))) + quadrature_degree = estimate_total_polynomial_degree(ufl.inner(operand, ufl.TestFunction(Vf))) + macro_quadrature_rule = make_quadrature(ref_complex, quadrature_degree) index_cache = {} parameters = default_parameters() integration_dim, _ = lower_integral_type(Vfe.cell, "cell") macro_cfg = dict(interface=macro_builder, - ufl_cell=Vf.ufl_cell(), + ufl_cell=source_mesh.ufl_cell(), integration_dim=integration_dim, index_cache=index_cache, quadrature_rule=macro_quadrature_rule, scalar_type=parameters["scalar_type"]) macro_context = fem.PointSetContext(**macro_cfg) - fexpr, = fem.compile_ufl(f, macro_context) - X = ufl.SpatialCoordinate(Vf.mesh()) + fexpr, = fem.compile_ufl(operand, macro_context) + X = ufl.SpatialCoordinate(source_mesh) C_a, = fem.compile_ufl(X, macro_context) - detJ = ufl_utils.preprocess_expression(abs(ufl.JacobianDeterminant(extract_unique_domain(f))), + detJ = ufl_utils.preprocess_expression(abs(ufl.JacobianDeterminant(source_mesh)), complex_mode=complex_mode) macro_detJ, = fem.compile_ufl(detJ, macro_context) + Vc = dual_arg.ufl_function_space().dual() Vce = create_element(Vc.ufl_element()) + target_mesh = Vc.mesh() - info = TSFCIntegralDataInfo(domain=Vc.mesh(), + info = TSFCIntegralDataInfo(domain=target_mesh, integral_type="cell", subdomain_id=("otherwise",), domain_number=0, - domain_integral_type_map={Vc.mesh(): "cell"}, + domain_integral_type_map={target_mesh: "cell"}, arguments=(ufl.TestFunction(Vc), ), coefficients=(), coefficient_split={}, coefficient_numbers=()) coarse_builder = firedrake_interface.KernelBuilder(info, parameters["scalar_type"]) - coarse_builder.set_coordinates([Vc.mesh()]) - coarse_builder.set_entity_numbers([Vc.mesh()]) + coarse_builder.set_coordinates([target_mesh]) + coarse_builder.set_entity_numbers([target_mesh]) argument_multiindices = coarse_builder.argument_multiindices argument_multiindex, = argument_multiindices return_variable, = coarse_builder.return_variables @@ -472,14 +492,14 @@ def dg_injection_kernel(Vf, Vc, ncell): quadrature_rule = make_quadrature(Vce.cell, 0) coarse_cfg = dict(interface=coarse_builder, - ufl_cell=Vc.ufl_cell(), + ufl_cell=target_mesh.ufl_cell(), integration_dim=integration_dim, index_cache=index_cache, quadrature_rule=quadrature_rule, scalar_type=parameters["scalar_type"]) - X = ufl.SpatialCoordinate(Vc.mesh()) - K = ufl_utils.preprocess_expression(ufl.JacobianInverse(Vc.mesh()), + X = ufl.SpatialCoordinate(target_mesh) + K = ufl_utils.preprocess_expression(ufl.JacobianInverse(target_mesh), complex_mode=complex_mode) coarse_context = fem.PointSetContext(**coarse_cfg) C_0, = fem.compile_ufl(X, coarse_context) @@ -505,7 +525,7 @@ def dg_injection_kernel(Vf, Vc, ncell): # Coarse basis function evaluated at fine quadrature points phi_c = fem.fiat_to_ufl(Vce.point_evaluation(0, X_a, (Vce.cell.get_dimension(), 0)), 0) - index_shape = f.ufl_element().reference_value_shape + index_shape = Vf.ufl_element().reference_value_shape tensor_indices = tuple(gem.Index(extent=d) for d in index_shape) phi_c = gem.Indexed(phi_c, argument_multiindex + tensor_indices) diff --git a/firedrake/supermeshing.py b/firedrake/supermeshing.py index e9b3a8c38f..32717d7eea 100644 --- a/firedrake/supermeshing.py +++ b/firedrake/supermeshing.py @@ -213,9 +213,12 @@ def likely(cell_A): V_S_A = FunctionSpace(reference_mesh, V_A.ufl_element()) V_S_B = FunctionSpace(reference_mesh, V_B.ufl_element()) - evaluate_kernel_A = compile_element(ufl.Coefficient(V_A), ufl.TestFunction(V_S_A.dual()), name="evaluate_kernel_A") - evaluate_kernel_B = compile_element(ufl.Coefficient(V_B), ufl.TestFunction(V_S_B.dual()), name="evaluate_kernel_B") - evaluate_kernel_S = compile_element(ufl.Coefficient(V_S), ufl.TestFunction(V_S.dual()), name="evaluate_kernel_S") + interp_A = ufl.Interpolate(ufl.Coefficient(V_A), ufl.TestFunction(V_S_A.dual())) + interp_B = ufl.Interpolate(ufl.Coefficient(V_B), ufl.TestFunction(V_S_B.dual())) + interp_S = ufl.Interpolate(ufl.Coefficient(V_S), ufl.TestFunction(V_S.dual())) + evaluate_kernel_A = compile_element(interp_A, name="evaluate_kernel_A") + evaluate_kernel_B = compile_element(interp_B, name="evaluate_kernel_B") + evaluate_kernel_S = compile_element(interp_S, name="evaluate_kernel_S") M_SS = assemble(inner(TrialFunction(V_S_A), TestFunction(V_S_B)) * dx) M_SS = M_SS.petscmat[:, :] From c778cf63ea4b5a6d674e9f484ad6e064e56c2603 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 4 Feb 2026 17:12:58 +0000 Subject: [PATCH 3/4] unify parloop --- firedrake/mg/interface.py | 117 +++++++++++++++++++------------------- firedrake/mg/kernels.py | 8 ++- 2 files changed, 62 insertions(+), 63 deletions(-) diff --git a/firedrake/mg/interface.py b/firedrake/mg/interface.py index f3850d182c..78cdd93a69 100644 --- a/firedrake/mg/interface.py +++ b/firedrake/mg/interface.py @@ -1,4 +1,5 @@ import ufl +from itertools import repeat from pyop2 import op2 from firedrake import ufl_expr, dmhooks @@ -7,6 +8,7 @@ from firedrake.petsc import PETSc from ufl.duals import is_dual from ufl.algorithms.analysis import extract_coefficients +from ufl.domain import extract_unique_domain from . import utils from . import kernels @@ -37,6 +39,53 @@ def check_arguments(coarse, fine, needs_dual=False): raise ValueError("Can't transfer between functions from different hierarchies") +def multigrid_transfer(ufl_interpolate, tensor=None): + if tensor is None: + tensor = Function(ufl_interpolate.ufl_function_space()) + + coefficients = extract_coefficients(ufl_interpolate) + if is_dual(ufl_interpolate.ufl_function_space()): + kernel = kernels.restrict_kernel(ufl_interpolate) + access = op2.INC + source, = ufl_interpolate.arguments() + else: + kernel = kernels.prolong_kernel(ufl_interpolate) + access = op2.WRITE + source, = coefficients + + dual_arg, operand = ufl_interpolate.argument_slots() + Vtarget = dual_arg.ufl_function_space().dual() + source_mesh = extract_unique_domain(operand) + target_mesh = Vtarget.mesh() + + # XXX: Should be able to figure out locations by pushing forward + # reference cell node locations to physical space. + # x = \sum_i c_i \phi_i(x_hat) + target_coords = utils.physical_node_locations(Vtarget) + source_coords = get_coordinates(source.ufl_function_space()) + if utils.get_level(target_mesh)[1] > utils.get_level(source_mesh)[1]: + node_map = utils.fine_node_to_coarse_node_map + else: + node_map = utils.coarse_node_to_fine_node_map + + # Have to do this, because the node set core size is not right for + # this expanded stencil + for d in [target_coords, *coefficients]: + if d.function_space().mesh() is target_mesh: + continue + d.dat.global_to_local_begin(op2.READ) + d.dat.global_to_local_end(op2.READ) + + def parloop_arg(c, access): + m_ = None if c.function_space().mesh() is target_mesh else node_map(Vtarget, c.function_space()) + return c.dat(access, m_) + + op2.par_loop(kernel, Vtarget.node_set, + parloop_arg(tensor, access), + *map(parloop_arg, (*coefficients, target_coords, source_coords), repeat(op2.READ))) + return tensor + + @PETSc.Log.EventDecorator() def prolong(coarse, fine): check_arguments(coarse, fine) @@ -81,27 +130,9 @@ def prolong(coarse, fine): Vf = fine.function_space() Vc = coarse.function_space() - coarse_coords = get_coordinates(Vc) - fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc) - fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space()) - # XXX: Should be able to figure out locations by pushing forward - # reference cell node locations to physical space. - # x = \sum_i c_i \phi_i(x_hat) - node_locations = utils.physical_node_locations(Vf) - # Have to do this, because the node set core size is not right for - # this expanded stencil - for d in [coarse, coarse_coords]: - d.dat.global_to_local_begin(op2.READ) - d.dat.global_to_local_end(op2.READ) - fine_dual = ufl.TestFunction(Vf.dual()) ufl_interpolate = ufl.Interpolate(coarse_expr, fine_dual) - kernel = kernels.prolong_kernel(ufl_interpolate) - op2.par_loop(kernel, fine.node_set, - fine.dat(op2.WRITE), - coarse.dat(op2.READ, fine_to_coarse), - node_locations.dat(op2.READ), - coarse_coords.dat(op2.READ, fine_to_coarse_coords)) + multigrid_transfer(ufl_interpolate, tensor=fine) if needs_quadrature: # Transfer to the actual target space @@ -157,28 +188,9 @@ def restrict(fine_dual, coarse_dual): Vf = fine_dual.function_space() Vc = coarse_dual.function_space() - # XXX: Should be able to figure out locations by pushing forward - # reference cell node locations to physical space. - # x = \sum_i c_i \phi_i(x_hat) - node_locations = utils.physical_node_locations(Vf.dual()) - - coarse_coords = get_coordinates(Vc.dual()) - fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc) - fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space()) - # Have to do this, because the node set core size is not right for - # this expanded stencil - for d in [coarse_coords]: - d.dat.global_to_local_begin(op2.READ) - d.dat.global_to_local_end(op2.READ) - coarse_expr = ufl.TestFunction(Vc.dual()) ufl_interpolate = ufl.Interpolate(coarse_expr, fine_dual) - kernel = kernels.restrict_kernel(ufl_interpolate) - op2.par_loop(kernel, fine_dual.node_set, - coarse_dual.dat(op2.INC, fine_to_coarse), - fine_dual.dat(op2.READ), - node_locations.dat(op2.READ), - coarse_coords.dat(op2.READ, fine_to_coarse_coords)) + multigrid_transfer(ufl_interpolate, tensor=coarse_dual) fine_dual = coarse_dual return coarse_dual @@ -230,12 +242,6 @@ def inject(fine, coarse): Vcoarsest = coarsest.function_space() meshes = hierarchy._meshes for j in range(repeat): - - ufl_interpolate = ufl.Interpolate(fine_expr, ufl.TestFunction(Vc.dual())) - kernel, dg = kernels.inject_kernel(ufl_interpolate) - if dg and not hierarchy.nested: - raise NotImplementedError("Sorry, we can't do supermesh projections yet!") - next_level -= 1 if j == repeat - 1 and not needs_quadrature: coarse = coarsest @@ -243,23 +249,14 @@ def inject(fine, coarse): coarse = Function(Vc.reconstruct(mesh=meshes[next_level])) Vc = coarse.function_space() Vf = fine.function_space() - if not dg: - fine_coords = get_coordinates(Vf) - coarse_to_fine = utils.coarse_node_to_fine_node_map(Vc, Vf) - coarse_to_fine_coords = utils.coarse_node_to_fine_node_map(Vc, fine_coords.function_space()) - node_locations = utils.physical_node_locations(Vc) - # Have to do this, because the node set core size is not right for - # this expanded stencil - for d in [fine, fine_coords]: - d.dat.global_to_local_begin(op2.READ) - d.dat.global_to_local_end(op2.READ) - op2.par_loop(kernel, coarse.node_set, - coarse.dat(op2.WRITE), - fine.dat(op2.READ, coarse_to_fine), - node_locations.dat(op2.READ), - fine_coords.dat(op2.READ, coarse_to_fine_coords)) + ufl_interpolate = ufl.Interpolate(fine_expr, ufl.TestFunction(Vc.dual())) + if not Vf.finat_element.is_dg(): + multigrid_transfer(ufl_interpolate, tensor=coarse) else: + kernel, dg = kernels.inject_kernel(ufl_interpolate) + if dg and not hierarchy.nested: + raise NotImplementedError("Sorry, we can't do supermesh projections yet!") coarse_coords = get_coordinates(Vc) fine_coords = get_coordinates(Vf) coarse_cell_to_fine_nodes = utils.coarse_cell_to_fine_node_map(Vc, Vf) diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index c67d941adf..a9bd1cf139 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -196,6 +196,7 @@ def prolong_kernel(ufl_interpolate): try: return cache[key] except KeyError: + name = "pyop2_kernel_prolong" evaluate_code = compile_element(ufl_interpolate) to_reference_kernel = to_reference_coordinates(coordinates.ufl_element()) coords_element = create_element(coordinates.ufl_element()) @@ -205,7 +206,7 @@ def prolong_kernel(ufl_interpolate): %(to_reference)s %(evaluate)s __attribute__((noinline)) /* Clang bug */ - static void pyop2_kernel_prolong(PetscScalar *R, PetscScalar *f, const PetscScalar *X, const PetscScalar *Xc) + static void %(name)s(PetscScalar *R, PetscScalar *f, const PetscScalar *X, const PetscScalar *Xc) { PetscScalar Xref[%(tdim)d]; int cell = -1; @@ -249,7 +250,8 @@ def prolong_kernel(ufl_interpolate): } pyop2_kernel_evaluate(%(kernel_args)s); } - """ % {"to_reference": str(to_reference_kernel), + """ % {"name": name, + "to_reference": str(to_reference_kernel), "evaluate": evaluate_code, "kernel_args": _make_kernel_args(element, "R", "Xci", "fi", "Xref"), "ncandidate": ncandidate, @@ -260,7 +262,7 @@ def prolong_kernel(ufl_interpolate): "coarse_cell_inc": element.space_dimension(), "tdim": element.cell.get_spatial_dimension()} - return cache.setdefault(key, op2.Kernel(my_kernel, name="pyop2_kernel_prolong")) + return cache.setdefault(key, op2.Kernel(my_kernel, name=name)) def restrict_kernel(ufl_interpolate): From 24ce3de93fb6fa1d146c7fd5fed3386f2be46d82 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 4 Feb 2026 17:33:44 +0000 Subject: [PATCH 4/4] tidy --- firedrake/mg/interface.py | 51 ++++++++++++++++++--------------------- firedrake/mg/kernels.py | 4 +-- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/firedrake/mg/interface.py b/firedrake/mg/interface.py index 78cdd93a69..5742d1ce6a 100644 --- a/firedrake/mg/interface.py +++ b/firedrake/mg/interface.py @@ -57,24 +57,22 @@ def multigrid_transfer(ufl_interpolate, tensor=None): Vtarget = dual_arg.ufl_function_space().dual() source_mesh = extract_unique_domain(operand) target_mesh = Vtarget.mesh() + if utils.get_level(target_mesh)[1] > utils.get_level(source_mesh)[1]: + node_map = utils.fine_node_to_coarse_node_map + else: + node_map = utils.coarse_node_to_fine_node_map # XXX: Should be able to figure out locations by pushing forward # reference cell node locations to physical space. # x = \sum_i c_i \phi_i(x_hat) target_coords = utils.physical_node_locations(Vtarget) source_coords = get_coordinates(source.ufl_function_space()) - if utils.get_level(target_mesh)[1] > utils.get_level(source_mesh)[1]: - node_map = utils.fine_node_to_coarse_node_map - else: - node_map = utils.coarse_node_to_fine_node_map - # Have to do this, because the node set core size is not right for # this expanded stencil - for d in [target_coords, *coefficients]: - if d.function_space().mesh() is target_mesh: - continue - d.dat.global_to_local_begin(op2.READ) - d.dat.global_to_local_end(op2.READ) + for d in [source_coords, *coefficients]: + if d.function_space().mesh() is not target_mesh: + d.dat.global_to_local_begin(op2.READ) + d.dat.global_to_local_end(op2.READ) def parloop_arg(c, access): m_ = None if c.function_space().mesh() is target_mesh else node_map(Vtarget, c.function_space()) @@ -111,7 +109,7 @@ def prolong(coarse, fine): hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse)) _, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine)) refinements_per_level = hierarchy.refinements_per_level - repeat = (fine_level - coarse_level)*refinements_per_level + refine = (fine_level - coarse_level)*refinements_per_level next_level = coarse_level * refinements_per_level if needs_quadrature := not Vf.finat_element.has_pointwise_dual_basis: @@ -121,23 +119,22 @@ def prolong(coarse, fine): finest = fine Vfinest = finest.function_space() meshes = hierarchy._meshes - for j in range(repeat): + for j in range(refine): next_level += 1 - if j == repeat - 1 and not needs_quadrature: - fine = finest + if j == refine - 1 and not needs_quadrature: + tensor = finest else: - fine = Function(Vf.reconstruct(mesh=meshes[next_level])) - Vf = fine.function_space() - Vc = coarse.function_space() + tensor = None - fine_dual = ufl.TestFunction(Vf.dual()) + fine_dual = ufl.TestFunction(Vf.reconstruct(mesh=meshes[next_level]).dual()) ufl_interpolate = ufl.Interpolate(coarse_expr, fine_dual) - multigrid_transfer(ufl_interpolate, tensor=fine) + fine = multigrid_transfer(ufl_interpolate, tensor=tensor) if needs_quadrature: # Transfer to the actual target space - new_fine = finest if j == repeat-1 else Function(Vfinest.reconstruct(mesh=meshes[next_level])) + new_fine = finest if j == refine-1 else Function(Vfinest.reconstruct(mesh=meshes[next_level])) fine = new_fine.interpolate(fine) + coarse = fine coarse_expr = coarse return fine @@ -166,7 +163,7 @@ def restrict(fine_dual, coarse_dual): hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse_dual)) _, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine_dual)) refinements_per_level = hierarchy.refinements_per_level - repeat = (fine_level - coarse_level)*refinements_per_level + refine = (fine_level - coarse_level)*refinements_per_level next_level = fine_level * refinements_per_level if needs_quadrature := not Vf.finat_element.has_pointwise_dual_basis: @@ -175,13 +172,13 @@ def restrict(fine_dual, coarse_dual): coarsest = coarse_dual.zero() meshes = hierarchy._meshes - for j in range(repeat): + for j in range(refine): if needs_quadrature: # Transfer to the quadrature source space fine_dual = Function(Vq.reconstruct(mesh=meshes[next_level])).interpolate(fine_dual) next_level -= 1 - if j == repeat - 1: + if j == refine - 1: coarse_dual = coarsest else: coarse_dual = Function(Vc.reconstruct(mesh=meshes[next_level])) @@ -231,7 +228,7 @@ def inject(fine, coarse): hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse)) _, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine)) refinements_per_level = hierarchy.refinements_per_level - repeat = (fine_level - coarse_level)*refinements_per_level + refine = (fine_level - coarse_level)*refinements_per_level next_level = fine_level * refinements_per_level if needs_quadrature := not Vc.finat_element.has_pointwise_dual_basis: @@ -241,9 +238,9 @@ def inject(fine, coarse): coarsest = coarse.zero() Vcoarsest = coarsest.function_space() meshes = hierarchy._meshes - for j in range(repeat): + for j in range(refine): next_level -= 1 - if j == repeat - 1 and not needs_quadrature: + if j == refine - 1 and not needs_quadrature: coarse = coarsest else: coarse = Function(Vc.reconstruct(mesh=meshes[next_level])) @@ -274,7 +271,7 @@ def inject(fine, coarse): if needs_quadrature: # Transfer to the actual target space - new_coarse = coarsest if j == repeat - 1 else Function(Vcoarsest.reconstruct(mesh=meshes[next_level])) + new_coarse = coarsest if j == refine - 1 else Function(Vcoarsest.reconstruct(mesh=meshes[next_level])) coarse = new_coarse.interpolate(coarse) fine = coarse fine_expr = fine diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index a9bd1cf139..620a7d06c9 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -168,7 +168,7 @@ def _make_element_key(element): def prolong_kernel(ufl_interpolate): - dual_arg, operand = ufl_interpolate.argument_slots() + dual_arg, _ = ufl_interpolate.argument_slots() coarse, = extract_coefficients(ufl_interpolate) Vc = coarse.ufl_function_space() Vf = dual_arg.ufl_function_space().dual() @@ -353,7 +353,7 @@ def restrict_kernel(ufl_interpolate): def inject_kernel(ufl_interpolate): - dual_arg, operand = ufl_interpolate.argument_slots() + dual_arg, _ = ufl_interpolate.argument_slots() Vc = dual_arg.ufl_function_space().dual() if Vc.finat_element.is_dg(): fine, = extract_coefficients(ufl_interpolate)