Training of Z-band 3D U-Net

This tutorial explains the creation of training data and the training of a 3D U-Net neural network model for the prediction of sarcomere Z-bands from high-speed microscopy movies of contracting cardiomyocytes. SarcAsM uses our package bio-image-unet, see https://github.com/danihae/bio-image-unet. We strongly recommend using GPU-equipped workstation or server for training and prediction. Make sure that CUDA toolkit along the respective version of PyTorch are installed and verify the installation by

import torch
torch.cuda.is_available()

Creation of training data set

We recommend a training data set of 20-100 movie segments of each 100-200 frames. Since manual tracing of Z-bands in ~1000s of image is not feasible, labels are generated in two-step procedure:

  1. Randomly select 20-50 single images from set of movies, manually annotate these, and then train 2D U-Net model. Alternatively, use our generalist or other pretrained model.

  2. Predict movies with 2D U-Net and then process resulting labels by removing flickering artifacts and

[ ]:
import os
import glob

# create folders for training data
dir_training_data = '../../training_data/unet3d/'  # adjust path
dir_training_data_movies = dir_training_data + 'movies/'
dir_training_data_labels = dir_training_data + 'zbands/'
dir_training_data_prelim_labels = dir_training_data + 'prelim_zbands/'
os.makedirs(dir_training_data_movies, exist_ok=True)
os.makedirs(dir_training_data_labels, exist_ok=True)
os.makedirs(dir_training_data_prelim_labels, exist_ok=True)

Create set of movie sequences and predict with 2D U-Net

Here we assume that 2D U-Net model already exists. For training of 2D U-Net, follow instruction here.

[ ]:
import random
import tifffile

# randomly select 50 movie segments from larger data set (alternatively manually select movies)
n_movies = 50
len_sequence = 128
dir_movies = 'path/all_movies/'  # adjust
movies = glob.glob(dir_movies + '*/*.tif')  # adjust when necessary
movies_sample = random.sample(movies, n_movies)

for movie in movies_sample:
    name = os.path.basename(os.path.dirname(movie)) + '_' + os.path.basename(movie)[:-4]
    imgs = tifffile.imread(movie)
    start_frame_random = random.randint(0, imgs.shape[0]-len_sequence-1)
    random_frames = imgs[start_frame_random: start_frame_random+len_sequence]
    tifffile.imwrite(dir_training_data_movies + name + f'_{start_frame_random}-{start_frame_random+len_sequence}.tif', random_frames)
[ ]:
import bio_image_unet.unet as unet

# predict sequences with 2D U-Net
training_sequences = glob.glob(dir_training_data_movies + '*.tif')
model_params = 'path/to/2d_unet/model.pt'  # adjust path of model

for sequence in training_sequences:
    name = os.path.basename(sequence)
    unet.Predict(sequence, dir_training_data_prelim_labels + name, model_params, resize_dim=(256, 1024), show_progress=False)  # change parameters when needed

Process preliminary labels

The preliminary training data labels are processed by analyzing the sarcomere vectors for each frame, see details here, and creating binary masks of sarcomeres, i.e., regions of sarcomeres in each frame. The mean mask of these sarcomere masks is then thresholded and dilated to form a refined mask. This dilated mask is applied to the label images to exclude unwanted regions. The labels are then thresholded and connected components are labeled. The labeled objects are filtered by volume to remove small objects. Finally, the processed masks are saved as TIFF files for use in training the 3D U-Net model.

[ ]:
import os
import glob
import numpy as np
import tifffile
from scipy.ndimage import label, binary_dilation
from sarcasm import SarcAsM
import matplotlib.pyplot as plt

def filter_by_volume(labels, min_volume=None):
    """
    Filter labeled objects by their volume.

    Parameters:
    - labels: 3D numpy array of labeled objects.
    - min_volume: Minimum volume threshold for filtering.

    Returns:
    - filtered_labels: 3D numpy array of filtered labeled objects.
    """
    # Get unique objects and their volumes
    unique, counts = np.unique(labels, return_counts=True)
    volumes = dict(zip(unique[1:], counts[1:]))  # Exclude background (label 0)

    # Filter by volume if min_volume is specified
    filtered_labels = labels.copy()
    for obj_label, volume in volumes.items():
        if min_volume is not None and volume < min_volume:
            filtered_labels[filtered_labels == obj_label] = 0

    return filtered_labels

# Path to preliminary training data labels
dir_training_data_prelim_labels = '../training_data/unet3d/prelim_zbands/'
training_sequences_labels_prelim = glob.glob(dir_training_data_prelim_labels + '*.tif')

# Ensure the directory for final training data labels exists
dir_training_data_labels = '../training_data/unet3d/zbands/'
os.makedirs(dir_training_data_labels, exist_ok=True)

# Process each sequence of preliminary labels
for sequence_labels in training_sequences_labels_prelim:
    print(f'Processing {sequence_labels}')
    name = os.path.basename(sequence_labels)
    imgs_labels = tifffile.imread(sequence_labels)

    pixelsize = 0.065
    sarc_obj = SarcAsM(sequence_labels, pixelsize=pixelsize, restart=False)

    sarc_obj.analyze_sarcomere_length_orient(save_all=False, score_threshold=0.2)

    # Read the sarcomere masks
    masks = tifffile.imread(sarc_obj.file_sarcomere_mask)
    masks_mean = masks.mean(axis=0)
    mask_thres = masks_mean > 0.5

    # Dilate the thresholded masks
    masks_thres_dilated = binary_dilation(mask_thres, structure=np.ones((11, 11)))

    # Plot the mean masks and overlaid dilated masks
    plt.figure()
    plt.imshow(masks_mean)
    plt.title("Mean Masks")
    plt.show()

    plt.figure()
    plt.imshow(tifffile.imread(sarc_obj.file_sarcomeres)[0])
    plt.imshow(masks_thres_dilated, alpha=0.5)
    plt.title("Z-bands with Overlaid Masks")
    plt.show()

    # Threshold the labels
    imgs_labels_out = imgs_labels > 20

    # Apply the dilated mask to the labels
    imgs_labels_out[:, ~masks_thres_dilated] = 0

    # Label the connected components
    labels = label(imgs_labels_out)[0]
    filtered_labels = filter_by_volume(labels, min_volume=200)

    # Generate the final mask
    masks_out = filtered_labels > 0

    # Save the final mask as a TIFF file
    tifffile.imwrite(os.path.join(dir_training_data_labels, name), masks_out.astype('uint8') * 255)

Augmenting movies to include rapid high-frequency motion

We augmented movie sequences by simulating rapidly moving Z-bands, which are under-represented in the dataset. The augmentation is performed by applying random sinusoidal shifts to both the image and corresponding label sequences.

[ ]:
import os
import glob
import random
import numpy as np
import tifffile
from scipy import ndimage

# Directories for images and masks (adjust!)
dir_images = 'D:/git/SarcAsM/training_data/unet3d/movies/'
dir_masks = 'D:/git/SarcAsM/training_data/unet3d/zbands/'

# Directories for shifted (augmented) images and masks
dir_images_shifted = os.path.join(dir_images, 'shifted/')
dir_masks_shifted = os.path.join(dir_masks, 'shifted/')

os.makedirs(dir_images_shifted, exist_ok=True)
os.makedirs(dir_masks_shifted, exist_ok=True)

# Get the list of training sequences
training_sequences = glob.glob(dir_images + '*.tif')

# Process each image sequence
for img_seq in training_sequences:
    name = os.path.basename(img_seq)
    label_seq = os.path.join(dir_masks, name)

    # Load images and labels
    imgs = tifffile.imread(img_seq)
    labels = tifffile.imread(label_seq)

    # Random frequencies and amplitudes for sinusoidal shifts (adjust when necessary)
    freq_x, amp_x = random.uniform(0, 0.5), random.uniform(5, 25)
    freq_y, amp_y = random.uniform(0, 0.5), random.uniform(0, 5)

    z_range = np.arange(imgs.shape[0])
    x_shift = amp_x * np.sin(freq_x * z_range)
    y_shift = amp_y * np.sin(freq_y * z_range)

    shifted_imgs = np.zeros_like(imgs)
    shifted_labels = np.zeros_like(labels)

    # Apply the shifts to each frame in the sequence
    for t in range(imgs.shape[0]):
        shifted_imgs[t] = ndimage.shift(imgs[t], (y_shift[t], x_shift[t]), mode='constant', cval=0.0)
        shifted_labels[t] = ndimage.shift(labels[t], (y_shift[t], x_shift[t]), mode='constant', cval=0.0)

    name_shifted = name.replace('.tif', '_random_shift.tif')

    # Save the shifted images and labels
    tifffile.imwrite(os.path.join(dir_images_shifted, name_shifted), shifted_imgs)
    tifffile.imwrite(os.path.join(dir_masks_shifted, name_shifted), shifted_labels)

Training

Prepare and process training data

Prior to training, the training images and labels are processed and augmented. For the different options for processing and augmentation (add noise, blur, adjust contrast, …) see docstring or API reference.

[ ]:
import bio_image_unet.multi_output_unet3d as unet3d

# path to training data (images and labels with identical names in separate folders)
dir_images = f'D:/git/SarcAsM/training_data/unet3d/images/'
dir_masks = f'D:/git/SarcAsM/training_data/unet3d/zbands/'

# path to directory for training data generation (is created automatically, drive should have enough storage)
data_dir = 'D:/git/SarcAsM/training_temp/20240430_unet3d_data/'

# generation of training data set and augmentation
dataset = unet3d.DataProcess(volume_dir=dir_images, target_dirs=[dir_masks], data_dir=data_dir, create=True,
                             brightness_contrast=(0.15, 0.15), aug_factor=2, clip_threshold=(0., 99.98),
                             dim_out=(64, 128, 128))

Set training parameters and train

For different training parameters, check the docstring print(unet3d.Trainer.__doc__) or API reference.

[ ]:
import os

# temp folder
save_dir = 'path/training_temp/training_unet3d/'

# define output head
output_heads = {'zbands': {'channels': 1, 'activation': 'sigmoid', 'loss': 'BCEDiceTemporalLoss', 'weight': 1}}

# initialize Trainer object
training = unet3d.Trainer(dataset, output_heads=output_heads, save_dir=save_dir, num_epochs=100 ,batch_size=8,
                          n_filter=16, load_weights=None, lr=0.0005, save_iter=True, use_interpolation=True)

# start training
training.start()

After training is completed, the model parameters model.pt are stored in the save_dir.