diff --git a/tomobar/methodsDIR.py b/tomobar/methodsDIR.py index 5d5f8122a..1ab1d1686 100644 --- a/tomobar/methodsDIR.py +++ b/tomobar/methodsDIR.py @@ -5,6 +5,7 @@ * :func:`RecToolsDIR.FBP` Filtered Back Projection 2D/3D (ASTRA with the custom built filter). """ +from typing import Literal import numpy as np import scipy.fftpack @@ -36,33 +37,36 @@ def __init__( CenterRotOffset, # Centre of Rotation (CoR) scalar or a vector AnglesVec, # Array of angles in radians ObjSize, # A scalar to define reconstructed object dimensions + projector: Literal["fourier", "astra"] = "astra", device_projector="gpu", # Choose the device to be 'cpu' or 'gpu' OR provide a GPU index (integer) of a specific device ): device_projector, GPUdevice_index = _parse_device_argument(device_projector) if DetectorsDimV == 0 or DetectorsDimV is None: self.geom = "2D" - self.Atools = AstraTools2D( - DetectorsDimH, - DetectorsDimH_pad, - AnglesVec, - CenterRotOffset, - ObjSize, - device_projector, - GPUdevice_index, - ) + if projector == "astra": + self.Atools = AstraTools2D( + DetectorsDimH, + DetectorsDimH_pad, + AnglesVec, + CenterRotOffset, + ObjSize, + device_projector, + GPUdevice_index, + ) else: self.geom = "3D" - self.Atools = AstraTools3D( - DetectorsDimH, - DetectorsDimH_pad, - DetectorsDimV, - AnglesVec, - CenterRotOffset, - ObjSize, - device_projector, - GPUdevice_index, - ) + if projector == "astra": + self.Atools = AstraTools3D( + DetectorsDimH, + DetectorsDimH_pad, + DetectorsDimV, + AnglesVec, + CenterRotOffset, + ObjSize, + device_projector, + GPUdevice_index, + ) def FORWPROJ(self, data: np.ndarray, **kwargs) -> np.ndarray: """Module to perform forward projection of 2d/3d data numpy array diff --git a/tomobar/methodsDIR_CuPy.py b/tomobar/methodsDIR_CuPy.py index f2f1259b5..e1ed765b0 100644 --- a/tomobar/methodsDIR_CuPy.py +++ b/tomobar/methodsDIR_CuPy.py @@ -5,12 +5,14 @@ * :func:`RecToolsDIRCuPy.FOURIER_INV` - Fourier direct reconstruction on unequally spaced grids (interpolation in image space), aka log-polar method [NIKITIN2017]_. """ +from typing import Literal, Tuple import numpy as np -from numpy import float32 import math import cupy as cp from cupyx.scipy.fft import fft, ifft2, rfftfreq, rfft, irfft +from cupyx.scipy.fftpack import get_fft_plan +from tomobar.supp.memory_estimator_helpers import DeviceMemStack from tomobar.supp.suppTools import check_kwargs, _apply_horiz_detector_padding from tomobar.supp.funcs import _data_dims_swapper from tomobar.fourier import _filtersinc3D_cupy, calc_filter @@ -42,8 +44,14 @@ def __init__( CenterRotOffset, # The Centre of Rotation scalar or a vector AnglesVec, # Array of projection angles in radians ObjSize, # Reconstructed object dimensions (scalar) + projector: Literal["fourier", "astra"] = "astra", device_projector=0, # set an index (integer) of a specific GPU device ): + self.detectors_x_pad = DetectorsDimH_pad + self.centre_of_rotation = CenterRotOffset + self.angles_vec = AnglesVec + self.recon_size = ObjSize + super().__init__( DetectorsDimH, DetectorsDimH_pad, @@ -51,6 +59,7 @@ def __init__( CenterRotOffset, AnglesVec, ObjSize, + projector, device_projector, ) # if DetectorsDimV == 0 or DetectorsDimV is None: @@ -138,7 +147,9 @@ def FBP(self, data: cp.ndarray, **kwargs) -> cp.ndarray: reconstruction = self.Atools._backprojCuPy(data) # 3d backprojecting return check_kwargs(reconstruction, **kwargs) - def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: + def FOURIER_INV( + self, data: cp.ndarray | Tuple[int, int, int], **kwargs + ) -> cp.ndarray: """Fourier direct inversion in 3D on unequally spaced (also called as NonUniform FFT/NUFFT) grids using CuPy array as an input, see more in [NIKITIN2017]_. This implementation is originated from V. Nikitin's CUDA-C implementation: https://github.com/nikitinvv/radonusfft and TomoCuPy package. @@ -237,8 +248,14 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: unpadding_mul_phi = module.get_function("unpadding_mul_phi") # initialisation - [nz, nproj, data_n] = data.shape - recon_size = self.Atools.recon_size + mem_stack = DeviceMemStack().instance() + if mem_stack: + mem_stack.malloc(np.prod(data) * kwargs["data_dtype"].itemsize) + [nz, nproj, data_n] = data + else: + [nz, nproj, data_n] = data.shape + + recon_size = self.recon_size if recon_size > data_n: raise ValueError( "The reconstruction size {} should not be larger than the size of the horizontal detector {}".format( @@ -253,13 +270,16 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: nz += odd_vert if odd_horiz or odd_vert: - data_p = cp.zeros((nz, nproj, data_n), dtype=cp.float32) - data_p[: nz - odd_vert, :, : data_n - odd_horiz] = data - data_p[: nz - odd_vert, :, -odd_horiz] = data[..., -odd_horiz] - data = data_p - del data_p - - n = data_n + self.Atools.detectors_x_pad * 2 + padding * 2 + if mem_stack: + mem_stack.malloc(np.prod((nz, nproj, data_n)) * cp.float32().itemsize) + else: + data_p = cp.zeros((nz, nproj, data_n), dtype=cp.float32) + data_p[: nz - odd_vert, :, : data_n - odd_horiz] = data + data_p[: nz - odd_vert, :, -odd_horiz] = data[..., -odd_horiz] + data = data_p + del data_p + + n = data_n + self.detectors_x_pad * 2 + padding * 2 if power_of_2_cropping: n_pow2 = 2 ** math.ceil(math.log2(n)) if 0.9 < n / n_pow2: @@ -268,79 +288,232 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: # Limit the center size parameter center_size = min(center_size, n * 2) - rotation_axis = self.Atools.centre_of_rotation + 0.5 - theta = cp.array(-self.Atools.angles_vec, dtype=cp.float32) + if mem_stack: + mem_stack.malloc(np.prod(self.angles_vec.shape) * cp.float32().itemsize) + + theta = cp.array(-self.angles_vec, dtype=cp.float32) + if center_size >= _CENTER_SIZE_MIN: - sorted_theta_indices = cp.argsort(theta) - sorted_theta = theta[sorted_theta_indices] - sorted_theta_cpu = sorted_theta.get() + if mem_stack: + mem_stack.malloc(np.prod(self.angles_vec.shape) * np.int64().itemsize) + mem_stack.malloc(np.prod(self.angles_vec.shape) * np.float32().itemsize) + + sorted_theta_cpu = cp.sort(theta).get() + else: + sorted_theta_indices = cp.argsort(theta) + sorted_theta = theta[sorted_theta_indices] + sorted_theta_cpu = sorted_theta.get() theta_full_range = abs(sorted_theta_cpu[nproj - 1] - sorted_theta_cpu[0]) angle_range_pi_count = 1 + int(np.ceil(theta_full_range / math.pi)) - angle_range = cp.zeros( - [center_size, center_size, 1 + angle_range_pi_count * 2], - dtype=cp.uint16, - ) + + if mem_stack: + mem_stack.malloc( + np.prod((center_size, center_size, (1 + angle_range_pi_count * 2))) + * cp.uint16().itemsize + ) + else: + angle_range = cp.zeros( + [center_size, center_size, 1 + angle_range_pi_count * 2], + dtype=cp.uint16, + ) # usfft parameters eps = 1e-4 # accuracy of usfft mu = -np.log(eps) / (2 * n * n) - m = int( - np.ceil( - 2 * n * 1 / np.pi * np.sqrt(-mu * np.log(eps) + (mu * n) * (mu * n) / 4) + + # STEP0: FBP filtering + if mem_stack: + tmp_p = self._fbp_filtering_estimator( + data_n, + n, + nproj, + nz, + power_of_2_oversampling, + oversampling_level, + filter_vol_chunk_count, + filter_proj_chunk_count, + min_mem_usage_filter, + ) + + if odd_horiz or odd_vert: + mem_stack.free(np.prod((nz, nproj, data_n)) * cp.float32().itemsize) + else: + tmp_p = self._fbp_filtering( + data, + data_n, + n, + nproj, + nz, + power_of_2_oversampling, + oversampling_level, + filter_type, + cutoff_freq, + filter_vol_chunk_count, + filter_proj_chunk_count, + min_mem_usage_filter, + ) + + del data + + # Memory clean up of interpolation extra arrays + if mem_stack: + datac, fde = self._setup_backprojection_input_estimator(n, nproj, nz) + (tmp_p_shape, tmp_p_dtype) = tmp_p + mem_stack.free(np.prod(tmp_p_shape) * tmp_p_dtype.itemsize) + else: + (datac, fde) = self._setup_backprojection_input( + tmp_p, n, nproj, nz, center_size, r2c_c1dfftshift ) + + del tmp_p + + # BACKPROJECTION + if mem_stack: + self._fft_and_interpolation_estimator(datac) + (datac_shape, datac_dtype) = datac + mem_stack.free(np.prod(datac_shape) * datac_dtype.itemsize) + else: + self._fft_and_interpolation( + datac, + fde, + n, + nproj, + nz, + center_size, + block_dim, + block_dim_center, + theta, + sorted_theta if center_size >= _CENTER_SIZE_MIN else None, + sorted_theta_indices if center_size >= _CENTER_SIZE_MIN else None, + angle_range_pi_count if center_size >= _CENTER_SIZE_MIN else None, + angle_range if center_size >= _CENTER_SIZE_MIN else None, + eps, + mu, + c1dfftshift, + gather_kernel_partial, + gather_kernel_center_angle_based_prune, + gather_kernel_center, + gather_kernel, + ) + + del datac + + # STEP3: ifft 2d + if mem_stack: + self.ifft_gathered_projections_estimator( + fde, nz, chunk_count, min_mem_usage_ifft2 + ) + else: + self.ifft_gathered_projections( + fde, n, nz, chunk_count, min_mem_usage_ifft2, c2dfftshift + ) + + # Unpadded recon output size + if mem_stack: + recon_up_shape = self.unpad_reconstructed_data_estimator( + n, nz, odd_horiz, odd_vert, recon_size + ) + + (fde_shape, fde_dtype) = fde + mem_stack.free(np.prod(fde_shape) * fde_dtype.itemsize) + else: + recon_up = self.unpad_reconstructed_data( + fde, + n, + nproj, + nz, + odd_horiz, + odd_vert, + recon_size, + mu, + unpadding_mul_phi, + ) + + del fde + + if mem_stack: + mem_stack.malloc(np.prod(recon_up_shape[1:]) * cp.float32().itemsize) + mem_stack.malloc(np.prod(recon_up_shape) * cp.float32().itemsize) + mem_stack.free(np.prod(recon_up_shape) * cp.float32().itemsize) + mem_stack.free(np.prod(recon_up_shape[1:]) * cp.float32().itemsize) + return recon_up_shape + + return check_kwargs( + recon_up, + **kwargs, ) + def _fbp_filtering( + self, + data: cp.ndarray, + raw_detector_width: int, + detector_width: int, + projection_count: int, + detector_height: int, + power_of_2_oversampling: bool, + oversampling_level: int, + filter_type: str, + cutoff_freq, + filter_vol_chunk_count: int, + filter_proj_chunk_count: int, + min_mem_usage_filter: bool, + ) -> cp.ndarray: # init filter if power_of_2_oversampling: - ne = 2 ** math.ceil(math.log2(data_n * 3)) - if n > ne: - ne = 2 ** math.ceil(math.log2(n)) + oversampled_detector_width = 2 ** math.ceil( + math.log2(raw_detector_width * 3) + ) + if detector_width > oversampled_detector_width: + oversampled_detector_width = 2 ** math.ceil(math.log2(detector_width)) else: - ne = int(oversampling_level * data_n) - ne = max(ne, n) + oversampled_detector_width = int(oversampling_level * raw_detector_width) + oversampled_detector_width = max(oversampled_detector_width, detector_width) - padding_m = ne // 2 - data_n // 2 - unpad_m = ne // 2 - n // 2 - unpad_p = ne // 2 + n // 2 + padding_m = oversampled_detector_width // 2 - raw_detector_width // 2 + unpad_m = oversampled_detector_width // 2 - detector_width // 2 + unpad_p = oversampled_detector_width // 2 + detector_width // 2 - wfilter = calc_filter(ne, filter_type, cutoff_freq) + rotation_axis = self.centre_of_rotation + 0.5 - # STEP0: FBP filtering - t = rfftfreq(ne).astype(cp.float32) + wfilter = calc_filter(oversampled_detector_width, filter_type, cutoff_freq) + t = rfftfreq(oversampled_detector_width).astype(cp.float32) w = wfilter * cp.exp(-2 * cp.pi * 1j * t * (rotation_axis)) # FBP filtering output - tmp_p = cp.empty((nz, nproj, n), dtype=cp.float32) + tmp_p = cp.empty( + (detector_height, projection_count, detector_width), dtype=cp.float32 + ) if min_mem_usage_filter: - filter_vol_chunk_count = nz - - if min_mem_usage_ifft2: - chunk_count = nz // 2 + filter_vol_chunk_count = detector_height - slice_count_per_chunk = np.ceil(nz / filter_vol_chunk_count) + slice_count_per_chunk = int(np.ceil(detector_height / filter_vol_chunk_count)) + projection_count_per_projection_chunk = int( + np.ceil(projection_count / filter_proj_chunk_count) + ) # Loop over the chunks for chunk_index in range(0, filter_vol_chunk_count): - slice_start_index = min(chunk_index * slice_count_per_chunk, nz) - slice_end_index = min((chunk_index + 1) * slice_count_per_chunk, nz) + slice_start_index = min( + chunk_index * slice_count_per_chunk, detector_height + ) + slice_end_index = min( + (chunk_index + 1) * slice_count_per_chunk, detector_height + ) if slice_start_index >= slice_end_index: break # processing by chunks over the second dimension # to avoid increased data sizes due to oversampling - projection_count_per_projection_chunk = np.ceil( - nproj / filter_proj_chunk_count - ) for projection_chunk_index in range(filter_proj_chunk_count): projection_start_index = min( projection_chunk_index * projection_count_per_projection_chunk, - nproj, + projection_count, ) projection_end_index = min( (projection_chunk_index + 1) * projection_count_per_projection_chunk, - nproj, + projection_count, ) if projection_start_index >= projection_end_index: break @@ -365,55 +538,229 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: del tmp - # Memory clean up of filter and input data - del data, t, wfilter, w + # Memory clean up of filter data + del t, wfilter, w + return tmp_p - # BACKPROJECTION + def _fbp_filtering_estimator( + self, + raw_detector_width: int, + detector_width: int, + projection_count: int, + detector_height: int, + power_of_2_oversampling: bool, + oversampling_level: int, + filter_vol_chunk_count: int, + filter_proj_chunk_count: int, + min_mem_usage_filter: bool, + ) -> tuple: + # init filter + if power_of_2_oversampling: + oversampled_detector_width = 2 ** math.ceil( + math.log2(raw_detector_width * 3) + ) + if detector_width > oversampled_detector_width: + oversampled_detector_width = 2 ** math.ceil(math.log2(detector_width)) + else: + oversampled_detector_width = int(oversampling_level * raw_detector_width) + oversampled_detector_width = max(oversampled_detector_width, detector_width) + + mem_stack = DeviceMemStack.instance() + mem_stack.malloc((oversampled_detector_width // 2 + 1) * np.float32().itemsize) + mem_stack.malloc((oversampled_detector_width // 2 + 1) * np.float32().itemsize) + mem_stack.malloc( + (oversampled_detector_width // 2 + 1) * np.complex64().itemsize + ) + + # FBP filtering output + tmp_p_shape = (detector_height, projection_count, detector_width) + tmp_p_dtype = cp.float32() + mem_stack.malloc(np.prod(tmp_p_shape) * tmp_p_dtype.itemsize) + + if min_mem_usage_filter: + filter_vol_chunk_count = detector_height + + slice_count_per_chunk = int(np.ceil(detector_height / filter_vol_chunk_count)) + projection_count_per_projection_chunk = int( + np.ceil(projection_count / filter_proj_chunk_count) + ) + # Loop over the chunks + rfft_input = cp.empty( + ( + slice_count_per_chunk, + projection_count_per_projection_chunk, + oversampled_detector_width, + ), + cp.float32, + ) + rfft_input_nbytes = rfft_input.nbytes + mem_stack.malloc(rfft_input_nbytes) + + rfft_plan = get_fft_plan(rfft_input, axes=(2), value_type="R2C") + rfft_plan_work_area_mem_size = rfft_plan.work_area.mem.size + del rfft_input, rfft_plan + mem_stack.malloc(rfft_plan_work_area_mem_size) + + rfft_output = cp.empty( + ( + slice_count_per_chunk, + projection_count_per_projection_chunk, + oversampled_detector_width // 2 + 1, + ), + cp.complex64, + ) + rfft_output_nbytes = rfft_output.nbytes + mem_stack.malloc(rfft_output_nbytes) + mem_stack.free(rfft_output_nbytes) + + irfft_plan = get_fft_plan(rfft_output, axes=(2), value_type="C2R") + irfft_plan_work_area_mem_size = irfft_plan.work_area.mem.size + del rfft_output, irfft_plan + mem_stack.malloc(irfft_plan_work_area_mem_size) + + irfft_output_size = ( + np.prod( + ( + slice_count_per_chunk, + projection_count_per_projection_chunk, + oversampled_detector_width, + ) + ) + * cp.float32().itemsize + ) + + mem_stack.malloc(irfft_output_size) + mem_stack.free(irfft_output_size) + mem_stack.free(rfft_input_nbytes) + + # Memory clean up of filter data + mem_stack.free((oversampled_detector_width // 2 + 1) * np.float32().itemsize) + mem_stack.free((oversampled_detector_width // 2 + 1) * np.float32().itemsize) + mem_stack.free((oversampled_detector_width // 2 + 1) * np.complex64().itemsize) + + return (tmp_p_shape, tmp_p_dtype) + + def _setup_backprojection_input( + self, + tmp_p: cp.ndarray, + detector_width: int, + projection_count: int, + detector_height: int, + center_size: int, + r2c_c1dfftshift: cp.RawKernel, + ) -> tuple[cp.ndarray, cp.ndarray]: # input data - datac = cp.empty((nz // 2, nproj, n), dtype=cp.complex64) + datac = cp.empty( + (detector_height // 2, projection_count, detector_width), + dtype=cp.complex64, + ) # fft, reusable by chunks if center_size >= _CENTER_SIZE_MIN: - fde = cp.empty([nz // 2, 2 * n, 2 * n], dtype=cp.complex64) + fde = cp.empty( + [detector_height // 2, 2 * detector_width, 2 * detector_width], + dtype=cp.complex64, + ) else: - fde = cp.zeros([nz // 2, 2 * n, 2 * n], dtype=cp.complex64) + fde = cp.zeros( + [detector_height // 2, 2 * detector_width, 2 * detector_width], + dtype=cp.complex64, + ) # STEP1: fft 1d r2c_c1dfftshift( ( - int(np.ceil(n / 32)), - int(np.ceil(nproj / 32)), - np.int32(nz // 2), + int(np.ceil(detector_width / 32)), + int(np.ceil(projection_count / 32)), + np.int32(detector_height // 2), ), (32, 32, 1), - (tmp_p, datac, n, nproj, nz // 2), + (tmp_p, datac, detector_width, projection_count, detector_height // 2), ) - # Memory clean up of interpolation extra arrays - del tmp_p + return (datac, fde) + def _setup_backprojection_input_estimator( + self, + detector_width: int, + projection_count: int, + detector_height: int, + ): + mem_stack = DeviceMemStack.instance() + datac_shape = (detector_height // 2, projection_count, detector_width) + datac_dtype = cp.complex64() + fde_shape = (detector_height // 2, 2 * detector_width, 2 * detector_width) + fde_dtype = cp.complex64() + mem_stack.malloc(np.prod(datac_shape) * datac_dtype.itemsize) + mem_stack.malloc(np.prod(fde_shape) * fde_dtype.itemsize) + + return ((datac_shape, datac_dtype), (fde_shape, fde_dtype)) + + def _fft_and_interpolation( + self, + datac: cp.ndarray, + fde: cp.ndarray, + detector_width: int, + projection_count: int, + detector_height: int, + center_size: int, + block_dim, + block_dim_center, + theta: cp.ndarray, + sorted_theta: cp.ndarray | None, + sorted_theta_indices: cp.ndarray | None, + angle_range_pi_count: int | None, + angle_range: cp.ndarray | None, + eps: float, + mu: float, + c1dfftshift: cp.RawKernel, + gather_kernel_partial: cp.RawKernel, + gather_kernel_center_angle_based_prune: cp.RawKernel, + gather_kernel_center: cp.RawKernel, + gather_kernel: cp.RawKernel, + ): + # STEP1: fft 1d datac = fft(datac) + m = int( + np.ceil( + 2 + * detector_width + * 1 + / np.pi + * np.sqrt( + -mu * np.log(eps) + + (mu * detector_width) * (mu * detector_width) / 4 + ) + ) + ) + c1dfftshift( ( - int(np.ceil(n / 32)), - int(np.ceil(nproj / 32)), - np.int32(nz // 2), + int(np.ceil(detector_width / 32)), + int(np.ceil(projection_count / 32)), + np.int32(detector_height // 2), ), (32, 32, 1), - (datac, np.float32(4 / n), n, nproj, nz // 2), + ( + datac, + np.float32(4 / detector_width), + detector_width, + projection_count, + detector_height // 2, + ), ) # STEP2: interpolation (gathering) in the frequency domain - # Use original one kernel at low dimension. + # Use original kernel at low dimension. if center_size >= _CENTER_SIZE_MIN: - if center_size != (n * 2): + if center_size != (detector_width * 2): gather_kernel_partial( ( - int(np.ceil(n / block_dim[0])), - int(np.ceil(nproj / block_dim[1])), - nz // 2, + int(np.ceil(detector_width / block_dim[0])), + int(np.ceil(projection_count / block_dim[1])), + detector_height // 2, ), (block_dim[0], block_dim[1], 1), ( @@ -423,9 +770,9 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: np.int32(m), np.float32(mu), np.int32(center_size), - np.int32(n), - np.int32(nproj), - np.int32(nz // 2), + np.int32(detector_width), + np.int32(projection_count), + np.int32(detector_height // 2), ), ) @@ -438,8 +785,8 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: sorted_theta, np.int32(m), np.int32(center_size), - np.int32(n), - np.int32(nproj), + np.int32(detector_width), + np.int32(projection_count), ), ) @@ -447,7 +794,7 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: ( int(np.ceil(center_size / block_dim_center[0])), int(np.ceil(center_size / block_dim_center[1])), - nz // 2, + detector_height // 2, ), (block_dim_center[0], block_dim_center[1], 1), ( @@ -460,17 +807,17 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: np.int32(m), np.float32(mu), np.int32(center_size), - np.int32(n), - np.int32(nproj), - np.int32(nz // 2), + np.int32(detector_width), + np.int32(projection_count), + np.int32(detector_height // 2), ), ) else: gather_kernel( ( - int(np.ceil(n / block_dim[0])), - int(np.ceil(nproj / block_dim[1])), - nz // 2, + int(np.ceil(detector_width / block_dim[0])), + int(np.ceil(projection_count / block_dim[1])), + detector_height // 2, ), (block_dim[0], block_dim[1], 1), ( @@ -479,54 +826,113 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: theta, np.int32(m), np.float32(mu), - np.int32(n), - np.int32(nproj), - np.int32(nz // 2), + np.int32(detector_width), + np.int32(projection_count), + np.int32(detector_height // 2), ), ) - del datac + def _fft_and_interpolation_estimator( + self, + datac: tuple, + ): + mem_stack = DeviceMemStack.instance() + (datac_shape, datac_dtype) = datac + fft_input = cp.empty(datac_shape, datac_dtype) + mem_stack.malloc(fft_input.nbytes) - # STEP3: ifft 2d + fft_plan = get_fft_plan(fft_input, axes=(-1)) + mem_stack.malloc(fft_plan.work_area.mem.size) + mem_stack.free(fft_plan.work_area.mem.size) + del fft_input, fft_plan + + def ifft_gathered_projections( + self, + fde: cp.ndarray, + detector_width: int, + detector_height: int, + chunk_count: int, + min_mem_usage_ifft2: bool, + c2dfftshift: cp.RawKernel, + ): c2dfftshift( ( - int(np.ceil((2 * n) / 32)), - int(np.ceil((2 * n) / 8)), - np.int32(nz // 2), + int(np.ceil((2 * detector_width) / 32)), + int(np.ceil((2 * detector_width) / 8)), + np.int32(detector_height // 2), ), (32, 8, 1), - (fde, n, nz // 2), + (fde, detector_width, detector_height // 2), ) - slice_count_per_chunk = np.ceil(nz // 2 / chunk_count) + if min_mem_usage_ifft2: + chunk_count = detector_height // 2 + + slice_count_per_chunk = int(np.ceil(detector_height // 2 / chunk_count)) # Loop over the chunks for chunk_index in range(0, chunk_count): - start_index = min(chunk_index * slice_count_per_chunk, nz // 2) - end_index = min((chunk_index + 1) * slice_count_per_chunk, nz // 2) + start_index = min(chunk_index * slice_count_per_chunk, detector_height // 2) + end_index = min( + (chunk_index + 1) * slice_count_per_chunk, detector_height // 2 + ) if start_index >= end_index: break tmp = fde[start_index:end_index, :, :] tmp = ifft2(tmp, axes=(-2, -1), overwrite_x=True) fde[start_index:end_index, :, :] = tmp - del tmp c2dfftshift( ( - int(np.ceil((2 * n) / 32)), - int(np.ceil((2 * n) / 8)), - np.int32(nz // 2), + int(np.ceil((2 * detector_width) / 32)), + int(np.ceil((2 * detector_width) / 8)), + np.int32(detector_height // 2), ), (32, 8, 1), - (fde, n, nz // 2), + (fde, detector_width, detector_height // 2), ) - # Unpadded recon output size + def ifft_gathered_projections_estimator( + self, + fde: tuple, + detector_height: int, + chunk_count: int, + min_mem_usage_ifft2: bool, + ): + mem_stack = DeviceMemStack.instance() + + if min_mem_usage_ifft2: + chunk_count = detector_height // 2 + + slice_count_per_chunk = int(np.ceil(detector_height // 2 / chunk_count)) + (fde_shape, fde_dtype) = fde + ifft2_input = cp.empty((slice_count_per_chunk, *fde_shape[1:]), fde_dtype) + + ifft2_plan = get_fft_plan(ifft2_input, axes=(-2, -1)) + mem_stack.malloc(ifft2_plan.work_area.mem.size) + mem_stack.malloc(ifft2_input.nbytes) + mem_stack.free(ifft2_input.nbytes) + del ifft2_input, ifft2_plan + + def unpad_reconstructed_data( + self, + fde: cp.ndarray, + detector_width: int, + projection_count: int, + detector_height: int, + odd_horiz: bool, + odd_vert: bool, + recon_size: int, + mu, + unpadding_mul_phi: cp.RawKernel, + ) -> cp.ndarray: odd_recon_size = bool(recon_size % 2) - unpad_z = nz - odd_vert - unpad_recon_m = (n - odd_horiz) // 2 - recon_size // 2 - unpad_recon_p = (n - odd_horiz) // 2 + (recon_size + odd_recon_size) // 2 + unpad_z = detector_height - odd_vert + unpad_recon_m = (detector_width - odd_horiz) // 2 - recon_size // 2 + unpad_recon_p = (detector_width - odd_horiz) // 2 + ( + recon_size + odd_recon_size + ) // 2 unpad_recon_size = unpad_recon_p - unpad_recon_m # memory for recon @@ -539,25 +945,43 @@ def FOURIER_INV(self, data: cp.ndarray, **kwargs) -> cp.ndarray: ( int(np.ceil(unpad_recon_size / 32)), int(np.ceil(unpad_recon_size / 32)), - np.int32(nz // 2), + np.int32(detector_height // 2), ), (32, 32, 1), ( recon_up, fde, np.float32(mu), - nproj, + projection_count, unpad_recon_p, unpad_z, unpad_recon_m, - n, - nz // 2, + detector_width, + detector_height // 2, ), ) - del fde + return recon_up - return check_kwargs( - recon_up, - **kwargs, - ) + def unpad_reconstructed_data_estimator( + self, + detector_width: int, + detector_height: int, + odd_horiz: bool, + odd_vert: bool, + recon_size: int, + ): + odd_recon_size = bool(recon_size % 2) + unpad_z = detector_height - odd_vert + unpad_recon_m = (detector_width - odd_horiz) // 2 - recon_size // 2 + unpad_recon_p = (detector_width - odd_horiz) // 2 + ( + recon_size + odd_recon_size + ) // 2 + unpad_recon_size = unpad_recon_p - unpad_recon_m + recon_up_shape = (unpad_z, unpad_recon_size, unpad_recon_size) + + # memory for recon + mem_stack = DeviceMemStack.instance() + mem_stack.malloc(np.prod(recon_up_shape) * cp.float32().itemsize) + + return recon_up_shape diff --git a/tomobar/supp/funcs.py b/tomobar/supp/funcs.py index 865c98b88..66d389fd6 100644 --- a/tomobar/supp/funcs.py +++ b/tomobar/supp/funcs.py @@ -138,8 +138,12 @@ def _swap_data_axes_to_accepted(data_axes_labels, required_labels_order): return [swap_tuple1, swap_tuple2] +def swap_tuple_elements(tup: Tuple[int, int, int], idx1: int, idx2: int): + items = list(tup) + items[idx1], items[idx2] = items[idx2], items[idx1] + return tuple(items) -def _data_swap(data: xp.ndarray, data_swap_list: list) -> xp.ndarray: +def _data_swap(data: xp.ndarray | Tuple[int, int, int], data_swap_list: list) -> xp.ndarray: """Swap data labels based on the provided list of tuples Args: @@ -151,17 +155,14 @@ def _data_swap(data: xp.ndarray, data_swap_list: list) -> xp.ndarray: """ for swap_tuple in data_swap_list: if swap_tuple is not None: - if type(data) is tuple: - data = list(data) - tmp = data[swap_tuple[0]] - data[swap_tuple[0]] = data[swap_tuple[1]] - data[swap_tuple[1]] = tmp - data = tuple(data) - elif cupy_enabled: - xpp = xp.get_array_module(data) - data = xpp.swapaxes(data, swap_tuple[0], swap_tuple[1]) + if isinstance(data, tuple): + data = swap_tuple_elements(data, swap_tuple[0], swap_tuple[1]) else: - data = np.swapaxes(data, swap_tuple[0], swap_tuple[1]) + if cupy_enabled: + xpp = xp.get_array_module(data) + data = xpp.swapaxes(data, swap_tuple[0], swap_tuple[1]) + else: + data = np.swapaxes(data, swap_tuple[0], swap_tuple[1]) return data diff --git a/tomobar/supp/memory_estimator_helpers.py b/tomobar/supp/memory_estimator_helpers.py index 4c63a188d..e03ea3c23 100644 --- a/tomobar/supp/memory_estimator_helpers.py +++ b/tomobar/supp/memory_estimator_helpers.py @@ -1,22 +1,42 @@ ALLOCATION_UNIT_SIZE = 512 -class _DeviceMemStack: +class DeviceMemStack: + _instance = None + _stack_count = 0 + + def __enter__(self): + if DeviceMemStack._stack_count == 0: + DeviceMemStack._instance = self + + DeviceMemStack._stack_count += 1 + return self + + def __exit__(self, exc_type, exc_value, traceback): + DeviceMemStack._stack_count -= 1 + + if DeviceMemStack._stack_count == 0: + DeviceMemStack._instance = None + + @classmethod + def instance(cls): + return cls._instance + def __init__(self) -> None: self.allocations = [] self.current = 0 self.highwater = 0 - def malloc(self, bytes): - self.allocations.append(bytes) - allocated = self._round_up(bytes) + def malloc(self, byte_count): + self.allocations.append(byte_count) + allocated = self._round_up(byte_count) self.current += allocated self.highwater = max(self.current, self.highwater) - def free(self, bytes): - assert bytes in self.allocations - self.allocations.remove(bytes) - self.current -= self._round_up(bytes) + def free(self, byte_count): + assert byte_count in self.allocations + self.allocations.remove(byte_count) + self.current -= self._round_up(byte_count) assert self.current >= 0 def _round_up(self, size):