import Dagre from '@dagrejs/dagre';
import { forceCenter, forceLink, forceManyBody, forceSimulation } from "d3-force";
import { forEach, values } from 'lodash';
import { Edge, Node } from 'reactflow';

type ILayoutOption = {
    direction: "TB" | "LR";
}

const g = new Dagre.graphlib.Graph().setDefaultEdgeLabel(() => ({}));
export const getDagreLayoutedElements = (
    nodes: Node[] = [],
    edges: Edge[] = [],
    options: ILayoutOption = {
        direction: "LR",
    }
) => {
    g.setGraph({ rankdir: options.direction, edgesep: 300 });

    edges.forEach((edge) => g.setEdge(edge.source, edge.target));
    nodes.forEach((node) =>
        g.setNode(node.id, {
            height: node.height!,
            width: node.width!,
            label: node.id,
        })
    );
    Dagre.layout(g);
    return {
        nodes: nodes.map((node) => {
            const { x, y, height, width } = g.node(node.id);
            return {
                ...node,
                height,
                width,
                position: {
                    x: x - width / 2,
                    y: y - height / 2,
                },
            };
        }),
        edges,
    };
};

const simulation = forceSimulation()
  .force('charge', forceManyBody().strength(-1000))
  .force('center', forceCenter(window.innerWidth/2, window.innerHeight / 2))
  .alphaTarget(0.05)
  .stop();

export const getForceFieldLayoutedElements = (nodes: Node[], edges: Edge[]) => {
    const nodeMap: Record<number, Node> = {};
    const reverseMap: Record<string, number> = {};

    forEach(nodes, (node, i) => {
        nodeMap[i] = node;
        reverseMap[node.id] = i;
    });

    const simulationNodes = nodes.map((node, i) => ({
        index: i,
        x: node.position.x,
        y: node.position.y,
    }));

    const simulationEdges = edges.map((edge) => ({
        source: reverseMap[edge.source],
        target: reverseMap[edge.target],
    }));

    simulation.nodes(simulationNodes).force(
        'link',
        forceLink().links(simulationEdges)
          .strength(0.2)
          .distance(300)
    );

    for (let i = 0; i < 3000; ++i) simulation.tick();

    forEach(simulationNodes, node => {
        nodeMap[node.index].position = {
            x: node.x,
            y: node.y,
        };
    });

    return { nodes: [...values(nodeMap)], edges };
};