import { notifications } from "@mantine/notifications"
import _ from "lodash"
import { PropsWithChildren, createContext, useContext, useEffect, useMemo, useState } from "react"
import { dt, useShallowDeps, useShallowMemo } from "../../../dfi-utils"
import { interpolateColor, updateOpacity } from "../colors"
import { ChartType } from "../config"
import { colorPalette, tileIndicatorsCode } from "../utils"
import { useChart } from "./context"

interface UpdateColumnsType {
    time?: string[]
    tiles?: string[]
    dimensions?: string[]
}

export interface DataContextDimension<T extends string | number> {
    values: T[]
    value_labels: Map<T, string>
    nunique: number
}

export interface DataContextIndicator {
    min?: number
    max?: number
}

export interface DataContextIndicatorColumns {
    xAxis: DataContextIndicator
    yAxis: DataContextIndicator
    size: DataContextIndicator
}

export interface DataContextDimensionColumns {
    tiles: DataContextDimension<string>
    dimensions: DataContextDimension<string>
    colors: DataContextDimension<string>
}

export interface DataContextColumns
    extends DataContextIndicatorColumns,
        DataContextDimensionColumns {
    time: DataContextDimension<number> & DataContextIndicator
}

export interface DataContextType extends DataContextColumns {
    sortedIndices: number[]
    filteredIndices: number[]
    legendData: Record<string, string>

    updateColumns: (update: { time?: string[]; tiles?: string[]; dimensions?: string[] }) => void
    getAxisLimit: (
        axisName: keyof DataContextIndicatorColumns,
        type: "min" | "max",
    ) => number | undefined
    getDisplayLabel: (colname: keyof DataContextColumns, value: string) => string
}

const DataContext = createContext<DataContextType | null>(null)

export function useDataContext() {
    const data = useContext(DataContext)
    if (!data) {
        throw new Error("Wrap component in <DataContextProvider>")
    }
    return data
}

export const DataContextProvider: React.FC<PropsWithChildren> = ({ children }) => {
    const [defaults, setDefaults] = useState(true)

    const { table, form } = useChart()
    const { chart } = form.values

    const dctx = {} as DataContextType

    const sortDeps = useShallowDeps([chart.sort, chart.dimensions.customSort])
    const timeDeps = useShallowDeps([chart.time.codes])
    const tilesDeps = useShallowDeps([chart.tiles.codes])
    const dimensionsDeps = useShallowDeps([chart.dimensions.codes])
    const colorsDeps = useShallowDeps([chart.colors?.codes])

    // Keep track of the sorted indices
    dctx.sortedIndices = useMemo(() => {
        const indices = _.range(table.num_rows)
        if (!chart.sort.codes) return indices

        const sort_func = []
        for (const code of chart.sort.codes) {
            const values = table.data[code]
            sort_func.push((i: number) => dt.getValueLabel(table, code, values[i]))
        }
        return _.orderBy(indices, sort_func, chart.sort.order)
    }, sortDeps)

    dctx.tiles = useShallowMemo(
        () => dt.joinColumns(table, chart.tiles.codes, dctx.sortedIndices),
        [tilesDeps, sortDeps],
    )

    dctx.dimensions = useShallowMemo(
        () => dt.joinColumns(table, chart.dimensions.codes, dctx.sortedIndices),
        [dimensionsDeps, sortDeps],
    )

    dctx.colors = useShallowMemo(
        () => dt.joinColumns(table, chart.colors?.codes ?? [], dctx.sortedIndices),
        [colorsDeps, sortDeps],
    )

    dctx.time = useMemo(() => {
        const col = chart.time.codes.length ? table.data[chart.time.codes[0]] : []
        const values = col.map((v) => parseInt(v))
        const value_labels = new Map(values.map((v, i) => [v, col[i]]))
        return {
            values,
            value_labels,
            min: _.min(values),
            max: _.max(values),
            nunique: _.size(value_labels),
        }
    }, timeDeps)

    const timeSelectedDeps = useShallowDeps([chart.time.min_time, chart.time.max_time])
    const tilesSelectedDeps = useShallowDeps([chart.tiles.selected])
    const dimensionsSelectedDeps = useShallowDeps([chart.dimensions.selected])

    const yAxisSelectedDeps = useShallowDeps([chart.yAxis?.selected])
    const xAxisSelectedDeps = useShallowDeps([chart.xAxis?.selected])
    const sizeSelectedDeps = useShallowDeps([chart.size?.selected])

    const filterDeps = useShallowDeps([
        timeDeps,
        tilesDeps,
        dimensionsDeps,
        timeSelectedDeps,
        tilesSelectedDeps,
        dimensionsSelectedDeps,
        sortDeps,
        chart.type,
    ])

    dctx.filteredIndices = useMemo(() => {
        const filters: ((i: number) => boolean)[] = []
        if (dctx.time.values.length) {
            const time_max = chart.time.max_time ?? dctx.time.max!
            const time_min = chart.time.min_time ?? dctx.time.min!

            if (chart.type === ChartType.LineChart) {
                filters.push((i: number) => {
                    const value = dctx.time.values[i]
                    return value >= time_min && value <= time_max
                })
            } else {
                // Non-line charts should only show the last time value
                filters.push((i: number) => dctx.time.values[i] === time_max)
            }
        }

        if (dctx.tiles.values.length) {
            const tilesSelected = new Set(chart.tiles.selected)
            filters.push((i: number) => tilesSelected.has(dctx.tiles.values[i]))
        }

        if (dctx.dimensions.values.length && chart.type !== ChartType.ScatterPlot) {
            const dimSelected = new Set(chart.dimensions.selected)
            filters.push((i: number) => dimSelected.has(dctx.dimensions.values[i]))
        }

        const indices = dctx.sortedIndices.filter((i: number) => filters.every((f) => f(i)))
        if (chart.dimensions.customSort) {
            return customSortDimensions(indices)
        }
        return indices
    }, filterDeps)

    function customSortDimensions(indices: number[]) {
        const sortOrder = chart.dimensions.customSort ?? []
        const selectedDimensionsMap = new Map()

        for (let i = 0; i < sortOrder.length; i++) {
            selectedDimensionsMap.set(chart.dimensions.selected[sortOrder[i]], i)
        }

        const sortedIndices = _.sortBy(indices, (i) => {
            const value = dctx.dimensions.values[i]
            return selectedDimensionsMap.get(value)
        })
        return sortedIndices
    }

    // Set color seleted to the unique values color.values among the filtered indices
    if (dctx.colors.values.length) {
        chart.colors!.selected = _.uniq(_.at(dctx.colors.values, dctx.filteredIndices))
    }

    function getMinMax(codes: string[]) {
        let min: number | undefined = undefined,
            max: number | undefined = undefined
        for (const code of codes) {
            for (const i of dctx.filteredIndices) {
                const value = Number(table.data[code][i])
                min = Math.min(min ?? value, value)
                max = Math.max(max ?? value, value)
            }
        }
        return { min, max }
    }

    dctx.xAxis = useShallowMemo(
        () => getMinMax(chart.xAxis?.selected ?? []),
        [xAxisSelectedDeps, filterDeps],
    )
    dctx.yAxis = useShallowMemo(
        () => getMinMax(chart.yAxis?.selected ?? []),
        [yAxisSelectedDeps, filterDeps],
    )
    dctx.size = useShallowMemo(
        () => getMinMax(chart.size?.selected ?? []),
        [sizeSelectedDeps, filterDeps],
    )

    dctx.updateColumns = (update: UpdateColumnsType) => {
        // Note: might include tileIndicatorsCode
        const old_dimensions = [
            ...chart.tiles.codes,
            ...chart.dimensions.codes,
            ...chart.time.codes,
        ]

        update.time = update.time ?? chart.time.codes
        update.tiles = _.difference(update.tiles ?? chart.tiles.codes, update.time)
        // update.dimensions if not provided is the leftover columns
        update.dimensions = _.difference(update.dimensions ?? old_dimensions, [
            ...update.tiles,
            ...update.time,
            tileIndicatorsCode,
        ])

        // new_index is checked for uniquenes (note: dt.joinColumns will check
        // for and filter out tileIndicatorsCode from the joined columns if it is
        // present, so no need for us to do it here)
        const new_index = [...update.dimensions, ...update.time, ...update.tiles]
        const isUnique = table.index.every((i) => new_index.includes(i))
            ? true
            : dt.joinColumns(table, new_index).nunique === table.num_rows

        if (isUnique) {
            if (!_.isEqual(update.time, chart.time.codes)) {
                form.setFieldValue("chart.time", {
                    codes: update.time,
                    min_time: undefined,
                    max_time: undefined,
                })
            }

            if (_.isEqual(update.tiles, chart.dimensions.codes)) {
                form.setFieldValue("chart.tiles", chart.dimensions)
            } else if (!_.isEqual(update.tiles, chart.tiles.codes)) {
                form.setFieldValue("chart.tiles", {
                    codes: update.tiles,
                    selected: [],
                    info: {},
                })
            }

            if (_.isEqual(update.dimensions, chart.tiles.codes)) {
                form.setFieldValue("chart.dimensions", chart.tiles)
            } else if (!_.isEqual(update.dimensions, chart.dimensions.codes)) {
                form.setFieldValue("chart.dimensions", {
                    codes: update.dimensions,
                    selected: [],
                    info: {},
                    customSort: null,
                })
            }

            setDefaults(true)
        } else {
            notifications.show({ message: "Dimensions would not be unique", color: "red" })
        }
    }

    dctx.getAxisLimit = (axisName: "xAxis" | "yAxis" | "size", minmax: "min" | "max") => {
        const chartAxis = chart[axisName]
        const chartMinMax = chartAxis?.[minmax]
        const dataMinMax = dctx[axisName][minmax]
        return (
            chartMinMax ?? (chartAxis?.align && chart.tiles.codes.length ? dataMinMax : undefined)
        )
    }

    dctx.getDisplayLabel = (colname: keyof DataContextColumns, value: string) => {
        switch (colname) {
            case "tiles":
            case "dimensions":
            case "colors":
                return (
                    chart[colname]?.info?.[value]?.label ??
                    dctx[colname].value_labels.get(value) ??
                    value
                )
            case "xAxis":
            case "yAxis":
            case "size":
                return chart[colname]?.info?.[value]?.label ?? dt.getLabel(table, value)
            case "time":
                return value
        }
    }

    dctx.legendData = getLegendData()

    function getLegendData() {
        const colorByDimensions =
            chart.type === ChartType.LineChart &&
            (chart.yAxis.codes.length == 1 ||
                chart.yAxis.singleSelect ||
                chart.tiles.codes.includes(tileIndicatorsCode))

        const colname =
            chart.type === ChartType.ScatterPlot
                ? "colors"
                : colorByDimensions
                  ? "dimensions"
                  : "yAxis"

        const valueList =
            colname == "yAxis"
                ? chart.yAxis.codes
                : colname == "dimensions"
                  ? chart.dimensions.selected
                  : Array.from(dctx.colors.value_labels.keys())

        const factor = Math.max(valueList.length - 1, 1)
        const categorical = factor < 5
        const palette = categorical ? "default" : "extended"
        return Object.fromEntries(
            valueList.map((key, index) => [
                key,
                updateOpacity(
                    chart[colname]?.info[key]?.color ??
                        interpolateColor(colorPalette[palette], index / factor, categorical),
                ),
            ]),
        )
    }

    useEffect(() => {
        if (defaults) {
            setDefaults(false)
            if (chart.tiles.selected.length === 0 && dctx.tiles.nunique > 0) {
                form.setFieldValue(
                    "chart.tiles.selected",
                    Array.from(dctx.tiles.value_labels.keys()).slice(0, 4),
                )
            }
            if (chart.dimensions.selected.length === 0 && dctx.dimensions.nunique > 0) {
                form.setFieldValue(
                    "chart.dimensions.selected",
                    Array.from(dctx.dimensions.value_labels.keys()).slice(0, 4),
                )
            }
        }
    }, [defaults])

    return <DataContext.Provider value={dctx}>{children}</DataContext.Provider>
}
