// %BANNER_BEGIN%
// ---------------------------------------------------------------------
// %COPYRIGHT_BEGIN%
// Copyright (c) (2021-2022) Magic Leap, Inc. All Rights Reserved.
// Use of this file is governed by the Software License Agreement, located here: https://www.magicleap.com/software-license-agreement-ml2
// Terms and conditions applicable to third-party materials accompanying this distribution may also be found in the top-level NOTICE file appearing herein.
// %COPYRIGHT_END%
// ---------------------------------------------------------------------
// %BANNER_END%

using System;
using System.Collections;
using System.Collections.Generic;
using MagicLeap.Android;
using Unity.XR.CoreUtils;
using UnityEngine.XR.OpenXR;
using UnityEngine.XR.OpenXR.Features.Interactions;
using UnityEngine.XR.OpenXR.Features.MagicLeapSupport;

namespace UnityEngine.XR.MagicLeap
{
    /// <summary>
    /// Detects the focus distance by utilizing the eye tracking fixation point either
    /// directly or in conjunction with sphere casting colliders in the scene.  If
    /// eye tracking is not used or not available, this detector will fall back to
    /// sphere casting from headpose.
    /// This component expects a MagicLeapCamera to be in the scene and will set
    /// the MagicLeapCamera.StereoConvergencePoint to control focus distance.
    /// </summary>
    public class StereoConvergenceDetector : MonoBehaviour
    {
        #region NestedType / Constructors
        [Serializable]
        public enum EyeTrackingOptions
        {
            DoNotUseEyeTracking_UseHeadpose,
            SphereCastThroughEyeFixationPoint,
        }
        #endregion NestedType / Constructors

        #region [SerializeField] Private Members
        [Header("Sphere Casting")]
        [SerializeField]
        [Tooltip("Choose if eye tracking is used at all along with how to utilize the eye fixation point.  " +
                    "Headpose vector will provide a fall back if eye tracking is not used or not available.")]
        private EyeTrackingOptions _eyeTrackingOption = EyeTrackingOptions.SphereCastThroughEyeFixationPoint;
        [SerializeField]
        [Tooltip("The interval in seconds between detecting the focus point via sphere cast or direct eye fixation point.")]
        private float _sphereCastInterval = .1f;
        [SerializeField]
        [Tooltip("The radius to use for the sphere cast when sphere casting is used.")]
        private float _sphereCastRadius = .075f;
        [SerializeField]
        [Tooltip("The Raycast will set the distance to the furthest object when cast through multiple objects. This is so the object that is being observed and the object behind it look stable." +
                 "When set to 0 the raycast will only detect the first object. When non 0, the raycast will select the object behind the target as long as it's withing the specified distance.")]
        private float _maxGroupDistance = 3.0f; // Maximum distance for objects to be considered in the same group

        [SerializeField]
        [Tooltip("The max number of objects the Raycast can intersect with.")]
        private RaycastHit[] _hitsBuffer = new RaycastHit[30]; // Adjust size based on expected number of hits

        [SerializeField]
        [Tooltip("The layer mask for the sphere cast.")]
        private LayerMask _sphereCastMask;
        [Header("Debug Visuals")]
        [SerializeField]
        [Tooltip("Whether to show debug visuals for focus point detection.")]
        private bool _showDebugVisuals = false;
        [SerializeField]
        [Tooltip("Material representing sphere cast hit point.")]
        private Material _hitPointMaterial;
        #endregion [SerializeField] Private Members

        #region Private Members
        private GameObject _convergencePoint = null;
        private GameObject _hitPointVisual = null;
        private Coroutine _raycastRoutine = null;
        private InputDevice _eyesDevice;
        private XROrigin _xrOrigin = null;
        #endregion Private Members

        private MagicLeapRenderingExtensionsFeature renderFeature;

        #region MonoBehaviour Methods
        private void Awake()
        {
            SetupConvergencePointObject();
        }

        private void Start()
        {
            // Request EyeTracking when an eye tracking option is selected
            InitializeEyeTracking();
            InitializeRenderFeature();
            InitializeMainCameraSettings();
        }

        private void InitializeEyeTracking()
        {
            if (_eyeTrackingOption != EyeTrackingOptions.DoNotUseEyeTracking_UseHeadpose)
            {
                Permissions.RequestPermission(MLPermission.EyeTracking, null, OnPermissionDenied, OnPermissionDenied);
            }
        }

        private void InitializeRenderFeature()
        {
            renderFeature = OpenXRSettings.Instance.GetFeature<MagicLeapRenderingExtensionsFeature>();
            if (renderFeature == null || !renderFeature.enabled)
            {
                Debug.LogError("Focus Distance cannot be set. Disabling script. " +
                               "Ensure all requirements are met : \n" +
                               $"Render Feature is present : {renderFeature != null} \n" +
                               $"Render Feature is enabled : {(renderFeature != null ? renderFeature.enabled : renderFeature)}");
                enabled = false;
            }
        }

        private void InitializeMainCameraSettings()
        {
            if (!Camera.main)
            {
                Debug.LogError("No Main Camera Detected in Scene. Disabling script.");
                enabled = false;
                return;
            }
            // Detect if the main camera is part of an XROrigin-based rig by obtaining the
            // XROrigin reference as a parent.
            _xrOrigin = Camera.main.GetComponentInParent<XROrigin>();
        }

        private void OnEnable()
        {
            _raycastRoutine = StartCoroutine(DetectConvergencePoint());
        }

        private void OnDisable()
        {
            if (_raycastRoutine != null)
            {
                StopCoroutine(_raycastRoutine);
                _raycastRoutine = null;
            }

            if (_showDebugVisuals)
            {
                DisplayDebugVisuals(false);
            }
        }

        private void OnDestroy()
        {
            if (_raycastRoutine != null)
            {
                StopCoroutine(_raycastRoutine);
                _raycastRoutine = null;
            }

            if (_convergencePoint != null)
            {
                Destroy(_convergencePoint);
                _convergencePoint = null;
            }
        }
        #endregion MonoBehaviour Methods

        #region Private Methods
        private void SetupConvergencePointObject()
        {
            // Empty game object to represent the transform for the stereo convergence point
            if (!_showDebugVisuals) return;
            _convergencePoint = new GameObject("Stereo Convergence Point");
            _hitPointVisual = CreateDebugVisual(_hitPointMaterial);
        }

        private GameObject CreateDebugVisual(Material material)
        {
            var primitive = GameObject.CreatePrimitive(PrimitiveType.Sphere);
            SetupPrimitiveVisual(primitive, material);
            return primitive;
        }

        private void SetupPrimitiveVisual(GameObject primitive, Material material)
        {
            primitive.layer = gameObject.layer;
            primitive.transform.SetParent(_convergencePoint.transform);
            primitive.SetActive(false);
            if (material != null) primitive.GetComponent<Renderer>().material = material;
            Destroy(primitive.GetComponent<Collider>());
        }

        private IEnumerator DetectConvergencePoint()
        {
            while (enabled)
            {
                yield return _sphereCastInterval > 0 ? new WaitForSeconds(_sphereCastInterval) : null;

                bool focusPointDetected = false;
                Vector3 focusPoint = Vector3.zero;
                Vector3 rayOrigin = Camera.main.transform.position;
                Vector3 rayDirection = Camera.main.transform.forward;

                if (Permissions.CheckPermission(MLPermission.EyeTracking))
                {
                    yield return EnsureEyesDeviceValid();

                    if (TryGetEyeTrackingData(out Vector3 position, out Quaternion rotation))
                    {
                        rayDirection = TransformRotationBasedOnXROrigin(rotation) * Vector3.forward;
                        rayOrigin = TransformBasedOnXROrigin(position);
                    }
                }

                RaycastHit? raycastHit =
                    SelectFurthestObjectWithinCloseGroup(rayOrigin, rayDirection, _sphereCastRadius, _sphereCastMask);
                if (raycastHit.HasValue)
                {
                    focusPoint = raycastHit.Value.point;
                    focusPointDetected = true;
                }

                if (focusPointDetected)
                {
                    UpdateConvergencePointAndVisuals(focusPoint, rayOrigin, rayDirection);
                }
                else
                {
                    SetFocusDistance(null);
                    DisplayDebugVisuals(false);
                }
            }
        }


        private RaycastHit? SelectFurthestObjectWithinCloseGroup(Vector3 rayOrigin, Vector3 rayDirection, float sphereCastRadius, LayerMask sphereCastMask)
        {
            int hitCount = Physics.SphereCastNonAlloc(new Ray(rayOrigin, rayDirection), sphereCastRadius, _hitsBuffer, Camera.main.farClipPlane, sphereCastMask);

            if (hitCount == 0)
            {
                return null; // Return null if no hits
            }

            // Sort hits by distance from the origin point up to hitCount
            Array.Sort(_hitsBuffer, 0, hitCount, Comparer<RaycastHit>.Create((a, b) => a.distance.CompareTo(b.distance)));

            int startIndex = 0; // Start index of the current group
            RaycastHit? furthestHitInGroup = null; // Hold the furthest hit within a group satisfying the criteria

            for (int i = 1; i < hitCount; i++)
            {
                // Check if the current object is more than maxGroupDistance from the start of the group
                if (_hitsBuffer[i].distance - _hitsBuffer[startIndex].distance > _maxGroupDistance)
                {
                    // If previous group had more than one object, the last object of that group is the furthest
                    if (i - startIndex > 1)
                    {
                        furthestHitInGroup = _hitsBuffer[i - 1];
                        break; // Found a valid group, exit the loop
                    }
                    startIndex = i; // Start a new group from the current object
                }
                else if (i == hitCount - 1)
                {
                    // If the last object still within the group distance, it is the furthest
                    furthestHitInGroup = _hitsBuffer[i];
                    break; // Found a valid group, exit the loop
                }
            }

            return furthestHitInGroup;
        }

        private IEnumerator EnsureEyesDeviceValid()
        {
            while (!_eyesDevice.isValid)
            {
                _eyesDevice = InputSubsystem.Utils.FindMagicLeapDevice(InputDeviceCharacteristics.EyeTracking | InputDeviceCharacteristics.TrackedDevice);
                yield return new WaitForSeconds(1);
            }
        }

        private bool TryGetEyeTrackingData(out Vector3 position, out Quaternion rotation)
        {
            bool isTracked = _eyesDevice.TryGetFeatureValue(CommonUsages.isTracked, out isTracked) && isTracked;
            bool hasPositionData = _eyesDevice.TryGetFeatureValue(EyeTrackingUsages.gazePosition, out position);
            bool hasRotationData = _eyesDevice.TryGetFeatureValue(EyeTrackingUsages.gazeRotation, out rotation);
            return isTracked && hasPositionData && hasRotationData;
        }

        private Vector3 TransformBasedOnXROrigin(Vector3 position)
        {
            return _xrOrigin != null ? _xrOrigin.CameraFloorOffsetObject.transform.TransformPoint(position) : position;
        }

        private Quaternion TransformRotationBasedOnXROrigin(Quaternion rotation)
        {
            return _xrOrigin != null ? _xrOrigin.CameraFloorOffsetObject.transform.rotation * rotation : rotation;
        }

        private void UpdateConvergencePointAndVisuals(Vector3 focusPoint, Vector3 rayOrigin, Vector3 rayDirection)
        {
            _convergencePoint.transform.position = focusPoint;
            SetFocusDistance(_convergencePoint.transform);

            if (_showDebugVisuals)
            {
                DisplayDebugVisuals(true);
                _hitPointVisual.transform.localScale = Vector3.one * .02f;
                _hitPointVisual.transform.position = focusPoint;
            }
        }

        private void SetFocusDistance(Transform focusTarget)
        {
            // Get Focus Distance and log warnings if not within the allowed value bounds.
            Camera currentCamera = Camera.main;
            if (currentCamera == null)
            {
                return;
            }

            float focusDistance = currentCamera.stereoConvergence;
            if (focusTarget != null)
            {
                // From Unity documentation:
                // Note that camera space matches OpenGL convention: camera's forward is the negative Z axis.
                // This is different from Unity's convention, where forward is the positive Z axis.
                Vector3 worldForward = new Vector3(0.0f, 0.0f, -1.0f);
                Vector3 camForward = currentCamera.cameraToWorldMatrix.MultiplyVector(worldForward);
                camForward = camForward.normalized;

                // We are only interested in the focus object's distance to the camera forward tangent plane.
                focusDistance = Vector3.Dot(focusTarget.position - transform.position, camForward);
            }

            float nearClip = currentCamera.nearClipPlane;
            if (focusDistance < nearClip)
            {
                focusDistance = nearClip;
            }

            if (renderFeature == null)
                return;
            currentCamera.stereoConvergence = focusDistance;
            renderFeature.FocusDistance = focusDistance;
        }

        private void DisplayDebugVisuals(bool show)
        {
            if (_hitPointVisual != null)
            {
                _hitPointVisual.SetActive(show);
            }
        }

        private void OnPermissionDenied(string permission)
        {
           Debug.LogError($"{permission} denied, falling back to Headpose sphere cast.");
        }
        #endregion Private Methods
    }
}
