import * as math from 'mathjs';
import { Matrix } from 'mathjs';
import { Point2, Point3, Plane, Euler, Perspective } from '../../utils';
import {
  Solver,
  SolverResult,
  Point2D,
  Point3D,
  Line3D,
  Plane3D,
  Axis,
  ZKFieldOfView,
  ZKEulerAngles,
  ZKPosition, Line2D, SolverOptions,
} from '../../types';

interface VPSolverResult extends SolverResult {
}

class VPSolver extends Solver<VPSolverResult> {

  type = -1;
  vp1Line: [Line2D, Line2D] | undefined;
  vp2Line: [Line2D, Line2D] | undefined;
  principalPoint: Point2D | undefined;

  vanishingPoints: [Point2D, Point2D, Point2D] | undefined;

  referenceAxis: Axis | undefined;

  referenceLength: number | undefined;

  referencePoints: [Point2D, Point2D] | undefined;

  rl: Line2D | undefined;


  floorPoints: Point2D[] | undefined;

  points: Point2D[] | undefined;

  constructor(options: SolverOptions<VPSolverResult>) {
    super(options);
    this.vp1Line = options.layout?.vp1Line;
    this.vp2Line = options.layout?.vp2Line;
    this.principalPoint = options.layout?.principalPoint;
    this.vanishingPoints = options.layout?.vanishingPoints;
    this.rl = options.layout?.rl;
    this.referenceAxis = options.layout?.referenceAxis;
    this.referenceLength = options.layout?.refLength;
    this.referencePoints = options.layout?.referencePoints;
    this.points = options.layout?.points
    this.floorPoints = options.layout?.floorPoints;
  }

  prepareVPStep(referenceAxis: Axis) {
    return {
      image: this.image,
      principalPoint: this.principalPoint!,
      vanishingPoints: this.vanishingPoints!,
      referenceAxis: referenceAxis!,
    };
  }

  prepareReferenceStep(referenceAxis: Axis) {
    const {
      image,
      principalPoint,
      vanishingPoints,
    } = this;

    return {
      image,
      principalPoint: principalPoint!,
      vanishingPoints: vanishingPoints!,
      referenceAxis: referenceAxis!,
    };
  }

  prepareFloorStep() {
    return { image: this.image };
  }

  cancelVPStep() {
    this.backward();
  }

  cancelReferenceStep() {
    delete this.rl;
    this.backward();
  }

  cancelFloorStep() {
    this.backward();
  }

  completeVPStep(result: any) {
    this.vp1Line = result.vp1Line;
    this.vp2Line = result.vp2Line;
    this.principalPoint = result.principalPoint;
    this.vanishingPoints = result.vanishingPoints;
    this.rl = result.rl;
    this.referenceAxis = result.referenceAxis;
    this.referenceLength = result.referenceLength;
    this.referencePoints = result.referencePoints;
    this.forward();
  }

  completeReferenceStep(result: any) {
    this.rl = result.rl;
    this.referenceAxis = result.referenceAxis;
    this.referenceLength = result.referenceLength;
    this.referencePoints = result.referencePoints;
    this.forward();
  }

  completeFloorStep(result: any) {
    this.points = result.points;
    this.floorPoints = result.floorPoints;
    this.forward();
  }

  computeFocalLength(): number | null {
    const { image: { width, height }, vanishingPoints, principalPoint } = this;
    const Fu = Point2.abs(vanishingPoints![0], { width, height } as any);
    const Fv = Point2.abs(vanishingPoints![2], { width, height } as any);
    const P = Point2.abs(principalPoint!, { width, height } as any);

    const dirFuFv = Point2.normalized(Point2.substract(Fu, Fv));
    const FvP = Point2.substract(P, Fv);
    const Puv = Point2.add(Point2.multiply(dirFuFv, Point2.dot(dirFuFv, FvP)), Fv);

    const FvPuv = Point2.distance(Fv, Puv);
    const FuPuv = Point2.distance(Fu, Puv);
    const PPuv = Point2.distance(P, Puv);
    const f2 = FvPuv * FuPuv - PPuv * PPuv;

    if (f2 > 0) {
      return Math.sqrt(f2);
    }

    return null;
  }

  computeFieldOfView(focalLength: number): ZKFieldOfView {
    const { image: { width, height } } = this;
    const x = 2 * Math.atan(0.5 * width! / focalLength);
    const y = 2 * Math.atan(0.5 * height! / focalLength);

    return { x, y };
  }

  computeRotationMatrix(focalLength: number, delta = 1): Matrix {
    const { image: { width, height }, vanishingPoints, principalPoint } = this;
    const Fu = Point2.ndc(vanishingPoints![0], { width, height } as any);
    const Fv = Point2.ndc(vanishingPoints![2], { width, height } as any);
    const P = Point2.ndc(principalPoint!, { width, height } as any);
    const s = 2 * focalLength / width!;

    const OFu = Point3.v(Fu.x - P.x, Fu.y - P.y, -s);
    const OFv = Point3.v(Fv.x - P.x, Fv.y - P.y, -s);

    const l1 = Point3.length(OFu);
    const upRc = Point3.normalized(OFu);

    const l2 = Point3.length(OFv);
    const vpRc = Point3.normalized(OFv);

    const wpRc = Point3.cross(upRc, vpRc);

    const M = math.matrix([
      [OFu.x / l1, OFv.x / l2, wpRc.x],
      [OFu.y / l1, OFv.y / l2, wpRc.y],
      [-s / l1, -s / l2, wpRc.z],
    ]);

    const row1 = [delta, 0, 0];
    const row2 = [0, 0, -1];
    const row3 = math.cross(row1, row2) as number[];

    const axisMatrix = math.matrix([
      row1,
      row2,
      row3,
    ]);

    const rotationMatrix = math.multiply(M, axisMatrix);
    const yaw = Euler.fromMatrix(rotationMatrix, 'YXZ')!.y;
    const alignMatrix = math.inv(Euler.toMatrix(Point3.p(0, yaw, 0), 'YXZ')!);

    return math.multiply(rotationMatrix, alignMatrix);
  }

  computeCameraTransform(focalLength: number, rotationMatrix: Matrix): Matrix {
    const {
      image: { width, height },
      principalPoint,
      referenceAxis,
      referencePoints,
      referenceLength,
    } = this;

    const viewProps = { width, height, focalLength, principalPoint };
    const origin = Point2.ndc(referencePoints![0], { width, height } as any);
    const P = Point2.ndc(principalPoint!, { width, height } as any);

    const r11 = rotationMatrix.get([0, 0]);
    const r12 = rotationMatrix.get([0, 1]);
    const
      r13 = rotationMatrix.get([0, 2]);
    const r21 = rotationMatrix.get([1, 0]);
    const r22 = rotationMatrix.get([1, 1]);
    const
      r23 = rotationMatrix.get([1, 2]);
    const r31 = rotationMatrix.get([2, 0]);
    const r32 = rotationMatrix.get([2, 1]);
    const
      r33 = rotationMatrix.get([2, 2]);
    const k = 0.5 * width! / focalLength;

    let origin3 = Point3.multiply(Point3.p(k * (origin.x - P.x), k * (origin.y - P.y), -1), 10);
    let u: Point3D;
    let v: Point3D;
    let
      w: Point3D;

    switch (referenceAxis) {
      case Axis.X:
        u = Point3.p(0, 1, 0);
        v = Point3.p(0, 0, 1);
        w = Point3.p(1, 0, 0);
        break;
      case Axis.Y:
        u = Point3.p(1, 0, 0);
        v = Point3.p(0, 0, 1);
        w = Point3.p(0, 1, 0);
        break;
      case Axis.Z:
        u = Point3.p(1, 0, 0);
        v = Point3.p(0, 1, 0);
        w = Point3.p(0, 0, 1);
        break;
    }

    const viewTransform = math.matrix([
      [r11, r12, r13, origin3.x],
      [r21, r22, r23, origin3.y],
      [r31, r32, r33, origin3.z],
      [0, 0, 0, 1],
    ]);

    const anchor = Point2.ndc(referencePoints![0], { width, height } as any);
    const anchorRay = <Line3D>[
      Perspective.unproject(Point3.p(anchor.x, anchor.y, 1), viewProps as any, viewTransform),
      Perspective.unproject(Point3.p(anchor.x, anchor.y, 2), viewProps as any, viewTransform),
    ];

    const point = Point2.ndc(referencePoints![1], { width, height } as any);
    const pointRay = <Line3D>[
      Perspective.unproject(Point3.p(point.x, point.y, 1), viewProps as any, viewTransform),
      Perspective.unproject(Point3.p(point.x, point.y, 2), viewProps as any, viewTransform),
    ];

    const planeIntersection = Plane.intersection([Point3.zero(), u!, v!], anchorRay);

    const p0 = Plane.shortestLineBetweenLines(anchorRay, [planeIntersection, Point3.add(planeIntersection, w!)])[0];
    const p1 = Plane.shortestLineBetweenLines(pointRay, [planeIntersection, Point3.add(planeIntersection, w!)])[0];
    const length = Point3.distance(p0, p1);
    const scale = referenceLength! / length;

    origin3 = Point3.multiply(origin3, scale);

    viewTransform.set([0, 3], origin3.x);
    viewTransform.set([1, 3], origin3.y);
    viewTransform.set([2, 3], origin3.z);

    const cameraTransform = math.inv(viewTransform);

    cameraTransform.set([0, 3], 0);
    cameraTransform.set([2, 3], 0);

    return cameraTransform;
  }

  computeFloorPoints(focalLength: number, cameraTransform: Matrix): ZKPosition[] {
    const {
      image: { width, height },
      principalPoint,
      floorPoints,
    } = this;

    const viewProps = { width, height, focalLength, principalPoint };
    const viewTransform = math.inv(cameraTransform);

    const floorPlane = <Plane3D>[
      Point3.zero(),
      Point3.p(1, 0, 0),
      Point3.p(0, 0, 1),
    ];

    return floorPoints!.map((floorPoint, index) => {
      const point = Point2.ndc(floorPoint, { width, height } as any);
      const pointRay = <Line3D>[
        Perspective.unproject(Point3.p(point.x, point.y, 1), viewProps as any, viewTransform),
        Perspective.unproject(Point3.p(point.x, point.y, 2), viewProps as any, viewTransform),
      ];

      return Plane.intersection(floorPlane, pointRay);
    });
  }

  computeEulerAngles(cameraTransform: Matrix, order: Euler.Order = 'YXZ'): ZKEulerAngles {
    const eulerAngles = Euler.fromMatrix(cameraTransform, order);

    return {
      pitch: eulerAngles!.x,
      roll: eulerAngles!.z,
      yaw: eulerAngles!.y,
    };
  }

  computePosition(cameraTransform: Matrix): ZKPosition {
    return {
      x: cameraTransform.get([0, 3]),
      y: cameraTransform.get([1, 3]),
      z: cameraTransform.get([2, 3]),
    };
  }

  solve(negative = false): VPSolverResult {
    const focalLength = this.computeFocalLength()!;
    const rotationMatrix = this.computeRotationMatrix(focalLength, negative ? -1 : 1);
    const cameraTransform = this.computeCameraTransform(focalLength, rotationMatrix);

    const floorPoints = this.computeFloorPoints(focalLength, cameraTransform);
    const fieldOfView = this.computeFieldOfView(focalLength);
    const eulerAngles = this.computeEulerAngles(cameraTransform);
    const position = this.computePosition(cameraTransform);
    const viewport = { fieldOfView, eulerAngles, position };

    if (!negative && position.y < 0) {
      return this.solve(true);
    }

    return {
      viewport,
      floorPoints,
      type: this.type,
      vp1Line: this.vp1Line,
      vp2Line: this.vp2Line,
      principalPoint: this.principalPoint,
      vanishingPoints: this.vanishingPoints,
      rl: this.rl,
      referencePoints: this.referencePoints,
      referenceAxis: this.referenceAxis,
      points: this.points,
    };
  }

  preSolve(result: any) {
    this.vp1Line = result.vp1Line;
    this.vp2Line = result.vp2Line;
    this.principalPoint = result.principalPoint;
    this.vanishingPoints = result.vanishingPoints;
    this.rl = result.rl;
    this.referenceAxis = result.referenceAxis;
    this.referenceLength = result.referenceLength;
    this.referencePoints = result.referencePoints;
    try {
      return this.solve();
    } catch (error: any) {
      return null;
    }
  }
}

export default VPSolver;
