Source code for toupy.registration.registration

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# standard libraries imports
import time

# third party packages
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import fft, ifft, fft2, ifft2, fftshift, ifftshift
from scipy.ndimage import center_of_mass, interpolation
from scipy.ndimage.filters import gaussian_filter, gaussian_filter1d
from scipy.ndimage.fourier import fourier_shift
from skimage.feature import register_translation

# local packages
from ..restoration import derivatives, derivatives_sino
from .shift import ShiftFunc
from ..tomo import projector, tomo_recons
from ..utils import (

__all__ = [
    # "cc_align", #FIXME: not working

[docs]def register_2Darrays(image1, image2): """ Image registration. Register two images using phase cross correlations. Parameters ---------- image1 : array_like Image of reference image2 : array_like Image to be shifted relative to image1 Returns ------- shift : list of floats List of shifts applied, with the row shift in the 1st dimension and the column shift in the 2nd dimension. diffphase : float Difference of phase between the two images offset_image2 : array_like Shifted image2 relative to image1 """ # Choose between pixel or subpixel precision image registration. By default, it is pixel precision. precision = input( "Do you want to use pixel(1) or subpixel(2) precision registration?[1] " ) if precision == str(1) or precision == "": # pixel precision print("\nCalculating the pixel precision image registration ...") start = time.time() shift, error, diffphase = register_translation(image1.copy(), image2.copy()) print(diffphase) end = time.time() print("Time elapsed: {:g} s".format(end - start)) print("Detected pixel offset [y,x]: [{:g}, {:g}]".format(shift[0], shift[1])) elif precision == str(2): # subpixel precision print("\nCalculating the subpixel image registration ...") start = time.time() shift, error, diffphase = register_translation( image1.copy(), image2.copy(), 100 ) print(diffphase) end = time.time() print("Time elapsed: {:g} s".format(end - start)) print("Detected subpixel offset [y,x]: [{:g}, {:g}]".format(shift[0], shift[1])) else: print("You must choose between 1 and 2") raise SystemExit print("\nCorrecting the shift of image2 by using subpixel precision...") # (shift[0],-shift[1]))) offset_image2 = ifft2(fourier_shift(fft2(image2.copy()), shift)) # TODO: check if we can use ShiftFunc here. offset_image2 *= np.exp(1j * diffphase) return shift, diffphase, offset_image2
[docs]def compute_aligned_stack(input_stack, shiftstack, shift_method="linear"): """ Compute the aligned stack given the correction for object positions Parameters ---------- input_array : array_like Stack of images to be shifted shiftstack : array_like Array of initial estimates for object motion (2,n) shift_method : str (default linear) Name of the shift method. Options: 'linear', 'fourier', 'spline' Return ------ output_stack : array_like 2D function containing the stack of aligned images """ # Initialize shift class S = ShiftFunc(shiftmeth=shift_method) # array shape nstack = input_stack.shape[0] print( "Using {} shift method (function {})".format(shift_method, S.shiftmeth.__name__) ) output_stack = np.empty_like(input_stack) for ii in range(nstack): deltashift = (shiftstack[0, ii], shiftstack[1, ii]) output_stack[ii] = S(input_stack[ii], deltashift) strbar = "Image {} of {}".format(ii + 1, nstack) progbar(ii + 1, nstack, strbar) print("\r") return output_stack
def compute_aligned_stack_special(input_stack, shiftstack, shift_method="linear"): """ Compute the aligned stack given the correction for object positions Parameters ---------- input_array : array_like Stack of images to be shifted shiftstack : array_like Array of initial estimates for object motion (2,n) shift_method : str (default linear) Name of the shift method. Options: 'linear', 'fourier', 'spline' Return ------ output_stack : array_like 2D function containing the stack of aligned images """ # Initialize shift class S = ShiftFunc(shiftmeth=shift_method) # array shape nstack = input_stack.shape[0] print( "Using {} shift method (function {})".format(shift_method, S.shiftmeth.__name__) ) # output_stack = np.empty_like(input_stack) for ii in range(nstack): deltashift = (shiftstack[0, ii], shiftstack[1, ii]) input_stack[ii] = S(input_stack[ii], deltashift) strbar = "Image {} of {}".format(ii + 1, nstack) progbar(ii + 1, nstack, strbar) print("\r") return input_stack def compute_aligned_horizontal_special(input_stack, shiftstack, shift_method="linear"): """ Compute the alignment of the stack on at the horizontal direction Parameters ---------- input_array : array_like Stack of images to be shifted shiftstack : array_like Array of initial estimates for object motion (2,n) The estimates for vertical movement will be changed to 0 shift_method : str (default linear) Name of the shift method. Options: 'linear', 'fourier', 'spline' Return ------ output_stack : array_like 2D function containing the stack of aligned images """ deltashift = np.zeros_like(shiftstack) deltashift[1] = shiftstack[1].copy() output_stack = compute_aligned_stack_special( input_stack, shiftstack, shift_method=shift_method ) return output_stack
[docs]def compute_aligned_sino(input_sino, shiftslice, shift_method="linear"): """ Compute the aligned sinogram given the correction for object positions Parameters ---------- input_sino : array_like Input sinogram to be shifted shiftslice : array_like Array of estimates for object motion (1,n) shift_method : str (default linear) Name of the shift method. Options: 'linear', 'fourier', 'spline' Return ------ output_sino: array_like 2D function containing the aligned sinogram """ # Initialize shift class S = ShiftFunc(shiftmeth=shift_method) # array shape nprojs = input_sino.shape[1] print( "Using {} shift method (function {})".format(shift_method, S.shiftmeth.__name__) ) output_sino = np.empty_like(input_sino) for ii in range(nprojs): deltashift = shiftslice[0, ii] output_sino[:, ii] = S(input_sino[:, ii], deltashift) print("Image {} of {}".format(ii + 1, nprojs), end="\r") print("\r") return output_sino
[docs]def compute_aligned_horizontal(input_stack, shiftstack, shift_method="linear"): """ Compute the alignment of the stack on at the horizontal direction Parameters ---------- input_array : array_like Stack of images to be shifted shiftstack : array_like Array of initial estimates for object motion (2,n) The estimates for vertical movement will be changed to 0 shift_method : str (default linear) Name of the shift method. Options: 'linear', 'fourier', 'spline' Return ------ output_stack : array_like 2D function containing the stack of aligned images """ deltashift = np.zeros_like(shiftstack) deltashift[1] = shiftstack[1].copy() output_stack = compute_aligned_stack( input_stack, deltashift, shift_method=shift_method ) return output_stack
[docs]def center_of_mass_stack(input_stack, lims, shiftstack, shift_method="fourier"): """ Calculates the center of the mass for each projection in the stack and returns a stack of centers of mass (row, col) i.e., returns shiftstack[1] If the array is zero, it return the center of mass at 0. """ # separate lims limrow, limcol = lims print("Calculating center-of-mass with pixel precision") # initialize shift class S = ShiftFunc(shiftmeth=shift_method) # create array positions stack_roi = input_stack[0, limrow[0] : limrow[-1], limcol[0] : limcol[-1]].copy() ind_roi = np.indices(stack_roi.shape) # create array Xp of horizontal of positions ind_roi[1] -= ( np.floor(ind_roi[1].mean(axis=1)).reshape((ind_roi.shape[1], 1)).astype("int") ) Xp = ind_roi[1].astype("float") # create array Xp of horizontal of positions ind_roi[0] -= ( np.floor(ind_roi[0].mean(axis=0)).reshape((ind_roi.shape[2], 1)).T.astype("int") ) Yp = ind_roi[0].astype("float") # initializing the arrays mass_sum = np.empty(input_stack.shape[0]) centerx = np.empty(input_stack.shape[0]) centery = np.empty(input_stack.shape[0]) for ii in range(input_stack.shape[0]): stack_aux = S(input_stack[ii], (shiftstack[0, ii], shiftstack[1, ii])) mass_sum[ii] = np.sum(stack_aux[limrow[0] : limrow[-1], limcol[0] : limcol[-1]]) centerx[ii] = np.sum( Xp * stack_aux[limrow[0] : limrow[-1], limcol[0] : limcol[-1]] ) centery[ii] = np.sum( Yp * stack_aux[limrow[0] : limrow[-1], limcol[0] : limcol[-1]] ) centerx[np.nonzero(mass_sum)] = ( centerx[np.nonzero(mass_sum)] / mass_sum[np.nonzero(mass_sum)] ) centerx[np.where(mass_sum == 0)] = 0 centery = np.asarray(centery) centery[np.nonzero(mass_sum)] = ( centery[np.nonzero(mass_sum)] / mass_sum[np.nonzero(mass_sum)] ) centery[np.where(mass_sum == 0)] = 0 return np.asarray([centerx, centery])
[docs]def vertical_fluctuations( input_stack, lims, shiftstack, shift_method="fourier", polyorder=2 ): """ Calculate the vertical fluctuation functions of a stack Parameters ---------- input_array : array_like Stack of images to be shifted lims : list of ints Limits of rows and columns to be considered. lims=[limrow,limcol] shiftstack : array_like Array of initial estimates for object motion (2,n) shift_method : str, optional Name of the shift method. Options: 'linear', 'fourier', 'spline'. The default method is 'linear'. polyorder : int, optional Order of the polynomial to remove bias from the mass fluctuation function. The default value is 2. Return ------ vert_fluct : array_like 2D function containing the mass fluctuation after shift and bias removal for the stack of images """ # Initialize shift class S = ShiftFunc(shiftmeth=shift_method) # array shape nproj, nr, nc = input_stack.shape # separate the lims rows, cols = lims # get the maximum shift value max_vshift = int(np.ceil(np.max(np.abs(shiftstack[0, :])))) + 1 if np.any((rows - max_vshift) < 0) or np.any((rows + max_vshift) > nr): max_vshift = 1 # initializing array # +2*max_vshift)) vert_fluct = np.empty((nproj, rows[-1] - rows[0])) for ii in range(nproj): # print("Calculating for projection: {}".format(ii + 1), end="\r") strbar = "Projection {}".format(ii + 1) proj = input_stack[ ii, rows[0] - max_vshift : rows[-1] + max_vshift, cols[0] : cols[-1] ] stack_shift = S(proj, (shiftstack[0, ii], 0.0)) # the max_vshift has to be subtracted shift_calc = stack_shift[max_vshift:-max_vshift].sum(axis=1) # to remove possible bias shift_calc = projectpoly1d(shift_calc, polyorder, 1) vert_fluct[ii] = shift_calc if not isnotebook(): progbar(ii + 1, nproj, strbar) print("\r") return vert_fluct
[docs]def vertical_shift( input_array, lims, vstep, maxshift, shift_method="linear", polyorder=2 ): """ Calculate the vertical shift of an array Parameters ---------- input_array : array_like Image to be shifted lims : list of ints Limits of rows and columns to be considered. lims=[limrow,limcol] vstep : float Amount to shift the input_array vertically maxshift : float Maximum value of the shifts in order to avoid border problems shift_method : str, optional Name of the shift method. Options: 'linear', 'fourier', 'spline'. The default method is 'linear'. polyorder : int, optional Order of the polynomial to remove bias from the mass fluctuation function. The default value is 2. Returns ------- shift_cal : array_like 1D function containing the mass fluctuation after shift and bias removal """ # Initialize shift class S = ShiftFunc(shiftmeth=shift_method) # array shape nr, nc = input_array.shape # Max vertical shift + 1. At least one for a margin. Had to take the int of vstep. max_vshift = maxshift + int(np.abs(vstep)) # +1 # get the maximum shift value rows, cols = lims if np.any((rows - max_vshift) < 0) or np.any((rows + max_vshift) > nr): max_vshift = 1 stack_shift = S( input_array[rows[0] - max_vshift : rows[-1] + max_vshift, cols[0] : cols[-1]], (vstep, 0.0), ) # Integration because stack_shift is 2D shift_calc = stack_shift[max_vshift:-max_vshift].sum(axis=1) # to remove possible bias shift_calc = projectpoly1d(shift_calc, polyorder, 1) return shift_calc
def _selectROI(stack_shape, **params): """ Define the ROI for alignment """ # defining the boundaries of the area to be used for the alignment deltax = params["deltax"] limcol = (deltax, stack_shape[2] - deltax) # horizontal limrow = params["limsy"] if limrow == None or limrow == "": limrow = [0, stack_shape[1]] return np.asarray(limrow), np.asarray(limcol) def _alignprojections_vertical( input_stack, lims, shiftstack, metric_error, vert_fluct_init, RP, **params ): """ Auxiliary function for align the projection vertically. It contains the wrapper for iteration during the alignement """ # Initialize the counter count = 0 error_reg = np.zeros(vert_fluct_init.shape[0]) while True: count += 1 print("\n============================================") print("Iteration {}".format(count)) it0 = time.time() deltaprev = shiftstack.copy() # Mass distribution registration in y if count == 1: vert_fluct = vert_fluct_init.copy() else: print("Updating the vertical fluctuations") vert_fluct = vertical_fluctuations( input_stack, lims, shiftstack, params["shiftmeth"], polyorder=params["polyorder"], ) # Average the vertical fluctuation functions print("Calculating the average of the vertical fluctuation function") vert_fluct_mean = vert_fluct.mean(axis=0) # Search for shifts with respect to mean print("Search for the shifts with respect to the mean vertical fluctuations...") shiftstack_aux, vert_fluct_temp = _search_vshift_stack( input_stack, lims, shiftstack, vert_fluct_mean, **params ) shiftstack[0] = shiftstack_aux[0].copy() shiftstack[0] -= shiftstack_aux[0].mean().round() # recentering # Error calculation # keep temporarily the vertical fluctuations vert_fluct_mean_temp = vert_fluct_temp.mean(axis=0) print("\nCalculating the error metric") for ii in range(vert_fluct_temp.shape[0]): error_reg[ii] = np.sum( np.abs(vert_fluct_temp[ii] - vert_fluct_mean_temp) ** 2 ) print("Final error metric for y, E = {:.04e}".format(np.sum(error_reg))) metric_error.append(np.sum(error_reg)) # Maximum changes in y print("Estimating the changes in y:") changey = np.abs(deltaprev[0] - shiftstack[0]) print("Maximum correction in y = {:.02f} pixels".format(np.max(changey))) print("Elapsed time = {} s".format(time.time() - it0)) # update figures RP.plotsvertical( input_stack[0], lims, vert_fluct_init, vert_fluct_temp, shiftstack, metric_error, count, ) if params["subpixel"]: pixtol = params["pixtol"] else: pixtol = 1 reason = _checkconditions( metric_error, changey, pixtol, count, params["maxit"], params["subpixel"] ) if reason == 1: shiftstack = deltaprev.copy() metric_error.pop() break elif reason >= 2: break return shiftstack, metric_error
[docs]def alignprojections_vertical(input_stack, shiftstack, **params): """ Vertical alignment of projections using mass fluctuation approach [#massfluct]_, [#tomoalgosv]_. It relies on having air on both sides of the sample (non local tomography). It performs a local search in y, so convergence issues can be addressed by giving an approximate initial guess for a possible drift via shiftstack Parameters ---------- input_stack : array_like Stack of projections limrow : list of ints Limits of window of interest in y limcol : list of ints Limits of window of interest in x shiftstack : array_like Array of initial estimates for object motion (2,n) params : dict Container with parameters for the registration params['pixtol'] : float Tolerance for alignment, which is also used as a search step params['polyorder'] : int Specify the polynomial order of bias removal. For example: polyorder = 1 -> mean, polyorder = 2 -> linear). params['alignx'] : bool True or False to activate align x using center of mass (default= False, which means align y only) params['shiftmeth'] : str Shift images with fourier method (default). The options are `linear` -> Shift images with linear interpolation (default); `fourier` -> Fourier shift or `spline` -> Shift images with spline interpolation. Returns ------- shiftstack : array_like Corrected bject positions input_stack : array_like Aligned stack of the projections References ---------- .. [#massfluct] Guizar-Sicairos, M., et al. , "Phase tomography from x-ray coherent diffractive imaging projections," Opt. Express 19, 21345-21357 (2011). .. [#tomoalgosv] da Silva, J. C., et al. "High energy near-and far-field ptychographic tomography at the ESRF," Proc. SPIE 10391, Developments in X-Ray Tomography XI, 1039106 (2017) """ if not isinstance(params["maxit"], int): params["maxit"] = 10 try: params["alignx"] except: params["alignx"] = False limrow, limcol = _selectROI(input_stack.shape, **params) lims = (limrow, limcol) print("\n============================================") print("Vertical Mass fluctuation pixel alignment") print("Number of iterations: {}".format(params["maxit"])) # horizontal alignement with center of mass if requested if params["alignx"] and count == 0: print("Estimating the changes in x using center-of-mass:") centerx = center_of_mass_stack( input_stack, params, limrow=limrow, limcol=limcol, shiftstack=shiftstack )[ 0 ] # [1] # Correction with mass center shiftstack[1] = -centerx.round() shiftstack[1] -= shiftstack[1].mean().round() changex = np.max(np.abs(deltaprev[1] - shiftstack[1])) print( "Maximum correction of center of mass in x = {:.02f} pixels".format(changex) ) else: changex = 0 # first iteration only correcting for the limrow and limcol and in case shiftstack is already no zero vert_fluct_init = vertical_fluctuations( input_stack, (limrow, limcol), shiftstack, params["shiftmeth"], polyorder=params["polyorder"], ) avg_init = vert_fluct_init.mean(axis=0) shiftstack_init = shiftstack.copy() nr, nc = vert_fluct_init.shape # for the image display # Store initial states metric_error = [] # initialize metrics error_init = np.zeros(vert_fluct_init.shape[0]) error_reg = np.zeros_like(error_init) for ii in range(vert_fluct_init.shape[0]): error_init[ii] = np.sum(np.abs(vert_fluct_init[ii] - avg_init) ** 2) print("Initial error metric for y, E = {:.02e}".format(np.sum(error_init))) metric_error.append(np.sum(error_init)) # initializing display canvas for the figures plt.ion() RP = RegisterPlot(**params) RP.plotsvertical( input_stack[0], lims, vert_fluct_init, vert_fluct_init, shiftstack_init, metric_error, count=0, ) # Single pixel precision print("\n================================================") print("Registration of projections with pixel precision") print("================================================") params["subpixel"] = False shiftstack, metric_error = _alignprojections_vertical( input_stack, lims, shiftstack, metric_error, vert_fluct_init, RP, **params ) if not isinstance(params["pixtol"], int) or np.mod(params["pixtol"], 1) != 0: # Subpixel precision print("\n================================================") print("Registration of projections with subpixel precision") print("================================================") params["subpixel"] = True shiftstack, metric_error = _alignprojections_vertical( input_stack, lims, shiftstack, metric_error, vert_fluct_init, RP, **params ) # Compute the shifted images print("Computing aligned images") output_stack = compute_aligned_stack( input_stack, shiftstack.copy(), shift_method=params["shiftmeth"] ) return shiftstack, output_stack
def _alignprojections_horizontal( sinogram, sino_orig, theta, circleROI, shiftslice, metric_error, RP, **params ): """ Auxiliary function for align the projection horizontally. It contains the wrapper for iteration during the alignement """ # Compute tomogram with current sinogram print("Initializing tomographic slice...") t0 = time.time() recons = tomo_recons(sinogram, theta=theta, **params) recons_std = recons.std() # clipping gray level if needed recons = _clipping_tomo(recons, **params) if params["circle"]: recons = recons * circleROI print("Done. Time elapsed: {} s".format(time.time() - t0)) print("Slice standard deviation = {:0.04e}".format(recons_std)) # Initialize the counter count = 0 while True: count += 1 print("\nIteration {}".format(count)) print("-------------------------------------") it0 = time.time() sinoprev = sinogram.copy() # keep deltaprev in case the iteration does not decrease the error deltaprev = shiftslice.copy() # Compute synthetic sinogram print("Computing synthetic sinogram...") sinogramcomp = projector(recons, theta, **params) if params["derivatives"] and not params["calc_derivatives"]: sinogramcomp = derivatives_sino( sinogramcomp, shift_method=params["shiftmeth"] ) # Start searching for shift relative to synthetic sinogram sinotempreg, shiftslice = _search_hshift_sinogram( sino_orig, sinogramcomp, shiftslice, **params ) # updating sinogram sinogram = compute_aligned_sino( sino_orig, shiftslice, shift_method=params["shiftmeth"] ) # Compute tomogram with current sinogram print("Computing tomographic slice...") t0 = time.time() recons = tomo_recons(sinogram, theta=theta, **params) recons_std = recons.std() # clipping gray level if needed recons = _clipping_tomo(recons, **params) if params["circle"]: recons = recons * circleROI print("Done. Time elapsed: {} s".format(time.time() - t0)) print("Slice standard deviation = {:0.04e}".format(recons_std)) # Calculate the error: errorxreg = _sino_error_metric(sinogram, sinogramcomp, params) sumerrorxreg = errorxreg.sum() print("Final error metric for x, E = {:0.04e}".format(sumerrorxreg)) metric_error.append(sumerrorxreg) # Estimate amount of changes print("Estimating the changes in x:") changex = np.abs(deltaprev - shiftslice) if params["subpixel"]: strprint = "Maximum correction in x = {:0.02f} pixels" else: strprint = "Maximum correction in x = {:0.02g} pixels" print(strprint.format(np.max(changex))) print("Elapsed time in the iteration= {:0.02f} s".format(time.time() - it0)) # update figures # sinogram = _filter_sino(sinogram, **params) RP.plotshorizontal( recons, sino_orig, sinogram, sinogramcomp, shiftslice, metric_error, count ) if params["subpixel"]: pixtol = params["pixtol"] else: pixtol = 1 reason = _checkconditions( metric_error, changex, pixtol, count, params["maxit"], params["subpixel"] ) if reason == 1: shiftslice = deltaprev.copy() # ~ sinogram = sinoprev.copy() metric_error.pop() break elif reason >= 2: break return shiftslice, metric_error def _filter_sino(sinogram, **params): """ Filter to the sinogram """ N, M = sinogram.shape apod_width = * N * (params["freqcutoff"])) filteraux = hanning_apod1D(N, apod_width) filteraux = np.tile(filteraux, (M, 1)).T return np.real(ifft(fft(sinogram) * filteraux))
[docs]def alignprojections_horizontal(sinogram, theta, shiftstack, **params): """ Function to align projections by tomographic consistency [#tomoconsist]_, [#tomoalgosh]_. It relies on having already aligned the vertical direction. The code aligns using the consistency before and after tomographic combination of projections. Parameters ---------- sinogram : array_like Sinogram derivative, the second index should be the angle theta : array_like Reconstruction angles (in degrees). Default: m angles evenly spaced between 0 and 180 (if the shape of `radon_image` is (N, M)). shiftstack : array_like Array with initial estimates of positions params : dict Container with parameters for the registration params["pixtol"] : float Tolerance for alignment, which is also used as a search step params["alignx"] : bool True or False to activate align x using center of mass (default= False, which means align y only) params["shiftmeth"] : str Shift images with fourier method (default). The options are `linear` -> Shift images with linear interpolation (default); `fourier` -> Fourier shift or `spline` -> Shift images with spline interpolation. params["circle"] : bool Use a circular mask to eliminate corners of the tomogram params["filtertype"] : str Filter to use for FBP params["freqcutoff"] : float Frequency cutoff for tomography filter (between 0 and 1) params["cliplow"] : float Minimum value in tomogram params["cliphigh"] : float Maximum value in tomogram Returns ------- shiftstack : array_like Corrected object positions alinedsinogram : array_like Array containting the aligned sinogram References ---------- .. [#tomoconsist] Guizar-Sicairos, M., et al., "Quantitative interior x-ray nanotomography by a hybrid imaging technique," Optica 2, 259-266 (2015). .. [#tomoalgosh] da Silva, J. C., et al., "High energy near-and far-field ptychographic tomography at the ESRF," Proc. SPIE 10391, Developments in X-Ray Tomography XI, 1039106 (2017). """ # parsing of the parameters try: params["circle"] except KeyError: params["circle"] = True try: params["sinohigh"] except KeyError: params["sinohigh"] = 0.6 try: params["sinolow"] except KeyError: params["sinolow"] = -0.6 try: params["opencl"] except KeyError: params["opencl"] = False if not isinstance(params["maxit"], int): params["maxit"] = 10 try: params["cliplow"] except: params["cliplow"] = None try: params["cliphigh"] except: params["cliphigh"] = None print("\nStarting the horizontal alignment") print("=====================================") print("Number of iterations: {}".format(params["maxit"])) print("Using a frequency cutoff of {}".format(params["freqcutoff"])) print("Low limit for tomo values = {}".format(params["cliplow"])) print("High limit for tomo values = {}".format(params["cliphigh"])) # appropriate keeping of variable original_sino = sinogram.copy() shiftslice = np.expand_dims(shiftstack[1], axis=0) # pad sinogram of derivatives # TODO: check if we only need this for derivative (if params['derivatives']:) or not! padval = int(2 * np.round(1 / params["freqcutoff"])) sinogram = np.pad( sinogram, ((padval, padval), (0, 0)), "constant", constant_values=0 ).copy() N, M = sinogram.shape # applying a filter to the sinogram sino_orig = _filter_sino(sinogram, **params) # Shifting projection according to the initial shiftslice if not np.all(shiftslice == 0): print("Shifting sinogram.") sinogram = compute_aligned_sino( sino_orig, shiftslice, shift_method=params["shiftmeth"] ) print("Done.") else: print("Initializing shiftslice with zeros") # initial reconstruction print("Computing initial tomographic slice...") # Filtered back projection print("Backprojecting") t0 = time.time() recons = tomo_recons(sinogram, theta=theta, **params) print("Done. Time elapsed: {} s".format(time.time() - t0)) print("Slice standard deviation = {:0.04e}".format(recons.std())) # clipping gray level if needed recons = _clipping_tomo(recons, **params) if params["circle"]: circleROI = create_circle(recons) # only need to calculate once else: circleROI = 1 recons = recons * circleROI # initial synthetic sinogram print("Computing synthetic sinogram...") t0 = time.time() sinogramcomp = projector(recons, theta, **params) if params["derivatives"] and not params["calc_derivatives"]: sinogramcomp = derivatives_sino(sinogramcomp, shift_method=params["shiftmeth"]) print("Done. Time elapsed: {:0.02f} s".format(time.time() - t0)) # store initial error metric metric_error = [] print("Store initial error metric") errorinit = _sino_error_metric(sinogram, sinogramcomp, params) sumerrorinit = np.sum(errorinit) print("Initial error metric, E= {:0.04e}".format(sumerrorinit)) metric_error.append(sumerrorinit) # initializing display canvas for the figures plt.ion() # sinogram = _filter_sino(sinogram, **params) RP = RegisterPlot(**params) RP.plotshorizontal( recons, sino_orig, sinogram, sinogramcomp, shiftslice, metric_error, count=0 ) # Single pixel precision print("\n===================================================") print("Registration of projections with pixel precision") print("===================================================") params["subpixel"] = False shiftslice, metric_error = _alignprojections_horizontal( sinogram, sino_orig, theta, circleROI, shiftslice, metric_error, RP, **params ) print("\n===================================================") print("Registration of projections with subpixel precision") print("===================================================") params["subpixel"] = True shiftslice, metric_error = _alignprojections_horizontal( sinogram, sino_orig, theta, circleROI, shiftslice, metric_error, RP, **params ) # updating shiftstack shiftstack[1] = shiftslice # Compute the shifted images print("\nComputing aligned images") alignedsinogram = compute_aligned_sino( original_sino, shiftslice, shift_method=params["shiftmeth"] ) print("Calculating aligned slice for display") _oneslicefordisplay(alignedsinogram, theta, **params) return shiftstack
[docs]def refine_horizontalalignment(input_stack, theta, shiftstack, **params): """ Refine horizontal alignment. Please, see the description of each parameter in :py:meth:`alignprojections_horizontal`. """ try: params["correct_bad"] except KeyError: params["correct_bad"] = False while True: a = input("Do you want to refine further the alignment? ([y]/n): ").lower() if str(a) == "" or str(a) == "y": a1 = input("Do you want to use the same parameters? ([y]/n): ").lower() if a1 == "n": a1 = input("Slice number (e.g. {}): ".format(params["slicenum"])) if a1 != "": params["slicenum"] = eval(a1) a2 = input("Pixel tolerance (e.g. {}): ".format(params["pixtol"])) if a2 != "": params["pixtol"] = eval(a2) a3 = input( "Filter Tomo cutoff (e.g. {}): ".format(params["freqcutoff"]) ) if a3 != "": params["freqcutoff"] = eval(a3) a4 = input("Number of iterations (e.g. {}): ".format(params["maxit"])) if a4 != "": params["maxit"] = eval(a4) a5 = input("Apply a circle (e.g. {}): ".format(params["circle"])) if a5 != "": params["circle"] = eval(a5) a6 = input("Clipping high (e.g. {}): ".format(params["cliphigh"])) if a6 != "": params["cliphigh"] = eval(a6) # calculate again the sinogram with corrected bad projections sinogram = np.transpose(input_stack[:, params["slicenum"], :]) # correcting bad projections if params["correct_bad"]: sinogram = replace_bad( sinogram, list_bad=params["bad_projs"], temporary=False ) # actual alignment print("Starting the refinement of the alignment") shiftstack = alignprojections_horizontal( sinogram, theta, shiftstack, **params ) elif str(a) == "n": print("No further refinement done") break else: print("You should answer 'y' or 'n or accept default answer.") return shiftstack, params
[docs]def oneslicefordisplay(sinogram, theta, **params): """ Calculate one slice for display. Parameters ---------- sinogram : array_like Sinogram derivative, the second index should be the angle theta : array_like Reconstruction angles (in degrees). Default: m angles evenly spaced between 0 and 180 (if the shape of `radon_image` is (N, M)). params : dict Container with parameters for the registration. params["filtertype"] : str Filter to use for FBP params["freqcutoff"] : float Frequency cutoff for tomography filter (between 0 and 1) """ a = input( "Do you want to reconstruct the slice with different parameters? ([y]/n) :" ).lower() if str(a) == "" or str(a) == "y": freqcutoff = input("freqcutoff (current: {}) = ".format(params["freqcutoff"])) if freqcutoff != "": params["freqcutoff"] = eval(freqcutoff) filtertype = str( input("filtertype (current: {}) = ".format(params["filtertype"])).lower() ) if filtertype != "": params["filtertype"] = str(filtertype) print("Calculating a tomographic slice") # display of the slice: _oneslicefordisplay(sinogram, theta, **params)
def _oneslicefordisplay(sinogram, theta, **params): """ Auxiliary for displaying the slice without the questions """ p0 = time.time() recons = tomo_recons(sinogram, theta=theta, **params) # clipping gray level if needed recons = _clipping_tomo(recons, **params) if params["circle"]: circleROI = create_circle(recons) else: circleROI = 1 recons = recons * circleROI print("Done. Time elapsed: {} s".format(time.time() - p0)) display_slice( recons, colormap="bone", vmin=params["cliplow"], vmax=params["cliphigh"] )
[docs]def tomoconsistency_multiple(input_stack, theta, shiftstack, **params): """ Apply tomographic consistency alignement on multiple slices. By default is implemented over 10 slices. Parameters ---------- Input_stack : array_like Stack of projections theta : array_like Reconstruction angles (in degrees). Default: m angles evenly spaced between 0 and 180 (if the shape of `radon_image` is (N, M)). shiftstack : array_like Array with initial estimates of positions params : dict Dictionary with additional parameters for the alignment. Please, see the description of each parameter in :py:meth:`alignprojections_horizontal`. Return ------ shiftstack : array_like Average of the object shifts over 10 slices """ print("Starting Tomographic consistency on multiple slices") # select the slices, which are typically +5 and -5 relative to slicenum slicenumorig = params["slicenum"] slices = np.arange(slicenumorig - 5, slicenumorig + 5) shiftslice = np.expand_dims(shiftstack[1], axis=0) shiftslice_prev = shiftslice.copy() shiftxrefine = [] for ii in slices: print("\nAligning slice {}".format(ii)) params["slicenum"] = ii sinogram = np.transpose(input_stack[:, ii, :]) # create the sinogram shiftstack_aux = alignprojections_horizontal( sinogram, theta, shiftstack, **params ) shiftxrefine.append(shiftstack_aux[1]) shiftslice = shiftstack_aux[1].copy() # updating shiftslice shiftxrefine = np.squeeze(shiftxrefine) shiftxrefine_avg = shiftxrefine.mean(axis=0) plt.close("all") fig = plt.figure(num=6, figsize=(14, 8)) ax1 = fig.add_subplot(211) ax1.imshow(shiftxrefine, interpolation="none", cmap="jet") ax1.axis("tight") ax1.set_xlabel("Projection number") ax1.set_ylabel("Slice number") ax1.set_title("Displacements in x") ax2 = fig.add_subplot(212) ax2.plot(shiftxrefine_avg, "b-", label="average") ax2.plot(shiftslice_prev[0], "r--", label="previous") ax2.legend() ax2.axis("tight") ax2.set_xlim([0, len(shiftxrefine_avg)]) ax2.set_title("Average displacements in x") ax2.set_xlabel("Projection number") plt.tight_layout() if isnotebook: display.display(fig) display.display(fig.canvas) else: a = input( "Are you happy with the tomographic consistency alignment of the multiples slices? ([y]/n) " ).lower() if a == "" or a == "y": shiftstack[1] = shiftxrefine_avg.copy() print("Using the average of all shiftstack") else: shiftstack[1] = shiftslice_prev[0].copy() print("Using the shiftstack before tomographic consisteny in multiple slices") return shiftstack
def _search_vshift_stack(input_stack, lims, input_delta, avg_vert_fluct, **kwargs): """ Search for the shifts directions for the stack """ if isinstance(kwargs["pixtol"], int) or kwargs["subpixel"] == False: pixtol = 1 shift_method = "linear" elif not isinstance(kwargs["pixtol"], int) or kwargs["subpixel"] == True: pixtol = kwargs["pixtol"] shift_method = kwargs["shiftmeth"] # polynomial order to remove bias polyorder = kwargs["polyorder"] # separate the lims rows, cols = lims # get array dimensions nprojs, nr, nc = input_stack.shape # get the maximum shift value from input_delta # plus 1 for a margin max_vshift = int(np.ceil(np.max(np.abs(input_delta[0, :])))) + 1 if np.any((rows - max_vshift) < 0) or np.any((rows + max_vshift) > nr): max_vshift = 1 # at least one for a margin # initializing array vert_fluct_stack = np.empty((input_stack.shape[0], rows[-1] - rows[0])) output_shiftstack = np.empty_like(input_delta) # np.zeros_like(input_delta) if not isinstance(input_stack, np.ndarray): input_stack = np.asarray(input_stack).copy() for ii in range(nprojs): print("Searching the shifts for projection: {}".format(ii + 1), end="\r") shift_delta = input_delta[0, ii] output_shiftstack[0, ii], vert_fluct_stack[ii] = _search_vshift_direction( input_stack[ii], lims, shift_delta, avg_vert_fluct, pixtol, max_vshift, shift_method, polyorder, ) print("\r") return output_shiftstack, vert_fluct_stack def _search_vshift_direction( input_array, lims, shift_delta, avg_vert_fluct, pixtol, max_vshift, shift_method="linear", polyorder=2, ): """ Search for the shifts directions for each image """ # Search for shifts with respect to mean dir_shift = dict() # dictionary shift directions shifts = dict() # dictionary shifts arrays # compute current shift error shifts["current"] = vertical_shift( input_array, lims, shift_delta - 0, max_vshift, shift_method, polyorder ) # compute shift forward error shifts["forward"] = vertical_shift( input_array, lims, shift_delta + pixtol, max_vshift, shift_method, polyorder ) # compute shift backward error shifts["backward"] = vertical_shift( input_array, lims, shift_delta - pixtol, max_vshift, shift_method, polyorder ) # directional shift error calculation dir_shift["current"] = np.sum(np.abs(shifts["current"] - avg_vert_fluct) ** 2) dir_shift["forward"] = np.sum(np.abs(shifts["forward"] - avg_vert_fluct) ** 2) dir_shift["backward"] = np.sum(np.abs(shifts["backward"] - avg_vert_fluct) ** 2) # get the smallest shift error min_error = min(dir_shift, key=dir_shift.get) # calculate the increment to be shifted if min_error == u"current": dir_inc = 0 elif min_error == u"backward": dir_inc = -1 * pixtol elif min_error == u"forward": dir_inc = 1 * pixtol # update shift_delta shift_delta += dir_inc # keep shifting in the direction that minimizes errors. shift = shift_delta.copy() # will return this value if dir_inc = 0 if dir_inc != 0: shift += dir_inc while True: # shift the stack once more in the same direction shifted_stack = vertical_shift( input_array, lims, shift, max_vshift, shift_method, polyorder ) nexterror = np.sum(np.abs(shifted_stack - avg_vert_fluct) ** 2) if nexterror < dir_shift["current"]: # if error is minimized dir_shift["current"] = nexterror shift += dir_inc else: shift -= dir_inc # subtract once dir_inc in case of no sucess in the previous iteraction break else: shifted_stack = shifts["current"] return shift, shifted_stack def _search_hshift_sinogram(sinogram, sinogramcomp, shiftslice, **kwargs): """ Wrapper to search for the shifts in the sinogram """ if isinstance(kwargs["pixtol"], int) or kwargs["subpixel"] == False: pixtol = 1 shift_method = "linear" elif not isinstance(kwargs["pixtol"], int) or kwargs["subpixel"] == True: pixtol = kwargs["pixtol"] shift_method = kwargs["shiftmeth"] # get array dimensions nr, nc = sinogram.shape # intializing arrays sino_out = np.zeros_like(sinogram) shiftslice_out = np.zeros_like(shiftslice) for ii in range(nc): print("Searching the shifts for projection: {}".format(ii + 1), end="\r") shift_delta = shiftslice[0, ii] shiftslice_out[0, ii], sino_out[:, ii] = _search_hshift_direction( sinogram[:, ii], sinogramcomp[:, ii], shift_delta, pixtol, shift_method ) print("\r") return sino_out, shiftslice_out def _search_hshift_direction( sinogram, sinogramcomp, shift_delta, pixtol, shift_method="linear" ): """ Search for sinogram shift for each projection """ shifts = dict() # dictionary shifts arrays dir_shift = dict() # dictionary shifts direction # Initialize shift class S = ShiftFunc(shiftmeth=shift_method) # looking both ways # compute current shift error shifts["current"] = S(sinogram, shift_delta - 0) # compute shift forward error shifts["forwards"] = S(sinogram, shift_delta + pixtol) # compute shift backward error shifts["backwards"] = S(sinogram, shift_delta - pixtol) # directional shift error calculation dir_shift["current"] = np.sum(np.abs(shifts["current"] - sinogramcomp) ** 2) dir_shift["forward"] = np.sum(np.abs(shifts["forwards"] - sinogramcomp) ** 2) dir_shift["backward"] = np.sum(np.abs(shifts["backwards"] - sinogramcomp) ** 2) # get the smallest shift error min_error = min(dir_shift, key=dir_shift.get) # calculate the increment to be shifted if min_error == u"current": dir_inc = 0 elif min_error == u"backward": dir_inc = -1 * pixtol elif min_error == u"forward": dir_inc = 1 * pixtol # update shift delta shift_delta += dir_inc # keep shifting in the direction that minimizes errors. shift = shift_delta.copy() # will return this value if dir_inc = 0 if dir_inc != 0: shift += dir_inc while True: # shift the sino according to shift shifted_sino = S(sinogram, shift) nexterror = np.sum(np.abs(shifted_sino - sinogramcomp) ** 2) if nexterror < dir_shift["current"]: # if error is minimized dir_shift["current"] = nexterror shift += dir_inc # shift the sino once more in the same direction else: shift -= dir_inc # subtract once dir_inc in case of no sucess in the previous iteraction # errorxreg[ii] = dir_shift['current'].copy()#currenterror break else: shifted_sino = shifts["current"].copy() return shift, shifted_sino def _clipping_tomo(recons, **params): """ Clip gray level of tomographic images """ if params["cliplow"] is not None: recons = recons * (recons >= params["cliplow"]) + params["cliplow"] * ( recons < params["cliplow"] ) if params["cliphigh"] is not None: recons = recons * (recons <= params["cliphigh"]) + params["cliphigh"] * ( recons > params["cliphigh"] ) recons = recons - params["cliphigh"] return recons def _sino_error_metric(sinogramexp, sinogramcomp, params): """ Estimate the error metric between the experimental sinogram and the synthetic one. @author: jdasilva """ errorxreg = np.zeros(sinogramexp.shape[1]) for ii in range(sinogramexp.shape[1]): errorxreg[ii] = np.sum(np.abs(sinogramexp[:, ii] - sinogramcomp[:, ii]) ** 2) return errorxreg def _checkconditions(metric_error, changes, pixtol, count, maxit, subpixel=False): """ Check if the registration conditions are satisfied """ if subpixel: step = pixtol else: step = 1 # We then check if the error increases # compare the last with the before last value if metric_error[-1] > metric_error[-2]: print("Last iteration increased error.") print( "Before -> {:.04e}, current -> {:.04e}".format( metric_error[-2], metric_error[-1] ) ) print("Keeping previous shifts.") reason = 1 # We check if the changes is larger than 1 or pixtol elif np.max(changes) < step: if step >= 1: print("Changes are smaller than one pixel.") else: print("Changes are smaller than {} pixel.".format(step)) reason = 2 # we check if the number of iteration is reached elif count >= maxit: print("Maximum number of iterations reached.") reason = 3 else: reason = 0 return reason @deprecated def _offset_sinogram_old(sinogram, offset): """ Shift the sinogram for an initial guess of the rotation axis offset """ if np.sign(offset) == +1: # -1: print("Initial guess of the rotation axis offset : {}".format(offset)) sinogram = np.pad( sinogram, ((0, 2 * abs(offset)), (0, 0)), "constant", constant_values=0 ) elif np.sign(offset) == -1: # +1: print("Initial guess of the rotation axis offset : {}".format(offset)) sinogram = np.pad( sinogram, ((2 * abs(offset), 0), (0, 0)), "constant", constant_values=0 ) return sinogram def _offset_sinogram(sinogram, offset, shift_method="linear"): """ Shift the sinogram for an initial guess of the rotation axis offset """ S = ShiftFunc(shiftmeth="linear") return S(sinogram, (offset, 0))
[docs]def estimate_rot_axis(input_array, theta, **params): """ Initial estimate of the rotation axis """ try: params["sinocmap"] except KeyError: params["sinocmap"] = params["colormap"] # Ensuring that theta starts at zero theta -= theta.min() # Inspection of a sinogram and a tomogram slicenum = params["slicenum"] rot_axis_offset = params["rot_axis_offset"] while True: sinogram = np.transpose(input_array[:, slicenum, :]) sinogram = _offset_sinogram(sinogram, rot_axis_offset) # reconstruction print("Calculating a tomographic slice") p0 = time.time() tomogram = tomo_recons(sinogram, theta, **params) print("Time elapsed: {} s".format(time.time() - p0)) # Display slice: plt.close("all") print("Slice: {}".format(slicenum)) fig1 = plt.figure(num=5, figsize=(12, 4)) ax1 = fig1.add_subplot(121) im1 = ax1.imshow( tomogram, cmap=params["colormap"], interpolation="none", vmin=params["cliplow"], vmax=params["cliphigh"], ) ax1.set_title("Slice".format(slicenum)) fig1.colorbar(im1) ax2 = fig1.add_subplot(122) im2 = ax2.imshow( sinogram, cmap=params["sinocmap"], interpolation="none", vmin=params["sinolow"], vmax=params["sinohigh"], ) ax2.axis("tight") ax2.set_title("Sinogram - Slice".format(slicenum)) fig1.colorbar(im2) if isnotebook(): display.display(fig1) display.display(fig1.canvas) display.clear_output(wait=True) else: # a = input("Are you happy with the rotation axis?([y]/n)").lower() if a == "" or a == "y": break else: rot_axis_offset = eval(input("Enter new rotation axis estimate: ")) print( "The initial estimate of the offset of the rotation axis is {}".format( rot_axis_offset ) ) return rot_axis_offset
# ~ @deprecated # ~ def cc_align(input_stack, limrow, limcol, params): # ~ """ # ~ Cross-correlation alignment (DEPRECATED) # ~ FIXME: IT IS NOT WORKING PROPERLY # ~ """ # ~ shift_values = np.empty((len(input_stack), 2)) # ~ # The cross-correlation compares to the first projections, which does not move # ~ shift_values[0] = np.array([0, 0]) # ~ for ii in range(1, len(input_stack)): # ~ print("\nCalculating the subpixel image registration...") # ~ print("Projection: {}".format(ii - 1)) # ~ image1 = input_stack[ii - 1, limrow[0] : limrow[-1], limcol[0] : limcol[-1]] # ~ print("Projection: {}".format(ii)) # ~ image2 = input_stack[ii, limrow[0] : limrow[-1], limcol[0] : limcol[-1]] # ~ start = time.time() # ~ if params["gaussian_filter"]: # ~ image1 = gaussian_filter(image1, params["gaussian_sigma"]) # ~ image2 = gaussian_filter(image2, params["gaussian_sigma"]) # ~ shift, error, diffphase = register_translation(image1, image2, 100) # ~ shift_values[ii] = shift # ~ print(diffphase) # ~ end = time.time() # ~ print("Time elapsed: {} s".format(end - start)) # ~ print("Detected subpixel offset [y,x]: [{}, {}]".format(shift[0], shift[1])) # ~ shift_vert_aux = np.array(shift_values)[:, 0] # ~ shift_hor_aux = np.array(shift_values)[:, 1] # ~ # Cumulative sum of the shifts minus the average # ~ shift_vert = np.cumsum(shift_vert_aux - shift_vert_aux.mean()) # ~ shift_hor = np.cumsum(shift_hor_aux - shift_hor_aux.mean()) # ~ # smoothing the shifts is needed # ~ if params["smooth_shifts"] is not None: # ~ shift_vert = snf.gaussian_filter1d(shift_vert, params["smooth_shifts"]) # ~ shift_hor = snf.gaussian_filter1d(shift_hor, params["smooth_shifts"]) # ~ # display shifts # ~ plt.close("all") # ~ fig1 = plt.figure(1) # ~ ax1 = fig1.add_subplot(211) # ~ ax1.plot(np.array(shift_vert), "ro-") # ~ ax1.set_title("Vertical shifts") # ~ ax2 = fig1.add_subplot(212) # ~ ax2.plot(np.array(shift_hor), "ro-") # ~ ax2.set_title("Horizontal shifts") # ~ # ~ # updating the shiftstack # ~ shiftstack = np.zeros((2, input_stack.shape[0])) # ~ shiftstack[0] = shift_vert # ~ shiftstack[1] = shift_hor # ~ # Compute the shifted images # ~ # print('Computing aligned images') # ~ # if not params['expshift']: # ~ # output_stack = compute_aligned_stack(input_stack,shiftstack.copy(),params) # ~ # else: # ~ # print('Computing aligned images in phase space') # ~ # output_stack = np.angle(compute_aligned_stack(np.exp(1j*input_stack),shiftstack.copy(),params)) # ~ # return shiftstack,output_stack # ~ plt.close("all") # ~ fig1 = plt.figure(1) # ~ ax1 = fig1.add_subplot(111) # (ncols=1, figsize=(14, 6)) # ~ im1 = ax1.imshow( # ~ stack_unwrap[1, limrow[0] : limrow[-1], limcol[0] : limcol[-1]], # ~ interpolation="none", # ~ cmap="bone", # ~ ) # ~ ax1.set_axis_off() # ~ ax1.set_title("Offset corrected image2") # ~ # offset_stack_unwrap = np.empty_like(stack_unwrap[:,80:-80,80:-80]) # ~ # aligned = np.empty_like(stack_unwrap[:,80:-80,80:-80]) # ~ aligned = compute_aligned_stack( # ~ input_stack, shiftstack.copy(), shift_method=params["shiftmeth"] # ~ ) # ~ plt.ion() # ~ for ii in range(0, len(stack_unwrap)): # ~ # img = stack_unwrap[ii,80:-80,80:-80] # ~ shift = np.array([shift_vert[ii], shift_hor[ii]]) # ~ print(shift) # ~ print( # ~ "\nCorrecting the shift of projection {} by using subpixel precision.".format( # ~ ii # ~ ) # ~ ) # ~ # offset_stack_unwrap[ii] = ifftn(fourier_shift(fftn(img),shift))# # ~ # aligned[ii] = ifftn(fourier_shift(fftn(img),shift))# # ~ # im1.set_data(offset_stack_unwrap[ii]) # ~ im1.set_data(aligned[ii]) # ~ ax1.set_title(u"Projection {}".format(ii)) # ~ fig1.canvas.draw() # ~ plt.pause(0.001) # ~ plt.ioff() # ~ # Display the images # ~ fig, (ax1, ax2, ax3) = plt.subplots(num=3, ncols=3, figsize=(14, 6)) # ~ ax1.imshow(image1, interpolation="none", cmap="bone") # ~ ax1.set_axis_off() # ~ ax1.set_title("Image 1 (ref.)") # ~ ax2.imshow(image2, interpolation="none", cmap="bone") # ~ ax2.set_axis_off() # ~ ax2.set_title("Image 2") # ~ # View the output of a cross-correlation to show what the algorithm is # ~ # doing behind the scenes # ~ image_product = fft2(image1) * fft2(image2).conj() # ~ cc_image = fftshift(ifft2(image_product)) # ~ ax3.imshow(cc_image.real) # ~ # ax3.set_axis_off() # ~ ax3.set_title("Cross-correlation") # ~ # ~ return shiftstack, aligned