import { sumBy } from 'lodash';

import {
  Graph,
  GraphData,
  GraphExtend,
  GraphSettings,
  Link,
  Node,
} from '../model';
import { getWidth } from '../utils/links';
type GraphDimensions = {
  scaleX: number;
  scaleY: number;
};
type Margin = { top: number; left: number; right: number; bottom: number };

const getMinVal = (nodes: Node[], extend: GraphExtend, py: number) => {
  const sumNodes = sumBy(nodes, (d) => d.value ?? 0);
  return (extend.y1 - extend.y0 - (nodes.length - 1) * py) / sumNodes;
};

const getKy = (data: GraphData, py: number, settings: GraphSettings) => {
  const values: number[] = [];

  data.forEachColumn((nodes) => values.push(getMinVal(nodes, data.extend, py)));

  const minValue = Math.min(...values);

  return minValue * settings.scale;
};

const caclulateMarginValue = (
  value: number,
  { baseRadius, verticalMargin }: GraphSettings,
) => {
  return value > 0 ? value + verticalMargin + baseRadius : value;
};

const calculateMargin = (
  graph: Readonly<GraphData>,
  { settings }: Pick<Graph<Node, Link>, 'settings'>,
): Margin => {
  let totalTopLinksWidth = 0,
    totalBottomLinksWidth = 0,
    totalRightLinksWidth = 0,
    totalLeftLinksWidth = 0;

  const maxColumn = graph.maxColumns();

  graph.forEachLink((link: Link) => {
    const linkWidth = getWidth(link, graph.extend.ky);
    if (link.circular) {
      if (link.circularLinkType === 'top') {
        totalTopLinksWidth = totalTopLinksWidth + linkWidth;
      } else {
        totalBottomLinksWidth = totalBottomLinksWidth + linkWidth;
      }
      const { target, source } = graph.getNodeLinks(link);
      if (target.column === 0) {
        totalLeftLinksWidth = totalLeftLinksWidth + linkWidth;
      }

      if (source.column === maxColumn) {
        totalRightLinksWidth = totalRightLinksWidth + linkWidth;
      }
    }
  });

  return {
    top: caclulateMarginValue(totalTopLinksWidth, settings),
    bottom: caclulateMarginValue(totalBottomLinksWidth, settings),
    left: caclulateMarginValue(totalLeftLinksWidth, settings),
    right: caclulateMarginValue(totalRightLinksWidth, settings),
  };
};

const calculateGraphDimsions = (
  graph: Readonly<GraphData>,
  margin: Margin,
): GraphDimensions => {
  const { extend } = graph;
  const currentWidth = extend.x1 - extend.x0;
  const currentHeight = extend.y1 - extend.y0;

  const newWidth = currentWidth + margin.right + margin.left;
  const newHeight = currentHeight + margin.top + margin.bottom;

  const scaleX = currentWidth / newWidth;
  const scaleY = currentHeight / newHeight;

  const x0 = extend.x0 * scaleX + margin.left;
  const x1 = margin.right === 0 ? extend.x1 : extend.x1 * scaleX;
  const y0 = extend.y0 * scaleY + margin.top;
  const y1 = extend.y1 * scaleY;

  graph.setExtendValue('x0', x0);
  graph.setExtendValue('x1', x1);
  graph.setExtendValue('y0', y0);
  graph.setExtendValue('y1', y1);

  return { scaleX, scaleY };
};

const calculateNodeSize = (
  extend: GraphExtend,
  node: Node,
  { nodeDimensions }: Pick<GraphSettings, 'nodeDimensions'>,
  maxColumn: number,
) => {
  const x0 = extend.x0 ?? 0;
  const x1 = extend.x1 ?? 0;
  const column = node.column ?? 0;
  const mCol = maxColumn ?? 1;
  const width = nodeDimensions(node).width;
  const nodeX0 = x0 + column * ((x1 - x0 - width) / mCol);
  const nodeX1 = nodeX0 + (width ?? 10);
  return { x0: nodeX0, x1: nodeX1 };
};

export const adjustSankeySize = (graph: Readonly<Graph<any, any>>) => {
  const { settings, graph: data } = graph;
  const { extend } = data;
  //  let graph = cloneDeep(data);
  const py = extend.py ?? 0;

  const maxColumn = data.maxColumns();

  //override py if nodePadding has been set
  const { paddingRatio } = settings;
  if (paddingRatio) {
    let padding = Infinity;
    data.forEachColumn((nodes) => {
      const thisPadding = (extend.y1 * paddingRatio) / (nodes.length + 1);
      padding = thisPadding < padding ? thisPadding : padding;
    });
    data.setExtendValue('py', padding);
  } else {
    data.setExtendValue('py', settings.nodePadding);
  }
  data.setExtendValue('ky', getKy(data, py, settings));

  const margin = calculateMargin(data, graph);

  const graphDimensions = calculateGraphDimsions(data, margin);

  data.forEachNode((node) => {
    const { x0, x1 } = calculateNodeSize(
      data.extend,
      node,
      settings,
      maxColumn,
    );
    node.setValue('x0', x0);
    node.setValue('x1', x1);
  });

  //re-calculate widths
  data.setExtendValue('ky', extend.ky * graphDimensions.scaleY);

  data.forEachLink((link) => {
    link.setValue('width', getWidth(link, extend.ky));
  });

  return data;
};
