BowlerKernel
UniquePersonFactory.java
Go to the documentation of this file.
1 package com.neuronrobotics.bowlerkernel.djl;
2 
3 import java.awt.image.BufferedImage;
4 import java.awt.image.DataBufferByte;
5 import java.awt.image.WritableRaster;
6 import java.io.File;
7 import java.io.IOException;
8 
9 import com.google.gson.Gson;
10 import com.google.gson.GsonBuilder;
11 import com.google.gson.reflect.TypeToken;
12 import com.neuronrobotics.bowlerstudio.opencv.OpenCVManager;
13 import com.neuronrobotics.bowlerstudio.scripting.ScriptingEngine;
14 import com.neuronrobotics.sdk.common.DeviceManager;
15 import com.neuronrobotics.sdk.common.NonBowlerDevice;
16 
17 import ai.djl.MalformedModelException;
18 import ai.djl.inference.Predictor;
19 import ai.djl.modality.cv.Image;
20 import ai.djl.modality.cv.ImageFactory;
21 import ai.djl.pytorch.jni.JniUtils;
22 import ai.djl.repository.zoo.ModelNotFoundException;
23 import javafx.application.Platform;
24 import javafx.embed.swing.SwingFXUtils;
25 import javafx.event.ActionEvent;
26 import javafx.event.EventHandler;
27 import javafx.scene.control.Label;
28 import javafx.scene.control.TextField;
29 import javafx.scene.image.ImageView;
30 import javafx.scene.image.WritableImage;
31 import javafx.scene.layout.HBox;
32 import javafx.scene.layout.VBox;
33 
34 import java.lang.reflect.Type;
35 import java.nio.file.Files;
36 import java.nio.file.Path;
37 import java.nio.file.Paths;
38 import java.util.ArrayList;
39 import java.util.HashMap;
40 import java.util.HashSet;
41 import java.util.List;
42 
43 import org.opencv.core.Mat;
44 import org.opencv.core.Point;
45 import org.opencv.core.Rect;
46 import org.opencv.core.Size;
47 import org.opencv.imgproc.Imgproc;
48 
49 public class UniquePersonFactory extends NonBowlerDevice {
50  private File database;
51  Type TT_mapStringString = new TypeToken<HashSet<UniquePerson>>() {
52  }.getType();
53  Gson gson = new GsonBuilder().disableHtmlEscaping().setPrettyPrinting().create();
54  private HashSet<UniquePerson> longTermMemory = new HashSet<>();;
55  private ArrayList<UniquePerson> shortTermMemory = new ArrayList<>();;
56  Predictor<Image, float[]> features;
57 
58  private static double confidence = 0.8;
59  private static long timeout = 30000;
60  private static long countPeople = 1;
61  private static int numberOfTrainingHashes = 30;
62  private Thread processor;
63  private boolean run = true;
64  private HashMap<BufferedImage, Point> factoryFromImageTMp = null;
65  private HashMap<UniquePerson, Point> currentPersons = null;
66  private ImageFactory factory;
67  private VBox workingMemory = new VBox();
68  private boolean processFlag = false;
69 
70  private class UniquePersonUI {
71  HBox box = new HBox();
72  TextField name = new TextField();
73  Label percent = new Label();
75  box.getChildren().addAll(name);
76  box.getChildren().addAll(percent);
77  }
78  }
79 
80  private HashMap<UniquePerson, UniquePersonUI> uiElelments = new HashMap<UniquePerson, UniquePersonFactory.UniquePersonUI>();
81  private HashMap<BufferedImage, Point> localMailbox;
82 
84  if (uiElelments.get(p) == null) {
85  uiElelments.put(p, new UniquePersonUI());
86  }
87  return uiElelments.get(p);
88  }
89 
90  public static UniquePersonFactory get() {
91  return get(new File(ScriptingEngine.getWorkspace().getAbsolutePath() + "/face_memory.json"));
92 
93  }
94 
95  public static UniquePersonFactory get(File index) {
96  return (UniquePersonFactory) DeviceManager.getSpecificDevice("UniquePersonFactory_" + index.getName(),
97  () -> new UniquePersonFactory(index));
98 
99  }
100 
102  this.setDatabase(database);
103 
104  try {
106  } catch (ModelNotFoundException e) {
107  // TODO Auto-generated catch block
108  e.printStackTrace();
109  } catch (MalformedModelException e) {
110  // TODO Auto-generated catch block
111  e.printStackTrace();
112  } catch (IOException e) {
113  // TODO Auto-generated catch block
114  e.printStackTrace();
115  }
116  factory = ImageFactory.getInstance();
117 
118  }
119 
120  public void clear() {
121  while(isProcessFlag()) {
122  try {
123  Thread.sleep(10);
124  } catch (InterruptedException e) {
125  return;
126  }
127  }
128  if (currentPersons != null)
129  synchronized (currentPersons) {
130  currentPersons.clear();
131  }
132  if (factoryFromImageTMp != null) {
133  factoryFromImageTMp.clear();
134  factoryFromImageTMp = null;
135  }
136  }
137 
138  public String save() {
139  String jsonString = gson.toJson(longTermMemory, TT_mapStringString);
140  Path path = Paths.get(getDatabase().getAbsolutePath());
141  byte[] strToBytes = jsonString.getBytes();
142 
143  try {
144  Files.write(path, strToBytes);
145  } catch (IOException e) {
146  // TODO Auto-generated catch block
147  e.printStackTrace();
148  }
149  if (workingMemory != null)
150  for (UniquePerson up : shortTermMemory) {
151  Platform.runLater(() -> {
152  getUI(up).name.setText(up.name);
153  });
154  }
155  return jsonString;
156  }
157 
161  public static double getConfidence() {
162  return confidence;
163  }
164 
168  public static void setConfidence(double confidence) {
170  }
171 
175  public static long getTimeout() {
176  return timeout;
177  }
178 
182  public static void setTimeout(long timeout) {
184  }
185 
189  public static int getNumberOfTrainingHashes() {
190  return numberOfTrainingHashes;
191  }
192 
198  }
199 
200  @Override
201  public void disconnectDeviceImp() {
202  run = false;
203  if (processor != null)
204  processor.interrupt();
205  }
206 
207  @Override
208  public boolean connectDeviceImp() {
210  run = true;
211  processor = new Thread(() -> {
212  processBlocking();
213  });
214  processor.start();
215  return run;
216  }
217 
218  public void addFace(Mat matrix, Rect crop, ai.djl.modality.cv.output.Point nose) {
219  if (factoryFromImageTMp == null) {
220  factoryFromImageTMp = new HashMap<BufferedImage, Point>();
221  }
222  HashMap<BufferedImage, Point> local = factoryFromImageTMp;
223  try {
224  Mat tmpImg = new Mat(matrix, crop);
225  Mat image_roi = new Mat();
226  int fixedWidth = 100;
227  double scale = ((double) tmpImg.cols()) / ((double) fixedWidth);
228 
229  // Converting the image to grey scale
230  Imgproc.cvtColor(tmpImg, image_roi, Imgproc.COLOR_RGB2GRAY);
231  int rows = (int) (image_roi.rows() / scale);
232  Mat image_roi_resized = new Mat(rows, 100, Imgproc.COLOR_RGB2GRAY);
233  Imgproc.resize(image_roi, image_roi_resized, image_roi_resized.size(), 0, 0, Imgproc.INTER_NEAREST);
234  BufferedImage image = new BufferedImage(image_roi_resized.width(), image_roi_resized.height(),
235  BufferedImage.TYPE_BYTE_GRAY);
236  WritableRaster raster = image.getRaster();
237  DataBufferByte dataBuffer = (DataBufferByte) raster.getDataBuffer();
238  byte[] data = dataBuffer.getData();
239  image_roi_resized.get(0, 0, data);
240  local.put(image, new Point(nose.getX(), nose.getY()));
241  } catch (Throwable tr) {
242  // tr.printStackTrace();
243  }
244  }
245 
246  private void processBlocking() {
247  JniUtils.setGraphExecutorOptimize(false);
248  while (run) {
249  if (!isProcessFlag()) {
250  try {
251  Thread.sleep(16);
252  } catch (InterruptedException e) {
253  break;
254  }
255  continue;
256  }
257 
258  try {
259 
260  for (UniquePerson up : shortTermMemory) {
261  if (up.features.size() >= numberOfTrainingHashes) {
262  if (!longTermMemory.contains(up)) {
263  longTermMemory.add(up);
264  System.out.println("Saving new Face to database " + up.name);
265  save();
266  }
267  }
268  }
269  HashMap<BufferedImage, Point> local = localMailbox;
270  localMailbox = null;
271  HashMap<UniquePerson, org.opencv.core.Point> tmpPersons = new HashMap<>();
272  for (BufferedImage imgBuff : local.keySet()) {
273  ai.djl.modality.cv.Image cmp = factory.fromImage(imgBuff);
274  Point point = local.get(imgBuff);
275  // println "Processing new image "
276  if (imgBuff.getHeight() < imgBuff.getWidth())
277  continue;
278  float[] id;
279  try {
280  id = features.predict(cmp);
281  } catch (Throwable ex) {
282  System.out.println("Image failed h=" + imgBuff.getHeight() + " w=" + imgBuff.getWidth());
283  // ex.printStackTrace();
284  continue;
285  }
286  boolean found = false;
287  ArrayList<UniquePerson> duplicates = new ArrayList<UniquePerson>();
288  // check long term first.
289  for (UniquePerson pp : longTermMemory) {
290  found = processMemory(tmpPersons, imgBuff, point, id, found, duplicates, pp);
291  }
292  for (UniquePerson pp : shortTermMemory) {
293  found = processMemory(tmpPersons, imgBuff, point, id, found, duplicates, pp);
294  }
295  for (int i = 0; i < shortTermMemory.size(); i++) {
296  UniquePerson p = shortTermMemory.get(i);
297  if ((System.currentTimeMillis() - p.time) > timeout && p.timesSeen < numberOfTrainingHashes) {
298  duplicates.add(p);
299  }
300  }
301  for (UniquePerson p : duplicates) {
302  if (!longTermMemory.contains(p)) {
303  shortTermMemory.remove(p);
304  if (workingMemory != null) {
305  UniquePersonUI UI = getUI(p);
306  uiElelments.remove(p);
307  Platform.runLater(() -> {
308  workingMemory.getChildren().remove(UI.box);
309  });
310  }
311  }
312  // println "Removing "+p.name
313  }
314 
315  if (found == false) {
316  resetHash();
317  UniquePerson p = new UniquePerson();
318  p.features.add(id);
319  p.name = "Person " + (countPeople);
320  p.UUID = countPeople;
321  String tmpDirsLocation = System.getProperty("java.io.tmpdir") + "/idFiles/" + p.name + ".jpeg";
322  p.referenceImageLocation = tmpDirsLocation;
323  // println "New person found! "+tmpDirsLocation
324  shortTermMemory.add(p);
325  }
326  }
327  local.clear();
328  local = null;
329  if (currentPersons != null) {
330  synchronized (currentPersons) {
331  currentPersons.clear();
332  currentPersons = tmpPersons;
333  }
334  } else
335  currentPersons = tmpPersons;
336  } catch (Throwable tr) {
337  tr.printStackTrace(); // run=false;
338  }
339  processFlag = false;
340  }
341  }
342 
343  private EventHandler<ActionEvent> setAction(UniquePerson p) {
344  return event -> {
345  String newName = getUI(p).name.getText();
346  System.out.println("Renaming " + p.name + " to " + newName);
347  p.name = newName;
348  new Thread(() -> save()).start();
349  };
350  }
351 
352  private boolean processMemory(HashMap<UniquePerson, org.opencv.core.Point> tmpPersons, BufferedImage imgBuff,
353  Point point, float[] id, boolean found, ArrayList<UniquePerson> duplicates, UniquePerson pp) {
354  UniquePerson p = pp;
355 
356  int count = 0;
357  // for(int i=0;i<p.features.size();i++) {
358  // float[] featureFloats =p.features.get(i);
359  float result = PredictorFactory.calculSimilarFaceFeature(id, p.features);
360  // println "Difference from "+p.name+" is "+result
361  if (result > p.confidenceTarget) {
362  if (found) {
363  duplicates.add(p);
364  } else {
365  count++;
366 
367  p.timesSeen++;
368  found = true;
369  if (p.timesSeen > 2)
370  tmpPersons.put(p, point);
371  UniquePersonUI UI = getUI(p);
372  if (workingMemory != null)
373  if (p.timesSeen > 3 && !workingMemory.getChildren().contains(UI.box)) {
374  // on the third seen, display
375  UI.name.setText(p.name);
376  WritableImage tmpImg = SwingFXUtils.toFXImage(imgBuff, null);
377  UI.box.getChildren().addAll(new ImageView(tmpImg));
378 
379  UI.name.setOnAction(setAction(p));
380  Platform.runLater(() -> {
381  workingMemory.getChildren().add(UI.box);
382  });
383  }
384  p.time = System.currentTimeMillis();
385  // if(result<(confidence+0.01))
386  int percent = (int) (((double) p.features.size()) / ((double) numberOfTrainingHashes) * 100);
387  if (percent > 100)
388  percent = 100;
389  double perc = percent;
390  // println "Trained "+percent;
391  if (workingMemory != null)
392  Platform.runLater(() -> {
393  UI.percent.setText(" : Trained " + perc + "%");
394  });
395 
396  if (p.features.size() < numberOfTrainingHashes && result < 0.95) {
397  p.features.add(id);
398  if (p.features.size() == numberOfTrainingHashes) {
399  // println " Trained "+p.name;
400  if (workingMemory != null)
401  Platform.runLater(() -> {
402  UI.box.getChildren().addAll(new Label(" Done! "));
403  });
404  }
405  }
406  if (p.features.size() >= numberOfTrainingHashes && p.confidenceTarget != confidence) {
407  p.confidenceTarget = confidence;
408  save();
409  }
410  }
411  }
412  return found;
413  }
414 
415  @Override
416  public ArrayList<String> getNamespacesImp() {
417  // TODO Auto-generated method stub
418  return null;
419  }
420 
424  public HashMap<UniquePerson, Point> getCurrentPersons() {
425  HashMap<UniquePerson, Point> tmp = new HashMap<UniquePerson, Point>();
426  if (currentPersons == null)
427  return null;
428  else
429  synchronized (currentPersons) {
430  tmp.putAll(currentPersons);
431  }
432 
433  return tmp;
434  }
435 
439  public void setWorkingMemory(VBox workingMemory) {
440  if (this.workingMemory != null) {
441  this.workingMemory.getChildren().clear();
442  for(UniquePersonUI ui:uiElelments.values()) {
443  ui.box.getChildren().clear();
444  }
445  uiElelments.clear();
446  }
447  clear();
448  this.workingMemory = workingMemory;
449  }
450 
454  public File getDatabase() {
455  return database;
456  }
457 
458  private void resetHash() {
459  countPeople = 0;
460  for (UniquePerson u : longTermMemory) {
461  if (u.UUID >= countPeople)
462  countPeople = u.UUID + 1;
463 
464  }
465  for (UniquePerson u : shortTermMemory) {
466  if (u.UUID >= countPeople)
467  countPeople = u.UUID + 1;
468 
469  }
470  }
471 
475  public void setDatabase(File database) {
476  this.database = database;
477  if (!database.exists())
478  try {
479  database.createNewFile();
480  save();
481  } catch (IOException e1) {
482  // TODO Auto-generated catch block
483  e1.printStackTrace();
484  }
485  else {
486  String jsonString;
487  try {
488  jsonString = new String(Files.readAllBytes(Paths.get(database.getAbsolutePath())));
489  longTermMemory = gson.fromJson(jsonString, TT_mapStringString);
490  resetHash();
491  } catch (IOException e) {
492  // TODO Auto-generated catch block
493  e.printStackTrace();
494  }
495  }
496  }
497 
501  public boolean isProcessFlag() {
502  return processFlag;
503  }
504 
508  public void setProcessFlag() {
509  if (isProcessFlag()) {
510  // discard this set of faces to process.
511  factoryFromImageTMp.clear();
512  factoryFromImageTMp = null;
513  return;// processor is running, do not interrupt it.
514  }
515  localMailbox = new HashMap<>();
516  if (factoryFromImageTMp != null)
518  else
519  return;
520  factoryFromImageTMp = null;
521  this.processFlag = true;
522 
523  }
524 }
static float calculSimilarFaceFeature(float[] feature1, ArrayList< float[]> people)
static Predictor< Image, float[]> faceFeatureFactory()
HashMap< UniquePerson, UniquePersonUI > uiElelments
boolean processMemory(HashMap< UniquePerson, org.opencv.core.Point > tmpPersons, BufferedImage imgBuff, Point point, float[] id, boolean found, ArrayList< UniquePerson > duplicates, UniquePerson pp)
static void setNumberOfTrainingHashes(int numberOfTrainingHashes)
EventHandler< ActionEvent > setAction(UniquePerson p)
void addFace(Mat matrix, Rect crop, ai.djl.modality.cv.output.Point nose)
static Object getSpecificDevice(String name, IDeviceProvider provider)