BowlerKernel
PredictorFactory.java
Go to the documentation of this file.
1 package com.neuronrobotics.bowlerkernel.djl;
2 
3 import java.io.IOException;
4 import java.util.ArrayList;
5 import java.util.HashMap;
6 
7 import ai.djl.MalformedModelException;
8 import ai.djl.engine.Engine;
9 import ai.djl.inference.Predictor;
10 import ai.djl.modality.cv.Image;
11 import ai.djl.modality.cv.output.DetectedObjects;
12 import ai.djl.modality.cv.translator.YoloV5TranslatorFactory;
13 import ai.djl.pytorch.jni.JniUtils;
14 import ai.djl.repository.zoo.Criteria;
15 import ai.djl.repository.zoo.ModelNotFoundException;
16 import ai.djl.repository.zoo.ZooModel;
17 import ai.djl.training.util.ProgressBar;
18 
19 public class PredictorFactory {
20  static {
21  Engine.getEngine("PyTorch"); // Make sure PyTorch engine is loaded
22  Engine.getEngine("OnnxRuntime"); // Make sure PyTorch engine is loaded
23  }
24  private static HashMap<ImagePredictorType, Predictor<Image, DetectedObjects>> preloaded = new HashMap<>();
25  private static Predictor<Image, float[]> features = null;
26 
27  public static Predictor<Image, DetectedObjects> imageContentsFactory(ImagePredictorType type)
28  throws ModelNotFoundException, MalformedModelException, IOException {
29  JniUtils.setGraphExecutorOptimize(false);
30 
31  if (preloaded.get(type) == null) {
32 
33  switch (type) {
34  case retinaface:
35  double confThreshretinaface = 0.85f;
36  double nmsThreshretinaface = 0.45f;
37  double[] varianceretinaface = { 0.1f, 0.2f };
38  int topKretinaface = 5000;
39  int[][] scalesretinaface = { { 16, 32 }, { 64, 128 }, { 256, 512 } };
40  int[] stepsretinaface = { 8, 16, 32 };
41  FaceDetectionTranslator translatorretinaface = new FaceDetectionTranslator(confThreshretinaface,
42  nmsThreshretinaface, varianceretinaface, topKretinaface, scalesretinaface, stepsretinaface);
43 
44  Criteria<Image, DetectedObjects> criteriaretinaface = Criteria.builder()
45  .setTypes(Image.class, DetectedObjects.class)
46  .optModelUrls("https://resources.djl.ai/test-models/pytorch/retinaface.zip")
47  // Load model from local file, e.g:
48  .optModelName("retinaface") // specify model file prefix
49  .optTranslator(translatorretinaface).optProgress(new ProgressBar()).optEngine("PyTorch") // Use
50  // PyTorch
51  // engine
52  .build();
53 
54  preloaded.put(type, criteriaretinaface.loadModel().newPredictor());
55  break;
56  case ultranet:
57  double confThresh = 0.85f;
58  double nmsThresh = 0.45f;
59  double[] variance = { 0.1f, 0.2f };
60  int topK = 5000;
61  int[][] scales = { { 10, 16, 24 }, { 32, 48 }, { 64, 96 }, { 128, 192, 256 } };
62  int[] steps = { 8, 16, 32, 64 };
63  FaceDetectionTranslator translator = new FaceDetectionTranslator(confThresh, nmsThresh, variance, topK,
64  scales, steps);
65 
66  Criteria<Image, DetectedObjects> criteria = Criteria.builder()
67  .setTypes(Image.class, DetectedObjects.class)
68  .optModelUrls("https://resources.djl.ai/test-models/pytorch/ultranet.zip")
69  .optTranslator(translator).optProgress(new ProgressBar()).optEngine("PyTorch") // Use PyTorch
70  // engine
71  .build();
72 
73  preloaded.put(type, criteria.loadModel().newPredictor());
74  break;
75  case yolov5:
76  String MODEL_URL = "https://mlrepo.djl.ai/model/cv/object_detection/ai/djl/onnxruntime/yolo5s/0.0.1/yolov5s.zip";
77 
78  Criteria<Image, DetectedObjects> criteria2 = Criteria.builder()
79  .setTypes(Image.class, DetectedObjects.class).optModelUrls(MODEL_URL)
80  .optEngine("OnnxRuntime")
81  .optTranslatorFactory(new YoloV5TranslatorFactory()).build();
82  preloaded.put(type, criteria2.loadModel().newPredictor());
83  break;
84  default:
85  throw new RuntimeException("No Model availible of type " + type);
86 
87  }
88  }
89  return preloaded.get(type);
90  }
91 
92  public static Predictor<Image, float[]> faceFeatureFactory()
93  throws ModelNotFoundException, MalformedModelException, IOException {
94  JniUtils.setGraphExecutorOptimize(false);
95  if (features == null) {
96  Criteria<Image, float[]> criteria = Criteria.builder().setTypes(Image.class, float[].class)
97  .optModelUrls("https://resources.djl.ai/test-models/pytorch/face_feature.zip")
98  .optModelName("face_feature") // specify model file prefix
99  .optTranslator(new FaceFeatureTranslator()).optProgress(new ProgressBar()).optEngine("PyTorch") // Use
100  // PyTorch
101  // engine
102  .build();
103  ZooModel<Image, float[]> model = criteria.loadModel();
104  features = model.newPredictor();
105  }
106  return features;
107  }
108  public static float calculSimilarFaceFeature(float[] feature1, ArrayList<float[]> people) {
109  float ret = 0.0f;
110  float mod1 = 0.0f;
111  float mod2 = 0.0f;
112  int length = feature1.length;
113  for(int j=0;j<people.size();j++) {
114  float[] feature2 = people.get(j);
115  for (int i = 0; i < length; ++i) {
116  ret += feature1[i] * feature2[i];
117  mod1 += feature1[i] * feature1[i];
118  mod2 += feature2[i] * feature2[i];
119  }
120  }
121  return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2) + 1) / 2.0f);
122  }
123  public static float calculSimilarFaceFeature(float[] feature1, float[] feature2) {
124  float ret = 0.0f;
125  float mod1 = 0.0f;
126  float mod2 = 0.0f;
127  int length = feature1.length;
128  for (int i = 0; i < length; ++i) {
129  ret += feature1[i] * feature2[i];
130  mod1 += feature1[i] * feature1[i];
131  mod2 += feature2[i] * feature2[i];
132  }
133  return (float) ((ret / Math.sqrt(mod1) / Math.sqrt(mod2) + 1) / 2.0f);
134  }
135 
136 }
static Predictor< Image, DetectedObjects > imageContentsFactory(ImagePredictorType type)
static float calculSimilarFaceFeature(float[] feature1, ArrayList< float[]> people)
static float calculSimilarFaceFeature(float[] feature1, float[] feature2)
static Predictor< Image, float[]> faceFeatureFactory()
static HashMap< ImagePredictorType, Predictor< Image, DetectedObjects > > preloaded