import { Mesh, NormalMesh } from "../model/Mesh";
import { OccluUtil, MeshPairT, DataPairT, WasmCmdType } from "./OccluUtil";
import WasmWorkerClient from "../workers/WasmWorkerClient";
import { vtkUtil } from "./vtkUtils";
import vtkActor from "@kitware/vtk.js/Rendering/Core/Actor";

class DistanceMapping {
  private worker!: WasmWorkerClient;
  private secondWorker!: WasmWorkerClient;
  private isProcessing = false;
  private meshPair = new DataPairT<DataPairT<Mesh>>();
  private lut = OccluUtil.createLutForDistMapping();
  private computed = false;

  isComputed = () => this.computed;

  setWorker(worker: WasmWorkerClient, secondWorker: WasmWorkerClient) {
    this.worker = worker;
    this.secondWorker = secondWorker;
  }

  private cache: Map<
    string,
    { upperScalars: Float32Array; lowerScalars: Float32Array }
  >;
  constructor() {
    this.cache = new Map();
  }

  private generateCacheKey(
    upperPair: DataPairT<NormalMesh>,
    lowerPair: DataPairT<NormalMesh>
  ): string {
    const upperPd1 = vtkUtil.getPolyData(upperPair.first.actor);
    const upperPd2 = vtkUtil.getPolyData(upperPair.second.actor);
    const lowerPd1 = vtkUtil.getPolyData(lowerPair.first.actor);
    const lowerPd2 = vtkUtil.getPolyData(lowerPair.second.actor);
    // Combine bounds and point counts from both upper and lower pairs
    const upperKey = `${upperPd1.getBounds().join(",")}-${upperPd1.getNumberOfPoints()}-${upperPd2.getBounds().join(",")}-${upperPd2.getNumberOfPoints()}`;
    const lowerKey = `${lowerPd1.getBounds().join(",")}-${lowerPd1.getNumberOfPoints()}-${lowerPd2.getBounds().join(",")}-${lowerPd2.getNumberOfPoints()}`;
    // Concatenate to form a unique cache key
    return `${upperKey}-${lowerKey}`;
  }

  private applyCachedScalars(actor: vtkActor, scalars: Float32Array) {
    if (scalars?.length) {
      OccluUtil.setMappingScalars(
        actor,
        scalars,
        OccluUtil.distMappingRange(),
        this.lut
      );
    }
  }

  compute(
    upperPair: DataPairT<NormalMesh>,
    lowerPair: DataPairT<NormalMesh>
  ): Promise<{ upperActor: vtkActor; lowerActor: vtkActor }> {
    // Generate a a unique key for the distance map polydata input
    const cacheKey = this.generateCacheKey(upperPair, lowerPair);
    if (this.cache.has(cacheKey)) {
      const cachedData = this.cache.get(cacheKey)!;
      this.applyCachedScalars(upperPair.first.actor, cachedData.upperScalars);
      this.applyCachedScalars(lowerPair.first.actor, cachedData.lowerScalars);
      return Promise.resolve({
        upperActor: upperPair.first.actor,
        lowerActor: lowerPair.first.actor,
      });
    }

    if (this.isProcessing) {
      return Promise.resolve({
        upperActor: upperPair.first.actor,
        lowerActor: lowerPair.first.actor,
      });
    }
    console.time("DistanceMapping");
    this.meshPair.first = upperPair;
    this.meshPair.second = lowerPair;
    this.isProcessing = true;
    return new Promise((resolve) => {
      let resolveCount = 0;
      const results: { upperActor: any; lowerActor: any } = {
        upperActor: null,
        lowerActor: null,
      };
      const onFinally = () => {
        ++resolveCount;
        if (resolveCount === 2) {
          this.isProcessing = false;
          this.computed = true;

          // Store the results in the cache
          this.cache.set(cacheKey, {
            upperScalars: results.upperActor
              .getMapper()
              .getInputData()
              .getPointData()
              .getScalars()
              .getData(),
            lowerScalars: results.lowerActor
              .getMapper()
              .getInputData()
              .getPointData()
              .getScalars()
              .getData(),
          });

          resolve(results);
        }
      };
      this.doCompute(upperPair, this.worker).then((actor) => {
        results.upperActor = actor;
        onFinally();
      });
      this.doCompute(lowerPair, this.secondWorker).then((actor) => {
        results.lowerActor = actor;
        onFinally();
      });
    });
  }

  private postToWorker(
    msg: MeshPairT,
    worker: WasmWorkerClient
  ): Promise<Float32Array> {
    if (worker && msg.first && msg.second) {
      const type = WasmCmdType.ComputeDistanceMapping;
      return new Promise((resolve) => {
        worker.postMessage({ type, data: msg }, (e) => {
          resolve(e.data.scalars);
        });
      });
    } else {
      return Promise.reject();
    }
  }

  private doCompute(
    pair: DataPairT<Mesh>,
    worker: WasmWorkerClient
  ): Promise<any> {
    if (pair?.first?.actor && pair?.second?.actor) {
      const first = OccluUtil.toMeshDataT(
        vtkUtil.getPolyData(pair.first.actor)
      );
      const second = OccluUtil.toMeshDataT(
        vtkUtil.getPolyData(pair.second.actor)
      );
      return new Promise((resolve) => {
        this.postToWorker({ first, second }, worker)
          .then((scalars) => {
            if (scalars?.length) {
              OccluUtil.setMappingScalars(
                pair.first.actor,
                scalars,
                OccluUtil.distMappingRange(),
                this.lut
              );
            }
          })
          .finally(() => resolve(pair.first.actor));
      });
    } else {
      return Promise.resolve(pair.first.actor);
    }
  }
}

const distanceMapping = new DistanceMapping();
export default distanceMapping;
