import * as ml5 from "ml5";

class NeuralNet {
  nn = null;
  trained = false;

  classify(sample, callback) {
    const result = NeuralNet.formatData([sample]);
    if (result.length > 0 && this.nn !== null) {
      const { label, ...data } = result[0];
      this.nn.classify(data, handleResults);
      function handleResults(error, result) {
        if (error) {
          console.error("ERROR Classifying gesture", error);
          callback(error, null);
        } else {
          callback(error, result);
        }
      }
    } else {
      callback("Net not ready or no data", null);
    }
  }

  static formatData(samples) {
    const data = [];
    samples.forEach((sample) => {
      if (
        sample.results.multiHandLandmarks &&
        sample.results.multiHandLandmarks.length > 0
      ) {
        const dataPoint = {};
        const firstHand = sample.results.multiHandLandmarks[0];
        firstHand.forEach((value, idx) => {
          dataPoint[idx.toString() + "_x"] = value.x;
          dataPoint[idx.toString() + "_y"] = value.y;
          dataPoint[idx.toString() + "_z"] = value.z;
        });
        dataPoint.label = sample.label;
        // dataPoint.side = sample.results.multiHandedness[0].label === "left" ? 0 : 1;
        // dataPoint.conf = sample.results.multiHandedness[0].score;
        data.push(dataPoint);
      }
    });
    return data;
  }

  load(callback) {
    this.nnLoadedCallback = callback;
    this.nn = this.createNet();
    const modelInfo = {

      model: "data/gestureModel04/model.json",
      metadata: "data/gestureModel04/model_meta.json",
      weights: "data/gestureModel04/model.weights.bin",

      // model: "data/gestureModel02/model.json",
      // metadata: "data/gestureModel02/model_meta.json",
      // weights: "data/gestureModel02/model.weights.bin",

      // model: "data/gesturesmodel/model.json",
      // metadata: "data/gesturesmodel/model_meta.json",
      // weights: "data/gesturesmodel/model.weights.bin",
    };
    this.nn.load(modelInfo, this.modelLoadedCallback.bind(this));
  }

  modelLoadedCallback(result) {
    console.log("Neural Net Model Loaded");
    this.nnLoadedCallback(result);
  }

  save() {
    if (this.nn !== null && this.trained) {
      this.nn.save("handGestureModel", () => {
        console.log("Model Saved!");
      });
    } else {
      console.error("ERROR: Can't save net. Net not trained yet!");
    }
  }

  createNet() {
    const layers = [
      {
        type: "dense",
        units: 24,
        activation: "relu",
      },
      {
        type: "dense",
        units: 16,
        activation: "relu",
      },
      {
        type: "dense",
        activation: "softmax",
      },
    ];
    const options = {
      task: "classification",
      debug: true,
      layers,
    };
    return ml5.neuralNetwork(options);
  }

  train(samples) {
    const data = NeuralNet.formatData(samples);
    const nn = this.createNet();
    data.forEach((item) => {
      const { label, ...inputs } = item;
      const output = {
        label,
      };
      nn.addData(inputs, output);
    });
    nn.normalizeData();
    const trainingOptions = {
      epochs: 200,
      batchSize: 12,
    };
    nn.train(trainingOptions, () => {
      console.log("Finished Training");
      this.nn = nn;
      this.trained = true;
    });
  }
}

export default NeuralNet;
