import React from "react";
import axios from "axios";
import Checkbox from "@mui/material/Checkbox";
import Button from "@mui/material/Button";
import IconButton from "@mui/material/IconButton";
import ArrowBack from "@mui/icons-material/ArrowBack";
import ArrowForward from "@mui/icons-material/ArrowForward";
import Select from "@mui/material/Select";
import MenuItem from "@mui/material/MenuItem";

import { TokenHandler } from "../utils/token-handler";
import { get_access_token } from "../utils/authorization.js";

class InpaintingBenchmark extends React.PureComponent {
  constructor(props) {
    super(props);
    this.state = {
      currentTestId: "",
      tests: [],
      requestList: [],
      currentRequestIndex: 0,
      accessToken: "",
      imageRatings: {},
      selectedImage: null,
    };
  }

  componentDidMount() {
    const tokenCallbackFunc = (jwt) => {
      this.setState({ accessToken: jwt }, () => {
        this.tokenHandler = new TokenHandler(jwt);
        this.fetchAllTests(); // Move this inside setState callback
      });
    };
    get_access_token(tokenCallbackFunc);
    window.addEventListener("keydown", this.handleKeyPress);
  }

  componentWillUnmount() {
    window.removeEventListener("keydown", this.handleKeyPress);
  }

  handleKeyPress = (event) => {
    if (event.code === "ArrowRight") {
      this.handleNavigation(1);
    } else if (event.code === "ArrowLeft") {
      this.handleNavigation(-1);
    } else if (event.code === "Enter") {
      this.submitRating();
    }
  };

  setSelectedImage = (image_url) => {
    this.setState({ selectedImage: image_url });
  };

  fetchAllTests = () => {
    axios
      .get(
        "https://n7jwv35eki.execute-api.us-west-2.amazonaws.com/prod/benchmark_inpainting/test",
        {
          headers: {
            Authorization: `Bearer ${this.state.accessToken}`,
          },
        }
      )
      .then((response) => {
        this.setState({ tests: response.data });
      });
  };

  handleTestIdChange = (event) => {
    const newTestId = event.target.value;
    this.setState({ currentTestId: newTestId }, () =>
      this.handleFetchTest(newTestId)
    );
  };

  handleFetchTest = (testId) => {
    const selectedTest = this.state.tests.find(
      (test) => test.test_id === testId
    );

    axios
      .get(
        "https://n7jwv35eki.execute-api.us-west-2.amazonaws.com/prod/benchmark_inpainting/test",
        {
          params: selectedTest,
          headers: {
            Authorization: `Bearer ${this.state.accessToken}`,
          },
        }
      )
      .then((response) => {
        this.setState({ requestList: response.data });
      });
  };

  handleNavigation = (direction) => {
    let newIndex = this.state.currentRequestIndex + direction;
    if (newIndex >= 0 && newIndex < this.state.requestList.length) {
      this.setState({ currentRequestIndex: newIndex });
      this.setState(this.state.requestList[newIndex]);
    }
  };

  handleRating = (imageId, isChecked) => {
    this.setState((prevState) => ({
      imageRatings: {
        ...prevState.imageRatings,
        [imageId]: isChecked ? 1 : 0,
      },
    }));
  };

  submitRating = () => {
    axios
      .post(
        "https://n7jwv35eki.execute-api.us-west-2.amazonaws.com/prod/benchmark_inpainting/test",
        {
          user_id:
            this.state.requestList[this.state.currentRequestIndex].user_id,
          request_id:
            this.state.requestList[this.state.currentRequestIndex].request_id,
          labels: this.state.imageRatings,
        },
        {
          headers: {
            Authorization: `Bearer ${this.state.accessToken}`,
          },
        }
      )
      .then(() => {
        this.setState({ imageRatings: {} }); // Reset imageRatings to an empty object
      });

    let updatedList = [...this.state.requestList];
    updatedList.splice(this.state.currentRequestIndex, 1);

    this.setState({ requestList: updatedList }, () => this.handleNavigation(0));
  };

  render() {
    const isFirstImage = this.state.currentRequestIndex === 0;
    const isLastImage =
      this.state.currentRequestIndex === this.state.requestList.length - 1;

    return (
      <div style={{ marginTop: "50px" }}>
        <div>
          <Select
            value={this.state.currentTestId}
            onChange={this.handleTestIdChange}
            displayEmpty
            size="small"
            style={{ marginRight: "10px" }}
          >
            <MenuItem value="" disabled>
              Select a test
            </MenuItem>
            {this.state.tests.map((test) => (
              <MenuItem value={test.test_id} key={test.test_id}>
                {test.test_id}
              </MenuItem>
            ))}
          </Select>
        </div>
        {this.state.selectedImage && (
          <div
            style={{
              position: "fixed",
              top: "50%",
              left: "50%",
              transform: "translate(-50%, -50%)",
              zIndex: 1000,
            }}
          >
            <img
              src={this.state.selectedImage}
              alt="description"
              style={{
                maxWidth: "768px",
                maxHeight: "768px",
                height: "auto",
                width: "auto",
              }}
            />
          </div>
        )}
        {this.state.requestList.length > 0 && (
          <div>
            <div
              style={{
                display: "flex",
                flexDirection: "column",
                alignItems: "center",
              }}
            >
              <div
                style={{
                  display: "flex",
                  alignItems: "center",
                  justifyContent: "center",
                }}
              >
                {!isFirstImage && (
                  <IconButton onClick={() => this.handleNavigation(-1)}>
                    <ArrowBack />
                  </IconButton>
                )}
                <div style={{ textAlign: "center" }}>
                  <p>
                    {
                      this.state.requestList[this.state.currentRequestIndex]
                        .content
                    }
                  </p>
                  {this.state.requestList[
                    this.state.currentRequestIndex
                  ]?.image?.map((image, index) => (
                    <img
                      key={index}
                      src={image.image_url}
                      alt="description"
                      title={image.image_id}
                      style={{
                        margin: "0 10px",
                        maxWidth: "256px",
                        maxHeight: "256px",
                        height: "auto",
                        width: "auto",
                      }}
                      onMouseEnter={() =>
                        this.setSelectedImage(image.image_url)
                      }
                      onMouseLeave={() => this.setSelectedImage(null)}
                    />
                  ))}
                </div>
                <div style={{ textAlign: "center" }}>
                  <p>
                    {
                      this.state.requestList[this.state.currentRequestIndex]
                        .prompt_target
                    }
                  </p>
                  <p>
                    {
                      this.state.requestList[this.state.currentRequestIndex]
                        .prompt_new
                    }
                  </p>
                  <div
                    style={{
                      display: "grid",
                      gridTemplateColumns: "repeat(2, 1fr)",
                      gap: "10px",
                    }}
                  >
                    {this.state.requestList[
                      this.state.currentRequestIndex
                    ]?.mask_image?.map((image, index) => (
                      <div key={index}>
                        <img
                          key={index}
                          src={image.image_url}
                          alt="description"
                          title={image.image_id}
                          style={{
                            margin: "0 10px",
                            maxWidth: "256px",
                            maxHeight: "256px",
                            height: "auto",
                            width: "auto",
                          }}
                          onMouseEnter={() =>
                            this.setSelectedImage(image.image_url)
                          }
                          onMouseLeave={() => this.setSelectedImage(null)}
                        />
                        <label>Mask alignment</label>
                        <Checkbox
                          color="primary"
                          checked={
                            this.state.imageRatings[`mask_${index}`] || false
                          }
                          onChange={(event) =>
                            this.handleRating(
                              `mask_${index}`,
                              event.target.checked ? 1 : 0
                            )
                          }
                        />
                      </div>
                    ))}
                  </div>
                </div>
                <div
                  style={{
                    display: "grid",
                    gridTemplateColumns: "repeat(2, 1fr)",
                    gap: "10px",
                  }}
                >
                  {this.state.requestList[
                    this.state.currentRequestIndex
                  ]?.output?.map((image, index) => (
                    <div key={index}>
                      <img
                        key={index}
                        src={image.image_url}
                        alt="description"
                        title={image.image_id}
                        style={{
                          margin: "0 10px",
                          maxWidth: "256px",
                          maxHeight: "256px",
                          height: "auto",
                          width: "auto",
                        }}
                        onMouseEnter={() =>
                          this.setSelectedImage(image.image_url)
                        }
                        onMouseLeave={() => this.setSelectedImage(null)}
                      />
                      <Checkbox
                        color="primary"
                        checked={
                          this.state.imageRatings[`output_${index}`] || false
                        }
                        onChange={(event) =>
                          this.handleRating(
                            `output_${index}`,
                            event.target.checked ? 1 : 0
                          )
                        }
                      />
                    </div>
                  ))}
                </div>
                {!isLastImage && (
                  <IconButton onClick={() => this.handleNavigation(1)}>
                    <ArrowForward />
                  </IconButton>
                )}
              </div>
            </div>
            <div
              style={{
                display: "flex",
                justifyContent: "center",
                marginTop: "10px",
              }}
            >
              <Button
                variant="contained"
                color="primary"
                onClick={this.submitRating}
              >
                Submit
              </Button>
            </div>
          </div>
        )}
      </div>
    );
  }
}

export default InpaintingBenchmark;
