功能性模块:(7)检测性能评估模块(precision,recall等)

功能性模块:(7)检测性能评估模块

一、模块介绍

其实每个算法的好坏都是有对应的评估标准的,如果你和老板说检测算法好或者不好,哈哈哈,那必然就是悲剧了。好或者不好是一个定性的说法,对于实际算法来说,到底怎么样算法算好?怎么样算法算不好?这些应该是有个定量的标准。对于检测来说,可能最常用的几个评价指标就是precision(查准率,就是你检测出来的目标有多少是真的目标),recall(查全率,就是实际的目标你的算法能检测出来多少),还有ap,map等。本篇博客其实就是让小伙伴们对自己的检测模型心里有一个底,换句话说这个模型你训练出来到底咋样?

二、代码实现

import numpy as np
import os

def voc_ap(rec, prec, use_07_metric=False):
    """Compute VOC AP given precision and recall. If use_07_metric is true, uses
    the VOC 07 11-point method (default:False).
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], rec, [1.]))
        mpre = np.concatenate(([0.], prec, [0.]))
        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]
        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap


def ComputeMAP(gt_root, predict_root, OVTHRESH=0.5):
    """

    :param gt_root: 生成gt文件的根目录
    :param predict_root: 算法跑出的根目录
    :param overthresh: 设置的阈值
    :return:
    """
    # 获取所有的文件
    files_gt = os.listdir(gt_root)
    files_pred = os.listdir(predict_root)
    files_gt.sort()
    # 这个变量的目的是什么?保存gt中真正的框的数量
    npos = 0
    class_recs = {}
    # 遍历所有gt文件
    for file_gt in files_gt:
        img_name = os.path.splitext(os.path.basename(file_gt))[0]
        file_gt = os.path.join(gt_root, os.path.basename(file_gt))
        print("*" * 80)
        print("img name is: ", img_name)
        print("gt file is: ", file_gt)
        # 处理gt文件
        with open(file_gt, 'r') as f:
            lines = f.readlines()
        splitlines = [x.strip().split(' ') for x in lines]
        bbox = np.array([[float(z) for z in x[:]] for x in splitlines])
        print("bbox is: \n", bbox)
        det = [False] * len(bbox)
        npos = npos + len(bbox)
        class_recs[img_name] = {'bbox': bbox, 'det': det}
    print("*" * 80)
    print("Total npos is: ", npos)

    # 遍历所有的检测结果
    img_ids = []
    confidence = []
    BB = []
    for file_pred in files_pred:
        img_name = os.path.splitext(os.path.basename(file_pred))[0]
        file_pred = os.path.join(pred_root, os.path.basename(file_pred))
        print("*" * 80)
        print("img_name is: ", img_name)
        print("pred file is: ", file_pred)
        with open(file_pred, 'r') as f:
            lines = f.readlines()
        splitlines = [x.strip().split(" ") for x in lines]
        confidence_p = [float(x[0]) for x in splitlines]
        bbox_p = [[float(z) for z in x[1:]] for x in splitlines]
        # 根据confidence_p的长度,复制对应的img_name的str,生成对应长度的list
        # ['20160220082030T28_H', '20160220082030T28_H', '20160220082030T28_H', '20160220082030T28_H']
        img_ids.extend([img_name] * len(confidence_p))
        confidence.extend(confidence_p)
        BB.extend(bbox_p)
        print(img_ids)
        print(confidence)
        print(BB)

    confidence = np.array(confidence)
    BB = np.array(BB)
    print("*" * 80)
    print("All files loaded!")

    # 按照confidence的降序进行排列
    sorted_idx = np.argsort(-confidence)
    print("sorted idx is: ", sorted_idx)
    BB = BB[sorted_idx, :]
    img_ids = [img_ids[x] for x in sorted_idx]

    # 计算对应的TPs 和 FPs
    nd = len(img_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    wrong_count = 0
    for d in range(nd):
        print("We are now test: ", img_ids[d])
        # 取出对应图像的gt
        R = class_recs[img_ids[d]]
        # 检测的结果
        bb = BB[d, :].astype(float)
        # 假设重叠面积初始为-inf
        ovmax = -np.inf
        BBGT = R['bbox'].astype(float)
        print("bb: \n ", bb)
        print("BBGT: \n", BBGT)
        print("BBGT size is: ", BBGT.size)

        if BBGT.size > 0:
            # 计算覆盖的部分
            ixmin = np.maximum(BBGT[:, 0], bb[0])
            iymin = np.maximum(BBGT[:, 1], bb[1])
            ixmax = np.minimum(BBGT[:, 2], bb[2])
            iymax = np.minimum(BBGT[:, 3], bb[3])

            iw = np.maximum(ixmax - ixmin + 1., 0.)
            ih = np.maximum(iymax - iymin + 1., 0.)
            # 计算交叉的面积
            inters = iw * ih

            # 计算iou吧
            uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.)
                   + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
                   - inters)

            overlaps = inters / uni
            ovmax = np.max(overlaps)
            jmax = np.argmax(overlaps)
            print("overlaps is: ", overlaps)
            print("ovmax is: ", ovmax)
            print("jmax is: ", jmax)

        if ovmax > OVTHRESH:
            # 如果检测的这个标记还没有激活,默认是False
            if not R['det'][jmax]:
                tp[d] = 1.
                R['det'][jmax] = 1
            else:
                fp[d] = 1.
                wrong_count += 1
        else:
            fp[d] = 1.
            wrong_count += 1
    np.set_printoptions(threshold=np.inf)
    # 计算 precision 和 recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    print("fp is: ", fp)
    print("tp is: ", tp)
    # 召回率(查全率)
    rec = tp / float(npos)
    # 精确率(查准率)
    prec = tp / np.maximum(tp + fp, np.finfo(np.float).eps)

    ap = voc_ap(rec, prec, False)
    print("ap is: ", ap)
    print("*" * 80)
    print("RESULTS: \n")
    print("Total %d images, %d objects" % (len(files_gt), npos))
    print("Detected Correct: %d, Wrong: %d, Miss: %d under IOU: %f"
          % (nd - wrong_count, wrong_count, npos - (nd - wrong_count), OVTHRESH))
    print("Accuracy %f, Recall %f, Average Precision %f"
          % (float(nd - wrong_count) / (nd), float(nd - wrong_count) / (npos), ap))

    # 记录漏检的文件
    f = open('./lost.txt', 'w')
    for k, v in class_recs.items():
        if False in v['det']:
            f.write(str(k) + '.jpg' + '\n')
    f.close()


if __name__ == "__main__":
    gt_root = './mini_test/gt/'
    pred_root = './mini_test/res/'
    ComputeMAP(gt_root, pred_root)

LZ就不详细讲代码了,注释已经很详细了,主要是你的gt应该是什么样子的呢?

  • 命名标准:img_name.txt
  • gt格式:
# x1 y1 x2 y2
965 209 1040 329 
  • res格式:
# score x1 y1 x2 y2
0.9999481 962 222 1043 331
0.9999091 635 251 747 412
0.9783503 1795 340 1836 402
0.57386667 1730 305 1748 337

这个是结果展示,代码中LZ为了清晰加了非常多的打印,谁让云存储不稳定呢,动不动图片就被损坏了,哭唧唧。。。

在这里插入图片描述
ps:最近疫情反弹的厉害,谁能想到新冠肺炎居然坚持了一年,国外疫情也是指数性增长,这算是人类的灾难,也许多年后在看现在,又会有不一样的体会。珍惜当下,爱惜生命!

已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: Age of Ai 设计师:meimeiellie 返回首页