正负样本不平衡处理方法总结

首先给大家推荐一下我老师大神的人工智能教学网站。教学不仅零基础,通俗易懂,而且非常风趣幽默,还时不时有内涵黄段子!点这里可以跳转到网站

1, Bootstrapping,hard negative mining
最原始的一种方法,主要使用在传统的机器学习方法中。
比如,训练cascade类型分类模型的时候,可以将每一级分类错误的样本继续添加进下一层进行训练。

比如,SVM分类中去掉那些离分界线较远的样本,只保留离分界线较近的样本。

2, heuristic sampling

标准的faster rcnn中,假设正样本IOU(0.7~1.0)。负样本IOU(0.0~0.3)。比如实际的RPN网络中,实际最后的anchor经过NMS处理后的负样本是很多的,假如有30000个。这里面只有少数的正样本,大部分都是负样本。在RPN模块中,仅仅对其中大约1/3的hard的proposal进行了2分类和回归的loss计算。即只从3000个proposal里面只选择那些hard负样本,这样实际训练出来的效果最好。但是后续的roi pooling部分,只需要传递256个proposal进行计算既可。即正的proposal128个,负的proposal128个。假设正的不够,则使用负的进行补齐。

3, online hard example mining(OHEM)

出自,Training Region-based Object Detectors with Online Hard Example Mining这篇文章,

在fast RCNN这样的框架下,在原始的网络基础上,经过Selective-search后,新接入了一个绿色的Read-only Layer,该网络对所有的ROI进行前向传播,并计算最后的loss,然后红色的网络对其中loss较大的ROI进行前向和后向传播,可以说是一种动态的选择性的传播梯度。优势也就显而易见,比原始的faster RCNN可以节省很大的运算量,训练速度回提升,最终模型准确性也提升。
其中,一个trick就是,在绿色的网络进行前向传播完,其中出来的好多ROI会存在一些Loss较高,但是这些ROI有很大的IOU的情况,这样就会使得梯度重复计算和传播,因此,这里,作者加入了NMS进行IOU的过滤。


4,Focal Loss
出自Focal Loss for Dense Object Detection这篇文章,
文章重点就是提出了focal loss这个cross entropy (CE) loss的改进版,实现了对于正负样本不平衡的调整。具体思路就是其公式,

从这个公式就可以分析出,
假设r=2,pt分数为0.9,那么这个easy example的loss将会被缩小0.01a倍
假设r=2,pt分数为0.968,那么这个easy example的loss将会被缩小0.001a倍
假设r=2,pt分数为0.1,那么这个hard example的loss将会被缩小0.81a倍
同样所有样本的loss都会缩小,但是hard example要不easy example缩小的小,从而取得好的训练效果。

从上图也可以反映出,r>0的曲线的loss要比r=0的曲线的更低,loss更小。

当然文章还提出了一个RetinaNet

RetinaNet以Resnet为基础结构,通过Feature Pyramid Network (FPN)产生多尺度的输出特征图,然后分别级联一个分类和回归的子网络。这点和faster RCNN有点区别,在faster中是只使用一个网络进行分类和回归操作,RetinaNet将2个任务分离后,也许会在精度上有一定提高吧,也更容易训练。
这里一个trick是RetinaNet的初始化,
(1)除了分类子网络的最后一层,其余层w全部初始化为u=0, σ = 0:01的高斯分布,b初始化为0。
(2)最后一个分类的卷积层,b初始化为- log((1 – π)/π),文中使用π = 0.01,这样初始化使得每个anchor被标记为前景的概率为0.01,
这里b的这个公式是怎么得出的呢?
最终的分类得分是一个逻辑回归,公式为,

这里的z=wx+b,由于w初始化为u=0, σ = 0:01的高斯分布,所以,z=b,最终的概率设为π,从而得出公式,

从而解出,b=- log((1 -π)/π)

这个初始化对于focal loss的训练很重要。

5,class balanced cross-entropy

出自Holistically-nested edge detection这篇文章, 主要用于FCN,U-net等分割,边缘检测的网络,用于对像素级别的2分类样本不平衡进行优化。

sigmoid_cross_entropy公式:

-y_hat* log(sigmoid(y)) – (1 – y_hat) * log(1 – sigmoid(y))

class_balanced_sigmoid_cross_entropy公式:

-β*y_hat* log(sigmoid(y)) -(1-β) * (1 – y_hat) * log(1 – sigmoid(y))

思想就是引入新的权值β,实现正负样本loss的平衡,从而实现对不同正负样本的平衡。

6, local rank,PISA,ISR

References:

https://github.com/abhi2610/ohem

Libra R-CNN: Towards balanced learning for object detection

Prime Sample Attention in Object Detection
 

点这里可以跳转到人工智能网站

发表评论