diff --git a/packages/core/src/Diagram.ts b/packages/core/src/Diagram.ts index c2ca6321..5a1a4a41 100644 --- a/packages/core/src/Diagram.ts +++ b/packages/core/src/Diagram.ts @@ -47,6 +47,12 @@ export class Diagram { return this } + disconnect(linkId: LinkId) { + this.links = this.links.filter(l => l.id !== linkId) + + return this + } + linksAtInputPortId(id: PortId | undefined): Link[] { return this.links.filter(link => link.targetPortId === id) } diff --git a/packages/ui/src/components/DataStory/DataStoryCanvas.tsx b/packages/ui/src/components/DataStory/DataStoryCanvas.tsx index d8e52a4e..9a36d35d 100644 --- a/packages/ui/src/components/DataStory/DataStoryCanvas.tsx +++ b/packages/ui/src/components/DataStory/DataStoryCanvas.tsx @@ -3,6 +3,7 @@ import React, { forwardRef, useCallback, useEffect, useId, useMemo, useRef, useS import { Background, BackgroundVariant, + Edge, EdgeChange, NodeChange, ReactFlow, @@ -25,6 +26,7 @@ import { useEscapeKey } from './hooks/useEscapeKey'; import { keyManager } from './keyManager'; import { getNodesWithNewSelection } from './getNodesWithNewSelection'; import { createDataStoryId, LinkCount, LinkId, NodeStatus, RequestObserverType } from '@data-story/core'; +import { useDragNode } from './useDragNode'; const nodeTypes = { commentNodeComponent: CommentNodeComponent, @@ -67,16 +69,19 @@ const Flow = ({ onNodesChange: state.onNodesChange, onEdgesChange: state.onEdgesChange, connect: state.connect, + disconnect: state.disconnect, onInit: state.onInit, onRun: state.onRun, addNodeFromDescription: state.addNodeFromDescription, toDiagram: state.toDiagram, updateEdgeCounts: state.updateEdgeCounts, updateEdgeStatus: state.updateEdgeStatus, + setEdges: state.setEdges, }); const { connect, + disconnect, nodes, edges, onNodesChange, @@ -87,6 +92,7 @@ const Flow = ({ toDiagram, updateEdgeCounts, updateEdgeStatus, + setEdges, } = useStore(selector, shallow); const id = useId() @@ -178,6 +184,11 @@ const Flow = ({ }); useEscapeKey(() => setSidebarKey!(''), flowRef); + const { draggedNode, onNodeDragStop, onNodeDrag } = useDragNode({ + connect, + disconnect, + edges, + }); return ( <> @@ -188,6 +199,22 @@ const Flow = ({ stroke-dashoffset: -10; } } + .react-flow__edge:hover { + cursor: crosshair; + } + ${draggedNode ? ` + .react-flow__edge { + opacity: 0.5; + } + .react-flow__edge[data-testid="rf__edge-${draggedNode.droppedOnEdge?.id}"] { + opacity: 1; + filter: drop-shadow(0 0 5px #4f46e5); + } + .react-flow__edge[data-testid="rf__edge-${draggedNode.droppedOnEdge?.id}"] path { + stroke: #4f46e5; + stroke-width: 3; + } + ` : ''} `} { + onNodesDelete={(nodesToDelete) => { + // console.log('onNodesDelete', nodesToDelete); + + nodesToDelete.forEach(node => { + const store = reactFlowStore.getState(); + const { edges } = store; + + // Find all incoming and outgoing edges for this node + const incomingEdges = edges.filter(e => e.target === node.id); + const outgoingEdges = edges.filter(e => e.source === node.id); + + // console.log({ + // incomingEdges, + // outgoingEdges, + // }); + + // For each incoming edge, connect it to all outgoing edges + incomingEdges.forEach(inEdge => { + outgoingEdges.forEach(outEdge => { + // Create a connection that will be handled by the store's connect method + connect({ + source: inEdge.source, + sourceHandle: inEdge.sourceHandle ?? null, + target: outEdge.target, + targetHandle: outEdge.targetHandle ?? null, + }); + }); + }); + }); + // focus on the diagram after node deletion to enhance hotkey usage focusOnFlow(); }} @@ -227,6 +283,8 @@ const Flow = ({ client, }); }} + onNodeDrag={onNodeDrag} + onNodeDragStop={onNodeDragStop} minZoom={0.25} maxZoom={8} fitView={true} @@ -235,7 +293,16 @@ const Flow = ({ }} onDragOver={useCallback((event) => { event.preventDefault(); - event.dataTransfer.dropEffect = 'move'; + // Allow dropping on edges + const target = event.target as HTMLElement; + const isEdge = target.closest('.react-flow__edge'); + const hasNodeType = event.dataTransfer.types.includes('application/reactflow'); + + if (isEdge && hasNodeType) { + event.dataTransfer.dropEffect = 'copy'; + } else { + event.dataTransfer.dropEffect = 'move'; + } }, [])} onDrop={ useCallback((event) => { diff --git a/packages/ui/src/components/DataStory/modals/addNodeForm.tsx b/packages/ui/src/components/DataStory/modals/addNodeForm.tsx index 5ca11051..1164e765 100644 --- a/packages/ui/src/components/DataStory/modals/addNodeForm.tsx +++ b/packages/ui/src/components/DataStory/modals/addNodeForm.tsx @@ -84,6 +84,11 @@ export const AddNodeFormContent = (props: AddNodeModalContentProps) => { )} key={nodeDescription.name} onClick={() => doAddNode(nodeDescription)} + draggable="true" + onDragStart={(event) => { + event.dataTransfer.setData('application/reactflow', nodeDescription.name); + event.dataTransfer.effectAllowed = 'move'; + }} >
{nodeDescription.category || 'Core'}:: diff --git a/packages/ui/src/components/DataStory/store/store.tsx b/packages/ui/src/components/DataStory/store/store.tsx index d22df91b..6443b7b1 100644 --- a/packages/ui/src/components/DataStory/store/store.tsx +++ b/packages/ui/src/components/DataStory/store/store.tsx @@ -66,6 +66,13 @@ export const createStore = () => createWithEqualityFn((set, get) => // Update the diagram get().updateDiagram(diagram) }, + disconnect: (linkId: string) => { + const diagram = get().toDiagram() + + diagram.disconnect(linkId) + + get().updateDiagram(diagram) + }, addNode: (node: ReactFlowNode) => { set({ nodes: [ diff --git a/packages/ui/src/components/DataStory/types.ts b/packages/ui/src/components/DataStory/types.ts index 791c7407..b137f7a6 100644 --- a/packages/ui/src/components/DataStory/types.ts +++ b/packages/ui/src/components/DataStory/types.ts @@ -113,6 +113,7 @@ export type StoreSchema = { updateEdgeStatus: (edgeStatus: { nodeId: NodeId, status: NodeStatus }[]) => void setEdges: (edges: Edge[]) => void; connect: OnConnect; + disconnect: (linkId: string) => void; /** Global Params */ params: Param[], diff --git a/packages/ui/src/components/DataStory/useDragNode.tsx b/packages/ui/src/components/DataStory/useDragNode.tsx new file mode 100644 index 00000000..53968e2c --- /dev/null +++ b/packages/ui/src/components/DataStory/useDragNode.tsx @@ -0,0 +1,125 @@ +import { type MouseEvent as ReactMouseEvent, useCallback, useState } from 'react'; +import { Edge } from '@xyflow/react'; +import { StoreSchema } from './types'; +import { ReactFlowNode } from '../Node/ReactFlowNode'; + +interface IntersectionResult { + isIntersecting: boolean; + edge?: Edge; + edgeElement?: SVGPathElement; +} + +function isIntersecting( + edgeRect: DOMRect, + nodeRect: DOMRect, + threshold: number = 0.33, +): boolean { + // check if there is any overlap. + if ( + edgeRect.left > nodeRect.right || + edgeRect.right < nodeRect.left || + edgeRect.top > nodeRect.bottom || + edgeRect.bottom < nodeRect.top + ) { + return false; + } + + // calculate the overlap area + const overlapLeft = Math.max(edgeRect.left, nodeRect.left); + const overlapRight = Math.min(edgeRect.right, nodeRect.right); + const overlapTop = Math.max(edgeRect.top, nodeRect.top); + const overlapBottom = Math.min(edgeRect.bottom, nodeRect.bottom); + + const overlapArea = + (overlapRight - overlapLeft) * (overlapBottom - overlapTop); + const nodeArea = nodeRect.width * nodeRect.height; + + // calculate the overlap ratio + const overlapRatio = overlapArea / nodeArea; + + return overlapRatio > threshold; +} + +export function useDragNode({ + connect, + disconnect, + edges, +}: { + connect: StoreSchema['connect']; + disconnect: StoreSchema['disconnect']; + edges: StoreSchema['edges']; +}) { + const [draggedNode, setDraggedNode] = useState<{ node: any, droppedOnEdge: any } | null>(null); + + const checkNodeEdgeIntersection = useCallback(( + dragNodeRect: DOMRect, + ): IntersectionResult => { + for (const edge of edges) { + const edgeElement = document.querySelector( + `[data-id="${edge.id}"]`, + ) as SVGPathElement; + + if (!edgeElement) continue; + + const edgeRect = edgeElement.getBoundingClientRect(); + const isEdgeCrossingNode = isIntersecting(edgeRect, dragNodeRect); + + if (isEdgeCrossingNode) { + return { + isIntersecting: true, + edge, + edgeElement: edgeElement, + }; + } + } + + return { isIntersecting: false }; + }, [edges.length]); + + const onNodeDrag = useCallback((event: ReactMouseEvent, node: ReactFlowNode) => { + // @ts-ignore + const nodeRect = event.target!.getBoundingClientRect() as unknown as DOMRect; + + const { edgeElement, edge, isIntersecting } = checkNodeEdgeIntersection(nodeRect); + + // The node must have inputs and outputs that can be connected + const isConnected = node.data.inputs.length > 0 && node.data.outputs.length > 0; + + if (isConnected && isIntersecting) { + edgeElement!.getAttribute('data-testid')?.replace('rf__edge-', ''); + setDraggedNode({ node, droppedOnEdge: edge }); + return; + } + + setDraggedNode(null); + }, [checkNodeEdgeIntersection]); + + const onNodeDragStop = useCallback((event: any, node: ReactFlowNode, nodes: ReactFlowNode[]) => { + if (!draggedNode?.droppedOnEdge) return; + const { node: node1, droppedOnEdge } = draggedNode; + const nodeOutputId = node.data.outputs[0].id; + const nodeInputId = node.data.inputs[0].id; + + connect({ + source: droppedOnEdge.source, + target: node.id, + sourceHandle: droppedOnEdge.sourceHandle, + targetHandle: nodeInputId, + }); + connect({ + source: node.id, + target: droppedOnEdge.target, + sourceHandle: nodeOutputId, + targetHandle: droppedOnEdge.targetHandle, + }); + + disconnect(droppedOnEdge.id) + setDraggedNode(null); + }, [draggedNode, connect, disconnect]); + + return { + onNodeDrag, + onNodeDragStop, + draggedNode, + } +} \ No newline at end of file