Source code for contraction_net.losses

import torch
from torch import nn


[docs] class BCEDiceLoss(nn.Module): """ A combination of Binary Cross Entropy (BCE) and Dice Loss for binary segmentation tasks. Parameters ---------- loss_params : tuple, optional A tuple containing the weights for BCE and Dice losses respectively. Default is (1, 1). Methods ------- dice_loss(inputs, targets, epsilon=1e-6) Computes the Dice loss. forward(inputs, targets) Computes the combined BCE and Dice loss. """ def __init__(self, loss_params=(1, 1)): super(BCEDiceLoss, self).__init__() self.bce_loss = nn.BCEWithLogitsLoss() self.loss_params = loss_params
[docs] def dice_loss(self, inputs, targets, epsilon=1e-6): inputs = torch.sigmoid(inputs) intersection = (inputs * targets).sum() dice_coeff = (2. * intersection + epsilon) / (inputs.sum() + targets.sum() + epsilon) return 1 - dice_coeff
[docs] def forward(self, inputs, targets): bce = self.bce_loss(inputs, targets) dice = self.dice_loss(inputs, targets) return self.loss_params[0] * bce + self.loss_params[1] * dice
[docs] class SmoothnessLoss(nn.Module): """ Computes the smoothness loss for a sequence of predictions. Parameters ---------- alpha : float, optional Weight of the smoothness loss component. Default is 10. Methods ------- forward(predictions) Computes the smoothness loss for a sequence of predictions. """ def __init__(self, alpha=10): super(SmoothnessLoss, self).__init__() self.alpha = alpha
[docs] def forward(self, predictions): if predictions.dim() < 3: raise ValueError("The input tensor must be 3-dimensional.") diffs = predictions[:, :, 1:] - predictions[:, :, :-1] loss = torch.sum(diffs ** 2) / predictions.size(0) return self.alpha * loss
[docs] def f1_score(logits, true_labels, threshold=0.5, epsilon=1e-7): """ Computes the F1 score for binary classification. Parameters ---------- logits : torch.Tensor The raw output from the model (before applying sigmoid). true_labels : torch.Tensor The ground truth binary labels. threshold : float, optional The threshold to convert probabilities to binary predictions. Default is 0.5. epsilon : float, optional A small value to avoid division by zero. Default is 1e-7. Returns ------- float The computed F1 score. """ probabilities = torch.sigmoid(logits) predictions = probabilities > threshold predictions = predictions.float() true_labels = true_labels.float() tp = (predictions * true_labels).sum().item() fp = ((1 - true_labels) * predictions).sum().item() fn = (true_labels * (1 - predictions)).sum().item() precision = tp / (tp + fp + epsilon) recall = tp / (tp + fn + epsilon) f1_score = 2 * (precision * recall) / (precision + recall + epsilon) return f1_score