Source code for toupy.resolution.decorr

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Single-image resolution estimation via decorrelation analysis.
"""

# standard library
import time
import warnings

# third party
import numpy as np
from scipy.fft import fft2, ifft2, fftfreq

# local
from ..utils.plot_utils import plt, isnotebook

__all__ = ["ImageDecorr"]


[docs] class ImageDecorr: """ Single-image resolution estimation via decorrelation analysis [1]_. For each radial spatial frequency *r* the image is correlated with its phase-normalised ring-filtered self. In the signal-dominated band the phase is coherent and the correlation is high; beyond the resolution limit the Fourier phases are noise-dominated and the correlation drops. The highest spatial frequency at which the normalised correlation exceeds ``threshold`` is taken as the resolution limit. No second image or half-dataset is required — the estimate is obtained from a **single** 2-D image (or from a set of 2-D slices sampled from a 3-D volume). Parameters ---------- image : ndarray A 2-dimensional array (single image) **or** a 3-dimensional array (tomographic volume). When a 3-D volume is supplied the algorithm is applied slice-by-slice along ``axis`` and summary statistics are reported. pixel_size : float, optional Physical size of one pixel (in any consistent unit, e.g. nm). Used to convert the resolution from pixels to physical units. Default ``1.0`` (result in pixels). n_r : int, optional Number of radial frequency bins between the lowest non-zero frequency and the Nyquist limit (0.5 cycles/pixel). Default ``100``. threshold : float, optional Correlation threshold used to define the resolution limit. The default ``0.15`` matches the FSC/FRC half-bit criterion. apod_width : int, optional Width in pixels of the Hanning apodization applied to the image edges before computing the FFT. Set to ``0`` to disable apodization. Default ``20``. axis : int, optional Axis along which slices are taken when ``image`` is 3-D. Default ``0`` (first axis, i.e. slices are ``image[i, :, :]``). n_slices : int, optional Number of evenly-spaced slices to sample from the 3-D volume (or from the sub-range defined by ``slice_range``). Ignored when ``image`` is 2-D. Default ``10``. slice_range : tuple of int or None, optional ``(start, stop)`` slice indices along ``axis`` that define the sub-range from which slices are sampled, following standard Python / NumPy half-open convention (``start`` inclusive, ``stop`` exclusive). Negative indices are supported and are resolved against the axis length before use. ``None`` (default) samples the full extent of the volume. Attributes ---------- r_values : ndarray Radial spatial frequencies (cycles/pixel) at which the correlation was evaluated. For 3-D input this is taken from the last slice processed. A : ndarray Normalised ring correlation A(r). For 3-D input: result of the last slice processed. d : ndarray Decorrelation function d(r) = 1 − A(r). For 3-D input: result of the last slice processed. r_res : float Estimated resolution spatial frequency (cycles/pixel). For 3-D input: median over all sampled slices. resolution_px : float Estimated resolution in pixels (= 1 / r_res). For 3-D input: median over all sampled slices. resolution : float Estimated resolution in physical units (= pixel_size / r_res). For 3-D input: median over all sampled slices. resolutions_px : ndarray or None Per-slice resolution estimates in pixels. ``None`` for 2-D input. resolution_px_mean : float or None Mean per-slice resolution in pixels. ``None`` for 2-D input. resolution_px_median : float or None Median per-slice resolution in pixels. ``None`` for 2-D input. resolution_px_std : float or None Standard deviation of per-slice resolutions in pixels. ``None`` for 2-D input. References ---------- .. [1] A. Descloux, K. S. Grußmayer, and A. Radenovic, "Parameter-free image resolution estimation based on decorrelation analysis", Nature Methods 16, 918-924 (2019). https://doi.org/10.1038/s41592-019-0515-7 """ def __init__( self, image, pixel_size=1.0, n_r=100, threshold=0.15, apod_width=20, axis=0, n_slices=10, slice_range=None, ): print("Calling the class ImageDecorr") self.image = np.asarray(image, dtype=np.float64) if self.image.ndim not in (2, 3): raise ValueError("ImageDecorr requires a 2-D or 3-D array.") self.pixel_size = float(pixel_size) self.n_r = int(n_r) self.threshold = float(threshold) self.apod_width = int(apod_width) self.axis = int(axis) self.n_slices = int(n_slices) self.slice_range = slice_range # validated in _run_3d # Initialise per-slice attributes (populated only for 3-D input) self.resolutions_px = None self.resolution_px_mean = None self.resolution_px_median = None self.resolution_px_std = None # Warning flag: fire at most once per instance self._tomo_warned = False p0 = time.time() if self.image.ndim == 2: self._run_2d(self.image) else: self._run_3d() print(f"Done. Time elapsed: {time.time() - p0:.2f}s") # ------------------------------------------------------------------ # Top-level dispatch # ------------------------------------------------------------------ def _warn_if_tomo(self, img): """ Emit a ``UserWarning`` if *img* looks like a tomographic slice. The check tests whether the fraction of negative-valued pixels exceeds 5 %. Tomographic reconstructions (especially filtered back-projection) typically contain many negative values due to correlated noise, whereas raw 2-D microscopy images (for which the decorrelation algorithm was designed) rarely do. The warning is emitted at most **once per instance** — subsequent calls after the first warning are silently ignored. Parameters ---------- img : ndarray Two-dimensional array to test. Warns ----- UserWarning If the negative-pixel fraction exceeds 0.05 and the warning has not already been issued for this instance. """ if self._tomo_warned: return neg_frac = float(np.sum(img < 0)) / img.size if neg_frac > 0.05: warnings.warn( "ImageDecorr: the supplied image appears to be a tomographic " "reconstruction slice (negative-pixel fraction = " f"{neg_frac:.1%} > 5 %). " "The decorrelation algorithm was designed for raw 2-D " "microscopy images; applied to reconstructed slices it gives " "unreliable results because correlated noise biases the " "ring-correlation toward low spatial frequencies. " "For tomographic resolution estimation use FSC/SSNR or " "LocalFSC instead.", UserWarning, stacklevel=3, ) self._tomo_warned = True def _run_2d(self, img2d): """Run the analysis on a single 2-D image and store results.""" self._warn_if_tomo(img2d) self.nr, self.nc = img2d.shape print(f" Image size : {self.nr} × {self.nc} pixels") print(f" Pixel size : {self.pixel_size}") print(f" Threshold : {self.threshold}") # Temporarily point self.image at the 2-D slice so _apodize works _saved = self.image self.image = img2d self.r_values, self.A, self.d = self._compute() self.r_res, self.resolution_px, self.resolution = self._find_resolution() self.image = _saved print( f" Resolution : {self.resolution_px:.1f} px " f"({self.resolution:.4g} in physical units)" ) def _run_3d(self): """ Apply the 2-D algorithm to evenly-spaced slices along ``self.axis``. The set of candidate slice indices is taken from ``self.slice_range`` when provided, otherwise the full axis length is used. ``self.n_slices`` slices are sampled evenly from that range. Summary statistics (mean, median, std) are stored as instance attributes. ``self.r_res``, ``self.resolution_px``, and ``self.resolution`` are set to the median values so that the scalar interface remains consistent. """ vol = self.image n_ax = vol.shape[self.axis] # --- resolve slice_range --- if self.slice_range is None: i_start, i_stop = 0, n_ax else: try: i_start, i_stop = self.slice_range except (TypeError, ValueError): raise ValueError( "slice_range must be a (start, stop) tuple of integers." ) # Resolve negative indices if i_start < 0: i_start = max(0, n_ax + i_start) if i_stop < 0: i_stop = max(0, n_ax + i_stop) i_start = int(np.clip(i_start, 0, n_ax)) i_stop = int(np.clip(i_stop, 0, n_ax)) if i_start >= i_stop: raise ValueError( f"slice_range ({i_start}, {i_stop}) is empty after " f"resolving against axis length {n_ax}." ) # Choose evenly-spaced indices within [i_start, i_stop) n_avail = i_stop - i_start n = min(self.n_slices, n_avail) indices = np.round( np.linspace(i_start, i_stop - 1, n) ).astype(int) # Remove duplicates that can appear on very thin sub-ranges indices = np.unique(indices) range_str = ( f"[{i_start}:{i_stop}]" if self.slice_range is not None else "full range" ) print( f" Volume shape : {vol.shape} — axis={self.axis}, " f"{range_str}, sampling {len(indices)} slice(s)" ) print(f" Pixel size : {self.pixel_size}") print(f" Threshold : {self.threshold}") # Check the full volume for tomographic reconstruction characteristics # before the per-slice loop; call _warn_if_tomo on the first slice so # the per-slice loop skips subsequent warnings via the flag. vol_neg_frac = float(np.sum(vol < 0)) / vol.size if vol_neg_frac > 0.05 and not self._tomo_warned: first_sl = np.take(vol, indices[0], axis=self.axis) self._warn_if_tomo(first_sl) res_px_list = [] for idx in indices: sl = np.take(vol, idx, axis=self.axis) # always 2-D self._run_2d(sl) # sets self.r_values, self.A, self.d, … res_px_list.append(self.resolution_px) print( f" Slice {idx:4d}{self.resolution_px:.1f} px " f"({self.resolution:.4g} in physical units)" ) self.resolutions_px = np.array(res_px_list) self.resolution_px_mean = float(np.mean(self.resolutions_px)) self.resolution_px_median = float(np.median(self.resolutions_px)) self.resolution_px_std = float(np.std(self.resolutions_px)) # Set scalar summary to median self.resolution_px = self.resolution_px_median self.resolution = self.resolution_px_median * self.pixel_size self.r_res = 1.0 / self.resolution_px_median print( f" Resolution (median) : {self.resolution_px_median:.1f} ± " f"{self.resolution_px_std:.1f} px " f"(mean {self.resolution_px_mean:.1f} px) " f"({self.resolution:.4g} in physical units)" ) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _apodize(self): """ Apply a Hanning apodization window and subtract the mean. Returns ------- img : ndarray Zero-mean apodized copy of the image. """ img = self.image.copy() if self.apod_width > 0: # Tukey-like window: cosine taper of width apod_width on each side wr = np.ones(self.nr) wc = np.ones(self.nc) aw = self.apod_width half_h = np.hanning(2 * aw) wr[:aw] = half_h[:aw] wr[-aw:] = half_h[aw:] wc[:aw] = half_h[:aw] wc[-aw:] = half_h[aw:] img *= np.outer(wr, wc) img -= img.mean() return img def _compute(self): """ Compute the phase-normalised ring correlation A(r). For each ring at radius *r* (cycles/pixel): 1. Divide the FFT by its modulus (keep only the phase): ``F_n = F / |F|`` 2. Apply a ring mask of width ``dr`` centred on *r*. 3. Back-transform to real space: ``I_n_r = Re(IFFT(F_n * ring))``. 4. Compute the Pearson correlation between the apodized image and ``I_n_r``. Returns ------- r_values : ndarray Radial frequency axis (cycles/pixel). A : ndarray Normalised ring correlation. d : ndarray Decorrelation function ``d = 1 − A``. """ img = self._apodize() # Phase-normalised FFT F = fft2(img) eps = np.finfo(np.float64).tiny F_n = F / (np.abs(F) + eps) # Radial coordinate map (cycles/pixel, FFT-shifted layout) fy = fftfreq(self.nr) # shape (nr,) fx = fftfreq(self.nc) # shape (nc,) FY, FX = np.meshgrid(fy, fx, indexing="ij") R = np.sqrt(FX ** 2 + FY ** 2) # Frequency axis r_min = 1.0 / min(self.nr, self.nc) r_max = 0.5 # Nyquist r_values = np.linspace(r_min, r_max, self.n_r) dr = r_values[1] - r_values[0] # Precompute image statistics (zero-mean, already done in _apodize) img_flat = img.ravel() img_rms = np.sqrt(np.mean(img_flat ** 2)) if img_rms < eps: return r_values, np.zeros(self.n_r), np.ones(self.n_r) A = np.zeros(self.n_r) for i, r in enumerate(r_values): ring_mask = (R >= r - 0.5 * dr) & (R < r + 0.5 * dr) if not np.any(ring_mask): continue F_n_r = F_n * ring_mask I_n_r = np.real(ifft2(F_n_r)) ring_rms = np.sqrt(np.mean(I_n_r ** 2)) if ring_rms < eps: continue A[i] = np.mean(img_flat * I_n_r.ravel()) / (img_rms * ring_rms) d = 1.0 - A return r_values, A, d def _find_resolution(self): """ Find the resolution spatial frequency from the correlation curve. The resolution is the **highest** spatial frequency r for which A(r) ≥ ``self.threshold``. If the correlation never exceeds the threshold the Nyquist frequency is returned as a conservative (worst-case) estimate. Returns ------- r_res : float Resolution frequency (cycles/pixel). resolution_px : float Resolution in pixels (= 1 / r_res). resolution : float Resolution in physical units (= pixel_size * resolution_px). """ above = self.A >= self.threshold if not np.any(above): # Correlation never reaches threshold — return Nyquist as limit r_res = self.r_values[-1] else: r_res = self.r_values[above][-1] resolution_px = 1.0 / r_res resolution = resolution_px * self.pixel_size return r_res, resolution_px, resolution # ------------------------------------------------------------------ # Public interface # ------------------------------------------------------------------
[docs] def plot(self): """ Plot the decorrelation analysis result. For a **2-D** input (or after 3-D analysis) the decorrelation curve A(r) is shown together with the threshold line and the resolution estimate marker. For a **3-D** input a histogram of the per-slice resolution estimates is shown in addition to the A(r) curve of the last processed slice, so that the spread across slices is visible. Returns ------- r_values : ndarray Radial spatial frequencies (cycles/pixel). A : ndarray Normalised ring correlation A(r). d : ndarray Decorrelation function d(r) = 1 − A(r). resolution_px : float Estimated resolution in pixels (median for 3-D input). """ print("Calling method plot from the class ImageDecorr") r = self.r_values A = self.A d = self.d fn = r / 0.5 # normalise to [0, 1] (Nyquist = 1) is3d = self.resolutions_px is not None if is3d: # Two-panel figure: A(r) curve on left, histogram on right if isnotebook(): fig, (ax, ax2) = plt.subplots(1, 2, figsize=(14, 5)) else: fig, (ax, ax2) = plt.subplots(1, 2) ax2.hist(self.resolutions_px, bins="auto", color="steelblue", edgecolor="white") ax2.axvline(self.resolution_px_median, color="k", linestyle="-", label=f"Median = {self.resolution_px_median:.1f} px") ax2.axvline(self.resolution_px_mean, color="r", linestyle="--", label=f"Mean = {self.resolution_px_mean:.1f} px") ax2.set_xlabel("Resolution (pixels)") ax2.set_ylabel("Count") ax2.set_title( f"Per-slice resolution (n={len(self.resolutions_px)}, " f"std={self.resolution_px_std:.1f} px)" ) ax2.legend() ax2.grid(True, linestyle="--", alpha=0.5) title_suffix = " [last slice]" else: if isnotebook(): fig = plt.figure(figsize=(8, 6)) else: fig = plt.figure() plt.clf() ax = fig.add_subplot(111) title_suffix = "" ax.plot(fn, A, "-b", label="A(r) (ring correlation)") ax.plot(fn, d, "-g", label="d(r) = 1 − A(r)") ax.axhline(self.threshold, color="r", linestyle="--", label=f"Threshold = {self.threshold}") ax.axvline(self.r_res / 0.5, color="k", linestyle=":", label=f"Resolution ≈ {self.resolution_px:.1f} px") ax.legend() ax.set_xlim(0, 1) ax.set_ylim(-0.1, 1.1) ax.set_xlabel("Spatial frequency / Nyquist") ax.set_ylabel("Normalised correlation") ax.set_title(f"Image Decorrelation Analysis{title_suffix}") ax.grid(True, linestyle="--", alpha=0.5) fig.tight_layout() fig.savefig("ImageDecorr.png", bbox_inches="tight") if isnotebook(): from IPython import display display.display(fig) plt.close(fig) else: plt.show(block=False) return r, A, d, self.resolution_px