import {PropsWithChildren, useEffect, useMemo, useRef, useState} from 'react'
import maskStyles from './scroll-mask.module.css'
import classNames from 'classnames'

type ScrollMaskProps = PropsWithChildren<{
  mask?: 'left' | 'right' | 'both'
  className?: string
  threshold?: number
  as?: keyof JSX.IntrinsicElements
}>

export default function ScrollMask({
  mask = 'both',
  className,
  children,
  threshold = 24,
}: ScrollMaskProps) {
  const scrollContainer = useRef<HTMLDivElement>(null)
  const [startVisible, setStartVisible] = useState(true)
  const [endVisible, setEndVisible] = useState(true)
  const [scrollPosition, setScrollPosition] = useState(0)
  const [screenWidth, setScreenWidth] = useState(0)

  useEffect(() => {
    const containerWidth = scrollContainer.current?.offsetWidth
    if (containerWidth) {
      const firstChild = scrollContainer.current.firstChild
      if (firstChild) {
        setStartVisible(
          elementInView(
            firstChild as HTMLElement,
            scrollContainer.current,
            threshold,
          ),
        )
      }
      const lastChild = scrollContainer.current.lastChild
      if (lastChild) {
        setEndVisible(
          elementInView(
            lastChild as HTMLElement,
            scrollContainer.current,
            threshold,
          ),
        )
      }
    }
  }, [scrollPosition, scrollContainer, screenWidth])

  useEffect(() => {
    const handleResize = () => setScreenWidth(window.innerWidth)
    window.addEventListener('resize', handleResize)

    return () => {
      window.removeEventListener('resize', handleResize)
    }
  }, [])

  const maskClass = useMemo(() => {
    if (!startVisible && !endVisible && mask === 'both') {
      return maskStyles.maskBoth
    } else if (!startVisible) {
      return maskStyles.maskLeft
    } else if (!endVisible) {
      return maskStyles.maskRight
    }
  }, [startVisible, endVisible])

  return (
    <div
      ref={scrollContainer}
      className={classNames(className, maskClass)}
      onScroll={(event) => setScrollPosition(event.currentTarget.scrollLeft)}
    >
      {children}
    </div>
  )
}

const elementInView = (
  element: HTMLElement,
  container: HTMLElement,
  threshold: number,
) => {
  const {left: containerLeft, right: containerRight} =
    container.getBoundingClientRect()
  const {left: elLeft, width: elWidth} = element.getBoundingClientRect()
  const viewportWidth = typeof window !== 'undefined' ? window.innerWidth : 0
  const inView = [
    Math.max(0, containerLeft),
    Math.min(viewportWidth, containerRight),
  ]

  return (
    elLeft + threshold >= inView[0] && elLeft + elWidth - threshold <= inView[1]
  )
}
