Source code for contraction_net.training

import os

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from .losses import BCEDiceLoss, SmoothnessLoss
from .contraction_net import ContractionNet
from .utils import get_device

# select device
device = get_device()


[docs] class Trainer: """ Class for training of ContractionNet. Creates Trainer object. Parameters ---------- dataset Training data, object of PyTorch Dataset class num_epochs : int Number of training epochs network Network class (Default Unet) in_channels : int Number of input channels out_channels : int Number of output channels batch_size : int Batch size for training lr : float Learning rate n_filter : int Number of convolutional filters in first layer val_split : float Validation split save_dir : str Path of directory to save trained networks save_name : str Base name for saving trained networks save_iter : bool If True, network state is save after each epoch load_weights : str, optional If not None, network state is loaded before training loss_function : str Loss function ('BCEDice', 'Tversky' or 'logcoshTversky') loss_params : Tuple[float, float] Parameter of loss function, depends on chosen loss function """ def __init__(self, dataset, num_epochs, network=ContractionNet, in_channels=1, out_channels=2, batch_size=16, lr=1e-3, n_filter=64, val_split=0.2, save_dir='./', save_name='model.pt', save_iter=False, loss_function='BCEDice', loss_params=(1, 1)): self.network = network self.model = network(n_filter=n_filter, in_channels=in_channels, out_channels=out_channels).to(device) self.data = dataset self.in_channels = in_channels self.out_channels = out_channels self.num_epochs = num_epochs self.batch_size = batch_size self.lr = lr self.best_loss = torch.tensor(float('inf')) self.save_iter = save_iter self.loss_function = loss_function self.loss_params = loss_params self.n_filter = n_filter # split training and validation data num_val = int(len(dataset) * val_split) num_train = len(dataset) - num_val self.dim = dataset.input_len self.train_data, self.val_data = random_split(dataset, [num_train, num_val]) self.train_loader = DataLoader(self.train_data, batch_size=self.batch_size, pin_memory=True, drop_last=True) self.val_loader = DataLoader(self.val_data, batch_size=self.batch_size, pin_memory=True, drop_last=True) if loss_function == 'BCEDice': self.criterion = BCEDiceLoss(loss_params) else: raise ValueError(f'Loss "{loss_function}" not defined!') self.smooth_loss = SmoothnessLoss(alpha=0.01) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=4, factor=0.1) self.save_dir = save_dir os.makedirs(self.save_dir, exist_ok=True) self.save_name = save_name def __iterate(self, epoch, mode): if mode == 'train': print('\nStarting training epoch %s ...' % epoch) for i, batch_i in tqdm(enumerate(self.train_loader), total=len(self.train_loader), unit='batch'): x_i = batch_i['input'].view(self.batch_size, self.in_channels, self.dim).to(device) y_i = batch_i['target'].view(self.batch_size, 1, self.dim).to(device) d_i = batch_i['distance'].view(self.batch_size, 1, self.dim).to(device) # Forward pass: Compute predicted y by passing x to the model y_pred, y_logits = self.model(x_i) # Split the tensor into 2 chunks along the second dimension y_1, y_2 = torch.chunk(y_logits, chunks=2, dim=1) # Compute loss contr_loss = self.criterion(y_1, y_i) dist_loss = self.criterion(y_2, d_i) smooth_loss = self.smooth_loss(y_2) loss = contr_loss + dist_loss # Zero gradients, perform a backward pass, and update the weights. self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss elif mode == 'val': loss_list = [] print('\nStarting validation epoch %s ...' % epoch) with torch.no_grad(): for i, batch_i in enumerate(self.val_loader): x_i = batch_i['input'].view(self.batch_size, self.in_channels, self.dim).to(device) y_i = batch_i['target'].view(self.batch_size, 1, self.dim).to(device) d_i = batch_i['distance'].view(self.batch_size, 1, self.dim).to(device) # Forward pass: Compute predicted y by passing x to the model y_pred, y_logits = self.model(x_i) # Compute loss loss = self.criterion(y_logits[:, 0], y_i[:, 0]) + self.criterion(y_logits[:, 1], d_i[:, 0]) loss_list.append(loss.detach()) val_loss = torch.stack(loss_list).mean() return val_loss
[docs] def start(self): """ Start network training. """ train_loss = [] val_loss = [] for epoch in range(self.num_epochs): train_loss_i = self.__iterate(epoch, 'train') train_loss.append(train_loss_i) self.state = { 'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'best_loss': self.best_loss, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'lr': self.lr, 'loss_function': self.loss_function, 'loss_params': self.loss_params, 'in_channels': self.in_channels, 'out_channels': self.out_channels, 'n_filter': self.n_filter, 'batch_size': self.batch_size, 'augmentation': self.data.aug_factor, 'noise_amp': self.data.noise_amp, 'random_offset': self.data.random_offset, 'random_drift': self.data.random_drift, 'random_outlier': self.data.random_outlier, 'random_subsampling': self.data.random_subsampling, 'random_swap': self.data.random_swap, } with torch.no_grad(): val_loss_i = self.__iterate(epoch, 'val') val_loss.append(val_loss_i) self.scheduler.step(val_loss_i) if val_loss_i < self.best_loss: print('\nValidation loss improved from %s to %s - saving model state' % ( round(self.best_loss.item(), 5), round(val_loss_i.item(), 5))) self.state['best_loss'] = self.best_loss = val_loss_i torch.save(self.state, self.save_dir + '/' + self.save_name) if self.save_iter: torch.save(self.state, self.save_dir + '/' + f'model_epoch_{epoch}.pt')