import { GridCellParams, GridColDef, GridRowSelectionModel, useGridApiRef } from '@mui/x-data-grid-pro';
import { InformaticsQcMetrics, InformaticsQcThresholdEnum, SequenceType, Thresholds } from 'data/InformaticsQcData';
import { QualityCheckStatus, QualityCheckStatuses } from 'data/SampleTrackingData';
import { useEffect, useMemo, useState } from 'react';
import clsx from 'clsx';
import { InformaticsQcMetricsGridToolbar } from './InformaticsQcMetricsGridToolbar';
import { GridThresholdFilterOperators } from 'components/grid/columnFilters/GridThresholdFilterOperators';
import {
  compactGridHeaderClassName,
  labAssignedSampleId,
  medianCoverage,
  multiQc,
  percentDuplication,
  percentGcContent,
  qualityCheckStatus,
  rawTotalSequences,
  readsAligned,
  readsMapped,
  readsMappedPercent,
  readsProperlyPairedPercent,
  sampleId,
  sampleIdentifier,
  selectionStatus,
  sequenceRunId,
  totalReads,
  trimmedR1PercentGc,
  trimmedR2PercentGc,
  uniquelyMapped,
  uniquelyMappedPercent,
} from 'util/Constants';
import { renderCellMultiQcDownloadButton } from 'components/grid/GridCellMultiQcDownloadButton';
import { CompactGridWrapper } from '../../components/grid/CompactGridWrapper';
import useMemoTranslation from 'hooks/UseMemoTranslation';
import { uniq } from 'lodash';

export interface InformaticsQcMetricsGridProps {
  seqType: SequenceType;
  data: InformaticsCheckByGridData[];
  onQcSelection(sampleIds: string[], selection: QualityCheckStatus): void;
}

export interface InformaticsCheckByGridData extends InformaticsQcMetrics {
  sampleIdentifier: string;
  selectionStatus: string;
  qualityCheckStatus: QualityCheckStatus;
  sequenceType: SequenceType;
  previousQcUpdated: boolean;
}

export const InformaticsQcMetricsGrid = ({ seqType, data, onQcSelection }: InformaticsQcMetricsGridProps) => {
  const apiRef = useGridApiRef();

  const columns = useColumns(seqType);
  const rows = useRows(data, seqType);

  const [selectionModel, setSelectionModel] = useState<GridRowSelectionModel>();
  const [selectedSampleIds, setSelectedSampleIds] = useState<string[]>([]);

  const handleFilterModelChange = () => {
    setSelectionModel([]);
  };

  const handleQcButtonClick = (qcSelection: QualityCheckStatus) => {
    if (!selectedSampleIds) {
      return;
    }

    onQcSelection(selectedSampleIds, qcSelection);
    setSelectionModel([]);
  };

  useEffect(() => {
    if (!selectionModel) {
      return;
    }

    const sampleIds: string[] = [];
    selectionModel.forEach(selected => {
      const id = apiRef.current?.getRow(selected)?.sampleId;
      if (id) {
        sampleIds.push(id);
      }
    });

    setSelectedSampleIds(uniq(sampleIds));
  }, [selectionModel, apiRef]);

  return (
    <CompactGridWrapper
      apiRef={apiRef}
      rows={rows}
      columns={columns}
      checkboxSelection
      disableRowSelectionOnClick
      rowSelectionModel={selectionModel}
      onRowSelectionModelChange={newSelectionModel => setSelectionModel(newSelectionModel)}
      onFilterModelChange={handleFilterModelChange}
      density='compact'
      hideFooterSelectedRowCount={true}
      columnVisibilityModel={{
        sampleId: false,
      }}
      initialState={{
        pinnedColumns: {
          left: ['__check__', labAssignedSampleId],
          right: [qualityCheckStatus],
        },
        sorting: {
          sortModel: [{ field: 'medianCoverage', sort: 'desc' }],
        },
      }}
      components={{
        Toolbar: InformaticsQcMetricsGridToolbar,
      }}
      componentsProps={{
        toolbar: {
          onQcButtonClick: handleQcButtonClick,
          anyRowSelected: selectedSampleIds.length !== 0,
          apiRef: apiRef,
        },
      }}
    />
  );
};

const useColumns = (seqType: SequenceType): GridColDef[] => {
  const { t } = useMemoTranslation();

  return useMemo(() => {
    const startColumns: GridColDef[] = [
      {
        field: sampleId,
        headerName: t(sampleId),
        filterable: false,
      },
      {
        field: sampleIdentifier,
        headerName: t(sampleIdentifier),
        headerClassName: compactGridHeaderClassName,
        width: 200,
      },
    ];

    const dnaColumns: GridColDef[] =
      seqType === 'DNA'
        ? [
            {
              field: percentDuplication,
              headerName: t(percentDuplication),
              headerClassName: compactGridHeaderClassName,
              headerAlign: 'center',
              align: 'right',
              width: 110,
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.percentDuplication.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => value?.toFixed(2),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: medianCoverage,
              headerName: t(medianCoverage),
              headerClassName: compactGridHeaderClassName,
              headerAlign: 'center',
              align: 'right',
              width: 125,
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.medianCoverage.getThresholdEnum(params.value as number));
              },
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: rawTotalSequences,
              headerName: t(rawTotalSequences),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.rawTotalSequences.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => Number(value?.toFixed(2)).toLocaleString(),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: readsMapped,
              headerName: t(readsMapped),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.readsMapped.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => Number(value?.toFixed(2)).toLocaleString(),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: readsMappedPercent,
              headerName: t(readsMappedPercent),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.readsMappedPercent.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => value?.toFixed(2),
            },
            {
              field: percentGcContent,
              headerName: t(percentGcContent),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.percentGcContent.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => value?.toFixed(2),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: readsProperlyPairedPercent,
              headerName: t(readsProperlyPairedPercent),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(
                      Thresholds.readsProperlyPairedPercent.getThresholdEnum(params.value as number)
                    );
              },
              valueFormatter: ({ value }) => value?.toFixed(2),
              filterOperators: GridThresholdFilterOperators,
            },
          ]
        : [];

    const rnaColumns: GridColDef[] =
      seqType === 'RNA'
        ? [
            {
              field: percentDuplication,
              headerName: t(percentDuplication),
              headerClassName: compactGridHeaderClassName,
              headerAlign: 'center',
              align: 'right',
              width: 110,
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.percentDuplication.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => value?.toFixed(2),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: totalReads,
              headerName: t(totalReads),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.totalReads.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => value && Number(value?.toFixed(2)).toLocaleString(),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: readsAligned,
              headerName: t(readsAligned),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.readsAligned.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => value && Number(value?.toFixed(2)).toLocaleString(),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: uniquelyMapped,
              headerName: t(uniquelyMapped),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.uniquelyMapped.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => value && Number(value?.toFixed(2)).toLocaleString(),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: uniquelyMappedPercent,
              headerName: t(uniquelyMappedPercent),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(
                      Thresholds.uniquelyMappedPercent.getThresholdEnum(params.value as number)
                    );
              },
              valueFormatter: ({ value }) => value && value?.toFixed(2),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: trimmedR1PercentGc,
              headerName: t(trimmedR1PercentGc),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.trimmedR1PercentGc.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => value && value?.toFixed(2),
              filterOperators: GridThresholdFilterOperators,
            },
            {
              field: trimmedR2PercentGc,
              headerName: t(trimmedR2PercentGc),
              headerClassName: compactGridHeaderClassName,
              width: 125,
              headerAlign: 'center',
              align: 'right',
              type: 'number',
              cellClassName: (params: GridCellParams) => {
                return params.value == null
                  ? ''
                  : getThresholdCellClassName(Thresholds.trimmedR2PercentGc.getThresholdEnum(params.value as number));
              },
              valueFormatter: ({ value }) => value && value?.toFixed(2),
              filterOperators: GridThresholdFilterOperators,
            },
          ]
        : [];

    const endColumns: GridColDef[] = [
      {
        field: selectionStatus,
        headerName: t(selectionStatus),
        headerClassName: compactGridHeaderClassName,
        width: 140,
        headerAlign: 'center',
        align: 'center',
        valueFormatter: ({ value }) => value && t(value),
      },
      {
        field: sequenceRunId,
        headerName: t(multiQc),
        headerClassName: compactGridHeaderClassName,
        width: 80,
        headerAlign: 'center',
        align: 'center',
        filterable: false,
        sortable: false,
        renderCell: renderCellMultiQcDownloadButton,
      },
      {
        field: qualityCheckStatus,
        headerName: t(qualityCheckStatus),
        headerClassName: compactGridHeaderClassName,
        headerAlign: 'center',
        align: 'center',
        width: 135,
        type: 'singleSelect',
        valueOptions: QualityCheckStatuses,
        valueFormatter: ({ value }) => (value === 'NotYetDecided' ? 'Not Yet Decided' : t(value)),
        renderCell: params => {
          if (params.value === 'NotYetDecided') {
            return <></>;
          } else {
            return <>{t(params.value)}</>;
          }
        },
      },
    ];

    return startColumns.concat(dnaColumns, rnaColumns, endColumns);
  }, [seqType, t]);
};

function getThresholdCellClassName(value: InformaticsQcThresholdEnum) {
  return clsx('livinglab-grid', {
    fail: value === InformaticsQcThresholdEnum.Fail,
    warn: value === InformaticsQcThresholdEnum.Warn,
    pass: value === InformaticsQcThresholdEnum.Pass,
  });
}

const useRows = (data: InformaticsCheckByGridData[], seqType: SequenceType) => {
  return useMemo(() => {
    let id = 0;
    let rows: any[] = [];

    if (seqType === 'DNA') {
      rows = data.map(d => {
        return {
          id: id++,
          ...d,
          percentDuplication: d.percentDuplication,
          medianCoverage: d.medianCoverage,
          rawTotalSequences: getMetricPerM(d.rawTotalSequences),
          readsMapped: getMetricPerM(d.readsMapped),
          readsMappedPercent: d.readsMappedPercent,
          percentGcContent: d.percentGcContent,
          readsProperlyPairedPercent: d.readsProperlyPairedPercent,
        };
      });
    } else if (seqType === 'RNA') {
      rows = data.map(d => {
        return {
          id: id++,
          ...d,
          percentDuplication: d.percentDuplication,
          totalReads: getMetricPerM(d.totalReads),
          readsAligned: getMetricPerM(d.readsAligned),
          uniquelyMapped: getMetricPerM(d.uniquelyMapped),
          uniquelyMappedPercent: d.uniquelyMappedPercent,
          percentDups: d.percentDups,
          trimmedR1PercentGc: d.trimmedR1PercentGc,
          trimmedR2PercentGc: d.trimmedR2PercentGc,
        };
      });
    } else if (seqType === 'TNA') {
      rows = data.map(d => {
        return {
          id: id++,
          ...d,
        };
      });
    }
    return rows;
  }, [data, seqType]);
};

function getMetricPerM(metric?: number) {
  if (!metric) {
    return metric;
  }
  return metric / 1000;
}
