Skip to content

Commit a844bac

Browse files
Merge pull request #2909 from devitocodes/fix-break-clusters
compiler: catch corner case read after write
2 parents 6ec492d + d2cbc85 commit a844bac

5 files changed

Lines changed: 41 additions & 7 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,12 @@ def _break_for_parallelism(self, scope, dim, timestamp):
228228
# Would break a dependence on storage
229229
return False
230230

231-
if any(dep.is_carried(i) for i in candidates):
231+
if any(dep.as_logical.is_carried(i) for i in candidates):
232+
# If, from a semantic viewpoint, `i` is a purely sequential
233+
# Dimension, give up
232234
test0 = dep.is_flow and dep.is_lex_negative
233235
test1 = dep.is_anti and dep.is_lex_positive
234236
if test0 or test1:
235-
# Would break a data dependence
236237
return False
237238

238239
test = test or (bool(dep.cause & candidates) and not dep.is_lex_equal)

devito/ir/equations/equation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class IREq(sympy.Eq, Pickable):
3030
__rargs__ = ('lhs', 'rhs')
3131
__rkwargs__ = ('ispace', 'conditionals', 'implicit_dims', 'operation')
3232

33+
def _hashable_content(self):
34+
return (*super()._hashable_content(),
35+
*tuple(getattr(self, i) for i in self.__rkwargs__))
36+
3337
@property
3438
def is_Scalar(self):
3539
return self.lhs.is_Symbol
@@ -302,7 +306,7 @@ def __new__(cls, *args, **kwargs):
302306
setattr(expr, f'_{i}', v)
303307
else:
304308
expr._ispace = kwargs['ispace']
305-
expr._conditionals = kwargs.get('conditionals', frozendict())
309+
expr._conditionals = kwargs.get('conditionals', {})
306310
expr._implicit_dims = input_expr.implicit_dims
307311
expr._operation = Operation.detect(input_expr)
308312
elif len(args) == 2:
@@ -313,6 +317,10 @@ def __new__(cls, *args, **kwargs):
313317
else:
314318
raise ValueError(f"Cannot construct ClusterizedEq from args={str(args)} "
315319
f"and kwargs={str(kwargs)}")
320+
321+
# Immutability (and thus hashability, etc)
322+
expr._conditionals = frozendict(expr._conditionals)
323+
316324
return expr
317325

318326
func = IREq._rebuild

devito/ir/support/basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,10 @@ def is_iaw(self):
698698
def is_reduction(self):
699699
return self.source.is_reduction or self.sink.is_reduction
700700

701+
@cached_property
702+
def as_logical(self):
703+
return LogicalDependence(self.source, self.sink)
704+
701705
@memoized_meth
702706
def is_const(self, dim):
703707
"""
@@ -1105,6 +1109,7 @@ def d_flow_gen(self):
11051109
continue
11061110

11071111
distance = dependence.distance
1112+
11081113
try:
11091114
is_flow = distance > 0 or (r.lex_ge(w) and distance == 0)
11101115
except TypeError:

devito/ir/support/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def detect_io(exprs, relax=False):
278278
if rule(f):
279279
writes.append(f)
280280

281-
return filter_sorted(reads), filter_sorted(writes)
281+
return tuple(filter_sorted(reads)), tuple(filter_sorted(writes))
282282

283283

284284
def pull_dims(exprs, flag=True):

tests/test_fission.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from conftest import assert_structure
44
from devito import (
5-
Buffer, Eq, Function, Grid, Inc, Operator, SubDimension, SubDomain, TimeFunction,
6-
solve
5+
NODE, Buffer, Eq, Function, Grid, Inc, Operator, SubDimension, SubDomain,
6+
TimeFunction, solve
77
)
88
from devito.ir.iet import retrieve_iteration_tree
99
from devito.ir.support.properties import PARALLEL
@@ -131,7 +131,7 @@ def test_issue_1921():
131131
assert np.all(g.data == g1.data)
132132

133133

134-
def test_buffer1_fissioning():
134+
def test_buffer1_v0():
135135
"""
136136
Tests an edge case whereby inability to spot the equivalence of
137137
`f.forward`/`backward` and `f` when using `Buffer(1)` would cause
@@ -196,3 +196,23 @@ def define(self, dimensions):
196196
# Two loop nests: free-surface-like and update-like
197197
assert_structure(op, ['t,x,y,z', 't,x0_blk0,y0_blk0,x,y,z'],
198198
't,x,y,z,x0_blk0,y0_blk0,x,y,z')
199+
200+
201+
def test_buffer1_v1():
202+
grid = Grid((11, 11, 11))
203+
x, y, z = grid.dimensions
204+
205+
image_vs = Function(name='image_vs', grid=grid, space_order=1, staggered=NODE)
206+
p_back_xy = TimeFunction(name='p_back_xy', grid=grid, staggered=(x, y),
207+
space_order=4, time_order=1, save=Buffer(1))
208+
209+
eqns = [Eq(image_vs, p_back_xy + image_vs),
210+
Eq(p_back_xy.backward, p_back_xy)]
211+
212+
op = Operator(eqns)
213+
214+
assert_structure(
215+
op,
216+
['t,x0_blk0,y0_blk0,x,y,z', 't,x1_blk0,y1_blk0,x,y,z'],
217+
'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz'
218+
)

0 commit comments

Comments
 (0)