import React, { useState } from "react";
import { z } from "zod";
import { chunk } from "~/lib/std";
import type { Record } from "~/services/datastore";
import { FlipDirection } from "../../../panels";
import { mapSegmentationClassToOklchColor } from "./utils";

export interface Polygon {
  points: ReadonlyArray<{
    x: number;
    y: number;
  }>;
}

export interface Segmentation {
  classId: number;
  className: string;
  boundingBox: {
    topLeftX: number;
    topLeftY: number;
    width: number;
    height: number;
  };
  polygons: ReadonlyArray<Polygon>;
}

export interface ImageSegmentations {
  segmentations: ReadonlyArray<Segmentation>;
  imageWidth: number;
  imageHeight: number;
}

const cocoBboxSchema = z.tuple([
  z.number(),
  z.number(),
  z.number(),
  z.number(),
]);

const cocoSegmentationSchema = z
  .array(z.number())
  .refine(
    (value) => value.length % 2 === 0,
    "Expected an even number of points",
  );

function transformSegmentationToPolygon(
  segmentation: ReadonlyArray<number>,
): Polygon {
  return {
    points: chunk(segmentation, 2).map(([x, y]) => ({ x, y })),
  };
}

const standardCocoDataSchema = z
  .object({
    images: z
      .array(
        z.object({
          width: z.number(),
          height: z.number(),
        }),
      )
      .nonempty(),
    categories: z.array(
      z.object({
        id: z.number(),
        name: z.string(),
      }),
    ),
    annotations: z.array(
      z.object({
        bbox: cocoBboxSchema,
        category_id: z.number(),
        segmentation: z.array(cocoSegmentationSchema),
      }),
    ),
  })
  .transform((value, ctx): ImageSegmentations => {
    const {
      annotations,
      categories,
      // Assuming there's only 1 image
      images: [{ width: imageWidth, height: imageHeight }],
    } = value;

    const categoryMap = new Map(categories.map(({ id, name }) => [id, name]));

    const segmentations = new Array<Segmentation>();
    // Using a for-of loop instead of `.map()` makes it simpler to return early
    // with `z.NEVER` if needed
    for (const annotation of annotations) {
      const {
        bbox: [topLeftX, topLeftY, bboxWidth, bboxHeight],
        segmentation,
        category_id: classId,
      } = annotation;

      const className = categoryMap.get(classId);
      if (className === undefined) {
        ctx.addIssue({
          code: z.ZodIssueCode.custom,
          message: `No category found with ID ${classId}`,
          fatal: true,
        });

        return z.NEVER;
      }

      segmentations.push({
        classId,
        className,
        boundingBox: {
          topLeftX,
          topLeftY,
          width: bboxWidth,
          height: bboxHeight,
        },
        polygons: segmentation.map(transformSegmentationToPolygon),
      });
    }

    return {
      imageWidth,
      imageHeight,
      segmentations,
    };
  });

const crlCocoDataSchema = z
  .object({
    img_attributes: z.object({
      width: z.number(),
      height: z.number(),
    }),
    annotations: z.array(
      z.object({
        bbox: cocoBboxSchema,
        category: z.object({
          id: z.number(),
          name: z.string(),
        }),
        segmentations: z.array(
          z.object({
            segmentation: cocoSegmentationSchema,
          }),
        ),
      }),
    ),
  })
  .transform((value): ImageSegmentations => {
    const {
      annotations,
      img_attributes: { width: imageWidth, height: imageHeight },
    } = value;

    return {
      imageWidth,
      imageHeight,
      segmentations: annotations.map((annotation): Segmentation => {
        const {
          bbox: [topLeftX, topLeftY, bboxWidth, bboxHeight],
          segmentations,
          category: { id: classId, name: className },
        } = annotation;

        return {
          classId,
          className,
          boundingBox: {
            topLeftX,
            topLeftY,
            width: bboxWidth,
            height: bboxHeight,
          },
          polygons: segmentations.map(({ segmentation }) =>
            transformSegmentationToPolygon(segmentation),
          ),
        };
      }),
    };
  });

const cocoRecordsSchema = z.union([standardCocoDataSchema, crlCocoDataSchema]);

export function normalizeCocoRecords(
  record: Record,
): ImageSegmentations | null {
  if (record.queryData === null) {
    return null;
  }

  try {
    return cocoRecordsSchema.parse(record.queryData);
  } catch {
    // TODO: Present user with error message
    return null;
  }
}

export function SegmentationsOverlay({
  imageSegmentations,
  rotationDeg,
  flipDirection,
  showBoundingBoxes,
  showClassNames,
  hiddenClassNames,
}: {
  imageSegmentations: ImageSegmentations;
  rotationDeg: number;
  flipDirection: FlipDirection | null;
  showBoundingBoxes: boolean;
  showClassNames: boolean;
  hiddenClassNames: ReadonlyArray<string>;
}) {
  const [hoveredClassId, setHoveredClassId] = useState<number | null>(null);

  function createPointerEnterHandler(classId: number) {
    return function handlePointerEnter() {
      setHoveredClassId(classId);
    };
  }

  function handlePointerLeave() {
    setHoveredClassId(null);
  }

  const hoveredClass = imageSegmentations.segmentations.find(
    (segmentation) => segmentation.classId === hoveredClassId,
  );

  return (
    <svg
      viewBox={`0 0 ${imageSegmentations.imageWidth} ${imageSegmentations.imageHeight}`}
      data-image-segmentations
    >
      {hoveredClass !== undefined && (
        <g>
          <text
            style={{
              transform: calculateTextTransform({
                bboxWidth: imageSegmentations.imageWidth,
                bboxHeight: imageSegmentations.imageHeight,
                rotationDeg,
                flipDirection,
              }),
            }}
            color={mapSegmentationClassToOklchColor(hoveredClass)}
            fill="currentcolor"
            stroke="currentcolor"
            x={0}
            y={0}
            dx={3}
            dy={24}
            fontSize={24}
            fontFamily="monospace"
          >
            {hoveredClass.className}
          </text>
        </g>
      )}
      {imageSegmentations.segmentations.flatMap((segmentation, index) => {
        const isNotHoveredClass =
          hoveredClassId !== null && segmentation.classId !== hoveredClassId;

        if (
          isNotHoveredClass ||
          hiddenClassNames.includes(segmentation.className)
        ) {
          return [];
        }

        return (
          <g
            key={index}
            color={mapSegmentationClassToOklchColor(segmentation)}
            fill="currentcolor"
            stroke="currentcolor"
            data-class-id={segmentation.classId}
            onPointerEnter={createPointerEnterHandler(segmentation.classId)}
            onPointerLeave={handlePointerLeave}
          >
            <rect
              x={segmentation.boundingBox.topLeftX}
              y={segmentation.boundingBox.topLeftY}
              width={segmentation.boundingBox.width}
              height={segmentation.boundingBox.height}
              strokeWidth={3}
              stroke={showBoundingBoxes ? undefined : "none"}
              fill="none"
            />
            {segmentation.polygons.map((segmentation, index) => (
              <polygon
                key={index}
                fillOpacity={0.25}
                points={segmentation.points
                  .map(({ x, y }) => `${x},${y}`)
                  .join(" ")}
              />
            ))}
            {showClassNames && segmentation.classId !== hoveredClassId && (
              <text
                style={{
                  transformOrigin: [
                    `${segmentation.boundingBox.topLeftX}px`,
                    `${segmentation.boundingBox.topLeftY}px`,
                  ].join(" "),
                  transform: calculateTextTransform({
                    bboxWidth: segmentation.boundingBox.width,
                    bboxHeight: segmentation.boundingBox.height,
                    rotationDeg,
                    flipDirection,
                  }),
                }}
                x={segmentation.boundingBox.topLeftX}
                y={segmentation.boundingBox.topLeftY}
                dx={3}
                dy={15}
                fontSize={14}
                fontFamily="monospace"
              >
                {segmentation.className}
              </text>
            )}
          </g>
        );
      })}
    </svg>
  );
}

function calculateTextTransform({
  bboxWidth,
  bboxHeight,
  rotationDeg,
  flipDirection,
}: {
  bboxWidth: number;
  bboxHeight: number;
  rotationDeg: number;
  flipDirection: FlipDirection | null;
}): string {
  let scaleX = 1;
  let scaleY = 1;
  let isScaled = false;

  if (flipDirection === FlipDirection.X) {
    scaleX = -1;

    isScaled = true;
  } else if (flipDirection === FlipDirection.Y) {
    scaleY = -1;

    isScaled = true;
  }

  // Number in range [0, 3] representing the effective number of 90 degree
  // clockwise rotations
  const rotationQuadrant = modulo(rotationDeg, 360) / 90;

  let translateX = 0;
  let translateY = 0;

  if (flipDirection === null) {
    if (rotationQuadrant === 0) {
      /* no translation */
    } else if (rotationQuadrant === 1) {
      translateY = bboxHeight;
    } else if (rotationQuadrant === 2) {
      translateX = bboxWidth;
      translateY = bboxHeight;
    } else {
      translateX = bboxWidth;
    }
  } else if (flipDirection === FlipDirection.X) {
    if (rotationQuadrant === 0) {
      translateX = bboxWidth;
    } else if (rotationQuadrant === 1) {
      translateX = bboxWidth;
      translateY = bboxHeight;
    } else if (rotationQuadrant === 2) {
      translateY = bboxHeight;
    } else {
      /* no translation */
    }
  } else {
    if (rotationQuadrant === 0) {
      translateY = bboxHeight;
    } else if (rotationQuadrant === 1) {
      /* no translation */
    } else if (rotationQuadrant === 2) {
      translateX = bboxWidth;
    } else {
      translateX = bboxWidth;
      translateY = bboxHeight;
    }
  }

  const rotationMagnitude = isScaled ? 1 : -1;

  return [
    `translate(${translateX}px, ${translateY}px)`,
    `rotate(${rotationMagnitude * rotationDeg}deg)`,
    `scale(${scaleX}, ${scaleY})`,
  ].join(" ");
}

function modulo(dividend: number, divisor: number): number {
  return ((dividend % divisor) + divisor) % divisor;
}
