Skip to content

Commit 68e2f74

Browse files
Merge pull request #2935 from devitocodes/wdr-restart
compiler: Fix ComponentAccess.dtype
2 parents 4c886eb + 166f4e1 commit 68e2f74

3 files changed

Lines changed: 23 additions & 2 deletions

File tree

devito/arch/archinfo.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,12 @@ class Device(Platform):
10081008
"warp" in NVidia GPUs and a "wavefront" in AMD GPUs.
10091009
"""
10101010

1011+
thread_group_slots = None
1012+
"""
1013+
Number of thread groups issue/execution slots per compute engine (e.g.
1014+
SM for Nvidia GPUs, CU for AMD GPUs).
1015+
"""
1016+
10111017
def __init__(self, name, cores_logical=None, cores_physical=None, isa='cpp',
10121018
max_threads_per_block=1024, max_threads_dimx=1024,
10131019
max_threads_dimy=1024, max_threads_dimz=64,
@@ -1091,6 +1097,7 @@ def march(self):
10911097
class NvidiaDevice(Device):
10921098

10931099
thread_group_size = 32
1100+
thread_group_slots = 4
10941101

10951102
max_mem_trans_nbytes = 128
10961103

@@ -1182,6 +1189,7 @@ class Blackwell(Hopper):
11821189
class AmdDevice(Device):
11831190

11841191
thread_group_size = 64
1192+
thread_group_slots = 4
11851193

11861194
max_mem_trans_nbytes = 256
11871195

devito/types/array.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,13 @@ def indices(self):
653653

654654
@property
655655
def dtype(self):
656-
return self.function.dtype
656+
try:
657+
return self.function.c0.dtype
658+
except AttributeError:
659+
# Vector-component access over a scalar symbol, e.g. a float4 register.
660+
if self.function.is_Symbol:
661+
return dtypes_vector_mapper.get_base_dtype(self.function.dtype)
662+
raise
657663

658664
@cacheit
659665
def sort_key(self, order=None):

tests/test_symbolics.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
SizeOf, VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions,
2121
retrieve_indexed, uxreplace
2222
)
23-
from devito.tools import CustomDtype, as_tuple
23+
from devito.tools import CustomDtype, as_tuple, dtypes_vector_mapper
2424
from devito.types import (
2525
Array, Bundle, ComponentAccess, FIndexed, LocalObject, Object, StencilDimension
2626
)
@@ -575,6 +575,13 @@ def test_component_access():
575575
assert cf2 == cf1
576576

577577

578+
def test_component_access_symbol_printing():
579+
acc = dSymbol(name='acc', dtype=dtypes_vector_mapper[(np.float32, 4)])
580+
expr = ComponentAccess(acc, 0)
581+
582+
assert ccode(sympy.Float('1.25')*expr, dtype=expr.dtype) == '1.250F*acc.x'
583+
584+
578585
def test_vector_access():
579586
grid = Grid(shape=(3, 3, 3))
580587

0 commit comments

Comments
 (0)