yolo v5 损失函数分析
与 yolo v1 类似,v5 损失函数由 3 个部分组成,分别为 bbox 回归损失、目标置信度损失和类别损失。网络每个将特征图分为若干个 cell,每个 cell 输出一个 [ t x , t y , t w , t h , p o , c 1 , c 2 , . . . ] [t_x, t_y, t_w, t_h, p_o, c_1, c_2, ...] [tx,ty,tw,th,po,c1,c2,...] 的向量,其中 t x , t y t_x,t_y tx,ty 用于计算预测框和对应 anchor box (也就是所在 cell) 两者中心的偏移量, t w , t h t_w,t_h tw,th 用于计算预测框的宽高, p o p_o po 是该 cell (预测框) 含有目标的概率, c 1 , c 2 , . . . c_1, c_2, ... c1,c2,... 为对应类别的预测值。
三个部分的损失均是通过匹配到的正样本对来计算,每一个输出特征图相互独立,直接相加得到最终每一部分的损失值。先给出整体的计算公式:
L v 5 ( t p , t gt ) = ∑ k = 0 K [ α k balance α box ∑ i = 0 S 2 ∑ j = 0 B I k i j obj L CIoU + α obj ∑ i = 0 S 2 ∑ j = 0 B I k i j obj L obj + α cls ∑ i = 0 S 2 ∑ j = 0 B I k i j obj L cls ] \mathcal{L}_{\text{v}5}\left( \boldsymbol{t}_{\text{p}},\boldsymbol{t}_{\text{gt}} \right) =\sum_{k=0}^K{\left[ \alpha _{k}^{\text{balance}}\alpha _{\text{box}}\sum_{i=0}^{S^2}{\sum_{j=0}^B{\mathbb{I}_{kij}^{\text{obj}}\mathcal{L}_{\text{CIoU}}}}+\alpha _{\text{obj}}\sum_{i=0}^{S^2}{\sum_{j=0}^B{\mathbb{I}_{kij}^{\text{obj}}\mathcal{L}_{\text{obj}}}}+\alpha _{\text{cls}}\sum_{i=0}^{S^2}{\sum_{j=0}^B{\mathbb{I}_{kij}^{\text{obj}}\mathcal{L}_{\text{cls}}}} \right]} Lv5(tp,tgt)=k=0∑K⎣ ⎡αkbalanceαboxi=0∑S2j=0∑BIkijobjLCIoU+αobji=0∑S2j=0∑BIkijobjLobj+αclsi=0∑S2j=0∑BIkijobjLcls⎦ ⎤
其中, K , S 2 , B K,S^2,B K,S2,B 分别为输出特征图、cell 和 每个 cell 上 anchor 的数量; α ⋆ \alpha_\star α⋆ 为对应项的权重,在 hyp.scratch-high.yaml 中默认取值为 α box = 0.05 , α cls = 0.3 , α obj = 0.7 \alpha_\text{box}=0.05,\alpha_\text{cls}=0.3,\alpha_\text{obj}=0.7 αbox=0.05,αcls=0.3,αobj=0.7; I k i j obj \mathbb{I}_{kij}^{\text{obj}} Ikijobj 表示第 k k k 个输出特征图,第 i i i 个 cell, 第 j j j 个 anchor box 是否是正样本,如果是正样本则为 1,反之为 0; t p , t p \boldsymbol{t}_{\text{p}},\boldsymbol{t}_{\text{p}} tp,tp 是预测向量和 ground-truth 向量; α k balance \alpha _{k}^{\text{balance}} αkbalance 用于平衡每个尺度的输出特征图的权重,默认取值为 [ 4.0 , 1.0 , 0.4 ] [4.0, 1.0, 0.4] [4.0,1.0,0.4], 依次对应 80 × 80 , 40 × 40 , 20 × 20 80\times80,40\times40,20\times20 80×80,40×40,20×20 的输出特征图。
1. bbox 回归损失
v5 使用的是 CIoU Loss。
yolo v5 中正样本匹配策略和 bbox 回归如下图所示。
具体 CIoU Loss 分析可以参考 基于IOU的损失函数合集。
iou_term = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)
lbox += (1.0 - iou_term).mean()
2. 目标置信度损失
目标置信度损失由正样本匹配得到的样本对计算,一是预测框中的目标置信度分数 p o p_o po;二是预测框和与之对应的目标框的 iou 值,其作为 ground-truth。两者计算二进制交叉熵得到最终的目标置信度损失。公式如下:
L obj ( p o , p iou ) = BCE obj sig ( p o , p iou ; w obj ) \mathcal{L}_{\text{obj}}\left( p_o,p_{\text{iou}} \right) =\text{BCE}_{\text{obj}}^\text{sig}\left( p_o,p_{\text{iou}};w_{\text{obj}} \right) Lobj(po,piou)=BCEobjsig(po,piou;wobj)
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
obji = self.BCEobj(pi[..., 4], tobj)
3. 类别损失
类别损失与置信度损失类似,通过预测框的类别分数和目标框类别的 one-hot 表现来计算类别损失,公式如下:
L cls ( c p , c gt ) = BCE cls sig ( c p , c gt ; w cls ) \mathcal{L}_{\text{cls}}\left( \boldsymbol{c}_{\text{p}},\boldsymbol{c}_{\text{gt}} \right) =\text{BCE}_{\text{cls}}^{\text{sig}}\left( \boldsymbol{c}_{\text{p}},\boldsymbol{c}_{\text{gt}};w_{\text{cls}} \right) Lcls(cp,cgt)=BCEclssig(cp,cgt;wcls)
这里目标置信度损失和类别损失使用的是带 sigmoid 的二进制交叉熵函数 BCEWithLogitsLoss。如果要使用 Focal Loss 在其基础上改动即可。
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
lcls += self.BCEcls(pi[..., 5:], t_cls)
源程序分析下次再说。