13 package com.neuronrobotics.bowlerkernel.djl;
15 import ai.djl.modality.cv.Image;
16 import ai.djl.modality.cv.output.BoundingBox;
17 import ai.djl.modality.cv.output.DetectedObjects;
18 import ai.djl.modality.cv.output.Landmark;
19 import ai.djl.modality.cv.output.Point;
20 import ai.djl.modality.cv.output.Rectangle;
21 import ai.djl.ndarray.NDArray;
22 import ai.djl.ndarray.NDArrays;
23 import ai.djl.ndarray.NDList;
24 import ai.djl.ndarray.NDManager;
25 import ai.djl.ndarray.types.DataType;
26 import ai.djl.ndarray.types.Shape;
27 import ai.djl.translate.Translator;
28 import ai.djl.translate.TranslatorContext;
30 import java.util.ArrayList;
31 import java.util.List;
33 import java.util.concurrent.ConcurrentHashMap;
64 width = input.getWidth();
65 height = input.getHeight();
66 NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
67 array = array.transpose(2, 0, 1).flip(0);
69 if (!array.getDataType().equals(DataType.FLOAT32)) {
70 array = array.toType(DataType.FLOAT32,
false);
73 ctx.getNDManager().create(
new float[] {104f, 117f, 123f},
new Shape(3, 1, 1));
74 array = array.sub(mean);
75 return new NDList(array);
80 public DetectedObjects
processOutput(TranslatorContext ctx, NDList list) {
81 NDManager manager = ctx.getNDManager();
85 NDArray prob = list.get(1).get(
":, 1:");
89 prob.argMax(1).toType(DataType.FLOAT32,
false),
90 prob.max(
new int[] {1})));
93 NDArray boundingBoxes = list.get(0);
94 NDArray bbWH = boundingBoxes.get(
":, 2:").mul(scaleWH).exp().mul(
boxRecover.get(
":, 2:"));
101 .sub(bbWH.mul(0.5f));
103 boundingBoxes = NDArrays.concat(
new NDList(bbXY, bbWH), 1);
105 NDArray landms = list.get(2);
110 boundingBoxes = boundingBoxes.transpose().booleanMask(cutOff, 1).transpose();
111 landms = landms.transpose().booleanMask(cutOff, 1).transpose();
112 prob = prob.booleanMask(cutOff, 1);
115 long[] order = prob.get(1).argSort().get(
":" +
topK).toLongArray();
116 prob = prob.transpose();
117 List<String> retNames =
new ArrayList<>();
118 List<Double> retProbs =
new ArrayList<>();
119 List<BoundingBox> retBB =
new ArrayList<>();
121 Map<Integer, List<BoundingBox>> recorder =
new ConcurrentHashMap<>();
123 for (
int i = order.length - 1; i >= 0; i--) {
124 long currMaxLoc = order[i];
125 float[] classProb = prob.get(currMaxLoc).toFloatArray();
126 int classId = (int) classProb[0];
127 double probability = classProb[1];
129 double[] boxArr = boundingBoxes.get(currMaxLoc).toDoubleArray();
130 double[] landmsArr = landms.get(currMaxLoc).toDoubleArray();
131 Rectangle rect =
new Rectangle(boxArr[0], boxArr[1], boxArr[2], boxArr[3]);
132 List<BoundingBox> boxes = recorder.getOrDefault(classId,
new ArrayList<>());
133 boolean belowIoU =
true;
134 for (BoundingBox box : boxes) {
141 List<Point> keyPoints =
new ArrayList<>();
142 for (
int j = 0; j < 5; j++) {
143 double x = landmsArr[j * 2];
144 double y = landmsArr[j * 2 + 1];
148 new Landmark(boxArr[0], boxArr[1], boxArr[2], boxArr[3], keyPoints);
151 recorder.put(classId, boxes);
152 String className =
"Face";
153 retNames.add(className);
154 retProbs.add(probability);
159 return new DetectedObjects(retNames, retProbs, retBB);
164 int[][] aspectRatio =
new int[
steps.length][2];
165 for (
int i = 0; i <
steps.length; i++) {
166 int wRatio = (int) Math.ceil((
float)
width /
steps[i]);
167 int hRatio = (int) Math.ceil((
float)
height /
steps[i]);
168 aspectRatio[i] =
new int[] {hRatio, wRatio};
171 List<double[]> defaultBoxes =
new ArrayList<>();
173 for (
int idx = 0; idx <
steps.length; idx++) {
174 int[] scale =
scales[idx];
175 for (
int h = 0; h < aspectRatio[idx][0]; h++) {
176 for (
int w = 0; w < aspectRatio[idx][1]; w++) {
177 for (
int i : scale) {
178 double skx = i * 1.0 /
width;
179 double sky = i * 1.0 /
height;
182 defaultBoxes.add(
new double[] {cx, cy, skx, sky});
188 double[][] boxes =
new double[defaultBoxes.size()][defaultBoxes.get(0).length];
189 for (
int i = 0; i < defaultBoxes.size(); i++) {
190 boxes[i] = defaultBoxes.get(i);
192 return manager.create(boxes).clip(0.0, 1.0);
196 private NDArray
decodeLandm(NDArray pre, NDArray priors,
double scaleXY) {
198 pre.get(
":, :2").mul(scaleXY).mul(priors.get(
":, 2:")).add(priors.get(
":, :2"));
200 pre.get(
":, 2:4").mul(scaleXY).mul(priors.get(
":, 2:")).add(priors.get(
":, :2"));
202 pre.get(
":, 4:6").mul(scaleXY).mul(priors.get(
":, 2:")).add(priors.get(
":, :2"));
204 pre.get(
":, 6:8").mul(scaleXY).mul(priors.get(
":, 2:")).add(priors.get(
":, :2"));
206 pre.get(
":, 8:10").mul(scaleXY).mul(priors.get(
":, 2:")).add(priors.get(
":, :2"));
207 return NDArrays.concat(
new NDList(point1, point2, point3, point4, point5), 1);
NDArray boxRecover(NDManager manager, int width, int height, int[][] scales, int[] steps)
NDList processInput(TranslatorContext ctx, Image input)
DetectedObjects processOutput(TranslatorContext ctx, NDList list)
FaceDetectionTranslator(double confThresh, double nmsThresh, double[] variance, int topK, int[][] scales, int[] steps)
NDArray decodeLandm(NDArray pre, NDArray priors, double scaleXY)