import { ChartData, ChartDataset, Point } from "chart.js"
import _ from "lodash"
import { DFIBubbleDataPoint } from "./charts/scatterplot"
import * as colors from "./colors"
import { updateOpacity } from "./colors"
import { DataContextType } from "./components/chart-data-context"
import { DFIChartContextType } from "./components/context"
import { AxisConfig, ChartType } from "./config"

import {
    calculateChange,
    calculateCumulative,
    calculateGrowth,
    colorPalette,
    linearScale,
    tileIndicatorsCode,
} from "./utils"

function formatBarData(
    chartContext: DFIChartContextType,
    dctx: DataContextType,
    indicators: string[],
    indices: number[],
    title: string,
): ChartData<"bar", number[]> & { title: string } {
    const { form, table } = chartContext
    const { data } = table
    const { chart } = form.values
    const { yAxis, type } = chart
    const palette = colorPalette[chart.palette]
    const singleIndicator = chart.tiles.codes.includes(tileIndicatorsCode)

    const barDatasetConfig: Partial<ChartDataset<"bar">> = {
        barPercentage: 0.9,
        categoryPercentage: 0.9,
        borderWidth: 0,
    }
    const isDotPlot = type === ChartType.Dotplot
    if (type === ChartType.Marimekko) {
        barDatasetConfig.barPercentage = 1
        barDatasetConfig.categoryPercentage = 1
        barDatasetConfig.borderWidth = 1
        barDatasetConfig.barThickness = "flex"
    } else if (isDotPlot) {
        barDatasetConfig.barThickness = 3
    }

    const multiplier = yAxis.multiplier || 1

    // Add up ALL (not just selected) indicators if relative is selected to use as denominator
    const filteredSumCol: number[] | false =
        yAxis.relative === "relative" &&
        indices.map((i) => _.sum(yAxis.codes.map((code) => data[code][i] * multiplier)))

    const isHighlighted = Object.values(yAxis.info).some((info) => info.highlighted)

    const datasets = indicators.map((ind: string) => {
        const indicatorCol = data[ind]
        const indicatorData = _.map(indices, (i) => indicatorCol[i] * multiplier)
        if (filteredSumCol) {
            indicatorData.forEach((val, i) => {
                indicatorData[i] = _.round((val / filteredSumCol[i]) * 100, 2)
            })
        }
        const backgroundColor = isDotPlot
            ? "#eee"
            : singleIndicator
              ? updateOpacity(palette[0])
              : dctx.legendData[ind]

        return {
            ...barDatasetConfig,
            label: ind,
            data: indicatorData,
            backgroundColor:
                !isHighlighted || yAxis.info[ind]?.highlighted
                    ? backgroundColor
                    : colors.desaturate(backgroundColor),
            hoverBackgroundColor: backgroundColor,
            color: backgroundColor,
        }
    })

    const labels = dctx.dimensions.values.length
        ? _.map(indices, (i) => dctx.getDisplayLabel("dimensions", dctx.dimensions.values[i]))
        : [""]

    if (isDotPlot) {
        addDotPlotDatasets(dctx, datasets, labels)
    }

    return {
        title,
        labels,
        datasets,
    }
}

function addDotPlotDatasets(dctx: DataContextType, datasets: any, labels: string[]) {
    const datasetLength = datasets.length
    const dotPlotDatasets = []
    for (let i = 0; i < datasetLength; i++) {
        const dataset = datasets[i]
        dataset.order = 2
        const color = dctx.legendData[datasets[i].label]
        dotPlotDatasets.push({
            type: "scatter",
            order: 1,
            label: dataset.label,
            backgroundColor: color,
            hoverBackgroundColor: color,
            color,
            borderWidth: 0,
            pointRadius: 7,
            hoverRadius: 7,
            data: dataset.data.map((val: number, index: number) => ({ x: val, y: labels[index] })),
        })
    }
    // insert dotplot datasets at the beginning
    datasets.splice(0, 0, ...dotPlotDatasets)
}

function formatLineData(
    chartContext: DFIChartContextType,
    dctx: DataContextType,
    indicators: string[],
    indices: number[],
    title: string,
): ChartData<"line", Point[]> & { title: string } {
    const { form, table } = chartContext
    const { data } = table
    const { chart } = form.values
    const { yAxis, dimensions } = chart

    const colorByDimensions =
        chart.yAxis.codes.length == 1 ||
        chart.yAxis.singleSelect ||
        chart.tiles.codes.includes(tileIndicatorsCode)

    const colorInfo = colorByDimensions ? dimensions.info : yAxis.info
    const isHighlighted = Object.values(colorInfo).some((info) => info.highlighted)

    const multiplier = yAxis.multiplier || 1

    const datasets: { [ind: string]: { [label: string]: Point[] } } = {}
    indicators.forEach((ind) => {
        datasets[ind] = {}
        const indicatorCol = data[ind]

        indices.forEach((i) => {
            const dimValue = dctx.dimensions.values[i] ?? ""
            const timeValue = dctx.time.values[i] ?? null

            // initialize dataset if not present
            datasets[ind][dimValue] = datasets[ind][dimValue] ?? []
            if (indicatorCol[i] !== null)
                datasets[ind][dimValue].push({ x: timeValue, y: indicatorCol[i] * multiplier })
        })
    })

    const transform = yAxis.transform ?? "absolute"
    const stacked = yAxis.grouped === "stack"

    return {
        title,
        datasets: _.flatMap(indicators, (ind, index) => {
            const indLabel = dctx.getDisplayLabel("yAxis", ind)
            return _.map(Object.keys(datasets[ind]), (dimValue, dimIndex) => {
                const values = datasets[ind][dimValue]
                const color = colorByDimensions
                    ? dctx.legendData[dimValue] ?? updateOpacity(colorPalette.default[0])
                    : dctx.legendData[ind]

                // when tiled by all dimensions, each line represents an indicator
                const tileByDimensions = dimValue === ind
                const dimLabel = tileByDimensions
                    ? dctx.getDisplayLabel("yAxis", ind)
                    : dctx.getDisplayLabel("dimensions", dimValue)

                const lineLabel =
                    dimLabel && indicators.length > 1
                        ? `${dimLabel} - ${indLabel}`
                        : dimLabel || indLabel

                const lineColor = stacked
                    ? colors.updateOpacity(color, 1)
                    : colors.updateOpacity(color)
                const colorVal = colorByDimensions ? dimValue : ind
                const highlightLine = !isHighlighted || colorInfo[colorVal]?.highlighted

                return {
                    label: lineLabel,
                    data: transformData(values, transform),
                    backgroundColor: highlightLine ? lineColor : colors.desaturate(lineColor),
                    color: lineColor,
                    fill: stacked,
                    borderColor: updateOpacity(highlightLine ? color : colors.desaturate(color)),
                    pointRadius: 0,
                    pointHoverRadius: 2.5,
                    pointHitRadius: 2.5,
                    borderWidth: 1.75,
                    segment: {
                        borderDash: (ctx: any) => {
                            if (chart.projectionStart && ctx.p0.raw.x >= chart.projectionStart)
                                return [7, 4]
                        },
                    },
                }
            })
        }),
    }
}

function formatScatterData(
    chartContext: DFIChartContextType,
    dctx: DataContextType,
    indicators: string[],
    indices: number[],
    title: string,
): ChartData<"bubble", DFIBubbleDataPoint[]> & { title: string } {
    const { form, table } = chartContext
    const { data } = table
    const { chart } = form.values
    const { xAxis, yAxis, size, dimensions } = chart
    const yAxisCode = indicators[0]
    const xAxisCode = xAxis?.selected[0]
    const sizeCode = size?.selected[0]
    const palette = colorPalette[chart.palette]

    if (!xAxisCode || !yAxisCode) {
        return { title, datasets: [] }
    }

    // create a linear scale from 3 to 20 based on sqrt of size
    const minSize = Math.sqrt(dctx.size.min ?? 0),
        maxSize = Math.sqrt(dctx.size.max ?? 0)
    const sizeScale = linearScale(minSize, maxSize, 3, 20)

    // separate indices into selected and non-selected
    const dimensionSet = new Set(dimensions.selected)
    const [selected, non_selected] = _.partition(indices, (i) =>
        dimensionSet.has(dctx.dimensions.values[i]),
    )

    function getDataPoints(indices: number[], selected: boolean) {
        const dataPoints = []
        for (let index of indices) {
            const xVal = data[xAxisCode!][index]
            const yVal = data[yAxisCode][index]
            if (xVal === null || yVal === null) continue

            const sizeVal = sizeCode ? data[sizeCode][index] : null
            const radius = sizeScale(Math.sqrt(sizeVal)) // handles null sizeVal
            dataPoints.push({
                // Even though we return early if xAxisCode is missing, TS does
                // not recognize this and still complains about x not being defined
                x: xVal * (xAxis?.multiplier || 1),
                y: yVal * (yAxis?.multiplier || 1),
                r: selected ? radius + 2 : radius,
                i: index,
            })
        }
        return dataPoints
    }

    return {
        title,
        datasets: [
            {
                order: 1,
                label: yAxisCode,
                data: getDataPoints(selected, true),
                hoverRadius: 0,
                backgroundColor: (context) => {
                    const point = context.raw as DFIBubbleDataPoint
                    const colorKey = dctx.colors.values[point?.i]
                    return dctx.legendData[colorKey] ?? updateOpacity(palette[0])
                },
            },
            {
                order: 2,
                label: yAxisCode,
                data: getDataPoints(non_selected, false),
                hoverRadius: 0,
                backgroundColor: "#eee",
            },
        ],
    }
}

function transformData(data: Point[], type: AxisConfig["transform"]) {
    const sortedData = _.orderBy(data, "x")
    if (type === "change") calculateChange(sortedData)
    else if (type === "growth") calculateGrowth(sortedData)
    else if (type === "cumulative") calculateCumulative(sortedData)
    return sortedData
}

function formatMapData(
    chartContext: DFIChartContextType,
    dctx: DataContextType,
    indicators: string[],
    indices: number[],
    title: string,
) {
    const { form, table } = chartContext
    const { data } = table
    const { chart } = form.values
    if (!indicators.length && !chart.choropleth?.codes?.length) return { title, values: {} }

    const indicator = indicators[0]
    const mapCode = chart.choropleth?.codes?.[0]
    const multiplier = chart.yAxis.multiplier || 1
    const values: any = {}
    if (mapCode)
        for (let i of indices) {
            values[data[mapCode][i]] = data[indicator][i] * multiplier
        }

    return { title, values }
}

function formatChartData(
    chartContext: DFIChartContextType,
    dctx: DataContextType,
    indicators: string[],
    indices: number[],
    title: string,
) {
    const { form } = chartContext

    switch (form.values.chart.type) {
        case ChartType.LineChart:
            return formatLineData(chartContext, dctx, indicators, indices, title)
        case ChartType.ScatterPlot:
            return formatScatterData(chartContext, dctx, indicators, indices, title)
        case ChartType.Map:
            return formatMapData(chartContext, dctx, indicators, indices, title)
        default:
            return formatBarData(chartContext, dctx, indicators, indices, title)
    }
}

export function generateDatasets(chartContext: DFIChartContextType, dctx: DataContextType) {
    const { form } = chartContext
    const { tiles, yAxis } = form.values.chart
    if (tiles.codes.length) {
        if (!tiles.codes.includes(tileIndicatorsCode)) {
            // tile by tile dimensions only
            return tiles.selected.map((tile: string) => {
                const indices = dctx.filteredIndices.filter((i) => dctx.tiles.values[i] === tile)
                const tileLabel = dctx.getDisplayLabel("tiles", tile)

                return formatChartData(chartContext, dctx, yAxis.selected, indices, tileLabel)
            })
        } else {
            if (tiles.codes.length === 1) {
                // tile is indicators only
                return yAxis.selected.map((ind: string) => {
                    const indLabel = dctx.getDisplayLabel("yAxis", ind)
                    return formatChartData(
                        chartContext,
                        dctx,
                        [ind],
                        dctx.filteredIndices,
                        indLabel,
                    )
                })
            }

            const datasets: any[] = []
            for (let tile of tiles.selected) {
                const indices = dctx.filteredIndices.filter((i) => dctx.tiles.values[i] === tile)

                for (let ind of yAxis.selected) {
                    const indLabel = dctx.getDisplayLabel("yAxis", ind)
                    const tileLabel = dctx.getDisplayLabel("tiles", tile)
                    const tileInd = `${tileLabel} - ${indLabel}`
                    datasets.push(formatChartData(chartContext, dctx, [ind], indices, tileInd))
                }
            }
            return datasets
        }
    }
    return [formatChartData(chartContext, dctx, yAxis.selected, dctx.filteredIndices, "")]
}
