contraction_net

Submodules

Classes

ContractionNet

ContractionNet model for detecting contraction intervals from time-series data of individual Z-band positions

Trainer

Class for training of ContractionNet. Creates Trainer object.

DataProcess

A Dataset class for creating training data objects for ContractionNet training.

Functions

predict_contractions(data, model[, network])

predict contraction intervals time-series with neural network

Package Contents

class contraction_net.ContractionNet(n_filter=64, in_channels=1, out_channels=2, dropout_rate=0.5)[source]

Bases: torch.nn.Module

ContractionNet model for detecting contraction intervals from time-series data of individual Z-band positions and sarcomere lengths of beating cardiomyocytes.

This neural network is designed to handle noisy data and distinguish between contracting and non-contracting intervals. The network first extracts various features from a single input time-series by two convolutional layers with kernel size 5, followed by a dilated convolution in the third layer to capture broader temporal patterns. Each convolution is followed by instance normalization and ReLU activation. A self-attention layer enhances focus on salient features. The processed signal then undergoes two further convolutions before being outputted through a sigmoid activation function.

forward(x)[source]

Forward pass through the network.

conv1
in1
conv2
bn2
conv3
bn3
attention
norm1
dropout_attention
conv4
bn4
dropout_pre_output
conv_out
forward(x)[source]

Forward pass through the network.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, in_channels, sequence_length).

Returns:

  • torch.Tensor – Output tensor of shape (batch_size, out_channels, sequence_length) after sigmoid activation.

  • torch.Tensor – Raw output tensor of shape (batch_size, out_channels, sequence_length).

class contraction_net.Trainer(dataset, num_epochs, network=ContractionNet, in_channels=1, out_channels=2, batch_size=16, lr=0.001, n_filter=64, val_split=0.2, save_dir='./', save_name='model.pt', save_iter=False, loss_function='BCEDice', loss_params=(1, 1))[source]

Class for training of ContractionNet. Creates Trainer object.

Parameters:
  • dataset – Training data, object of PyTorch Dataset class

  • num_epochs (int) – Number of training epochs

  • network – Network class (Default Unet)

  • in_channels (int) – Number of input channels

  • out_channels (int) – Number of output channels

  • batch_size (int) – Batch size for training

  • lr (float) – Learning rate

  • n_filter (int) – Number of convolutional filters in first layer

  • val_split (float) – Validation split

  • save_dir (str) – Path of directory to save trained networks

  • save_name (str) – Base name for saving trained networks

  • save_iter (bool) – If True, network state is save after each epoch

  • load_weights (str, optional) – If not None, network state is loaded before training

  • loss_function (str) – Loss function (‘BCEDice’, ‘Tversky’ or ‘logcoshTversky’)

  • loss_params (Tuple[float, float]) – Parameter of loss function, depends on chosen loss function

network
model
data
in_channels = 1
out_channels = 2
num_epochs
batch_size = 16
lr = 0.001
best_loss
save_iter = False
loss_function = 'BCEDice'
loss_params = (1, 1)
n_filter = 64
dim
train_loader
val_loader
smooth_loss
optimizer
scheduler
save_dir = './'
save_name = 'model.pt'
__iterate(epoch, mode)
start()[source]

Start network training.

class contraction_net.DataProcess(source_dir, input_len=512, normalize=False, val_split=0.2, aug_factor=10, aug_p=0.5, noise_amp=0.2, random_offset=0.25, random_outlier=0.5, random_drift=(0.01, 0.2), random_swap=0.5, random_subsampling=None)[source]

Bases: torch.utils.data.Dataset

A Dataset class for creating training data objects for ContractionNet training.

Parameters:
  • source_dir (Tuple[str, str]) – Tuple containing paths to the directories of training data [images, labels]. Images should be in .tif format.

  • input_len (int, optional) – Length of the input sequences (default is 512).

  • normalize (bool, optional) – Whether to normalize each time-series (default is False).

  • aug_factor (int, optional) – Factor for image augmentation (default is 10).

  • val_split (float, optional) – Validation split for training (default is 0.2).

  • noise_amp (float, optional) – Amplitude of Gaussian noise for image augmentation (default is 0.2).

  • aug_p (float, optional) – Probability of applying augmentation (default is 0.5).

  • random_offset (float, optional) – Amplitude of random offset applied to the input sequences (default is 0.25).

  • random_outlier (float, optional) – Amplitude of random outliers added to the input sequences (default is 0.5).

  • random_drift (Tuple[float, float], optional) – Parameters for random drift: (frequency, amplitude) (default is (0.01, 0.2)).

  • random_swap (float, optional) – Probability of randomly swapping the sign of the input sequences (default is 0.5).

  • random_subsampling (Tuple[int, int], optional) – Range for random subsampling intervals (default is None).

__len__()[source]

Returns the total number of samples.

__getitem__(idx)[source]

Generates one sample of data.

source_dir
data = []
is_real = []
input_len = 512
val_split = 0.2
normalize = False
aug_factor = 10
aug_p = 0.5
noise_amp = 0.2
random_offset = 0.25
random_drift = (0.01, 0.2)
random_outlier = 0.5
random_subsampling = None
random_swap = 0.5
mode = 'train'
__load_and_edit()

Loads and preprocesses the input data files from the source directory.

__augment()

Applies data augmentation techniques to the loaded data.

__len__()[source]

Returns the total number of samples.

__getitem__(idx)[source]

Generates one sample of data.

Parameters:

idx (int) – Index of the sample to retrieve.

Returns:

sample – Dictionary containing ‘input’, ‘target’, and ‘distance’ tensors.

Return type:

dict

contraction_net.predict_contractions(data, model, network=ContractionNet)[source]

predict contraction intervals time-series with neural network

Parameters:
  • data (ndarray) – 1D array with time-series of contraction

  • model (str) – trained model weights (.pt file)

  • network (nn.Module) – Network to predict contractions from time-series

  • standard_normalizer (bool, ndarray) – If False, each data is normalized by its mean and std. If True, the mean and std from the training data set are applied. If ndarray, the data is normalized with the entered args [[input_mean, input_std], [vel_mean, vel_std]]