contraction_net.training

Attributes

device

Classes

Trainer

Class for training of ContractionNet. Creates Trainer object.

Module Contents

contraction_net.training.device
class contraction_net.training.Trainer(dataset, num_epochs, network=ContractionNet, in_channels=1, out_channels=2, batch_size=16, lr=0.001, n_filter=64, val_split=0.2, save_dir='./', save_name='model.pt', save_iter=False, loss_function='BCEDice', loss_params=(1, 1))[source]

Class for training of ContractionNet. Creates Trainer object.

Parameters:
  • dataset – Training data, object of PyTorch Dataset class

  • num_epochs (int) – Number of training epochs

  • network – Network class (Default Unet)

  • in_channels (int) – Number of input channels

  • out_channels (int) – Number of output channels

  • batch_size (int) – Batch size for training

  • lr (float) – Learning rate

  • n_filter (int) – Number of convolutional filters in first layer

  • val_split (float) – Validation split

  • save_dir (str) – Path of directory to save trained networks

  • save_name (str) – Base name for saving trained networks

  • save_iter (bool) – If True, network state is save after each epoch

  • load_weights (str, optional) – If not None, network state is loaded before training

  • loss_function (str) – Loss function (‘BCEDice’, ‘Tversky’ or ‘logcoshTversky’)

  • loss_params (Tuple[float, float]) – Parameter of loss function, depends on chosen loss function

network
model
data
in_channels = 1
out_channels = 2
num_epochs
batch_size = 16
lr = 0.001
best_loss
save_iter = False
loss_function = 'BCEDice'
loss_params = (1, 1)
n_filter = 64
dim
train_loader
val_loader
smooth_loss
optimizer
scheduler
save_dir = './'
save_name = 'model.pt'
__iterate(epoch, mode)
start()[source]

Start network training.