import { AIEvalResult, AIModelResult, AIRequest } from "@incident-io/api";
import { Form } from "@incident-shared/forms";
import { NumberInputV2 } from "@incident-shared/forms/v2/inputs/NumberInputV2";
import { StaticMultiSelectV2 } from "@incident-shared/forms/v2/inputs/StaticSelectV2";
import { ColorPaletteEnum } from "@incident-shared/utils/ColorPalettes";
import {
  Badge,
  BadgeSize,
  BadgeTheme,
  Icon,
  IconEnum,
  IconSize,
  LoadingWrapper,
  Tooltip,
} from "@incident-ui";
import { Button, ButtonTheme } from "@incident-ui/Button/Button";
import {
  Drawer,
  DrawerBody,
  DrawerContents,
  DrawerFooter,
  DrawerTitle,
} from "@incident-ui/Drawer/Drawer";
import {
  Table,
  TableCell,
  TableHeaderCell,
  TableRow,
} from "@incident-ui/Table/Table";
import { sumBy } from "lodash";
import { useMemo, useState } from "react";
import { useForm } from "react-hook-form";
import { cacheKey, useAPI, useMutationV2 } from "src/utils/swr";
import { useCounter } from "usehooks-ts";

import { CodeViewer } from "../common/CodeViewer";
import {
  getEvalOverride,
  removeEvalOverride,
  upsertEvalOverride,
} from "../common/evals";
import { LabeledValue } from "../common/LabeledValue";
import { displayCost } from "../common/utils";
import { ExpectedResultWarning } from "./AISpanDrawer";

type FormData = {
  models: string[];
  repeat: number;
};

export const RerunAIRequestDrawer = ({
  request,
  onClose,
}: {
  request: AIRequest;
  onClose: () => void;
}) => {
  const [results, setResults] = useState<AIModelResult[]>([]);
  const formMethods = useForm<FormData>({
    defaultValues: {
      models: [request.model],
      repeat: 1,
    },
  });

  const { data: modelsData } = useAPI(
    "aIStaffListModels",
    {},
    {
      fallbackData: { models: [] },
    },
  );

  const [isEditingResponse, setIsEditingResponse] = useState(false);
  // This will cause a re-render when the value changes
  const { count, increment: reRender } = useCounter();
  const { hasOverride, result: result } = useMemo(
    () => getEvalOverride(request),
    // eslint-disable-next-line react-hooks/exhaustive-deps
    [request, count],
  );
  const onUpdateResult = (res) => {
    upsertEvalOverride({
      requestId: request.id,
      result: res,
    });
    setTimeout(() => {
      reRender();
    }, 500);
  };

  const { trigger: onSubmit, isMutating } = useMutationV2(
    async (apiClient, data: FormData) => {
      const res = await apiClient.aIStaffRerunAIRequest({
        id: request.id,
        rerunAIRequestRequestBody: {
          models: data.models,
          repeat: data.repeat,
          expected: result,
        },
      });
      setResults(res.results);
    },
    {
      invalidate: [cacheKey.all("aIStaffListAIRequests")],
      showErrorToast: "Failed to rerun AI request",
    },
  );

  const modelOptions =
    modelsData?.models
      .filter((m) => m.model_type === "completions")
      .map((model) => ({
        label: model.name,
        value: model.name,
        sort_key: model.name,
      })) || [];

  return (
    <Drawer onClose={onClose} width="medium" className="overflow-y-hidden">
      <DrawerContents>
        <DrawerTitle
          compact
          title="Rerun prompt"
          onClose={onClose}
          icon={IconEnum.Refresh}
          color={ColorPaletteEnum.Purple}
        />
        <DrawerBody className="overflow-y-auto">
          <Form.Root
            formMethods={formMethods}
            onSubmit={onSubmit}
            id="rerun-ai-request"
          >
            <CodeViewer
              mode="yaml"
              key={count}
              title={
                <div className="flex items-center gap-2">
                  <span>Expected result</span>
                  {hasOverride && !isEditingResponse && (
                    <div className="flex">
                      <Badge
                        size={BadgeSize.Small}
                        theme={BadgeTheme.Info}
                        className="rounded-r-none border-r-0"
                      >
                        Edited
                      </Badge>
                      <Button
                        size={BadgeSize.Small}
                        icon={IconEnum.Close}
                        theme={ButtonTheme.Secondary}
                        className="rounded-l-none"
                        analyticsTrackingId={null}
                        title="discard"
                        onClick={() => {
                          removeEvalOverride(request.id);
                          setTimeout(() => {
                            reRender();
                          }, 500);
                        }}
                      />
                    </div>
                  )}
                </div>
              }
              content={result}
              isEditing={isEditingResponse}
              setIsEditing={setIsEditingResponse}
              onEdit={onUpdateResult}
            />
            <ExpectedResultWarning result={result} />

            <StaticMultiSelectV2
              formMethods={formMethods}
              name="models"
              label="Models"
              placeholder="Select models to run"
              options={modelOptions}
              required="Please select at least one model"
            />

            <NumberInputV2
              formMethods={formMethods}
              name="repeat"
              label="Number of repeats"
              min={1}
              max={10}
              required="Please enter number of repeats"
            />
          </Form.Root>
          {results && (
            <>
              <hr />
              <LoadingWrapper loading={isMutating}>
                <div className={"flex flex-col gap-2"}>
                  {results.map((evalResult) => (
                    <EvalResultForModel
                      key={evalResult.model_name}
                      result={evalResult}
                      expected={result}
                    />
                  ))}
                </div>
              </LoadingWrapper>
            </>
          )}
        </DrawerBody>
        <DrawerFooter className="flex justify-end">
          <Button
            form="rerun-ai-request"
            type="submit"
            loading={isMutating}
            theme={ButtonTheme.Primary}
            analyticsTrackingId={null}
          >
            {results ? "Re-run" : "Run"}
          </Button>
        </DrawerFooter>
      </DrawerContents>
    </Drawer>
  );
};

const EvalResultForModel = ({
  result,
  expected,
}: {
  result: AIModelResult;
  expected: string;
}) => {
  const [expandAll, setExpandAll] = useState(false);
  const orderedCheckNames = result.cases[0]?.check_results.map((c) => c.name);

  const requestCostSum = sumBy(result.cases, (c) => c.request_cost_cents);
  const checkCostSum = sumBy(result.cases, (c) => c.check_cost_cents);
  const totalCost = requestCostSum + checkCostSum;

  return (
    <div className="flex flex-col gap-2">
      <div className="flex items-center gap-2 justify-between">
        <div className="text-base-bold">{result.model_name}</div>
        <div className="text-content-tertiary flex items-center gap-2">
          <Tooltip
            content={
              <div>
                <p>
                  Cost of running the prompts themselves:{" "}
                  {displayCost(requestCostSum, 4)}
                </p>
                <p>
                  Cost of running the checks: {displayCost(checkCostSum, 4)}
                </p>
              </div>
            }
          >
            <div className="text-content-tertiary flex items-center gap-0.5">
              <Icon id={IconEnum.PiggyBank} size={IconSize.Small} />
              {displayCost(totalCost)}
            </div>
          </Tooltip>
          <Button
            size={BadgeSize.Small}
            onClick={() => setExpandAll(!expandAll)}
            analyticsTrackingId={null}
          >
            {expandAll ? "Collapse all" : "Expand all"}
          </Button>
        </div>
      </div>
      <Table
        gridTemplateColumns={`repeat(${orderedCheckNames.length}, 1fr) auto`}
        data={result.cases}
        wrappedInBox
        // Force it to re-render when expandAll changes
        key={expandAll ? "expanded" : "collapsed"}
        header={
          <>
            {orderedCheckNames.map((name) => (
              <TableHeaderCell key={name} title={name} />
            ))}
            <TableHeaderCell />
          </>
        }
        renderRow={(rowResult, idx) => (
          <EvalResultRow
            expected={expected}
            key={idx}
            result={rowResult}
            expandByDefault={result.cases.length === 1 || expandAll}
            orderedCheckNames={orderedCheckNames}
          />
        )}
      />
    </div>
  );
};

const EvalResultRow = ({
  expected,
  result,
  orderedCheckNames,
  expandByDefault,
}: {
  expected: string;
  result: AIEvalResult;
  orderedCheckNames: string[];
  expandByDefault: boolean;
}) => {
  const [expanded, setExpanded] = useState(expandByDefault);

  const failedChecks = result.check_results.filter((c) => !c.pass);

  return (
    <>
      <TableRow
        isLastRow={false}
        onClick={() => setExpanded(!expanded)}
        className="hover:bg-gray-50 cursor-pointer"
      >
        {orderedCheckNames.map((name, idx) => {
          const check = result.check_results.find((c) => c.name === name);
          if (!check) {
            return (
              <TableCell key={idx}>
                <Icon id={IconEnum.QuestionMark} className="text-slate-300" />
              </TableCell>
            );
          }

          return (
            <TableCell key={idx}>
              <Tooltip content={check.reasoning}>
                <div className="flex items-center gap-2">
                  <Icon
                    id={check.pass ? IconEnum.TickCircle : IconEnum.CloseCircle}
                    className={check.pass ? "text-green-600" : "text-red-600"}
                  />
                  {check.grade !== undefined && <span>{check.grade}</span>}
                </div>
              </Tooltip>
            </TableCell>
          );
        })}
        <TableCell>
          <Tooltip
            content={
              <div>
                <p>
                  Cost of running the prompt itself:{" "}
                  {displayCost(result.request_cost_cents, 4)}
                </p>
                <p>
                  Cost of running the checks:{" "}
                  {displayCost(result.check_cost_cents, 4)}
                </p>
              </div>
            }
          >
            <div className="text-content-tertiary flex items-center gap-0.5">
              <Icon id={IconEnum.PiggyBank} size={IconSize.Small} />
              {displayCost(result.request_cost_cents + result.check_cost_cents)}
            </div>
          </Tooltip>
          <Icon
            id={expanded ? IconEnum.Collapse : IconEnum.Expand}
            className="text-slate-600 group-hover:text-slate-900 transition"
          />
        </TableCell>
      </TableRow>

      {expanded && (
        <TableRow isLastRow={false}>
          <TableCell
            gridColumn={`span ${orderedCheckNames.length + 1}`}
            className="flex flex-col gap-4 items-start bg-surface-secondary"
          >
            {/* Show the reasoning for any failed checks */}
            {failedChecks.length > 0 && (
              <div className="flex flex-col gap-2 pt-2">
                <div className="text-sm-bold">Failed checks</div>
                {failedChecks.map((check) => (
                  <LabeledValue
                    key={check.name}
                    label={check.name}
                    labelClassName="w-32"
                    value={check.reasoning}
                  />
                ))}
              </div>
            )}
            <CodeViewer
              mode="yaml"
              title="Actual"
              content={result.actual}
              codeClassName="bg-white"
              diffWith={expected}
            />
          </TableCell>
        </TableRow>
      )}
    </>
  );
};
