Source code for toupy.resolution.FSC

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
FOURIER SHELL CORRELATION modules
"""

# standard library
import os
import re
import time

# third party package
import h5py
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import fftshift, ifftshift

# local packages
from ..utils import progbar
from ..utils.FFT_utils import fastfftn
from ..utils.funcutils import checkhostname
from ..utils.plot_utils import isnotebook

__all__ = ["FourierShellCorr", "FSCPlot"]


[docs]class FourierShellCorr: """ Computes the Fourier Shell Correlation [1]_ between image1 and image2, and estimate the resolution based on the threshold funcion T of 1 or 1/2 bit. Parameters ---------- img1 : ndarray A 2-dimensional array containing the first image img2 : ndarray A 2-dimensional array containing the second image threshold : str, optional The option `onebit` means 1 bit threshold with ``SNRt = 0.5``, which should be used for two independent measurements. The option `halfbit` means 1/2 bit threshold with ``SNRt = 0.2071``, which should be use for split tomogram. The default option is ``half-bit``. ring_thick : int, optional Thickness of the frequency rings. Normally the pixels get assined to the closest integer pixel ring in Fourier Domain. With ring_thick, each ring gets more pixels and more statistics. The default value is ``1``. apod_width : int, optional Width in pixel of the edges apodization. It applies a Hanning window of the size of the data to the data before the Fourier transform calculations to attenuate the border effects. The default value is ``20``. Returns ------- FSC : ndarray Fourier Shell correlation curve T : ndarray Threshold curve Note ---- If 3D images, the first axis is the number of slices, ie., ``[slices, rows, cols]`` References ---------- .. [1] M. van Heel, M. Schatzb, `Fourier shell correlation threshold criteria,` Journal of Structural Biology 151, 250-262 (2005) """ @checkhostname def __init__(self, img1, img2, threshold="halfbit", ring_thick=1, apod_width=20): print("Calling the class FourierShellCorr") self.img1 = np.array(img1) self.img2 = np.array(img2) if self.img1.shape != self.img2.shape: raise ValueError("Images must have the same size") # get dimensions and indices of the images self.n = self.img1.shape self.ndim = self.img1.ndim if self.ndim == 2: self.nr, self.nc = self.n elif self.img1.ndim == 3: self.ns, self.nr, self.nc = self.n else: print("Number of dimensions is different from 2 or 3.Exiting...") raise SystemExit("Number of dimensions is different from 2 or 3.Exiting...") self.Y, self.X = np.indices((self.nr, self.nc)) self.Y -= np.round(self.nr / 2).astype(int) self.X -= np.round(self.nc / 2).astype(int) self.ring_thick = ring_thick # ring thickness print("Using ring thickness of {} pixels".format(ring_thick)) self.apod_width = apod_width if threshold == "halfbit" or threshold == "half-bit": print("Using half-bit threshold") self.snrt = 0.2071 elif threshold == "onebit" or threshold == "one-bit": print("Using 1-bit threshold") self.snrt = 0.5 else: raise ValueError( "You need to choose a between 'halfbit' or 'onebit' threshold" ) print("Using SNRt = {}".format(self.snrt)) print("Input images have {} dimensions".format(self.img1.ndim))
[docs] def nyquist(self): """ Evaluate the Nyquist Frequency """ nmax = np.max(self.n) fnyquist = np.floor(nmax / 2.0) f = np.arange(0, fnyquist + 1).astype(np.int) return f, fnyquist
[docs] def ringthickness(self): """ Define indexes for ring_thick """ nmax = np.max(self.n) x = ( np.arange(-np.fix(self.nc / 2.0), np.ceil(self.nc / 2.0)) * np.floor(nmax / 2.0) / np.floor(self.nc / 2.0) ) y = ( np.arange(-np.fix(self.nr / 2.0), np.ceil(self.nr / 2.0)) * np.floor(nmax / 2.0) / np.floor(self.nr / 2.0) ) # bring the central pixel to the corners (important for odd array dimensions) x = ifftshift(x) y = ifftshift(y) if self.ndim == 2: # meshgriding X = np.meshgrid(x, y) elif self.ndim == 3: z = ( np.arange(-np.fix(self.ns / 2.0), np.ceil(self.ns / 2.0)) * np.floor(nmax / 2.0) / np.floor(self.ns / 2.0) ) # bring the central pixel to the corners (important for odd array dimensions) z = ifftshift(z) # meshgriding X = np.meshgrid(y, z, x) # sum of the squares independent of ndim sumsquares = np.zeros_like(X[0]) for ii in range(len(X)): sumsquares += X[ii] ** 2 index = np.round(np.sqrt(sumsquares)).astype(np.int) return index
[docs] def apodization(self): """ Compute the Hanning window of the size of the data for the apodization Note ---- This method does not depend on the parameter ``apod_width`` from the class """ if self.ndim == 2: window = np.outer(np.hanning(self.nr), np.hanning(self.nc)) elif self.ndim == 3: window1 = np.hanning(self.ns) window2 = np.hanning(self.nr) window3 = np.hanning(self.nc) windowaxial = np.outer(window2, window3) windowsag = np.array([window1 for ii in range(self.nr)]).swapaxes(0, 1) # win2d = np.rollaxis(np.array([np.tile(windowaxial,(1,1)) for ii in range(n[0])]),1).swapaxes(1,2) win2d = np.array([np.tile(windowaxial, (1, 1)) for ii in range(self.ns)]) window = ( np.array( [np.squeeze(win2d[:, :, ii]) * windowsag for ii in range(self.nc)] ) .swapaxes(0, 1) .swapaxes(1, 2) ) else: print("Number of dimensions is different from 2 or 3. Exiting...") raise SystemExit( "Number of dimensions is different from 2 or 3. Exiting..." ) return window
[docs] def circle(self): """ Create circle with apodized edges """ self.axial_apod = self.apod_width R = np.sqrt(self.X ** 2 + self.Y ** 2) Rmax = np.round(np.max(R.shape) / 2.0) maskout = R < Rmax t = ( maskout * (1 - np.cos(np.pi * (R - Rmax - 2 * self.axial_apod) / self.axial_apod)) / 2.0 ) t[np.where(R < (Rmax - self.axial_apod))] = 1 return t
[docs] def transverse_apodization(self): """ Compute a tapered Hanning-like window of the size of the data for the apodization """ print("Calculating the transverse apodization") self.transv_apod = self.apod_width if self.ndim == 2: Nr = fftshift(np.arange(self.nr)) Nc = fftshift(np.arange(self.nc)) window1D1 = ( 1.0 + np.cos( 2 * np.pi * (Nr - np.floor((self.nr - 2 * self.transv_apod - 1) / 2)) / (1 + 2 * self.transv_apod) ) ) / 2.0 window1D2 = ( 1.0 + np.cos( 2 * np.pi * (Nc - np.floor((self.nc - 2 * self.transv_apod - 1) / 2)) / (1 + 2 * self.transv_apod) ) ) / 2.0 window1D1[self.transv_apod : -self.transv_apod] = 1 window1D2[self.transv_apod : -self.transv_apod] = 1 window = np.outer(window1D1, window1D2) elif self.ndim == 3: Ns = fftshift(np.arange(self.ns)) Nr = fftshift(np.arange(self.nr)) Nc = fftshift(np.arange(self.nc)) window1D1 = ( 1.0 + np.cos( 2 * np.pi * (Ns - np.floor((self.ns - 2 * self.transv_apod - 1) / 2)) / (1 + 2 * self.transv_apod) ) ) / 2.0 window1D2 = ( 1.0 + np.cos( 2 * np.pi * (Nr - np.floor((self.nr - 2 * self.transv_apod - 1) / 2)) / (1 + 2 * self.transv_apod) ) ) / 2.0 window1D3 = ( 1.0 + np.cos( 2 * np.pi * (Nc - np.floor((self.nc - 2 * self.transv_apod - 1) / 2)) / (1 + 2 * self.transv_apod) ) ) / 2.0 window1D1[self.transv_apod : -self.transv_apod] = 1 window1D2[self.transv_apod : -self.transv_apod] = 1 window1D3[self.transv_apod : -self.transv_apod] = 1 window = [np.outer(window1D1, window1D2), np.outer(window1D1, window1D3)] return window
[docs] def fouriercorr(self): """ Method to compute FSC and threshold """ # Apodization print("Performing the apodization") circular_region = self.circle() if self.ndim == 2: print("Apodization in 2D") if self.snrt == 0.2071: self.window = circular_region elif self.snrt == 0.5: self.window = self.transverse_apodization() # ~ self.window = self.apodization() img1_apod = self.img1 * self.window img2_apod = self.img2 * self.window elif self.ndim == 3: if self.apod_width == 0: self.window = 1 else: print("Apodization in 3D. This takes time and memory...") p0 = time.time() # TODO: find a more efficient way to do this. It know this is not optimum window3D = self.transverse_apodization() circle3D = np.asarray([circular_region for ii in range(self.ns)]) self.window = ( np.array( [ np.squeeze(circle3D[:, :, ii]) * window3D[0] for ii in range(self.nc) ] ) .swapaxes(0, 1) .swapaxes(1, 2) ) self.window = np.array( [ np.squeeze(self.window[:, ii, :]) * window3D[1] for ii in range(self.nr) ] ).swapaxes(0, 1) print("Done. Time elapsed: {:.02f}s".format(time.time() - p0)) # sagital slices slicenum = np.round(self.nr / 2).astype("int") img1_apod = (self.window * self.img1)[:, slicenum, :] img2_apod = (self.window * self.img2)[:, slicenum, :] # display image fig1 = plt.figure(1) ax1 = fig1.add_subplot(121) ax2 = fig1.add_subplot(122) im1 = ax1.imshow(img1_apod, cmap="bone", interpolation="none") ax1.set_title("image1") ax1.set_axis_off() im2 = ax2.imshow(img2_apod, cmap="bone", interpolation="none") ax2.set_title("image2") ax2.set_axis_off() if isnotebook(): display.display(fig1) display.display(fig1.canvas) # display.clear_output(wait=True) else: plt.show(block=False) plt.show(block=False) # FSC computation print("Calling method fouriercorr from the class FourierShellCorr") p1 = time.time() F1 = fastfftn(self.img1 * self.window) # FFT of the first image F2 = fastfftn(self.img2 * self.window) # FFT of the second image index = self.ringthickness() # index for the ring thickness f, fnyquist = self.nyquist() # Frequency and Nyquist Frequency # initializing variables print("Initializing...") C = np.empty_like(f).astype(np.float) C1 = np.empty_like(f).astype(np.float) C2 = np.empty_like(f).astype(np.float) npts = np.zeros_like(f) print("Calculating the correlation...") for ii in f: strbar = "Normalized frequency: {:.2f}".format((ii + 1) / fnyquist) if self.ring_thick == 0 or self.ring_thick == 1: auxF1 = F1[np.where(index == ii)] auxF2 = F2[np.where(index == ii)] else: auxF1 = F1[ ( np.where( (index >= (ii - self.ring_thick / 2)) & (index <= (ii + self.ring_thick / 2)) ) ) ] auxF2 = F2[ ( np.where( (index >= (ii - self.ring_thick / 2)) & (index <= (ii + self.ring_thick / 2)) ) ) ] C[ii] = np.abs((auxF1 * np.conj(auxF2)).sum()) C1[ii] = np.abs((auxF1 * np.conj(auxF1)).sum()) C2[ii] = np.abs((auxF2 * np.conj(auxF2)).sum()) npts[ii] = auxF1.shape[0] progbar(ii + 1, len(f), strbar) print("\r") # The correlation FSC = C / (np.sqrt(C1 * C2)) # Threshold computation Tnum = ( self.snrt + (2 * np.sqrt(self.snrt) / np.sqrt(npts + np.spacing(1))) + 1 / np.sqrt(npts) ) Tden = self.snrt + (2 * np.sqrt(self.snrt) / np.sqrt(npts + np.spacing(1))) + 1 T = Tnum / Tden print("Done. Time elapsed: {:.02f}s".format(time.time() - p1)) return FSC, T
[docs]class FSCPlot(FourierShellCorr): """ Upper level object to plot the FSC and threshold curves Parameters ---------- img1 : ndarray A 2-dimensional array containing the first image img2 : ndarray A 2-dimensional array containing the second image threshold : str, optional The option `onebit` means 1 bit threshold with ``SNRt = 0.5``, which should be used for two independent measurements. The option `halfbit` means 1/2 bit threshold with ``SNRt = 0.2071``, which should be use for split tomogram. The default option is ``half-bit``. ring_thick : int, optional Thickness of the frequency rings. Normally the pixels get assined to the closest integer pixel ring in Fourier Domain. With ring_thick, each ring gets more pixels and more statistics. The default value is ``1``. apod_width : int, optional Width in pixel of the edges apodization. It applies a Hanning window of the size of the data to the data before the Fourier transform calculations to attenuate the border effects. The default value is ``20``. Returns ------- fn : ndarray A 1-dimensional array containing the frequencies normalized by the Nyquist frequency FSC : ndarray A 1-dimensional array containing the Fourier Shell correlation curve T : ndarray A 1-dimensional array containing the threshold curve """ def __init__(self, img1, img2, threshold="halfbit", ring_thick=1, apod_width=20): print("calling the class FSCplot") super().__init__(img1, img2, threshold, ring_thick, apod_width) self.FSC, self.T = FourierShellCorr.fouriercorr(self) self.f, self.fnyquist = FourierShellCorr.nyquist(self)
[docs] def plot(self): print("calling method plot from the class FSCplot") plt.figure(2) plt.clf() plt.plot(self.f / self.fnyquist, self.FSC.real, "-b", label="FSC") plt.legend() if self.snrt == 0.2071: plt.plot(self.f / self.fnyquist, self.T, "--r", label="1/2 bit threshold") plt.legend() elif self.snrt == 0.5: plt.plot(self.f / self.fnyquist, self.T, "--r", label="1 bit threshold") plt.legend() else: plotT = plt.plot(self.f / self.fnyquist, self.T) plt.legend(plotT, "Threshold SNR = %g " % self.snrt, loc="center") fn = self.f / self.fnyquist T = self.T FSC = self.FSC.real plt.xlim(0, 1) plt.ylim(0, 1.1) plt.xlabel("Spatial frequency/Nyquist") plt.ylabel("Magnitude") if isnotebook(): display.display(plt.gcf()) display.clear_output(wait=True) else: plt.show(block=False) if self.img1.ndim == 2: plt.savefig("FSC_2D.png", bbox_inches="tight") elif self.img1.ndim == 3: plt.savefig("FSC_3D.png", bbox_inches="tight") return fn, T, FSC