import { SteppedLineChartData } from '@vault/SteppedLineChart';
import { Margin, Size } from '@vault/SteppedLineChart/types';
import { AxisBottom, AxisLeft } from '@visx/axis';
import { GridRows } from '@visx/grid';
import createLinearScale from '@visx/scale/lib/scales/linear';
import { RefObject, useCallback, useEffect, useMemo } from 'react';

function hideOverlappingXTicks(axis: SVGGElement, gap = 12) {
  const ticks = axis.querySelectorAll('g.visx-axis-tick');
  if (ticks.length <= 1) return;

  const rects = Array.from(ticks).map((tick) => tick.getBoundingClientRect());
  const maxWidth = Math.max(...rects.map((r) => r.width));

  // Get lowest common index distance between ticks
  let index = 1;
  while (
    index < rects.length - 1 &&
    rects[0].left + maxWidth + gap > rects[index].left
  ) {
    index++;
  }

  // Only show every nth tick
  ticks.forEach((tick, i) => {
    if (i % index === 0) {
      tick.setAttribute('opacity', '1');
    } else {
      tick.setAttribute('opacity', '0');
    }
  });
}

export interface AxesProps<Data extends SteppedLineChartData> {
  xAxisRef?: RefObject<SVGGElement>;
  yAxisRef?: RefObject<SVGGElement>;
  data: Data[][];
  indexScale: ReturnType<typeof createLinearScale<number>>;
  yScale: ReturnType<typeof createLinearScale<number>>;
  margin: Margin;
  size: Size;
  renderXLabel?: (x: number) => string;
  renderYLabel?: (y: number) => string;
}

export function Axes<Data extends SteppedLineChartData>({
  xAxisRef,
  yAxisRef,
  data,
  indexScale,
  yScale,
  margin,
  size,
  renderXLabel,
  renderYLabel,
}: AxesProps<Data>) {
  const yTickFormat = useCallback(
    (tick: { valueOf(): number }) => {
      if (!renderYLabel) return undefined;
      return renderYLabel(tick.valueOf());
    },
    [renderYLabel]
  );

  const xTickFormat = useCallback(
    (tick: { valueOf(): number }) => {
      if (!renderXLabel) return undefined;
      const dataPt = data.at(0)?.at(tick.valueOf());
      if (!dataPt) return undefined;
      return renderXLabel(dataPt.x);
    },
    [renderXLabel, data]
  );

  // Show around 5 integer ticks
  const yTickValues = useMemo(() => {
    const ticks = yScale.ticks(5).filter((t) => t % 1 === 0);
    if (ticks.length <= 1) return ticks;

    // Include the max value if it's not already there, and if it's more than 0.5 steps away
    const lastTick = ticks[ticks.length - 1];
    const maxY = yScale.domain()[1];
    if (lastTick !== maxY) {
      const step = ticks[1] - ticks[0];
      const delta = maxY - lastTick;
      if (delta / step > 0.5) {
        ticks.push(maxY);
      }
    }

    return ticks;
  }, [yScale]);

  // Show as many unique, integer ticks as possible
  const xTickValues = useMemo(() => {
    const uniqueValues = new Set<string>();
    const ticks = indexScale.ticks().filter((t) => {
      if (t % 1 !== 0) return false;

      const formatted = xTickFormat(t);
      if (!formatted) return false;

      if (uniqueValues.has(formatted)) return false;
      uniqueValues.add(formatted);
      return true;
    });
    return ticks;
  }, [xTickFormat, indexScale]);

  useEffect(() => {
    const xAxis = xAxisRef?.current;
    if (xAxis) hideOverlappingXTicks(xAxis);
  }, [xTickValues, size.width]);

  return (
    <>
      {renderYLabel && (
        <GridRows
          scale={yScale}
          left={margin.left}
          width={size.width - margin.left - margin.right}
          tickValues={yTickValues}
          stroke="#E9E5F5"
          strokeDasharray="1,2.5"
          strokeLinecap="round"
        />
      )}

      {renderXLabel && (
        <AxisBottom
          innerRef={xAxisRef}
          top={size.height - margin.bottom}
          scale={indexScale}
          orientation="bottom"
          hideTicks
          hideAxisLine
          tickValues={xTickValues}
          tickFormat={xTickFormat}
          tickTransform="translate(0 -2)"
          tickLabelProps={{
            textAnchor: 'start',
            fontFamily: 'var(--font-family-base)',
            fill: 'var(--text-secondary)',
          }}
        />
      )}

      {renderYLabel && (
        <AxisLeft
          innerRef={yAxisRef}
          left={margin.left}
          scale={yScale}
          tickValues={yTickValues}
          hideTicks
          hideAxisLine
          tickFormat={yTickFormat}
          tickLabelProps={{
            fontFamily: 'var(--font-family-base)',
            fill: 'var(--text-secondary)',
          }}
        />
      )}
    </>
  );
}
