diff --git a/firedrake/preconditioners/__init__.py b/firedrake/preconditioners/__init__.py index 491a73657b..aa9ba1f3d7 100644 --- a/firedrake/preconditioners/__init__.py +++ b/firedrake/preconditioners/__init__.py @@ -13,3 +13,5 @@ from firedrake.preconditioners.hiptmair import * # noqa: F401 from firedrake.preconditioners.facet_split import * # noqa: F401 from firedrake.preconditioners.bddc import * # noqa: F401 +from firedrake.preconditioners.fieldsplit_snes import * # noqa: F401 +from firedrake.preconditioners.auxiliary_snes import * # noqa: F401 diff --git a/firedrake/preconditioners/auxiliary_snes.py b/firedrake/preconditioners/auxiliary_snes.py new file mode 100644 index 0000000000..21b50beb02 --- /dev/null +++ b/firedrake/preconditioners/auxiliary_snes.py @@ -0,0 +1,79 @@ +from firedrake.preconditioners.base import SNESBase +from firedrake.petsc import PETSc +from firedrake.dmhooks import get_appctx, get_function_space + +__all__ = ("AuxiliaryOperatorSNES",) + + +class AuxiliaryOperatorSNES(SNESBase): + prefix = "aux_" + + @PETSc.Log.EventDecorator() + def initialize(self, snes): + from firedrake import ( # ImportError if this is at file level + NonlinearVariationalSolver, + NonlinearVariationalProblem, + Function, TestFunction, Cofunction) + + ctx = get_appctx(snes.dm) + V = get_function_space(snes.dm).collapse() + + appctx = ctx.appctx + fcp = appctx.get("form_compiler_parameters") + + u = Function(V) + v = TestFunction(V) + + F, bcs, self.u = self.form(snes, u, v) + + self.b = Cofunction(V.dual()) + F += self.b + + prefix = snes.getOptionsPrefix() + self.prefix + + self.solver = NonlinearVariationalSolver( + NonlinearVariationalProblem( + F, self.u, bcs=bcs, + form_compiler_parameters=fcp), + appctx=appctx, options_prefix=prefix) + outer_snes = snes + inner_snes = self.solver.snes + inner_snes.incrementTabLevel(1, parent=outer_snes) + inner_snes.ksp.incrementTabLevel(1, parent=outer_snes) + inner_snes.ksp.pc.incrementTabLevel(1, parent=outer_snes) + + def update(self, snes): + pass + + @PETSc.Log.EventDecorator() + def step(self, snes, x, f, y): + from firedrake import errornorm + with self.u.dat.vec_wo as vec: + x.copy(vec) + # PETSc.Sys.Print(f"{x.norm() = }") + if f is not None: + with self.b.dat.vec_wo as vec: + f.copy(vec) + else: + self.b.zero() + # self.b.zero() + + # PETSc.Sys.Print(f"Before: {errornorm(self.un, self.u) = :.5e}") + PETSc.Sys.Print(f"Before: {errornorm(self.un1, self.u) = :.5e}") + self.solver.solve() + # PETSc.Sys.Print(f"After: {errornorm(self.un, self.u) = :.5e}") + PETSc.Sys.Print(f"After: {errornorm(self.un1, self.u) = :.5e}") + with self.u.dat.vec_ro as vec: + # PETSc.Sys.Print(f"{vec.norm() = }") + vec.copy(y) + y.aypx(-1, x) + # PETSc.Sys.Print(f"{y.norm() = }") + + def form(self, snes, u, v): + raise NotImplementedError + + def view(self, snes, viewer=None): + super().view(snes, viewer) + if hasattr(self, "solver"): + viewer.printfASCII("SNES to apply auxiliary inverse\n") + self.solver.snes.view(viewer) diff --git a/firedrake/preconditioners/fieldsplit_snes.py b/firedrake/preconditioners/fieldsplit_snes.py new file mode 100644 index 0000000000..179ce8aa47 --- /dev/null +++ b/firedrake/preconditioners/fieldsplit_snes.py @@ -0,0 +1,85 @@ +from firedrake.preconditioners.base import SNESBase +from firedrake.petsc import PETSc +from firedrake.dmhooks import get_appctx, get_function_space +from firedrake.function import Function + +__all__ = ("FieldsplitSNES",) + + +class FieldsplitSNES(SNESBase): + prefix = "fieldsplit_" + + # TODO: + # - Allow setting field grouping/ordering like fieldsplit + + @PETSc.Log.EventDecorator() + def initialize(self, snes): + from firedrake.variational_solver import NonlinearVariationalSolver # ImportError if we do this at file level + ctx = get_appctx(snes.dm) + W = get_function_space(snes.dm) + self.sol = ctx._problem.u_restrict + + # buffer to save solution to outer problem during solve + self.sol_outer = Function(self.sol.function_space()) + + # buffers for shuffling solutions during solve + self.sol_current = Function(self.sol.function_space()) + self.sol_new = Function(self.sol.function_space()) + + # options for setting up the fieldsplit + snes_prefix = snes.getOptionsPrefix() + 'snes_' + self.prefix + # options for each field + sub_prefix = snes.getOptionsPrefix() + self.prefix + + snes_options = PETSc.Options(snes_prefix) + self.fieldsplit_type = snes_options.getString('type', 'additive') + if self.fieldsplit_type not in ('additive', 'multiplicative'): + raise ValueError( + 'FieldsplitSNES option snes_fieldsplit_type must be' + ' "additive" or "multiplicative"') + + split_ctxs = ctx.split([(i,) for i in range(len(W))]) + + self.solvers = tuple( + NonlinearVariationalSolver( + ctx._problem, appctx=ctx.appctx, + options_prefix=sub_prefix+str(i)) + for i, ctx in enumerate(split_ctxs) + ) + + def update(self, snes): + pass + + @PETSc.Log.EventDecorator() + def step(self, snes, x, f, y): + # store current value of outer solution + self.sol_outer.assign(self.sol) + + # the full form in ctx now has the most up to date solution + with self.sol_current.dat.vec_wo as vec: + x.copy(vec) + self.sol.assign(self.sol_current) + + # The current snes solution x is held in sol_current, and we + # will place the new solution in sol_new. + # The solvers evaluate forms containing sol, so for each + # splitting type sol needs to hold: + # - additive: all fields need to hold sol_current values + # - multiplicative: fields need to hold sol_current before + # they are are solved for, and keep the updated sol_new + # values afterwards. + for solver, u, ucurr, unew in zip(self.solvers, + self.sol.subfunctions, + self.sol_current.subfunctions, + self.sol_new.subfunctions): + solver.solve() + unew.assign(u) + if self.fieldsplit_type == 'additive': + u.assign(ucurr) + + with self.sol_new.dat.vec_ro as vec: + vec.copy(y) + y.aypx(-1, x) + + # restore outer solution + self.sol.assign(self.sol_outer) diff --git a/tests/firedrake/regression/test_fieldsplit_snes.py b/tests/firedrake/regression/test_fieldsplit_snes.py new file mode 100644 index 0000000000..14f9530dd1 --- /dev/null +++ b/tests/firedrake/regression/test_fieldsplit_snes.py @@ -0,0 +1,200 @@ +from firedrake import * + + +def test_fieldsplit_snes(): + re = Constant(100) + nu = Constant(1/re) + + nx = 50 + dt = Constant(0.1) # CFL = dt*nx + + mesh = PeriodicUnitIntervalMesh(nx) + x, = SpatialCoordinate(mesh) + + Vu = VectorFunctionSpace(mesh, "CG", 2) + Vq = FunctionSpace(mesh, "DG", 1) + W = Vu*Vq + + w0 = Function(W) + u0, q0 = w0.subfunctions + u0.project(as_vector([0.5 + 1.0*sin(2*pi*x)])) + q0.interpolate(cos(2*pi*x)) + + def M(u, v): + return inner(u, v)*dx + + def Aburgers(u, v, nu): + return ( + inner(dot(u, nabla_grad(u)), v)*dx + + nu*inner(grad(u), grad(v))*dx + ) + + def Ascalar(q, p, u): + n = FacetNormal(mesh) + un = 0.5*(dot(u, n) + abs(dot(u, n))) + return (- q*div(u*p)*dx + + jump(un*q)*jump(p)*dS) + + # current and next timestep + w = Function(W) + wn = Function(W) + + u, q = split(w) + un, qn = split(wn) + + v, p = TestFunctions(W) + + # Trapezium rule + F = ( + M(un - u, v) + 0.5*dt*(Aburgers(un, v, nu) + Aburgers(u, v, nu)) + + M(qn - q, p) + 0.5*dt*(Ascalar(qn, p, un) + Ascalar(q, p, u)) + ) + + common_params = { + 'snes_converged_reason': None, + 'snes_monitor': None, + 'snes_rtol': 1e-8, + 'snes_atol': 1e-12, + 'ksp_converged_reason': None, + 'ksp_monitor': None, + } + + newton_params = { + 'snes_type': 'newtonls', + 'mat_type': 'aij', + 'ksp_type': 'preonly', + 'pc_type': 'lu', + } + + uparams = common_params | newton_params + qparams = common_params | newton_params | {'snes_type': 'ksponly'} + + python_params = { + 'snes_type': 'nrichardson', + 'npc_snes_type': 'python', + 'npc_snes_python_type': 'firedrake.FieldsplitSNES', + 'npc_snes_fieldsplit_type': 'additive', + 'npc_fieldsplit_0': uparams, + 'npc_fieldsplit_1': qparams, + } + + params = common_params | python_params + + w.assign(w0) + wn.assign(w0) + u, q = w.subfunctions + un, qn = wn.subfunctions + solver = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, wn), + solver_parameters=params, + options_prefix="") + + nsteps = 2 + for i in range(nsteps): + w.assign(wn) + solver.solve() + + +def M(u, v): + return inner(u, v)*dx + + +def A(u, v, nu): + return ( + inner(dot(u, nabla_grad(u)), v)*dx + + nu*inner(grad(u), grad(v))*dx + ) + +class AuxiliaryBurgersSNES(AuxiliaryOperatorSNES): + def form(self, snes, u, v): + appctx = self.get_appctx(snes) + nu = appctx["nu"] + dt = appctx["dt"] + un = appctx["un"] + un1 = appctx["un1"] + uh = (u + un)/2 + F = M(u - un, v) + dt*A(uh, v, nu) + self.un = un + self.un1 = un1 + return F, None, u + + +def test_auxiliary_snes(): + re = Constant(100) + re_aux = Constant(50) + + nu = Constant(1/re) + nu_aux = Constant(1/re_aux) + + nx = 50 + dt = Constant(0.1) # CFL = dt*nx + + mesh = PeriodicUnitIntervalMesh(nx) + x, = SpatialCoordinate(mesh) + + V = VectorFunctionSpace(mesh, "CG", 2) + + # current and next timestep + ic = as_vector([1.0 + 0.5*sin(2*pi*x)]) + un = Function(V).project(ic) + un1 = Function(V).project(ic) + + v = TestFunction(V) + + # Implicit midpoint rule + uh = (un + un1)/2 + F = M(un1 - un, v) + dt*A(uh, v, nu) + + solver_parameters = { + 'snes': { + 'view': ':snes_view.log', + 'converged_reason': None, + 'monitor': None, + 'rtol': 1e-8, + 'atol': 0, + 'max_it': 3, + 'convergence_test': 'skip', + 'linesearch_type': 'l2', + 'linesearch_damping': 1.0, + 'linesearch_monitor': None, + }, + 'snes_type': 'nrichardson', + 'npc_snes_type': 'python', + 'npc_snes_python_type': f'{__name__}.AuxiliaryBurgersSNES', + 'npc_aux': { + 'snes': { + 'converged_reason': None, + 'monitor': None, + 'rtol': 1e-4, + 'atol': 0, + 'max_it': 2, + 'convergence_test': 'skip', + }, + 'snes_type': 'newtonls', + 'mat_type': 'aij', + 'ksp_type': 'preonly', + 'pc_type': 'lu', + 'pc_factor_mat_solver_type': 'petsc', + }, + } + + appctx = { + "nu": nu_aux, + "dt": dt, + "un": un, + "un1": un1, + } + + solver = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, un1), + solver_parameters=solver_parameters, + options_prefix="fd", appctx=appctx) + + nsteps = 1 + for i in range(nsteps): + solver.solve() + un.assign(un1) + + +if __name__ == "__main__": + test_auxiliary_snes()