import { Button, Col, Form, Row } from 'react-bootstrap'
import { useTranslation } from 'react-i18next'
import DataTable from 'react-data-table-component'
import { IoMdRefresh } from 'react-icons/io'
import { FaArrowLeft, FaPen } from 'react-icons/fa'
import { useMemo, useState } from 'react'
import { useDebouncedCallback } from 'use-debounce'
import CustomSelect from './CustomSelect'
import { useQuery } from 'react-query'
import { ResponsiveLine } from '@nivo/line'
import { toast } from 'react-toastify'

const DISTRIBUTIONS = [
  {
    id: 'Normal',
    variables: [
      { id: 'mu', default: 0.2 },
      { id: 'sigma', default: 2 },
    ],
  },
  {
    id: 'HalfNormal',
    variables: [{ id: 'sigma', default: 2 }],
  },
  {
    id: 'Gamma',
    variables: [
      { id: 'alpha', default: 3 },
      { id: 'beta', default: 1 },
    ],
  },
  {
    id: 'Laplace',
    variables: [
      { id: 'mu', default: 2 },
      { id: 'b', default: 0.2 },
    ],
  },
  {
    id: 'Beta',
    variables: [
      { id: 'alpha', default: 2 },
      { id: 'beta', default: 2 },
    ],
  },
]

export function DEFAULT_PRIORS() {
  return {
    intercept: {
      type: 'Normal',
      values: {
        mu: { value: 0.2 },
        sigma: { value: 2 },
      },
    },
    saturation_beta: {
      type: 'HalfNormal',
      values: {},
    },
    saturation_lam: {
      type: 'Gamma',
      values: {
        alpha: { value: 3 },
        beta: { value: 1 },
      },
    },
    gamma_control: {
      type: 'Laplace',
      values: {
        mu: { value: 2 },
        b: { value: 0.2 },
      },
    },
    gamma_fourier: {
      type: 'Normal',
      values: {
        mu: { value: 0 },
        sigma: { value: 0.3 },
      },
    },
    // likelihood: {
    //   type: 'Normal',
    //   values: {},
    // },
    adstock_alpha: {
      type: 'Beta',
      values: {
        alpha: { value: 2 },
        beta: { value: 2 },
      },
    },
    // peak_effect_delay: {
    //   type: 'Beta',
    //   values: {
    //     alpha: { value: 1 },
    //     beta: { value: 4 },
    //   },
    // },
    // coef_trend: {
    //   type: 'Normal',
    //   values: {
    //     mu: { value: 0 },
    //     sigma: { value: 1.5 },
    //   },
    // },
  }
}

const nameToSymbol = {
  alpha: 'α',
  beta: 'β',
  mu: 'μ',
  sigma: 'σ',
  b: 'b',
}

const ALLOWED_DISTRIBUTIONS = {
  intercept: new Set(['Normal', 'HalfNormal', 'Gamma', 'Laplace', 'Beta']),
  saturation_beta: new Set([
    'Normal',
    'HalfNormal',
    'Gamma',
    'Laplace',
    'Beta',
  ]),
  saturation_lam: new Set(['Normal', 'HalfNormal', 'Gamma', 'Laplace', 'Beta']),
  gamma_control: new Set(['Normal', 'HalfNormal', 'Gamma', 'Laplace', 'Beta']),
  gamma_fourier: new Set(['Normal', 'HalfNormal', 'Gamma', 'Laplace', 'Beta']),
  likelihood: new Set([]),
  adstock_alpha: new Set(['Beta']),
  peak_effect_delay: new Set([
    'Normal',
    'HalfNormal',
    'Gamma',
    'Laplace',
    'Beta',
  ]),
  coef_trend: new Set(['Normal', 'HalfNormal', 'Gamma', 'Laplace', 'Beta']),
}

function PrintValues(prior) {
  return Object.keys(prior.values)
    .map((k) => `${nameToSymbol[k] ?? k}=${prior.values[k].value}`)
    .join(', ')
}

function Symbol({ name }) {
  return (
    <div className="absolute left-[25px] top-[calc(50%-13px)] max-h-[10px] min-h-[10px] font-strong">
      {nameToSymbol[name] ?? ''}
    </div>
  )
}

function generateNormalDistribution(mu, sigma, numPoints = 100) {
  const points = []
  const step = (6 * sigma) / numPoints
  const start = mu - 3 * sigma

  function normalPDF(x, mu, sigma) {
    return (
      (1 / (sigma * Math.sqrt(2 * Math.PI))) *
      Math.exp(-0.5 * Math.pow((x - mu) / sigma, 2))
    )
  }

  for (let i = 0; i < numPoints; i++) {
    let x = start + i * step
    let y = normalPDF(x, mu, sigma)
    points.push({ x, y })
  }
  return points
}

function generateHalfNormalDistribution(sigma, numPoints = 100) {
  const points = []
  const step = (3 * sigma) / numPoints
  const start = 0

  function halfNormalPDF(x, sigma) {
    return (
      (Math.sqrt(2) / (sigma * Math.sqrt(Math.PI))) *
      Math.exp(-0.5 * Math.pow(x / sigma, 2))
    )
  }

  for (let i = 0; i < numPoints; i++) {
    let x = start + i * step
    let y = halfNormalPDF(x, sigma)
    points.push({ x, y })
  }

  return points
}

function generateGammaDistribution(alpha, beta, numPoints = 100) {
  const points = []
  const step = (5 * alpha) / numPoints
  const start = 0

  function gammaPDF(x, alpha, beta) {
    if (x < 0) return 0
    return (
      (Math.pow(beta, alpha) * Math.pow(x, alpha - 1) * Math.exp(-beta * x)) /
      gammaFunction(alpha)
    )
  }

  function gammaFunction(n) {
    if (n === 1) return 1
    if (n === 0.5) return Math.sqrt(Math.PI)
    return (n - 1) * gammaFunction(n - 1)
  }

  for (let i = 0; i < numPoints; i++) {
    let x = start + i * step
    let y = gammaPDF(x, alpha, beta)
    points.push({ x, y })
  }

  return points
}

function generateLaplaceDistribution(mu, b, numPoints = 100) {
  const points = []
  const step = (6 * b) / numPoints
  const start = mu - 3 * b

  function laplacePDF(x, mu, b) {
    return (1 / (2 * b)) * Math.exp(-Math.abs(x - mu) / b)
  }

  for (let i = 0; i < numPoints; i++) {
    let x = start + i * step
    let y = laplacePDF(x, mu, b)
    points.push({ x, y })
  }

  return points
}

function generateBetaDistribution(alpha, beta, numPoints = 100) {
  const points = []
  const step = 1 / numPoints

  function gammaFunction(n) {
    if (n === 1) return 1
    if (n === 0.5) return Math.sqrt(Math.PI)
    return (n - 1) * gammaFunction(n - 1)
  }

  function betaPDF(x, alpha, beta) {
    if (x < 0 || x > 1) return 0
    return (
      (Math.pow(x, alpha - 1) * Math.pow(1 - x, beta - 1)) /
      betaFunction(alpha, beta)
    )
  }

  function betaFunction(a, b) {
    return (gammaFunction(a) * gammaFunction(b)) / gammaFunction(a + b)
  }

  for (let i = 0; i < numPoints; i++) {
    let x = i * step
    let y = betaPDF(x, alpha, beta)
    points.push({ x, y })
  }

  return points
}

function EditHyperparameter({
  params,
  priorName,
  prior,
  onChange,
  onChangePrior,
  onCancel,
  onSave,
}) {
  const { t } = useTranslation()
  const [localPrior, setLocalPrior] = useState(() =>
    JSON.parse(JSON.stringify(prior ?? {})),
  )
  const allowReset =
    JSON.stringify(localPrior) !== JSON.stringify(DEFAULT_PRIORS()[priorName])
  const distributions = useMemo(
    () =>
      DISTRIBUTIONS.filter((d) =>
        ALLOWED_DISTRIBUTIONS[priorName].has(d.id),
      ).map((k) => ({ label: t(k.id), value: k.id })),
    [t, priorName],
  )
  const [graphSettings, setGraphSettings] = useState(() => {
    return {
      type: distributions.find((d) => d.value === localPrior.type)?.value,
      ...DISTRIBUTIONS.find((d) => d.id === localPrior.type)?.variables.reduce(
        (acc, v) => {
          acc[v.id] = localPrior.values[v.id]?.value
          return acc
        },
        {},
      ),
    }
  })
  const { data: graphPoints } = useQuery(
    ['graphPoints', graphSettings],
    async () => {
      if (!graphSettings) return null
      const check = (vars = []) => {
        if (vars.some((s) => graphSettings[s] === undefined)) return false
        return true
      }
      switch (graphSettings?.type) {
        case 'Normal':
          if (!check(['mu', 'sigma'])) return null
          return generateNormalDistribution(
            graphSettings.mu,
            graphSettings.sigma,
          )
        case 'HalfNormal':
          if (!check(['sigma'])) return null
          return generateHalfNormalDistribution(graphSettings.sigma)
        case 'Gamma':
          if (!check(['alpha', 'beta'])) return null
          return generateGammaDistribution(
            graphSettings.alpha,
            graphSettings.beta,
          )
        case 'Laplace':
          if (!check(['mu', 'b'])) return null
          return generateLaplaceDistribution(graphSettings.mu, graphSettings.b)
        case 'Beta':
          if (!check(['alpha', 'beta'])) return null
          return generateBetaDistribution(
            graphSettings.alpha,
            graphSettings.beta,
          )
        default:
          return null
      }
    },
    { staleTime: Infinity },
  )
  const includePoint = new Set()
  if (Array.isArray(graphPoints) && graphPoints?.length) {
    includePoint.add(graphPoints[0]?.x)
    includePoint.add(graphPoints[graphPoints.length - 1]?.x)
    const items = Math.max(Math.round(graphPoints.length / 15), 1)
    graphPoints.slice(1, graphPoints.length - 1).forEach((p, i) => {
      if (i % items === 0) includePoint.add(p.x)
    })
  }

  const setSettings = useDebouncedCallback((v) => setGraphSettings(v), 400)

  return (
    <Row className="mx-4">
      <Col className="ps-0" xs={12}>
        <button
          className="flex items-center flex-nowrap bg-transparent border-0 ps-0"
          onClick={onCancel}
        >
          <FaArrowLeft size={30} className="me-2" />
          <span>{t('Model Hyperparameters')}</span>
        </button>
      </Col>
      <Col className="mt-3" xs={12}>
        <Row>
          <Col xs={3}>
            <Row>
              <Col xs={12}>{t('Hyperparameter')}</Col>
              <Col className="mt-2" xs={12}>
                <CustomSelect
                  options={Object.keys(params || {}).map((s) => ({
                    value: s,
                    label: s,
                  }))}
                  className="basic-single"
                  classNamePrefix="select"
                  isSearchable={true}
                  placeholder={t('Select a datasource')}
                  onChange={(e) => onChangePrior(e.value)}
                  value={{ label: t(priorName), value: priorName }}
                  isClearable={false}
                />
              </Col>
            </Row>
          </Col>
          <Col xs={3}>
            <Row>
              <Col xs={12}>{t('Distribution')}</Col>
              <Col className="mt-2" xs={12}>
                <CustomSelect
                  options={distributions}
                  className="basic-single"
                  classNamePrefix="select"
                  isSearchable={true}
                  placeholder={t('Select a datasource')}
                  onChange={(e) => {
                    if (e.value === localPrior.type) return
                    const d = DISTRIBUTIONS.find((d) => d.id === e.value)
                    const newPrior = {
                      type: e.value,
                      values: d.variables.reduce((acc, v) => {
                        acc[v.id] = { value: v.default }
                        return acc
                      }, {}),
                    }
                    const newGraph = {
                      type: distributions.find((d) => d.value === newPrior.type)
                        ?.value,
                      ...DISTRIBUTIONS.find(
                        (d) => d.id === newPrior.type,
                      )?.variables.reduce((acc, v) => {
                        acc[v.id] = newPrior.values[v.id].value
                        return acc
                      }, {}),
                    }
                    setGraphSettings(newGraph)
                    setLocalPrior(newPrior)
                  }}
                  value={distributions.find((d) => d.value === localPrior.type)}
                  isClearable={false}
                />
              </Col>
            </Row>
          </Col>
          <Col key={localPrior?.type} xs={2}>
            <Row>
              <Col className="pointer-events-none opacity-0" xs={12}>
                {'Placeholder'}
              </Col>
              <Col className="mt-2" xs={12}>
                <Row>
                  {DISTRIBUTIONS.find(
                    (d) => d.id === localPrior.type,
                  )?.variables.map((v, i) => (
                    <Col className="relative" xs={6}>
                      <Form.Control
                        className="text-input-mmm !ps-[30px] !pe-[5px]"
                        id={`prior-${v.id}`}
                        defaultValue={
                          localPrior?.values[v?.id]?.value ??
                          DEFAULT_PRIORS()[priorName]?.values[v?.id]?.value ??
                          ''
                        }
                        onChange={(e) => {
                          const vv = Number.parseFloat(e.target.value)
                          if (!Number.isNaN(vv)) {
                            localPrior.values[v.id] =
                              localPrior.values[v.id] ?? {}
                            localPrior.values[v.id].value = vv
                            setLocalPrior({ ...localPrior })
                            setSettings((d) => ({ ...d, [v.id]: vv }))
                          }
                        }}
                      />
                      <Symbol name={v.id} />
                    </Col>
                  ))}
                </Row>
              </Col>
            </Row>
          </Col>
          <Col xs={4}>
            <Row>
              <Col className="pointer-events-none opacity-0" xs={12}>
                {'Placeholder'}
              </Col>
              <Col className="mt-2" xs={12}>
                <Row className="justify-end">
                  <Col xs="auto">
                    <Button
                      className="!bg-transparent !border-[#96CDFF] !text-[#96CDFF] !rounded-[18px] me-2"
                      onClick={() => {
                        onSave(localPrior)
                        toast.success(t('Saved') + ` ${priorName}`)
                      }}
                    >
                      {t('SAVE')}
                    </Button>
                    <Button
                      className="!bg-transparent !border-[#FF4346] !text-[#FF4346] !rounded-[18px] me-2"
                      disabled={!allowReset}
                      onClick={() => {
                        const reset = DEFAULT_PRIORS()[priorName]
                        const d = DISTRIBUTIONS.find((d) => d.id === reset.type)
                        const newPrior = {
                          type: reset.type,
                          values: d.variables.reduce((acc, v) => {
                            acc[v.id] = { value: v.default }
                            try {
                              document.querySelector(`#prior-${v.id}`).value =
                                v.default
                            } catch (e) {}
                            return acc
                          }, {}),
                        }
                        const newGraph = {
                          type: distributions.find(
                            (d) => d.value === newPrior.type,
                          )?.value,
                          ...DISTRIBUTIONS.find(
                            (d) => d.id === newPrior.type,
                          )?.variables.reduce((acc, v) => {
                            acc[v.id] = newPrior.values[v.id].value
                            return acc
                          }, {}),
                        }
                        setGraphSettings(newGraph)
                        setLocalPrior(newPrior)
                        toast.success(`${priorName} ${t('reset')}`)
                      }}
                    >
                      {t('RESET TO DEFAULT')}
                    </Button>
                  </Col>
                </Row>
              </Col>
            </Row>
          </Col>
        </Row>
      </Col>
      <Col xs={12} className="py-3 min-h-[350px]">
        {Array.isArray(graphPoints) && (
          <>
            <ResponsiveLine
              data={[
                {
                  id: 'prior',
                  data: graphPoints,
                  color: '#4240B5',
                },
              ]}
              margin={{ top: 60, right: 20, bottom: 60, left: 80 }}
              yScale={{
                type: 'linear',
                min: 'auto',
                stacked: false,
                reverse: false,
              }}
              colors={(d) => d.color}
              curve={'monotoneX'}
              xFormat=" >-.2f"
              yFormat=" >-.2f"
              enablePoints={false}
              enableGridX={false}
              enableGridY={false}
              axisTop={null}
              axisRight={null}
              axisBottom={{
                legend: 'X',
                legendOffset: 35,
                orient: 'bottom',
                tickSize: 5,
                tickPadding: 5,
                legendOffset: 46,
                legendPosition: 'middle',
                tickValues: 5,
                format: (v) => {
                  if (includePoint.has(v)) return v.toFixed(2)
                  return ''
                },
              }}
              axisLeft={{
                legendPosition: 'middle',
                legendOffset: -60,
                legend: 'Y',
              }}
              pointSize={10}
              pointColor={{ theme: 'background' }}
              pointBorderWidth={2}
              pointBorderColor={{ from: 'serieColor' }}
              pointLabelYOffset={-12}
              useMesh={true}
              legends={[]}
              theme={{
                fontSize: 11,
                textColor: 'white',
                axis: {
                  ticks: {
                    text: {
                      fontSize: 11,
                    },
                  },
                  legend: {
                    text: {
                      fontSize: 13,
                      fill: 'white',
                    },
                  },
                  domain: {
                    line: {
                      stroke: 'var(--graph-grid-color)',
                      strokeWidth: 1,
                    },
                  },
                },
                tooltip: {
                  container: {
                    color: 'black',
                  },
                },
                legends: {
                  text: {
                    fontSize: 11,
                  },
                },
                grid: {
                  line: {
                    stroke: 'var(--graph-grid-color)',
                    opacity: 0.4,
                  },
                },
              }}
            />
          </>
        )}
      </Col>
    </Row>
  )
}

export default function ConfigureHyperparameters({
  model,
  params,
  onChange,
  ...props
}) {
  const { t } = useTranslation()
  const [edit, setEdit] = useState(null)

  const def = DEFAULT_PRIORS()
  const changes = Object.keys(params || {}).reduce((ac, k) => {
    ac[k] = JSON.stringify(params[k]) !== JSON.stringify(def[k])
    return ac
  }, {})
  const block = Object.values(changes).every((v) => !v)

  const columns = [
    {
      name: t('Hyperparameter'),
      cell: (row) => <div>{t(row[0])}</div>,
    },
    {
      name: t('Distribution'),
      cell: (row) => <div>{row[1].type}</div>,
    },
    {
      name: t('Key values (e.g., α, β, μ, σ)'),
      cell: (row) => <div>{PrintValues(row[1])}</div>,
    },
    {
      name: t('Actions'),
      width: '100px',
      cell: (row) => (
        <div>
          <button
            title={t('Edit parameters')}
            className="bg-transparent border-[#96CDFF] hover:border-[#96CDFF] hover:!bg-[#96CDFF] hover-ic-hyperparams border-2 p-[3px] rounded-md ms-2 "
            onClick={() => setEdit(row[0])}
          >
            <FaPen color="#96CDFF" size={14} />
          </button>
          <button
            title={t('Reset to default')}
            className="bg-transparent border-[#FF4346] hover:border-[#FF4346] hover:!bg-[#FF4346] hover-ic-hyperparams border-2 !p-[2px] rounded-md ms-2 disabled:pointer-events-none disabled:opacity-50"
            disabled={!changes?.[row[0]]}
            onClick={() => {
              onChange({
                ...params,
                [row[0]]: DEFAULT_PRIORS()[row[0]],
              })
              toast.success(`${row[0]} ` + t('reset'))
            }}
          >
            <IoMdRefresh
              color="#FF4346"
              style={{ transform: 'scaleX(-1)' }}
              size={16}
            />
          </button>
        </div>
      ),
    },
  ]
  if (edit)
    return (
      <Row {...props}>
        <Col className="" xs={12}>
          <EditHyperparameter
            key={edit}
            params={params}
            prior={params[edit]}
            priorName={edit}
            changes={changes}
            onCancel={() => setEdit(null)}
            onChangePrior={(p) => setEdit(p)}
            onSave={(values) => {
              onChange({
                ...params,
                [edit]: values,
              })
              setEdit(null)
            }}
          />
        </Col>
      </Row>
    )

  return (
    <Row {...props}>
      <Col xs={12}>
        <h4 className="text-center">{t('Model Hyperparameters')}</h4>
      </Col>
      <Col className="py-5" xs={12}>
        <div className="mx-5 table-priors">
          <DataTable
            theme="mmm-table-theme"
            columns={columns}
            data={Object.entries(params).filter((v) => !!def[v[0]])}
          />
        </div>
        <div className="flex justify-end me-5 mb-1">
          <button
            disabled={block}
            title={t('Reset all')}
            className="bg-transparent text-[#FF4346]  border-none hover:border-none  !duration-200 border-2 !p-[2px] rounded-md ms-2  inline-flex items-center disabled:pointer-events-none disabled:opacity-50 mt-1 font-bold text-sm me-0"
            onClick={() => {
              onChange(DEFAULT_PRIORS())
              toast.success(t('All hyperparameters reset'))
            }}
          >
            {t('Reset all to default')}
          </button>
        </div>
      </Col>
    </Row>
  )
}
