diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index ac17eeec8c..10159ecb1d 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -16,7 +16,8 @@ from devito.symbolics import estimate_cost from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype from devito.types import ( - CriticalRegion, Fence, ThreadCommit, ThreadPoolSync, ThreadWait, WeakFence + CriticalRegion, Fence, Indexed, TensorMove, ThreadArrive, ThreadCommit, + ThreadPoolSync, ThreadWait, WeakFence ) __all__ = ["Cluster", "ClusterGroup"] @@ -310,14 +311,47 @@ def is_critical_region(self): def is_thread_pool_sync(self): return self._is_type(ThreadPoolSync) + @cached_property + def is_shm_write(self): + return all(w._mem_shared for w in self.scope.writes) + @cached_property def is_thread_commit(self): return self._is_type(ThreadCommit) + @cached_property + def is_thread_arrive(self): + return self._is_type(ThreadArrive) + @cached_property def is_thread_wait(self): return self._is_type(ThreadWait) + @cached_property + def is_thread_sync(self): + return self.is_thread_pool_sync or self.is_thread_wait + + @cached_property + def is_word_move(self): + return (self._is_type(Indexed) and + all(e.rhs.function._mem_heap for e in self.exprs)) + + @cached_property + def is_tensor_move(self): + return self._is_type(TensorMove) + + @cached_property + def is_word_move_to_mem_shared(self): + return self.is_word_move and self.is_shm_write + + @cached_property + def is_tensor_move_to_mem_shared(self): + return self.is_tensor_move and self.is_shm_write + + @cached_property + def is_glb_load_to_mem_shared(self): + return self.is_word_move_to_mem_shared or self.is_tensor_move_to_mem_shared + @cached_property def is_async(self): """ @@ -557,6 +591,10 @@ def dspace(self): def is_halo_touch(self): return all(i.is_halo_touch for i in self) + @cached_property + def is_glb_load_to_mem_shared(self): + return all(i.is_glb_load_to_mem_shared for i in self) + @cached_property def dtype(self): """ diff --git a/devito/ir/iet/efunc.py b/devito/ir/iet/efunc.py index 1a17202140..e161e1936f 100644 --- a/devito/ir/iet/efunc.py +++ b/devito/ir/iet/efunc.py @@ -1,4 +1,6 @@ +from dataclasses import dataclass from functools import cached_property +from itertools import chain from devito.ir.iet.nodes import Call, Callable from devito.ir.iet.utils import derive_parameters @@ -11,6 +13,7 @@ 'CommCallable', 'DeviceCall', 'DeviceFunction', + 'EFuncMeta', 'ElementalCall', 'ElementalFunction', 'EntryFunction', @@ -21,6 +24,38 @@ ] +@dataclass(frozen=True) +class EFuncMeta: + + body: object = None + efuncs: tuple = () + includes: tuple = () + namespaces: tuple = () + libs: tuple = () + + @classmethod + def compose(cls, *items): + items = tuple(items) + + if not items: + return cls() + + return cls( + body=items[-1].body, + efuncs=tuple(chain.from_iterable(i.efuncs for i in items)), + includes=tuple(chain.from_iterable(i.includes for i in items)), + namespaces=tuple(chain.from_iterable(i.namespaces for i in items)), + libs=tuple(chain.from_iterable(i.libs for i in items)) + ) + + def __iter__(self): + yield self.body + yield self.efuncs + yield self.includes + yield self.namespaces + yield self.libs + + # ElementalFunction machinery class ElementalCall(Call): diff --git a/devito/ir/support/properties.py b/devito/ir/support/properties.py index 664105b704..9e787a8b9e 100644 --- a/devito/ir/support/properties.py +++ b/devito/ir/support/properties.py @@ -97,11 +97,6 @@ def __init__(self, name, val=None): A Dimension along which prefetching is feasible and beneficial. """ -PREFETCHABLE_SHM = Property('prefetchable-shm') -""" -A Dimension along which shared-memory prefetching is feasible and beneficial. -""" - INIT_CORE_SHM = Property('init-core-shm') """ A Dimension along which the shared-memory CORE data region is initialized. @@ -190,32 +185,6 @@ def update_properties(properties, exprs): if not exprs: return properties - # Auto-detect prefetchable Dimensions - dims = set() - flag = False - for e in as_tuple(exprs): - w, r = e.args - - # Ensure it's in the form `Indexed = Indexed` - try: - wf, rf = w.function, r.function - except AttributeError: - break - - if not rf or not wf._mem_shared: - break - dims.update({d.parent for d in wf.dimensions if d.parent in properties}) - - if not rf._mem_heap: - break - else: - flag = True - - if flag: - properties = properties.prefetchable_shm(dims) - else: - properties = properties.drop(properties=PREFETCHABLE_SHM) - # Remove properties that are trivially incompatible with `exprs` if not all(e.lhs.function._mem_shared for e in as_tuple(exprs)): drop = {INIT_CORE_SHM, INIT_HALO_LEFT_SHM, INIT_HALO_RIGHT_SHM} @@ -284,9 +253,6 @@ def prefetchable(self, dims, v=PREFETCHABLE): m[d] = self.get(d, set()) | {v} return Properties(m) - def prefetchable_shm(self, dims): - return self.prefetchable(dims, PREFETCHABLE_SHM) - def block(self, dims, kind='default'): if kind == 'default': p = TILABLE @@ -357,9 +323,6 @@ def _is_property_any(self, dims, v): def is_prefetchable(self, dims=None, v=PREFETCHABLE): return self._is_property_any(dims, PREFETCHABLE) - def is_prefetchable_shm(self, dims=None): - return self._is_property_any(dims, PREFETCHABLE_SHM) - def is_core_init(self, dims=None): return self._is_property_any(dims, INIT_CORE_SHM) diff --git a/devito/passes/clusters/aliases.py b/devito/passes/clusters/aliases.py index 21d4ba684a..a51de1f25c 100644 --- a/devito/passes/clusters/aliases.py +++ b/devito/passes/clusters/aliases.py @@ -139,6 +139,10 @@ def _aliases_from_clusters(self, cgroup, exclude, meta): # [Schedule]_m -> Schedule (s.t. best memory/flops trade-off) schedule, exprs = self._select(variants) + # Schedule -> Schedule (optimization) + if self.opt_maxpar: + schedule = optimize_schedule_maxpar(schedule) + # Schedule -> Schedule (optimization) if self.opt_rotate: schedule = optimize_schedule_rotations(schedule, self.sregistry) @@ -664,7 +668,6 @@ def lower_aliases(aliases, meta, maxpar): """ Create a Schedule from an AliasList. """ - stampcache = {} dmapper = {} processed = [] for a in aliases: @@ -704,12 +707,6 @@ def lower_aliases(aliases, meta, maxpar): # use `<1>` as stamp, which is what appears in `ispace` interval = interval.lift(i.stamp) - # We further bump the interval stamp if we were requested to trade - # fusion for more collapse-parallelism - if maxpar: - stamp = stampcache.setdefault(interval.dim, Stamp()) - interval = interval.lift(stamp) - writeto.append(interval) intervals.append(interval) @@ -853,6 +850,30 @@ def optimize_schedule_rotations(schedule, sregistry): return schedule.rebuild(*processed, rmapper=rmapper) +def optimize_schedule_maxpar(schedule): + """ + Bump the IterationSpace' stamp trading fusion for more collapse-parallelism. + """ + key = lambda i: (i.writeto, i.ispace) + + processed = [] + for (writeto0, ispace0), group in groupby(schedule, key=key): + g = list(group) + + stamp = Stamp() + dims = writeto0.itdims + + writeto = writeto0.lift(dims, stamp) + ispace = ispace0.lift(dims, stamp) + + processed.extend([ + ScheduledAlias(pivot, writeto, ispace, aliaseds, indicess) + for pivot, _, _, aliaseds, indicess in g + ]) + + return schedule.rebuild(*processed) + + def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype, opt_minmem): """ diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index a92c5495f1..494ebe7490 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -232,7 +232,7 @@ def _key(self, c): weak.append(c.properties.is_core_init()) # Prefetchable Clusters should get merged, if possible - weak.append(c.properties.is_prefetchable_shm()) + weak.append(c.is_glb_load_to_mem_shared) # Promoting adjacency of IndexDerivatives will maximize their reuse weak.append(any(search(c.exprs, IndexDerivative))) diff --git a/devito/types/dense.py b/devito/types/dense.py index 0de36d3093..5fcbe8d1d7 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -1593,15 +1593,6 @@ def _time_buffering(self): def _time_buffering_default(self): return self._time_buffering and not isinstance(self.save, Buffer) - def _evaluate(self, **kwargs): - retval = super()._evaluate(**kwargs) - if not self._time_buffering and not retval.is_Function: - # Saved TimeFunction might need streaming, expand interpolations - # for easier processing - return retval.evaluate - else: - return retval - def _arg_check(self, args, intervals, **kwargs): super()._arg_check(args, intervals, **kwargs) diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 8300d0e80c..c6aceb42e5 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -11,10 +11,11 @@ from functools import cached_property import numpy as np +from sympy import Expr from devito.exceptions import InvalidArgument from devito.parameters import configuration -from devito.symbolics import search +from devito.symbolics import Reserved, Terminal, search from devito.tools import as_list, as_tuple, is_integer from devito.types.array import Array, ArrayObject from devito.types.basic import Scalar, Symbol @@ -35,7 +36,9 @@ 'QueueID', 'SharedData', 'TBArray', + 'TensorMove', 'ThreadArray', + 'ThreadArrive', 'ThreadCommit', 'ThreadID', 'ThreadPoolSync', @@ -365,12 +368,24 @@ class ThreadCommit(Fence): pass +class ThreadArrive(Fence): + + """ + A generic arrive operation for a single thread, typically used to signal + the arrival at a certain point through a suitable synchronization object. + """ + + pass + + class ThreadWait(Fence): """ A generic wait operation for a single thread, typically used to synchronize - after a memory operation issued at a specific program point with a - ThreadCommit operation. + with other threads over: + + * a memory operation issued by a prior ThreadCommit operation. + * the consumption of a shared resource via a ThreadArrive operation. """ pass @@ -386,3 +401,18 @@ def __init_finalize__(self, *args, **kwargs): kwargs['liveness'] = 'eager' super().__init_finalize__(*args, **kwargs) + + +class TensorMove(Expr, Reserved, Terminal): + + """ + Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher + level of the memory hierarchy + """ + + func = Reserved._rebuild + + def _ccode(self, printer): + return str(self) + + _sympystr = _ccode diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 8909d5a9fa..7ed57b6c91 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -2,10 +2,11 @@ import pytest from sympy import Float, Symbol, diff, simplify, sympify +from conftest import assert_structure from devito import ( NODE, ConditionalDimension, Eq, Function, Grid, Operator, TensorFunction, - TimeFunction, VectorFunction, centered, cos, curl, div, grad, laplace, left, right, - sin + TensorTimeFunction, TimeFunction, VectorFunction, centered, cos, curl, div, grad, + laplace, left, right, sin ) from devito.finite_differences import Derivative, Differentiable, diffify from devito.finite_differences.differentiable import ( @@ -927,6 +928,17 @@ def test_param_stagg_add(self): # Addition should apply the same logic as above for each term assert simplify(eq2.evaluate.rhs - (expect1 + expect0)) == 0 + def test_unexpand_space_interp_w_saved_timefunc(self): + grid = Grid(shape=(3, 3, 3)) + + tau = TensorTimeFunction(name="tau", grid=grid, save=10) + + eq = Eq(tau[0, 1], tau[2, 2]) + + op = Operator(eq, opt=('advanced', {'expand': False})) + + assert_structure(op, ['t,x,y,z', 't,x,y,z,i1', 't,x,y,z,i1,i0']) + class TestTwoStageEvaluation: