import type { Range } from '@tanstack/react-virtual'
import type { UseTRPCInfiniteQueryResult } from '@trpc/react-query/shared'
import type { ReactElement } from 'react'
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { Functions } from '@goatlab/js-utils'
import { useFlashListStore } from '@sodium/shared-frontend-schemas'
import {
  defaultRangeExtractor,
  useWindowVirtualizer,
} from '@tanstack/react-virtual'

export interface DataWithId {
  id: string | number
  [key: string]: any
}

export interface PaginatedResult<T extends DataWithId> {
  total: number
  perPage: number
  currentPage: number
  lastPage: number
  nextPage: number
  previousPage: number
  data: T[]
}

export interface VirtualFeedParams<T extends DataWithId> {
  infiniteQuery: UseTRPCInfiniteQueryResult<
    any & PaginatedResult<T>,
    Record<string, any>,
    Record<string, any>
  >
  renderItem: (props: { virtualItem: any; item: any }) => ReactElement
  overrideData?: T[]
  horizontal?: boolean
  scrollToIndex?: number
}

const MAX_VIEWABLE_ITEMS = 1

export function VirtualFeed<T extends DataWithId>({
  infiniteQuery,
  renderItem,
  overrideData,
  horizontal = false,
  scrollToIndex,
}: VirtualFeedParams<T>) {
  const {
    status,
    data,
    fetchNextPage,
    isFetching,
    hasNextPage,
    isFetchingNextPage,
  } = infiniteQuery
  const {
    setFocusedItemId,
    setFocusedItem,
    setViewableItemIds,
    setViewableItems,
    setFocusedItemIndex,
    setViewableItemsIndex,
  } = useFlashListStore()

  const flatData = useMemo(
    () => overrideData ?? data?.pages.flatMap((page) => page.data) ?? [],
    [data, overrideData],
  )

  const totalDBRowCount = data?.pages?.[0]?.total ?? 0
  const totalFetched = flatData.length

  const fetchMoreOnBottomReached = useCallback(
    (containerRefElement?: HTMLDivElement | null) => {
      if (containerRefElement) {
        const { scrollHeight, scrollTop, clientHeight } = containerRefElement
        if (
          scrollHeight - scrollTop - clientHeight < 300 &&
          !isFetching &&
          totalFetched < totalDBRowCount &&
          flatData.length !== 0
        ) {
          fetchNextPage().catch(console.error)
        }
      }
    },
    [fetchNextPage, isFetching, totalFetched, totalDBRowCount],
  )

  const [activeItem, setActiveItem] = useState<string | null>(null)

  const setItemsInStore = (viewableItems: any) => {
    // Limit the number of viewable items at the time
    const focusedItem = viewableItems.slice(0, MAX_VIEWABLE_ITEMS)?.[0]
    const focusedItemId = focusedItem?.id
    const focusedItemIndex = focusedItem?.index

    const viewableItemIds = viewableItems.map((item: any) => item?.id)
    const viewableItemsIndex = viewableItems.map((item: any) => item?.index)

    setViewableItemIds(viewableItemIds)
    setViewableItemsIndex(viewableItemsIndex)
    setViewableItems(viewableItems)
    setFocusedItem(focusedItem)
    setFocusedItemId(focusedItemId)
    setFocusedItemIndex(focusedItemIndex)
  }

  const debouncedSetItems = Functions.debounce(setItemsInStore, 100)

  // Update the current activeItem when scrolling occurs
  const observer = useRef<IntersectionObserver | null>(null)

  // Item observer
  useEffect(() => {
    // Initialize the observer
    const ratios: Record<string, number> = {}

    observer.current = new IntersectionObserver(
      (entries) => {
        entries.forEach((entry) => {
          const getId = entry.target as any
          const id = getId.dataset.id
          ratios[id] = entry.intersectionRatio
        })

        let activePage = { id: '', ratio: 0 }
        Object.entries(ratios).forEach(([id, ratio]) => {
          if (ratio > activePage.ratio) {
            activePage = { id, ratio }
          }
        })

        if (activePage) {
          setActiveItem(activePage.id)
        }
      },
      {
        threshold: [0.1, 0.5, 1], // Adjust threshold for better visibility
        root: null, // Scrollable container
      },
    )

    return () => observer.current?.disconnect() // Clean up
  }, [])

  useEffect(() => {
    if (activeItem) {
      const activeIndex = Number(activeItem)
      const item = flatData[activeIndex]
      debouncedSetItems([item])
    }
  }, [activeItem])

  const observeElement = useCallback((element: any) => {
    if (observer.current && element) {
      observer.current.observe(element)
    }
  }, [])

  const estimateSize = useCallback(() => 60, [])
  const rangeExtractor = useCallback((range: Range) => {
    return defaultRangeExtractor(range)
  }, [])

  const virtualizer = useWindowVirtualizer({
    horizontal,
    count: totalFetched,
    estimateSize,
    overscan: 20,
    rangeExtractor,
  })

  useEffect(() => {
    const [lastItem] = [...virtualizer.getVirtualItems()].reverse()

    if (!lastItem) {
      return
    }

    if (
      lastItem.index >= flatData.length - 1 &&
      hasNextPage &&
      !isFetchingNextPage
    ) {
      void fetchNextPage()
    }
  }, [
    hasNextPage,
    fetchNextPage,
    flatData.length,
    isFetchingNextPage,
    virtualizer.getVirtualItems(),
  ])

  useEffect(() => {
    if (scrollToIndex !== undefined && virtualizer.scrollToIndex) {
      virtualizer.scrollToIndex(scrollToIndex, { align: 'center' })
    }
  }, [scrollToIndex, virtualizer])

  return (
    <div>
      {status === 'pending' ? (
        <p>Loading...</p>
      ) : (
        <div
          className="relative"
          onScroll={(e) => {
            fetchMoreOnBottomReached(e.target as HTMLDivElement)
          }}
        >
          <div
            style={{
              height: 'auto',
              width: '100%',
              position: 'relative',
            }}
          >
            {virtualizer.getVirtualItems().map((virtualItem) => {
              const item = flatData[virtualItem.index]
              return (
                <div
                  key={item.id}
                  data-id={virtualItem.index}
                  // data-vid={virtualItem.index}
                  ref={observeElement}
                >
                  {renderItem({
                    item,
                    virtualItem,
                  })}
                </div>
              )
            })}
          </div>
        </div>
      )}
    </div>
  )
}
