import React, {
  useEffect,
  useMemo,
  useState,
  useRef,
  forwardRef,
  Fragment,
} from 'react'
import { Row, Col } from 'react-bootstrap'
import { ResponsiveLine } from '@nivo/line'
import { useTranslation } from 'react-i18next'
import { useQuery } from 'react-query'
import * as d3 from 'd3'
import { animated } from '@react-spring/web'
import { useAnimatedPath } from '@nivo/core'

import YearlyLayer from './YearlyLayer'
import { getMMMDataColumnInfo } from '../../utility/model'
import { dateParams, integerParams, nivoLineProps } from '../../utility/model'
import { toast } from 'react-toastify'
import { getMMMStatistics } from '../../services/model'
import { useAuth } from '../../providers/AuthProvider'
import { placeholderPrediction } from './placeholders'
import UpgradeToPro from './UpgradeToPro'

function getTextWidth(text, fontSize = '13', fontFace = 'monospace') {
  var a = document.createElement('canvas')
  var b = a.getContext('2d')
  b.font = fontSize + 'px ' + fontFace
  return b.measureText(text).width
}

function Label({ x, y, id, value, color, width, widthId }) {
  return (
    <>
      <rect
        x={x - 11 - width - widthId}
        y={y - 8}
        width={widthId + 5}
        height={18}
        stroke="transparent"
        fill="var(--mmm-secondary-color)"
        strokeWidth="1"
      />
      <text
        className="label-media-contribution"
        x={x - 11 - width - widthId + 5}
        y={y + 5}
        fill="var(--mmm-white-color)"
      >
        {id}
      </text>

      <rect
        x={x - 12 - width - widthId}
        y={y - 9}
        width={widthId + width + 7}
        height={20}
        fill={'transparent'}
        stroke={color}
        strokeWidth={2}
        rx={3}
        ry={3}
      />
    </>
  )
}

function Anchor({ x, y, targetX, top, color }) {
  const startX = x
  const startY = y
  const endX = targetX
  const endY = top

  return (
    <>
      <path
        d={`M${startX},${startY} L${endX},${endY}`}
        className="ant-trail"
        stroke={color}
        stroke-lineca="butt"
        strokeWidth="2"
        fill="transparent"
        strokeLinejoin="bevel"
        strokeDasharray="3 6"
      />
      <polyline
        points={`${x} ${y} ${x - 5} ${y + 4} ${x - 5} ${y - 4}`}
        stroke="#ffffff44"
        fill={color}
        strokeWidth="1"
      />
    </>
  )
}

function LabelL({ x, y, id, value, color, width, widthId }) {
  return (
    <g>
      <Label
        id={id}
        value={value}
        color={color}
        width={width}
        widthId={widthId}
        x={x}
        y={y}
      />
    </g>
  )
}

function CustomLabel({ slices, points, data, minX, maxX }) {
  try {
    const span = (maxX - minX) / data.length
    const hightestPt = data.map((d, idx) => {
      const pos =
        Math.floor(minX + span * 0.5 + span * idx) +
        (idx === 0 ? 1 : idx === data.length - 1 ? -1 : 0)
      return {
        id: d.id,
        label: d.label,
        data: {
          index: pos,
          value: d.data[pos].y,
        },
      }
    })
    const labelAnchors = hightestPt
      .map(({ id, label, data }, i) => {
        const dex = slices[data.index].points.find(
          (d) => d.serieId === id,
        ).index
        const res = {
          id,
          label,
          top: points[dex].y,
          color: points[dex].serieColor,
          x: Math.max(200, points[dex].x - i * 10),
          targetX: points[dex].x,
          numValue: data.value,
          value: ``,
          width: 0,
          widthId: getTextWidth(label) + 4,
        }
        res.totalWidth = res.width + res.widthId + 4
        return res
      })
      .filter((d) => d.id !== 'Baseline')
      .sort((a, b) => b.numValue - a.numValue)
      .slice(0, 5)
      .sort((a, b) => a.top - b.top)
      .map((v, i) => {
        v.y = 8 + i * 50
        return v
      })

    return (
      <>
        <g className="pe-none">
          {labelAnchors.map(
            ({ x, y, id, value, color, width, widthId, top, targetX }) => {
              if (Number.parseInt(value) === 0 || id === 'Baseline')
                return <Fragment key={id}></Fragment>
              return (
                <Anchor
                  key={id}
                  targetX={targetX}
                  x={x}
                  y={y}
                  top={top}
                  id={id}
                  value={value}
                  color={color}
                  width={width}
                  widthId={widthId}
                />
              )
            },
          )}
        </g>
        <g className="pe-none">
          {labelAnchors.map(
            ({ x, y, id, label, value, color, width, widthId, key, top }) => {
              if (Number.parseInt(value) === 0 || id === 'Baseline')
                return <Fragment key={id}></Fragment>
              return (
                <LabelL
                  key={id}
                  x={x}
                  y={y}
                  top={top}
                  id={label}
                  value={value}
                  color={color}
                  width={width}
                  widthId={widthId}
                />
              )
            },
          )}
        </g>
      </>
    )
  } catch (e) {
    return <></>
  }
}

const CustomBrush = ({ graphRef, minX, maxX, ...props }) => {
  const brushRef = useRef(null)

  useEffect(() => {
    d3.select(brushRef.current).call(
      d3
        .brushX()
        .extent([
          [0, 0],
          [props.width, props.innerHeight],
        ])
        .on('brush', (e) => {
          const event = new CustomEvent('brush_update', {
            detail: {
              minX: minX + ((maxX - minX) * e.selection[0]) / props.width,
              maxX: minX + ((maxX - minX) * e.selection[1]) / props.width,
            },
          })
          graphRef?.current?.dispatchEvent(event)
        })
        .on('end', ({ selection }) => {
          if (!selection) graphRef.current.dispatchEvent(new CustomEvent('end'))
        }),
    )
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [])

  return <g ref={brushRef} />
}

const CustomLineLayer = ({
  series,
  lineWidth,
  lineID,
  dashed = false,
  ...props
}) => {
  const lineGenerator = d3
    .line()
    .x((d) => d.position.x)
    .y((d) => d.position.y)
    .curve(d3.curveMonotoneX)

  const line = lineID ? series.find((s) => s.id === lineID) : series[0]
  if (!line) return null
  return (
    <animated.path
      d={useAnimatedPath(lineGenerator(line.data))}
      fill="none"
      filter="url(#visible)"
      stroke={line.color}
      strokeWidth={lineWidth}
      strokeDasharray={dashed ? '4 8' : '0'}
    />
  )
}

const CustomSplitter = ({
  series,
  at,
  text,
  data,
  xScale,
  color = '#4141b3',
  height,
  forceHeight,
  ...props
}) => {
  if (!data?.[0]?.data?.[at]?.x) return null
  const x_pos = xScale(new Date(data?.[0]?.data[at]?.x))
  return (
    <>
      <rect
        width={1}
        height={forceHeight || height - 100}
        x={x_pos}
        y={0}
        fill={'white'}
      />
      <g transform={`translate(${x_pos - 80}, 30)`}>
        <text
          transform="rotate(0)"
          fill="white"
          stroke={color}
          strokeWidth={0.5}
          fontSize={15}
        >
          {text}
        </text>
      </g>
    </>
  )
}

const CustomAreaLayer = ({ series, innerWidth, ...props }) => {
  const areaGenerator = d3
    .area()
    .x((d) => d.position.x)
    .y0((d) => d.position.y)
    .y1((d) => d.position.y_up)
    .curve(d3.curveMonotoneX)
  const pred_lower = series.find((line) => line.id === 'pred_lower')?.data
  const pred_upper = series.find((line) => line.id === 'pred_upper')?.data

  const animatedPath = useAnimatedPath(
    areaGenerator(
      pred_lower.map((d, i) => ({
        ...d,
        position: { ...d.position, y_up: pred_upper[i].position.y },
      })),
    ),
  )
  return (
    <animated.path
      key="custom-area"
      fillRule="even-odd"
      filter="url(#visible)"
      d={animatedPath}
      fill={'#df997f'}
      fillOpacity={'0.3'}
    />
  )
}

const MainPredictionChart = forwardRef(function MainPredictionChart(
  {
    shownData,
    minX,
    maxX,
    maxY,
    minY,
    target,
    nivoProps,
    model,
    customDates,
    ...props
  },
  ref,
) {
  const [brushMinX, setBrushMinX] = useState(minX)
  const [brushMaxX, setBrushMaxX] = useState(maxX)
  const { t } = useTranslation()
  const columnInfo = useMemo(
    () => getMMMDataColumnInfo(model, '', { customDates }),
    [model],
  )
  const { minDate, maxDate } = useMemo(() => {
    const colData =
      model?.column_statistics?.[model?.dataslayer_training_config?.time_column]
    return {
      minDate: new Date(colData?.min),
      maxDate: new Date(colData?.max),
    }
  }, [model])

  const [splitShowData, splitTrain, splitTest] = useMemo(() => {
    try {
      const pred = shownData?.find((v) => v.id === 'pred')
      if (
        pred &&
        minDate &&
        maxDate &&
        model?.training_config?.test_train_split
      ) {
        const split =
          (maxDate.getTime() - minDate.getTime()) *
            model?.training_config?.test_train_split +
          minDate.getTime()
        const index = pred.data.findIndex(
          (v) => new Date(v.x).getTime() >= split,
        )
        const predTrain = pred.data.slice(0, index)
        const predTest = [
          ...predTrain,
          ...pred.data.slice(Math.max(0, index - 1)),
        ]

        return [
          [
            ...shownData.filter((v) => v.id !== 'pred'),
            {
              id: 'pred',
              label: t('Train data'),
              data: predTrain,
              color: '#198754',
            },
            {
              id: 'pred_test',
              label: t('Test data'),
              data: predTest,
              color: '#4141b3aa',
            },
          ],
          index,
          pred.data.length - 1,
        ]
      }
    } catch (e) {}
    return [shownData, null, null]
  }, [shownData])

  const customYearLayer = useMemo(() => {
    return (props) => (
      <YearlyLayer
        timeOffset={true}
        start={brushMinX}
        end={brushMaxX}
        {...props}
      />
    )
  }, [brushMinX, brushMaxX])

  useEffect(() => {
    //add event to ref listener for brush_update
    if (ref.current) {
      const updateBrush = (e) => {
        if (e.detail) {
          setBrushMinX(e.detail.minX)
          setBrushMaxX(e.detail.maxX)
        }
      }
      const end = () => {
        setBrushMinX(minX)
        setBrushMaxX(maxX)
      }
      ref.current.addEventListener('brush_update', updateBrush)
      ref.current.addEventListener('end', end)
      const current = ref.current
      return () => {
        current?.removeEventListener('brush_update', updateBrush)
        current?.removeEventListener('end', end)
      }
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [ref.current])
  return (
    <Col ref={ref} style={{ minHeight: '350px', maxHeight: '350px' }} xs={12}>
      <ResponsiveLine
        {...nivoLineProps}
        {...nivoProps}
        data={splitShowData}
        margin={{
          top: 30,
          bottom: 70,
          left: 100,
          right: 50,
        }}
        xScale={{
          ...nivoProps?.xScale,
          min: columnInfo.map(brushMinX),
          max: columnInfo.map(brushMaxX),
        }}
        yScale={{ type: 'linear', min: minY, max: maxY }}
        enablePoints={false}
        enableGridX={false}
        enableGridY={false}
        useMesh={true}
        enableSlices={'x'}
        legends={[]}
        axisLeft={{
          ...nivoLineProps.axisLeft,
          legendOffset: -90,
          legend: target ?? t('Outcome'),
        }}
        layers={[
          CustomAreaLayer,
          ...nivoLineProps.layers.filter(
            (x) => !['lines', 'slices'].includes(x),
          ),
          (props) => <CustomLineLayer {...props} lineID={'pred_test'} />,
          (props) => <CustomLineLayer {...props} lineID={'pred'} />,
          (props) => <CustomLineLayer {...props} lineID={'pred_lower'} />,
          (props) => <CustomLineLayer {...props} lineID={'pred_upper'} />,
          (props) => (
            <CustomLineLayer
              dashed={true}
              {...props}
              lineWidth={2}
              lineID={'original'}
            />
          ),
          (props) => (
            <CustomSplitter
              {...props}
              at={splitTrain - 1}
              color="#198754"
              text={t('Train {{perc}}', {
                perc: `${Number.parseInt(model?.training_config?.test_train_split * 100)}%`,
              })}
            />
          ),
          (props) => (
            <CustomSplitter
              {...props}
              at={splitTest}
              text={t('Test {{perc}}', {
                perc: `${100 - Number.parseInt(model?.training_config?.test_train_split * 100)}%`,
              })}
            />
          ),
          (props) => (
            <defs>
              <filter id="visible">
                <feFlood
                  x="0"
                  y="0"
                  width={props.innerWidth}
                  height={props.innerHeight}
                  result="visible"
                />
                <feComposite
                  operator="in"
                  in2="visible"
                  in="SourceGraphic"
                  result="compos"
                />
              </filter>
            </defs>
          ),
          'slices',
          (props) => (
            <CustomLabel {...props} minX={brushMinX} maxX={brushMaxX} />
          ),
          customYearLayer,
        ]}
      />
    </Col>
  )
})

export default function PredictionChart({
  model,
  height = 400,
  target = null,
  isInView = false,
  ...props
}) {
  const { t } = useTranslation()
  const { token, isEssential } = useAuth()
  const requirePro = isEssential && !model?.read_only

  const { data: _data } = useQuery(
    ['mmm-model-statistics', model.id],
    async () => {
      const response = await getMMMStatistics({
        modelId: model.id,
        token,
      })

      if (!response?.ok) toast.error(t('Failed to retrieve original forecast'))
      const res = await response.json()
      return res
    },
    { staleTime: 60 * 1000 },
  )
  const data = requirePro ? placeholderPrediction : _data

  const [processedData, noData, minX, maxX, minY, maxY, nivoProps] =
    useMemo(() => {
      if (!data) return [null, 0, 0]
      const options = data?.media_data
        ? { customDates: data?.media_data?.map((d) => new Date(`${d}Z`)) }
        : {}
      const columnInfo = getMMMDataColumnInfo(model, '', options)
      const nivoProps =
        columnInfo.mode === 'datetime' ? dateParams : integerParams

      const weeks = data?.media_data?.map((_, i) => i) ?? []
      const maxY = Math.max(
        data.y.reduce((a, b) => Math.max(a, b), -Infinity),
        data.upper_bound.reduce((a, b, idx) => Math.max(a, b), -Infinity),
      )

      const minY = Math.min(
        data.y.reduce((a, b) => Math.min(a, b), Infinity),
        data.lower_bound.reduce((a, b, idx) => Math.min(a, b), Infinity),
      )

      return [
        [
          {
            id: 'original',
            label: t('original'),
            data: data.y.map((d, i) => ({ x: columnInfo.map(weeks[i]), y: d })),
            color: '#59b2f5',
          },
          {
            id: 'pred',
            label: t('predicted'),
            data: data.y_pred.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: d,
            })),
            color: '#198754',
          },
          {
            id: 'pred_upper',
            label: t('upper error margin'),
            data: data.upper_bound.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: d,
            })),
            color: '#750a03c8',
          },
          {
            id: 'pred_lower',
            label: t('lower error margin'),
            data: data.lower_bound.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: d,
            })),
            color: '#e66a6088',
          },
        ],
        [
          {
            id: 'original',
            label: t('original'),
            data: data.y.map((d, i) => ({ x: columnInfo.map(weeks[i]), y: 0 })),
            color: 'var(--mmm-tables-blue-graph-bar-color)',
          },
          {
            id: 'pred',
            label: t('predicted'),
            data: data.y_pred.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: 0,
            })),
            color: '#198754',
          },
          {
            id: 'pred_upper',
            label: t('upper error margin'),
            data: data.upper_bound.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: 0,
            })),
            color: '#750a03c8',
          },
          {
            id: 'pred_lower',
            label: t('lower error margin'),
            data: data.y_pred.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: 0,
            })),
            color: '#e66a6088',
          },
        ],
        weeks[0],
        weeks[weeks.length - 1],
        minY * 0.9,
        maxY * 1.1,
        nivoProps,
      ]
    }, [data, model, t])

  const graphRef = useRef(null)

  const shownData = isInView ? processedData : noData

  const csvData = processedData?.[0]?.data
    ? [
        [
          t('week'),
          t('original'),
          t('predicted'),
          t('upper error margin'),
          t('lower error margin'),
        ],
        ...processedData[0].data.map((x, idx) => [
          x?.x,
          x?.y,
          processedData[1].data?.[idx]?.y,
          processedData[2].data?.[idx]?.y,
          processedData[3].data?.[idx]?.y,
        ]),
      ]
    : []

  const { minDate, maxDate } = useMemo(() => {
    const colData =
      model?.column_statistics?.[model?.dataslayer_training_config?.time_column]
    return {
      minDate: new Date(colData?.min),
      maxDate: new Date(colData?.max),
    }
  }, [model])

  const [splitTrain] = useMemo(() => {
    try {
      const pred = processedData?.find((v) => v.id === 'pred')
      if (
        pred &&
        minDate &&
        maxDate &&
        model?.training_config?.test_train_split
      ) {
        const split =
          (maxDate.getTime() - minDate.getTime()) *
            model?.training_config?.test_train_split +
          minDate.getTime()
        const index = pred.data.findIndex(
          (v) => new Date(v.x).getTime() >= split,
        )
        return [index, pred.data.length - 1]
      }
    } catch (e) {}
    return [null, null]
  }, [processedData])

  if (!shownData) return <></>

  const predData = shownData.find((x) => x.id === 'pred')
  const brushData = [
    {
      id: 'train',
      data: predData?.data?.slice(0, splitTrain),
      color: '#198754',
    },
    {
      id: 'test',
      data: predData?.data,
      color: '#4141b3aa',
    },
  ]

  return (
    <Row
      className={`h-100 data-holder relative ${requirePro ? 'blur-under' : ''}`}
      data-csv={encodeURIComponent(JSON.stringify(csvData))}
      data-filename={`response_model_accuracy__${model.id}`}
      {...props}
    >
      {requirePro && <UpgradeToPro />}
      <MainPredictionChart
        ref={graphRef}
        shownData={shownData}
        minX={minX}
        maxX={maxX}
        minY={minY}
        maxY={maxY}
        target={target}
        model={model}
        nivoProps={nivoProps}
        customDates={data?.media_data?.map((d) => new Date(`${d}Z`))}
      />
      <Col xs={12}>
        <Row className="!ms-[100px] !me-[50px] relative">
          <Col
            className=""
            style={{ minHeight: '100px', maxHeight: '100px' }}
            xs={12}
          >
            <ResponsiveLine
              {...nivoLineProps}
              {...nivoProps}
              data={brushData}
              margin={{
                ...nivoLineProps.margin,
                bottom: 50,
                left: 0,
                right: 0,
              }}
              enablePoints={false}
              enableGridX={false}
              enableGridY={false}
              axisLeft={null}
              axisBottom={{
                ...nivoProps.axisBottom,
                legendOffset: 45,
                legend: t(`Week`),
              }}
              layers={[
                ...nivoLineProps.layers,
                (props) => (
                  <CustomSplitter
                    {...props}
                    at={splitTrain - 1}
                    text={''}
                    forceHeight={40}
                  />
                ),
                (props) => (
                  <CustomBrush
                    {...props}
                    graphRef={graphRef}
                    minX={minX}
                    maxX={maxX}
                  />
                ),
              ]}
            />
          </Col>
        </Row>
      </Col>
    </Row>
  )
}
