// this file was primarily adapted from C++
import type { CartesianPose } from '@sb/geometry';
import { cartesianPoseToMatrix4 } from '@sb/geometry';
import type { Sixteen, Six } from '@sb/utilities';

import type { ArmJointLimits } from './ArmJointLimits';
import type { ArmJointPositions } from './ArmJointPositions';
import type { DHParams } from './DHParams';
import { DEFAULT_DH_PARAMS, DEFAULT_OFFSETS } from './DHParams';

const ZERO_THRESH = 1e-8;

function sign(x: number): number {
  if (x > 0) {
    return 1;
  }

  if (x < 0) {
    return -1;
  }

  return 0;
}

// helper function that gets called multiple times by inverseKinematics
function inverse(
  input: Sixteen<number>,
  defaultJoint5: number,
  dhParams: DHParams,
): Array<ArmJointPositions> {
  /* eslint-disable camelcase */
  /* eslint-disable spaced-comment */
  const { d1, a2, a3, d4, d5, d6 } = dhParams;

  const solutions: Array<ArmJointPositions> = [];
  let index = 0;
  const T02 = -input[index];
  index += 1;
  const T00 = input[index];
  index += 1;
  const T01 = input[index];
  index += 1;
  const T03 = -input[index];
  index += 1;
  const T12 = -input[index];
  index += 1;
  const T10 = input[index];
  index += 1;
  const T11 = input[index];
  index += 1;
  const T13 = -input[index];
  index += 1;
  const T22 = input[index];
  index += 1;
  const T20 = -input[index];
  index += 1;
  const T21 = -input[index];
  index += 1;
  const T23 = input[index];

  ///////////// shoulder rotate joint (q1) /////////////
  const q1 = [0, 0];

  {
    const A = d6 * T12 - T13;
    const B = d6 * T02 - T03;
    let R = A * A + B * B;

    if (Math.abs(A) < ZERO_THRESH) {
      let div: number;

      if (Math.abs(Math.abs(d4) - Math.abs(B)) < ZERO_THRESH) {
        div = -sign(d4) * sign(B);
      } else {
        div = -d4 / B;
      }

      let arcsin = Math.asin(div);

      if (Math.abs(arcsin) < ZERO_THRESH) {
        arcsin = 0.0;
      }

      if (arcsin < 0.0) {
        q1[0] = arcsin + 2.0 * Math.PI;
      } else {
        q1[0] = arcsin;
      }

      q1[1] = Math.PI - arcsin;
    } else if (Math.abs(B) < ZERO_THRESH) {
      let div: number;

      if (Math.abs(Math.abs(d4) - Math.abs(A)) < ZERO_THRESH) {
        div = sign(d4) * sign(A);
      } else {
        div = d4 / A;
      }

      const arccos = Math.acos(div);
      q1[0] = arccos;
      q1[1] = 2.0 * Math.PI - arccos;
    } else if (d4 * d4 - R >= ZERO_THRESH) {
      // pbeeson: Added to handle
      // numerical errors
      return solutions;
    } else {
      if (d4 * d4 > R) R = d4 * d4;
      const arccos = Math.acos(d4 / Math.sqrt(R));
      const arctan = Math.atan2(-B, A);
      let pos = arccos + arctan;
      let neg = -arccos + arctan;
      if (Math.abs(pos) < ZERO_THRESH) pos = 0.0;
      if (Math.abs(neg) < ZERO_THRESH) neg = 0.0;
      if (pos >= 0.0) q1[0] = pos;
      else q1[0] = 2.0 * Math.PI + pos;
      if (neg >= 0.0) q1[1] = neg;
      else q1[1] = 2.0 * Math.PI + neg;
    }
  }

  ///////////// wrist 2 joint (q5) /////////////
  const q5 = [
    [0, 0],
    [0, 0],
  ];

  for (let i = 0; i < 2; i += 1) {
    const numer = T03 * Math.sin(q1[i]) - T13 * Math.cos(q1[i]) - d4;
    let div: number;

    if (Math.abs(Math.abs(numer) - Math.abs(d6)) < ZERO_THRESH) {
      div = sign(numer) * sign(d6);
    } else {
      div = numer / d6;
    }

    const arccos = Math.acos(div);
    q5[i][0] = arccos;
    q5[i][1] = 2.0 * Math.PI - arccos;
  }

  for (let i = 0; i < 2; i += 1) {
    for (let j = 0; j < 2; j += 1) {
      const c1 = Math.cos(q1[i]);
      const s1 = Math.sin(q1[i]);
      const c5 = Math.cos(q5[i][j]);
      const s5 = Math.sin(q5[i][j]);
      let q6: number;

      ///////////// wrist 3 joint (q6) /////////////
      if (Math.abs(s5) < ZERO_THRESH) {
        q6 = defaultJoint5;
      } else {
        q6 = Math.atan2(
          sign(s5) * -(T01 * s1 - T11 * c1),
          sign(s5) * (T00 * s1 - T10 * c1),
        );

        if (Math.abs(q6) < ZERO_THRESH) {
          q6 = 0.0;
        }

        if (q6 < 0.0) {
          q6 += 2.0 * Math.PI;
        }
      }

      const q2 = [0, 0];
      const q3 = [0, 0];
      const q4 = [0, 0];

      ///////////// RRR joints (q2,q3,q4) /////////////
      const c6 = Math.cos(q6);
      const s6 = Math.sin(q6);

      const x04x =
        -s5 * (T02 * c1 + T12 * s1) -
        c5 * (s6 * (T01 * c1 + T11 * s1) - c6 * (T00 * c1 + T10 * s1));

      const x04y = c5 * (T20 * c6 - T21 * s6) - T22 * s5;

      const p13x =
        d5 * (s6 * (T00 * c1 + T10 * s1) + c6 * (T01 * c1 + T11 * s1)) -
        d6 * (T02 * c1 + T12 * s1) +
        T03 * c1 +
        T13 * s1;

      const p13y = T23 - d1 - d6 * T22 + d5 * (T21 * c6 + T20 * s6);

      let c3 =
        (p13x * p13x + p13y * p13y - a2 * a2 - a3 * a3) / (2.0 * a2 * a3);

      if (Math.abs(Math.abs(c3) - 1.0) < ZERO_THRESH) {
        c3 = sign(c3);
      } else if (Math.abs(c3) > 1.0) {
        // TODO NO SOLUTION
        continue;
      }

      const arccos = Math.acos(c3);
      q3[0] = arccos;
      q3[1] = 2.0 * Math.PI - arccos;
      const denom = a2 * a2 + a3 * a3 + 2 * a2 * a3 * c3;
      const s3 = Math.sin(arccos);
      const A = a2 + a3 * c3;
      const B = a3 * s3;

      q2[0] = Math.atan2(
        (A * p13y - B * p13x) / denom,
        (A * p13x + B * p13y) / denom,
      );

      q2[1] = Math.atan2(
        (A * p13y + B * p13x) / denom,
        (A * p13x - B * p13y) / denom,
      );

      const c23_0 = Math.cos(q2[0] + q3[0]);
      const s23_0 = Math.sin(q2[0] + q3[0]);
      const c23_1 = Math.cos(q2[1] + q3[1]);
      const s23_1 = Math.sin(q2[1] + q3[1]);

      q4[0] = Math.atan2(
        c23_0 * x04y - s23_0 * x04x,
        x04x * c23_0 + x04y * s23_0,
      );

      q4[1] = Math.atan2(
        c23_1 * x04y - s23_1 * x04x,
        x04x * c23_1 + x04y * s23_1,
      );

      for (let k = 0; k < 2; k += 1) {
        if (Math.abs(q2[k]) < ZERO_THRESH) {
          q2[k] = 0.0;
        } else if (q2[k] < 0.0) {
          q2[k] += 2.0 * Math.PI;
        }

        if (Math.abs(q4[k]) < ZERO_THRESH) {
          q4[k] = 0.0;
        } else if (q4[k] < 0.0) {
          q4[k] += 2.0 * Math.PI;
        }

        const solution = [q1[i], q2[k], q3[k], q4[k], q5[i][j], q6].map(
          (position, jointIndex) => position + DEFAULT_OFFSETS[jointIndex],
        ) as ArmJointPositions;

        const valid = solution.findIndex((angle) => Number.isNaN(angle)) === -1;

        if (valid) {
          solutions.push(solution);
        }
      }
    }
  }

  return solutions;
}

export function inverseKinematics(
  pose: CartesianPose,
  limits: ArmJointLimits,
  seedAngles: ArmJointPositions,
  // a default value to give joint 5 if it doesn't matter
  defaultJoint5: number = seedAngles[seedAngles.length - 1],
  dhParams = DEFAULT_DH_PARAMS,
): Array<ArmJointPositions> {
  const matrix = cartesianPoseToMatrix4(pose).transpose();
  const asArray: Sixteen<number> = matrix.toArray();
  // try to solve inverse kinematics with the pose directly
  let solutions = inverse(asArray, defaultJoint5, dhParams);

  // if that fails, try again with a grid of different nudges
  if (solutions.length === 0) {
    const epsilon = 1e-7;

    for (let xx = -1; xx <= 1; xx += 1) {
      for (let yy = -1; yy <= 1; yy += 1) {
        for (let zz = -1; zz <= 1; zz += 1) {
          if (xx === 0 && yy === 0 && zz === 0) {
            continue;
          }

          const nudgedPose: Sixteen<number> = [...asArray];
          nudgedPose[3] = asArray[3] + epsilon * xx;
          nudgedPose[7] = asArray[7] + epsilon * yy;
          nudgedPose[11] = asArray[11] + epsilon * zz;

          solutions = inverse(nudgedPose, defaultJoint5, dhParams);

          if (solutions.length > 0) {
            break;
          }
        }
      }
    }
  }

  // Put theoretical results into valid joint limit ranges, if possible
  const withinLimits = solutions
    .map((solution) => {
      const limited = solution.map((position, jointIndex) => {
        let proposed = position;
        const { min, max } = limits[jointIndex];

        // first move within the limits
        while (proposed > max) {
          proposed -= 2 * Math.PI;
        }

        while (proposed < min) {
          proposed += 2 * Math.PI;
        }

        // if we still could not correct to being within the limits, just return null
        // to filter out the result
        if (proposed > max || proposed < min) {
          return null;
        }

        return proposed;
      }) as Six<number | null>;

      if (limited.includes(null)) {
        return null;
      }

      return limited as ArmJointPositions;
    })
    .filter((positions) => Boolean(positions)) as Array<ArmJointPositions>; // filter out nulls

  // Then move closer to the seed if possible
  const movedCloser = withinLimits.map((solution) => {
    return solution.map((position, jointIndex) => {
      let proposed = position;

      while (proposed - seedAngles[jointIndex] > Math.PI) {
        if (proposed < limits[jointIndex].min) {
          return proposed;
        }

        proposed -= 2 * Math.PI;
      }

      while (proposed - seedAngles[jointIndex] < -Math.PI) {
        if (proposed > limits[jointIndex].min) {
          return proposed;
        }

        proposed += 2 * Math.PI;
      }

      return proposed;
    }) as ArmJointPositions;
  });

  // then sort by seed
  const sorted = movedCloser.sort((solutionA, solutionB) => {
    const diffA = solutionA.reduce((positionA, jointIndex) => {
      return Math.abs(positionA - seedAngles[jointIndex]);
    });

    const diffB = solutionB.reduce((positionB, jointIndex) => {
      return Math.abs(positionB - seedAngles[jointIndex]);
    });

    return diffA - diffB;
  });

  return sorted;
}
