BowlerKernel
FaceDetectionTranslator.java
Go to the documentation of this file.
1 /*
2  * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5  * with the License. A copy of the License is located at
6  *
7  * http://aws.amazon.com/apache2.0/
8  *
9  * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10  * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11  * and limitations under the License.
12  */
13 package com.neuronrobotics.bowlerkernel.djl;
14 
15 import ai.djl.modality.cv.Image;
16 import ai.djl.modality.cv.output.BoundingBox;
17 import ai.djl.modality.cv.output.DetectedObjects;
18 import ai.djl.modality.cv.output.Landmark;
19 import ai.djl.modality.cv.output.Point;
20 import ai.djl.modality.cv.output.Rectangle;
21 import ai.djl.ndarray.NDArray;
22 import ai.djl.ndarray.NDArrays;
23 import ai.djl.ndarray.NDList;
24 import ai.djl.ndarray.NDManager;
25 import ai.djl.ndarray.types.DataType;
26 import ai.djl.ndarray.types.Shape;
27 import ai.djl.translate.Translator;
28 import ai.djl.translate.TranslatorContext;
29 
30 import java.util.ArrayList;
31 import java.util.List;
32 import java.util.Map;
33 import java.util.concurrent.ConcurrentHashMap;
34 
35 public class FaceDetectionTranslator implements Translator<Image, DetectedObjects> {
36 
37  private double confThresh;
38  private double nmsThresh;
39  private int topK;
40  private double[] variance;
41  private int[][] scales;
42  private int[] steps;
43  private int width;
44  private int height;
45 
47  double confThresh,
48  double nmsThresh,
49  double[] variance,
50  int topK,
51  int[][] scales,
52  int[] steps) {
53  this.confThresh = confThresh;
54  this.nmsThresh = nmsThresh;
55  this.variance = variance;
56  this.topK = topK;
57  this.scales = scales;
58  this.steps = steps;
59  }
60 
62  @Override
63  public NDList processInput(TranslatorContext ctx, Image input) {
64  width = input.getWidth();
65  height = input.getHeight();
66  NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
67  array = array.transpose(2, 0, 1).flip(0); // HWC -> CHW RGB -> BGR
68  // The network by default takes float32
69  if (!array.getDataType().equals(DataType.FLOAT32)) {
70  array = array.toType(DataType.FLOAT32, false);
71  }
72  NDArray mean =
73  ctx.getNDManager().create(new float[] {104f, 117f, 123f}, new Shape(3, 1, 1));
74  array = array.sub(mean);
75  return new NDList(array);
76  }
77 
79  @Override
80  public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
81  NDManager manager = ctx.getNDManager();
82  double scaleXY = variance[0];
83  double scaleWH = variance[1];
84 
85  NDArray prob = list.get(1).get(":, 1:");
86  prob =
87  NDArrays.stack(
88  new NDList(
89  prob.argMax(1).toType(DataType.FLOAT32, false),
90  prob.max(new int[] {1})));
91 
92  NDArray boxRecover = boxRecover(manager, width, height, scales, steps);
93  NDArray boundingBoxes = list.get(0);
94  NDArray bbWH = boundingBoxes.get(":, 2:").mul(scaleWH).exp().mul(boxRecover.get(":, 2:"));
95  NDArray bbXY =
96  boundingBoxes
97  .get(":, :2")
98  .mul(scaleXY)
99  .mul(boxRecover.get(":, 2:"))
100  .add(boxRecover.get(":, :2"))
101  .sub(bbWH.mul(0.5f));
102 
103  boundingBoxes = NDArrays.concat(new NDList(bbXY, bbWH), 1);
104 
105  NDArray landms = list.get(2);
106  landms = decodeLandm(landms, boxRecover, scaleXY);
107 
108  // filter the result below the threshold
109  NDArray cutOff = prob.get(1).gt(confThresh);
110  boundingBoxes = boundingBoxes.transpose().booleanMask(cutOff, 1).transpose();
111  landms = landms.transpose().booleanMask(cutOff, 1).transpose();
112  prob = prob.booleanMask(cutOff, 1);
113 
114  // start categorical filtering
115  long[] order = prob.get(1).argSort().get(":" + topK).toLongArray();
116  prob = prob.transpose();
117  List<String> retNames = new ArrayList<>();
118  List<Double> retProbs = new ArrayList<>();
119  List<BoundingBox> retBB = new ArrayList<>();
120 
121  Map<Integer, List<BoundingBox>> recorder = new ConcurrentHashMap<>();
122 
123  for (int i = order.length - 1; i >= 0; i--) {
124  long currMaxLoc = order[i];
125  float[] classProb = prob.get(currMaxLoc).toFloatArray();
126  int classId = (int) classProb[0];
127  double probability = classProb[1];
128 
129  double[] boxArr = boundingBoxes.get(currMaxLoc).toDoubleArray();
130  double[] landmsArr = landms.get(currMaxLoc).toDoubleArray();
131  Rectangle rect = new Rectangle(boxArr[0], boxArr[1], boxArr[2], boxArr[3]);
132  List<BoundingBox> boxes = recorder.getOrDefault(classId, new ArrayList<>());
133  boolean belowIoU = true;
134  for (BoundingBox box : boxes) {
135  if (box.getIoU(rect) > nmsThresh) {
136  belowIoU = false;
137  break;
138  }
139  }
140  if (belowIoU) {
141  List<Point> keyPoints = new ArrayList<>();
142  for (int j = 0; j < 5; j++) { // 5 face landmarks
143  double x = landmsArr[j * 2];
144  double y = landmsArr[j * 2 + 1];
145  keyPoints.add(new Point(x * width, y * height));
146  }
147  Landmark landmark =
148  new Landmark(boxArr[0], boxArr[1], boxArr[2], boxArr[3], keyPoints);
149 
150  boxes.add(landmark);
151  recorder.put(classId, boxes);
152  String className = "Face"; // classes.get(classId)
153  retNames.add(className);
154  retProbs.add(probability);
155  retBB.add(landmark);
156  }
157  }
158 
159  return new DetectedObjects(retNames, retProbs, retBB);
160  }
161 
162  private NDArray boxRecover(
163  NDManager manager, int width, int height, int[][] scales, int[] steps) {
164  int[][] aspectRatio = new int[steps.length][2];
165  for (int i = 0; i < steps.length; i++) {
166  int wRatio = (int) Math.ceil((float) width / steps[i]);
167  int hRatio = (int) Math.ceil((float) height / steps[i]);
168  aspectRatio[i] = new int[] {hRatio, wRatio};
169  }
170 
171  List<double[]> defaultBoxes = new ArrayList<>();
172 
173  for (int idx = 0; idx < steps.length; idx++) {
174  int[] scale = scales[idx];
175  for (int h = 0; h < aspectRatio[idx][0]; h++) {
176  for (int w = 0; w < aspectRatio[idx][1]; w++) {
177  for (int i : scale) {
178  double skx = i * 1.0 / width;
179  double sky = i * 1.0 / height;
180  double cx = (w + 0.5) * steps[idx] / width;
181  double cy = (h + 0.5) * steps[idx] / height;
182  defaultBoxes.add(new double[] {cx, cy, skx, sky});
183  }
184  }
185  }
186  }
187 
188  double[][] boxes = new double[defaultBoxes.size()][defaultBoxes.get(0).length];
189  for (int i = 0; i < defaultBoxes.size(); i++) {
190  boxes[i] = defaultBoxes.get(i);
191  }
192  return manager.create(boxes).clip(0.0, 1.0);
193  }
194 
195  // decode face landmarks, 5 points per face
196  private NDArray decodeLandm(NDArray pre, NDArray priors, double scaleXY) {
197  NDArray point1 =
198  pre.get(":, :2").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2"));
199  NDArray point2 =
200  pre.get(":, 2:4").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2"));
201  NDArray point3 =
202  pre.get(":, 4:6").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2"));
203  NDArray point4 =
204  pre.get(":, 6:8").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2"));
205  NDArray point5 =
206  pre.get(":, 8:10").mul(scaleXY).mul(priors.get(":, 2:")).add(priors.get(":, :2"));
207  return NDArrays.concat(new NDList(point1, point2, point3, point4, point5), 1);
208  }
209 }
NDArray boxRecover(NDManager manager, int width, int height, int[][] scales, int[] steps)
DetectedObjects processOutput(TranslatorContext ctx, NDList list)
FaceDetectionTranslator(double confThresh, double nmsThresh, double[] variance, int topK, int[][] scales, int[] steps)
NDArray decodeLandm(NDArray pre, NDArray priors, double scaleXY)