import { simd } from "wasm-feature-detect";

import { constructPipeStreamAdapter, StreamAdapter, Frame } from "../../../../streamAdapters";

import { AbstractStreamTransformation } from "./AbstractStreamTransformation";
import { WebGLJointBilateral } from './webglUtils';



class BackgroundSegmentationTflite {
  private static tflite?: Promise<TFLite>;
  private canvas: StreamAdapter;
  private ctx: CanvasRenderingContext2D;
  private modelShape: [number, number];

  constructor() {
    this.modelShape = [256, 144];
    this.canvas = constructPipeStreamAdapter();
    this.canvas.setShape(this.modelShape);
    this.ctx = this.canvas.get2DContext();
    if (!BackgroundSegmentationTflite.tflite)
      BackgroundSegmentationTflite.tflite = this.buildTflite();
  }

  private async buildTflite(): Promise<TFLite> {
    const simdSupported = await simd();
    const tflite = simdSupported ? await createTFLiteSIMDModule() : await createTFLiteModule();
    const modelBufferOffset = tflite._getModelBufferMemoryOffset();
    const modelResponse = await fetch("../assets/model/selfie_segmentation_landscape.tflite");
    const model = await modelResponse.arrayBuffer();
    tflite.HEAPU8.set(new Uint8Array(model), modelBufferOffset);
    tflite._loadModel(model.byteLength);
    return tflite;
  }

  public segment(frame: Frame, frameShape: [number, number]): Promise<ImageData> {
    const [frameWidth, frameHeight] = frameShape;
    const [modelWidth, modelHeight] = this.modelShape;
    this.ctx.drawImage(frame, 0, 0, frameWidth, frameHeight, 0, 0, modelWidth, modelHeight);
    const mask = this.ctx.getImageData(0, 0, modelWidth, modelHeight);
    if (!BackgroundSegmentationTflite.tflite) throw new Error("TFLite has not been built");
    return BackgroundSegmentationTflite.tflite.then(tflite => {
      const inputMemoryOffset = tflite._getInputMemoryOffset() / 4;
      for (let i = 0; i < modelWidth * modelHeight ; i++) {
        tflite.HEAPF32[inputMemoryOffset + (i * 3) + 0] = mask.data[(i * 4) + 0] / 255;
        tflite.HEAPF32[inputMemoryOffset + (i * 3) + 1] = mask.data[(i * 4) + 1] / 255;
        tflite.HEAPF32[inputMemoryOffset + (i * 3) + 2] = mask.data[(i * 4) + 2] / 255;
      }
      tflite._runInference();
      const outputMemoryOffset = tflite._getOutputMemoryOffset() / 4;
      for (let i = 0; i < modelWidth * modelHeight; i++) {
        mask.data[(i * 4) + 3] = 255 * tflite.HEAPF32[outputMemoryOffset + i];
      }
      return mask;
    });
  }
}

export class BackgroundStreamTransformation extends AbstractStreamTransformation {
  private segmentation: BackgroundSegmentationTflite;
  private stillCanvas: StreamAdapter;
  private stillContext: CanvasRenderingContext2D;
  private bgTransform: AbstractStreamTransformation;
  private fgTransform: AbstractStreamTransformation;
  private context?: CanvasRenderingContext2D;
  private maskBlur?: WebGLJointBilateral;
  private maskBlurRadius = 5;

  constructor(bgTransform: AbstractStreamTransformation, fgTransform: AbstractStreamTransformation) {
    super();
    this.segmentation = new BackgroundSegmentationTflite();
    this.stillCanvas = constructPipeStreamAdapter();
    this.stillContext = this.stillCanvas.get2DContext();
    this.bgTransform = bgTransform;
    this.fgTransform = fgTransform;
  }

  public async setup(source: StreamAdapter, sink: StreamAdapter): Promise<void> {
    const shape = await source.getShape();
    this.context = sink.get2DContext();
    this.maskBlur = new WebGLJointBilateral(shape);
    this.stillCanvas.setShape(shape);
    const fgCanvas = this.maskBlur.getSink();
    await this.bgTransform.setup(this.stillCanvas, sink);
    await this.fgTransform.setup(fgCanvas, sink);
    this.setStreams(source, sink);
  }

  public async render(): Promise<void> {
    await this.setupPromise;
    if (!this.source || !this.maskBlur) return;
    const frame = await this.source.getFrame();
    const shape = await this.source.getShape();
    this.stillContext.drawImage(frame, 0, 0);
    if (!this.context) return;
    this.context.clearRect(0, 0, ...shape);
    const stillFrame = await this.stillCanvas.getFrame();
    const mask = await this.segmentation.segment(stillFrame, shape);
    await this.maskBlur.bilateral(stillFrame, mask);
    await this.bgTransform.render();
    await this.fgTransform.render();
  }

  public async stop(): Promise<void> {
    await this.setupPromise;
    this.bgTransform.stop();
    this.fgTransform.stop();
    if (this.maskBlur) this.maskBlur.dispose();
    this.stopStreams();
  }
}
