Source code for sarcasm.training_data_generation

# -*- coding: utf-8 -*-
# Copyright (c) 2025 University Medical Center Göttingen, Germany.
# All rights reserved.
#
# Patent Pending: DE 10 2024 112 939.5
# SPDX-License-Identifier: LicenseRef-Proprietary-See-LICENSE
#
# This software is licensed under a custom license. See the LICENSE file
# in the root directory for full details.
#
# **Commercial use is prohibited without a separate license.**
# Contact MBM ScienceBridge GmbH (https://sciencebridge.de/en/) for licensing.


import os.path
from typing import Union

import numpy as np
import skimage
import tifffile
import torch
import torch.nn.functional as F
from bio_image_unet.unet import Predict
from matplotlib import pyplot as plt
from scipy import ndimage
from scipy.interpolate import griddata
from scipy.ndimage import label
from skimage.draw import line
from skimage.morphology import skeletonize

from sarcasm import PlotUtils, Structure, Utils


[docs] class TrainingDataGenerator: """ Class for training data generation: - Z-band mask - M-band mask - Sarcomere orientation field - Sarcomere mask """ def __init__(self, image_path: str, output_dirs: dict, pixelsize: float = None) -> None: """ Initialize TrainingDataGenerator for a single tiff file. Parameters ---------- image_path : str Path to microscopy image tiff file. output_dirs : dict Dictionary with paths of output dictionaries for targets, e.g. {'zbands': 'D:/training_data/zbands/', 'mbands': 'D:/training_data/mbands/', 'orientation': 'D:/training_data/orientation/', 'sarcomere_mask': 'D:/training_data/sarcomere_mask/'} pixelsize : float Pixel size of image in µm. """ self.image_path = image_path self.image = tifffile.imread(image_path) self.basename = os.path.basename(self.image_path) self.output_dirs = output_dirs self.shape = self.image.shape self.wavelet_dict = None if pixelsize is None: self.pixelsize = self.get_pixel_size(image_path)[0] else: self.pixelsize = pixelsize
[docs] def __getattr__(self, attr): if attr in self.output_dirs: value = tifffile.imread(self.output_dirs[attr] + self.basename) return value raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'")
[docs] def __dir__(self): """Include dynamic attributes in autocompletion""" return super().__dir__() + list(self.output_dirs.keys())
[docs] def predict_zbands(self, model_path: str, network: str = 'Unet_v0', patch_size: tuple[int, int] = (1024, 1024)): """ Predict sarcomere Z-bands using pre-trained U-Net model. This is optional, alternatively manually annotated Z-band masks in 'zbands' directory can be used. Parameters ---------- model_path : str Path of U-Net model for sarcomere Z-band detection. network : str Model type, choose from models in bio-image-unet package. Defaults to 'UNet_v0'. patch_size : Tuple[int, int] Patch size for prediction. Sizes should be multiples of 16. """ Predict(self.image, result_name=self.output_dirs['zbands'] + self.basename, model_params=model_path, network=network, resize_dim=patch_size)
[docs] def wavelet_analysis(self, kernel: str = 'half_gaussian', size: float = 3.0, minor: float = 0.33, major: float = 1.0, len_lims: tuple[float, float] = (1.45, 2.7), len_step: float = 0.05, orient_lims: tuple[float, float] = (-90, 90), orient_step: float = 10, add_negative_center_kernel: bool = False, patch_size: int = 1024, score_threshold: float = 0.25, abs_threshold: bool = True, gating: bool = True, load_mbands: bool = False, dtype: Union[torch.dtype, str] = 'auto', save_memory: bool = False, device: torch.device = torch.device('cpu')): """ AND-gated double wavelet analysis of sarcomere length and orientation. Parameters ---------- kernel : str, optional Filter kernel - 'gaussian' for bivariate Gaussian kernel - 'half_gaussian' for univariate Gaussian in minor axis direction and step function in major axis direction - 'binary' for binary step function in both directions Defaults to 'half_gaussian'. size : float, optional Size of wavelet filters (in µm), needs to be larger than the upper limit of len_lims. Defaults to 3.0. minor : float, optional Minor axis width in µm, quantified by full width at half-maximum (FWHM, 2.33 * sigma in our paper), should match the thickness of Z-bands, for kernel='gaussian' and kernel='half_gaussian'. Defaults to 0.33. major : float, optional Major axis width (parameter 'w' in our paper) in µm, should match the width of Z-bands. Full width at half-maximum (FWHM) for kernel='gaussian' and full width for kernel='half_gaussian'. Defaults to 1.0. len_lims : tuple(float, float), optional Limits of lengths / wavelet distances in µm, range of sarcomere lengths. Defaults to (1.3, 2.6). len_step : float, optional Step size of sarcomere lengths in µm. Defaults to 0.05. orient_lims : tuple(float, float), optional Limits of sarcomere orientation angles in degrees. Defaults to (-90, 90). orient_step : float, optional Step size of orientation angles in degrees. Defaults to 10. add_negative_center_kernel : bool, optional Whether to add a negative kernel in the middle of the two wavelets, to avoid detection of two Z-bands two sarcomeres apart as sarcomere, only for kernel='gaussian'. Defaults to False. patch_size : int, optional Patch size for wavelet analysis, default is 1024 pixels. Adapt to GPU storage. Defaults to 1024. score_threshold : float, optional Threshold score for clipping of length and orientation map (if abs_threshold=False, score_threshold is percentile (e.g., 90) for adaptive thresholding). Defaults to 0.25. abs_threshold : bool, optional If True, absolute threshold value is applied; if False, adaptive threshold based on percentile. Defaults to True. gating : bool, optional If True, AND-gated wavelet filtering is used. If False, both wavelet filters are applied jointly. Defaults to True. load_mbands : bool, optional If True, manually curated M-band mask is loaded. dtype : torch.dtype or str, optional Specify torch data type (torch.float32 or torch.float16), 'auto' chooses float16 for cuda and mps, and float32 for cpu. Defaults to 'auto'. device : torch.device Device for 2D convolutions (torch.device('cuda') for GPU, torch.device('mps') for Apple Silicon, torch.device('cpu') for CPU) """ assert size > 1.1 * len_lims[1], (f"The size of wavelet filter {size} is too small for the maximum sarcomere " f"length {len_lims[1]}") # select precision if device.type == 'cpu': dtype = torch.float32 elif device.type == 'cuda' or device.type == 'mps': dtype = torch.float16 zbands = tifffile.imread(self.output_dirs['zbands'] + self.basename) # create filter bank bank, len_range, orient_range = TrainingDataGenerator.create_wavelet_bank(pixelsize=self.pixelsize, kernel=kernel, size=size, minor=minor, major=major, len_lims=len_lims, len_step=len_step, orient_lims=orient_lims, orient_step=orient_step, add_negative_center_kernel=add_negative_center_kernel) len_range_tensor = torch.from_numpy(len_range).to(device).to(dtype=dtype) orient_range_tensor = torch.from_numpy(np.radians(orient_range)).to(device).to(dtype=dtype) # convolve zbands with wavelet kernels result = TrainingDataGenerator.convolve_image_with_bank(zbands, bank, device=device, gating=gating, dtype=dtype, save_memory=save_memory, patch_size=patch_size) # argmax (wavelet_sarcomere_length, wavelet_sarcomere_orientation, wavelet_max_score) = TrainingDataGenerator.argmax_wavelets(result, len_range_tensor, orient_range_tensor) # evaluate wavelet results at sarcomere mbands if load_mbands: mbands = self.mbands else: mbands = None (pos_vectors_px, mband_id_vectors, mband_length_vectors, sarcomere_length_vectors, sarcomere_orientation_vectors, max_score_vectors, mbands, mbands_labels, score_threshold) = self.get_sarcomere_vectors_wavelet( wavelet_sarcomere_length, wavelet_sarcomere_orientation, wavelet_max_score, len_range=len_range, mbands=mbands, score_threshold=score_threshold, abs_threshold=abs_threshold) # empty memory del result if torch.cuda.is_available(): torch.cuda.empty_cache() self.wavelet_dict = { 'sarcomere_length': wavelet_sarcomere_length, 'sarcomere_orientation': wavelet_sarcomere_orientation, 'max_score': wavelet_max_score, 'pos_vectors_px': pos_vectors_px, 'sarcomere_length_vectors': sarcomere_length_vectors, 'mband_length_vectors': mband_length_vectors, 'mband_id_vectors': mband_id_vectors, 'sarcomere_orientation_vectors': sarcomere_orientation_vectors, 'max_score_vectors': max_score_vectors, 'mbands': mbands, 'mbands_labels': mbands_labels} # save mbands as tiff tifffile.imwrite(self.output_dirs['mbands'] + self.basename, mbands)
[docs] def create_sarcomere_mask(self, dilation_radius): """ Create binary sarcomere mask from sarcomere vectors. Parameters ---------- dilation_radius : float Dilation radius to dilate sarcomere mask and close small gaps, in µm. """ # get sarcomere vectors pos_vectors_px = self.wavelet_dict['pos_vectors_px'] sarcomere_orientation_vectors = self.wavelet_dict['sarcomere_orientation_vectors'] sarcomere_length_vectors = self.wavelet_dict['sarcomere_length_vectors'] # calculate sarcomere mask if len(pos_vectors_px) > 0: sarcomere_mask = Structure.sarcomere_mask(pos_vectors_px * self.pixelsize, -sarcomere_orientation_vectors, sarcomere_length_vectors, shape=self.shape, pixelsize=self.pixelsize, dilation_radius=dilation_radius) else: sarcomere_mask = np.zeros(self.shape, dtype='bool') # save sarcomere mask as tiff tifffile.imwrite(self.output_dirs['sarcomere_mask'] + self.basename, sarcomere_mask)
[docs] def create_orientation_map(self): """ Creates 2D sarcomere orientation map from sarcomere vectors. The 2D shows shows the directions of unit vectors pointing from M-bands to Z-bands. Undefined regions have np.nan values. Returns ------- orientation_map : numpy.ndarray A 2D array with values reflecting the local sarcomere orientation angle. """ # Extract data from wavelet_dict pos_vectors = self.wavelet_dict['pos_vectors_px'] sarcomere_orientation_vectors = self.wavelet_dict['sarcomere_orientation_vectors'] orientation_vectors = np.asarray([np.sin(sarcomere_orientation_vectors), -np.cos(sarcomere_orientation_vectors)]) sarcomere_length_vectors = self.wavelet_dict['sarcomere_length_vectors'] / self.pixelsize # Calculate endpoints of each vector based on orientation and length ends_0 = pos_vectors.T + orientation_vectors * sarcomere_length_vectors / 2 # End point 1 ends_1 = pos_vectors.T - orientation_vectors * sarcomere_length_vectors / 2 # End point 2 # Initialize output array orientation_map = np.full(self.image.shape, np.nan, dtype='float32') def orientation_angle_line(o, len_line): """ Creates an array with: - First half: o + π - Second half: o """ if len_line < 2: raise ValueError("Length must be at least 2.") midpoint = len_line // 2 return np.concatenate([ np.full(midpoint, o + np.pi), np.full(len_line - midpoint, o) ]) # Populate orientation array for each sarcomere for e0, e1, o in zip(ends_0.T.astype('int'), ends_1.T.astype('int'), sarcomere_orientation_vectors): rr, cc = line(*e0, *e1) # Get pixel coordinates for the line # Check for out-of-bounds coordinates if (np.any(rr < 0) or np.any(cc < 0) or np.any(rr >= self.image.shape[0]) or np.any(cc >= self.image.shape[1])): continue orientation_map[rr, cc] = orientation_angle_line(o, len(cc)) tifffile.imwrite(self.output_dirs['orientation'] + self.basename, orientation_map)
[docs] def smooth_orientation_map(self, window_size: int = 3): """ Smooth orientation angle map using a nanmedian filter. To handle the angle discontinuity from 2 pi -> 0, the orientation angle map is converted to a 2D orientation field, both components are smoothed, and then converted back Parameters ---------- window_size : int Size of smoothing kernel, must be odd integer. """ # load orientation map orientation_map = tifffile.imread(self.output_dirs['orientation'] + self.basename) # convert to orientation field orientation_field = np.stack((np.cos(orientation_map), np.sin(orientation_map))) # smooth both components with custom nanmedian filter orientation_field_smoothed = orientation_field.copy() for i, comp_i in enumerate(orientation_field): orientation_field_smoothed[i] = Utils.nanmedian_filter_numba(comp_i, window_size=window_size) # convert back to orientation map orientation_map_smoothed = np.arctan2(orientation_field_smoothed[1], orientation_field_smoothed[0]) # save smoothed orientation map as tiff tifffile.imwrite(self.output_dirs['orientation'] + self.basename, orientation_map_smoothed)
[docs] def plot_results(self, save_path=None, xlim=None, ylim=None): mosaic = """ ABC DEF """ fig, axs = plt.subplot_mosaic(figsize=(PlotUtils.width_1p5cols, 4), mosaic=mosaic, constrained_layout=True, dpi=300) # image image_i = tifffile.imread(self.image_path) image_i = np.clip(image_i, a_min=np.percentile(image_i, 0.1), a_max=np.percentile(image_i, 99.9)) axs['A'].imshow(image_i, cmap='gray') axs['A'].set_title('Image') # zbands zbands_i = tifffile.imread(self.output_dirs['zbands'] + self.basename) axs['B'].imshow(zbands_i, cmap='gray') axs['B'].set_title('Z-bands') # mbands zbands_i = tifffile.imread(self.output_dirs['mbands'] + self.basename) axs['C'].imshow(zbands_i, cmap='gray') axs['C'].set_title('M-bands') # sarcomere vectors pos_vectors = self.wavelet_dict['pos_vectors_px'] sarcomere_orientation_vectors = self.wavelet_dict['sarcomere_orientation_vectors'] sarcomere_length_vectors = self.wavelet_dict['sarcomere_length_vectors'] / self.pixelsize orientation_vectors = np.asarray( [np.cos(sarcomere_orientation_vectors), np.sin(sarcomere_orientation_vectors)]) # adjust sarcomere lengths to appear correct in quiver plot half_length = sarcomere_length_vectors * 0.5 headaxislength = 4 linewidths = 0.5 color_arrows = 'k' color_points = 'darkgreen' s_points = 5 axs['D'].imshow(self.zbands, cmap='Purples') axs['D'].quiver(pos_vectors[:, 1], pos_vectors[:, 0], -orientation_vectors[0] * half_length, orientation_vectors[1] * half_length, width=linewidths, headaxislength=headaxislength, units='xy', angles='xy', scale_units='xy', scale=1, color=color_arrows, alpha=0.5, label='Sarcomere vectors') axs['D'].quiver(pos_vectors[:, 1], pos_vectors[:, 0], orientation_vectors[0] * half_length, -orientation_vectors[1] * half_length, headaxislength=headaxislength, units='xy', angles='xy', scale_units='xy', scale=1, color=color_arrows, alpha=0.5, width=linewidths) axs['D'].scatter(pos_vectors[:, 1], pos_vectors[:, 0], marker='.', c=color_points, edgecolors='none', s=s_points * 0.5, label='Midline pos_vectors') axs['D'].set_title('Sarcomere vectors') # sarcomere mask sarcomere_mask_i = tifffile.imread(self.output_dirs['sarcomere_mask'] + self.basename) axs['E'].imshow(sarcomere_mask_i, cmap='gray') axs['E'].set_title('Sarcomere mask') # sarcomere orientation angle field orientation_i = tifffile.imread(self.output_dirs['orientation'] + self.basename) plot = axs['F'].imshow(orientation_i, cmap='hsv', vmin=-np.pi, vmax=np.pi) axs['F'].set_title('Orientation angle') colorbar = plt.colorbar(ax=axs['F'], mappable=plot, orientation='horizontal', shrink=0.7) colorbar.set_ticks([-np.pi, 0, np.pi]) colorbar.set_ticklabels([r'-$\pi$', '0', r'$\pi$']) [PlotUtils.remove_ticks(axs[key]) for key in axs.keys()] if xlim: [axs[key].set_xlim(xlim) for key in axs.keys()] if ylim: [axs[key].set_ylim(ylim[::-1]) for key in axs.keys()] if save_path: fig.savefig(save_path, dpi=500) plt.show()
[docs] @staticmethod def binary_kernel(d: float, sigma: float, width: float, orient: float, size: float, pixelsize: float, mode: str = 'both') -> Union[np.ndarray, tuple[np.ndarray, np.ndarray]]: """ Returns binary kernel pair for AND-gated double wavelet analysis. Parameters ---------- d : float Distance between two wavelets. sigma : float Minor axis width of single wavelets. width : float Major axis width of single wavelets. orient : float Rotation orientation in degrees. size : Tuple[float, float] Size of kernel in µm. pixelsize : float Pixelsize in µm. mode : str, optional 'separate' returns two separate kernels, 'both' returns a single kernel. Defaults to 'both'. Returns ------- Union[np.ndarray, Tuple[np.ndarray, np.ndarray]] The generated binary kernel(s). """ # meshgrid size_pixel = TrainingDataGenerator.round_up_to_odd(size / pixelsize) _range = np.linspace(-size / 2, size / 2, size_pixel, dtype='float32') x_mesh, y_mesh = np.meshgrid(_range, _range) # build kernel kernel0 = np.zeros_like(x_mesh) kernel0[np.abs((-x_mesh - d / 2)) < sigma / 2] = 1 kernel0[np.abs(y_mesh) > width / 2] = 0 kernel1 = np.zeros_like(x_mesh) kernel1[np.abs((x_mesh - d / 2)) < sigma / 2] = 1 kernel1[np.abs(y_mesh) > width / 2] = 0 # Normalize the kernels kernel0 /= np.sum(kernel0) kernel1 /= np.sum(kernel1) kernel0 = ndimage.rotate(kernel0, orient, reshape=False, order=3) kernel1 = ndimage.rotate(kernel1, orient, reshape=False, order=3) if mode == 'separate': return kernel0, kernel1 elif mode == 'both': return kernel0 + kernel1 else: raise ValueError(f'Kernel mode {mode} not defined.')
[docs] @staticmethod def gaussian_kernel(dist: float, minor: float, major: float, orient: float, size: float, pixelsize: float, mode: str = 'both', add_negative_center_kernel: bool = False) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]: """ Returns gaussian kernel pair for AND-gated double wavelet analysis Parameters ---------- dist : float Distance between two wavelets minor : float Minor axis width of single wavelets in µm major : float Major axis width of single wavelets in µm orient : float Rotation orientation in degree size : float Size of kernel in µm pixelsize : float Pixelsize in µm mode : str, optional 'separate' returns two separate kernels, 'both' returns single kernel add_negative_center_kernel : bool, optional Whether to add a negative kernel in the middle of the two wavelets, to avoid detection of two Z-bands two sarcomeres apart as sarcomere Returns ------- tuple[np.ndarray, np.ndarray] Gaussian kernel pair """ # Transform FWHM to sigma minor_sigma = minor / 2.355 major_sigma = major / 2.355 # Calculate the size of the kernel in pixels and create meshgrid size_pixel = TrainingDataGenerator.round_up_to_odd(size / pixelsize) _range = np.linspace(-size / 2, size / 2, size_pixel, dtype='float32') x_mesh, y_mesh = np.meshgrid(_range, _range) # Create the first Gaussian kernel kernel0 = (1 / (2 * np.pi * minor_sigma * major_sigma) * np.exp( -((x_mesh - dist / 2) ** 2 / (2 * minor_sigma ** 2) + y_mesh ** 2 / (2 * major_sigma ** 2)))) # Create the second Gaussian kernel kernel1 = (1 / (2 * np.pi * minor_sigma * major) * np.exp( -((x_mesh + dist / 2) ** 2 / (2 * minor_sigma ** 2) + y_mesh ** 2 / (2 * major_sigma ** 2)))) # Create the middle Gaussian kernel kernelmid = (1 / (2 * np.pi * minor_sigma * major_sigma) * np.exp( -(x_mesh ** 2 / (2 * minor_sigma ** 2) + y_mesh ** 2 / (2 * major_sigma ** 2)))) # Normalize the kernels kernel0 /= np.sum(kernel0) kernel1 /= np.sum(kernel1) kernelmid /= np.sum(kernelmid) kernelmid *= -1 if add_negative_center_kernel: kernel0 += kernelmid kernel1 += kernelmid # Rotate the kernels kernel0 = ndimage.rotate(kernel0, orient, reshape=False, order=2) kernel1 = ndimage.rotate(kernel1, orient, reshape=False, order=2) # Return the kernels based on the mode if mode == 'separate': return kernel0, kernel1 elif mode == 'both': return kernel0 + kernel1 else: raise ValueError(f'Kernel model {mode} not defined!')
[docs] @staticmethod def half_gaussian_kernel(dist: float, minor: float, major: float, orient: float, size: float, pixelsize: float, mode: str = 'both', add_negative_center_kernel: bool = False) -> Union[ tuple[np.ndarray, np.ndarray], np.ndarray]: """ Returns kernel pair for AND-gated double wavelet analysis with univariate Gaussian profile in longitudinal minor axis direction and step function in lateral major axis direction Parameters ---------- dist : float Distance between two wavelets minor : float Minor axis width, in full width at half maximum (FWHM), of single wavelets in µm major : float Major axis width of single wavelets in µm. orient : float Rotation orientation in degree size : float Size of kernel in µm pixelsize : float Pixelsize in µm mode : str, optional 'separate' returns two separate kernels, 'both' returns single kernel add_negative_center_kernel : bool, optional Whether to add a negative kernel in the middle of the two wavelets, to avoid detection of two Z-bands two sarcomeres apart as sarcomere Returns ------- tuple[np.ndarray, np.ndarray] Gaussian kernel pair """ # Transform FWHM to sigma minor_sigma = minor / 2.355 major_sigma = major / 2.355 # Calculate the size of the kernel in pixels and create meshgrid size_pixel = TrainingDataGenerator.round_up_to_odd(size / pixelsize) _range = np.linspace(-size / 2, size / 2, size_pixel, dtype='float32') x_mesh, y_mesh = np.meshgrid(_range, _range) # Create the first Gaussian kernel kernel0 = 1 / (np.sqrt(2 * np.pi) * minor_sigma) * np.exp(-((x_mesh - dist / 2) ** 2 / (2 * minor_sigma ** 2))) # Create the second Gaussian kernel kernel1 = 1 / (np.sqrt(2 * np.pi) * minor_sigma) * np.exp(-((x_mesh + dist / 2) ** 2 / (2 * minor_sigma ** 2))) # Create the middle Gaussian kernel kernelmid = 1 / (np.sqrt(2 * np.pi) * minor_sigma) * np.exp(-(x_mesh ** 2 / (2 * minor_sigma ** 2))) # set to 0 where wider than major axis kernel0[np.abs(y_mesh) > major / 2] = 0 kernel1[np.abs(y_mesh) > major / 2] = 0 kernelmid[np.abs(y_mesh) > major / 2] = 0 # Normalize the kernels kernel0 /= np.sum(kernel0) kernel1 /= np.sum(kernel1) kernelmid /= np.sum(kernelmid) kernelmid *= -1 if add_negative_center_kernel: kernel0 += kernelmid kernel1 += kernelmid # Rotate the kernels kernel0 = ndimage.rotate(kernel0, orient, reshape=False, order=2) kernel1 = ndimage.rotate(kernel1, orient, reshape=False, order=2) # Return the kernels based on the mode if mode == 'separate': return kernel0, kernel1 elif mode == 'both': return kernel0 + kernel1 else: raise ValueError(f'Kernel mode {mode} not defined!')
[docs] @staticmethod def round_up_to_odd(f: float) -> int: """ Rounds float up to the next odd integer. Parameters ---------- f : float The input float number. Returns ------- int The next odd integer. """ return int(np.ceil(f) // 2 * 2 + 1)
[docs] @staticmethod def create_wavelet_bank(pixelsize: float, kernel: str = 'half_gaussian', size: float = 3, minor: float = 0.15, major: float = 0.5, len_lims: tuple[float, float] = (1.3, 2.5), len_step: float = 0.025, orient_lims: tuple[float, float] = (-90, 90), orient_step: float = 5, add_negative_center_kernel: bool = False) -> list[np.ndarray]: """ Returns bank of double wavelets. Parameters ---------- pixelsize : float Pixel size in µm. kernel : str, optional Filter kernel ('gaussian' for double Gaussian kernel, 'binary' for binary double-line, 'half_gaussian' for half Gaussian kernel). Defaults to 'half_gaussian'. size : float, optional Size of kernel in µm. Defaults to 3. minor : float, optional Minor axis width of single wavelets. Defaults to 0.15. major : float, optional Major axis width of single wavelets. Defaults to 0.5. len_lims : Tuple[float, float], optional Limits of lengths / wavelet distances in µm. Defaults to (1.3, 2.5). len_step : float, optional Step size in µm. Defaults to 0.025. orient_lims : Tuple[float, float], optional Limits of orientation angle in degrees. Defaults to (-90, 90). orient_step : float, optional Step size in degrees. Defaults to 5. add_negative_center_kernel : bool, optional Whether to add a negative kernel in the middle of the two wavelets, to avoid detection of two Z-bands two sarcomeres apart as sarcomere, only for kernel=='gaussian' or 'half_gaussian. Defaults to False. Returns ------- List[np.ndarray] Bank of double wavelets. """ len_range = np.arange(len_lims[0] - len_step, len_lims[1] + len_step, len_step, dtype='float32') orient_range = np.arange(orient_lims[0], orient_lims[1], orient_step, dtype='float32') size_pixel = TrainingDataGenerator.round_up_to_odd(size / pixelsize) bank = np.zeros((len_range.shape[0], orient_range.shape[0], 2, size_pixel, size_pixel)) for i, d in enumerate(len_range): for j, orient in enumerate(orient_range): if kernel == 'gaussian': bank[i, j] = TrainingDataGenerator.gaussian_kernel(d, minor, major, orient=orient, size=size, pixelsize=pixelsize, mode='separate', add_negative_center_kernel=add_negative_center_kernel) elif kernel == 'half_gaussian': bank[i, j] = TrainingDataGenerator.half_gaussian_kernel(d, minor, major, orient=orient, size=size, pixelsize=pixelsize, mode='separate', add_negative_center_kernel=add_negative_center_kernel) elif kernel == 'binary': bank[i, j] = TrainingDataGenerator.binary_kernel(d, minor, major, orient, size, pixelsize, mode='separate') else: raise ValueError("Unsupported kernel type. Choose from 'gaussian', 'binary', or 'half_gaussian'.") return bank, len_range, orient_range
[docs] @staticmethod def convolve_image_with_bank(image: np.ndarray, bank: np.ndarray, device: torch.device, gating: bool = True, dtype: torch.dtype = torch.float16, save_memory: bool = False, patch_size: int = 512) -> torch.Tensor: """ AND-gated double-wavelet convolution of image using kernels from filter bank, with merged functionality. Processes the image in smaller overlapping patches to manage GPU memory usage and avoid edge effects. Parameters ---------- image : np.ndarray Input image to be convolved. bank : np.ndarray Filter bank containing the wavelet kernels. device : torch.device Device on which to perform the computation (e.g., 'cuda', 'mps' or 'cpu'). gating : bool, optional Whether to use AND-gated double-wavelet convolution. Default is True. dtype : torch.dtype, optional Data type for the tensors. Default is torch.float16. save_memory : bool, optional Whether to save memory by moving intermediate results to CPU. Default is False. patch_size : int, optional Size of the patches to process the image in. Default is 512. Returns ------- torch.Tensor The result of the convolution, reshaped to match the input image dimensions. """ # Convert image to dtype and normalize image_torch = torch.from_numpy((image)).to(dtype=dtype).to(device).view(1, 1, image.shape[0], image.shape[1]) kernel_size = bank.shape[3] margin = kernel_size // 2 def process_patch(patch: torch.Tensor) -> torch.Tensor: with torch.no_grad(): if gating: # Convert filters to float32 bank_0, bank_1 = bank[:, :, 0], bank[:, :, 1] filters_torch_0 = torch.from_numpy(bank_0).to(dtype=dtype).to(device).view( bank_0.shape[0] * bank_0.shape[1], 1, bank_0.shape[2], bank_0.shape[3]) filters_torch_1 = torch.from_numpy(bank_1).to(dtype=dtype).to(device).view( bank_1.shape[0] * bank_1.shape[1], 1, bank_1.shape[2], bank_1.shape[3]) # Perform convolutions if save_memory: res0 = F.conv2d(patch, filters_torch_0, padding='same').to('cpu') del filters_torch_0 res1 = F.conv2d(patch, filters_torch_1, padding='same').to('cpu') del filters_torch_1 else: res0 = F.conv2d(patch, filters_torch_0, padding='same') del filters_torch_0 res1 = F.conv2d(patch, filters_torch_1, padding='same') del filters_torch_1 del patch # Multiply results as torch tensors result = res0 * res1 del res0, res1 else: # Combine filters combined_filters = bank[:, :, 0] + bank[:, :, 1] filters_torch = torch.from_numpy(combined_filters).to(dtype=dtype).to(device).view( combined_filters.shape[0] * combined_filters.shape[1], 1, combined_filters.shape[2], combined_filters.shape[3]) # Perform convolution if save_memory: result = F.conv2d(patch, filters_torch, padding='same').to('cpu') else: result = F.conv2d(patch, filters_torch, padding='same') return result # Process image in patches with overlap if image.shape[0] <= patch_size and image.shape[1] <= patch_size: return process_patch(image_torch).view(bank.shape[0], bank.shape[1], image.shape[0], image.shape[1]) output = torch.zeros(bank.shape[0], bank.shape[1], image.shape[0], image.shape[1], dtype=dtype, device=device) for i in range(0, image.shape[0], patch_size - 2 * margin): for j in range(0, image.shape[1], patch_size - 2 * margin): patch = image_torch[:, :, max(i - margin, 0):min(i + patch_size + margin, image.shape[0]), max(j - margin, 0):min(j + patch_size + margin, image.shape[1])] patch_result = process_patch(patch).view(bank.shape[0], bank.shape[1], patch.shape[2], patch.shape[3]) # Determine the region to place the patch result start_i = i end_i = min(i + patch_size, image.shape[0]) start_j = j end_j = min(j + patch_size, image.shape[1]) # Calculate the corresponding region in the patch result patch_start_i = 0 if i == 0 else margin patch_end_i = (end_i - start_i) + patch_start_i patch_start_j = 0 if j == 0 else margin patch_end_j = (end_j - start_j) + patch_start_j output[:, :, start_i:end_i, start_j:end_j] = patch_result[:, :, patch_start_i:patch_end_i, patch_start_j:patch_end_j] return output.view(bank.shape[0], bank.shape[1], image.shape[0], image.shape[1])
[docs] @staticmethod def argmax_wavelets(result: torch.Tensor, len_range: torch.Tensor, orient_range: torch.Tensor) -> tuple[ np.ndarray, np.ndarray, np.ndarray]: """ Compute the argmax of wavelet convolution results to extract length, orientation, and maximum score map. This function processes the result of a wavelet convolution operation to determine the optimal length and orientation for each position in the input image. It leverages GPU acceleration for efficient computation and returns the results as NumPy arrays. Parameters ---------- result : torch.Tensor The result tensor from a wavelet convolution operation, expected to be on a GPU device. Shape is expected to be (num_orientations, num_lengths, height, width). len_range : torch.Tensor A tensor containing the different lengths used in the wavelet bank. Shape: (num_lengths,). orient_range : torch.Tensor A tensor containing the different orientation angles used in the wavelet bank, in degrees. Shape: (num_orientations,). Returns ------- length_np : np.ndarray A 2D array of the optimal length for each position in the input image. Shape: (height, width). orient_np : np.ndarray A 2D array of the optimal orientation (in radians) for each position in the input image. Shape: (height, width). max_score_np : np.ndarray A 2D array of the maximum convolution score for each position in the input image. Shape: (height, width). """ # Keep the reshaping and max operation on the GPU result_reshaped = result.permute(2, 3, 0, 1).view(result.shape[2] * result.shape[3], -1) max_score, argmax = torch.max(result_reshaped, 1) max_score = max_score.view(result.shape[2], result.shape[3]) # Calculate indices for lengths and orientations using PyTorch len_indices = argmax // result.shape[1] orient_indices = argmax % result.shape[1] length = len_range[len_indices].view(result.shape[2], result.shape[3]) orient = orient_range[orient_indices].view(result.shape[2], result.shape[3]) return length.cpu().numpy(), orient.cpu().numpy(), max_score.cpu().numpy()
[docs] @staticmethod def get_sarcomere_vectors_wavelet(length: np.ndarray, orientation: np.ndarray, max_score: np.ndarray, len_range: np.ndarray, mbands: Union[np.ndarray, None] = None, score_threshold: float = 0.2, abs_threshold: bool = True) -> tuple: """ Extracts vector positions on sarcomere mbands and calculates sarcomere length and orientation. This function performs the following steps: 1. **Thresholding:** Applies a threshold to the length, orientation, and max_score arrays to refine sarcomere detection. 2. **Binarization:** Creates a binary mask to isolate mband regions. 3. **Skeletonization:** Thins the mband regions for easier analysis. 4. **Labeling:** Assigns unique labels to each connected mband component. 5. **Midline Point Extraction:** Identifies the coordinates of vectors along each mband. 6. **Value Calculation:** Calculates sarcomere length, orientation, and maximal score at each mband point. Parameters ---------- length : np.ndarray Sarcomere length map obtained from wavelet analysis. orientation : np.ndarray Sarcomere orientation angle map obtained from wavelet analysis. max_score : np.ndarray Map of maximal wavelet scores. len_range : torch.Tensor An array containing the different lengths used in the wavelet bank. mbands : np.ndarray or None, optional If not None, manually curated / corrected M-band mask is loaded. score_threshold : float, optional Threshold for filtering detected sarcomeres. Can be either an absolute value (if abs_threshold=True) or a percentile value for adaptive thresholding (if abs_threshold=False). Default is 90. abs_threshold : bool, optional Flag to determine the thresholding method. If True, 'score_threshold' is used as an absolute value. If False, 'score_threshold' is interpreted as a percentile for adaptive thresholding. Default is False. Returns ------- tuple * **pos_vectors_px** (list): List of (x, y) coordinates for each mband point. In pixels. * **mband_id_vectors** (list): List of corresponding mband labels for each point. * **mband_length_vectors** (list): List of approximate mband lengths associated with each point. In pixels. * **sarcomere_length_vectors** (list): List of sarcomere lengths at each mband point. In µm. * **sarcomere_orientation_vectors** (list): List of sarcomere orientation angles at each mband point. * **max_score_vectors** (list): List of maximal wavelet scores at each mband point. * **mband** (np.ndarray): The binarized mband mask. * **score_threshold** (float): The final threshold value used. """ # rough thresholding of sarcomere structures to better identify adaptive threshold # determine adaptive threshold from value distribution if not abs_threshold: score_threshold_val = max_score.max() * score_threshold else: score_threshold_val = score_threshold # binarize and skeletonize B-bands if mbands is None: mbands = max_score >= score_threshold_val mbands_skel = skeletonize(mbands, method='lee') > 0 # label mbands mbands_labels, n_mbands = ndimage.label(mbands_skel, ndimage.generate_binary_structure(2, 2)) # iterate mbands and create additional list with labels and mbands length (approximated by max. Feret diameter) props = skimage.measure.regionprops_table(mbands_labels, properties=['label', 'coords', 'feret_diameter_max']) list_labels, coords_mbands, length_mbands = props['label'], props['coords'], props['feret_diameter_max'] pos_vectors_px, pos_vectors, mband_id_vectors, mband_length_vectors = [], [], [], [] if n_mbands > 0: for i, (label_i, coords_i, length_mband_i) in enumerate( zip(list_labels, coords_mbands, length_mbands)): pos_vectors_px.append(coords_i) mband_length_vectors.append(np.ones(coords_i.shape[0]) * length_mband_i) mband_id_vectors.append(np.ones(coords_i.shape[0]) * label_i) pos_vectors_px = np.concatenate(pos_vectors_px, axis=0) mband_id_vectors = np.concatenate(mband_id_vectors) mband_length_vectors = np.concatenate(mband_length_vectors) # get sarcomere orientation and distance at vectors, additionally filter score sarcomere_length_vectors = length[pos_vectors_px[:, 0], pos_vectors_px[:, 1]] sarcomere_orientation_vectors = orientation[pos_vectors_px[:, 0], pos_vectors_px[:, 1]] max_score_vectors = max_score[pos_vectors_px[:, 0], pos_vectors_px[:, 1]] # remove vectors outside range of sarcomere lengths in wavelet bank ids_in = (sarcomere_length_vectors >= len_range[1]) & (sarcomere_length_vectors <= len_range[-2]) pos_vectors_px = pos_vectors_px[ids_in] mband_length_vectors = mband_length_vectors[ids_in] mband_id_vectors = mband_id_vectors[ids_in] sarcomere_length_vectors = sarcomere_length_vectors[ids_in] sarcomere_orientation_vectors = sarcomere_orientation_vectors[ids_in] max_score_vectors = max_score_vectors[ids_in] else: sarcomere_length_vectors, sarcomere_orientation_vectors, max_score_vectors = [], [], [] return (pos_vectors_px, mband_id_vectors, mband_length_vectors, sarcomere_length_vectors, sarcomere_orientation_vectors, max_score_vectors, mbands, mbands_labels, score_threshold)
[docs] @staticmethod def interpolate_distance_map(image, N=50, method='linear'): """ Interpolates NaN regions in a 2D image, filling only those regions whose size is less than or equal to a specified threshold. Parameters ---------- image : numpy.ndarray A 2D array representing the input image. NaN values represent gaps to be filled. N : int The maximum size (in pixels) of connected NaN regions to interpolate. Regions larger than this threshold will remain unaltered. method : str, optional The interpolation method to use. Options are 'linear', 'nearest', and 'cubic'. Default is 'linear'. Returns ------- numpy.ndarray A 2D array with the same shape as the input `image`, where small NaN regions (size <= N) have been interpolated. Larger NaN regions are left unchanged. """ # Get indices and mask valid points x, y = np.indices(image.shape) valid_points = ~np.isnan(image) valid_coords = np.array((x[valid_points], y[valid_points])).T valid_values = image[valid_points] # Label connected NaN regions nan_mask = np.isnan(image) labeled_nan_regions, num_features = label(nan_mask) # Combine masks for all small regions combined_small_nan_mask = np.zeros_like(image, dtype=bool) for region_label in range(1, num_features + 1): region_mask = labeled_nan_regions == region_label region_size = np.sum(region_mask) if region_size <= N: combined_small_nan_mask |= region_mask # Interpolate all small NaN regions at once if np.any(combined_small_nan_mask): invalid_coords = np.array((x[combined_small_nan_mask], y[combined_small_nan_mask])).T interpolated_values = griddata(valid_coords, valid_values, invalid_coords, method=method) image[combined_small_nan_mask] = interpolated_values return image
[docs] @staticmethod def get_pixel_size(file_path): """ Retrieves pixel size (x, y) in micrometers from a TIFF file. Prioritizes ImageJ metadata, then falls back to TIFF resolution tags. Raises ValueError if pixel size cannot be determined. """ with tifffile.TiffFile(file_path) as tif: # Handle ImageJ metadata case if ij_meta := tif.imagej_metadata: unit_conversion = { 'm': 1e6, 'cm': 1e4, 'mm': 1e3, 'um': 1, 'µm': 1, 'nm': 1e-3 }.get(ij_meta.get('unit', 'um').lower(), 1) x_res = tif.pages[0].tags.get('XResolution', (1, 1)).value y_res = tif.pages[0].tags.get('YResolution', (1, 1)).value if x_res == (1, 1) or y_res == (1, 1): raise ValueError("Could not determine pixel size from ImageJ metadata") return ( (x_res[1] / x_res[0]) * unit_conversion, (y_res[1] / y_res[0]) * unit_conversion, ) # Handle standard TIFF case page = tif.pages[0] x_px = page.tags.get('XResolution', (1, 1)).value y_px = page.tags.get('YResolution', (1, 1)).value if x_px == (1, 1) or y_px == (1, 1): raise ValueError("Could not determine pixel size from TIFF tags") unit = page.tags.get('ResolutionUnit', 1).value # 1=None, 2=inch, 3=cm conversion = 1 if unit == 2: # Convert inches to µm conversion = 25400 elif unit == 3: # Convert cm to µm conversion = 10000 return ( (x_px[1] / x_px[0]) * conversion, (y_px[1] / y_px[0]) * conversion, )