server-json/node_modules/tfjs-image-recognition-base/build/es6/metrics/averagePrecision.js
2024-11-01 08:00:42 +00:00

52 lines
No EOL
2.6 KiB
JavaScript

import { range } from '../utils';
import { assignFalseAndTruePositives } from './assignFalseAndTruePositives';
/**
* Calculates the average precision (AP) over 11 recall levels between [0, 1].
*
* @param inputs Array of ground truth and predicted boxes of each image.
* @param iouThreshold IOU threshold at which a predicted box is considered to be a true positive.
* @returns AP
*/
export function averagePrecision(inputs, iouThreshold) {
var sortedDetections = inputs
.map(function (input) { return categorizeBoxes(input, iouThreshold); })
.reduce(function (flat, arr) { return flat.concat(arr); })
.sort(function (b1, b2) { return b2.score - b1.score; })
.map(function (b) { return b.isTruePositive; });
var numGroundTruth = inputs
.map(function (input) { return input.groundTruth.length; })
.reduce(function (sum, val) { return sum + val; }, 0);
var computeRecall = function (tp) { return tp / numGroundTruth; };
var computePrecision = function (tp, fp) { return tp / (tp + fp); };
var accumulatedTps = 0, accumulatedFps = 0;
var precisionRecallPairs = sortedDetections.map(function (isTruePositive) {
accumulatedTps += (isTruePositive ? 1 : 0);
accumulatedFps += (isTruePositive ? 0 : 1);
var recall = computeRecall(accumulatedTps);
var precision = computePrecision(accumulatedTps, accumulatedFps);
return { recall: recall, precision: precision };
});
var numIntervals = 11;
var precisionsForRecallInterval = range(numIntervals, 0, 0.1).map(function (recallThreshold) {
return precisionRecallPairs
.filter(function (pair) { return pair.recall >= recallThreshold; })
.map(function (pair) { return pair.precision; })
.reduce(function (max, val) { return max > val ? max : val; }, 0);
});
return {
averagePrec: precisionsForRecallInterval.reduce(function (sum, val) { return sum + val; }, 0) / numIntervals,
accumulatedTps: accumulatedTps,
accumulatedFps: accumulatedFps
};
}
function categorizeBoxes(input, iouThreshold) {
var _a = assignFalseAndTruePositives(input.groundTruth, input.predictions, iouThreshold), truePositives = _a.truePositives, falsePositives = _a.falsePositives;
return truePositives
.map(categorize(true))
.concat(falsePositives.map(categorize(false)));
}
var categorize = function (isTruePositive) { return function (_a) {
var score = _a.score;
return ({ score: score, isTruePositive: isTruePositive });
}; };
//# sourceMappingURL=averagePrecision.js.map