AI编辑:我是小将
本文作者: CPFLAME
https://zhuanlan.zhihu.com/p/323814368
本文已由原作者授权
笔者重构了一版centernet(objects as points)的代码,并加入了蒸馏,多模型蒸馏,转caffe,转onnx,转tensorRT,把后处理也做到了网络前向当中,对落地非常的友好。
放一个centerX多模型蒸馏出来的效果图,在蒸馏时没有用到数据集的标签,只用了两个teacher的model蒸馏同一个student网络。就用大家的老婆来做demo吧。
不感兴趣的童鞋可以收藏一下笔者的表情包,如果觉得表情包好玩,跪求去github点赞。
代码地址:https://github.com/CPFLAME/centerX/
centernet是我最喜欢的检测文章之一,没有anchor,没有nms,结构简单,可拓展性强,最主要的是:落地极其方便,选一个简单的backbone,可以没有bug的转成你想要的模型(caffe,onnx,tensorRT)。并且后处理也极其的方便。
毕竟代码写的不是打打杀杀,而是人情世故,真学东西还得看其他人的文章,看我的也就图一乐。
一般来说读文章的人点进来都会带着这样一个心理,我为什么要用centerX,明明我用别的框架用的很顺利了,转过来多麻烦你知道吗,你在教我做事?
这个方面没有什么好说的,也没有做到和其他框架的差异化,只是在detectron2上对基础的centernet进行了复现而已,而且大部分代码都是白嫖自centernet-better和centernet-better-plus,就直接上在COCO上的实验结果吧。
Backbone为resnet50
centerX_KD是用27.9的resnet18作为学生网络,33.2的resnet50作为老师网络蒸馏得到的结果,详细过程在在下面的章节会讲。
Backbone为resnet18
模型蒸馏
大嘎好,我是detection。我时常羡慕的看着隔壁村的classification,embedding等玩伴,他们在蒸馏上面都混得风生水起,什么logits蒸馏,什么KL散度,什么Overhaul of Feature Distillation。每天都有不同的家庭教师来指导他们,凭什么我detection的教育资源就很少,我detection什么时候才能站起来!
造成上述的原因主要是因为detection的范式比较复杂,并不像隔壁村的classification embedding等任务,开局一张图,输出一个vector:
我们再来回头看看centernet的范式,哦,我的上帝,多么简单明了的范式:
这让笔者看到了在detection上安排家庭教师的希望,于是我们仿照了centernet本来的loss的写法,仿照了一个蒸馏的loss。具体的实现可以去code里面看,这里就说一下简单的思想。
我们拉到实验的部分,上述的瞎比猜想得到验证。
看到蒸馏效果还可以,可以在不增加计算量的情况下无痛涨点,笔者高兴了好一阵子,直到笔者在实际项目场景上遇到了一个尴尬地问题:
因为数据集A里面可能会有大量的未标注的B,B里面也会有大量的未标注的A,直接放到一起训练肯定不行,网络会学傻
在笔者再次拍了拍脑袋后,发挥了我最擅长的技能:白嫖。想到了这样一个方案:
那么我们的多模型蒸馏就可以用现有的方案拼凑起来了。这相当于我同时白嫖了自己的代码,以及不完整标注的数据集,白嫖是真的让人快乐啊。和上述提到的操作进行一番比♂较,果然用了的多模型蒸馏的效果要好一些。又一个瞎比猜想被验证了。
骂完了爽归爽,问题还是要解决的,为了解决这个问题,笔者首先想到笔者的代码是不是哪里有bug,但是找了半天都没找到,笔者还尝试了如下的方式:
看来这个bug油盐不进,软硬不吃。训练期间总会出现某个时间段loss突然增大,然后网络全部从头开始训练的情况。
这让我想到了内卷加速,资本主义泡沫破裂,经济大危机后一切推倒重来。这个时候才想起共产主义的好,毛主席真是永远滴神。
既然如此,咱们一不做二不休,直接把蛋糕给loss们分好,让共产主义无产阶级的光照耀到它们身上,笔者一气之下把loss的大小给各个兔崽子head们给规定死,具体操作如下:
接下来就是实验部分看看管不管用了,于是笔者尝试了一下之前崩溃的lr,得益于共产主义的好处,换了几个数据集跑实验都没有出现mAP拉胯的情况了,期间有几次出现了loss飞涨的情况,但是在共产主义loss强大的调控能力之下迅速恢复到正常状态,看来社会主义确实优越。同时笔者也尝试了用合适的lr,跑baseline和共产主义loss的实验,发现两者在±0.3的mAP左右,影响不大。
笔者又为此高兴了好一段时间,并且发现了共产主义loss可以用在蒸馏当中,并且表现也比较稳定,在±0.2个mAP左右。这下蒸馏可以end2end训练了,再也不用人眼去看loss、算loss weight、停掉从头训了。
这个部分的代码都在code的projects/speedup中,注意网络中不能包含DCN,不然转码很难。
centerX中提供了转caffe,转onnx的代码,onnx转tensorRT只要装好环境后一行指令就可以转换了,笔者还提供了转换后不同框架的前向代码。
代码如下:
onnx中可视化如下:
然后我们可以来康康经历了骚操作之后的后处理代码,极其的简单,相信也可以在任何的框架上快速的实现:
值得注意的是上述骚操作在转caffe的时候会报错,所以不能加。如果非要添加上去,得在caffe的prototxt中自行添加scale层,elementwise层,relu层,这个笔者没有实现,大家感兴趣可以自行添加。
考虑到大家需要向上管理,笔者写几个可以涨点的东西
除了以上的在精度方面的优化之外,其实笔者还想到很多可以做的东西,咱们不在精度这个地方跟别人卷,因为卷不过别人,检测这个领域真是神仙打架,打不过打不过。我们想着把蛋糕做大,大家一起有肉吃
其实有太多的东西想加到centerX里面去了,里面有很多很好玩的以及非常具有实用价值的东西都可以去做,但是个人精力有限,而且刚开始做centerX完全是基于兴趣爱好去做的,本人也只是渣硕,无法full time扑到这个东西上面去,所以上述的优化方向看看在我有生之年能不能做出来,做不出来给大家提供一个可行性思路也是极好的。
非常感谢廖星宇,何凌霄对centerX代码,以及发展方向上的贡献,感谢郭聪,于万金,蒋煜襄,张建浩等同学对centerX加速模块的采坑指导。
再放一遍自己的github:https://github.com/CPFLAME/centerX
以及感谢如下杰出的工作
https://github.com/xingyizhou/CenterNet
https://github.com/facebookresearch/detectron2
https://github.com/FateScript/CenterNet-better
https://github.com/lbin/CenterNet-better-plus
https://github.com/JDAI-CV/fast-reid
https://github.com/daquexian/onnx-simplifier
https://github.com/CaoWGG/TensorRT-CenterNet
推荐阅读
PyTorch 源码解读之 torch.autograd
mmdetection最小复刻版(十一):概率Anchor分配机制PAA深入分析
MMDetection新版本V2.7发布,支持DETR,还有YOLOV4在路上!
CNN:我不是你想的那样
TF Object Detection 终于支持TF2了!
无需tricks,知识蒸馏提升ResNet50在ImageNet上准确度至80%+
不妨试试MoCo,来替换ImageNet上pretrain模型!
重磅!一文深入深度学习模型压缩和加速
从源码学习Transformer!
mmdetection最小复刻版(七):anchor-base和anchor-free差异分析