import '@tensorflow/tfjs-backend-webgl';
import * as tfjsWasm from '@tensorflow/tfjs-backend-wasm';
import 'regenerator-runtime/runtime';
import React, { useState, useEffect, useRef } from 'react';
import * as mpPose from '@mediapipe/pose';
import * as posedetection from '@tensorflow-models/pose-detection';

tfjsWasm.setWasmPaths(
    `https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-wasm@${
        tfjsWasm.version_wasm}/dist/`);

import {Camera} from '../camera';
import {setupDatGui} from '../option_panel';
import {STATE} from '../params';
import {setupStats} from '../stats_panel';
import {setBackendAndEnvFlags} from '../util';
import Overlay from '../components/overlay/overlay';

let detector, camera, stats;
let startInferenceTime, numInferences = 0;
let inferenceTimeSum = 0, lastPanelUpdate = 0;
let rafId;
const restartDetection = require('../pose_detection');

async function createDetector() {
    switch (STATE.model) {
      case posedetection.SupportedModels.PoseNet:
        return posedetection.createDetector(STATE.model, {
          quantBytes: 4,
          architecture: 'MobileNetV1',
          outputStride: 16,
          inputResolution: {width: 500, height: 500},
          multiplier: 0.75
        });
      case posedetection.SupportedModels.BlazePose:
        const runtime = STATE.backend.split('-')[0];
        if (runtime === 'mediapipe') {
          return posedetection.createDetector(STATE.model, {
            runtime,
            modelType: STATE.modelConfig.type,
            solutionPath: `https://cdn.jsdelivr.net/npm/@mediapipe/pose@${mpPose.VERSION}`
          });
        } else if (runtime === 'tfjs') {
          return posedetection.createDetector(
              STATE.model, {runtime, modelType: STATE.modelConfig.type});
        }
      case posedetection.SupportedModels.MoveNet:
        let modelType;
        if (STATE.modelConfig.type == 'lightning') {
          modelType = posedetection.movenet.modelType.SINGLEPOSE_LIGHTNING;
        } else if (STATE.modelConfig.type == 'thunder') {
          modelType = posedetection.movenet.modelType.SINGLEPOSE_THUNDER;
        } else if (STATE.modelConfig.type == 'multipose') {
          modelType = posedetection.movenet.modelType.MULTIPOSE_LIGHTNING;
        }
        const modelConfig = {modelType};
  
        if (STATE.modelConfig.customModel !== '') {
          modelConfig.modelUrl = STATE.modelConfig.customModel;
        }
        if (STATE.modelConfig.type === 'multipose') {
          modelConfig.enableTracking = STATE.modelConfig.enableTracking;
        }
        return posedetection.createDetector(STATE.model, modelConfig);
    }
}
  
async function checkGuiUpdate() {
    if (STATE.isTargetFPSChanged || STATE.isSizeOptionChanged) {
        camera = await Camera.setupCamera(STATE.camera);
        STATE.isTargetFPSChanged = false;
        STATE.isSizeOptionChanged = false;
    }

    if (STATE.isModelChanged || STATE.isFlagChanged || STATE.isBackendChanged) {
        STATE.isModelChanged = true;

        window.cancelAnimationFrame(rafId);

        if (detector != null) {
        detector.dispose();
        }

        if (STATE.isFlagChanged || STATE.isBackendChanged) {
        await setBackendAndEnvFlags(STATE.flags, STATE.backend);
        }

        try {
        detector = await createDetector(STATE.model);
        } catch (error) {
        detector = null;
        alert(error);
        }

        STATE.isFlagChanged = false;
        STATE.isBackendChanged = false;
        STATE.isModelChanged = false;
    }
}
  
function beginEstimatePosesStats() {
    startInferenceTime = (performance || Date).now();
}
  
function endEstimatePosesStats() {
    const endInferenceTime = (performance || Date).now();
    inferenceTimeSum += endInferenceTime - startInferenceTime;
    ++numInferences;

    const panelUpdateMilliseconds = 1000;
    if (endInferenceTime - lastPanelUpdate >= panelUpdateMilliseconds) {
        const averageInferenceTime = inferenceTimeSum / numInferences;
        inferenceTimeSum = 0;
        numInferences = 0;
        stats.customFpsPanel.update(
            1000.0 / averageInferenceTime, 120 /* maxValue */);
        lastPanelUpdate = endInferenceTime;
    }
}
  
async function renderResult() {
    if (camera.video.readyState < 2) {
        await new Promise((resolve) => {
        camera.video.onloadeddata = () => {
            resolve(video);
        };
        });
    }

    let poses = null;

    // Detector can be null if initialization failed (for example when loading
    // from a URL that does not exist).
    if (detector != null) {
        // FPS only counts the time it takes to finish estimatePoses.
        beginEstimatePosesStats();

        // Detectors can throw errors, for example when using custom URLs that
        // contain a model that doesn't provide the expected output.
        try {
        poses = await detector.estimatePoses(
            camera.video,
            {maxPoses: STATE.modelConfig.maxPoses, flipHorizontal: false});
        } catch (error) {
          detector.dispose();
          detector = null;
          alert(error);
        }

        endEstimatePosesStats();
    }

    camera.drawCtx();

    // The null check makes sure the UI is not in the middle of changing to a
    // different model. If during model change, the result is from an old model,
    // which shouldn't be rendered.
    if (poses && poses.length > 0 && !STATE.isModelChanged) {
        camera.drawResults(poses);
    }
}
  
async function renderPrediction() {
    await checkGuiUpdate();

    if (!STATE.isModelChanged) {
        await renderResult();
    }

    rafId = requestAnimationFrame(renderPrediction);
};



const Embed = () => {
  const [width, setWidth] = useState(0);
  const [height, setHeight] = useState(0);
  const [message, setMessage] = useState("");
  const [showInstructions, setShowInstructions] = useState(true);
  const [leftScore, setLeftScore] = useState(-1);
  const [rightScore, setRightScore] = useState(-1);
  const [totalScore, setTotalScore] = useState(-1);
  const [imageUrl, setImageUrl] = useState(null);
  const [leftScoreColor, setLeftScoreColor] = useState('silver');
  const [rightScoreColor, setRightScoreColor] = useState('silver');
  const [totalScoreColor, setTotalScoreColor] = useState('silver');
  const [showAssessment, setShowAssessment] = useState(true);
  const [showResults, setShowResults] = useState(false);
  const [isLookingLeft, setIsLookingLeft] = useState(false);
  const [results, setResults] = useState([]);

  let currentMaxScore = useRef(180); // Use a ref to keep the mutable value
  let notificationMsg = useRef("");

  const getScoreColor = (score, maxScore) => {
    if (score <= 0) {
      return 'silver';
    } else if (score >= maxScore) {
      return 'blue';
    } else if (score >= 0.9 * maxScore) {
      return 'green';
    } else if (score >= 0.8 * maxScore) {
      return 'orange';
    } else {
      return 'red';
    }
  };

  const handleMessage = (e) => {

    if (!e.data || !e.data.msgType) {
      //console.warn("Received message without msgType: ", e.data);
      return;
    }

    switch (e.data.msgType) {
      case "scoreMessage":
          try {
            const msgScore = JSON.parse(e.data.msgBody);
            
            setLeftScore(msgScore.leftScore);
            setRightScore(msgScore.rightScore);
            setTotalScore(msgScore.totalScore);

            setLeftScoreColor(getScoreColor(msgScore.leftScore, currentMaxScore.current));
            setRightScoreColor(getScoreColor(msgScore.rightScore, currentMaxScore.current));
            setTotalScoreColor(getScoreColor(msgScore.totalScore, currentMaxScore.current));

          } catch (error) {
              console.error("Error for scoreMessage:", error);
          }
          break;
      case "statusMessage":
          try {
            const msgStatus = JSON.parse(e.data.msgBody);
            setMessage(msgStatus.status);
            
            if (msgStatus.notification && msgStatus.notification != "" && msgStatus.notification !== notificationMsg.current) {
              console.log("Audio notification: " + msgStatus.notification);
              notificationMsg.current = msgStatus.notification;
            }

          } catch (error) {
              console.error("Error for status:", error);
          }
          break;
      case "currentTestMessage":
          try {
            const msgCurrentTest = JSON.parse(e.data.msgBody);
            if (msgCurrentTest.gif) {
              setImageUrl(msgCurrentTest.gif);
            } else {
              setImageUrl(null);
            }

            currentMaxScore.current = msgCurrentTest.score;

          } catch (error) {
              console.error("Error for currentTestMsg:", error);
          }
          break;
      case "resultMessage":
          try {
            const msgResult = JSON.parse(e.data.msgBody);
            if (msgResult) {
              setResults(msgResult);
              setShowAssessment(false);
              setShowResults(true);
            }

          } catch (error) {
              console.error("Error for resultMsg:", error);
          }
          break;
      case "userMessage":
          const msgUser = JSON.parse(e.data.msgBody);
          if (msgUser.inAssessment && msgUser.inAssessment == "true") {
            setShowInstructions(false);
          } else {
            setShowInstructions(true);
          }
          if (msgUser.orientation && msgUser.orientation == "left") {
            setIsLookingLeft(true);
          } else {
            setIsLookingLeft(false);
          }
          break;
      case "cameraSize":
          console.log("cameraSize:" + e.data.msgBody);
          break;
      default:
        console.log("Unhandled action from Analyzer: ", JSON.stringify(e.data));
        break;
    }
  }
  
  useEffect(() => {
    
    async function setUp() {
      const urlParams = new URLSearchParams('?model=blazepose');
      let params = new URLSearchParams(document.location.search);
      if (!urlParams.has('model')) {
        alert('Cannot find model in the query string. Use /?model=blazepose');
        return;
      }
      if (params.has('showUI') && params.get('showUI') == 0) {
        document.getElementById("overlay-container").style.display = "none";
      }
      await setupDatGui(urlParams);
      stats = setupStats();
      camera = await Camera.setupCamera(STATE.camera);
      await setBackendAndEnvFlags(STATE.flags, STATE.backend);
      detector = await createDetector();
      renderPrediction();

      const video = document.querySelector("#video");
      setWidth(video.width);
      setHeight(video.height);

      window.addEventListener("message", handleMessage)
    }

    setUp();
  },[]);

  const formatName = (name) => {
    return name
      .split('_')
      .map(word => word.charAt(0).toUpperCase() + word.slice(1))
      .join(' ');
  };

  const getIconAndColor = (score, maxScore) => {
    if (score >= maxScore) {
      return { icon: 'fas fa-check-circle', color: 'blue' };
    } else if (score >= 0.9 * maxScore) {
      return { icon: 'fas fa-check-circle', color: 'green' };
    } else if (score >= 0.8 * maxScore) {
      return { icon: 'fas fa-exclamation-circle', color: 'orange' };
    } else {
      return { icon: 'fas fa-exclamation-circle', color: 'red' };
    }
  };

  return (<>
      <div id="stats"></div>
      <div id="main">
          {showAssessment && (<div className="container">
              <div className="canvas-wrapper">
                  <canvas id="output"></canvas>
                  <video id="video" playsInline style={{transform: 'scaleX(-1)', visibility: 'hidden', width: 'auto', height: 'auto', WebkitTransform: 'scaleX(-1)'}}>
                  </video>
                  <div id="overlay-container">
                    <Overlay 
                      width={width} 
                      height={height} 
                      message={message} 
                      imageUrl={imageUrl} 
                      showInstructions={showInstructions} 
                      leftScore={leftScore} 
                      rightScore={rightScore} 
                      totalScore={totalScore} 
                      leftScoreColor={leftScoreColor} 
                      rightScoreColor={rightScoreColor} 
                      totalScoreColor={totalScoreColor}
                      isLookingLeft={isLookingLeft} 
                    />      
                  </div>              
              </div>
              <div id="scatter-gl-container"></div>
          </div>)}
          {showResults && (
            <div id="resultAssessment">
              <h2 style={{ textAlign: 'center', color: 'white' }}>Results</h2>
              <table style={{ width: '50%', color: 'white', textAlign: 'center', margin: '30px auto' }}>
                <thead>
                  <tr>
                    <th>Motion</th>
                    <th>Score</th>
                  </tr>
                </thead>
                <tbody>
                  {results.map((result, index) => (
                    <tr key={index}>
                      <td>{formatName(result.name)}</td>
                      <td>
                        {result.totalScore !== -1 ? (
                          <p>
                            <em 
                              className={getIconAndColor(result.totalScore, result.normalScore).icon} 
                              style={{ color: getIconAndColor(result.totalScore, result.normalScore).color, marginRight: '8px' }}
                            ></em>
                            Total: {result.totalScore}
                          </p>
                        ) : (
                          <>
                            <p>
                              <em 
                                className={getIconAndColor(result.leftScore, result.normalScore).icon} 
                                style={{ color: getIconAndColor(result.leftScore, result.normalScore).color, marginRight: '8px' }}
                              ></em>
                              Left: {result.leftScore}
                            </p>
                            <p>
                              <em 
                                className={getIconAndColor(result.rightScore, result.normalScore).icon} 
                                style={{ color: getIconAndColor(result.rightScore, result.normalScore).color, marginRight: '8px' }}
                              ></em>
                              Right: {result.rightScore}
                            </p>
                          </>
                        )}
                      </td>
                    </tr>
                  ))}
                </tbody>
              </table>
            </div>
          )}

      </div>
  </>);
};

export default Embed;
