import * as d3 from 'd3'
import { useCallback, useEffect, useRef, useState } from 'react'
import { ChartTooltip, ChartTooltipItem } from './ChartTooltip'
import { cleanFloatingPointErrors, round } from './utils/utils'
import { AxisCalculator } from './utils/AxisCalculator'

export interface HistogramDataItem {
  label: number
  value:
    | number
    | {
        [key: string]: number
      }
  color:
    | string
    | {
        [key: string]: string
      }
  totalValue?: number
  bucketRange?: [number, number]
}

/**
 * Standardizes the data for the histogram chart
 * Buckets are [label, label + bucketSize)
 * @param data The data to be standardized
 * @returns The standardized data
 */
export type StandardizedHistogramDataItem = {
  label: number
  value: number
  color: string
  valueBeneath?: number
  category?: string
  originalDataItem?: HistogramDataItem
}

const standardizeData = (data: HistogramDataItem[]): StandardizedHistogramDataItem[] => {
  return data.flatMap((item): StandardizedHistogramDataItem[] => {
    if (typeof item.value === 'number') {
      return [
        {
          label: +item.label,
          value: item.value,
          color: item.color as string,
          originalDataItem: item,
        },
      ]
    } else {
      let valueBeneath = 0
      return Object.entries(item.value).flatMap(([category, value]) => {
        if (value === 0) return []
        const result = [
          {
            label: +item.label,
            category,
            value,
            valueBeneath,
            originalDataItem: item,
            color: (item.color as { [key: string]: string })[category],
          },
        ]
        valueBeneath += value
        return result
      })
    }
  })
}

const processData = (data: HistogramDataItem[], bucketSize: number): HistogramDataItem[] => {
  if (data.length === 0) return []
  const preppedData = data.map((item) => {
    if (typeof item.value === 'number') {
      return {
        ...item,
        label: +item.label,
        totalValue: item.value,
        bucketRange: [+item.label, +item.label + bucketSize] as [number, number],
      }
    } else {
      const totalValue = Object.values(item.value).reduce((acc, curr) => acc + curr, 0)
      return {
        ...item,
        label: +item.label,
        totalValue,
        bucketRange: [+item.label, +item.label + bucketSize] as [number, number],
      }
    }
  })
  const xmin = Math.min(...preppedData.map((item) => item.label))
  const xmax = Math.max(...preppedData.map((item) => item.label))
  const newXMin = Math.min(Math.max(0, xmax - bucketSize * 8), xmin)
  const newXMax = Math.max(xmax, xmin + bucketSize * 8)
  const fullData = [] as HistogramDataItem[]
  for (
    let bucket = newXMin;
    bucket <= newXMax;
    bucket = cleanFloatingPointErrors(bucket + bucketSize)
  ) {
    if (!preppedData.some((item) => item.label === bucket)) {
      fullData.push({
        label: bucket,
        value: 0,
        color: '#f0f0f0',
        totalValue: 0,
        bucketRange: [bucket, cleanFloatingPointErrors(bucket + bucketSize)] as [number, number],
      })
    } else {
      fullData.push(preppedData.find((item) => item.label === bucket) as HistogramDataItem)
    }
  }
  return fullData
}

export const MakeHistogramSvg = (
  element: HTMLElement,
  data: HistogramDataItem[],
  bucketSize: number,
  xTitle = '',
  yTitle = '',
  fullWidth = 300,
  updateTooltip?: (args: { show: boolean; data: HistogramDataItem }) => void,
) => {
  // set the dimensions and margins of the graph
  const fullHeight = 200
  const margin = { top: 15, right: 40, bottom: 35, left: 50 },
    width = fullWidth - margin.left - margin.right,
    height = fullHeight - margin.top - margin.bottom

  data = processData(data, bucketSize)
  if (data.length === 0) return null
  const yCalculator = new AxisCalculator([...data.map((d) => d.totalValue ?? 0), 0])

  const standardizedData = standardizeData(data)

  d3.select(element).selectAll('svg').remove()
  const svg = d3
    .select(element)
    .append('svg')
    .attr('version', '1.1')
    .attr('xmlns', 'http://www.w3.org/2000/svg')
    .attr('width', fullWidth)
    .attr('height', fullHeight)
    .attr('font-family', 'Inter')
    .attr('font-size', '8pt')
    .attr('viewBox', `0 0 ${fullWidth} ${fullHeight}`)
    .attr('color', '#8392a1')
    .append('g')
    .attr('transform', `translate(${margin.left},${margin.top})`)

  const labels = data.map((d) => d.label)
  const xDomain = [...labels, Math.max(...labels) + bucketSize]
  const xAxisScale = d3
    .scaleBand()
    .range([0, width + width / (xDomain.length - 1)])
    .domain(
      xDomain.map((d) =>
        d.toLocaleString(undefined, {
          maximumFractionDigits: 2,
          useGrouping: false,
        }),
      ),
    )
    .padding(0.2)

  const yAxisScale = d3.scaleLinear().domain(yCalculator.domain).range([height, 0])

  // Axises
  const xTicks = svg
    .append('g')
    .attr('transform', `translate(${-xAxisScale.bandwidth() * 0.625}, ${height})`)
    .call(d3.axisBottom(xAxisScale).tickSize(3).tickValues(xAxisScale.domain()))
    .selectAll('g')

  xTicks.each((_, i, nodes) => {
    const bw = xAxisScale.bandwidth()
    if (bw < 20 && bw > 12 && i % 2 === 0) {
      d3.select(nodes[i]).remove()
    } else if (bw < 12 && i % 3 !== 0) {
      d3.select(nodes[i]).remove()
    }
  })
  xTicks
    .filter((_, i) => i === xAxisScale.domain().length)
    .attr('transform', `translate(${width + xAxisScale.bandwidth() * 0.5}, 0)`)

  svg
    .append('g')
    .call(
      d3
        .axisLeft(yAxisScale)
        .ticks(yCalculator.tickCount)
        .tickSize(0)
        .tickSizeOuter(0)
        .tickFormat((t) => yCalculator.roundValue(t.valueOf())),
    )
    .selectAll('.tick line')
    .attr('x2', width)
    .attr('x1', 1)
    .attr('stroke', '#e0e4e8')
    .attr('stroke-width', 1)
    .filter((_, i) => i === 0)
    .remove()
  svg.selectAll('.domain').attr('stroke', '#53687e').attr('stroke-width', 0)

  const selectionBox = svg
    .append('rect')
    .attr('id', 'selection-box')
    .attr('x', 0)
    .attr('y', 0)
    .attr('width', width)
    .attr('height', height)
    .attr('fill', 'grey')
    .attr('opacity', 0)
    .attr('rx', 1)
    .attr('ry', 1)
    .style('transition', 'all 0.1s ease-in-out')

  // Bars
  svg
    .selectAll('mybar')
    .data<StandardizedHistogramDataItem>(standardizedData.toReversed())
    .join('rect')
    .attr('x', (d) => xAxisScale('' + d.label) ?? 0)
    .attr('y', (d) => yAxisScale(d.value + (d.valueBeneath ?? 0)))
    .attr('width', (d) => {
      if (d.value === 0) return xAxisScale.bandwidth() * 1.25
      return xAxisScale.bandwidth()
    })
    .attr('height', (d) => Math.max(1, height - yAxisScale(d.value + (d.valueBeneath ?? 0))))
    .attr('fill', (d) => {
      if (d.value === 0) return '#e0e4e8'
      return d.color
    })
    .attr('rx', 4)
    .attr('ry', 4)

  svg
    .selectAll('hitbox')
    .data<StandardizedHistogramDataItem>(standardizedData.toReversed())
    .join('rect')
    .attr('x', (d) => (xAxisScale('' + d.label) ?? 0) - xAxisScale.bandwidth() * 0.125)
    .attr('y', 0)
    .attr('width', xAxisScale.bandwidth() * 1.25)
    .attr('height', height)
    .attr('fill', 'transparent')
    .on('mouseover', (_e, data) => {
      if (updateTooltip && data.originalDataItem) {
        updateTooltip({ show: true, data: data.originalDataItem })
        selectionBox
          .attr('x', (xAxisScale('' + data.label) ?? 0) + 4 + xAxisScale.bandwidth() * 0)
          .attr('width', xAxisScale.bandwidth() * 1 - 8)
          .attr('opacity', 0.5)
          .attr('fill', data.color)
          .attr('height', 2)
          .attr('y', height + 2)
      }
    })
    .on('mouseout', (_e, data) => {
      if (updateTooltip && data.originalDataItem) {
        updateTooltip({ show: false, data: data.originalDataItem })
        selectionBox.attr('opacity', 0)
      }
    })

  // Add X axis title:
  if (xTitle)
    svg
      .append('text')
      .attr('fill', '#8392a1')
      .attr('text-anchor', 'middle')
      .attr('x', width / 2)
      .attr('y', height + margin.top + margin.bottom - 20)
      .text(xTitle)

  // Y axis title:
  if (yTitle)
    svg
      .append('text')
      .attr('fill', '#8392a1')
      .attr('text-anchor', 'middle')
      .attr('transform', 'rotate(-90)')
      .attr('y', -margin.left + 10)
      .attr('x', -height / 2)
      .text(yTitle)
}

interface BarchartProps {
  data: HistogramDataItem[]
  bucketSize: number
  xTitle?: string
  yTitle?: string
  tooltipFormatter?: ({ value, label, color }: HistogramDataItem) => ChartTooltipItem[]
  tooltipHeaderFormatter?: ({
    bucketRangeStart,
    bucketRangeEnd,
  }: {
    bucketRangeStart: number
    bucketRangeEnd: number
  }) => string
}

const cleanDataItem = (data: HistogramDataItem): HistogramDataItem => {
  if (typeof data.value === 'number') {
    return { ...data, value: Math.round(data.value * 10) / 10 }
  } else {
    const value = data.value as { [key: string]: number }
    const color = data.color as { [key: string]: string }
    const keys = Object.keys(value)
    const cleanedValue = keys.reduce((acc, key) => {
      if (+value[key] !== 0) {
        acc[key] = Math.round(value[key] * 1000) / 1000
      }
      return acc
    }, {} as { [key: string]: number })
    const cleanedColor = keys.reduce((acc, key) => {
      if (+value[key] !== 0) {
        acc[key] = color[key]
      }
      return acc
    }, {} as { [key: string]: string })
    return { ...data, value: cleanedValue, color: cleanedColor }
  }
}

export const Histogram = ({
  data,
  bucketSize,
  xTitle,
  yTitle,
  tooltipFormatter,
  tooltipHeaderFormatter,
}: BarchartProps) => {
  const ref = useRef<HTMLDivElement>(null)
  const [width, setWidth] = useState<number | null>(null)
  const [showTooltip, setShowTooltip] = useState(false)
  const [tooltipHeader, setTooltipHeader] = useState('')
  const [tooltipItems, setTooltipItems] = useState<ChartTooltipItem[]>([])

  const updateTooltip = useCallback(
    ({ show, data }: { show: boolean; data: HistogramDataItem }) => {
      data = cleanDataItem(data)
      setShowTooltip(show)
      const bucketRangeStart = data.bucketRange?.[0].toLocaleString(undefined, {
        maximumFractionDigits: 2,
        useGrouping: false,
      })
      const bucketRangeEnd = data.bucketRange?.[1].toLocaleString(undefined, {
        maximumFractionDigits: 2,
        useGrouping: false,
      })
      if (tooltipHeaderFormatter) {
        setTooltipHeader(
          tooltipHeaderFormatter({
            bucketRangeStart: +(bucketRangeStart ?? 0),
            bucketRangeEnd: +(bucketRangeEnd ?? 0),
          }),
        )
      } else {
        setTooltipHeader(`${bucketRangeStart} to ${bucketRangeEnd}`)
      }
      if (tooltipFormatter) {
        setTooltipItems(tooltipFormatter(data))
      } else if (data) {
        if (typeof data.value === 'number') {
          setTooltipItems([
            {
              text: round(data.value).toLocaleString(),
              color: data.color as string,
            },
          ])
        } else {
          const values = data.value as { [key: string]: number }
          const colors = data.color as { [key: string]: string }
          setTooltipItems(
            Object.keys(data.value)
              .toReversed()
              .map((key) => {
                return {
                  text: round(values[key]).toLocaleString(),
                  color: colors[key],
                }
              }),
          )
        }
      }
    },
    [tooltipFormatter, tooltipHeaderFormatter],
  )

  useEffect(() => {
    if (!ref.current) return
    const resizeObserver = new ResizeObserver((event) => {
      // Round width to nearest 50
      const newWidth = Math.round(event[0].contentBoxSize[0].inlineSize / 50) * 50
      setWidth(newWidth)
    })

    resizeObserver.observe(ref.current)
  }, [])

  useEffect(() => {
    if (!ref.current || width == null) return
    MakeHistogramSvg(ref.current, data, +bucketSize, xTitle, yTitle, width, updateTooltip)
  }, [data, xTitle, yTitle, width, updateTooltip, bucketSize])

  if (data.length === 0) return null
  return (
    <ChartTooltip header={tooltipHeader} show={showTooltip} items={tooltipItems}>
      <div ref={ref} css={{ minHeight: '200px' }}></div>
    </ChartTooltip>
  )
}
