Skip to content

Commit

Permalink
Consolidate function calling
Browse files Browse the repository at this point in the history
This PR removes duplication in handling side effects and
function calling by introducing a function calling emulation
mode for StylingAgent.

Bug: 360751542
Change-Id: I383ade655828e655700444a971a27de13ce291fa
Reviewed-on: https://chromium-review.googlesource.com/c/devtools/devtools-frontend/+/6286165
Commit-Queue: Alex Rudenko <[email protected]>
Reviewed-by: Nikolay Vitkov <[email protected]>
Reviewed-by: Ergün Erdoğmuş <[email protected]>
  • Loading branch information
OrKoN authored and Devtools-frontend LUCI CQ committed Feb 21, 2025
1 parent 5efc7e9 commit 0621aea
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 293 deletions.
239 changes: 131 additions & 108 deletions front_end/panels/ai_assistance/agents/AiAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,12 @@ export interface FunctionDeclaration<Args extends Record<string, unknown>, Retur
*/
parameters: Host.AidaClient.FunctionObjectParam<keyof Args>;
/**
* Provided a way to give information back to
* the UI before running the the handler
* Provided a way to give information back to the UI.
*/
displayInfoFromArgs?: (
args: Args,
) => {
title?: string, thought?: string, code?: string, suggestions?: [string, ...string[]],
title?: string, thought?: string, action?: string, suggestions?: [string, ...string[]],
};
/**
* Function implementation that the LLM will try to execute,
Expand All @@ -208,8 +207,25 @@ export interface FunctionDeclaration<Args extends Record<string, unknown>, Retur
}) => Promise<FunctionCallHandlerResult<ReturnType>>;
}

const OBSERVATION_PREFIX = 'OBSERVATION:';
const OBSERVATION_PREFIX = 'OBSERVATION: ';

interface AidaFetchResult {
text?: string;
functionCall?: Host.AidaClient.AidaFunctionCallResponse;
completed: boolean;
rpcId?: Host.AidaClient.RpcGlobalId;
}

/**
* AiAgent is a base class for implementing an interaction with AIDA
* that involves one or more requests being sent to AIDA optionally
* utilizing function calling.
*
* TODO: missing a test that action code is yielded before the
* confirmation dialog.
* TODO: missing a test for an error if it took
* more than MAX_STEPS iterations.
*/
export abstract class AiAgent<T> {
/** Subclasses need to define these. */
abstract readonly type: AgentType;
Expand Down Expand Up @@ -274,6 +290,7 @@ export abstract class AiAgent<T> {
function validTemperature(temperature: number|undefined): number|undefined {
return typeof temperature === 'number' && temperature >= 0 ? temperature : undefined;
}
const enableAidaFunctionCalling = declarations.length && !this.functionCallEmulationEnabled;
const request: Host.AidaClient.AidaRequest = {
client: Host.AidaClient.CLIENT_NAME,

Expand All @@ -282,7 +299,7 @@ export abstract class AiAgent<T> {

historical_contexts: history.length ? history : undefined,

...(declarations.length ? {function_declarations: declarations} : {}),
...(enableAidaFunctionCalling ? {function_declarations: declarations} : {}),
options: {
temperature: validTemperature(this.options.temperature),
model_id: this.options.modelId,
Expand All @@ -294,8 +311,8 @@ export abstract class AiAgent<T> {
client_version: Root.Runtime.getChromeVersion(),
},

functionality_type: declarations.length ? Host.AidaClient.FunctionalityType.AGENTIC_CHAT :
Host.AidaClient.FunctionalityType.CHAT,
functionality_type: enableAidaFunctionCalling ? Host.AidaClient.FunctionalityType.AGENTIC_CHAT :
Host.AidaClient.FunctionalityType.CHAT,

client_feature: this.clientFeature,
};
Expand All @@ -314,13 +331,13 @@ export abstract class AiAgent<T> {
return this.#origin;
}

parseResponse(response: Host.AidaClient.AidaResponse): ParsedResponse {
if (response.functionCalls && response.completed) {
throw new Error('Function calling not supported yet');
}
return {
answer: response.explanation,
};
/**
* Parses a streaming text response into a
* though/action/title/answer/suggestions component. This is only used
* by StylingAgent.
*/
parseTextResponse(response: string): ParsedResponse {
return {answer: response};
}

/**
Expand All @@ -346,10 +363,14 @@ export abstract class AiAgent<T> {
return answer;
}

protected handleAction(action: string, options?: {signal?: AbortSignal}):
AsyncGenerator<SideEffectResponse, ActionResponse, void>;
protected handleAction(): never {
throw new Error('Unexpected action found');
/**
* Special mode for StylingAgent that turns custom text output into a
* function call.
*/
protected functionCallEmulationEnabled = false;
protected emulateFunctionCall(_aidaResponse: Host.AidaClient.AidaResponse): Host.AidaClient.AidaFunctionCallResponse|
'no-function-call'|'wait-for-completion' {
throw new Error('Unexpected emulateFunctionCall. Only StylingAgent implements function call emulation');
}

async *
Expand Down Expand Up @@ -391,23 +412,30 @@ export abstract class AiAgent<T> {
};

let rpcId: Host.AidaClient.RpcGlobalId|undefined;
let parsedResponse: ParsedResponse|undefined = undefined;
let textResponse = '';
let functionCall: Host.AidaClient.AidaFunctionCallResponse|undefined = undefined;
try {
for await (const fetchResult of this.#aidaFetch(request, {signal: options.signal})) {
rpcId = fetchResult.rpcId;
parsedResponse = fetchResult.parsedResponse;
textResponse = fetchResult.text ?? '';
functionCall = fetchResult.functionCall;

// Only yield partial responses here and do not add partial answers to the history.
if (!fetchResult.completed && !fetchResult.functionCall && 'answer' in parsedResponse &&
parsedResponse.answer) {
if (!functionCall && !fetchResult.completed) {
const parsed = this.parseTextResponse(textResponse);
const partialAnswer = 'answer' in parsed ? parsed.answer : '';
if (!partialAnswer) {
continue;
}
// Only yield partial responses here and do not add partial answers to the history.
yield {
type: ResponseType.ANSWER,
text: parsedResponse.answer,
text: partialAnswer,
complete: false,
};
}
if (functionCall) {
break;
}
}
} catch (err) {
debugLog('Error calling the AIDA API', err);
Expand All @@ -425,7 +453,11 @@ export abstract class AiAgent<T> {

this.#history.push(request.current_message);

if (parsedResponse && 'answer' in parsedResponse && Boolean(parsedResponse.answer)) {
if (textResponse) {
const parsedResponse = this.parseTextResponse(textResponse);
if (!('answer' in parsedResponse)) {
throw new Error('Expected a completed response to have an answer');
}
this.#history.push({
parts: [{
text: this.formatParsedAnswer(parsedResponse),
Expand All @@ -441,66 +473,24 @@ export abstract class AiAgent<T> {
rpcId,
};
break;
} else if (parsedResponse && !('answer' in parsedResponse)) {
const {
title,
thought,
action,
} = parsedResponse;

if (title) {
yield {
type: ResponseType.TITLE,
title,
rpcId,
};
}

if (thought) {
yield {
type: ResponseType.THOUGHT,
thought,
rpcId,
};
}

this.#history.push({
parts: [{
text: this.#formatParsedStep(parsedResponse),
}],
role: Host.AidaClient.Role.MODEL,
});
}

if (action) {
const result = yield* this.handleAction(action, {signal: options.signal});
if (options?.signal?.aborted) {
if (functionCall) {
try {
const result = yield* this.#callFunction(functionCall.name, functionCall.args, options);
if (options.signal?.aborted) {
yield this.#createErrorResponse(ErrorType.ABORT);
break;
}
query = {text: `${OBSERVATION_PREFIX} ${result.output}`};
// Capture history state for the next iteration query.
request = this.buildRequest(query, Host.AidaClient.Role.USER);
yield result;
}
} else if (functionCall) {
try {
const result = yield* this.#callFunction(functionCall.name, functionCall.args);

if (result.result) {
yield {
type: ResponseType.ACTION,
output: JSON.stringify(result.result),
canceled: false,
};
}

query = {
query = this.functionCallEmulationEnabled ? {text: OBSERVATION_PREFIX + result.result} : {
functionResponse: {
name: functionCall.name,
response: result,
},
};
request = this.buildRequest(query, Host.AidaClient.Role.ROLE_UNSPECIFIED);
request = this.buildRequest(
query,
this.functionCallEmulationEnabled ? Host.AidaClient.Role.USER : Host.AidaClient.Role.ROLE_UNSPECIFIED);
} catch {
yield this.#createErrorResponse(ErrorType.UNKNOWN);
break;
Expand All @@ -524,18 +514,31 @@ export abstract class AiAgent<T> {
if (!call) {
throw new Error(`Function ${name} is not found.`);
}
this.#history.push({
parts: [{
functionCall: {
name,
args,
},
}],
role: Host.AidaClient.Role.MODEL,
});
if (this.functionCallEmulationEnabled) {
if (!call.displayInfoFromArgs) {
throw new Error('functionCallEmulationEnabled requires all functions to provide displayInfoFromArgs');
}
// Emulated function calls are formatted as text.
this.#history.push({
parts: [{text: this.#formatParsedStep(call.displayInfoFromArgs(args))}],
role: Host.AidaClient.Role.MODEL,
});
} else {
this.#history.push({
parts: [{
functionCall: {
name,
args,
},
}],
role: Host.AidaClient.Role.MODEL,
});
}

let code;
if (call.displayInfoFromArgs) {
const {title, thought, code, suggestions} = call.displayInfoFromArgs(args);
const {title, thought, action: callCode} = call.displayInfoFromArgs(args);
code = callCode;
if (title) {
yield {
type: ResponseType.TITLE,
Expand All @@ -549,7 +552,11 @@ export abstract class AiAgent<T> {
thought,
};
}
}

let result = await call.handler(args, options) as FunctionCallHandlerResult<unknown>;

if ('requiresApproval' in result) {
if (code) {
yield {
type: ResponseType.ACTION,
Expand All @@ -558,17 +565,6 @@ export abstract class AiAgent<T> {
};
}

if (suggestions) {
yield {
type: ResponseType.SUGGESTIONS,
suggestions,
};
}
}

let result = await call.handler(args, options);

if ('requiresApproval' in result) {
const sideEffectConfirmationPromiseWithResolvers = this.confirmSideEffect<boolean>();

void sideEffectConfirmationPromiseWithResolvers.promise.then(result => {
Expand Down Expand Up @@ -597,7 +593,8 @@ export abstract class AiAgent<T> {
if (!approvedRun) {
yield {
type: ResponseType.ACTION,
code: '',
code,
output: 'Error: User denied code execution with side effects.',
canceled: true,
};
return {
Expand All @@ -611,17 +608,30 @@ export abstract class AiAgent<T> {
});
}

if ('result' in result) {
yield {
type: ResponseType.ACTION,
code,
output: typeof result.result === 'string' ? result.result : JSON.stringify(result.result),
canceled: false,
};
}

if ('error' in result) {
yield {
type: ResponseType.ACTION,
code,
output: result.error,
canceled: false,
};
}

return result as {result: unknown};
}

async *
#aidaFetch(request: Host.AidaClient.AidaRequest, options?: {signal?: AbortSignal}): AsyncGenerator<
{
parsedResponse: ParsedResponse,
functionCall?: Host.AidaClient.AidaFunctionCallResponse, completed: boolean,
rpcId?: Host.AidaClient.RpcGlobalId,
},
void, void> {
#aidaFetch(request: Host.AidaClient.AidaRequest, options?: {signal?: AbortSignal}):
AsyncGenerator<AidaFetchResult, void, void> {
let aidaResponse: Host.AidaClient.AidaResponse|undefined = undefined;
let response = '';
let rpcId: Host.AidaClient.RpcGlobalId|undefined;
Expand All @@ -631,19 +641,32 @@ export abstract class AiAgent<T> {
debugLog('functionCalls.length', aidaResponse.functionCalls.length);
yield {
rpcId,
parsedResponse: {answer: ''},
functionCall: aidaResponse.functionCalls[0],
completed: true,
};
break;
}

if (this.functionCallEmulationEnabled) {
const emulatedFunctionCall = this.emulateFunctionCall(aidaResponse);
if (emulatedFunctionCall === 'wait-for-completion') {
continue;
}
if (emulatedFunctionCall !== 'no-function-call') {
yield {
rpcId,
functionCall: emulatedFunctionCall,
completed: true,
};
break;
}
}

response = aidaResponse.explanation;
rpcId = aidaResponse.metadata.rpcGlobalId ?? rpcId;
const parsedResponse = this.parseResponse(aidaResponse);
yield {
rpcId,
parsedResponse,
text: aidaResponse.explanation,
completed: aidaResponse.completed,
};
}
Expand Down
Loading

0 comments on commit 0621aea

Please sign in to comment.