import vtkPolyData from "@kitware/vtk.js/Common/DataModel/PolyData";
import { vtkUtil } from "./vtkUtils";
import { Mesh, NormalMesh } from "../model/Mesh";
import { OccluUtil, WasmCmdType } from "./OccluUtil";
import WasmWorkerClient from "../workers/WasmWorkerClient";

class MeshAligner {
  private upperJawWorkers: WasmWorkerClient[] = [];
  private lowerJawWorkers: WasmWorkerClient[] = [];
  setWorkers(
    upperJawWorkers: WasmWorkerClient[],
    lowerJawWorkers: WasmWorkerClient[]
  ) {
    this.upperJawWorkers = upperJawWorkers;
    this.lowerJawWorkers = lowerJawWorkers;
  }

  getWorkerForJaw(isUpperJaw: boolean, index: number) {
    if (isUpperJaw) {
      return this.upperJawWorkers[index % this.upperJawWorkers.length];
    } else {
      return this.lowerJawWorkers[index % this.lowerJawWorkers.length];
    }
  }

  compute(
    targetPd: vtkPolyData,
    referPd: vtkPolyData,
    isUpperJaw: boolean,
    index: number
  ): Promise<number[]> {
    const worker = this.getWorkerForJaw(isUpperJaw, index);
    if (!worker || !targetPd || !referPd) {
      return Promise.reject("worker not ready");
    }
    return new Promise((resolve) => {
      const data = OccluUtil.toMeshPairT(targetPd, referPd);
      if (
        !OccluUtil.isMeshDataValid(data.first) ||
        !OccluUtil.isMeshDataValid(data.second)
      ) {
        resolve([]);
        return;
      }
      worker.postMessage(
        { type: WasmCmdType.ComputeAlignmentMatrix, data: data },
        (msg: any) => {
          const matrix = msg.data;
          resolve(matrix || []);
        }
      );
    });
  }
}

export const computeAlignmentMatrix = (
  refMesh: NormalMesh,
  tarMesh: NormalMesh,
  index: number
): Promise<number[]> => {
  return new Promise((resolve, reject) => {
    const transform = (a: Mesh, b: Mesh): Promise<number[]> => {
      return new Promise((resolve, reject) => {
        if (a && b) {
          const a_pd = vtkUtil.getPolyData(a.actor);
          const b_pd = vtkUtil.getPolyData(b.actor);
          const isUpperJaw = a.isUpperJaw();
          meshAligner
            .compute(a_pd, b_pd, isUpperJaw, index)
            .then((matrix: number[]) => {
              resolve(matrix);
            })
            .catch(reject);
        } else {
          reject("Invalid input");
        }
      });
    };
    if (refMesh && tarMesh) {
      const transformPromise = transform(refMesh, tarMesh);
      transformPromise
        .then((matrix) => {
          resolve(matrix);
        })
        .catch(reject);
    } else {
      reject("Invalid input");
    }
  });
};

export const meshAligner = new MeshAligner();
export default meshAligner;
