Skip to content

Commit

Permalink
[Fix] Fix DICE format and a bug (#1630)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyutang authored Dec 15, 2021
1 parent a2a43b8 commit 73f029f
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions paddleseg/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,29 +74,33 @@ def auc_roc(logits, label, num_classes, ignore_index=None):
Returns:
auc_roc(float): The area under roc curve
"""
if ignore_index or len(np.unique(label))>num_classes:
if ignore_index or len(np.unique(label)) > num_classes:
raise RuntimeError('labels with ignore_index is not supported yet.')

if len(label.shape) != 4:
raise ValueError('The shape of label is not 4 dimension as (N, C, H, W), it is {}'.format(label.shape))
raise ValueError(
'The shape of label is not 4 dimension as (N, C, H, W), it is {}'.
format(label.shape))

if len(logits.shape) != 4:
raise ValueError('The shape of logits is not 4 dimension as (N, C, H, W), it is {}'.format(logits.shape))

N, C, H, W = logits.shape
raise ValueError(
'The shape of logits is not 4 dimension as (N, C, H, W), it is {}'.
format(logits.shape))

N, C, H, W = logits.shape
logits = np.transpose(logits, (1, 0, 2, 3))
logits = logits.reshape([C, N*H*W]).transpose([1,0])
logits = logits.reshape([C, N * H * W]).transpose([1, 0])

label = np.transpose(label, (1, 0, 2, 3))
label = label.reshape([1, N*H*W]).squeeze()
label = label.reshape([1, N * H * W]).squeeze()

if not logits.shape[0] == label.shape[0]:
raise ValueError('length of `logit` and `label` should be equal, '
'but they are {} and {}.'.format(
pred.shape[0], label.shape[0]))
logits.shape[0], label.shape[0]))

if num_classes == 2:
auc = skmetrics.roc_auc_score(label, logits[:,1])
auc = skmetrics.roc_auc_score(label, logits[:, 1])
else:
auc = skmetrics.roc_auc_score(label, logits, multi_class='ovr')

Expand Down Expand Up @@ -156,7 +160,7 @@ def dice(intersect_area, pred_area, label_area):
dice = (2 * intersect_area[i]) / union[i]
class_dice.append(dice)
mdice = np.mean(class_dice)
return np.array(class_dice), mdice
return np.array(class_dice), mdice


def accuracy(intersect_area, pred_area):
Expand Down

0 comments on commit 73f029f

Please sign in to comment.