-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmetrics.py
More file actions
35 lines (25 loc) · 1.15 KB
/
metrics.py
File metadata and controls
35 lines (25 loc) · 1.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
from typing import Any
SMOOTH = 1e-8
class DiceLoss(nn.Module):
def __init__(self) -> None:
super(DiceLoss, self).__init__()
def forward(self, pred_mask: Any, true_mask: Any) -> torch.Tensor:
intersection = torch.sum(pred_mask * true_mask)
union = torch.sum(pred_mask) + torch.sum(true_mask)
# Add a small epsilon to the denominator to avoid division by zero
dice_loss = 1.0 - (2.0 * intersection + SMOOTH) / (union + SMOOTH)
return dice_loss
def calculate_metrics(pred_mask: Any, true_mask: Any) -> torch.Tensor:
pred_mask = pred_mask.float()
true_mask = true_mask.float()
intersection = torch.sum(pred_mask * true_mask)
union = torch.sum((pred_mask + true_mask) > 0.5)
# Add a small epsilon to the denominator to avoid division by zero
iou = (intersection + SMOOTH) / (union + SMOOTH)
dice_coefficient = (2 * intersection + SMOOTH) / (
torch.sum(pred_mask) + torch.sum(true_mask) + SMOOTH
)
pixel_accuracy = torch.sum(pred_mask == true_mask) / true_mask.numel()
return iou.item(), dice_coefficient.item(), pixel_accuracy.item()