import {ReactNode, useState, useMemo, useRef, useEffect} from 'react';
import { If, Then } from 'react-if';
import {Stage, Layer, Image, Line} from 'react-konva';
import {SolverStep, SolverStepInput, SolverStepResult, ImageData, Point2D, Axis, Solver, Line2D} from '../../../types';
import { Palette, Point2, transformNotNull, isNone } from '../../../utils';

import VanishingPointControl from '../../../components/VanishingPointControl';
import Stack from '@shared/components/stack';
import Button from '@shared/components/button';
import VPSolver from "@shared/components/camera-solver/solvers/vp-solver/VPSolver";
import ReferenceLengthControl from "@shared/components/camera-solver/components/ReferenceLengthControl";
import ModelObject from "@shared/components/camera-solver/components/ModelObject";
import Scene from "@shared/components/camera-solver/components/Scene";

interface Input extends SolverStepInput {
    image: ImageData | undefined
    referenceAxis: Axis
}

interface Result extends SolverStepResult {
    principalPoint: Point2D
    vanishingPoints: [Point2D, Point2D, Point2D]
    vp1Line?: [Line2D, Line2D]
    vp2Line?: [Line2D, Line2D]
    referenceAxis: Axis
    referenceLength: number
    referencePoints: [Point2D, Point2D]
    rl: Line2D
}

interface ComponentProps {
    input: Input
    solver: VPSolver
    roomHeight: number
    refLength: number
    onComplete?: (result: Result) => void
    onCancel?: () => void
}

function vanishingPointForAxis(vanishingPoints: [Point2D, Point2D, Point2D], axis: Axis): Point2D {
  switch (axis) {
    case Axis.X:
      return vanishingPoints[0];
    case Axis.Y:
      return vanishingPoints[1];
    case Axis.Z:
      return vanishingPoints[2];
  }
}

function Component({ input: { image, referenceAxis }, solver, roomHeight, refLength, onComplete, onCancel }: ComponentProps) {
  const imageRef: any = useRef();
  const layerRef: any = useRef();

  const [[width, height, naturalWidth, naturalHeight], setSize] = useState([0, 0, 0, 0]);
  const [vp1, setVP1] = useState<Point2D>();
  const [vp2, setVP2] = useState<Point2D>();
  const [vp1Line, setVP1Line] = useState<[Line2D, Line2D] | undefined>(solver.vp1Line);
  const [vp2Line, setVP2Line] = useState<[Line2D, Line2D] | undefined>(solver.vp2Line);
  const [rl, setRl] = useState<Line2D | undefined>(solver.rl);

  const isLoaded = useMemo(() => (width > 0 && height > 0), [width, height]);
  const isReady = useMemo(() => (!isNone(vp1, vp2)), [vp1, vp2]);

  const principalPoint = useMemo<Point2D>(() => ({ x: 0.5, y: 0.5 }), []);
  const vanishingPoints = useMemo<[Point2D, Point2D, Point2D] | null>(() => transformNotNull((vp1, vp2, principalPoint) => [
    Point2.rel(vp1!, { width: naturalWidth, height: naturalHeight }),
    Point2.rel(Point2.thirdVertex(vp1!, vp2!, Point2.abs(principalPoint!, { width: naturalWidth, height: naturalHeight })), { width: naturalWidth, height: naturalHeight }),
    Point2.rel(vp2!, { width: naturalWidth, height: naturalHeight }),
  ], [vp1, vp2, principalPoint]), [vp1?.x, vp1?.y, vp2?.x, vp2?.y, principalPoint]);

  const referencePoints = useMemo<Line2D | null>(() => transformNotNull((rl) => [
    Point2.rel(rl![0], { width: naturalWidth, height: naturalHeight }),
    Point2.rel(rl![1], { width: naturalWidth, height: naturalHeight }),
  ], [rl]), [rl]);
  const vanishingPoint = useMemo(() => {
    if (!vanishingPoints) {
      return null;
    }
    return Point2.abs(vanishingPointForAxis(vanishingPoints, referenceAxis), { width: naturalWidth, height: naturalHeight });
  }, [vanishingPoints, naturalWidth, naturalHeight]);


  const handleImageLoad = () => {
    const el: HTMLImageElement = imageRef.current;

    if (el) {
      setSize([el.clientWidth, el.clientHeight, el.naturalWidth, el.naturalHeight]);
    }
  };

  const interior = useMemo(() => {
    if (vanishingPoints && referencePoints) {
      const result =  solver.preSolve({
        principalPoint,
        vanishingPoints: vanishingPoints!,
        vp1Line,
        vp2Line,
        referenceAxis,
        referenceLength: refLength / 100,
        referencePoints: referencePoints!,
        rl: rl!,
      });
      return result ? {
        layout: {
          ...result,
          height: roomHeight,
        },
      } : null;
    }
    return;
  }, [roomHeight, refLength, vanishingPoints, referencePoints])

  return (
    <Stack direction="vertical" gap={ 1 }>
      <div
        style={ {
          position: 'relative',
          overflow: 'hidden',
          lineHeight: 0,
        } }
      >
        <img
          ref={ imageRef }
          style={ {
            maxWidth: '100%',
            maxHeight: '100%',
            visibility: 'hidden',
          } }
          src={ image!.url! }
          onLoad={ handleImageLoad }
        />
        <If condition={ isLoaded }>
          <Then>
            <Stage
              style={ {
                position: 'absolute',
                top: 0,
                left: 0,
                right: 0,
                bottom: 0,
              } }
              scaleX={ width / naturalWidth }
              scaleY={ height / naturalHeight }
              width={ width }
              height={ height }
            >
              <Layer ref={layerRef}>
                <Image
                  x={0}
                  y={0}
                  image={imageRef.current}
                  width={naturalWidth}
                  height={naturalHeight}
                />
                <Line
                  listening={ false }
                  points={ Point2.flatten(...solver.floorPoints!.map(p => ({
                    x: p.x * naturalWidth,
                    y: p.y * naturalHeight,
                  }))) }
                  stroke={ Palette.floorColor() }
                  strokeWidth={ 1 / (width / naturalWidth) }
                  closed={ false }
                />
                <VanishingPointControl
                  index={ 1 }
                  label={ 'x' }
                  color={ Palette.colorForAxis(Axis.X) }
                  direction="horizontal"
                  margin={ { x: naturalWidth * 0.25, y: naturalHeight * 0.25 } }
                  midpoint={ { x: naturalWidth * 0.5, y: naturalHeight * 0.5 } }
                  line1={vp1Line?.[0]}
                  line2={vp1Line?.[1]}
                  onReady={ (result) => {
                    setVP1(result.vanishingPoint);
                    setVP1Line([result.line1, result.line2]);
                  } }
                  onEmpty={ () => {
                    setVP1(undefined);
                  } }
                  layer={layerRef}
                  scale={ 1 / (width / naturalWidth) }
                />
                <VanishingPointControl
                  index={ 2 }
                  label={ 'z' }
                  color={ Palette.colorForAxis(Axis.Z) }
                  direction="vertical"
                  margin={ { x: naturalWidth * 0.3, y: naturalHeight * 0.25 } }
                  midpoint={ { x: naturalWidth * 0.5, y: naturalHeight * 0.5 } }
                  line1={vp2Line?.[0]}
                  line2={vp2Line?.[1]}
                  onReady={ (result) => {
                    setVP2(result.vanishingPoint);
                    setVP2Line([result.line1, result.line2]);
                  } }
                  onEmpty={ () => {
                    setVP2(undefined);
                  } }
                  layer={layerRef}
                  scale={ 1 / (width / naturalWidth) }
                />
                {
                  (vanishingPoint || rl?.[0] && rl?.[1])  && <ReferenceLengthControl
                    color={ Palette.colorForAxis(referenceAxis) }
                    midpoint={ { x: naturalWidth * 0.5, y: naturalHeight * 0.5 } }
                    margin={ Math.min(naturalWidth, naturalHeight) * 0.25 }
                    vanishingPoint={ vanishingPoint! }
                    onReady={ (rl) => setRl(rl) }
                    onEmpty={ () => setRl(undefined) }
                    point0={ rl?.[0] }
                    point1={ rl?.[1] }
                    scale={ 1 / (width / naturalWidth) }
                  />
                }
              </Layer>
            </Stage>
            {
              interior && <Scene
                interior={ interior! }
                width={ width }
                height={ height }
                style={ {
                  position: 'absolute',
                  top: 0,
                  left: 0,
                  right: 0,
                  bottom: 0,
                  pointerEvents: 'none',
                } }
              ><></>
              </Scene>
            }
          </Then>
        </If>
      </div>
      <Stack
        direction="row"
        gap={ 2 }
      >
        <Button onClick={ () => onCancel && onCancel() }>Back</Button>
        <Button onClick={ () => onComplete && onComplete({
          principalPoint,
          vanishingPoints: vanishingPoints!,
          vp1Line,
          vp2Line,
          referenceAxis,
          referenceLength: refLength / 100,
          referencePoints: referencePoints!,
          rl: rl!,
        }) } variant="contained" disabled={ !isReady }>Continue</Button>
      </Stack>
    </Stack>
  );
}

class VPStep extends SolverStep<Input, Result> {
  get label() {
    return 'Define Vanishing Points';
  }

  component(roomHeight: number, refLength: number): ReactNode {
    return (
      <Component
        input={ this.input! }
        solver={ this.solver as VPSolver }
        refLength={ refLength }
        roomHeight={ roomHeight }
        onComplete={ (result) => this.complete(result) }
        onCancel={ () => this.cancel() }
      />
    );
  }
}

export default VPStep;
