import { compact } from 'lodash';
import { ReactNode, useEffect, useState } from 'react';
import { FieldValues } from 'react-hook-form';
import { Merge, UnionToIntersection } from 'system';
import { SchemaOf, object } from 'yup';
import { z } from 'zod';

export type WizardAction<T extends FieldValues = Record<string, unknown>> = {
  label?: string;
  onClick?: ({
    values,
    handleNext,
    handleBack,
    last,
  }: {
    values: Merge<UnionToIntersection<T>>;
    handleNext: () => void;
    handleBack: () => void;
    last: boolean;
  }) => Record<string, unknown> | Promise<Record<string, unknown>> | void | Promise<void>;
  disabled?: boolean;
  hideProgress?: boolean;
  skipValidation?: boolean;
  align?: 'left' | 'right';
  variant?: 'text' | 'outlined' | 'contained';
  splitOptions?: {
    loading?: boolean;
    label?: string;
    onClick?: ({
      values,
      handleNext,
      handleBack,
      last,
    }: {
      values: Merge<UnionToIntersection<T>>;
      handleNext: () => void;
      handleBack: () => void;
      last: boolean;
    }) => Record<string, unknown> | Promise<Record<string, unknown>> | void | Promise<void>;
  }[];
  loading?: boolean;
};

export type WizardStep<
  TFields extends FieldValues = Record<string, unknown>,
  TFieldIds = unknown,
  TPropTypes extends Record<string, unknown> = Record<string, unknown>,
  TComponent extends (props: TPropTypes) => JSX.Element = (props: TPropTypes) => JSX.Element,
> = {
  id?: TFieldIds;

  header?: {
    icon?: string | ReactNode;
    subLabel?: string;
    label?: string;
  };

  schema?: SchemaOf<TFields>;
  shape?: z.ZodType<TFields>;
  FieldsComponent: TComponent;
  props?: TPropTypes;

  actions?: WizardAction<TFields>[];
  extraActions?: WizardAction<TFields>[];
};

export function useWizard<T extends Record<string, unknown>>({
  steps,
  startingStep,
  goToStep,
}: {
  steps: WizardStep<T>[];
  startingStep?: number;
  goToStep?: number;
}) {
  const [stepIndex, setStepIndex] = useState(startingStep ?? 0);

  useEffect(() => {
    if ((goToStep ?? 0) - 1 >= 0) setStepIndex((goToStep ?? 1) - 1);
  }, [goToStep]);

  const activeStep = steps[stepIndex];
  const first = stepIndex === 0;
  const last = stepIndex === steps.length - 1;

  const handleNext = () => {
    return new Promise((resolve) => {
      setTimeout(() => {
        if (!last) {
          setStepIndex(Math.min(stepIndex + 1, steps.length - 1));
        }
        resolve('');
      }, 1);
    });
  };

  const handleBack = () => {
    if (!first) {
      setStepIndex(Math.max(stepIndex - 1, 0));
    }
  };

  const stepSchemas = compact(
    steps
      .slice(0, stepIndex + 1)
      .filter((step) => step.schema)
      .map((step) => step.schema)
  );
  const stepShapes = compact(
    steps
      .slice(0, stepIndex + 1)
      .filter((step) => step.shape)
      .map((step) => step.shape)
  );

  const schema = stepSchemas.length
    ? stepSchemas.reduce((currentSchema, stepSchema) => currentSchema.concat(stepSchema), object())
    : stepShapes.length
      ? stepShapes.reduce((currentShape, stepShape) => currentShape.and(stepShape))
      : object();

  return {
    handleNext,
    handleBack,
    schema,
    stepIndex,
    activeStep,
    first,
    last,
  };
}
