Source code for toupy.restoration.ring_correction

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Ring artifact correction for tomographic sinograms.

Two methods are provided:

**Wavelet-FFT method (Münch et al., 2009)**
    Decomposes the sinogram with a Haar wavelet along the detector-pixel
    axis, then — for every sub-band including the final approximation —
    estimates the ring contribution from the per-pixel angular mean and
    removes it by Gaussian background subtraction.  The result is
    equivalent to the multi-scale stripe filter described by Münch et al.
    but is implemented without the FFT-based DC suppression that would
    also remove real mean-projection signal.

    Reference:
        Münch, B., Trtik, P., Marone, F., & Stampanoni, M. (2009).
        Stripe and ring artifact removal with combined wavelet — Fourier
        filtering. Optics Express, 17(10), 8567–8591.
        https://doi.org/10.1364/OE.17.008567

**Titarenko method (Titarenko et al., 2010)**
    Lightweight correction based on angular block-mean statistics.  Works
    well for narrow, isolated rings.

    Reference:
        Titarenko, V., Titarenko, S., Withers, P. J., De Carlo, F., &
        Xiao, X. (2010).  Applied Physics Letters, 97(7), 073905.
        https://doi.org/10.1063/1.3480961

GPU notes
---------
Both functions accept ``cuda=True``.  The Haar decomposition and
reconstruction are fully vectorised over detector columns and map
directly onto CuPy array operations.
"""

import warnings
import numpy as np
from scipy.ndimage import gaussian_filter1d, median_filter

# ---------------------------------------------------------------------------
# Optional CuPy import — graceful CPU fallback
# ---------------------------------------------------------------------------
try:
    import cupy as cp
    CUDA_AVAILABLE = True
except ImportError:
    CUDA_AVAILABLE = False

__all__ = ["remove_rings_wavelet_fft", "remove_rings_titarenko",
           "remove_rings_stack"]


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------

def _get_xp(cuda):
    if cuda:
        if not CUDA_AVAILABLE:
            warnings.warn(
                "CuPy not available — falling back to CPU ring correction.",
                stacklevel=3,
            )
            return np
        return cp
    return np


def _haar_dec(cH, xp):
    """Vectorised one-level Haar decomposition along axis 0."""
    a = (cH[0::2, :] + cH[1::2, :]) * 0.5
    d = (cH[0::2, :] - cH[1::2, :]) * 0.5
    return a, d


def _haar_rec(a, d, xp):
    """Vectorised one-level Haar reconstruction along axis 0."""
    rec = xp.empty((a.shape[0] * 2, a.shape[1]), dtype=a.dtype)
    rec[0::2, :] = a + d
    rec[1::2, :] = a - d
    return rec


def _correct_band(band, win_pix, xp):
    """Subtract ring stripes from a wavelet sub-band (angular-mean method).

    Ring artifacts are, by definition, **constant across all projection
    angles** — an additive offset at a fixed detector pixel.  For each
    band-pixel we therefore estimate the stripe as

        stripe[p] = mean_over_angles(band[p, :])
                    - smooth_baseline(mean_over_angles(band)[p])

    where the smooth baseline (a wide median filter along the *pixel*
    axis) follows the real object profile but ignores the narrow ring
    spikes.  Subtracting ``stripe`` from every angle removes the ring.

    This operation is *safe*: because the subtracted quantity is constant
    along the angle axis, it can never remove or distort the real,
    angularly-varying projection signal — unlike Fourier DC suppression,
    which leaks into low (but non-zero) angular frequencies.

    Parameters
    ----------
    band : ndarray, shape (n_pix_band, n_angles)
    win_pix : int
        Median-filter window (in band-pixel units) used to estimate the
        smooth object baseline.  Must be >> ring width but < object scale.
    xp : numpy or cupy module
    """
    # Per-pixel angular mean: real smooth profile + sharp ring spikes
    dc = band.mean(axis=1)                          # (n_pix_band,)
    dc_cpu = cp.asnumpy(dc) if xp is not np else np.asarray(dc)

    # Robust smooth baseline (median ignores the narrow ring spikes)
    win = max(3, int(win_pix))
    if win % 2 == 0:
        win += 1
    base = median_filter(dc_cpu, size=win)

    stripe = dc_cpu - base                          # the ring component
    stripe = xp.asarray(stripe) if xp is not np else stripe
    return band - stripe[:, None]


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] def remove_rings_wavelet_fft( sinogram, level=3, sigma=3, wname="haar", cuda=False, ): """ Remove ring artifacts using multi-scale wavelet stripe subtraction. The sinogram is decomposed with a Haar wavelet along the detector-pixel axis (axis 0) for *level* levels. At each scale — including the final approximation band — a per-pixel angular mean is computed and a Gaussian-smoothed background is subtracted to isolate the ring contribution. The cleaned bands are then reconstructed. Parameters ---------- sinogram : ndarray, shape (n_pixels, n_angles) Input sinogram. axis 0 = detector pixels; axis 1 = angles. level : int, optional Number of Haar decomposition levels. Default 3. sigma : float, optional Expected maximum ring width in **original detector-pixel** units. The Gaussian background estimation uses a smoothing sigma of ``max(nrow / 10, sigma * 10)`` pixels (in the original sinogram), scaled proportionally for each wavelet band so that the separation of ring spikes from real-signal background is consistent across all scales. Default 3. wname : str, optional Wavelet family. Only ``'haar'`` is supported. Default ``'haar'``. cuda : bool, optional Use GPU via CuPy. Default False. Returns ------- corrected : ndarray, shape (n_pixels, n_angles) Ring-corrected sinogram (CPU numpy array). Notes ----- The method suppresses **additive** stripes (constant angular offset at specific detector pixels). Multiplicative non-uniformity should be removed by flat-field normalisation first. Examples -------- >>> import numpy as np >>> from toupy.restoration import remove_rings_wavelet_fft >>> sino = np.random.rand(363, 180) >>> sino_clean = remove_rings_wavelet_fft(sino, level=3, sigma=3) """ xp = _get_xp(cuda) sino_cpu = np.asarray(sinogram, dtype=float) nrow, ncol = sino_cpu.shape use_gpu = cuda and CUDA_AVAILABLE arr = xp.asarray(sino_cpu) if use_gpu else sino_cpu.copy() # Pad to the next power-of-2 along the pixel axis for clean decomposition target = int(2 ** np.ceil(np.log2(max(nrow, 2)))) cH = xp.pad(arr, ((0, target - nrow), (0, 0)), mode="reflect") details = [] band_sizes = [] for _ in range(level): if cH.shape[0] < 4: details.append(None) band_sizes.append(None) continue a, d = _haar_dec(cH, xp) details.append(d) band_sizes.append(d.shape[0]) cH = a # ------------------------------------------------------------------ # Ring correction: angular-mean stripe subtraction on EVERY band. # # Because the subtracted stripe is constant along the angle axis, the # operation is safe on all bands — including the final approximation, # which holds the smooth real-signal background. (The old FFT-DC # approach leaked into low angular frequencies and destroyed real # signal; it has been replaced.) # # The median window is set once in original-pixel units and scaled # down for each (decimated) wavelet band so the physical separation # of ring spikes from object features is consistent across scales. # ------------------------------------------------------------------ win_full = max(int(target / 10), int(sigma) * 10) # original-pixel units # Correct the final approximation band (size = target / 2**level) win_approx = max(3, int(win_full * cH.shape[0] / target)) cH = _correct_band(cH, win_approx, xp) # Correct each detail band with a band-scaled window damped = [] for d, nb in zip(details, band_sizes): if d is None: damped.append(None) continue win_band = max(3, int(win_full * nb / target)) damped.append(_correct_band(d, win_band, xp)) # Reconstruct result = cH for d in reversed(damped): if d is None: continue result = _haar_rec(result, d, xp) out = result[:nrow, :] return cp.asnumpy(out) if use_gpu else np.asarray(out)
[docs] def remove_rings_titarenko(sinogram, size=31, cuda=False): """ Remove ring artifacts using median-filter background subtraction. Lightweight single-scale method: compute the per-pixel angular mean, median-filter it to estimate the smooth background, and subtract the residual (ring) from every angle. Equivalent to the wavelet method with level=0 (no decomposition). Parameters ---------- sinogram : ndarray, shape (n_pixels, n_angles) size : int, optional Median filter window size in pixels. Should be >> ring width but << real-signal variation scale. Typical range: 21–51. Default 31. cuda : bool, optional Use GPU. Default False. Returns ------- corrected : ndarray, shape (n_pixels, n_angles) Examples -------- >>> import numpy as np >>> from toupy.restoration import remove_rings_titarenko >>> sino = np.random.rand(363, 180) >>> sino_clean = remove_rings_titarenko(sino, size=31) """ xp = _get_xp(cuda) sino_cpu = np.asarray(sinogram, dtype=float) nrow, ncol = sino_cpu.shape use_gpu = cuda and CUDA_AVAILABLE arr = xp.asarray(sino_cpu) if use_gpu else sino_cpu.copy() # Per-pixel angular mean dc = arr.mean(axis=1) # (nrow,) # Median-filter to estimate smooth background smooth = median_filter(np.asarray(dc), size=max(3, int(size))) if use_gpu: smooth = xp.asarray(smooth) # Subtract ring residual from every angle ring = dc - smooth out = arr - ring[:, None] return cp.asnumpy(out) if use_gpu else np.asarray(out)
[docs] def remove_rings_stack(stack, method="wavelet_fft", cuda=False, **kwargs): """ Apply ring correction to a full 3-D sinogram stack. Parameters ---------- stack : ndarray, shape (n_pixels, n_slices, n_angles) method : {'wavelet_fft', 'titarenko'} cuda : bool, optional **kwargs : forwarded to the chosen correction function. Returns ------- corrected : ndarray, same shape as *stack* """ if method == "wavelet_fft": fn = remove_rings_wavelet_fft elif method == "titarenko": fn = remove_rings_titarenko else: raise ValueError("Unknown method '{}'".format(method)) out = np.empty_like(stack) n_slices = stack.shape[1] for s in range(n_slices): out[:, s, :] = fn(stack[:, s, :], cuda=cuda, **kwargs) return out