contraction_net.data

Classes

DataProcess

A Dataset class for creating training data objects for ContractionNet training.

Module Contents

class contraction_net.data.DataProcess(source_dir, input_len=512, normalize=False, val_split=0.2, aug_factor=10, aug_p=0.5, noise_amp=0.2, random_offset=0.25, random_outlier=0.5, random_drift=(0.01, 0.2), random_swap=0.5, random_subsampling=None)[source]

Bases: torch.utils.data.Dataset

A Dataset class for creating training data objects for ContractionNet training.

Parameters:
  • source_dir (Tuple[str, str]) – Tuple containing paths to the directories of training data [images, labels]. Images should be in .tif format.

  • input_len (int, optional) – Length of the input sequences (default is 512).

  • normalize (bool, optional) – Whether to normalize each time-series (default is False).

  • aug_factor (int, optional) – Factor for image augmentation (default is 10).

  • val_split (float, optional) – Validation split for training (default is 0.2).

  • noise_amp (float, optional) – Amplitude of Gaussian noise for image augmentation (default is 0.2).

  • aug_p (float, optional) – Probability of applying augmentation (default is 0.5).

  • random_offset (float, optional) – Amplitude of random offset applied to the input sequences (default is 0.25).

  • random_outlier (float, optional) – Amplitude of random outliers added to the input sequences (default is 0.5).

  • random_drift (Tuple[float, float], optional) – Parameters for random drift: (frequency, amplitude) (default is (0.01, 0.2)).

  • random_swap (float, optional) – Probability of randomly swapping the sign of the input sequences (default is 0.5).

  • random_subsampling (Tuple[int, int], optional) – Range for random subsampling intervals (default is None).

__len__()[source]

Returns the total number of samples.

__getitem__(idx)[source]

Generates one sample of data.

source_dir
data = []
is_real = []
input_len = 512
val_split = 0.2
normalize = False
aug_factor = 10
aug_p = 0.5
noise_amp = 0.2
random_offset = 0.25
random_drift = (0.01, 0.2)
random_outlier = 0.5
random_subsampling = None
random_swap = 0.5
mode = 'train'
__load_and_edit()

Loads and preprocesses the input data files from the source directory.

__augment()

Applies data augmentation techniques to the loaded data.

__len__()[source]

Returns the total number of samples.

__getitem__(idx)[source]

Generates one sample of data.

Parameters:

idx (int) – Index of the sample to retrieve.

Returns:

sample – Dictionary containing ‘input’, ‘target’, and ‘distance’ tensors.

Return type:

dict