import { CircularProgress, Skeleton } from '@mui/material';
import moment from 'moment-timezone';
import React, { useEffect, useMemo } from 'react';
import { useState } from 'react';

import { PatientAssessmentRead } from '@headway/api/models/PatientAssessmentRead';
import { PatientAssessmentType } from '@headway/api/models/PatientAssessmentType';
import { Badge } from '@headway/helix/Badge';
import { Button } from '@headway/helix/Button';
import { IconCaretRight } from '@headway/helix/icons/CaretRight';
import { InfoIconTip } from '@headway/helix/InfoIconTip';
import { SubBodyText } from '@headway/helix/SubBodyText';
import {
  Cell,
  Column,
  Row,
  Table,
  TableBody,
  TableHeader,
} from '@headway/helix/Table';

import { usePatientAssessmentCache } from 'hooks/usePatientAssessment';
import { usePatientAssessmentRecurrenceSchedules } from 'hooks/usePatientAssessmentRecurrenceSchedules';
import { usePatientAssessments } from 'hooks/usePatientAssessments';

import { AssessmentBadgeInfo, SelectedAssessmentInfo } from '../helpers/types';
import { indicatesSiRisk, isASRSInvalid } from '../helpers/utils';
import { ResultBadges, ScoreBadge } from './InsightBadges';
import { ScoreDiffBadge } from './InsightBadges';

export const INITIAL_NUM_ASSESSMENTS_SHOWN = 3;

interface ResultTooltipProps {
  assessment: PatientAssessmentRead;
}

/**
 * Displays a tooltip in the result column
 * if assessment type is ASRS and submitted by under 18 patient.
 */
const ResultTooltip = ({ assessment }: ResultTooltipProps) => {
  const { assessmentType, scorableResponseJson } = assessment;
  return scorableResponseJson &&
    isASRSInvalid(assessmentType, scorableResponseJson) ? (
    <InfoIconTip size="medium">
      This assessment is designed for people 18+
    </InfoIconTip>
  ) : assessmentType === PatientAssessmentType.ANCHOR ? (
    <InfoIconTip>This assessment does not have a score</InfoIconTip>
  ) : null;
};

interface AssessmentTableProps {
  providerPatientId: number;
  assessmentType: PatientAssessmentType;
  onAssessmentSelected: (
    selectedAssessmentInfo: SelectedAssessmentInfo
  ) => void;

  addInitiallyVisibleAssessmentIds: (
    assessmentIdsByType: Set<number>,
    assessmentType: PatientAssessmentType
  ) => void;
}

export const getAssessmentTableQueryKeyArgs = (
  providerPatientId: number | undefined,
  assessmentType: PatientAssessmentType,
  showingAll: boolean
) => {
  return {
    provider_patient_id: providerPatientId,
    assessment_type: assessmentType,
    // Although we only show 3 assessments by default, we fetch 4 so we can calculate score
    // differences.
    limit: showingAll ? undefined : INITIAL_NUM_ASSESSMENTS_SHOWN + 1,
    order_by: 'original_created_on',
    order: 'desc',
  };
};

/**
 * Displays a table of all assessments associated with a given provider-patient and type.
 */
export const AssessmentTable = ({
  providerPatientId,
  assessmentType,
  onAssessmentSelected,
  addInitiallyVisibleAssessmentIds,
}: AssessmentTableProps) => {
  const patientAssessmentCache = usePatientAssessmentCache();
  const [showingAll, setShowingAll] = useState<boolean>(false);
  const [
    hasCalledAddInitiallyVisibleAssessmentIds,
    setHasCalledAddInitiallyVisibleAssessmentIds,
  ] = useState<boolean>(false);

  const { data, isPreviousData, isFetchedAfterMount } = usePatientAssessments(
    getAssessmentTableQueryKeyArgs(
      providerPatientId,
      assessmentType,
      showingAll
    ),
    {
      refetchOnWindowFocus: false,
      keepPreviousData: true,
    }
  );

  // Proactively update single-assessment caches with fetched data, so the results modal loads
  // instantly.
  useEffect(() => {
    if (isFetchedAfterMount && data) {
      for (const assessment of data.data) {
        patientAssessmentCache.set({ assessmentId: assessment.id }, assessment);
      }
    }
  }, [isFetchedAfterMount, data, patientAssessmentCache]);

  const sortedAssessments = useMemo(() => data?.data || [], [data?.data]);
  const assessmentsToShow = showingAll
    ? sortedAssessments
    : sortedAssessments?.slice(0, INITIAL_NUM_ASSESSMENTS_SHOWN);
  const totalCount = data?.totalCount || 0;

  // Get active recurrence schedules
  const { data: patientAssessmentRecurrenceSchedules } =
    usePatientAssessmentRecurrenceSchedules(
      {
        providerPatientId: providerPatientId,
      },
      {
        refetchOnWindowFocus: false,
      }
    );
  const patientAssessmentRecurrenceSchedule =
    patientAssessmentRecurrenceSchedules?.find(
      (schedule) => schedule.assessmentType === assessmentType
    );

  // Prepend assessmentsToShow with scheduled assessment
  const scheduledAssessment = patientAssessmentRecurrenceSchedule && {
    score: null,
    originalCreatedOn: patientAssessmentRecurrenceSchedule.nextScheduledDate,
    completedOn: null,
    id: null,
    assessmentType: assessmentType,
    safetyRisk: null,
    scorableResponseJson: null,
    subscores: null,
  };
  const allAssessments = scheduledAssessment
    ? [scheduledAssessment, ...assessmentsToShow]
    : assessmentsToShow;
  const allSortedAssessments = scheduledAssessment
    ? [scheduledAssessment, ...sortedAssessments]
    : sortedAssessments;

  const labelId = `assessment-table-label-${providerPatientId}-${assessmentType}`;

  const disabledAssessmentKeys = new Set(
    allAssessments
      .filter((assessment) => !assessment.completedOn)
      .map((assessment) => String(assessment.id))
  );

  const indicesByAssessmentId = useMemo(
    () =>
      sortedAssessments.reduce(
        (acc, current, idx) => {
          acc[current.id] = idx;
          return acc;
        },
        {} as { [id: number]: number }
      ),
    [sortedAssessments]
  );

  // assessmentsToShow has already been sliced for pagination here
  // but for purposes of 'Client Assessment Page Viewed' event, table will only be on initial view if there are < 3 assessments anyway
  const assessmentIdsByType = new Set(
    assessmentsToShow.map((assessment) => assessment.id)
  );
  if (
    !hasCalledAddInitiallyVisibleAssessmentIds &&
    assessmentIdsByType.size > 0
  ) {
    setHasCalledAddInitiallyVisibleAssessmentIds(true);
    addInitiallyVisibleAssessmentIds(assessmentIdsByType, assessmentType);
  }

  const handleRowAction = (key: React.Key) => {
    const id = Number(key);
    onAssessmentSelected({
      id: id,
      previousScore: sortedAssessments[indicesByAssessmentId[id] + 1]?.score,
    });
  };

  return allAssessments ? (
    <>
      <Table
        aria-labelledby={labelId}
        disabledKeys={disabledAssessmentKeys}
        onRowAction={handleRowAction}
      >
        <TableHeader>
          <Column width={170}>Sent</Column>
          <Column width={135}>Completed</Column>
          <Column width={300}>Result</Column>
          <Column width={135} align="right">
            {assessmentType === PatientAssessmentType.ASRS
              ? 'Percentile'
              : 'Score'}
          </Column>
          <Column width={135} align="right">
            Since last
          </Column>
          <Column align="right"> </Column>
        </TableHeader>
        <TableBody>
          {allAssessments.map((assessment, idx) => {
            const {
              score,
              originalCreatedOn,
              completedOn,
              id,
              assessmentType,
              safetyRisk,
              subscores,
            } = assessment;
            return (
              <Row key={id}>
                <Cell>
                  <SubBodyText color="gray">
                    {`${moment(originalCreatedOn).format('M/D/YY')}${
                      id === null ? ' (scheduled)' : ''
                    }`}
                  </SubBodyText>
                </Cell>
                <Cell>
                  <SubBodyText color="gray">
                    {completedOn ? (
                      moment(completedOn).format('M/D/YY')
                    ) : (
                      <SubBodyText>—</SubBodyText>
                    )}
                  </SubBodyText>
                </Cell>
                <Cell>
                  {score != null ||
                  isASRSInvalid(
                    assessmentType,
                    assessment.scorableResponseJson
                  ) ||
                  (assessmentType === PatientAssessmentType.ANCHOR &&
                    completedOn != null) ? (
                    <div className="flex flex-row gap-x-1">
                      {indicatesSiRisk(assessmentType, safetyRisk) && (
                        <Badge variant="negative">Safety risk</Badge>
                      )}
                      <ResultBadges
                        assessment={assessment as AssessmentBadgeInfo}
                      />
                      <ResultTooltip
                        assessment={assessment as PatientAssessmentRead}
                      />
                    </div>
                  ) : (
                    <SubBodyText color="gray">—</SubBodyText>
                  )}
                </Cell>
                <Cell>
                  <div className="flex justify-end">
                    {score != null ? (
                      <ScoreBadge assessment={assessment} />
                    ) : (
                      <SubBodyText color="gray">—</SubBodyText>
                    )}
                  </div>
                </Cell>
                <Cell>
                  {score != null && allSortedAssessments[idx + 1]?.score ? (
                    <div className="flex justify-end">
                      <ScoreDiffBadge
                        assessmentType={assessmentType}
                        prevScore={allSortedAssessments[idx + 1].score!}
                        currentScore={score}
                      />
                    </div>
                  ) : (
                    <SubBodyText>—</SubBodyText>
                  )}
                </Cell>
                <Cell>
                  {completedOn && (
                    <div className="flex">
                      <IconCaretRight size="1em" aria-label="View assessment" />
                    </div>
                  )}
                </Cell>
              </Row>
            );
          })}
        </TableBody>
      </Table>
      {totalCount > 3 && (
        <div className="flex flex-row-reverse px-4 py-2">
          {isPreviousData ? (
            <CircularProgress size={22} />
          ) : (
            <Button
              variant="link"
              onPress={() => setShowingAll((curr) => !curr)}
            >
              {showingAll ? 'Show less' : 'Show all'}
            </Button>
          )}
        </div>
      )}
    </>
  ) : (
    <Skeleton variant="rectangular" height={60} />
  );
};
