import type { UseInfiniteQueryResult } from '@tanstack/react-query'
import type { Range, VirtualItem } from '@tanstack/react-virtual'
import type { HTMLAttributes, ReactElement } from 'react'
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { Functions } from '@goatlab/js-utils'
import { useFlashListStore } from '@sodium/shared-frontend-schemas'
import { cn } from '@src/utils/cn'
import { fastHash } from '@src/utils/fast-hash'
import {
  defaultRangeExtractor,
  useWindowVirtualizer,
} from '@tanstack/react-virtual'

export interface DataWithId {
  id: string | number
}

export interface PaginatedResult<T> {
  total: number
  perPage: number
  currentPage: number
  lastPage: number
  nextPage: number | null
  previousPage: number | null
  data: T[]
}

interface PaginatedBase {
  total: number
  perPage: number
  currentPage: number
  lastPage: number
  nextPage: number | null
  previousPage: number | null
  data: { id: string | number }[]
}

type ItemFromInfiniteQuery<T> = T extends { data?: { pages: (infer P)[] } }
  ? P extends { data: (infer D)[] }
    ? D extends DataWithId
      ? D
      : never
    : never
  : never

export interface VirtualFeedParams<
  T extends UseInfiniteQueryResult<{ pages: PaginatedBase[] }>,
> {
  infiniteQuery: T
  renderItem: (props: {
    virtualItem: VirtualItem
    item: ItemFromInfiniteQuery<T>
  }) => ReactElement
  horizontal?: boolean
  scrollToIndex?: number
  containerClassName?: HTMLAttributes<HTMLDivElement>['className']
  itemClassName?: HTMLAttributes<HTMLDivElement>['className']
}

const MAX_VIEWABLE_ITEMS = 1

export function VirtualFeedWindow<
  T extends UseInfiniteQueryResult<{ pages: PaginatedBase[] }, any>,
>({
  infiniteQuery,
  renderItem,
  horizontal = false,
  scrollToIndex,
  containerClassName,
  itemClassName,
}: VirtualFeedParams<T>) {
  const { status, data, fetchNextPage, hasNextPage, isFetchingNextPage } =
    infiniteQuery
  const {
    setFocusedItemId,
    setFocusedItem,
    setViewableItemIds,
    setViewableItems,
    setFocusedItemIndex,
    setViewableItemsIndex,
  } = useFlashListStore()
  const listRef = React.useRef<HTMLDivElement | null>(null)
  const estimatedItemSize = 950
  const estimateSize = useCallback(() => estimatedItemSize, [])

  const flatData = useMemo(
    () => (data?.pages ?? []).flatMap((page) => page.data),
    [fastHash(data)],
  )

  const totalFetched = flatData.length
  const [activeItem, setActiveItem] = useState<string | null>(null)

  const setItemsInStore = (viewableItems: any) => {
    const focusedItem = viewableItems.slice(0, MAX_VIEWABLE_ITEMS)?.[0]
    const focusedItemId = focusedItem?.id
    const focusedItemIndex = focusedItem?.index

    const viewableItemIds = viewableItems.map((item: any) => item?.id)
    const viewableItemsIndex = viewableItems.map((item: any) => item?.index)

    setViewableItemIds(viewableItemIds)
    setViewableItemsIndex(viewableItemsIndex)
    setViewableItems(viewableItems)
    setFocusedItem(focusedItem)
    setFocusedItemId(focusedItemId)
    setFocusedItemIndex(focusedItemIndex)
  }

  const debouncedSetItems = Functions.debounce(setItemsInStore, 100)

  const observer = useRef<IntersectionObserver | null>(null)

  useEffect(() => {
    const ratios: Record<string, number> = {}

    observer.current = new IntersectionObserver(
      (entries) => {
        entries.forEach((entry) => {
          const getId = entry.target as any
          const id = getId.dataset.id
          ratios[id] = entry.intersectionRatio
        })

        let activePage = { id: '', ratio: 0 }
        Object.entries(ratios).forEach(([id, ratio]) => {
          if (ratio > activePage.ratio) {
            activePage = { id, ratio }
          }
        })

        if (activePage) {
          setActiveItem(activePage.id)
        }
      },
      {
        threshold: [0.1, 0.5, 1],
        root: null,
      },
    )

    return () => observer.current?.disconnect()
  }, [])

  useEffect(() => {
    if (activeItem) {
      const activeIndex = Number(activeItem)
      const item = flatData[activeIndex]
      debouncedSetItems([item])
    }
  }, [activeItem])

  const observeElement = useCallback((element: any) => {
    if (observer.current && element) {
      observer.current.observe(element)
    }
  }, [])

  const rangeExtractor = useCallback((range: Range) => {
    return defaultRangeExtractor(range)
  }, [])

  const virtualizer = useWindowVirtualizer({
    horizontal,
    count: totalFetched,
    estimateSize,
    overscan: 30,
    rangeExtractor,
    scrollMargin: listRef.current?.offsetTop ?? 0,
  })

  useEffect(() => {
    const [lastItem] = [...virtualizer.getVirtualItems()].reverse()

    if (!lastItem) {
      return
    }

    if (
      lastItem.index >= totalFetched - 1 &&
      hasNextPage &&
      !isFetchingNextPage
    ) {
      void fetchNextPage()
    }
  }, [
    hasNextPage,
    fetchNextPage,
    totalFetched,
    isFetchingNextPage,
    virtualizer.getVirtualItems(),
  ])

  useEffect(() => {
    if (
      scrollToIndex !== undefined &&
      status === 'success' &&
      totalFetched > 0 &&
      scrollToIndex < totalFetched
    ) {
      virtualizer.scrollToIndex(scrollToIndex, { align: 'center' })
    }
  }, [scrollToIndex, virtualizer, status, flatData])

  return (
    <div ref={listRef}>
      {status === 'pending' ? (
        <p>Loading...</p>
      ) : (
        <div
          className={cn(containerClassName)}
          style={{
            height: `${virtualizer.getTotalSize()}px`,
            width: '100%',
            position: 'relative',
          }}
        >
          {virtualizer.getVirtualItems().map((virtualItem) => {
            const item = flatData[virtualItem.index] as ItemFromInfiniteQuery<T>
            if (!item) {
              return null
            }
            return (
              <div
                key={item.id}
                style={{
                  position: 'absolute',
                  top: 0,
                  left: 0,
                  width: '100%',
                  height: `${virtualItem.size}px`,
                  transform: `translateY(${
                    virtualItem.start - virtualizer.options.scrollMargin
                  }px)`,
                }}
                data-id={virtualItem.index}
                ref={observeElement}
                className={cn('w-full', itemClassName)}
              >
                {renderItem({ item, virtualItem })}
              </div>
            )
          })}
        </div>
      )}
    </div>
  )
}
