import React, { useCallback, useEffect } from 'react';
import {
  CardContent,
  ToggleButton,
  Box,
  Stack,
  Typography,
  Paper,
} from '@mui/material';

import { StyledToggleButtonGroup, TabPanel } from 'components/atoms';
import { ModelTabView } from 'types/app';
import { executeGraphqlOperation } from 'api';
import {
  twinsGetMemoryQuery,
  twinsGetModelQuery,
  twinsGetPromptQuery,
} from 'graphql/queries';
import { GraphQLResult } from 'aws-amplify/api';
import { useParams } from 'react-router-dom';
import { Memory, Model, Prompt as PromptType } from '@twins/types';
import { useModelConfig } from 'use/model-config';
import { PromptPosition } from 'types/enums';
import { tabs } from './tab-definition';
import { useUser } from 'use/user';

export default function ModelSettings() {
  const { modelConfigID } = useParams<{ modelConfigID: string }>();
  const { setPrompts, setMemory, setModels, setSystemPrompt, setFirstMessage } =
    useModelConfig();
  const [value, setValue] = React.useState<ModelTabView>('model');
  const { getJWT } = useUser();

  const handleChange = (
    event: React.MouseEvent<HTMLElement>,
    newValue: ModelTabView,
  ) => {
    if (newValue !== null) {
      setValue(newValue);
    }
  };

  const getMemory = useCallback(async () => {
    const jwt = await getJWT();
    const input = { modelConfigID: modelConfigID };
    const { data }: GraphQLResult<unknown> = await executeGraphqlOperation(
      twinsGetMemoryQuery,
      { input },
      jwt,
    );
    if (data) {
      const result = data as {
        twinsGetMemory: {
          data: Memory[];
        };
      };
      setMemory(result.twinsGetMemory.data);
    }
  }, [getJWT, modelConfigID, setMemory]);

  const getPrompts = useCallback(async () => {
    const jwt = await getJWT();
    const input = { modelConfigID: modelConfigID };
    const { data }: GraphQLResult<unknown> = await executeGraphqlOperation(
      twinsGetPromptQuery,
      { input },
      jwt,
    );
    if (data) {
      const result = data as {
        twinsGetPrompt: {
          data: PromptType[];
        };
      };
      setPrompts(result.twinsGetPrompt.data);
      const mainPrompt = result.twinsGetPrompt.data.find(
        (prompt) => prompt.position === PromptPosition.MAIN,
      );
      const startPrompt = result.twinsGetPrompt.data.find(
        (prompt) => prompt.position === PromptPosition.START,
      );
      if (mainPrompt) {
        setSystemPrompt(mainPrompt.content || '');
      }
      if (startPrompt) {
        setFirstMessage(startPrompt.content || '');
      }
    }
  }, [getJWT, modelConfigID, setPrompts, setSystemPrompt, setFirstMessage]);

  const getModels = useCallback(async () => {
    const jwt = await getJWT();
    const { data }: GraphQLResult<unknown> = await executeGraphqlOperation(
      twinsGetModelQuery,
      { input: { modelConfigID } },
      jwt,
    );
    if (data) {
      const result = data as {
        twinsGetModel: {
          data: Model[];
        };
      };
      setModels(result.twinsGetModel.data);
    }
  }, [getJWT, modelConfigID, setModels]);

  useEffect(() => {
    getMemory();
    getModels();
    getPrompts();
  }, [getPrompts, getModels, getMemory]);

  return (
    <Paper sx={{ p: 2 }}>
      <Box
        mb={2}
        display="flex"
        justifyContent="center"
      >
        <Paper
          elevation={0}
          sx={{
            display: 'flex',
            border: (theme) => `1px solid ${theme.palette.divider}`,
          }}
        >
          <StyledToggleButtonGroup
            value={value}
            exclusive
            onChange={handleChange}
            aria-label="model settings toggle buttons"
            sx={{ justifyContent: 'center' }}
          >
            {tabs.map((tab) => (
              <ToggleButton
                size="small"
                key={tab.value}
                value={tab.value}
                aria-label={tab.label}
                sx={{ flex: 1, textTransform: 'none' }}
                disabled={tab.disabled}
              >
                <Stack
                  direction="row"
                  justifyContent="center"
                  alignItems="center"
                  spacing={1}
                >
                  <Typography
                    variant="body2"
                    sx={{ fontWeight: 'bold' }}
                  >
                    {tab.label}
                  </Typography>
                </Stack>
              </ToggleButton>
            ))}
          </StyledToggleButtonGroup>
        </Paper>
      </Box>
      <CardContent style={{ padding: 0 }}>
        {tabs.map((tab, index) => (
          <TabPanel
            key={tab.label}
            index={index}
            value={tab.key}
          >
            {value === tab.value && tab.component}
          </TabPanel>
        ))}
      </CardContent>
    </Paper>
  );
}
