1 package com.neuronrobotics.bowlerkernel.djl;
3 import java.awt.image.BufferedImage;
4 import java.awt.image.DataBufferByte;
5 import java.awt.image.WritableRaster;
7 import java.io.IOException;
9 import com.google.gson.Gson;
10 import com.google.gson.GsonBuilder;
11 import com.google.gson.reflect.TypeToken;
12 import com.neuronrobotics.bowlerstudio.opencv.OpenCVManager;
13 import com.neuronrobotics.bowlerstudio.scripting.ScriptingEngine;
14 import com.neuronrobotics.sdk.common.DeviceManager;
15 import com.neuronrobotics.sdk.common.NonBowlerDevice;
17 import ai.djl.MalformedModelException;
18 import ai.djl.inference.Predictor;
19 import ai.djl.modality.cv.Image;
20 import ai.djl.modality.cv.ImageFactory;
21 import ai.djl.pytorch.jni.JniUtils;
22 import ai.djl.repository.zoo.ModelNotFoundException;
23 import javafx.application.Platform;
24 import javafx.embed.swing.SwingFXUtils;
25 import javafx.event.ActionEvent;
26 import javafx.event.EventHandler;
27 import javafx.scene.control.Label;
28 import javafx.scene.control.TextField;
29 import javafx.scene.image.ImageView;
30 import javafx.scene.image.WritableImage;
31 import javafx.scene.layout.HBox;
32 import javafx.scene.layout.VBox;
34 import java.lang.reflect.Type;
35 import java.nio.file.Files;
36 import java.nio.file.Path;
37 import java.nio.file.Paths;
38 import java.util.ArrayList;
39 import java.util.HashMap;
40 import java.util.HashSet;
41 import java.util.List;
43 import org.opencv.core.Mat;
44 import org.opencv.core.Point;
45 import org.opencv.core.Rect;
46 import org.opencv.core.Size;
47 import org.opencv.imgproc.Imgproc;
51 Type TT_mapStringString =
new TypeToken<HashSet<UniquePerson>>() {
53 Gson gson =
new GsonBuilder().disableHtmlEscaping().setPrettyPrinting().create();
56 Predictor<Image, float[]> features;
63 private boolean run =
true;
71 HBox box =
new HBox();
72 TextField name =
new TextField();
73 Label percent =
new Label();
75 box.getChildren().addAll(name);
76 box.getChildren().addAll(percent);
106 }
catch (ModelNotFoundException e) {
109 }
catch (MalformedModelException e) {
112 }
catch (IOException e) {
116 factory = ImageFactory.getInstance();
124 }
catch (InterruptedException e) {
139 String jsonString = gson.toJson(
longTermMemory, TT_mapStringString);
140 Path path = Paths.get(
getDatabase().getAbsolutePath());
141 byte[] strToBytes = jsonString.getBytes();
144 Files.write(path, strToBytes);
145 }
catch (IOException e) {
151 Platform.runLater(() -> {
152 getUI(up).name.setText(up.name);
218 public void addFace(Mat matrix, Rect crop, ai.djl.modality.cv.output.Point nose) {
224 Mat tmpImg =
new Mat(matrix, crop);
225 Mat image_roi =
new Mat();
226 int fixedWidth = 100;
227 double scale = ((double) tmpImg.cols()) / ((
double) fixedWidth);
230 Imgproc.cvtColor(tmpImg, image_roi, Imgproc.COLOR_RGB2GRAY);
231 int rows = (int) (image_roi.rows() / scale);
232 Mat image_roi_resized =
new Mat(rows, 100, Imgproc.COLOR_RGB2GRAY);
233 Imgproc.resize(image_roi, image_roi_resized, image_roi_resized.size(), 0, 0, Imgproc.INTER_NEAREST);
234 BufferedImage image =
new BufferedImage(image_roi_resized.width(), image_roi_resized.height(),
235 BufferedImage.TYPE_BYTE_GRAY);
236 WritableRaster raster = image.getRaster();
237 DataBufferByte dataBuffer = (DataBufferByte) raster.getDataBuffer();
238 byte[] data = dataBuffer.getData();
239 image_roi_resized.get(0, 0, data);
240 local.put(image,
new Point(nose.getX(), nose.getY()));
241 }
catch (Throwable tr) {
247 JniUtils.setGraphExecutorOptimize(
false);
252 }
catch (InterruptedException e) {
264 System.out.println(
"Saving new Face to database " + up.name);
271 HashMap<
UniquePerson, org.opencv.core.Point> tmpPersons =
new HashMap<>();
272 for (BufferedImage imgBuff : local.keySet()) {
273 ai.djl.modality.cv.Image cmp =
factory.fromImage(imgBuff);
274 Point point = local.get(imgBuff);
276 if (imgBuff.getHeight() < imgBuff.getWidth())
280 id = features.predict(cmp);
281 }
catch (Throwable ex) {
282 System.out.println(
"Image failed h=" + imgBuff.getHeight() +
" w=" + imgBuff.getWidth());
286 boolean found =
false;
287 ArrayList<UniquePerson> duplicates =
new ArrayList<UniquePerson>();
290 found =
processMemory(tmpPersons, imgBuff, point,
id, found, duplicates, pp);
293 found =
processMemory(tmpPersons, imgBuff, point,
id, found, duplicates, pp);
307 Platform.runLater(() -> {
315 if (found ==
false) {
321 String tmpDirsLocation = System.getProperty(
"java.io.tmpdir") +
"/idFiles/" + p.
name +
".jpeg";
336 }
catch (Throwable tr) {
337 tr.printStackTrace();
345 String newName =
getUI(p).name.getText();
346 System.out.println(
"Renaming " + p.
name +
" to " + newName);
348 new Thread(() ->
save()).start();
352 private boolean processMemory(HashMap<UniquePerson, org.opencv.core.Point> tmpPersons, BufferedImage imgBuff,
353 Point point,
float[]
id,
boolean found, ArrayList<UniquePerson> duplicates,
UniquePerson pp) {
361 if (result > p.confidenceTarget) {
370 tmpPersons.put(p, point);
375 UI.name.setText(p.
name);
376 WritableImage tmpImg = SwingFXUtils.toFXImage(imgBuff,
null);
377 UI.box.getChildren().addAll(
new ImageView(tmpImg));
380 Platform.runLater(() -> {
384 p.
time = System.currentTimeMillis();
389 double perc = percent;
392 Platform.runLater(() -> {
393 UI.percent.setText(
" : Trained " + perc +
"%");
401 Platform.runLater(() -> {
402 UI.box.getChildren().addAll(
new Label(
" Done! "));
425 HashMap<UniquePerson, Point> tmp =
new HashMap<UniquePerson, Point>();
440 if (this.workingMemory !=
null) {
441 this.workingMemory.getChildren().clear();
443 ui.box.getChildren().clear();
481 }
catch (IOException e1) {
483 e1.printStackTrace();
488 jsonString =
new String(Files.readAllBytes(Paths.get(
database.getAbsolutePath())));
491 }
catch (IOException e) {
521 this.processFlag =
true;
static float calculSimilarFaceFeature(float[] feature1, ArrayList< float[]> people)
static Predictor< Image, float[]> faceFeatureFactory()
static double getConfidence()
HashMap< UniquePerson, Point > getCurrentPersons()
static int numberOfTrainingHashes
UniquePersonFactory(File database)
void disconnectDeviceImp()
HashMap< UniquePerson, UniquePersonUI > uiElelments
HashMap< BufferedImage, Point > factoryFromImageTMp
boolean processMemory(HashMap< UniquePerson, org.opencv.core.Point > tmpPersons, BufferedImage imgBuff, Point point, float[] id, boolean found, ArrayList< UniquePerson > duplicates, UniquePerson pp)
ArrayList< UniquePerson > shortTermMemory
HashMap< BufferedImage, Point > localMailbox
UniquePersonUI getUI(UniquePerson p)
HashMap< UniquePerson, Point > currentPersons
static void setTimeout(long timeout)
static void setNumberOfTrainingHashes(int numberOfTrainingHashes)
HashSet< UniquePerson > longTermMemory
EventHandler< ActionEvent > setAction(UniquePerson p)
boolean connectDeviceImp()
static int getNumberOfTrainingHashes()
void setWorkingMemory(VBox workingMemory)
ArrayList< String > getNamespacesImp()
static void setConfidence(double confidence)
void setDatabase(File database)
void addFace(Mat matrix, Rect crop, ai.djl.modality.cv.output.Point nose)
ArrayList< float[]> features
String referenceImageLocation
static File getWorkspace()
static Object getSpecificDevice(String name, IDeviceProvider provider)