Source code for toupy.resolution.localres

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Single-map local resolution estimation via a ResMap-inspired statistical test.
"""

# standard library
import time
import warnings

# third party
import numpy as np
from scipy.fft import fftn, ifftn, fftfreq
from scipy.ndimage import (
    gaussian_filter,
    binary_fill_holes,
    binary_erosion,
    generate_binary_structure,
)
from scipy.stats import norm as _norm

# local
from ..utils import tqdm
from ..utils.plot_utils import show_resolution_map

__all__ = ["LocalResolution"]


[docs] class LocalResolution: """ Single-map local resolution estimation via a ResMap-inspired statistical test. For each voxel and each spatial-frequency band the local signal energy is compared with the noise floor via a z-score derived from the chi-squared distribution. The local resolution is the finest spatial frequency at which the test passes at significance level ``significance``. A single reconstruction is sufficient — no half-datasets are required. Three noise estimation strategies are available (see ``noise_method``), making the estimator usable even for **local tomography** datasets where the sample fills the entire field of view and the volume corners are not empty. Parameters ---------- vol : ndarray A 2-dimensional image or 3-dimensional reconstruction. No half-maps are needed. pixel_size : float, optional Physical size of one voxel (any consistent unit, e.g. nm). Default ``1.0`` (result in pixels). significance : float, optional Significance level α for the one-sided z-test. Smaller values require stronger evidence of signal and yield more conservative (coarser) local resolution estimates. Default ``0.05``. n_freq : int, optional Number of frequency bands tested between the lowest resolvable frequency and the Nyquist limit. Default ``20``. window_sigma : float, optional Standard deviation (in voxels) of the Gaussian window used to estimate the local signal energy. Larger values produce a smoother but spatially lower-resolution map. Default ``7.0``. mask : ndarray of bool or None, optional Binary mask identifying the **sample** region. Voxels outside the mask are set to ``NaN`` in the output map. If ``None`` (default) a mask is estimated automatically via Gaussian blurring and thresholding. noise_method : {'corners', 'highfreq', 'mask'}, optional Strategy used to estimate the noise standard deviation: ``'corners'`` *(default)* Noise is estimated from small cubic sub-volumes at the eight (four for 2-D) corners of the reconstruction. Requires the corners to be free of sample signal. **Not suitable for local tomography.** ``'highfreq'`` Noise is estimated from the high-frequency tail of the power spectrum (frequencies above ``1 - highfreq_fraction`` of the Nyquist range). Works even when the sample fills the entire field of view, because signal power typically falls well below the noise floor near the Nyquist limit. Suitable for **local tomography** and any dataset where empty corners are unavailable. ``'mask'`` Noise is estimated from the voxels indicated by ``noise_mask``. The user must supply a boolean array marking known empty (noise-only) regions anywhere in the volume. noise_mask : ndarray of bool or None, optional Boolean array (same shape as *vol*) whose ``True`` entries mark voxels that contain **only noise** (i.e. no sample signal). Required when ``noise_method='mask'``; ignored otherwise. sigma_noise : float or None, optional Directly supply the noise standard deviation, bypassing all automatic estimation. When provided, ``noise_method``, ``noise_mask``, and ``corner_fraction`` are all ignored. corner_fraction : float, optional Fraction of each edge length used to define the corner noise regions. Must be in ``(0, 0.5)``. Only used when ``noise_method='corners'``. Default ``0.10``. highfreq_fraction : float, optional Fraction of the frequency range (counting down from Nyquist) used to estimate noise power. Only used when ``noise_method='highfreq'``. Default ``0.10`` (top 10 %, i.e. frequencies from 0.45 to 0.50 cycles/pixel). Attributes ---------- resolution_map : ndarray Local full-period resolution in pixels (van Heel convention, ``NaN`` outside the sample mask), same shape as *vol*. resolution_map_phys : ndarray Local full-period resolution in physical units (``resolution_map * pixel_size``). resolution_map_half : ndarray Local half-period resolution in pixels (Rayleigh convention): ``resolution_map / 2``. resolution_map_phys_half : ndarray Local half-period resolution in physical units: ``resolution_map_phys / 2``. sigma_noise : float Noise standard deviation used for the hypothesis test. freq_bands : ndarray Centre frequencies (cycles/pixel) of each tested band. mask : ndarray of bool Binary sample mask (estimated or user-supplied). resolution_median : float Median full-period local resolution in pixels (within the mask). resolution_mean : float Mean full-period local resolution in pixels (within the mask). resolution_std : float Standard deviation of full-period local resolution in pixels (within the mask). resolution_median_half : float Median half-period local resolution in pixels (within the mask). resolution_mean_half : float Mean half-period local resolution in pixels (within the mask). resolution_std_half : float Standard deviation of half-period local resolution in pixels (within the mask). Notes ----- The implementation follows the methodology of ResMap [1]_ but uses a Gaussian bandpass filter and a z-score threshold in place of the original steerable-filter chi-squared test. For most practical purposes the results are equivalent. The key advantage over the block-wise :class:`LocalFSC` approach is that only a single reconstruction is required. For **local (region-of-interest) tomography**, use ``noise_method='highfreq'`` or supply ``sigma_noise`` directly: .. code-block:: python # Local tomography — corners are not empty lr = LocalResolution(vol, pixel_size=28.6e-9, noise_method='highfreq') # Or if you already know sigma_noise from the sinogram lr = LocalResolution(vol, pixel_size=28.6e-9, sigma_noise=0.003) References ---------- .. [1] A. Kucukelbir, F. J. Sigworth, and H. D. Tagare, "Quantifying the local resolution of cryo-EM density maps", Nature Methods 11, 63-65 (2014). https://doi.org/10.1038/nmeth.2727 """ _VALID_NOISE_METHODS = ("corners", "highfreq", "mask") def __init__( self, vol, pixel_size=1.0, significance=0.05, n_freq=20, window_sigma=7.0, mask=None, noise_method="corners", noise_mask=None, sigma_noise=None, corner_fraction=0.10, highfreq_fraction=0.10, ): """ Initialise, validate inputs, estimate noise, build mask, and compute. Parameters ---------- vol : ndarray 2-D or 3-D reconstruction volume. pixel_size : float, optional Physical voxel size. Default ``1.0``. significance : float, optional One-sided test significance level. Default ``0.05``. n_freq : int, optional Number of frequency bands. Default ``20``. window_sigma : float, optional Gaussian smoothing sigma (voxels) for local energy. Default ``7.0``. mask : ndarray of bool or None, optional Pre-computed sample mask. ``None`` triggers auto-masking. noise_method : {'corners', 'highfreq', 'mask'}, optional Noise estimation strategy. Default ``'corners'``. noise_mask : ndarray of bool or None, optional Noise-only voxel mask. Required when ``noise_method='mask'``. sigma_noise : float or None, optional Directly supply noise standard deviation. Overrides ``noise_method`` when provided. corner_fraction : float, optional Fraction of edge used for corner noise regions. Default ``0.10``. highfreq_fraction : float, optional Fraction of the frequency range (from Nyquist down) used for high-frequency noise estimation. Default ``0.10``. Raises ------ ValueError If inputs are invalid or the noise variance cannot be estimated. """ vol = np.asarray(vol, dtype=np.float64) if vol.ndim not in (2, 3): raise ValueError(f"vol must be 2-D or 3-D, got {vol.ndim}-D array.") if pixel_size <= 0: raise ValueError(f"pixel_size must be positive, got {pixel_size}.") if not (0.0 < significance < 1.0): raise ValueError(f"significance must be in (0, 1), got {significance}.") if n_freq < 2: raise ValueError(f"n_freq must be >= 2, got {n_freq}.") if window_sigma <= 0: raise ValueError(f"window_sigma must be positive, got {window_sigma}.") if not (0.0 < corner_fraction < 0.5): raise ValueError( f"corner_fraction must be in (0, 0.5), got {corner_fraction}." ) if not (0.0 < highfreq_fraction < 0.5): raise ValueError( f"highfreq_fraction must be in (0, 0.5), got {highfreq_fraction}." ) if noise_method not in self._VALID_NOISE_METHODS: raise ValueError( f"noise_method must be one of {self._VALID_NOISE_METHODS!r}, " f"got {noise_method!r}." ) self.vol = vol self.pixel_size = float(pixel_size) self.significance = float(significance) self.n_freq = int(n_freq) self.window_sigma = float(window_sigma) self.corner_fraction = float(corner_fraction) self.highfreq_fraction = float(highfreq_fraction) self.noise_method = noise_method # Validate user-supplied sample mask if mask is not None: mask = np.asarray(mask, dtype=bool) if mask.shape != vol.shape: raise ValueError( f"mask shape {mask.shape} does not match vol shape {vol.shape}." ) self.mask = mask else: self.mask = None # set by _make_mask # Validate noise_mask if noise_mask is not None: noise_mask = np.asarray(noise_mask, dtype=bool) if noise_mask.shape != vol.shape: raise ValueError( f"noise_mask shape {noise_mask.shape} does not match " f"vol shape {vol.shape}." ) self._noise_mask = noise_mask # Step 1: noise estimation if sigma_noise is not None: # User-supplied directly — skip all estimation sigma_noise = float(sigma_noise) if sigma_noise <= 0: raise ValueError( f"sigma_noise must be positive, got {sigma_noise}." ) self.sigma_noise = sigma_noise self.var_noise = sigma_noise ** 2 print(f"[LocalResolution] Using supplied sigma_noise = {self.sigma_noise:.6g}") self._corner_vals = None else: self._estimate_noise() # Step 2: auto-mask if none supplied if self.mask is None: self._make_mask() # Step 3–6: main computation self._compute() ps = self.pixel_size print( f"[LocalResolution] Median resolution (full period) = " f"{self.resolution_median:.3f} px " f"({self.resolution_median * ps:.4g} physical units)" ) print( f"[LocalResolution] Median resolution (half period) = " f"{self.resolution_median_half:.3f} px " f"({self.resolution_median_half * ps:.4g} physical units)" ) print( f"[LocalResolution] Mean = {self.resolution_mean:.3f} px " f"| Std = {self.resolution_std:.3f} px (full period)" ) # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ def _extract_corners(self): """Extract and pool corner voxels; return a 1-D flat array of values.""" vol = self.vol ndim = vol.ndim cs = max(1, round(self.corner_fraction * min(vol.shape))) if ndim == 2: ny, nx = vol.shape corners = [ vol[:cs, :cs], vol[:cs, nx - cs:], vol[ny - cs:, :cs], vol[ny - cs:, nx - cs:], ] else: # ndim == 3 nz, ny, nx = vol.shape corners = [ vol[:cs, :cs, :cs], vol[:cs, :cs, nx - cs:], vol[:cs, ny - cs:, :cs], vol[:cs, ny - cs:, nx - cs:], vol[nz - cs:, :cs, :cs], vol[nz - cs:, :cs, nx - cs:], vol[nz - cs:, ny - cs:, :cs], vol[nz - cs:, ny - cs:, nx - cs:], ] return np.concatenate([c.ravel() for c in corners]) def _estimate_noise(self): """ Dispatch noise estimation to the appropriate strategy. Calls one of :meth:`_noise_from_corners`, :meth:`_noise_from_highfreq`, or :meth:`_noise_from_mask` according to ``self.noise_method``, then sets ``self.sigma_noise`` and ``self.var_noise``. """ if self.noise_method == "corners": self._noise_from_corners() elif self.noise_method == "highfreq": self._noise_from_highfreq() elif self.noise_method == "mask": self._noise_from_mask() def _noise_from_corners(self): """ Estimate noise from the corner sub-volumes of the reconstruction. Pools voxels from the 8 (3-D) or 4 (2-D) corners — each of side ``round(corner_fraction * min(vol.shape))`` voxels — and computes their standard deviation. Raises ------ ValueError If the corner variance is zero, which typically means the corners contain sample material (local tomography) rather than empty background. Switch to ``noise_method='highfreq'`` or provide ``sigma_noise`` directly in that case. Warns ----- UserWarning If the absolute mean of the corner values exceeds ``2 * sigma_noise``, suggesting the corners may not be empty background (soft indicator of local tomography). """ corner_vals = self._extract_corners() self._corner_vals = corner_vals # cache for _make_mask sigma = float(np.std(corner_vals)) var = sigma ** 2 if var == 0.0: raise ValueError( "Corner regions have zero variance — the sample likely fills " "the entire field of view (local tomography) or the corners " "are identically zero. Use noise_method='highfreq' or supply " "sigma_noise directly." ) # Soft check: if |mean| > 2σ the corners may contain signal corner_mean = float(np.abs(np.mean(corner_vals))) if corner_mean > 2.0 * sigma: warnings.warn( f"[LocalResolution] Corner mean ({corner_mean:.4g}) is more than " f"2× the corner std ({sigma:.4g}), suggesting the corners may " "contain sample signal rather than empty background. " "This is common in local (region-of-interest) tomography. " "Consider using noise_method='highfreq' or supplying sigma_noise " "directly to avoid a biased noise estimate.", UserWarning, stacklevel=4, ) self.sigma_noise = sigma self.var_noise = var print( f"[LocalResolution] Noise estimated from corners: " f"sigma = {self.sigma_noise:.6g}" ) def _noise_from_highfreq(self): """ Estimate noise from the high-frequency tail of the power spectrum. For a well-sampled reconstruction the signal power rolls off well before the Nyquist limit, so the power spectrum near Nyquist is dominated by noise. The noise variance is: .. math:: \\sigma^2_{\\text{noise}} = \\frac{\\langle |\\hat{V}(\\mathbf{k})|^2 \\rangle_{R > f_{\\text{thr}}}}{N} where :math:`N` is the total number of voxels and :math:`f_{\\text{thr}}` is the lower bound of the high-frequency band (``0.5 * (1 - highfreq_fraction)``). This method is suitable for **local tomography** because it does not require empty corners. """ R = self._build_radial_grid() F = fftn(self.vol) N = self.vol.size f_thr = 0.5 * (1.0 - self.highfreq_fraction) hf_mask = R > f_thr n_hf = int(np.sum(hf_mask)) if n_hf == 0: raise ValueError( f"No frequency bins found above {f_thr:.3f} cycles/pixel. " "Decrease highfreq_fraction or check the volume shape." ) mean_power = float(np.mean(np.abs(F[hf_mask]) ** 2)) var = mean_power / N sigma = float(np.sqrt(var)) if var == 0.0: raise ValueError( "High-frequency power is zero — the volume may be uniformly " "zero. Check the input data." ) self._corner_vals = None # not available for auto-mask fallback self.sigma_noise = sigma self.var_noise = var print( f"[LocalResolution] Noise estimated from high frequencies " f"(R > {f_thr:.3f} cycles/px, {n_hf} bins): " f"sigma = {self.sigma_noise:.6g}" ) def _noise_from_mask(self): """ Estimate noise from a user-supplied noise-only mask. Uses the voxels where ``noise_mask`` is ``True`` to compute the noise standard deviation. Raises ------ ValueError If ``noise_mask`` was not provided or marks no voxels. """ if self._noise_mask is None: raise ValueError( "noise_method='mask' requires a noise_mask array to be " "supplied." ) vals = self.vol[self._noise_mask] if vals.size == 0: raise ValueError( "noise_mask contains no True entries — nothing to estimate " "noise from." ) sigma = float(np.std(vals)) var = sigma ** 2 if var == 0.0: raise ValueError( "noise_mask region has zero variance. Make sure the mask " "selects voxels that contain background noise, not a flat " "signal region." ) self._corner_vals = vals # reuse for auto-mask threshold self.sigma_noise = sigma self.var_noise = var print( f"[LocalResolution] Noise estimated from noise_mask " f"({vals.size} voxels): sigma = {self.sigma_noise:.6g}" ) def _make_mask(self): """Auto-generate a binary sample mask from the volume; set ``self.mask``.""" vol = self.vol sigma_noise = self.sigma_noise # Smooth |vol| then threshold smoothed = gaussian_filter(np.abs(vol), sigma=self.window_sigma / 2.0) # Threshold baseline: use corner/noise-mask mean if available, # otherwise fall back to the volume mean (works for highfreq method) if self._corner_vals is not None and self._corner_vals.size > 0: baseline = float(np.mean(self._corner_vals)) else: baseline = float(np.mean(np.abs(vol))) threshold = baseline + 2.0 * sigma_noise binary = smoothed > threshold # Fill holes filled = binary_fill_holes(binary) # Erode with ball-shaped structure of radius 2 ndim = vol.ndim struct = generate_binary_structure(ndim, ndim) # Expand to radius ~2 by iterating erosion twice eroded = binary_erosion(filled, structure=struct, iterations=2) if not np.any(eroded): warnings.warn( "Auto-mask is empty after erosion. " "Falling back to an all-True mask. " "Consider decreasing `window_sigma` or supplying an explicit mask.", RuntimeWarning, stacklevel=3, ) eroded = np.ones(vol.shape, dtype=bool) self.mask = eroded def _build_radial_grid(self): """Build a radial frequency grid (cycles/pixel); return array same shape as vol. Returns ------- R : ndarray Radial frequency magnitude at each voxel position, in cycles/pixel. """ vol = self.vol grids = [fftfreq(s) for s in vol.shape] R2 = np.zeros(vol.shape, dtype=np.float64) for i, g in enumerate(grids): # Broadcast g along axis i shape = [1] * vol.ndim shape[i] = len(g) R2 += g.reshape(shape) ** 2 return np.sqrt(R2) def _compute(self): """ Run the frequency-band loop; set resolution map and summary statistics. Sets ---- resolution_map : ndarray Local resolution in pixels (NaN outside mask). resolution_map_phys : ndarray Local resolution in physical units. freq_bands : ndarray Tested centre frequencies (cycles/pixel). resolution_median, resolution_mean, resolution_std : float Summary statistics over masked voxels. """ vol = self.vol ndim = vol.ndim eps = np.finfo(np.float64).tiny # Frequency bands f_min = 1.0 / min(vol.shape) f_max = 0.50 freq_bands = np.linspace(f_min, f_max, self.n_freq) sigma_f = (f_max - f_min) / max(self.n_freq - 1, 1) * 0.6 self.freq_bands = freq_bands # z threshold (one-sided) z_alpha = float(_norm.ppf(1.0 - self.significance)) # Precompute R = self._build_radial_grid() VOL_F = fftn(vol) # FFT of the full volume (computed once) # Resolution map initialised to NaN resolution_map = np.full(vol.shape, np.nan, dtype=np.float64) # n_eff: effective independent samples in the Gaussian window n_eff = (2.0 * np.sqrt(np.pi) * self.window_sigma) ** ndim # Iterate from low to high frequency so the final assignment is the # finest (highest-frequency) band where the test passes. for f_i in tqdm(freq_bands, desc="LocalResolution bands"): # Gaussian bandpass filter in Fourier space H_f = np.exp(-0.5 * ((R - f_i) / sigma_f) ** 2) # Bandpass-filtered volume (real part) vol_f = np.real(ifftn(VOL_F * H_f)) # Expected noise power fraction captured by this band bp_energy = float(np.mean(H_f ** 2)) # Local mean-squared signal via Gaussian smoothing E_f = gaussian_filter(vol_f ** 2, sigma=self.window_sigma) # Null-hypothesis statistics mu0 = self.var_noise * bp_energy std0 = mu0 * np.sqrt(2.0 / n_eff) # Z-score z = (E_f - mu0) / (std0 + eps) # Decision: passes test AND inside mask decision = (z > z_alpha) & self.mask # Overwrite resolution map where test passes (resolution = 1/f_i px) resolution_map[decision] = 1.0 / f_i self.resolution_map = resolution_map self.resolution_map_phys = resolution_map * self.pixel_size self.resolution_map_half = resolution_map / 2.0 self.resolution_map_phys_half = self.resolution_map_phys / 2.0 # Summary statistics (masked voxels only, ignoring NaN) valid = resolution_map[self.mask] valid = valid[~np.isnan(valid)] if valid.size > 0: self.resolution_median = float(np.median(valid)) self.resolution_mean = float(np.mean(valid)) self.resolution_std = float(np.std(valid)) else: self.resolution_median = float("nan") self.resolution_mean = float("nan") self.resolution_std = float("nan") self.resolution_median_half = self.resolution_median / 2.0 self.resolution_mean_half = self.resolution_mean / 2.0 self.resolution_std_half = self.resolution_std / 2.0 # ------------------------------------------------------------------ # Public interface # ------------------------------------------------------------------
[docs] def plot(self, slice_idx=None, axis=0, cmap="viridis_r", vmin=None, vmax=None): """ Visualise the local resolution map and save to ``LocalResolution_map.png``. For a 3-D volume a single slice is shown; for a 2-D image the full map is displayed. Parameters ---------- slice_idx : int or None, optional Index of the slice to display along *axis* (3-D only). Defaults to the central slice when ``None``. axis : int, optional Axis along which to take the slice (3-D only). Default ``0``. cmap : str, optional Matplotlib colormap name. Default ``'viridis_r'`` so that higher resolution (smaller value) appears bright. vmin : float or None, optional Lower bound for the colormap. Defaults to the map minimum. vmax : float or None, optional Upper bound for the colormap. Defaults to the map maximum. Returns ------- resolution_map : ndarray The ``resolution_map`` attribute (unchanged). """ show_resolution_map( self.resolution_map, self.vol.ndim, title="LocalResolution map", filename="LocalResolution_map.png", slice_idx=slice_idx, axis=axis, cmap=cmap, vmin=vmin, vmax=vmax, ) return self.resolution_map