Training of Z-band 3D U-Net
This tutorial explains the creation of training data and the training of a 3D U-Net neural network model for the prediction of sarcomere Z-bands from high-speed microscopy movies of contracting cardiomyocytes. SarcAsM uses our package bio-image-unet
, see https://github.com/danihae/bio-image-unet. We strongly recommend using GPU-equipped workstation or server for training and prediction. Make sure that CUDA toolkit along the respective
version of PyTorch are installed and verify the installation by
import torch
torch.cuda.is_available()
Creation of training data set
We recommend a training data set of 20-100 movie segments of each 100-200 frames. Since manual tracing of Z-bands in ~1000s of image is not feasible, labels are generated in two-step procedure:
Randomly select 20-50 single images from set of movies, manually annotate these, and then train 2D U-Net model. Alternatively, use our generalist or other pretrained model.
Predict movies with 2D U-Net and then process resulting labels by removing flickering artifacts and
[ ]:
import os
import glob
# create folders for training data
dir_training_data = '../../training_data/unet3d/' # adjust path
dir_training_data_movies = dir_training_data + 'movies/'
dir_training_data_labels = dir_training_data + 'zbands/'
dir_training_data_prelim_labels = dir_training_data + 'prelim_zbands/'
os.makedirs(dir_training_data_movies, exist_ok=True)
os.makedirs(dir_training_data_labels, exist_ok=True)
os.makedirs(dir_training_data_prelim_labels, exist_ok=True)
Create set of movie sequences and predict with 2D U-Net
Here we assume that 2D U-Net model already exists. For training of 2D U-Net, follow instruction here.
[ ]:
import random
import tifffile
# randomly select 50 movie segments from larger data set (alternatively manually select movies)
n_movies = 50
len_sequence = 128
dir_movies = 'path/all_movies/' # adjust
movies = glob.glob(dir_movies + '*/*.tif') # adjust when necessary
movies_sample = random.sample(movies, n_movies)
for movie in movies_sample:
name = os.path.basename(os.path.dirname(movie)) + '_' + os.path.basename(movie)[:-4]
imgs = tifffile.imread(movie)
start_frame_random = random.randint(0, imgs.shape[0]-len_sequence-1)
random_frames = imgs[start_frame_random: start_frame_random+len_sequence]
tifffile.imwrite(dir_training_data_movies + name + f'_{start_frame_random}-{start_frame_random+len_sequence}.tif', random_frames)
[ ]:
import bio_image_unet.unet as unet
# predict sequences with 2D U-Net
training_sequences = glob.glob(dir_training_data_movies + '*.tif')
model_params = 'path/to/2d_unet/model.pt' # adjust path of model
for sequence in training_sequences:
name = os.path.basename(sequence)
unet.Predict(sequence, dir_training_data_prelim_labels + name, model_params, resize_dim=(256, 1024), show_progress=False) # change parameters when needed
Process preliminary labels
The preliminary training data labels are processed by analyzing the sarcomere vectors for each frame, see details here, and creating binary masks of sarcomeres, i.e., regions of sarcomeres in each frame. The mean mask of these sarcomere masks is then thresholded and dilated to form a refined mask. This dilated mask is applied to the label images to exclude unwanted regions. The labels are then thresholded and connected components are labeled. The labeled objects are filtered by volume to remove small objects. Finally, the processed masks are saved as TIFF files for use in training the 3D U-Net model.
[ ]:
import os
import glob
import numpy as np
import tifffile
from scipy.ndimage import label, binary_dilation
from sarcasm import SarcAsM
import matplotlib.pyplot as plt
def filter_by_volume(labels, min_volume=None):
"""
Filter labeled objects by their volume.
Parameters:
- labels: 3D numpy array of labeled objects.
- min_volume: Minimum volume threshold for filtering.
Returns:
- filtered_labels: 3D numpy array of filtered labeled objects.
"""
# Get unique objects and their volumes
unique, counts = np.unique(labels, return_counts=True)
volumes = dict(zip(unique[1:], counts[1:])) # Exclude background (label 0)
# Filter by volume if min_volume is specified
filtered_labels = labels.copy()
for obj_label, volume in volumes.items():
if min_volume is not None and volume < min_volume:
filtered_labels[filtered_labels == obj_label] = 0
return filtered_labels
# Path to preliminary training data labels
dir_training_data_prelim_labels = '../training_data/unet3d/prelim_zbands/'
training_sequences_labels_prelim = glob.glob(dir_training_data_prelim_labels + '*.tif')
# Ensure the directory for final training data labels exists
dir_training_data_labels = '../training_data/unet3d/zbands/'
os.makedirs(dir_training_data_labels, exist_ok=True)
# Process each sequence of preliminary labels
for sequence_labels in training_sequences_labels_prelim:
print(f'Processing {sequence_labels}')
name = os.path.basename(sequence_labels)
imgs_labels = tifffile.imread(sequence_labels)
pixelsize = 0.065
sarc_obj = SarcAsM(sequence_labels, pixelsize=pixelsize, restart=False)
sarc_obj.analyze_sarcomere_length_orient(save_all=False, score_threshold=0.2)
# Read the sarcomere masks
masks = tifffile.imread(sarc_obj.file_sarcomere_mask)
masks_mean = masks.mean(axis=0)
mask_thres = masks_mean > 0.5
# Dilate the thresholded masks
masks_thres_dilated = binary_dilation(mask_thres, structure=np.ones((11, 11)))
# Plot the mean masks and overlaid dilated masks
plt.figure()
plt.imshow(masks_mean)
plt.title("Mean Masks")
plt.show()
plt.figure()
plt.imshow(tifffile.imread(sarc_obj.file_sarcomeres)[0])
plt.imshow(masks_thres_dilated, alpha=0.5)
plt.title("Z-bands with Overlaid Masks")
plt.show()
# Threshold the labels
imgs_labels_out = imgs_labels > 20
# Apply the dilated mask to the labels
imgs_labels_out[:, ~masks_thres_dilated] = 0
# Label the connected components
labels = label(imgs_labels_out)[0]
filtered_labels = filter_by_volume(labels, min_volume=200)
# Generate the final mask
masks_out = filtered_labels > 0
# Save the final mask as a TIFF file
tifffile.imwrite(os.path.join(dir_training_data_labels, name), masks_out.astype('uint8') * 255)
Augmenting movies to include rapid high-frequency motion
We augmented movie sequences by simulating rapidly moving Z-bands, which are under-represented in the dataset. The augmentation is performed by applying random sinusoidal shifts to both the image and corresponding label sequences.
[ ]:
import os
import glob
import random
import numpy as np
import tifffile
from scipy import ndimage
# Directories for images and masks (adjust!)
dir_images = 'D:/git/SarcAsM/training_data/unet3d/movies/'
dir_masks = 'D:/git/SarcAsM/training_data/unet3d/zbands/'
# Directories for shifted (augmented) images and masks
dir_images_shifted = os.path.join(dir_images, 'shifted/')
dir_masks_shifted = os.path.join(dir_masks, 'shifted/')
os.makedirs(dir_images_shifted, exist_ok=True)
os.makedirs(dir_masks_shifted, exist_ok=True)
# Get the list of training sequences
training_sequences = glob.glob(dir_images + '*.tif')
# Process each image sequence
for img_seq in training_sequences:
name = os.path.basename(img_seq)
label_seq = os.path.join(dir_masks, name)
# Load images and labels
imgs = tifffile.imread(img_seq)
labels = tifffile.imread(label_seq)
# Random frequencies and amplitudes for sinusoidal shifts (adjust when necessary)
freq_x, amp_x = random.uniform(0, 0.5), random.uniform(5, 25)
freq_y, amp_y = random.uniform(0, 0.5), random.uniform(0, 5)
z_range = np.arange(imgs.shape[0])
x_shift = amp_x * np.sin(freq_x * z_range)
y_shift = amp_y * np.sin(freq_y * z_range)
shifted_imgs = np.zeros_like(imgs)
shifted_labels = np.zeros_like(labels)
# Apply the shifts to each frame in the sequence
for t in range(imgs.shape[0]):
shifted_imgs[t] = ndimage.shift(imgs[t], (y_shift[t], x_shift[t]), mode='constant', cval=0.0)
shifted_labels[t] = ndimage.shift(labels[t], (y_shift[t], x_shift[t]), mode='constant', cval=0.0)
name_shifted = name.replace('.tif', '_random_shift.tif')
# Save the shifted images and labels
tifffile.imwrite(os.path.join(dir_images_shifted, name_shifted), shifted_imgs)
tifffile.imwrite(os.path.join(dir_masks_shifted, name_shifted), shifted_labels)
Training
Prepare and process training data
Prior to training, the training images and labels are processed and augmented. For the different options for processing and augmentation (add noise, blur, adjust contrast, …) see docstring or API reference.
[ ]:
import bio_image_unet.multi_output_unet3d as unet3d
# path to training data (images and labels with identical names in separate folders)
dir_images = f'D:/git/SarcAsM/training_data/unet3d/images/'
dir_masks = f'D:/git/SarcAsM/training_data/unet3d/zbands/'
# path to directory for training data generation (is created automatically, drive should have enough storage)
data_dir = 'D:/git/SarcAsM/training_temp/20240430_unet3d_data/'
# generation of training data set and augmentation
dataset = unet3d.DataProcess(volume_dir=dir_images, target_dirs=[dir_masks], data_dir=data_dir, create=True,
brightness_contrast=(0.15, 0.15), aug_factor=2, clip_threshold=(0., 99.98),
dim_out=(64, 128, 128))
Set training parameters and train
For different training parameters, check the docstring print(unet3d.Trainer.__doc__)
or API reference.
[ ]:
import os
# temp folder
save_dir = 'path/training_temp/training_unet3d/'
# define output head
output_heads = {'zbands': {'channels': 1, 'activation': 'sigmoid', 'loss': 'BCEDiceTemporalLoss', 'weight': 1}}
# initialize Trainer object
training = unet3d.Trainer(dataset, output_heads=output_heads, save_dir=save_dir, num_epochs=100 ,batch_size=8,
n_filter=16, load_weights=None, lr=0.0005, save_iter=True, use_interpolation=True)
# start training
training.start()
After training is completed, the model parameters model.pt
are stored in the save_dir
.