Source code for toupy.registration.registration

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
registration_gd.py
==================
Tomographic data alignment using gradient descent optimisation.

Vertical alignment  : Adam optimiser on the L2 mass-fluctuation cost.
Horizontal alignment: Adam optimiser on the L2 sinogram-consistency cost.

Both replace the original discrete line-search / parabolic-fit approach
(_search_vshift_direction / _search_hshift_direction) with a proper
gradient-based update rule.  All other helper functions are unchanged.
"""

# standard libraries imports
import contextlib
import io
import os
import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from matplotlib.widgets import Button, TextBox

# third party packages
from ..utils.plot_utils import plt
import numpy as np
from ..utils import tqdm
from scipy.fft import fft, ifft, fft2, ifft2, fftfreq, fftshift, ifftshift
from scipy.ndimage import center_of_mass, interpolation, gaussian_filter, gaussian_filter1d, fourier_shift
from skimage.registration import phase_cross_correlation

# local packages
from ..restoration import derivatives_sino
from .shift import ShiftFunc
from ..tomo import projector, tomo_recons
from ..utils import (
    deprecated,
    isnotebook,
    projectpoly1d,
    RegisterPlot,
    replace_bad,
    display_slice,
    create_circle,
    hanning_apod1D,
)

__all__ = [
    "alignprojections_vertical",
    "alignprojections_horizontal",
    "center_of_mass_stack",
    "compute_aligned_stack",
    "compute_aligned_sino",
    "compute_aligned_horizontal",
    "estimate_rot_axis",
    "oneslicefordisplay",
    "refine_horizontalalignment",
    "register_2Darrays",
    "tomoconsistency_multiple",
    "vertical_fluctuations",
    "vertical_shift",
]

# ---------------------------------------------------------------------------
# Gradient descent hyperparameters (module-level — easy to tune)
# ---------------------------------------------------------------------------
_GD_MAX_ITER  = 50   # maximum GD iterations per projection
_FD_H_VERT    = 0.1  # FD step for vertical gradient (pixels)
_FD_H_HORIZ   = 0.5  # FD step for horizontal gradient (pixels); wide enough
                      # for an accurate Newton step even far from the minimum
_H_MAX_STEP   = 50.0 # max horizontal Newton step (pixels); generous so the
                      # Newton step can cross a large initial offset in one
                      # iteration — the halving fallback handles overshoots


# ============================================================================
# Unchanged public helpers
# ============================================================================

[docs] def register_2Darrays(image1, image2, subpixel=False): """ Image registration using phase cross-correlations. Parameters ---------- image1 : array_like Reference image. image2 : array_like Image to be shifted relative to image1. subpixel : bool, optional If ``False`` (default) use pixel-precision registration. If ``True`` use sub-pixel precision (upsample factor = 100). Returns ------- shift : list of floats [row_shift, col_shift]. diffphase : float Phase difference between the two images. offset_image2 : array_like Shifted image2 aligned to image1. """ if subpixel: print("\nCalculating the subpixel image registration ...") start = time.time() shift, error, diffphase = phase_cross_correlation( image1.copy(), image2.copy(), 100 ) print(diffphase) print("Time elapsed: {:g} s".format(time.time() - start)) print("Detected subpixel offset [y,x]: [{:g}, {:g}]".format(shift[0], shift[1])) else: print("\nCalculating the pixel precision image registration ...") start = time.time() shift, error, diffphase = phase_cross_correlation(image1.copy(), image2.copy()) print(diffphase) print("Time elapsed: {:g} s".format(time.time() - start)) print("Detected pixel offset [y,x]: [{:g}, {:g}]".format(shift[0], shift[1])) print("\nCorrecting the shift of image2 by using subpixel precision...") offset_image2 = ifft2(fourier_shift(fft2(image2.copy()), shift)) offset_image2 *= np.exp(1j * diffphase) return shift, diffphase, offset_image2
[docs] def compute_aligned_stack(input_stack, shiftstack, shift_method="linear"): """ Compute the aligned stack given the correction for object positions. Parameters ---------- input_stack : array_like Stack of images to be shifted. shiftstack : array_like Array of object motion corrections (2, n). shift_method : str 'linear', 'fourier', or 'spline'. Returns ------- output_stack : array_like Aligned image stack. """ S = ShiftFunc(shiftmeth=shift_method) nstack = input_stack.shape[0] print("Using {} shift method (function {})".format(shift_method, S.shiftmeth.__name__)) output_stack = np.empty_like(input_stack) for ii in tqdm(range(nstack), desc="Aligning images"): deltashift = (shiftstack[0, ii], shiftstack[1, ii]) output_stack[ii] = S(input_stack[ii], deltashift) return output_stack
def compute_aligned_stack_special(input_stack, shiftstack, shift_method="linear"): """In-place variant of compute_aligned_stack.""" S = ShiftFunc(shiftmeth=shift_method) nstack = input_stack.shape[0] print("Using {} shift method (function {})".format(shift_method, S.shiftmeth.__name__)) for ii in tqdm(range(nstack), desc="Aligning images"): deltashift = (shiftstack[0, ii], shiftstack[1, ii]) input_stack[ii] = S(input_stack[ii], deltashift) return input_stack def compute_aligned_horizontal_special(input_stack, shiftstack, shift_method="linear", **kwargs): """Horizontal-only in-place alignment. ``**kwargs`` absorbs any extra keys when called as ``compute_aligned_horizontal_special(..., **params)``. ``shift_method`` is read from ``params["shiftmeth"]`` if present and not overridden by a positional argument. """ shift_method = kwargs.get("shiftmeth", shift_method) deltashift = np.zeros_like(shiftstack) deltashift[1] = shiftstack[1].copy() return compute_aligned_stack_special(input_stack, shiftstack, shift_method=shift_method)
[docs] def compute_aligned_sino(input_sino, shiftslice, shift_method="linear"): """ Compute the aligned sinogram given the correction for object positions. Parameters ---------- input_sino : array_like Sinogram to be shifted. shiftslice : array_like Per-projection horizontal shifts (1, n). shift_method : str 'linear', 'fourier', or 'spline'. Returns ------- output_sino : array_like Aligned sinogram. """ S = ShiftFunc(shiftmeth=shift_method) nprojs = input_sino.shape[1] print("Using {} shift method (function {})".format(shift_method, S.shiftmeth.__name__)) output_sino = np.empty_like(input_sino) for ii in tqdm(range(nprojs), desc="Aligning sinogram"): output_sino[:, ii] = S(input_sino[:, ii], shiftslice[0, ii]) return output_sino
[docs] def compute_aligned_horizontal(input_stack, shiftstack, shift_method="linear", **kwargs): """Horizontal-only alignment (copy variant). ``**kwargs`` absorbs any extra keys when called as ``compute_aligned_horizontal(..., **params)``. ``shift_method`` defaults to ``params["shiftmeth"]`` if that key is present and ``shift_method`` was not supplied explicitly. """ shift_method = kwargs.get("shiftmeth", shift_method) deltashift = np.zeros_like(shiftstack) deltashift[1] = shiftstack[1].copy() return compute_aligned_stack(input_stack, deltashift, shift_method=shift_method)
[docs] def center_of_mass_stack(input_stack, lims, shiftstack, shift_method="fourier"): """ Compute the centre-of-mass position for each projection in the stack. Parameters ---------- input_stack : ndarray, shape (n, nr, nc) Stack of projection images. lims : tuple of array_like ``(limrow, limcol)`` — row and column index arrays defining the region of interest. shiftstack : ndarray, shape (2, n) Current shift estimates used to pre-align each projection before computing the centre of mass. shift_method : str, optional Interpolation method for the shift operation. Default ``'fourier'``. Returns ------- ndarray, shape (2, n) Array ``[centerx, centery]`` where ``centerx[i]`` and ``centery[i]`` are the horizontal and vertical centre-of-mass offsets (in pixels) for projection ``i``. """ limrow, limcol = lims print("Calculating center-of-mass with pixel precision") S = ShiftFunc(shiftmeth=shift_method) stack_roi = input_stack[0, limrow[0]:limrow[-1], limcol[0]:limcol[-1]].copy() ind_roi = np.indices(stack_roi.shape) ind_roi[1] -= ( np.floor(ind_roi[1].mean(axis=1)).reshape((ind_roi.shape[1], 1)).astype("int") ) Xp = ind_roi[1].astype("float") ind_roi[0] -= ( np.floor(ind_roi[0].mean(axis=0)).reshape((ind_roi.shape[2], 1)).T.astype("int") ) Yp = ind_roi[0].astype("float") mass_sum = np.empty(input_stack.shape[0]) centerx = np.empty(input_stack.shape[0]) centery = np.empty(input_stack.shape[0]) for ii in range(input_stack.shape[0]): stack_aux = S(input_stack[ii], (shiftstack[0, ii], shiftstack[1, ii])) roi = stack_aux[limrow[0]:limrow[-1], limcol[0]:limcol[-1]] mass_sum[ii] = np.sum(roi) centerx[ii] = np.sum(Xp * roi) centery[ii] = np.sum(Yp * roi) nz = np.nonzero(mass_sum) centerx[nz] /= mass_sum[nz] centerx[np.where(mass_sum == 0)] = 0 centery = np.asarray(centery) centery[nz] /= mass_sum[nz] centery[np.where(mass_sum == 0)] = 0 return np.asarray([centerx, centery])
[docs] def vertical_fluctuations(input_stack, lims, shiftstack, shift_method="fourier", polyorder=2): """ Compute the vertical mass-fluctuation signal for each projection. The signal is obtained by integrating each (shifted) projection over the horizontal axis within the region of interest, then subtracting a polynomial baseline fit. Parameters ---------- input_stack : ndarray, shape (n, nr, nc) Stack of projection images. lims : tuple of array_like ``(rows, cols)`` — row and column index arrays for the ROI. shiftstack : ndarray, shape (2, n) Current vertical and horizontal shift estimates. shift_method : str, optional Interpolation method for the shift operation. Default ``'fourier'``. polyorder : int, optional Polynomial order used for baseline removal. Default ``2``. Returns ------- vert_fluct : ndarray, shape (n, n_rows_roi) Array of vertical fluctuation signals, one row per projection. """ S = ShiftFunc(shiftmeth=shift_method) nproj, nr, nc = input_stack.shape rows, cols = lims max_vshift = int(np.ceil(np.max(np.abs(shiftstack[0, :])))) + 1 if np.any((rows - max_vshift) < 0) or np.any((rows + max_vshift) > nr): max_vshift = 1 vert_fluct = np.empty((nproj, rows[-1] - rows[0])) for ii in tqdm(range(nproj), desc="Computing vertical fluctuations"): proj = input_stack[ii, rows[0] - max_vshift:rows[-1] + max_vshift, cols[0]:cols[-1]] stack_shift = S(proj, (shiftstack[0, ii], 0.0)) shift_calc = stack_shift[max_vshift:-max_vshift].sum(axis=1) shift_calc = projectpoly1d(shift_calc, polyorder, 1) vert_fluct[ii] = shift_calc return vert_fluct
[docs] def vertical_shift(input_array, lims, vstep, maxshift, shift_method="linear", polyorder=2): """ Compute the vertical mass-fluctuation signal for a single projection shifted by a given amount. Parameters ---------- input_array : ndarray, shape (nr, nc) Single projection image. lims : tuple of array_like ``(rows, cols)`` — row and column index arrays for the ROI. vstep : float Vertical shift to apply (pixels). maxshift : int Safety margin (pixels) around the ROI to absorb border effects. shift_method : str, optional Interpolation method for the shift operation. Default ``'linear'``. polyorder : int, optional Polynomial order used for baseline removal. Default ``2``. Returns ------- shift_calc : ndarray, shape (n_rows_roi,) Vertical fluctuation signal after shifting and polynomial baseline subtraction. """ S = ShiftFunc(shiftmeth=shift_method) nr, nc = input_array.shape max_vshift = maxshift + int(np.abs(vstep)) rows, cols = lims if np.any((rows - max_vshift) < 0) or np.any((rows + max_vshift) > nr): max_vshift = 1 stack_shift = S( input_array[rows[0] - max_vshift:rows[-1] + max_vshift, cols[0]:cols[-1]], (vstep, 0.0), ) shift_calc = stack_shift[max_vshift:-max_vshift].sum(axis=1) shift_calc = projectpoly1d(shift_calc, polyorder, 1) return shift_calc
# ============================================================================ # Private helpers (ROI, clipping, error metrics, convergence) # ============================================================================ def _selectROI(stack_shape, **params): """ Derive region-of-interest row and column limits from params. Parameters ---------- stack_shape : tuple of int Shape ``(n, nr, nc)`` of the projection stack. **params Must contain ``'deltax'`` (int, horizontal margin) and ``'limsy'`` (list of int or None, explicit row limits). Returns ------- limrow : ndarray of int Row limits ``[row_start, row_end]``. limcol : ndarray of int Column limits ``[col_start, col_end]``. """ deltax = params["deltax"] limcol = (deltax, stack_shape[2] - deltax) limrow = params["limsy"] if limrow is None or limrow == "": limrow = [0, stack_shape[1]] return np.asarray(limrow), np.asarray(limcol) def _clipping_tomo(recons, **params): """ Apply low and high clipping to a tomographic reconstruction. Parameters ---------- recons : ndarray Reconstructed slice array. **params Must contain ``'cliplow'`` (float or None) and ``'cliphigh'`` (float or None). When a value is not ``None`` the reconstruction is clipped to that threshold and, for ``cliphigh``, also shifted so that the clipped maximum maps to zero. Returns ------- recons : ndarray Clipped (and optionally shifted) reconstruction. """ if params["cliplow"] is not None: recons = recons * (recons >= params["cliplow"]) + params["cliplow"] * ( recons < params["cliplow"] ) if params["cliphigh"] is not None: recons = recons * (recons <= params["cliphigh"]) + params["cliphigh"] * ( recons > params["cliphigh"] ) recons = recons - params["cliphigh"] return recons def _sino_error_metric(sinogramexp, sinogramcomp, params): """ Compute the per-column L2 error between experimental and synthetic sinograms. Parameters ---------- sinogramexp : ndarray, shape (nr, nc) Experimental (measured) sinogram. sinogramcomp : ndarray, shape (nr, nc) Synthetic sinogram computed from the current reconstruction. params : dict Unused; reserved for future weighting options. Returns ------- errorxreg : ndarray, shape (nc,) Per-column sum of squared differences between the two sinograms. """ errorxreg = np.zeros(sinogramexp.shape[1]) for ii in range(sinogramexp.shape[1]): errorxreg[ii] = np.sum(np.abs(sinogramexp[:, ii] - sinogramcomp[:, ii]) ** 2) return errorxreg def _checkconditions(metric_error, changes, pixtol, count, maxit, subpixel=False, rtol=0.0): """ Evaluate stopping conditions for the alignment loop. Returns ------- 0 : continue 1 : diverged (2 consecutive error increases) — caller should roll back 2 : shifts converged (max change < pixtol) 3 : maximum iterations reached 4 : relative error improvement below rtol (only when rtol > 0) """ step = pixtol if subpixel else 1 eps = np.spacing(1) # Require 2 consecutive error increases before declaring divergence. # A single increase is often a transient fluctuation; stopping on it # forces the user to restart manually even though the algorithm would # recover by itself in the next iteration. if (len(metric_error) >= 3 and metric_error[-1] > metric_error[-2] and metric_error[-2] > metric_error[-3]): print("Error increased for 2 consecutive iterations.") print( "{:.04e} -> {:.04e} -> {:.04e}".format( metric_error[-3], metric_error[-2], metric_error[-1] ) ) print("Keeping previous shifts.") return 1 # Relative improvement below threshold — useful for warm-started refinement # where the solution is already close to optimal. if rtol > 0 and len(metric_error) >= 2: rel_improv = (metric_error[-2] - metric_error[-1]) / (abs(metric_error[-2]) + eps) if metric_error[-1] <= metric_error[-2] and rel_improv < rtol: print("Relative improvement {:.2e} < rtol {:.2e}. Converged.".format( rel_improv, rtol)) return 4 if np.max(changes) < step: if step >= 1: print("Changes are smaller than one pixel.") else: print("Changes are smaller than {} pixel.".format(step)) return 2 elif count >= maxit: print("Maximum number of iterations reached.") return 3 return 0 def _filter_sino(sinogram, **params): """ Apply a Hanning low-pass filter to a sinogram along the detector axis. Parameters ---------- sinogram : ndarray, shape (nr, nc) Input sinogram (rows = detector pixels, columns = projections). **params Must contain ``'freqcutoff'`` (float in ``(0, 1]``), which sets the half-width of the apodisation window as a fraction of ``nr``. Returns ------- ndarray, shape (nr, nc) Real-valued filtered sinogram. """ N, M = sinogram.shape apod_width = np.int32(0.5 * N * params["freqcutoff"]) filteraux = hanning_apod1D(N, apod_width) filteraux = np.tile(filteraux, (M, 1)).T return np.real(ifft(fft(sinogram) * filteraux)) # ============================================================================ # Gradient descent core — vertical shifts (replaces _search_vshift_direction) # ============================================================================ def _search_vshift_direction( input_array, lims, shift_delta, avg_vert_fluct, pixtol, max_vshift, shift_method="linear", polyorder=2, ): """ Find the optimal vertical shift for one projection using gradient descent with a numerically estimated gradient and an adaptive step size. The cost function is: C(s) = ||f(s) - avg_vert_fluct||² where f(s) is the mass-fluctuation signal after shifting by s pixels. Algorithm --------- At each iteration: 1. Estimate the gradient of C with a central finite difference: g = [C(s+h) - C(s-h)] / (2h) 2. Compute a Newton-like step normalised by the second derivative (curvature), estimated from the same three evaluations: s_new = s - g / |C''(s)| clipped to [-max_step, +max_step] 3. Accept s_new only if C(s_new) < C(s). Otherwise halve the step up to 8 times (guarantees descent without an explicit line search). This is equivalent to a damped Newton step on the 1-D cost, which converges quadratically near the minimum and is globally convergent because of the fallback halving. Parameters ---------- input_array : array_like Single projection image. lims : tuple (rows, cols) region-of-interest limits. shift_delta : float Current shift estimate (pixels). avg_vert_fluct : array_like Reference (mean) vertical fluctuation signal. pixtol : float Convergence tolerance (pixels). max_vshift : int Safety margin for border effects. shift_method : str Interpolation method passed to vertical_shift. polyorder : int Polynomial order for bias removal in vertical_shift. Returns ------- current_shift : float Optimised shift (pixels). final_signal : array_like Mass-fluctuation signal evaluated at current_shift. """ def cost_and_sig(s): sig = vertical_shift(input_array, lims, s, max_vshift, shift_method, polyorder) return np.sum((sig - avg_vert_fluct) ** 2), sig h = _FD_H_VERT current_shift = float(shift_delta) current_cost, current_sig = cost_and_sig(current_shift) for _ in range(_GD_MAX_ITER): c_plus, _ = cost_and_sig(current_shift + h) c_minus, _ = cost_and_sig(current_shift - h) # 1st derivative (gradient) and 2nd derivative (curvature) grad = (c_plus - c_minus) / (2.0 * h) curvature = (c_plus - 2.0 * current_cost + c_minus) / (h ** 2) if grad == 0.0: break # Newton step; fall back to steepest descent if curvature ≤ 0 if curvature > 0.0: raw_step = -grad / curvature else: # Steepest descent with unit step in the right direction raw_step = -np.sign(grad) # Clip to avoid wild jumps (max 2 pixels per iteration) step = float(np.clip(raw_step, -2.0, 2.0)) # Halving fallback: ensure the step actually decreases the cost for _ in range(8): candidate = current_shift + step candidate_cost, candidate_sig = cost_and_sig(candidate) if candidate_cost < current_cost: break step *= 0.5 else: # No improvement found — already at a local minimum break current_shift = candidate current_cost = candidate_cost current_sig = candidate_sig if abs(step) < pixtol / 10.0: break return current_shift, current_sig # ============================================================================ # Gradient descent core — horizontal shifts (replaces _search_hshift_direction) # ============================================================================ def _search_hshift_direction( sinogram_col, sinogramcomp_col, shift_delta, pixtol, shift_method="linear", S=None, sino_fft_col=None, N_fft_col=None, ): """ Find the optimal horizontal shift for a single sinogram column using iterative gradient descent with Newton steps. Unlike the vertical case, this function iterates to full convergence within a single call. The outer loop in ``_alignprojections_horizontal`` is responsible for updating the synthetic sinogram (expensive FBP) between calls; within one synthetic sinogram the shift minimisation should be fully resolved before returning. Algorithm --------- The cost is C(s) = ||T_s(sino_col) - sinogramcomp_col||². At each iteration, three evaluations at {s-h, s, s+h} give gradient and curvature via central finite differences: grad = [C(s+h) - C(s-h)] / (2h) curvature = [C(s+h) - 2C(s) + C(s-h)] / h² If curvature > 0: Newton step Δs = -grad/curvature (jumps directly to the parabola minimum; handles large offsets in one step when the cost landscape is locally quadratic). If curvature ≤ 0: steepest-descent step of size `h` downhill (safe fallback when far from the minimum or on a flat landscape). The step is always verified: if it does not decrease the cost it is halved up to 6 times. If no improvement is found the loop terminates (already at a local minimum for this sinogramcomp). Parameters ---------- sinogram_col : array_like Experimental sinogram column (one projection angle, unshifted). sinogramcomp_col : array_like Synthetic sinogram column. shift_delta : float Current accumulated horizontal shift estimate (pixels). pixtol : float Convergence tolerance (pixels): loop stops when |Δs| < pixtol. shift_method : str Interpolation method for ShiftFunc. S : ShiftFunc or None Pre-instantiated ShiftFunc to avoid repeated construction. If None, a new one is created. sino_fft_col : complex ndarray or None Pre-computed FFT of the padded sinogram column (optimisation E). When provided together with N_fft_col, reused across all Newton steps — avoids O(n log n) pad + FFT per cost evaluation. N_fft_col : ndarray or None Frequency-coordinate array paired with sino_fft_col. Returns ------- current_shift : float Optimised horizontal shift (pixels). final_sino : array_like sinogram_col shifted by current_shift. """ h = _FD_H_HORIZ col_len = len(sinogram_col) # E: fast Fourier-shift path — reuse precomputed FFT if sino_fft_col is not None and N_fft_col is not None: def _shift_col(s): H = np.exp(1j * 2.0 * np.pi * s * N_fft_col) return ifft(sino_fft_col * H).real[:col_len] else: if S is None: S = ShiftFunc(shiftmeth=shift_method) def _shift_col(s): return S(sinogram_col, s) def cost(s): return np.sum((_shift_col(s) - sinogramcomp_col) ** 2) current_shift = float(shift_delta) current_cost = cost(current_shift) for _ in range(_GD_MAX_ITER): c_plus = cost(current_shift + h) c_minus = cost(current_shift - h) grad = (c_plus - c_minus) / (2.0 * h) curvature = (c_plus - 2.0 * current_cost + c_minus) / (h ** 2) if grad == 0.0: break # Newton step if convex, else steepest-descent step of size h if curvature > 0.0: step = float(np.clip(-grad / curvature, -_H_MAX_STEP, _H_MAX_STEP)) else: step = -np.sign(grad) * h # Halving fallback: ensure cost strictly decreases for _ in range(6): candidate = current_shift + step candidate_cost = cost(candidate) if candidate_cost < current_cost: break step *= 0.5 else: # No improvement possible: already at local minimum break current_shift = candidate current_cost = candidate_cost if abs(step) < pixtol: break final_sino = _shift_col(current_shift) return current_shift, final_sino # ============================================================================ # Anderson acceleration helper (optimisation B) # ============================================================================ class _AndersonAccelerator: """ Anderson acceleration (Anderson mixing) for fixed-point iterations. At step k the caller provides the current iterate x_k and the fixed-point image g_k = g(x_k). The accelerator keeps a rolling window of the last m+1 pairs and returns a linear combination that minimises the norm of the stacked residuals — typically cutting outer iterations 2-5×. Reference --------- Walker & Ni, "Anderson Acceleration for Fixed-Point Iterations", SIAM J. Numer. Anal. 49(4), 2011. Parameters ---------- m : int History depth. Default 3. """ def __init__(self, m=3): self.m = m self._G = [] # history of g(x_k), flattened self._F = [] # history of f_k = g(x_k) - x_k, flattened self._shape = None def reset(self): """Clear history (call after a rejected step).""" self._G.clear() self._F.clear() def step(self, x_k, g_k): """ Compute the Anderson-mixed next iterate. Parameters ---------- x_k : ndarray Current iterate (before the fixed-point step). g_k : ndarray Fixed-point image g(x_k). Returns ------- x_next : ndarray Anderson-mixed next iterate, same shape as x_k. """ self._shape = x_k.shape self._G.append(g_k.ravel().copy()) self._F.append((g_k - x_k).ravel()) # Rolling window: keep at most m+1 entries if len(self._G) > self.m + 1: self._G.pop(0) self._F.pop(0) if len(self._G) < 2: return g_k # no history yet — plain fixed-point step # Stack into (n_x, n_hist) matrices F = np.column_stack(self._F) # residuals G = np.column_stack(self._G) # g(x) images # Unconstrained reformulation of the constrained LS: # min_{alpha} || F[:,-1] + dF @ alpha ||^2 # where dF = F[:,:-1] - F[:,-1:] # x_{k+1} = G[:,-1] + dG @ alpha dF = F[:, :-1] - F[:, -1:] dG = G[:, :-1] - G[:, -1:] alpha, _, _, _ = np.linalg.lstsq(dF, -F[:, -1], rcond=None) x_next = G[:, -1] + dG @ alpha return x_next.reshape(self._shape) # ============================================================================ # Wrappers (unchanged logic; they call the new _search_*_direction above) # ============================================================================ def _search_vshift_stack(input_stack, lims, input_delta, avg_vert_fluct, **kwargs): """ Search optimal vertical shifts for all projections in the stack. Applies :func:`_search_vshift_direction` independently to each projection using gradient descent. Parameters ---------- input_stack : ndarray, shape (n, nr, nc) Stack of projection images. lims : tuple of array_like ``(rows, cols)`` — ROI limits. input_delta : ndarray, shape (2, n) Current shift estimates; only row 0 (vertical) is updated. avg_vert_fluct : ndarray, shape (n_rows_roi,) Reference mean vertical fluctuation signal. **kwargs Must contain ``'pixtol'``, ``'shiftmeth'``, and ``'polyorder'``. Returns ------- output_shiftstack : ndarray, shape (2, n) Updated shift array (row 0 = optimised vertical shifts). vert_fluct_stack : ndarray, shape (n, n_rows_roi) Vertical fluctuation signals at the optimised shifts. """ pixtol = kwargs["pixtol"] shift_method = kwargs["shiftmeth"] polyorder = kwargs["polyorder"] rows, cols = lims nprojs, nr, nc = input_stack.shape max_vshift = int(np.ceil(np.max(np.abs(input_delta[0, :])))) + 1 if np.any((rows - max_vshift) < 0) or np.any((rows + max_vshift) > nr): max_vshift = 1 vert_fluct_stack = np.empty((input_stack.shape[0], rows[-1] - rows[0])) output_shiftstack = np.empty_like(input_delta) if not isinstance(input_stack, np.ndarray): input_stack = np.asarray(input_stack).copy() for ii in tqdm(range(nprojs), desc="Searching vertical shifts"): output_shiftstack[0, ii], vert_fluct_stack[ii] = _search_vshift_direction( input_stack[ii], lims, input_delta[0, ii], avg_vert_fluct, pixtol, max_vshift, shift_method, polyorder, ) return output_shiftstack, vert_fluct_stack def _search_hshift_sinogram(sinogram, sinogramcomp, shiftslice, **kwargs): """ Search horizontal shifts for all sinogram columns. Accelerations applied --------------------- A — Columns processed in parallel via ThreadPoolExecutor. E — When shiftmeth=='fourier', the FFT of every sinogram column is pre-computed once per outer iteration and reused across all Newton steps, avoiding a redundant pad + FFT per cost evaluation. Parameters ---------- sinogram : ndarray, shape (nr, nc) sinogramcomp : ndarray, shape (nr, nc) shiftslice : ndarray, shape (1, nc) **kwargs Must contain 'pixtol' and 'shiftmeth'. """ pixtol = kwargs["pixtol"] shift_method = kwargs["shiftmeth"] nr, nc = sinogram.shape sino_out = np.empty_like(sinogram) shiftslice_out = np.empty_like(shiftslice) # E: batch-precompute FFT of all sinogram columns (Fourier shift only) _sino_fft = None _N_fft = None if shift_method == "fourier": padw = int(2 ** np.ceil(np.log2(nr))) - nr # pad to next power-of-2 _padded = np.pad(sinogram, ((0, padw), (0, 0)), mode="reflect") _N_fft = fftfreq(nr + padw) # frequency coordinates _sino_fft = fft(_padded, axis=0) # (nr+padw, nc), batch FFT def _process_col(ii): # A: each thread gets its own ShiftFunc (ShiftFunc stores state in self) S_local = ShiftFunc(shiftmeth=shift_method) fft_col = _sino_fft[:, ii] if _sino_fft is not None else None s, col = _search_hshift_direction( sinogram[:, ii], sinogramcomp[:, ii], shiftslice[0, ii], pixtol, shift_method, S=S_local, sino_fft_col=fft_col, N_fft_col=_N_fft, ) return ii, s, col # A: parallel execution over columns n_workers = min(nc, os.cpu_count() or 1) with ThreadPoolExecutor(max_workers=n_workers) as executor: results = list(tqdm( executor.map(_process_col, range(nc)), total=nc, desc="Searching horizontal shifts", )) for ii, s, col in results: shiftslice_out[0, ii] = s sino_out[:, ii] = col return sino_out, shiftslice_out # ============================================================================ # Vertical alignment (outer loop — unchanged) # ============================================================================ def _alignprojections_vertical( input_stack, lims, shiftstack, metric_error, vert_fluct_init, RP, **params ): """ Iterative outer loop for vertical projection alignment. At each iteration the mean vertical fluctuation is recomputed from the currently shifted stack and :func:`_search_vshift_stack` refines the shifts via gradient descent. Stops when the error increases for two consecutive iterations, when shifts converge within ``pixtol``, or when ``params['maxit']`` iterations are reached. Parameters ---------- input_stack : ndarray, shape (n, nr, nc) Stack of projection images. lims : tuple of array_like ``(rows, cols)`` ROI limits. shiftstack : ndarray, shape (2, n) Initial shift estimates; modified in-place. metric_error : list of float Running list of error values; a new value is appended each iteration. vert_fluct_init : ndarray, shape (n, n_rows_roi) Vertical fluctuations computed before the first iteration. RP : RegisterPlot Plot helper for live visualisation. **params Algorithm parameters including ``'pixtol'``, ``'maxit'``, ``'shiftmeth'``, ``'polyorder'``, ``'subpixel'``. Returns ------- shiftstack : ndarray, shape (2, n) Optimised shift array. metric_error : list of float Updated error history. """ count = 0 error_reg = np.zeros(vert_fluct_init.shape[0]) while True: count += 1 # Capture iteration prints *and* tqdm stderr output to a buffer so # they can be shown in-place via DisplayHandle.update(). # ExitStack lets us conditionally redirect both stdout and stderr # without duplicating the computation block. _buf = io.StringIO() _use_dh = hasattr(RP, '_dh_verbose') and RP._dh_verbose is not None with contextlib.ExitStack() as _stack: if _use_dh: _stack.enter_context(contextlib.redirect_stdout(_buf)) _stack.enter_context(contextlib.redirect_stderr(_buf)) print("\n============================================") print("Iteration {}".format(count)) it0 = time.time() deltaprev = shiftstack.copy() if count == 1: vert_fluct = vert_fluct_init.copy() else: print("Updating the vertical fluctuations") vert_fluct = vertical_fluctuations( input_stack, lims, shiftstack, params["shiftmeth"], polyorder=params["polyorder"] ) vert_fluct_mean = vert_fluct.mean(axis=0) print("Gradient descent search for vertical shifts...") shiftstack_aux, vert_fluct_temp = _search_vshift_stack( input_stack, lims, shiftstack, vert_fluct_mean, **params ) shiftstack[0] = shiftstack_aux[0].copy() shiftstack[0] -= shiftstack_aux[0].mean().round() vert_fluct_mean_temp = vert_fluct_temp.mean(axis=0) print("\nCalculating the error metric") for ii in range(vert_fluct_temp.shape[0]): error_reg[ii] = np.sum(np.abs(vert_fluct_temp[ii] - vert_fluct_mean_temp) ** 2) print("Final error metric for y, E = {:.04e}".format(np.sum(error_reg))) metric_error.append(np.sum(error_reg)) changey = np.abs(deltaprev[0] - shiftstack[0]) print("Estimating the changes in y:") print("Maximum correction in y = {:.02f} pixels".format(np.max(changey))) print("Elapsed time = {} s".format(time.time() - it0)) pixtol = params["pixtol"] if params["subpixel"] else 1 reason = _checkconditions( metric_error, changey, pixtol, count, params["maxit"], params["subpixel"], rtol=params.get("rtol", 0.0), ) if reason == 1: shiftstack = deltaprev.copy() metric_error.pop() # Update verbose text in-place (DisplayHandle.update — no clear_output). if _use_dh: RP._verbose_update(_buf.getvalue()) # Update figure (also uses DisplayHandle.update — no clear_output). RP.plotsvertical( input_stack[0], lims, vert_fluct_init, vert_fluct_temp, shiftstack, metric_error, count, max_correction=float(np.max(changey)), ) if reason >= 1: break return shiftstack, metric_error
[docs] def alignprojections_vertical(input_stack, shiftstack, **params): """ Vertical alignment of projections using the mass-fluctuation approach. Shifts are optimised with a Newton gradient-descent algorithm that minimises the L2 distance between each projection's vertical mass-fluctuation signal and the stack mean. Parameters ---------- input_stack : ndarray, shape (n, nr, nc) Stack of projection images to be aligned. shiftstack : ndarray, shape (2, n) Initial shift estimates ``[vertical_shifts, horizontal_shifts]``. Modified in-place and also returned. **params Algorithm parameters. Required keys: maxit : int Maximum number of outer iterations. pixtol : float Convergence tolerance in pixels. deltax : int Horizontal margin (pixels) to exclude from the ROI. limsy : list of int or None Explicit ``[row_start, row_end]`` row limits. ``None`` uses the full image height. shiftmeth : str Interpolation method (``'linear'``, ``'fourier'``, ``'spline'``). polyorder : int Polynomial order for baseline removal in the fluctuation signal. alignx : bool, optional If ``True``, estimate horizontal shifts from the centre-of-mass before the vertical loop. Default ``False``. subpixel : bool, optional Ignored (always runs at sub-pixel precision). Default ``True``. Returns ------- shiftstack : ndarray, shape (2, n) Optimised shift array. output_stack : ndarray, shape (n, nr, nc) Vertically aligned projection stack. """ if not isinstance(params["maxit"], int): params["maxit"] = 10 try: params["alignx"] except KeyError: params["alignx"] = False limrow, limcol = _selectROI(input_stack.shape, **params) lims = (limrow, limcol) print("\n============================================") print("Vertical Mass fluctuation alignment — Adam gradient descent") print("Number of iterations: {}".format(params["maxit"])) if params["alignx"]: print("Estimating changes in x using center-of-mass:") centerx = center_of_mass_stack(input_stack, lims, shiftstack=shiftstack)[0] shiftstack[1] = -centerx.round() shiftstack[1] -= shiftstack[1].mean().round() vert_fluct_init = vertical_fluctuations( input_stack, lims, shiftstack, params["shiftmeth"], polyorder=params["polyorder"] ) avg_init = vert_fluct_init.mean(axis=0) shiftstack_init = shiftstack.copy() metric_error = [] error_init = np.array( [np.sum(np.abs(vert_fluct_init[ii] - avg_init) ** 2) for ii in range(vert_fluct_init.shape[0])] ) print("Initial error metric for y, E = {:.02e}".format(np.sum(error_init))) metric_error.append(np.sum(error_init)) plt.ion() RP = RegisterPlot(**params) RP.plotsvertical( input_stack[0], lims, vert_fluct_init, vert_fluct_init, shiftstack_init, metric_error, count=0, ) print("\n================================================") print("Vertical alignment (Newton GD, pixtol={})".format(params["pixtol"])) print("================================================") # A single pass is sufficient: the Newton step already operates at # sub-pixel scale (_FD_H_VERT), so a coarse pixel-precision warmup # followed by sub-pixel refinement is redundant. params["subpixel"] = True shiftstack, metric_error = _alignprojections_vertical( input_stack, lims, shiftstack, metric_error, vert_fluct_init, RP, **params ) print("Computing aligned images") output_stack = compute_aligned_stack( input_stack, shiftstack.copy(), shift_method=params["shiftmeth"] ) return shiftstack, output_stack
# ============================================================================ # Horizontal alignment (outer loop — unchanged) # ============================================================================ def _alignprojections_horizontal( sinogram, sino_orig, theta, circleROI, shiftslice, metric_error, RP, **params ): """ Iterative outer loop for horizontal projection alignment. At each iteration a synthetic sinogram is computed from the current reconstruction, horizontal shifts are updated via :func:`_search_hshift_sinogram`, the sinogram is realigned and a new reconstruction is computed. Stops on error divergence, shift convergence, or ``params['maxit']`` iterations. Accelerations active: A — parallel column processing inside :func:`_search_hshift_sinogram`; E — pre-computed FFT of sinogram columns in Fourier-shift mode. Parameters ---------- sinogram : ndarray, shape (nr, nc) Current aligned sinogram. sino_orig : ndarray, shape (nr, nc) Low-pass-filtered original sinogram (shift source). theta : ndarray, shape (nc,) Projection angles in radians. circleROI : ndarray or int Circular mask applied to the reconstruction. shiftslice : ndarray, shape (1, nc) Current horizontal shift estimates. metric_error : list of float Running error history; appended each iteration. RP : RegisterPlot or None Live-plot helper; ``None`` in silent mode. **params Algorithm parameters including ``'pixtol'``, ``'maxit'``, ``'shiftmeth'``, ``'subpixel'``, ``'circle'``, ``'cliplow'``, ``'cliphigh'``, ``'derivatives'``, ``'calc_derivatives'``. Returns ------- shiftslice : ndarray, shape (1, nc) Optimised horizontal shift array. metric_error : list of float Updated error history. """ print("Initializing tomographic slice...") t0 = time.time() recons = tomo_recons(sinogram, theta=theta, **params) recons_std = recons.std() recons = _clipping_tomo(recons, **params) if params["circle"]: recons = recons * circleROI print("Done. Time elapsed: {} s".format(time.time() - t0)) print("Slice standard deviation = {:0.04e}".format(recons_std)) count = 0 while True: count += 1 # Capture iteration prints *and* tqdm stderr to a buffer so they can # be shown in-place via DisplayHandle.update() — no clear_output. # RP may be None in silent / warm-start mode. _buf = io.StringIO() _use_dh = (RP is not None and hasattr(RP, '_dh_verbose') and RP._dh_verbose is not None) with contextlib.ExitStack() as _stack: if _use_dh: _stack.enter_context(contextlib.redirect_stdout(_buf)) _stack.enter_context(contextlib.redirect_stderr(_buf)) print("\nIteration {}".format(count)) print("-------------------------------------") it0 = time.time() deltaprev = shiftslice.copy() print("Computing synthetic sinogram...") sinogramcomp = projector(recons, theta, **params) if params["derivatives"] and not params["calc_derivatives"]: sinogramcomp = derivatives_sino(sinogramcomp, shift_method=params["shiftmeth"]) print("Gradient descent search for horizontal shifts...") sinotempreg, shiftslice = _search_hshift_sinogram( sino_orig, sinogramcomp, shiftslice, **params ) sinogram = compute_aligned_sino(sino_orig, shiftslice, shift_method=params["shiftmeth"]) print("Computing tomographic slice...") t0 = time.time() recons = tomo_recons(sinogram, theta=theta, **params) recons_std = recons.std() recons = _clipping_tomo(recons, **params) if params["circle"]: recons = recons * circleROI print("Done. Time elapsed: {} s".format(time.time() - t0)) print("Slice standard deviation = {:0.04e}".format(recons_std)) errorxreg = _sino_error_metric(sinogram, sinogramcomp, params) sumerrorxreg = errorxreg.sum() print("Final error metric for x, E = {:0.04e}".format(sumerrorxreg)) metric_error.append(sumerrorxreg) changex = np.abs(deltaprev - shiftslice) strprint = "Maximum correction in x = {:0.02f} pixels" if params["subpixel"] \ else "Maximum correction in x = {:0.02g} pixels" print("Estimating the changes in x:") print(strprint.format(np.max(changex))) print("Elapsed time in the iteration= {:0.02f} s".format(time.time() - it0)) pixtol = params["pixtol"] if params["subpixel"] else 1 reason = _checkconditions( metric_error, changex, pixtol, count, params["maxit"], params["subpixel"], rtol=params.get("rtol", 0.0), ) if reason == 1: shiftslice = deltaprev.copy() metric_error.pop() # Update verbose text in-place (DisplayHandle.update — no clear_output). if _use_dh: RP._verbose_update(_buf.getvalue()) # Update figure (also uses DisplayHandle.update — no clear_output). if RP is not None: RP.plotshorizontal( recons, sino_orig, sinogram, sinogramcomp, shiftslice, metric_error, count ) if reason >= 1: break return shiftslice, metric_error
[docs] def alignprojections_horizontal(sinogram, theta, shiftstack, **params): """ Horizontal alignment of projections by tomographic consistency. Shifts are optimised with a Newton gradient-descent algorithm that minimises the L2 distance between the measured sinogram and a synthetic sinogram computed from a filtered-back-projection reconstruction. Supports a multi-stage frequency-cutoff schedule and an optional spatial multiresolution warm-start. Parameters ---------- sinogram : ndarray, shape (nr, nc) Measured sinogram (rows = detector pixels, columns = projections). theta : ndarray, shape (nc,) Projection angles in radians. shiftstack : ndarray, shape (2, n) Current shift estimates ``[vertical_shifts, horizontal_shifts]``. Only the horizontal row (index 1) is updated. **params Algorithm parameters. Required keys: maxit : int Maximum number of outer iterations per stage. pixtol : float Convergence tolerance in pixels. freqcutoff : float Fraction of Nyquist used as the sinogram low-pass cut-off. shiftmeth : str Interpolation method (``'linear'``, ``'fourier'``, ``'spline'``). circle : bool Apply a circular mask to the reconstruction. cliplow : float or None Lower clipping threshold for the reconstruction values. cliphigh : float or None Upper clipping threshold for the reconstruction values. derivatives : bool Whether projections are phase-gradient images. calc_derivatives : bool Whether to compute derivatives on the synthetic sinogram. freqcutoff_schedule : list of float, optional Sequence of ``freqcutoff`` values (coarsest first). Each stage warm-starts the next. Default is ``[params['freqcutoff']]`` (single stage). multiresolution : bool, optional Spatial multiresolution warm-start at stage 0. Default ``False``. mr_factor : int, optional Down-sampling factor for the spatial warm-start. Default ``2``. n_coarse_iter : int, optional Number of coarse iterations in the warm-start. Default ``5``. rtol : float, optional Relative improvement threshold for early stopping. Default ``0``. silent : bool, optional If ``True``, suppress all matplotlib output (required for sub-process workers). Default ``False``. Returns ------- shiftstack : ndarray, shape (2, n) Updated shift array with optimised horizontal shifts in row 1. """ params.setdefault("circle", True) params.setdefault("sinohigh", 0.6) params.setdefault("sinolow", -0.6) params.setdefault("opencl", False) if not isinstance(params["maxit"], int): params["maxit"] = 10 params.setdefault("cliplow", None) params.setdefault("cliphigh", None) params.setdefault("rtol", 0.0) # silent=True: suppress all matplotlib calls (used for parallel workers). params.setdefault("silent", False) # Spatial multiresolution: downsample at stage 0 for a cheap coarse warm-start. params.setdefault("multiresolution", False) params.setdefault("mr_factor", 2) params.setdefault("n_coarse_iter", 5) # Frequency-cutoff schedule: list of freqcutoff values (coarsest first). # Each stage warm-starts the next with its shifts. # If not provided, falls back to a single pass at params["freqcutoff"]. schedule = params.get("freqcutoff_schedule", None) if schedule is None or len(schedule) == 0: schedule = [params["freqcutoff"]] n_stages = len(schedule) print("\nStarting the horizontal alignment (Adam gradient descent)") print("=====================================") print("Number of iterations per stage: {}".format(params["maxit"])) if n_stages > 1: print("Frequency-cutoff schedule: {}".format(schedule)) else: print("Using a frequency cutoff of {}".format(schedule[0])) if params["multiresolution"]: print("Spatial multiresolution warm-start at stage 0 (factor={}, {} iterations)".format( params["mr_factor"], params["n_coarse_iter"])) if params["rtol"] > 0: print("Relative tolerance (rtol) = {}".format(params["rtol"])) print("Low limit for tomo values = {}".format(params["cliplow"])) print("High limit for tomo values = {}".format(params["cliphigh"])) original_sino = sinogram.copy() shiftslice = np.expand_dims(shiftstack[1], axis=0) if not params["silent"]: plt.ion() RP = RegisterPlot(**params) else: RP = None # --------------------------------------------------------------- # Loop over the frequency-cutoff schedule. # Each stage inherits shiftslice from the previous one. # --------------------------------------------------------------- for stage_idx, fc in enumerate(schedule): is_last = (stage_idx == n_stages - 1) params_s = dict(params, freqcutoff=fc, subpixel=True) # Keep RegisterPlot informed of the current stage so the suptitle # can display it (no-op when RP is None or stage_info not used). if RP is not None: RP.stage_info = (stage_idx + 1, n_stages, fc) if n_stages > 1: print("\n╔══════════════════════════════════════════════════════╗") print("║ Stage {}/{} — freqcutoff = {:<30}║".format( stage_idx + 1, n_stages, str(fc) + " ")) print("╚══════════════════════════════════════════════════════╝") padval = int(2 * np.round(1 / fc)) sinogram = np.pad( original_sino, ((padval, padval), (0, 0)), "constant", constant_values=0 ).copy() sino_orig = _filter_sino(sinogram, **params_s) if not np.all(shiftslice == 0): print("Shifting sinogram.") sinogram = compute_aligned_sino( sino_orig, shiftslice, shift_method=params_s["shiftmeth"] ) else: print("Initializing shiftslice with zeros") # ----------------------------------------------------------- # Spatial multiresolution warm-start (stage 0 only). # Downsamples the detector axis by mr_factor, runs a few cheap # coarse iterations, then scales the shifts back to full res. # Combined with freqcutoff_schedule this acts at the coarsest # frequency stage, giving the largest displacement reduction for # the smallest cost. # ----------------------------------------------------------- if stage_idx == 0 and params["multiresolution"]: factor = params["mr_factor"] n_coarse_iter = params["n_coarse_iter"] print("\n--- Spatial warm-start (factor={}, {} coarse iterations) ---".format( factor, n_coarse_iter)) nr, nc_sino = original_sino.shape nr_trim = (nr // factor) * factor sino_pool = original_sino[:nr_trim, :].reshape( nr_trim // factor, factor, nc_sino ).mean(axis=1) sino_pool_padded = np.pad( sino_pool, ((padval, padval), (0, 0)), "constant", constant_values=0 ).copy() sino_orig_c = _filter_sino(sino_pool_padded, **params_s) shiftslice_c = shiftslice / factor if not np.all(shiftslice_c == 0): sinogram_c = compute_aligned_sino( sino_orig_c, shiftslice_c, shift_method=params_s["shiftmeth"] ) else: sinogram_c = sino_orig_c.copy() print("Computing initial coarse tomographic slice...") t_c = time.time() recons_c = tomo_recons(sinogram_c, theta=theta, **params_s) print("Done. Time elapsed: {:.02f} s".format(time.time() - t_c)) recons_c = _clipping_tomo(recons_c, **params_s) circleROI_c = create_circle(recons_c) if params_s["circle"] else 1 recons_c = recons_c * circleROI_c sinogramcomp_c = projector(recons_c, theta, **params_s) if params_s["derivatives"] and not params_s["calc_derivatives"]: sinogramcomp_c = derivatives_sino( sinogramcomp_c, shift_method=params_s["shiftmeth"] ) metric_error_c = [] errorinit_c = _sino_error_metric(sinogram_c, sinogramcomp_c, params_s) metric_error_c.append(np.sum(errorinit_c)) print("Initial coarse error, E= {:0.04e}".format(metric_error_c[0])) params_c = dict(params_s, maxit=n_coarse_iter) shiftslice_c, _ = _alignprojections_horizontal( sinogram_c, sino_orig_c, theta, circleROI_c, shiftslice_c, metric_error_c, RP, **params_c ) shiftslice = shiftslice_c * factor sinogram = compute_aligned_sino( sino_orig, shiftslice, shift_method=params_s["shiftmeth"] ) print("--- Spatial warm-start done. Proceeding at full resolution ---\n") print("Computing initial tomographic slice...") t0 = time.time() recons = tomo_recons(sinogram, theta=theta, **params_s) print("Done. Time elapsed: {:.02f} s".format(time.time() - t0)) print("Slice standard deviation = {:0.04e}".format(recons.std())) recons = _clipping_tomo(recons, **params_s) circleROI = create_circle(recons) if params_s["circle"] else 1 recons = recons * circleROI print("Computing synthetic sinogram...") t0 = time.time() sinogramcomp = projector(recons, theta, **params_s) if params_s["derivatives"] and not params_s["calc_derivatives"]: sinogramcomp = derivatives_sino(sinogramcomp, shift_method=params_s["shiftmeth"]) print("Done. Time elapsed: {:.02f} s".format(time.time() - t0)) metric_error = [] errorinit = _sino_error_metric(sinogram, sinogramcomp, params_s) print("Initial error metric, E= {:0.04e}".format(np.sum(errorinit))) metric_error.append(np.sum(errorinit)) if RP is not None: RP.plotshorizontal( recons, sino_orig, sinogram, sinogramcomp, shiftslice, metric_error, count=0 ) print("\n===================================================") print("Horizontal alignment (Newton GD, pixtol={})".format(params_s["pixtol"])) print("===================================================") shiftslice, metric_error = _alignprojections_horizontal( sinogram, sino_orig, theta, circleROI, shiftslice, metric_error, RP, **params_s ) if not is_last: print("Stage {} converged. Handing shifts to next stage.\n".format(stage_idx + 1)) shiftstack[1] = shiftslice if not params["silent"]: print("\nComputing aligned images") alignedsinogram = compute_aligned_sino( original_sino, shiftslice, shift_method=params["shiftmeth"] ) print("Calculating aligned slice for display") _oneslicefordisplay(alignedsinogram, theta, **params) return shiftstack
# ============================================================================ # Remaining public functions (unchanged) # ============================================================================
[docs] def refine_horizontalalignment(input_stack, theta, shiftstack, **params): """ Run one refinement pass of horizontal alignment. Refinement uses the current ``freqcutoff`` only — ``freqcutoff_schedule`` and ``multiresolution`` are ignored because the input shifts are already a good warm-start. Use ``params["rtol"]`` (e.g. 1e-3) to stop early when the improvement per iteration becomes negligible. To run multiple refinement passes or change parameters between passes, update ``params`` in your script/notebook and call this function again. Parameters ---------- input_stack : ndarray, shape (n, nr, nc) Stack of derivative projections. theta : ndarray Projection angles. shiftstack : ndarray, shape (2, n) Current shift array (modified in-place and returned). **params Alignment parameters forwarded to :func:`alignprojections_horizontal`. Returns ------- shiftstack : ndarray params : dict """ params.setdefault("correct_bad", False) sinogram = np.transpose(input_stack[:, params["slicenum"], :]) if params["correct_bad"]: sinogram = replace_bad(sinogram, list_bad=params["bad_projs"], temporary=False) # Refinement uses the current freqcutoff only — no multi-stage # pipeline and no spatial downsampling (shifts are already good). params_refine = dict(params) params_refine.pop("freqcutoff_schedule", None) params_refine["multiresolution"] = False print("\n================================================") print("Starting the refinement of the alignment") print("================================================") shiftstack = alignprojections_horizontal(sinogram, theta, shiftstack, **params_refine) return shiftstack, params
[docs] def oneslicefordisplay(sinogram, theta, **params): """Reconstruct and display one tomographic slice. Pass updated values via ``params`` (e.g. ``params["freqcutoff"] = 0.5``) and re-call to change reconstruction settings. """ print( "Reconstructing slice with freqcutoff={}, filtertype='{}' …".format( params.get("freqcutoff", "?"), params.get("filtertype", "?") ), flush=True, ) _oneslicefordisplay(sinogram, theta, **params)
def _oneslicefordisplay(sinogram, theta, **params): """ Reconstruct and display a single tomographic slice. Parameters ---------- sinogram : ndarray, shape (nr, nc) Sinogram to reconstruct. theta : ndarray, shape (nc,) Projection angles in radians. **params Must contain ``'circle'`` (bool), ``'cliplow'`` (float or None), ``'cliphigh'`` (float or None), and ``'colormap'`` (str). All remaining keys are forwarded to :func:`tomo_recons`. """ p0 = time.time() recons = tomo_recons(sinogram, theta=theta, **params) recons = _clipping_tomo(recons, **params) circleROI = create_circle(recons) if params["circle"] else 1 recons = recons * circleROI print("Done. Time elapsed: {} s".format(time.time() - p0)) display_slice(recons, colormap="bone", vmin=params["cliplow"], vmax=params["cliphigh"]) def _tc_worker(args): """ Process-pool worker for tomoconsistency_multiple. Each worker aligns one sinogram slice independently. All per-slice console output is suppressed so only the parent's overall progress bar is visible. Runs silently (no matplotlib display) so it is safe inside a subprocess. Parameters ---------- args : tuple (slice_index, sinogram_2d, theta, shiftstack_copy, params_dict) Returns ------- (slice_index, shift_array) """ import contextlib import os # matplotlib must be imported AFTER the process is spawned so we can # set the non-interactive backend before any display code runs. import matplotlib matplotlib.use("Agg") ii, sinogram, theta, shiftstack_copy, params_w = args # Redirect stdout and stderr to /dev/null for the duration of this slice. # contextlib restores them correctly even in the sequential (same-process) # case, so the parent's tqdm bar is never disrupted. with open(os.devnull, "w") as devnull: with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull): shiftstack_aux = alignprojections_horizontal( sinogram, theta, shiftstack_copy, **params_w ) return ii, shiftstack_aux[1]
[docs] def tomoconsistency_multiple(input_stack, theta, shiftstack, **params): """ Tomographic consistency alignment over multiple sinogram slices. Each slice is aligned independently (using the same ``shiftstack`` warm-start) and the resulting per-slice horizontal shifts are averaged to produce a robust final estimate. Slices can be processed in parallel via :class:`~concurrent.futures.ProcessPoolExecutor`. Parameters ---------- input_stack : ndarray, shape (n, nr, nc) Full 3-D projection stack. theta : ndarray, shape (n,) Projection angles in radians. shiftstack : ndarray, shape (2, n) Current shift estimates used as a warm-start for every slice. The horizontal row (index 1) is updated with the averaged result. **params Algorithm parameters forwarded to :func:`alignprojections_horizontal`. Extra keys: slicenum : int Central slice index. n_slices_tc : int, optional Number of slices to process (centred on ``slicenum``). Default ``10``. n_workers_tc : int, optional Number of parallel worker processes. Default ``max(1, cpu_count // 2)``. Pass ``1`` for sequential. Returns ------- shiftstack : ndarray, shape (2, n) Updated shift array with the averaged horizontal shifts in row 1 (or the original shifts if the user declines the result). """ print("Starting Tomographic consistency on multiple slices") slicenumorig = params["slicenum"] n_slices_tc = params.get("n_slices_tc", 10) half = n_slices_tc // 2 slices = np.arange(slicenumorig - half, slicenumorig - half + n_slices_tc) shiftslice_prev = np.expand_dims(shiftstack[1], axis=0).copy() # Single-pass params: strip schedule and spatial MR (warm-start is good). # silent=True suppresses matplotlib — required for subprocess workers. params_tc = dict(params) params_tc.pop("freqcutoff_schedule", None) params_tc["multiresolution"] = False params_tc["silent"] = True n_cpu = os.cpu_count() or 1 n_workers_tc = params.get("n_workers_tc", max(1, n_cpu // 2)) n_workers_tc = min(n_workers_tc, len(slices)) print("Slices: {} (n={})".format(list(slices), len(slices))) print("Parallel workers: {}".format(n_workers_tc)) # Build one argument tuple per slice (sinogram pre-extracted to avoid # pickling the full input_stack in every worker). task_args = [] for ii in slices: sino = np.transpose(input_stack[:, ii, :]) p = dict(params_tc, slicenum=int(ii)) task_args.append((int(ii), sino, theta, shiftstack.copy(), p)) if n_workers_tc == 1: # Sequential fallback — useful for debugging or single-core machines. results = [] for args in tqdm(task_args, desc="Tomographic consistency"): results.append(_tc_worker(args)) else: with ProcessPoolExecutor(max_workers=n_workers_tc) as executor: results = list( tqdm( executor.map(_tc_worker, task_args), total=len(task_args), desc="Tomographic consistency", ) ) # Sort by slice index (process pool may return out of order). results.sort(key=lambda r: r[0]) shiftxrefine = [r[1] for r in results] shiftxrefine = np.squeeze(shiftxrefine) shiftxrefine_avg = shiftxrefine.mean(axis=0) # Build the results figure using the matplotlib OO API (Figure, not # plt.subplots). This bypasses pyplot's display machinery and the ipympl # canvas-manager hooks entirely, so fig.canvas.manager is never needed. # The Agg canvas is attached explicitly; fig.savefig() renders directly # via the Agg software renderer and writes a valid PNG regardless of # whether plt.ion() has been called or a Jupyter display manager exists. # This is the same technique used by web frameworks to render figures # headlessly. plt.close("all") if isnotebook(): import matplotlib.figure as _mf_tc from matplotlib.backends.backend_agg import FigureCanvasAgg as _FCAgg_tc from IPython import display as _disp_tc fig = _mf_tc.Figure(figsize=(14, 8)) _FCAgg_tc(fig) # attach Agg canvas — no manager needed ax1, ax2 = fig.subplots(2, 1) else: fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8)) ax1.imshow(shiftxrefine, interpolation="none", cmap="jet") ax1.axis("tight") ax1.set_xlabel("Projection number") ax1.set_ylabel("Slice number") ax1.set_title("Displacements in x per slice") ax2.plot(shiftxrefine_avg, "b-", label="average") ax2.plot(shiftslice_prev[0], "r--", label="previous") ax2.legend() ax2.axis("tight") ax2.set_xlim([0, len(shiftxrefine_avg)]) ax2.set_title("Average displacements in x — blue=new average, red=previous") ax2.set_xlabel("Projection number") fig.tight_layout() if isnotebook(): _buf_tc = io.BytesIO() fig.savefig(_buf_tc, format="png", bbox_inches="tight", dpi=100) _buf_tc.seek(0) _disp_tc.display(_disp_tc.Image(_buf_tc.read())) else: plt.show(block=False) use_average = params.get("use_average", True) if use_average: shiftstack[1] = shiftxrefine_avg.copy() print( "Averaged shifts applied.\n" "Tip: to keep the previous shifts instead, add\n" " params['use_average'] = False\n" "to your params dict before calling tomoconsistency_multiple.\n" "Note: a backup copy is not made automatically — save\n" " shiftstack_backup = shiftstack.copy()\n" "before this call if you want to be able to revert.", flush=True, ) else: shiftstack[1] = shiftslice_prev[0].copy() print("Keeping previous shiftstack (use_average=False).", flush=True) return shiftstack
class _RotAxisPicker: """ Interactive GUI for estimating the rotation-axis offset. Displays the reconstructed slice (left) and sinogram (right) for the current offset. A text box lets you try a new integer offset value; clicking **Update** recomputes and refreshes both panels. Click **Confirm** to accept the current value and close the figure. Terminal usage -------------- Used internally by :func:`estimate_rot_axis` — not meant to be called directly. Jupyter two-cell workflow ------------------------- :: # Cell 1 picker = estimate_rot_axis(valigndiff, theta, **params) # Cell 2 — after clicking Confirm params["rot_axis_offset"] = picker.rot_axis_offset """ def __init__(self, input_array, theta, **params): self._array = input_array self._theta = theta self._params = params self.rot_axis_offset = params["rot_axis_offset"] self._done = False params.setdefault("sinocmap", params.get("colormap", "bone")) # ---- initial reconstruction ---- slicenum = params["slicenum"] sinogram = np.transpose(input_array[:, slicenum, :]) sinogram = _offset_sinogram(sinogram, self.rot_axis_offset) print("Computing initial slice (offset={}) …".format(self.rot_axis_offset), flush=True) p0 = time.time() tomogram = tomo_recons(sinogram, theta, **params) print("Done in {:.1f} s.".format(time.time() - p0), flush=True) # ---- figure ---- plt.close("all") fig = plt.figure(num=5, figsize=(12, 5)) # leave room at the bottom for the TextBox + buttons fig.subplots_adjust(bottom=0.22) ax1 = fig.add_subplot(121) self._im1 = ax1.imshow( tomogram, cmap=params["colormap"], interpolation="none", vmin=params["cliplow"], vmax=params["cliphigh"], ) self._ax1 = ax1 ax1.set_title("Slice {} — offset {}".format(slicenum, self.rot_axis_offset)) fig.colorbar(self._im1, ax=ax1) ax2 = fig.add_subplot(122) self._im2 = ax2.imshow( sinogram, cmap=params["sinocmap"], interpolation="none", vmin=params["sinolow"], vmax=params["sinohigh"], ) ax2.axis("tight") ax2.set_title("Sinogram — slice {}".format(slicenum)) fig.colorbar(self._im2, ax=ax2) # status text self._status = fig.text( 0.5, 0.97, "Enter a new offset → Update, or Confirm to accept current value.", ha="center", va="top", fontsize=9, color="steelblue", ) # TextBox for offset value ax_box = fig.add_axes([0.25, 0.04, 0.20, 0.07]) ax_update = fig.add_axes([0.50, 0.04, 0.12, 0.07]) ax_conf = fig.add_axes([0.65, 0.04, 0.12, 0.07]) self._textbox = TextBox(ax_box, "Offset: ", initial=str(self.rot_axis_offset)) self._btn_update = Button(ax_update, "Update") self._btn_confirm = Button(ax_conf, "Confirm") self._btn_update.on_clicked(self._on_update) self._btn_confirm.on_clicked(self._on_confirm) self.fig = fig print( "Inspect the slice. Enter a new offset value in the text box " "and click Update to recompute, then Confirm to accept.", flush=True, ) # ------------------------------------------------------------------ events def _on_update(self, event): try: new_offset = int(float(self._textbox.text.strip())) except ValueError: self._status.set_text("Invalid offset — enter an integer.") self.fig.canvas.draw_idle() return self.rot_axis_offset = new_offset slicenum = self._params["slicenum"] sinogram = np.transpose(self._array[:, slicenum, :]) sinogram = _offset_sinogram(sinogram, new_offset) self._status.set_text("Computing slice (offset={}) …".format(new_offset)) self.fig.canvas.draw_idle() print("Computing slice (offset={}) …".format(new_offset), flush=True) p0 = time.time() tomogram = tomo_recons(sinogram, self._theta, **self._params) print("Done in {:.1f} s.".format(time.time() - p0), flush=True) self._im1.set_data(tomogram) self._im1.autoscale() self._ax1.set_title("Slice {} — offset {}".format(slicenum, new_offset)) self._im2.set_data(sinogram) self._status.set_text( "Offset {} — click Confirm to accept or enter another value.".format(new_offset) ) self.fig.canvas.draw_idle() def _on_confirm(self, event): self._done = True print( "Confirmed rotation-axis offset: {}".format(self.rot_axis_offset), flush=True, ) plt.close(self.fig)
[docs] def estimate_rot_axis(input_array, theta, **params): """ Interactively estimate the rotation-axis offset. Opens a GUI figure showing the reconstructed slice and sinogram for the current offset (``params["rot_axis_offset"]``). Enter a new integer value in the text box and click **Update** to recompute. Click **Confirm** when satisfied. Parameters ---------- input_array : ndarray, shape (n, nr, nc) Stack of derivative projections. theta : ndarray Projection angles (degrees). **params Must contain: ``slicenum``, ``rot_axis_offset``, ``colormap``, ``cliplow``, ``cliphigh``, ``sinolow``, ``sinohigh``, ``filtertype``, ``freqcutoff``, ``circle``, ``algorithm``, ``derivatives``, ``calc_derivatives``. Returns ------- Terminal mode ``rot_axis_offset`` (int) — the confirmed offset value. Jupyter mode (two-cell workflow) ``picker`` (:class:`_RotAxisPicker`) — access ``picker.rot_axis_offset`` in the **next** cell after clicking Confirm. """ params.setdefault("sinocmap", params.get("colormap", "bone")) theta = theta - theta.min() picker = _RotAxisPicker(input_array, theta, **params) if isnotebook(): try: import matplotlib as _mpl _interactive = "inline" not in _mpl.get_backend().lower() except Exception: _interactive = False if _interactive: plt.show(block=False) picker.fig.canvas.draw() print( "\nJupyter two-cell workflow:\n" " Interact with the figure, then in the NEXT cell:\n" " params['rot_axis_offset'] = picker.rot_axis_offset", flush=True, ) else: from IPython import display as _ipy_display _ipy_display.display(picker.fig) return picker else: plt.show(block=True) print( "Initial estimate of rotation axis offset: {}".format( picker.rot_axis_offset ) ) return picker.rot_axis_offset
@deprecated def _offset_sinogram_old(sinogram, offset): """ Pad a sinogram to offset the rotation axis (deprecated). Parameters ---------- sinogram : ndarray, shape (nr, nc) Input sinogram. offset : int Rotation-axis offset in pixels. Positive values pad the bottom; negative values pad the top. Returns ------- sinogram : ndarray Zero-padded sinogram. """ if np.sign(offset) == +1: print("Initial guess of the rotation axis offset : {}".format(offset)) sinogram = np.pad(sinogram, ((0, 2 * abs(offset)), (0, 0)), "constant", constant_values=0) elif np.sign(offset) == -1: print("Initial guess of the rotation axis offset : {}".format(offset)) sinogram = np.pad(sinogram, ((2 * abs(offset), 0), (0, 0)), "constant", constant_values=0) return sinogram def _offset_sinogram(sinogram, offset, shift_method="linear"): """ Shift a sinogram vertically to offset the rotation axis. Parameters ---------- sinogram : ndarray, shape (nr, nc) Input sinogram. offset : float Rotation-axis offset in pixels (positive = shift down). shift_method : str, optional Interpolation method. Default ``'linear'``. Returns ------- ndarray, shape (nr, nc) Vertically shifted sinogram. """ S = ShiftFunc(shiftmeth="linear") return S(sinogram, (offset, 0))