Source code for toupy.restoration.unwraptools

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# standard packages
import contextlib
import heapq
import os
import sys
from collections import deque

# third party packages
from concurrent.futures import ProcessPoolExecutor
import functools
from matplotlib.widgets import Button
from ..utils.plot_utils import plt
import multiprocessing
import numpy as np
from scipy.fft import dctn, idctn
from ..utils import tqdm

# optional: SNAPHU (pip install snaphu)
try:
    import snaphu as _snaphu
    _SNAPHU_AVAILABLE = True
except ImportError:
    _SNAPHU_AVAILABLE = False

# local packages
from ..utils.plot_utils import _plotdelimiters
from ..utils import isnotebook

__all__ = [
    "wraptopi",
    "wrap",
    "distance",
    "get_charge",
    "phaseresidues",
    "phaseresiduesStack",
    "phaseresiduesStack_parallel",
    "chooseregiontounwrap",
    "unwrap_phase_2d",
    "unwrapping_phase"
]


[docs] def wraptopi(phase, endpoint=True): """ Wrap a scalar value or an entire array Parameters ---------- phase : float or array_like The value or signal to wrapped. endpoint : bool, optional If ``endpoint=False``, the scalar value or array is wrapped to [-pi, pi), whereas if ``endpoint=True``, it is wrapped to (-pi, pi]. The default value is ``endpoint=True`` Returns ------- float or array Wrapped value or array Examples -------- >>> import numpy as np >>> wraptopi(np.linspace(-np.pi,np.pi,7),endpoint=True) array([ 3.14159265, -2.0943951 , -1.04719755, -0. , 1.04719755, 2.0943951 , 3.14159265]) >>> wraptopi(np.linspace(-np.pi,np.pi,7),endpoint=False) array([-3.14159265, -2.0943951 , -1.04719755, 0. , 1.04719755, 2.0943951 , -3.14159265]) """ if not endpoint: # case [-pi, pi) return (phase + np.pi) % (2 * np.pi) - np.pi else: # case (-pi, pi] return ((-phase + np.pi) % (2.0 * np.pi) - np.pi) * -1.0
[docs] def wrap(phase): """ Wrap a scalar value or an entire array to [-0.5, 0.5). Parameters ---------- phase : float or array_like The value or signal to wrapped. Returns ------- float or array Wrapped value or array Notes ----- Created by Sebastian Theilenberg, PyMRR, which is available at Github repository: https://github.com/theilen/PyMRR.git """ if hasattr(phase, "__len__"): phase = phase.copy() phase[phase > 0.5] -= 1.0 phase[phase <= -0.5] += 1.0 else: if phase > 0.5: phase -= 1.0 elif phase <= -0.5: phase += 1.0 return phase
[docs] def distance(pixel1, pixel2): """ Return the Euclidean distance between two pixels. Parameters ---------- pixel1 : array_like Coordinates of the first pixel, e.g. ``(row, col)``. pixel2 : array_like Coordinates of the second pixel, e.g. ``(row, col)``. Returns ------- float Euclidean distance between ``pixel1`` and ``pixel2``. Examples -------- >>> distance(np.arange(1,10),np.arange(2,11)) 3.0 """ if (not isinstance(pixel1, np.ndarray)) and (not isinstance(pixel2, np.ndarray)): pixel1 = np.asarray(pixel1) pixel2 = np.asarray(pixel2) return np.sqrt(np.sum((pixel1 - pixel2) ** 2))
def _get_charge(residues): """ Auxiliary function to get the residues charges Parameters ---------- residues : ndarray A 2-dimensional array containing the with residues Returns ------- posres : array_like Positions of the residues with positive charge negres : array_like Positions of the residues with negative charge """ posres = np.where(np.round(residues) == 1) respos = len(posres[0]) negres = np.where(np.round(residues) == -1) resneg = len(negres[0]) nres = respos + resneg return posres, negres, nres
[docs] def get_charge(residues): """ Get the residues charges Parameters ---------- residues : ndarray A 2-dimensional array containing the with residues Returns ------- posres : array_like Positions of the residues with positive charge negres : array_like Positions of the residues with negative charge """ posres, negres, nres = _get_charge(residues) print("Found {:>3.0f} residues".format(nres), end="") return posres, negres
[docs] def phaseresidues(phimage): """ Calculates the phase residues [1]_ for a given wrapped phase image. Parameters ---------- phimage : ndarray A 2-dimensional array containing the phase-contrast images with gray-level in radians Returns ------- residues : ndarray A 2-dimensional array containing the map of residues (valued +1 or -1) Note ----- Note that by convention the positions of the phase residues are marked on the top left corner of the 2 by 2 regions as shown below: .. graphviz:: graph g { node [shape=plaintext]; active -- right [label=" res4 "]; right -- belowright [label=" res3 "]; below -- belowright [label=" res2 "]; below -- active [label=" res1 "]; { rank=same; active right } { rank=same; belowright below } } Inspired by PhaseResidues.m created by B.S. Spottiswoode on 07/10/2004 and by find_residues.m created by Manuel Guizar - Sept 27, 2011 References ---------- .. [1] R. M. Goldstein, H. A. Zebker and C. L. Werner, Radio Science 23, 713-720 (1988). """ # Stack all four loop increments into a single (4, nr-2, nc-2) array so # that wraptopi is called once instead of four times, reducing Python # overhead and allowing NumPy to process the full batch in one pass. diffs = np.stack([ phimage[2:, 1:-1] - phimage[1:-1, 1:-1], phimage[2:, 2:] - phimage[2:, 1:-1], phimage[1:-1, 2:] - phimage[2:, 2:], phimage[1:-1, 1:-1] - phimage[1:-1, 2:], ]) # (4, nr-2, nc-2) residues = np.sum(wraptopi(diffs), axis=0) / (2.0 * np.pi) respos, resneg, nres = _get_charge(residues) residues_charge = dict(pos=respos, neg=resneg) return residues, residues_charge, nres
[docs] def phaseresiduesStack(stack_array, threshold=5000): """ Calculate the map of residues on the stack. Parameters ---------- stack_array : ndarray A 3-dimensional array containing the stack of projections from which to calculate the phase residues. threshold : int, optional Maximum number of acceptable phase residues per projection. Projections with more residues than ``threshold`` are flagged as problematic. Default ``5000``. Returns ------- resmap : ndarray 2-D phase residue accumulation map (sum of absolute residue maps across all projections). posres : tuple of ndarray Indices ``(yres, xres)`` of pixels where ``resmap >= 1``. nres : int Total number of residues found in the last processed projection. """ resmap = 0 wrong = [] nproj = stack_array.shape[0] for ii in tqdm(range(nproj), desc="Searching phase residues"): residues, residues_charge, nres = phaseresidues(stack_array[ii]) if np.any(np.isnan(residues)): raise ValueError(f"NaN found in projection {ii+1}") if nres > threshold: wrong.append(ii) resmap += np.abs(residues) print(". Done") posres = np.where(resmap >= 1.0) if wrong: print("The following projections are problematic: {}".format(wrong)) return resmap, posres, nres
[docs] def phaseresiduesStack_parallel(stack_array, threshold=1000, ncores=2): """ Calculate the map of residues on the stack using parallel processing. Parameters ---------- stack_array : ndarray A 3-dimensional array containing the stack of projections from which to calculate the phase residues. threshold : int, optional Maximum number of acceptable phase residues per projection. Projections with more residues than ``threshold`` are flagged as problematic. Default ``1000``. ncores : int, optional Number of CPU cores for parallel computation. Default ``2``. Returns ------- resmap : ndarray 2-D phase residue accumulation map. posres : tuple of ndarray Indices ``(yres, xres)`` of pixels where ``resmap >= 1``. nres : tuple of int Tuple of per-projection residue counts. """ nprojs = len(stack_array) with ProcessPoolExecutor(max_workers=ncores) as executor: residues, residues_charge, nres = zip( *tqdm( executor.map(phaseresidues, stack_array), total=nprojs, desc="Searching phase residues", ) ) print("Done") print("Creating the map of residues") # Convert tuple of (nproj) residue arrays → (nproj, nr-2, nc-2) and # sum absolute values in one vectorised call instead of a Python loop. resmap = np.sum(np.abs(np.array(residues)), axis=0) del residues del residues_charge posres = np.where(resmap >= 1.0) wrong = np.where(np.array(nres) > threshold)[0] if len(wrong) > 0: print("The following projections are problematic: \n {}".format(wrong)) # return residues, residues_charge, nres return resmap, posres, nres
class _RegionPicker: """ Interactive GUI for selecting the unwrap region and an air pixel. Three sequential clicks on the image define the region: 1. Top-left corner of the rectangular unwrap region. 2. Bottom-right corner (blue rectangle drawn + residue count printed). 3. A pixel in air/vacuum (must fall inside the rectangle). Then click **Confirm** to accept, or **Reset** to start over. Terminal usage -------------- This class is not meant to be used directly; use :func:`chooseregiontounwrap`. Jupyter two-cell workflow ------------------------- The picker is returned immediately; access results in the **next** cell:: # Cell 1 picker = chooseregiontounwrap(stack_array, ...) # Cell 2 — run AFTER clicking Confirm rx, ry, airpix = picker.rx, picker.ry, picker.airpix """ _STEPS = [ "Step 1/3: click the TOP-LEFT corner of the unwrap region", "Step 2/3: click the BOTTOM-RIGHT corner of the unwrap region", "Step 3/3: click a pixel in AIR / vacuum (inside the rectangle)", ] def __init__(self, stack_array, resmap, xres, yres): self._stack = stack_array self._resmap = resmap self._xres = xres self._yres = yres # public results self.rx = None self.ry = None self.airpix = None self._done = False # internal state self._step = 0 # 0=TL, 1=BR, 2=air, 3=ready self._corners = [] # (x, y) world-coord tuples self._rect_lines = [] # rectangle artist handles self._air_dot = None # ---- figure ---- fig, ax = plt.subplots(figsize=(9, 7)) self.fig = fig self.ax = ax ax.imshow(stack_array[0], cmap="bone", aspect="auto") ax.plot(xres, yres, "or", markersize=3, alpha=0.6, label="residues") ax.axis("tight") ax.set_title( "First projection — red dots = phase residues\n" "Image size: {} rows × {} cols".format( stack_array[0].shape[0], stack_array[0].shape[1] ), fontsize=9, ) # step-instruction banner (centred over full figure) self._status_txt = fig.text( 0.5, 0.975, self._STEPS[0], ha="center", va="top", fontsize=10, color="steelblue", fontweight="bold", ) # secondary info line (residue count, coordinates, warnings) self._info_txt = fig.text( 0.5, 0.935, "", ha="center", va="top", fontsize=9, color="0.35", ) # buttons ax_reset = fig.add_axes([0.25, 0.01, 0.2, 0.055]) ax_confirm = fig.add_axes([0.55, 0.01, 0.2, 0.055]) self._btn_reset = Button(ax_reset, "Reset") self._btn_confirm = Button(ax_confirm, "Confirm") self._btn_confirm.ax.set_facecolor("0.85") # greyed out until ready self._btn_reset.on_clicked(self._on_reset) self._btn_confirm.on_clicked(self._on_confirm) self._cid = fig.canvas.mpl_connect("button_press_event", self._on_click) print(self._STEPS[0], flush=True) # ------------------------------------------------------------------ helpers def _clear_overlays(self): for ln in self._rect_lines: try: ln.remove() except Exception: pass self._rect_lines = [] if self._air_dot is not None: try: self._air_dot.remove() except Exception: pass self._air_dot = None def _draw_rect(self, x0, y0, x1, y1): """Draw blue rectangle from top-left (x0,y0) to bottom-right (x1,y1).""" self._clear_overlays() kw = dict(color="b", linewidth=1.5, linestyle="-") for xs, ys in [ ([x0, x1], [y0, y0]), ([x0, x1], [y1, y1]), ([x0, x0], [y0, y1]), ([x1, x1], [y0, y1]), ]: ln, = self.ax.plot(xs, ys, **kw) self._rect_lines.append(ln) def _set_info(self, msg): self._info_txt.set_text(msg) self.fig.canvas.draw_idle() def _set_status(self, msg): self._status_txt.set_text(msg) # ------------------------------------------------------------------ events def _on_click(self, event): if event.button != 1 or event.inaxes is not self.ax: return if event.xdata is None or event.ydata is None: return if self._done or self._step >= 3: return x, y = event.xdata, event.ydata if self._step == 0: # ---- step 1: top-left corner ---- self._corners = [(x, y)] self._clear_overlays() dot, = self.ax.plot(x, y, "bs", markersize=7) self._rect_lines.append(dot) self._step = 1 self._set_status(self._STEPS[1]) self._set_info("Top-left: ({:.0f}, {:.0f})".format(x, y)) print(" Top-left corner: ({:.0f}, {:.0f})".format(x, y), flush=True) print(self._STEPS[1], flush=True) elif self._step == 1: # ---- step 2: bottom-right corner — enforce ordering ---- x0, y0 = self._corners[0] x0, x1 = (x0, x) if x0 <= x else (x, x0) y0, y1 = (y0, y) if y0 <= y else (y, y0) self._corners = [(x0, y0), (x1, y1)] self._draw_rect(x0, y0, x1, y1) rx = range(int(round(x0)), int(round(x1)) + 1) ry = range(int(round(y0)), int(round(y1)) + 1) nres = int(np.round( self._resmap[ry[0]:ry[-1], rx[0]:rx[-1]].sum() )) self._step = 2 self._set_status(self._STEPS[2]) info = ( "Region: x=[{}, {}] y=[{}, {}] {} residues inside" ).format(rx[0], rx[-1], ry[0], ry[-1], nres) self._set_info(info) print(" " + info, flush=True) print(self._STEPS[2], flush=True) elif self._step == 2: # ---- step 3: air pixel (must be inside the rectangle) ---- (x0, y0), (x1, y1) = self._corners rx = range(int(round(x0)), int(round(x1)) + 1) ry = range(int(round(y0)), int(round(y1)) + 1) if not (rx[0] <= x <= rx[-1] and ry[0] <= y <= ry[-1]): msg = "Air pixel must be inside the blue rectangle — try again." self._set_info(msg) print(" WARNING: " + msg, flush=True) return self._air_dot, = self.ax.plot(x, y, "ob", markersize=8) self.rx = rx self.ry = ry self.airpix = (int(round(x)), int(round(y))) self._step = 3 # waiting for Confirm self._set_status( "All set — click Confirm to accept | Reset to start over" ) self._btn_confirm.ax.set_facecolor("limegreen") self._set_info("Air pixel: ({}, {})".format(*self.airpix)) print(" Air pixel: ({}, {})".format(*self.airpix), flush=True) print( "All set — click Confirm to accept, or Reset to redo.", flush=True, ) def _on_reset(self, event): self._clear_overlays() self._corners = [] self._step = 0 self.rx = None self.ry = None self.airpix = None self._btn_confirm.ax.set_facecolor("0.85") self._set_status(self._STEPS[0]) self._set_info("") print("Reset — " + self._STEPS[0], flush=True) def _on_confirm(self, event): if self._step < 3: print( "Not ready yet — complete all 3 clicks first.", flush=True ) return self._done = True try: self.fig.canvas.mpl_disconnect(self._cid) except Exception: pass print( "Confirmed — rx=[{}, {}] ry=[{}, {}] airpix={}".format( self.rx[0], self.rx[-1], self.ry[0], self.ry[-1], self.airpix, ), flush=True, ) plt.close(self.fig)
[docs] def chooseregiontounwrap(stack_array, threshold=5000, parallel=False, ncores=1): """ Choose the region to be unwrapped interactively. A GUI figure opens showing the first projection overlaid with phase-residue locations (red dots). Click three points in order: 1. Top-left corner of the rectangular unwrap region. 2. Bottom-right corner of the unwrap region. 3. A pixel in air / vacuum (must be inside the rectangle). Then click **Confirm** to accept, or **Reset** to start over. Parameters ---------- stack_array : ndarray A 3-dimensional array containing the stack of projections to be unwrapped. threshold : int, optional Threshold for the number of acceptable phase residues. (Default = 5000) parallel : bool, optional If ``True``, multiprocessing is used for residue computation. (Default = ``False``) ncores : int, optional Number of cores for parallel computation. (Default = 1) Returns ------- Terminal mode ``(rx, ry, airpix)`` — ranges and air-pixel tuple, ready to use. Jupyter mode (two-cell workflow) ``picker`` — :class:`_RegionPicker` instance. Access ``picker.rx``, ``picker.ry``, ``picker.airpix`` in the **next** cell after clicking Confirm. """ # ---- residue computation ---- print("Checking for phase residues …", flush=True) if ncores == 1: parallel = False if parallel: if ncores == -1: try: ncores = int(os.environ["SLURM_JOB_CPUS_PER_NODE"]) except Exception: ncores = multiprocessing.cpu_count() if ncores == 1: print("{} core used: parallel calculations are not possible".format(ncores)) resmap, posres, nres = phaseresiduesStack_parallel( stack_array, threshold, ncores ) else: resmap, posres, nres = phaseresiduesStack(stack_array, threshold) yres, xres = posres print("→ residue map done.", flush=True) # ---- interactive picker ---- plt.close("all") picker = _RegionPicker(stack_array, resmap, xres, yres) if isnotebook(): try: import matplotlib as _mpl _interactive = "inline" not in _mpl.get_backend().lower() except Exception: _interactive = False if _interactive: plt.show(block=False) picker.fig.canvas.draw() print( "\nJupyter two-cell workflow:\n" " Interact with the figure above (3 clicks + Confirm),\n" " then run in the NEXT cell:\n" " rx, ry, airpix = picker.rx, picker.ry, picker.airpix", flush=True, ) else: import warnings from IPython import display as _ipy_display _ipy_display.display(picker.fig) warnings.warn( "Interactive region picking requires a non-inline backend.\n" "Add %matplotlib widget as the first notebook cell\n" "and run pip install ipympl once, then restart the kernel.", UserWarning, stacklevel=2, ) return picker else: plt.show(block=True) # blocks until Confirm closes the figure return picker.rx, picker.ry, picker.airpix
# --------------------------------------------------------------------------- # Internal reliability-map helper # --------------------------------------------------------------------------- def _reliability_map(phase): """ Compute a per-pixel reliability map for Herraez's reliability-guided phase unwrapping algorithm. Reliability = 1 / D where D = sqrt(H^2 + V^2 + D1^2 + D2^2). H, V, D1, D2 are second-order phase differences computed via wraptopi. Border pixels receive reliability 0. Parameters ---------- phase : ndarray 2-D wrapped phase array (radians). Returns ------- rel : ndarray Float64 reliability array, same shape as ``phase``. """ nr, nc = phase.shape rel = np.zeros((nr, nc), dtype=np.float64) # Second differences (interior pixels only). # Stack all 8 first-order differences into a single (8, nr-2, nc-2) # array so wraptopi is called once instead of eight times. inner = phase[1:-1, 1:-1] all_diffs = np.stack([ phase[1:-1, 2:] - inner, # H forward inner - phase[1:-1, :-2], # H backward phase[2:, 1:-1] - inner, # V forward inner - phase[:-2, 1:-1], # V backward phase[2:, 2:] - inner, # D1 forward inner - phase[:-2, :-2], # D1 backward phase[2:, :-2] - inner, # D2 forward inner - phase[:-2, 2:], # D2 backward ]) # (8, nr-2, nc-2) w = wraptopi(all_diffs) # single call H = w[0] - w[1] V = w[2] - w[3] D1 = w[4] - w[5] D2 = w[6] - w[7] D = np.sqrt(H ** 2 + V ** 2 + D1 ** 2 + D2 ** 2) # Avoid division by zero; zero D gives maximum reliability (flat region) with np.errstate(divide="ignore", invalid="ignore"): rel[1:-1, 1:-1] = np.where(D == 0.0, np.finfo(np.float64).max, 1.0 / D) return rel # --------------------------------------------------------------------------- # Algorithm 1: Herraez reliability-guided BFS # --------------------------------------------------------------------------- def _unwrap_herraez(phase): """ Reliability-guided phase unwrapping (Herraez et al., 2002). Uses a max-heap (implemented as a min-heap with negated priorities) to process pixel edges in decreasing order of edge reliability, where edge reliability = min(reliability[p1], reliability[p2]). Compared with a naïve heap implementation this version reduces per- element overhead by: * Encoding each heap entry as a compact 3-tuple ``(-edge_rel, flat_dst_idx, flat_src_idx)`` instead of a 6-tuple with a monotone counter, saving the ``next(ctr)`` call and one int per push. * Inlining the neighbor-push logic to avoid a Python function call on every visited pixel (the hot path). * Caching local references to ``heapq.heappush`` / ``heapq.heappop`` to cut attribute-lookup time in the inner loop. * Using ``divmod(flat, nc)`` to recover ``(row, col)`` in a single C-level operation instead of two Python integer divisions. Parameters ---------- phase : ndarray 2-D wrapped phase array (radians). Returns ------- unwrapped : ndarray Float64 unwrapped phase array, same shape as ``phase``. """ nr, nc = phase.shape rel = _reliability_map(phase) unwrapped = phase.astype(np.float64).copy() visited = np.zeros((nr, nc), dtype=bool) # Seed from pixel (0, 0) visited[0, 0] = True # Cache hot-path references to avoid per-call attribute lookup _push = heapq.heappush _pop = heapq.heappop heap = [] # Seed neighbours for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)): r2, c2 = dr, dc if 0 <= r2 < nr and 0 <= c2 < nc: er = rel[r2, c2] if rel[r2, c2] < rel[0, 0] else rel[0, 0] _push(heap, (-er, r2 * nc + c2, 0)) while heap: neg_rel, flat2, flat1 = _pop(heap) r, c = divmod(flat2, nc) if visited[r, c]: continue visited[r, c] = True src_r, src_c = divmod(flat1, nc) d = wraptopi(phase[r, c] - unwrapped[src_r, src_c]) unwrapped[r, c] = unwrapped[src_r, src_c] + d # Inline neighbour push for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)): r2, c2 = r + dr, c + dc if 0 <= r2 < nr and 0 <= c2 < nc and not visited[r2, c2]: er = rel[r2, c2] if rel[r2, c2] < rel[r, c] else rel[r, c] _push(heap, (-er, r2 * nc + c2, flat2)) return unwrapped # --------------------------------------------------------------------------- # Algorithm 2: Goldstein branch-cut + BFS flood fill # --------------------------------------------------------------------------- def _unwrap_goldstein(phase): """ Goldstein branch-cut phase unwrapping (Goldstein et al., 1988). Locates phase residues, pairs them with a greedy nearest-neighbour strategy, draws L-shaped branch cuts, then flood-fills from the image centre while respecting the branch cuts. Parameters ---------- phase : ndarray 2-D wrapped phase array (radians). Returns ------- unwrapped : ndarray Float64 unwrapped phase array, same shape as ``phase``. Pixels unreachable due to branch cuts retain their wrapped value. """ nr, nc = phase.shape residues, residues_charge, nres = phaseresidues(phase) if nres == 0: return _unwrap_flynn(phase) # Convert residue indices to image coordinates (+1 offset) pos_rows = residues_charge["pos"][0] + 1 pos_cols = residues_charge["pos"][1] + 1 neg_rows = residues_charge["neg"][0] + 1 neg_cols = residues_charge["neg"][1] + 1 branch_cuts = np.zeros((nr, nc), dtype=bool) pos_used = [False] * len(pos_rows) neg_used = [False] * len(neg_rows) # Greedy nearest-neighbour pairing of positive to negative residues for pi in range(len(pos_rows)): r1, c1 = pos_rows[pi], pos_cols[pi] best_dist = np.inf best_ni = -1 for ni in range(len(neg_rows)): if neg_used[ni]: continue d = distance((r1, c1), (neg_rows[ni], neg_cols[ni])) if d < best_dist: best_dist = d best_ni = ni if best_ni >= 0: pos_used[pi] = True neg_used[best_ni] = True r2, c2 = neg_rows[best_ni], neg_cols[best_ni] # L-shaped branch cut: horizontal segment then vertical segment c_lo, c_hi = min(c1, c2), max(c1, c2) branch_cuts[r1, c_lo:c_hi + 1] = True r_lo, r_hi = min(r1, r2), max(r1, r2) branch_cuts[r_lo:r_hi + 1, c2] = True # Unmatched positives: vertical cut to top border for pi in range(len(pos_rows)): if not pos_used[pi]: r, c = pos_rows[pi], pos_cols[pi] branch_cuts[:r + 1, c] = True # Unmatched negatives: vertical cut to bottom border for ni in range(len(neg_rows)): if not neg_used[ni]: r, c = neg_rows[ni], neg_cols[ni] branch_cuts[r:, c] = True # BFS flood fill from centre unwrapped = phase.astype(np.float64).copy() visited = np.zeros((nr, nc), dtype=bool) start_r, start_c = nr // 2, nc // 2 # If start pixel is on a branch cut, search nearby for a free pixel if branch_cuts[start_r, start_c]: found = False for dr in range(nr): for dc in range(nc): for sr, sc in [(start_r + dr, start_c + dc), (start_r - dr, start_c - dc), (start_r + dr, start_c - dc), (start_r - dr, start_c + dc)]: if 0 <= sr < nr and 0 <= sc < nc and not branch_cuts[sr, sc]: start_r, start_c = sr, sc found = True break if found: break if found: break queue = deque() queue.append((start_r, start_c)) visited[start_r, start_c] = True # seed value is the wrapped phase itself (already set in unwrapped) while queue: r, c = queue.popleft() for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)): r2, c2 = r + dr, c + dc if 0 <= r2 < nr and 0 <= c2 < nc and not visited[r2, c2] and not branch_cuts[r2, c2]: visited[r2, c2] = True d = wraptopi(phase[r2, c2] - unwrapped[r, c]) unwrapped[r2, c2] = unwrapped[r, c] + d queue.append((r2, c2)) return unwrapped # --------------------------------------------------------------------------- # Algorithm 3: Flynn row-integration + median row correction # --------------------------------------------------------------------------- def _unwrap_flynn(phase): """ Flynn's integration-based phase unwrapping (Flynn, 1997). Fast O(N) approach: 1. Integrate wrapped horizontal differences row by row. 2. Fix row-to-row jumps using a median correction. Parameters ---------- phase : ndarray 2-D wrapped phase array (radians). Returns ------- unwrapped : ndarray Float64 unwrapped phase array, same shape as ``phase``. """ nr, nc = phase.shape two_pi = 2.0 * np.pi unwrapped = np.empty((nr, nc), dtype=np.float64) # Step 1: integrate wrapped horizontal differences row by row dh = wraptopi(np.diff(phase, axis=1)) unwrapped[:, 0] = phase[:, 0] unwrapped[:, 1:] = phase[:, :1] + np.cumsum(dh, axis=1) # Step 2: fix row-to-row jumps using median correction for r in range(1, nr): dv = unwrapped[r] - unwrapped[r - 1] correction = int(np.round(np.median(dv) / two_pi)) if correction != 0: unwrapped[r:] -= two_pi * correction return unwrapped # --------------------------------------------------------------------------- # Algorithm 4: DCT-based Weighted Least Squares (Ghiglia & Romero, 1994) # --------------------------------------------------------------------------- def _unwrap_wls(phase): """ DCT-based unweighted least-squares phase unwrapping (Ghiglia & Romero, 1994). Solves the 2-D discrete Poisson equation ∇²φ = ρ where ρ is the divergence of the wrapped phase gradients. The system is solved via a DCT-II transform (Neumann boundary conditions) in O(N log N) time. This gives a globally optimal solution under a flat (uniform) weight model, making it substantially more robust than the greedy Herraez or Goldstein methods for images that are smooth but wrapped many times. Parameters ---------- phase : ndarray 2-D wrapped phase array (radians). Returns ------- unwrapped : ndarray Float64 unwrapped phase array, same shape as ``phase``. Reference --------- D. C. Ghiglia and L. A. Romero, "Robust two-dimensional weighted and unweighted phase unwrapping that uses fast transforms and iterative methods", J. Opt. Soc. Am. A 11(1), 107-117 (1994). """ nr, nc = phase.shape # Wrapped forward differences (phase gradients) dy = wraptopi(np.diff(phase, axis=0)) # shape (nr-1, nc) dx = wraptopi(np.diff(phase, axis=1)) # shape (nr, nc-1) # Discrete divergence with Neumann (zero-flux) boundary conditions. # Interior: rho[i,j] = dx[i,j] - dx[i,j-1] + dy[i,j] - dy[i-1,j] rho = np.zeros((nr, nc), dtype=np.float64) rho[:, 1:-1] += dx[:, 1:] - dx[:, :-1] # interior x rho[:, 0] += dx[:, 0] # left boundary rho[:, -1] -= dx[:, -1] # right boundary rho[1:-1, :] += dy[1:, :] - dy[:-1, :] # interior y rho[0, :] += dy[0, :] # top boundary rho[-1, :] -= dy[-1, :] # bottom boundary # Residue correction: phase residues make the gradient field non- # conservative; the Poisson solver would otherwise spread that # inconsistency as a smooth large-scale gradient across the whole image. # We detect residue locations (2×2 loops where the wrapped path integral # ≠ 0) and zero-out their divergence contribution before solving, which # localises the residue error rather than distributing it globally. res_map, _, _ = phaseresidues(phase) # values ±1 at residue corners # phaseresidues uses a (nr-2)×(nc-2) interior grid with +1 offset res_mask = np.zeros((nr, nc), dtype=bool) res_mask[1:-1, 1:-1] = np.abs(np.round(res_map)) >= 1 # Dilate one pixel so the correction covers the full 2×2 residue cell from scipy.ndimage import binary_dilation res_mask = binary_dilation(res_mask) rho[res_mask] = 0.0 # Eigenvalues of the discrete Laplacian in DCT-II / Neumann basis: # mu[k, l] = 2*(cos(pi*k/nr) - 1) + 2*(cos(pi*l/nc) - 1) k = np.arange(nr, dtype=np.float64) l = np.arange(nc, dtype=np.float64) mu = ( 2.0 * (np.cos(np.pi * k / nr) - 1.0)[:, None] + 2.0 * (np.cos(np.pi * l / nc) - 1.0)[None, :] ) mu[0, 0] = 1.0 # avoid division by zero; DC set to 0 below # Solve in DCT domain rho_hat = dctn(rho, type=2, norm="ortho") phi_hat = rho_hat / mu phi_hat[0, 0] = 0.0 # DC is a free gauge (arbitrary global offset) phi = idctn(phi_hat, type=2, norm="ortho") return phi # --------------------------------------------------------------------------- # Algorithm 5: SNAPHU (Chen & Zebker, 2001) — optional dependency # --------------------------------------------------------------------------- @contextlib.contextmanager def _suppress_fd1(): """ Context manager that silences C-level stdout (file descriptor 1). ``contextlib.redirect_stdout`` only redirects Python's ``sys.stdout`` object. SNAPHU's C library writes directly to fd 1, so the only reliable way to suppress it is to temporarily replace fd 1 with /dev/null at the OS level via ``os.dup2``. """ sys.stdout.flush() saved_fd = os.dup(1) # save real fd 1 devnull_fd = os.open(os.devnull, os.O_WRONLY) # open /dev/null try: os.dup2(devnull_fd, 1) # redirect fd 1 → /dev/null yield finally: sys.stdout.flush() os.dup2(saved_fd, 1) # restore real fd 1 os.close(saved_fd) os.close(devnull_fd) def _unwrap_snaphu(phase, verbose=False): """ SNAPHU phase unwrapping (Chen & Zebker, 2001). Uses the Statistical-cost, Network-flow Algorithm for Phase Unwrapping (SNAPHU) via the ``snaphu-py`` package. The per-pixel reliability map is normalised to [0, 1] and used as a coherence estimate so that SNAPHU's statistical cost model is guided by local phase quality. Requires the optional dependency:: pip install snaphu Parameters ---------- phase : ndarray 2-D wrapped phase array (radians). verbose : bool, optional If ``False`` (default) SNAPHU's extensive C-level output is suppressed and a single progress line is printed instead. Set to ``True`` to see the full SNAPHU log (useful for debugging). Returns ------- unwrapped : ndarray Float64 unwrapped phase array, same shape as ``phase``. Raises ------ ImportError If ``snaphu-py`` is not installed. Reference --------- C. W. Chen and H. A. Zebker, "Two-dimensional phase unwrapping with use of statistical models for cost functions in nonlinear optimization", J. Opt. Soc. Am. A 18(2), 338-351 (2001). """ if not _SNAPHU_AVAILABLE: raise ImportError( "snaphu-py is required for method='snaphu'.\n" "Install it with: pip install snaphu" ) phase = np.asarray(phase, dtype=np.float64) # Convert wrapped phase → complex interferogram e^{i φ} igram = np.exp(1j * phase).astype(np.complex64) # Use reliability map as a coherence proxy (normalised to [0, 1]) rel = _reliability_map(phase) rel_range = rel.max() - rel.min() if rel_range == 0.0: corr = np.ones(phase.shape, dtype=np.float32) else: corr = ((rel - rel.min()) / rel_range).astype(np.float32) # Run SNAPHU, optionally suppressing its C-level output print("Running SNAPHU ...", end=" ", flush=True) if verbose: print() # newline before SNAPHU's own output unw, _ = _snaphu.unwrap(igram, corr, nlooks=1) else: with _suppress_fd1(): unw, _ = _snaphu.unwrap(igram, corr, nlooks=1) print("done.", flush=True) return np.asarray(unw, dtype=np.float64) # --------------------------------------------------------------------------- # Public dispatcher # ---------------------------------------------------------------------------
[docs] def unwrap_phase_2d(phase, method="herraez", verbose=False): """ Unwrap a 2-D phase array using one of three internal algorithms. Parameters ---------- phase : ndarray 2-D wrapped phase array (radians). method : str, optional Algorithm to use. One of: ``"herraez"`` (default) Reliability-guided BFS. Computes a per-pixel reliability from second-order phase differences and processes pixel edges in decreasing reliability order via a priority queue. Generally the most robust choice for noisy data. Reference: M. A. Herraez, D. R. Burton, M. J. Lalor, and M. A. Gdeisat, "Fast two-dimensional phase-unwrapping algorithm based on sorting by reliability", Applied Optics 41(35), 2002. ``"goldstein"`` Branch-cut algorithm. Locates phase residues, pairs them with a greedy nearest-neighbour strategy, draws L-shaped branch cuts, then unwraps via BFS flood fill from the image centre. Falls back to Flynn's method when no residues are found. Reference: R. M. Goldstein, H. A. Zebker, and C. L. Werner, "Satellite radar interferometry: Two-dimensional phase unwrapping", Radio Science 23(4), 1988. ``"flynn"`` Row-integration + median row correction. Integrates wrapped horizontal differences row by row, then corrects accumulated row-to-row offsets using a median estimator. O(N) complexity, fastest but least robust to noise. Reference: T. J. Flynn, "Two-dimensional phase unwrapping with minimum weighted discontinuity", Journal of the Optical Society of America A 14(10), 1997. ``"wls"`` DCT-based unweighted least-squares (Ghiglia & Romero, 1994). Solves the 2-D Poisson equation whose right-hand side is the divergence of the wrapped phase gradients, using a DCT-II transform (Neumann boundary conditions). Globally optimal under a flat weight model; O(N log N) complexity. More robust than the greedy methods for densely wrapped, smooth phase fields. Reference: D. C. Ghiglia and L. A. Romero, "Robust two-dimensional weighted and unweighted phase unwrapping that uses fast transforms and iterative methods", J. Opt. Soc. Am. A 11(1), 107-117 (1994). ``"snaphu"`` Statistical-cost network-flow algorithm (Chen & Zebker, 2001). Regarded as the gold standard for 2-D phase unwrapping. The per-pixel reliability map is used as a coherence estimate. Requires the optional package ``snaphu-py`` (``pip install snaphu``). Reference: C. W. Chen and H. A. Zebker, "Two-dimensional phase unwrapping with use of statistical models for cost functions in nonlinear optimization", J. Opt. Soc. Am. A 18(2), 338-351 (2001). verbose : bool, optional Only used when ``method='snaphu'``. If ``False`` (default) SNAPHU's extensive C-level log is suppressed and replaced by a single line. Set to ``True`` to see the full SNAPHU output. Returns ------- unwrapped : ndarray Float64 unwrapped phase array, same shape as ``phase``. Raises ------ ValueError If ``method`` is not one of the recognised algorithm names. ImportError If ``method='snaphu'`` and ``snaphu-py`` is not installed. """ method = method.lower() if method == "herraez": return _unwrap_herraez(phase) elif method == "goldstein": return _unwrap_goldstein(phase) elif method == "flynn": return _unwrap_flynn(phase) elif method == "wls": return _unwrap_wls(phase) elif method == "snaphu": return _unwrap_snaphu(phase, verbose=verbose) else: raise ValueError( f"Unknown unwrapping method '{method}'. " "Choose one of: 'herraez', 'goldstein', 'flynn', 'wls', 'snaphu'." )
# --------------------------------------------------------------------------- # Internal single-image unwrapping helper # --------------------------------------------------------------------------- def _unwrapping_phase(img2unwrap, rx=[], ry=[], airpix=[], method="herraez", verbose=False): """ Unwrap the phases of a projection Parameters ---------- img2unwrap : ndarray A 2-dimensional array containing the image to be unwrapped rx, ry : tuple or list of ints Limits of the are to be unwrapped in x and y airpix : tuple or list of ints Position of pixel in the air/vacuum area method : str, optional Phase unwrapping algorithm. See ``unwrap_phase_2d`` for details. Default is ``"herraez"``. verbose : bool, optional Passed to ``unwrap_phase_2d`` (only relevant for ``method='snaphu'``). Default is ``False``. Returns ------- unwrapped : ndarray Unwrapped image (new array — the input ``img2unwrap`` is never modified). """ if rx == [] and ry == []: # unwrap_phase_2d always returns a new array, so the original is safe unwrapped = unwrap_phase_2d(img2unwrap, method=method, verbose=verbose) unwrapped -= -2 * np.pi * np.round(unwrapped / (2 * np.pi)) else: # Take an explicit copy of the ROI so the original wrapped data is # never overwritten — this guarantees that calling with a different # method later always starts from the same original wrapped phase. img_wrap_sel = img2unwrap[ry[0] : ry[-1], rx[0] : rx[-1]].copy() # Unwrap the selected region img_unwrap_sel = unwrap_phase_2d(img_wrap_sel, method=method, verbose=verbose) # Air-pixel correction: read from the unwrapped result # (airpix is (col, row) == (x, y); chooseregiontounwrap guarantees # it lies inside [ry, rx]) air_val = img_unwrap_sel[airpix[1] - ry[0], airpix[0] - rx[0]] air_offset = 2 * np.pi * np.round(air_val / (2 * np.pi)) # Build the output image: start from a copy of the original so that # pixels outside the ROI keep their original values. unwrapped = img2unwrap.copy() unwrapped[ry[0] : ry[-1], rx[0] : rx[-1]] = img_unwrap_sel - air_offset return unwrapped # --------------------------------------------------------------------------- # Internal parallel unwrapping helper # --------------------------------------------------------------------------- def _unwrapping_phase_parallel(stack2unwrap, rx=[], ry=[], airpix=[], ncores=1, method="herraez", verbose=False): """ Unwrap the phases of a stack of projections in parallel. Parameters ---------- stack2unwrap : ndarray, shape (n, nr, nc) Stack of 2-D wrapped phase images to be unwrapped. rx : list or range of int, optional Column index range of the ROI. Default ``[]`` (full width). ry : list or range of int, optional Row index range of the ROI. Default ``[]`` (full height). airpix : tuple or list of int, optional Position ``(col, row)`` of a pixel in the air/vacuum region used to set the absolute phase offset. Default ``[]``. ncores : int, optional Number of CPU cores for parallel computation. Pass ``-1`` to use all available cores. Default ``1``. method : str, optional Phase unwrapping algorithm. See :func:`unwrap_phase_2d` for details. Default ``"herraez"``. verbose : bool, optional Passed to :func:`unwrap_phase_2d`; only relevant for ``method='snaphu'``. Default ``False``. Returns ------- stack_out : ndarray, shape (n, nr, nc) Stack of unwrapped phase images. """ if ncores == -1: try: ncpus = int(os.environ["SLURM_JOB_CPUS_PER_NODE"]) except Exception: ncpus = multiprocessing.cpu_count() else: ncpus = ncores print("Parallel unwrapping using {} cpus".format(ncpus), flush=True) # Unwrap each projection in the ROI (copy so the original is never touched) roi_wrapped = stack2unwrap[:, ry[0] : ry[-1], rx[0] : rx[-1]].copy() _worker = functools.partial(unwrap_phase_2d, method=method, verbose=verbose) nprojs = roi_wrapped.shape[0] with ProcessPoolExecutor(max_workers=ncpus) as executor: results = list( tqdm( executor.map(_worker, roi_wrapped), total=nprojs, desc="Unwrapping projections", ) ) # Air-pixel indices within the ROI sub-array air_row = airpix[1] - ry[0] # airpix is (col, row) air_col = airpix[0] - rx[0] # Build the output from a copy of the original (preserves pixels outside # the ROI and guarantees the caller's stack_phasecorr is never modified) stack_out = stack2unwrap.copy() print("Correcting for air values") for ii in range(stack2unwrap.shape[0]): # Read air correction from the UNWRAPPED result (same logic as the # sequential path) — the wrapped value would always round to 0 air_val = results[ii][air_row, air_col] airphase = np.round(air_val / (2 * np.pi)) stack_out[ii, ry[0] : ry[-1], rx[0] : rx[-1]] = ( results[ii] - 2 * np.pi * airphase ) return stack_out # --------------------------------------------------------------------------- # Public stack unwrapping function # ---------------------------------------------------------------------------
[docs] def unwrapping_phase(stack_phasecorr, rx, ry, airpix, **params): """ Unwrap the phase of the projections in a stack. Parameters ---------- stack_phasecorr : ndarray A 3-dimensional array containing the stack of projections to be unwrapped rx, ry : tuple or list of ints Limits of the are to be unwrapped in x and y airpix : tuple or list of ints Position of pixel in the air/vacuum area **params Dictionary of additional parameters. vmin : float or None Minimum value for the gray level at each display. vmax : float or None Maximum value for the gray level at each display. unwrap_method : str, optional Phase unwrapping algorithm to use. One of ``"herraez"`` (default), ``"goldstein"``, ``"flynn"``, ``"wls"``, or ``"snaphu"``. See :func:`unwrap_phase_2d` for details. ``"snaphu"`` requires the optional package ``snaphu-py`` (``pip install snaphu``). parallel : bool, optional If ``True``, use parallel processing. Default ``True``. n_cpus : int, optional Number of CPU cores for parallel computation. Pass ``-1`` to use all available cores. Default ``-1``. Returns ------- stack_unwrap : ndarray A 3-dimensional array containing the stack of unwrapped projections Notes ----- Five algorithms are available. The default is the reliability-guided algorithm by Herraez et al. [#herraez]_. For the best robustness use ``"wls"`` (no extra dependency) or ``"snaphu"`` (requires ``pip install snaphu``). References ---------- .. [#herraez] M. A. Herraez et al., "Fast two-dimensional phase-unwrapping algorithm based on sorting by reliability following a noncontinuous path", Applied Optics 41(35), 7437 (2002). .. [#ghiglia] D. C. Ghiglia and L. A. Romero, "Robust two-dimensional weighted and unweighted phase unwrapping that uses fast transforms and iterative methods", J. Opt. Soc. Am. A 11(1), 107-117 (1994). .. [#chen] C. W. Chen and H. A. Zebker, "Two-dimensional phase unwrapping with use of statistical models for cost functions in nonlinear optimization", J. Opt. Soc. Am. A 18(2), 338-351 (2001). """ params.setdefault("parallel", True) params.setdefault("unwrap_method", "herraez") params.setdefault("snaphu_verbose", False) verbose = params["snaphu_verbose"] if params["parallel"]: if params["n_cpus"] == -1: try: ncores = int(os.environ["SLURM_JOB_CPUS_PER_NODE"]) except: ncores = multiprocessing.cpu_count() else: ncores = params["n_cpus"] else: ncores = 1 stack_unwrap = np.empty_like(stack_phasecorr) # ---- test unwrap on the first projection and show the result ---- print("Testing unwrapping on the first projection …", flush=True) img0_unwrap = _unwrapping_phase( stack_phasecorr[0], rx, ry, airpix, method=params["unwrap_method"], verbose=verbose, ) print("→ done.", flush=True) # Show the unwrapped first projection with the selected boundaries so the # user can verify vmin/vmax before the full stack is processed. # If the color scale needs changing, re-run with different params["vmin"] / # params["vmax"] values — no blocking input() prompt. plt.close("all") fig, ax1 = plt.subplots(figsize=(8, 6)) im1 = ax1.imshow( img0_unwrap, cmap="bone", vmin=params["vmin"], vmax=params["vmax"] ) ax1 = _plotdelimiters(ax1, ry, rx, airpix) ax1.axis("tight") ax1.set_title( "First projection — unwrapped (vmin={}, vmax={})".format( params["vmin"], params["vmax"] ), fontsize=9, ) print( "Displaying first unwrapped projection " "(vmin={}, vmax={}).".format(params["vmin"], params["vmax"]), flush=True, ) print( "If the color scale needs adjusting, stop here, update " "params['vmin'] / params['vmax'] and re-run.", flush=True, ) if isnotebook(): try: import matplotlib as _mpl _interactive = "inline" not in _mpl.get_backend().lower() except Exception: _interactive = False if _interactive: plt.show(block=False) fig.canvas.draw() else: from IPython import display as _ipy_display _ipy_display.display(fig) plt.close(fig) else: plt.show(block=False) fig.canvas.draw() # ---- unwrap the full stack ---- if not params["parallel"] or ncores <= 1: nprojs = stack_phasecorr.shape[0] for ii in tqdm(range(nprojs), desc="Unwrapping projections"): stack_unwrap[ii] = _unwrapping_phase( stack_phasecorr[ii], rx, ry, airpix, method=params["unwrap_method"], verbose=verbose, ) else: stack_unwrap = _unwrapping_phase_parallel( stack_phasecorr, rx, ry, airpix, ncores=ncores, method=params["unwrap_method"], verbose=verbose, ) return stack_unwrap