|
| 1 | +import { useChainCallback } from '@/composables/functional/useChainCallback' |
| 2 | +import { NodeSlotType } from '@/lib/litegraph/src/types/globalEnums' |
| 3 | +import type { ISlotType } from '@/lib/litegraph/src/interfaces' |
1 | 4 | import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode' |
2 | 5 | import type { INodeInputSlot } from '@/lib/litegraph/src/interfaces' |
3 | 6 | import { LiteGraph } from '@/lib/litegraph/src/litegraph' |
4 | 7 | import { transformInputSpecV1ToV2 } from '@/schemas/nodeDef/migration' |
5 | 8 | import type { ComboInputSpec, InputSpec } from '@/schemas/nodeDefSchema' |
| 9 | +import type { InputSpec as InputSpecV2 } from '@/schemas/nodeDef/nodeDefSchemaV2' |
6 | 10 | import { zDynamicComboInputSpec } from '@/schemas/nodeDefSchema' |
7 | 11 | import { useLitegraphService } from '@/services/litegraphService' |
8 | 12 | import { app } from '@/scripts/app' |
@@ -139,3 +143,141 @@ function dynamicComboWidget( |
139 | 143 | } |
140 | 144 |
|
141 | 145 | export const dynamicWidgets = { COMFY_DYNAMICCOMBO_V3: dynamicComboWidget } |
| 146 | + |
| 147 | +export function applyAutoGrow(node: LGraphNode, inputSpec: InputSpecV2) { |
| 148 | + const { addNodeInput } = useLitegraphService() |
| 149 | + //@ts-expect-error - implement min, define inputSpec |
| 150 | + const { input, min, names, prefix, max } = inputSpec.template |
| 151 | + const inputTypes: [Record<string, InputSpec> | undefined, boolean][] = [ |
| 152 | + [input.required, false], |
| 153 | + [input.optional, true] |
| 154 | + ] |
| 155 | + const inputsV2 = inputTypes.flatMap(([inputType, isOptional]) => |
| 156 | + Object.entries(inputType ?? {}).map(([name, v]) => |
| 157 | + transformInputSpecV1ToV2(v, { name, isOptional }) |
| 158 | + ) |
| 159 | + ) |
| 160 | + if (inputsV2.length !== 1) throw new Error('Not Implemented') |
| 161 | + |
| 162 | + function nameToInputIndex(name: string) { |
| 163 | + const index = node.inputs.findIndex((input) => input.name === name) |
| 164 | + if (index === -1) throw new Error('Failed to find input') |
| 165 | + return index |
| 166 | + } |
| 167 | + function nameToInput(name: string) { |
| 168 | + return node.inputs[nameToInputIndex(name)] |
| 169 | + } |
| 170 | + |
| 171 | + //In the distance, someone shouting YAGNI |
| 172 | + const trackedInputs: string[][] = [] |
| 173 | + function addInputGroup(insertionIndex: number) { |
| 174 | + const ordinal = trackedInputs.length |
| 175 | + const inputGroup: string[] = [] |
| 176 | + for (const input of inputsV2) { |
| 177 | + const namedSpec = { |
| 178 | + ...input, |
| 179 | + name: names ? names[ordinal] : prefix + ordinal |
| 180 | + } |
| 181 | + addNodeInput(node, namedSpec) |
| 182 | + const addedInput = node.spliceInputs(node.inputs.length - 1, 1)[0] |
| 183 | + node.spliceInputs(insertionIndex++, 0, addedInput) |
| 184 | + inputGroup.push(namedSpec.name) |
| 185 | + } |
| 186 | + trackedInputs.push(inputGroup) |
| 187 | + app.canvas.setDirty(true, true) |
| 188 | + } |
| 189 | + addInputGroup(node.inputs.length) |
| 190 | + function removeInputGroup(inputName: string) { |
| 191 | + const groupIndex = trackedInputs.findIndex((ig) => |
| 192 | + ig.some((inpName) => inpName === inputName) |
| 193 | + ) |
| 194 | + if (groupIndex == -1) throw new Error('Failed to find group') |
| 195 | + const group = trackedInputs[groupIndex] |
| 196 | + for (const nameToRemove of group) { |
| 197 | + const inputIndex = nameToInputIndex(nameToRemove) |
| 198 | + node.spliceInputs(inputIndex, 1) |
| 199 | + } |
| 200 | + trackedInputs.splice(groupIndex, 1) |
| 201 | + node.size[1] = node.computeSize([...node.size])[1] |
| 202 | + app.canvas.setDirty(true, true) |
| 203 | + } |
| 204 | + |
| 205 | + function inputConnected(index: number) { |
| 206 | + const input = node.inputs[index] |
| 207 | + const groupIndex = trackedInputs.findIndex((ig) => |
| 208 | + ig.some((inputName) => inputName === input.name) |
| 209 | + ) |
| 210 | + if (groupIndex == -1) throw new Error('Failed to find group') |
| 211 | + if ( |
| 212 | + groupIndex + 1 === trackedInputs.length && |
| 213 | + trackedInputs.length < (max ?? names.length) |
| 214 | + ) { |
| 215 | + const lastInput = trackedInputs[groupIndex].at(-1) |
| 216 | + if (!lastInput) return |
| 217 | + const insertionIndex = nameToInputIndex(lastInput) + 1 |
| 218 | + if (insertionIndex === 0) throw new Error('Failed to find Input') |
| 219 | + addInputGroup(insertionIndex) |
| 220 | + } |
| 221 | + } |
| 222 | + function inputDisconnected(index: number) { |
| 223 | + const input = node.inputs[index] |
| 224 | + if (trackedInputs.length === 1) return |
| 225 | + const groupIndex = trackedInputs.findIndex((ig) => |
| 226 | + ig.some((inputName) => inputName === input.name) |
| 227 | + ) |
| 228 | + if (groupIndex == -1) throw new Error('Failed to find group') |
| 229 | + if ( |
| 230 | + trackedInputs[groupIndex].some( |
| 231 | + (inputName) => nameToInput(inputName).link != null |
| 232 | + ) |
| 233 | + ) |
| 234 | + return |
| 235 | + //For each group from here to last group, bubble swap links |
| 236 | + for (let column = 0; column < trackedInputs[0].length; column++) { |
| 237 | + let prevInput = nameToInputIndex(trackedInputs[groupIndex][column]) |
| 238 | + for (let i = groupIndex + 1; i < trackedInputs.length; i++) { |
| 239 | + const curInput = nameToInputIndex(trackedInputs[i][column]) |
| 240 | + const linkId = node.inputs[curInput].link |
| 241 | + node.inputs[prevInput].link = linkId |
| 242 | + const link = linkId && node.graph?.links?.[linkId] |
| 243 | + if (link) link.target_slot = prevInput |
| 244 | + prevInput = curInput |
| 245 | + } |
| 246 | + node.inputs[prevInput].link = null |
| 247 | + } |
| 248 | + if ( |
| 249 | + trackedInputs.at(-2) && |
| 250 | + !trackedInputs.at(-2)?.some((name) => !!nameToInput(name).link) |
| 251 | + ) |
| 252 | + removeInputGroup(trackedInputs.at(-1)![0]) |
| 253 | + } |
| 254 | + |
| 255 | + let pendingConnection: number | undefined |
| 256 | + let swappingConnection = false |
| 257 | + const originalOnConnectInput = node.onConnectInput |
| 258 | + node.onConnectInput = function (slot: number, ...args) { |
| 259 | + pendingConnection = slot |
| 260 | + setTimeout(() => (pendingConnection = undefined), 50) |
| 261 | + return originalOnConnectInput?.apply(this, [slot, ...args]) ?? true |
| 262 | + } |
| 263 | + node.onConnectionsChange = useChainCallback( |
| 264 | + node.onConnectionsChange, |
| 265 | + (type: ISlotType, index: number, isConnected: boolean) => { |
| 266 | + if (type !== NodeSlotType.INPUT) return |
| 267 | + const inputName = node.inputs[index].name |
| 268 | + if (!trackedInputs.flat().some((name) => name === inputName)) return |
| 269 | + if (isConnected) { |
| 270 | + if (swappingConnection) return |
| 271 | + inputConnected(index) |
| 272 | + } else { |
| 273 | + if (pendingConnection === index) { |
| 274 | + swappingConnection = true |
| 275 | + setTimeout(() => (swappingConnection = false), 50) |
| 276 | + return |
| 277 | + } |
| 278 | + inputDisconnected(index) |
| 279 | + } |
| 280 | + } |
| 281 | + ) |
| 282 | +} |
| 283 | +//COMFY_AUTOGROW_V3 |
0 commit comments