Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/graphs/Graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,13 @@ export class StandardGraph extends Graph<t.BaseGraphState, t.GraphNode> {
maxDepth: effectiveSubagentDepth,
createChildGraph: (input): StandardGraph => {
const childGraph = new StandardGraph(input);
childGraph.hookRegistry = this.hookRegistry;
/**
* Do not propagate `humanInTheLoop` into the child graph yet:
* nested subagent interrupts need a stable child checkpoint and
* resume bridge. Child hooks still fire; `ask` decisions fail
* closed inside the subagent until that flow is implemented.
*/
childGraph.toolOutputReferences = this.toolOutputReferences;
childGraph.eagerEventToolExecution = this.eagerEventToolExecution;
childGraph.toolExecution = this.toolExecution;
Expand Down
38 changes: 38 additions & 0 deletions src/hooks/__tests__/executeHooks.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,44 @@ describe('executeHooks', () => {
consoleWarnSpy.mockRestore();
});

describe('abort listener management', () => {
it('uses one abort listener for many hooks on one matcher', async () => {
const registry = new HookRegistry();
const listenerCounts = new Map<AbortSignal, number>();
let maxAbortListeners = 0;
const addEventListenerSpy = jest
.spyOn(AbortSignal.prototype, 'addEventListener')
.mockImplementation(function (
this: AbortSignal,
type: string,
_listener: EventListenerOrEventListenerObject | null
): void {
if (type !== 'abort') {
return;
}
const count = (listenerCounts.get(this) ?? 0) + 1;
listenerCounts.set(this, count);
maxAbortListeners = Math.max(maxAbortListeners, count);
});
const hooks = Array.from({ length: 12 }, () =>
runStartHook(async (): Promise<RunStartHookOutput> => ({}))
);

try {
registry.register('RunStart', { hooks });

await executeHooks({
registry,
input: runStartInput(),
timeoutMs: 1000,
});
expect(maxAbortListeners).toBe(1);
} finally {
addEventListenerSpy.mockRestore();
}
});
});

describe('empty matcher set', () => {
it('returns an empty aggregated result when no matchers are registered', async () => {
const registry = new HookRegistry();
Expand Down
34 changes: 27 additions & 7 deletions src/hooks/executeHooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ interface HookOutcome {
timedOut: boolean;
}

interface AbortRace {
promise: Promise<never>;
cleanup: () => void;
}

function freshResult(): AggregatedHookResult {
return {
additionalContexts: [],
Expand Down Expand Up @@ -110,10 +115,10 @@ async function runHook(
hook: WideCallback,
input: HookInput,
signal: AbortSignal,
abortPromise: Promise<never>,
matcher: WideMatcher
): Promise<HookOutcome> {
const hookPromise = Promise.resolve().then(() => hook(input, signal));
const { promise: abortPromise, cleanup } = makeAbortPromise(signal);
try {
const output = await Promise.race([hookPromise, abortPromise]);
return { matcher, output, error: null, timedOut: false };
Expand All @@ -124,8 +129,22 @@ async function runHook(
error: describeError(err),
timedOut: isTimeout(err),
};
}
}

async function runMatcherHooks(
matcher: WideMatcher,
input: HookInput,
signal: AbortSignal
): Promise<HookOutcome[]> {
const abortRace: AbortRace = makeAbortPromise(signal);
const tasks = matcher.hooks.map((hook) =>
runHook(hook, input, signal, abortRace.promise, matcher)
);
try {
return await Promise.all(tasks);
} finally {
cleanup();
abortRace.cleanup();
}
}

Expand Down Expand Up @@ -373,26 +392,27 @@ export async function executeHooks(
}

// --- SYNC CRITICAL SECTION: once-matcher removal must complete before any await ---
const tasks: Promise<HookOutcome>[] = [];
const tasks: Promise<HookOutcome[]>[] = [];
for (const matcher of matchers) {
if (!matchesQuery(matcher.pattern, matchQuery)) {
continue;
}
if (matcher.once === true) {
registry.removeMatcher(event, matcher, sessionId);
}
if (matcher.hooks.length === 0) {
continue;
}
const perHookTimeout = matcher.timeout ?? timeoutMs;
const matcherSignal = combineSignals(signal, perHookTimeout);
for (const hook of matcher.hooks) {
tasks.push(runHook(hook, input, matcherSignal, matcher));
}
tasks.push(runMatcherHooks(matcher, input, matcherSignal));
}
// --- END SYNC CRITICAL SECTION ---
if (tasks.length === 0) {
return freshResult();
}

const outcomes = await Promise.all(tasks);
const outcomes = (await Promise.all(tasks)).flat();
reportErrors(outcomes, event, logger);
const aggregated = fold(outcomes);
/**
Expand Down
Loading
Loading