import { Box } from "@mui/material";
import { ResponsiveBar } from "@nivo/bar";
import moment from "moment";
import numeral from "numeral";
import colorScheme from "./colorScheme";

const tooltipConstructor = ({ id, value, indexValue, color }) => (
  <Box
    sx={{
      backgroundColor: "white",
      padding: "10px",
      border: `2px solid ${color}`,
      borderRadius: "4px",
      boxShadow: "0 2px 4px rgba(0,0,0,0.1)",
    }}
  >
    <Box sx={{ mb: 1, fontSize: "0.875rem" }}>
      <strong>Story ID:</strong> {id}
    </Box>
    <Box sx={{ mb: 1, fontSize: "0.875rem" }}>
      <strong>Pageviews:</strong> {numeral(value).format("0,0")}
    </Box>
    <Box sx={{ fontSize: "0.875rem" }}>
      <strong>Date:</strong> {moment(indexValue).format("LL")}
    </Box>
  </Box>
);

const createLegends = (storyIds, colorScheme, options = {}) => {
  // Default options
  const { itemsPerRow = 7, startY = 120, rowHeight = 30 } = options;

  // Calculate how many rows we need
  const totalItems = storyIds.length;
  const totalRows = Math.ceil(totalItems / itemsPerRow);

  const legends = [];

  for (let row = 0; row < totalRows; row++) {
    const startIndex = row * itemsPerRow;
    const endIndex = Math.min(startIndex + itemsPerRow, totalItems);

    if (startIndex >= totalItems) break;

    legends.push({
      data: storyIds.slice(startIndex, endIndex).map((cur, index) => ({
        id: cur,
        label: cur,
        color: colorScheme[startIndex + (index % colorScheme.length)],
      })),
      anchor: "bottom-left",
      direction: "row",
      justify: false,
      translateX: 0,
      translateY: startY + row * rowHeight,
      itemsSpacing: 2,
      itemWidth: 100,
      itemHeight: 20,
      itemDirection: "left-to-right",
      itemOpacity: 0.85,
      symbolSize: 20,
      effects: [
        {
          on: "hover",
          style: {
            itemOpacity: 1,
          },
        },
      ],
    });
  }

  return legends;
};

const BarStackedTS = ({ data }) => {
  if (!data?.data?.length) {
    return (
      <Box
        sx={{
          height: 400,
          display: "flex",
          alignItems: "center",
          justifyContent: "center",
        }}
      >
        No data available
      </Box>
    );
  }

  // Process and validate data
  const processedData = data.data
    .sort((a, b) => new Date(a.date) - new Date(b.date))
    .map((item) => ({
      ...item,
      // Add total pageviews for each day
      totalViews: Object.entries(item)
        .filter(([key]) => key !== "date")
        .reduce((sum, [, value]) => sum + (Number(value) || 0), 0),
    }));

  // Calculate max value for better y-axis scaling
  const maxValue = Math.max(...processedData.map((item) => item.totalViews));
  // Get the order of magnitude of the maximum value
  const magnitude = Math.floor(Math.log10(maxValue));

  // Determine the appropriate scale unit based on magnitude
  const scale = Math.pow(10, magnitude) / 2;

  // Round up to nearest scale unit and add a bit of padding
  const yAxisMax = Math.ceil(maxValue / scale) * scale;

  const legends = createLegends(data.storyIds, colorScheme);

  return (
    <Box
      sx={{
        height: {
          xs: "420px",
          sm: "500px",
          lg: `${600 + legends.length * 30}px`,
        },
        width: "100%",
      }}
    >
      <ResponsiveBar
        data={processedData}
        indexBy="date"
        keys={data.storyIds}
        colors={colorScheme}
        margin={{
          top: 50,
          right: 30,
          bottom: 150 + legends.length * 30,
          left: 60,
        }}
        padding={0.2}
        valueScale={{ type: "linear", min: 0, max: yAxisMax }}
        indexScale={{ type: "band", round: true }}
        enableLabel={false}
        axisBottom={{
          tickRotation: -45,
          format: (v) => moment(v).format("MMM D, YYYY"),
          tickSize: 5,
          tickPadding: 5,
          // Only show every nth tick if we have more than 10 data points
          tickValues:
            processedData.length > 10
              ? processedData
                  .filter(
                    (_, index) =>
                      index % Math.ceil(processedData.length / 10) === 0,
                  )
                  .map((d) => d.date)
              : undefined,
        }}
        axisLeft={{
          format: (v) => numeral(v).format("0,0"),
          tickSize: 5,
          tickPadding: 5,
        }}
        tooltip={tooltipConstructor}
        legends={legends}
        role="application"
        ariaLabel="Pageviews over time"
        animate={true}
        motionStiffness={90}
        motionDamping={15}
      />
    </Box>
  );
};

export default BarStackedTS;
