【目标检测】NMS非极大值抑制代码示例

橙三吉。 提交于 2020-03-17 09:45:03
 1 import numpy as np
 2 
 3 def non_max_suppress(predicts_dict, threshold=0.2):
 4     """
 5     implement non-maximum supression on predict bounding boxes.
 6     Args:
 7         predicts_dict: {"stick": [[x1, y1, x2, y2, scores1], [...]]}.
 8         threshhold: iou threshold
 9     Return:
10         predicts_dict processed by non-maximum suppression
11     """
12     for object_name, bbox in predicts_dict.items(): #对每一个类别的目标分别进行NMS
13         bbox_array = np.array(bbox, dtype=np.float)
14  
15         ## 获取当前目标类别下所有矩形框(bounding box,下面简称bbx)的坐标和confidence,并计算所有bbx的面积
16         x1, y1, x2, y2, scores = bbox_array[:,0], bbox_array[:,1], bbox_array[:,2], bbox_array[:,3], bbox_array[:,4]
17         areas = (x2-x1+1) * (y2-y1+1)
18         #print("areas shape = ", areas.shape)
19  
20         ## 对当前类别下所有的bbx的confidence进行从高到低排序(order保存索引信息)
21         order = scores.argsort()[::-1]
22         print("类别%s的order = "%object_name, order)
23         keep = [] #用来存放最终保留的bbx的索引信息
24         k = 1 
25         ## 依次从按confidence从高到低遍历bbx,移除所有与该矩形框的IOU值大于threshold的矩形框
26         while order.size > 0:
27             print('第%d次遍历'%(k))
28             i = order[0]
29             keep.append(i) #保留当前最大confidence对应的bbx索引
30  
31             ## 获取所有与当前bbx的交集对应的左上角和右下角坐标,并计算IOU(注意这里是同时计算一个bbx与其他所有bbx的IOU)
32             xx1 = np.maximum(x1[i], x1[order[1:]]) #当order.size=1时,下面的计算结果都为np.array([]),不影响最终结果
33             yy1 = np.maximum(y1[i], y1[order[1:]])
34             xx2 = np.minimum(x2[i], x2[order[1:]])
35             yy2 = np.minimum(y2[i], y2[order[1:]])
36             inter = np.maximum(0.0, xx2-xx1+1) * np.maximum(0.0, yy2-yy1+1)
37             iou = inter/(areas[i]+areas[order[1:]]-inter)
38             print("iou =", iou)
39  
40             print(np.where(iou<=threshold)) #输出没有被移除的bbx索引(相对于iou向量的索引)
41             indexs = np.where(iou<=threshold)[0] + 1 #获取保留下来的索引(因为没有计算与自身的IOU,所以索引相差1,需要加上)
42             print("indexs = ", indexs)
43             order = order[indexs] #更新保留下来的索引, ( array([0, 1, 2]),)
44             print("order = ", order)
45             k+=1
46         bbox = bbox_array[keep]
47         predicts_dict[object_name] = bbox.tolist()
48         predicts_dict = predicts_dict
49     return predicts_dict
50     
51 if __name__ == "__main__":
52     #predicts_dict={"cup":[[894, 354, 63, 60, 0.6477274894714355], [648, 386, 72, 59, 0.9115888476371765],[772, 233, 30, 43, 0.6633416414260864], [723, 246, 50, 41, 0.8980304598808289]]}
53     predicts_dict={"cup":[[647, 385, 789, 501, 0.6477274894714355], [648, 386, 792, 504, 0.9115888476371765],
54                             [772, 233, 832, 319, 0.6633416414260864], [767, 224, 828, 309, 0.7833416414260864], [723, 246, 823, 328, 0.8980304598808289]],
55                    "person":[[647, 385, 789, 501, 0.6477274894714355], [648, 386, 792, 504, 0.9115888476371765],
56                             [772, 233, 832, 319, 0.6633416414260864], [767, 224, 828, 309, 0.7833416414260864], [723, 246, 823, 328, 0.8980304598808289]]}
57     predicts_dict=non_max_suppress(predicts_dict, threshold=0.2)
58     print(predicts_dict)

运行结果:

 1 类别cup的order =  [1 4 3 2 0]
 2 第1次遍历
 3 iou = [0.         0.         0.         0.94050474]
 4 (array([0, 1, 2]),)
 5 indexs =  [1 2 3]
 6 order =  [4 3 2]
 7 第2次遍历
 8 iou = [0.36237211 0.39097744]
 9 (array([], dtype=int64),)
10 indexs =  []
11 order =  []
12 类别person的order =  [1 4 3 2 0]
13 第1次遍历
14 iou = [0.         0.         0.         0.94050474]
15 (array([0, 1, 2]),)
16 indexs =  [1 2 3]
17 order =  [4 3 2]
18 第2次遍历
19 iou = [0.36237211 0.39097744]
20 (array([], dtype=int64),)
21 indexs =  []
22 order =  []
23 {'cup': [[648.0, 386.0, 792.0, 504.0, 0.9115888476371765], [723.0, 246.0, 823.0, 328.0, 0.8980304598808289]], 'person': [[648.0, 386.0, 792.0, 504.0, 0.9115888476371765], [723.0, 246.0, 823.0, 328.0, 0.8980304598808289]]}

参考博客:

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!