2024年1月17日

[转载]ZLPR: A NOVEL LOSS FOR MULTI-LABEL CLASSIFICATION

为防失效,转载原文,来自 https://kexue.fm/archives/9064

def multilabel_categorical_crossentropy(y_true, y_pred):
    """多标签分类的交叉熵
    说明:y_true和y_pred的shape一致,y_true的元素非0即1,
         1表示对应的类为目标类,0表示对应的类为非目标类。
    警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred
         不用加激活函数,尤其是不能加sigmoid或者softmax!预测
         阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解
         本文。
    """
    y_pred = (1 - 2 * y_true) * y_pred
    y_pred_neg = y_pred - y_true * 1e12
    y_pred_pos = y_pred - (1 - y_true) * 1e12
    zeros = K.zeros_like(y_pred[..., :1])
    y_pred_neg = K.concatenate([y_pred_neg, zeros], axis=-1)
    y_pred_pos = K.concatenate([y_pred_pos, zeros], axis=-1)
    neg_loss = K.logsumexp(y_pred_neg, axis=-1)
    pos_loss = K.logsumexp(y_pred_pos, axis=-1)
    return neg_loss + pos_loss

多标签分类问题的统一loss,能媲美精调权重下的二分类方案,这个损失函数有着单标签分类中“Softmax+交叉熵”的优点,即便在正负类不平衡的依然能够有效工作。但从这个损失函数的形式我们可以看到,它只适用于“硬标签”,这就意味着label smoothing、mixup等技巧就没法用了。本文则尝试解决这个问题,提出上述损失函数的一个软标签版本。

def multilabel_categorical_crossentropy(y_true, y_pred):
    """多标签分类的交叉熵
    """
    inf = torch.tensor(np.inf, dtype=torch.float32, device=y_pred.device)
    y_mask = torch.ne(y_pred, -inf)
    y_neg = torch.where(y_mask, y_pred, -inf) + torch.log(1 - y_true)
    y_pos = torch.where(y_mask, -y_pred, -inf) + torch.log(y_true)
    zeros = torch.zeros_like(y_pred[..., :1])
    y_neg = torch.cat([y_neg, zeros], dim=-1)
    y_pos = torch.cat([y_pos, zeros], dim=-1)
    neg_loss = torch.logsumexp(y_neg, dim=-1)
    pos_loss = torch.logsumexp(y_pos, dim=-1)
    return neg_loss + pos_loss
Share

You may also like...

发表评论

您的电子邮箱地址不会被公开。