import { curveStepAfter } from '@visx/curve';
import { localPoint } from '@visx/event';
import { LinearGradient } from '@visx/gradient';
import { scaleLinear } from '@visx/scale';
import { AreaClosed, Line, LinePath } from '@visx/shape';
import { useTooltip, useTooltipInPortal } from '@visx/tooltip';
import Color from 'color';
import {
  Fragment,
  MouseEvent,
  Ref,
  useCallback,
  useId,
  useLayoutEffect,
  useMemo,
  useRef,
  useState,
} from 'react';
import css from './styles.module.css';
import {
  RenderSteppedLineChartTooltipFn,
  Tooltip,
  TooltipData,
} from '@vault/SteppedLineChart/Tooltip';
import { HoverHighlight } from '@vault/SteppedLineChart/HoverHighlight';
import composeRefs from '@seznam/compose-react-refs';
import { cn } from '@vault/utilities';
import { Axes } from '@vault/SteppedLineChart/Axes';
import { useMargins } from '@vault/SteppedLineChart/utilities';
import { SteppedLineChartSkeleton } from '@vault/SteppedLineChart/Skeleton';

const ANNOTATION_WIDTH = 10;
const PATH_STROKE_WIDTH = 1;
const PATH_STROKE_DELTA = PATH_STROKE_WIDTH * 0.5;

export interface SteppedLineChartData {
  x: number;
  y: number;
}

export interface SteppedLineChartAnnotation {
  x: number;
}

export interface SteppedLineChartProps<
  Data extends SteppedLineChartData,
  Annotation extends SteppedLineChartAnnotation,
> {
  className?: string;
  /** The data series to display. Each series should be the same length. */
  data: Data[][];
  /** The colors to use for each series. */
  colors: string[];
  /** The annotations to display. */
  annotations?: Annotation[][];
  /** The label to display on the x-axis. Hides the axis if not provided. */
  renderXLabel?: (x: number) => string;
  /** The label to display on the y-axis. Hides the axis if not provided. */
  renderYLabel?: (y: number) => string;
  /** The tooltip to display when hovering over a data point or annotation. */
  renderTooltip?: RenderSteppedLineChartTooltipFn<Data, Annotation>;
}

function SteppedLineChart<
  Data extends SteppedLineChartData,
  Annotation extends SteppedLineChartAnnotation,
>({
  className,
  data,
  colors,
  annotations,
  renderXLabel,
  renderYLabel,
  renderTooltip,
}: SteppedLineChartProps<Data, Annotation>) {
  const id = useId();

  // Resize observer to update the size of the chart
  const svgRef = useRef<SVGElement>(null);
  const [size, setSize] = useState({ width: 0, height: 0 });
  useLayoutEffect(() => {
    const svg = svgRef.current;
    if (!svg) return;

    const observer = new ResizeObserver((entries) => {
      setSize({
        width: entries[0].contentRect.width,
        height: entries[0].contentRect.height,
      });
    });
    observer.observe(svg);
    return () => observer.disconnect();
  }, []);

  // Resize observer to update margins of the chart
  const xAxisRef = useRef<SVGSVGElement>(null);
  const yAxisRef = useRef<SVGSVGElement>(null);
  const margin = useMargins(
    { top: 12 },
    renderXLabel ? xAxisRef : null,
    renderYLabel ? yAxisRef : null
  );

  // Calculate the min and max indices of the data
  const [minIndex, maxIndex] = useMemo(() => {
    const min = 0;
    const max = Math.max(...data.map((series) => series.length)) - 1;
    return [min, max] as const;
  }, [data]);

  // Calculate the min and max X values of the data
  const [minX, maxX] = useMemo(() => {
    const min = Math.min(...data.flatMap((series) => series.map((d) => d.x)));
    let max = Math.max(...data.flatMap((series) => series.map((d) => d.x)));
    max += (max - min) / maxIndex; // Add an extra step
    return [min, max] as const;
  }, [data, maxIndex]);

  // Calculate the min and max Y values of the data
  const [minY, maxY] = useMemo(() => {
    const min = 0;
    let max = Math.max(...data.flatMap((series) => series.map((d) => d.y)));
    if (min === max) max += 1; // Ensure there is a difference between min and max
    return [min, max] as const;
  }, [data]);

  // Calculate the colors for each series
  const getSeriesColors = useCallback(
    (seriesIndex: number) => {
      const value = colors[seriesIndex % colors.length] ?? '#000000';
      const color = Color(value);
      return {
        id: value,
        annotation: color.alpha(0.6).hexa(),
        stroke: color.alpha(0.3).hexa(),
        fill: {
          from: color.alpha(0.1).hexa(),
          to: color.alpha(0.05).hexa(),
        },
      };
    },
    [colors]
  );

  // Calculate the scales for the area
  const [areaIndexScale, areaYScale] = useMemo(() => {
    return [
      scaleLinear({
        domain: [minIndex, maxIndex + 1],
        range: [margin.left, size.width - margin.right],
      }),
      scaleLinear({
        domain: [minY, maxY],
        range: [size.height - margin.bottom, margin.top],
      }),
    ] as const;
  }, [minIndex, maxIndex, minY, maxY, size, margin]);

  // Calculate the scales for the path
  const [pathIndexScale, pathXScale, pathYScale] = useMemo(() => {
    return [
      scaleLinear({
        domain: [minIndex, maxIndex + 1],
        range: [
          margin.left + PATH_STROKE_DELTA,
          size.width - margin.right - PATH_STROKE_DELTA,
        ],
      }),
      scaleLinear({
        domain: [minX, maxX],
        range: [
          margin.left + PATH_STROKE_DELTA,
          size.width - margin.right - PATH_STROKE_DELTA,
        ],
      }),
      scaleLinear({
        domain: [minY, maxY],
        range: [
          size.height - margin.bottom - PATH_STROKE_DELTA,
          margin.top + PATH_STROKE_DELTA,
        ],
      }),
    ] as const;
  }, [minIndex, maxIndex, minX, maxX, minY, maxY, size, margin]);

  const {
    tooltipOpen,
    tooltipData,
    showTooltip,
    hideTooltip,
    tooltipLeft,
    tooltipTop,
  } = useTooltip<TooltipData>();
  const { containerRef, TooltipInPortal } = useTooltipInPortal({
    detectBounds: true,
    scroll: true,
  });
  const isTooltipVisible = !!tooltipOpen && !!tooltipData && !!renderTooltip;

  const getAnnotationX = useCallback(
    (seriesIndex: number, index: number) => {
      return annotations?.at(seriesIndex)?.at(index)?.x ?? 0;
    },
    [annotations]
  );

  const handleTooltipMove = useCallback(
    (event: MouseEvent<SVGRectElement>) => {
      const point = localPoint(event) || { x: 0, y: 0 };

      const index = Math.floor(pathIndexScale.invert(point.x));
      if (index < 0 || index > maxIndex) {
        hideTooltip();
        return;
      }

      // Find the closest annotation to the point
      const nearbyAnnotations = annotations?.flatMap((series, seriesIndex) => {
        const indices = series.flatMap((annotation, index) => {
          const annotationX = pathXScale(annotation.x);
          const distance = Math.abs(annotationX - point.x);
          if (distance < ANNOTATION_WIDTH / 2) {
            return { seriesIndex, index, distance };
          }
          return [];
        });
        const closest = indices.sort((a, b) => a.distance - b.distance).at(0);
        return closest ?? [];
      });
      const closestAnnotation = nearbyAnnotations
        ?.sort((a, b) => a.distance - b.distance)
        .at(0);

      // Always prioritize an annotation over a data point
      let tooltipData: TooltipData;
      if (closestAnnotation) {
        tooltipData = {
          type: 'annotation',
          seriesIndex: closestAnnotation.seriesIndex,
          index: closestAnnotation.index,
        };
      } else {
        tooltipData = { type: 'data', index };
      }

      showTooltip({
        tooltipData,
        tooltipLeft: point.x,
        tooltipTop: point.y,
      });
    },
    [
      showTooltip,
      hideTooltip,
      maxIndex,
      pathIndexScale,
      pathXScale,
      annotations,
    ]
  );

  return (
    <>
      <svg
        ref={composeRefs(svgRef, containerRef) as Ref<SVGSVGElement>}
        className={cn(css.svg, className)}
      >
        {/* Define gradients for each series */}
        <defs>
          {colors.map((_, index) => {
            const config = getSeriesColors(index);
            return (
              <LinearGradient
                key={config.id}
                id={`${id}-gradient-${config.id}`}
                from={config.fill.from}
                to={config.fill.to}
              />
            );
          })}
        </defs>

        {/* Render hover highlights */}
        {isTooltipVisible && tooltipData.type === 'annotation' && (
          <HoverHighlight
            x={
              pathXScale(
                getAnnotationX(tooltipData.seriesIndex, tooltipData.index)
              ) -
              PATH_STROKE_DELTA -
              ANNOTATION_WIDTH / 2
            }
            y={0}
            width={ANNOTATION_WIDTH + PATH_STROKE_WIDTH}
            height={size.height - margin.bottom}
            radius={2}
          />
        )}
        {isTooltipVisible && tooltipData.type === 'data' && (
          <HoverHighlight
            x={pathIndexScale(tooltipData.index) - PATH_STROKE_DELTA}
            y={0}
            width={
              (size.width - margin.left - margin.right) / (maxIndex + 1) +
              PATH_STROKE_WIDTH
            }
            height={size.height - margin.bottom}
            radius={size.width < 400 ? 3 : 4}
          />
        )}

        {/* Render axes */}
        <Axes
          xAxisRef={xAxisRef}
          yAxisRef={yAxisRef}
          data={data}
          indexScale={pathIndexScale}
          yScale={pathYScale}
          margin={margin}
          size={size}
          renderXLabel={renderXLabel}
          renderYLabel={renderYLabel}
        />

        {/* Render data series */}
        {data.map((series, seriesIndex) => {
          // Add a duplicate of the last point to the end of the series
          // to ensure the last step is visible
          const seriesData = [...series, series[series.length - 1]];
          const colors = getSeriesColors(seriesIndex);
          return (
            <Fragment key={seriesIndex}>
              <AreaClosed
                data={seriesData}
                x={(_, i) => areaIndexScale(i)}
                y={(d) => areaYScale(d.y)}
                yScale={areaYScale}
                curve={curveStepAfter}
                fill={`url(#${id}-gradient-${colors.id})`}
              />
              <LinePath
                data={seriesData}
                x={(d, i) => pathIndexScale(i)}
                y={(d) => pathYScale(d.y)}
                curve={curveStepAfter}
                stroke={colors.stroke}
                strokeWidth={PATH_STROKE_WIDTH}
                strokeLinejoin="round"
                strokeLinecap="round"
              />
            </Fragment>
          );
        })}

        {/* Render annotations */}
        {annotations?.map((series, seriesIndex) => {
          const colors = getSeriesColors(seriesIndex);
          return (
            <Fragment key={seriesIndex}>
              {series.map((annotation) => {
                return (
                  <Line
                    key={annotation.x}
                    x1={pathXScale(annotation.x)}
                    y1={PATH_STROKE_WIDTH}
                    x2={pathXScale(annotation.x)}
                    y2={size.height - margin.bottom - PATH_STROKE_WIDTH}
                    stroke={colors.annotation}
                    strokeWidth={PATH_STROKE_WIDTH}
                    strokeDasharray="1.5,3"
                    strokeLinecap="round"
                  />
                );
              })}
            </Fragment>
          );
        })}

        {/* Hover area for tooltip */}
        <rect
          x={margin.left}
          y={0}
          width={size.width - margin.left}
          height={size.height - margin.bottom}
          fill="transparent"
          onPointerMove={handleTooltipMove}
          onPointerLeave={hideTooltip}
        />
      </svg>

      {isTooltipVisible && (
        <TooltipInPortal
          key={Math.random()}
          unstyled
          className={css.tooltipPortal}
          top={tooltipTop}
          left={tooltipLeft}
        >
          <Tooltip
            tooltip={tooltipData}
            data={data}
            annotations={annotations}
            renderTooltip={renderTooltip}
          />
        </TooltipInPortal>
      )}
    </>
  );
}

SteppedLineChart.Skeleton = SteppedLineChartSkeleton;

export { SteppedLineChart };
