/*
 * Copyright Starburst Data, Inc. All rights reserved.
 *
 * THIS IS UNPUBLISHED PROPRIETARY SOURCE CODE OF STARBURST DATA.
 * The copyright notice above does not evidence any
 * actual or intended publication of such source code.
 *
 * Redistribution of this material is strictly prohibited.
 */
import React, { ReactElement, useCallback, useEffect, useRef, useState } from 'react';
import { graphlib } from 'dagre';
import { GraphNode, GroupNode } from './GroupNode';
import { GraphEdge, GroupEdgePath } from './GroupEdgePath';
import { GraphGroup, GroupCluster } from './GroupCluster';
import { createDragObservable$ } from './dragObservable';
import { darkThemePalette, palette } from '../../themes/palette';
import { createScrollObservable$, Scroll } from './scrollObservable';
import { shadows } from '../../themes/shadows';
import { createUseStyles } from 'react-jss';
import clsx from 'clsx';
import { useThemeMode } from '../../app/UIThemeContextProvider';

export interface GraphClasses {
    nodes?: Record<string, string>;
}

interface GraphProps {
    layout: graphlib.Graph;
    nodes: GraphNode[];
    edges: GraphEdge[];
    groups: GraphGroup[];
    onNodeHover?: (node: GraphNode) => void;
    onNodeHoverLost?: (node: GraphNode) => void;
    classes?: GraphClasses;
}

interface ScaleDescriptor {
    value: number;
    min: number;
    max: number;
}

interface TranslateDescriptor {
    x: number;
    y: number;
}

interface TransformDescriptor {
    translate: TranslateDescriptor;
    scale: ScaleDescriptor;
}

export const Graph = ({
    layout,
    nodes,
    edges,
    groups,
    onNodeHover = () => void 0,
    onNodeHoverLost = () => void 0,
    classes: { nodes: nodeClasses = {} } = {},
}: GraphProps): ReactElement => {
    const [transform, setTransform] = useState<TransformDescriptor>({
        translate: {
            x: 0,
            y: 0,
        },
        scale: {
            value: 1,
            min: 1,
            max: 1,
        },
    });

    const onWheelEventHandler = useCallback(function (scroll: Scroll) {
        const { scrollDelta, x, y } = scroll;
        setTransform((prevState) => {
            if (svgRef.current === null) {
                return prevState;
            }
            const newScale = calculateScaleChange(
                prevState.scale.value,
                scrollDelta,
                prevState.scale.min,
                prevState.scale.max
            );

            const scrollPoint = {
                x: (x - prevState.translate.x) / prevState.scale.value,
                y: (y - prevState.translate.y) / prevState.scale.value,
            };
            const scrollPointAfterRescale = {
                x: (x - prevState.translate.x) / newScale,
                y: (y - prevState.translate.y) / newScale,
            };
            const translationDiff = {
                x: (scrollPointAfterRescale.x - scrollPoint.x) * newScale,
                y: (scrollPointAfterRescale.y - scrollPoint.y) * newScale,
            };

            return {
                ...prevState,
                translate: {
                    x: prevState.translate.x + translationDiff.x,
                    y: prevState.translate.y + translationDiff.y,
                },
                scale: {
                    ...prevState.scale,
                    value: newScale,
                },
            };
        });
    }, []);

    const svgRef = useRef<SVGSVGElement>(null);
    useEffect(() => {
        if (!svgRef.current) {
            return;
        }
        const fitTransform = calculateFitTransform(
            layout.graph().width || 1000,
            layout.graph().height || 1000,
            svgRef.current?.clientWidth || 1000,
            svgRef.current?.clientHeight || 1000
        );
        setTransform(fitTransform);
        const dragSubscription = createDragObservable$(svgRef.current).subscribe(
            ({ movementX, movementY }) => {
                setTransform((prevState) => ({
                    ...prevState,
                    translate: {
                        x: prevState.translate.x + movementX,
                        y: prevState.translate.y + movementY,
                    },
                }));
            }
        );
        const scrollSubscription = createScrollObservable$(svgRef.current).subscribe(
            onWheelEventHandler
        );
        return () => {
            dragSubscription.unsubscribe();
            scrollSubscription.unsubscribe();
        };
    }, []);
    const themeMode = useThemeMode();
    const classes = useStyles();
    return (
        <div
            className={clsx(classes.root, {
                [classes.dmRoot]: themeMode === 'dark',
            })}>
            <svg ref={svgRef} width={'100%'} height={'100%'}>
                <defs>
                    <marker
                        id="arrow-end"
                        viewBox="0 0 10 10"
                        refX="9"
                        refY="5"
                        markerUnits="strokeWidth"
                        markerWidth="10"
                        markerHeight="10"
                        orient="auto">
                        <path d="M 0 0 L 10 5 L 0 10 L 5 5 z" fill={palette.info} />
                    </marker>
                    <marker
                        id="arrow-start"
                        viewBox="0 0 4 4"
                        refX="0"
                        refY="2"
                        markerUnits="strokeWidth"
                        markerWidth="4"
                        markerHeight="4"
                        orient="auto">
                        <path d="M 0 0 L 4 2 L 0 4 z" fill={palette.info} />
                    </marker>
                </defs>
                <g transform={createTransformString(transform)}>
                    <GroupCluster layout={layout} groups={groups} />
                    <GroupNode
                        classes={nodeClasses}
                        layout={layout}
                        nodes={nodes}
                        onNodeHover={onNodeHover}
                        onNodeHoverLost={onNodeHoverLost}
                    />
                    <GroupEdgePath layout={layout} edges={edges} />
                </g>
            </svg>
        </div>
    );
};

const useStyles = createUseStyles({
    root: {
        cursor: 'grab',
        backgroundColor: palette.nebulaNavy15,
        boxShadow: shadows[2],
        padding: '1rem',
        borderRadius: '4px',
        width: '100%',
        height: '100%',
    },
    dmRoot: {
        backgroundColor: darkThemePalette.bgLevel3,
    },
});

function createTransformString(transform: TransformDescriptor): string {
    return `translate(${transform.translate.x}, ${transform.translate.y}) scale(${transform.scale.value})`;
}

function calculateScaleChange(
    oldScale: number,
    delta: number,
    minScale: number,
    maxScale: number
): number {
    const newScale = oldScale - delta / 10000;
    return Math.max(Math.min(newScale, maxScale), minScale);
}

function calculateFitTransform(
    graphWidth: number,
    graphHeight: number,
    containerWidth: number,
    containerHeight: number
): TransformDescriptor {
    const scaleValue = Math.min(containerWidth / graphWidth, containerHeight / graphHeight);
    return {
        translate: {
            x: (containerWidth - scaleValue * graphWidth) / 2,
            y: (containerHeight - scaleValue * graphHeight) / 2,
        },
        scale: {
            value: scaleValue,
            min: scaleValue,
            max: 1,
        },
    };
}
