diff --git a/implementations/report_generation/README.md b/implementations/report_generation/README.md index d5a35767..5f30539d 100644 --- a/implementations/report_generation/README.md +++ b/implementations/report_generation/README.md @@ -92,9 +92,9 @@ uv run --env-file .env python -m implementations.report_generation.evaluate --da ``` This script will run the Report Generation Agent against each element of the dataset -and then use an LLM-as-a-judge Evaluator Agent to evaluate each result. The evaluator -will check the data generated to produce the report against the ground truth and -produce a True/False score along with a reasoning. +and then use LLM-as-a-judge Evaluator Agents to evaluate each result. The evaluator +will check the data generated to produce the report and also the trajectory the +agent used against the ground truth and produce True/False scores along with a reasoning. At the end of the run, an evaluation report will be displayed along with a link to check details about the evaluation in Langfuse. diff --git a/implementations/report_generation/data/OnlineRetailReportEval.json b/implementations/report_generation/data/OnlineRetailReportEval.json index a4805a11..8446317c 100644 --- a/implementations/report_generation/data/OnlineRetailReportEval.json +++ b/implementations/report_generation/data/OnlineRetailReportEval.json @@ -3,192 +3,323 @@ "id": "1", "input": "Generate a monthly sales performance report.", "expected_output": { - "report_data": [ - ["2010-12", 748957.02], - ["2011-01", 560000.26], - ["2011-02", 498062.65], - ["2011-03", 683267.08], - ["2011-04", 493207.12], - ["2011-05", 723333.51], - ["2011-06", 691123.12], - ["2011-07", 681300.11], - ["2011-08", 682680.51], - ["2011-09", 1019687.62], - ["2011-10", 1070704.67], - ["2011-11", 1461756.25], - ["2011-12", 433668.01] - ], - "report_columns": ["SalesMonth", "TotalSales"], - "filename": "monthly_sales_performance_report.xlsx" + "final_report": { + "report_data": [ + ["2010-12", 748957.02], + ["2011-01", 560000.26], + ["2011-02", 498062.65], + ["2011-03", 683267.08], + ["2011-04", 493207.12], + ["2011-05", 723333.51], + ["2011-06", 691123.12], + ["2011-07", 681300.11], + ["2011-08", 682680.51], + ["2011-09", 1019687.62], + ["2011-10", 1070704.67], + ["2011-11", 1461756.25], + ["2011-12", 433668.01] + ], + "report_columns": ["SalesMonth", "TotalSales"], + "filename": "monthly_sales_performance_report.xlsx" + }, + "trajectory": { + "actions": [ + "execute_sql_query", + "execute_sql_query", + "execute_sql_query", + "write_report_to_file", + "output_text" + ], + "description": [ + "Check what are the tables that are available in the database", + "Check what are the columns that are available in the sales table", + "Query to retrieve the sales performance (quantity * price) per month", + "Send the report data to the function that writes the report to disk", + "Output text to the user with the report file as a Gradio hyperlink" + ] + } } }, { "id": "2", "input": "Generate a report of the top 5 selling products per year and the total sales value for each product.", "expected_output": { - "report_data": [ - ["2010", "REGENCY CAKESTAND 3 TIER", 26897.360000000022], - ["2010", "DOTCOM POSTAGE", 24671.189999999995], - ["2010", "WHITE HANGING HEART T-LIGHT HOLDER", 9877.820000000005], - ["2010", "RED WOOLLY HOTTIE WHITE HEART.", 9291.729999999996], - ["2010", "PAPER CHAIN KIT 50'S CHRISTMAS ", 9205.149999999994], - ["2011", "DOTCOM POSTAGE", 181574.29000000004], - ["2011", "REGENCY CAKESTAND 3 TIER", 137864.8299999998], - ["2011", "PARTY BUNTING", 97095.24000000046], - ["2011", "WHITE HANGING HEART T-LIGHT HOLDER", 89790.64999999909], - ["2011", "JUMBO BAG RED RETROSPOT", 88383.68000000181] - ], - "report_columns": ["SaleYear", "Description", "TotalSales"], - "filename": "top_selling_products_report.xlsx" + "final_report": { + "report_data": [ + ["2010", "REGENCY CAKESTAND 3 TIER", 26897.360000000022], + ["2010", "DOTCOM POSTAGE", 24671.189999999995], + ["2010", "WHITE HANGING HEART T-LIGHT HOLDER", 9877.820000000005], + ["2010", "RED WOOLLY HOTTIE WHITE HEART.", 9291.729999999996], + ["2010", "PAPER CHAIN KIT 50'S CHRISTMAS ", 9205.149999999994], + ["2011", "DOTCOM POSTAGE", 181574.29000000004], + ["2011", "REGENCY CAKESTAND 3 TIER", 137864.8299999998], + ["2011", "PARTY BUNTING", 97095.24000000046], + ["2011", "WHITE HANGING HEART T-LIGHT HOLDER", 89790.64999999909], + ["2011", "JUMBO BAG RED RETROSPOT", 88383.68000000181] + ], + "report_columns": ["SaleYear", "Description", "TotalSales"], + "filename": "top_selling_products_report.xlsx" + }, + "trajectory": { + "actions": [ + "execute_sql_query", + "execute_sql_query", + "execute_sql_query", + "write_report_to_file", + "output_text" + ], + "description": [ + "Check what are the tables that are available in the database", + "Check what are the columns that are available in the sales table", + "Query to retrieve the top 5 selling products grouped by year", + "Send the report data to the function that writes the report to disk", + "Output text to the user with the report file as a Gradio hyperlink" + ] + } } }, { "id": "3", "input": "Generate a report of the average order value per invoice per month.", "expected_output": { - "report_data": [ - ["2010-12", 369.6727640671277], - ["2011-01", 377.61312204990026], - ["2011-02", 355.50510349750186], - ["2011-03", 343.35029145728674], - ["2011-04", 281.8326405714284], - ["2011-05", 333.79488232579683], - ["2011-06", 343.32991554893266], - ["2011-07", 353.3714268672203], - ["2011-08", 392.5707360552048], - ["2011-09", 438.0101469072177], - ["2011-10", 405.8774336618668], - ["2011-11", 421.86327561327636], - ["2011-12", 427.25912315270944] - ], - "report_columns": ["SaleMonth", "AverageOrderValue"], - "filename": "average_order_value_per_month.xlsx" + "final_report": { + "report_data": [ + ["2010-12", 369.6727640671277], + ["2011-01", 377.61312204990026], + ["2011-02", 355.50510349750186], + ["2011-03", 343.35029145728674], + ["2011-04", 281.8326405714284], + ["2011-05", 333.79488232579683], + ["2011-06", 343.32991554893266], + ["2011-07", 353.3714268672203], + ["2011-08", 392.5707360552048], + ["2011-09", 438.0101469072177], + ["2011-10", 405.8774336618668], + ["2011-11", 421.86327561327636], + ["2011-12", 427.25912315270944] + ], + "report_columns": ["SaleMonth", "AverageOrderValue"], + "filename": "average_order_value_per_month.xlsx" + }, + "trajectory": { + "actions": [ + "execute_sql_query", + "execute_sql_query", + "execute_sql_query", + "write_report_to_file", + "output_text" + ], + "description": [ + "Check what are the tables that are available in the database", + "Check what are the columns that are available in the sales table", + "Query to retrieve the average total order value per invoice grouped by year-month. This can only be properly achieved by using a subselect query to first get the amount per invoice.", + "Send the report data to the function that writes the report to disk", + "Output text to the user with the report file as a Gradio hyperlink" + ] + } } }, { "id": "4", "input": "Generate a report with the month-over-month trends in sales. The report should include the monthly sales, the month-over-month change and the percentage change.", "expected_output": { - "report_data": [ - ["2010-12", 748957.02, 748957.02, 0], - ["2011-01", 560000.26, -188956.76, -25.23], - ["2011-02", 498062.65, -61937.61, -11.06], - ["2011-03", 683267.08, 185204.43, 37.18], - ["2011-04", 493207.12, -190059.96, -27.82], - ["2011-05", 723333.51, 230126.39, 46.66], - ["2011-06", 691123.12, -32210.39, -4.45], - ["2011-07", 681300.11, -9823.01, -1.42], - ["2011-08", 682680.51, 1380.4, 0.2], - ["2011-09", 1019687.62, 337007.11, 49.37], - ["2011-10", 1070704.67, 51017.05, 5], - ["2011-11", 1461756.25, 391051.58, 36.52], - ["2011-12", 433668.01, -1028088.24, -70.33] - ], - "report_columns": [ - "Month", - "Monthly Sales", - "MoM Change", - "% MoM Change" - ], - "filename": "month_over_month_sales_report.xlsx" + "final_report": { + "report_data": [ + ["2010-12", 748957.02, 748957.02, 0], + ["2011-01", 560000.26, -188956.76, -25.23], + ["2011-02", 498062.65, -61937.61, -11.06], + ["2011-03", 683267.08, 185204.43, 37.18], + ["2011-04", 493207.12, -190059.96, -27.82], + ["2011-05", 723333.51, 230126.39, 46.66], + ["2011-06", 691123.12, -32210.39, -4.45], + ["2011-07", 681300.11, -9823.01, -1.42], + ["2011-08", 682680.51, 1380.4, 0.2], + ["2011-09", 1019687.62, 337007.11, 49.37], + ["2011-10", 1070704.67, 51017.05, 5], + ["2011-11", 1461756.25, 391051.58, 36.52], + ["2011-12", 433668.01, -1028088.24, -70.33] + ], + "report_columns": [ + "Month", + "Monthly Sales", + "MoM Change", + "% MoM Change" + ], + "filename": "month_over_month_sales_report.xlsx" + }, + "trajectory": { + "actions": [ + "execute_sql_query", + "execute_sql_query", + "execute_sql_query", + "output_text", + "write_report_to_file", + "output_text" + ], + "description": [ + "Check what are the tables that are available in the database", + "Check what are the columns that are available in the sales table", + "Query to retrieve the monthly sales grouped by year-month", + "Calculate the month-over-month change and the percentage change with the previous query result", + "Send the report data to the function that writes the report to disk", + "Output text to the user with the report file as a Gradio hyperlink" + ] + } } }, { "id": "5", "input": "Generate a report on sales revenue by country per year.", "expected_output": { - "report_data": [ - ["2010", "Australia", 1005.1], - ["2010", "Austria", 257.04], - ["2010", "Bahrain", 205.74], - ["2010", "Belgium", 1809.91], - ["2010", "Channel Islands", 363.53], - ["2010", "Cyprus", 1590.82], - ["2010", "Denmark", 1281.5], - ["2010", "EIRE", 9029.95], - ["2010", "Finland", 892.8], - ["2010", "France", 9575.36], - ["2010", "Germany", 14562.84], - ["2010", "Iceland", 711.79], - ["2010", "Israel", -227.44], - ["2010", "Italy", 794.5], - ["2010", "Japan", 7705.07], - ["2010", "Lithuania", 1661.06], - ["2010", "Netherlands", 8784.48], - ["2010", "Norway", 3787.12], - ["2010", "Poland", 248.16], - ["2010", "Portugal", 2380.12], - ["2010", "Spain", 1843.73], - ["2010", "Sweden", 2646.3], - ["2010", "Switzerland", 1304.92], - ["2010", "United Kingdom", 676742.62], - ["2011", "Australia", 136072.17], - ["2011", "Austria", 9897.28], - ["2011", "Bahrain", 342.66], - ["2011", "Belgium", 39101.05], - ["2011", "Brazil", 1143.6], - ["2011", "Canada", 3666.38], - ["2011", "Channel Islands", 19722.76], - ["2011", "Cyprus", 11355.47], - ["2011", "Czech Republic", 707.72], - ["2011", "Denmark", 17486.64], - ["2011", "EIRE", 254246.87], - ["2011", "European Community", 1291.75], - ["2011", "Finland", 21433.94], - ["2011", "France", 187828.54], - ["2011", "Germany", 207135.37], - ["2011", "Greece", 4710.52], - ["2011", "Hong Kong", 10117.04], - ["2011", "Iceland", 3598.21], - ["2011", "Israel", 8135.26], - ["2011", "Italy", 16096.01], - ["2011", "Japan", 27635.55], - ["2011", "Lebanon", 1693.88], - ["2011", "Malta", 2505.47], - ["2011", "Netherlands", 275877.06], - ["2011", "Norway", 31376.34], - ["2011", "Poland", 6964.98], - ["2011", "Portugal", 26986.9], - ["2011", "RSA", 1002.31], - ["2011", "Saudi Arabia", 131.17], - ["2011", "Singapore", 9120.39], - ["2011", "Spain", 52930.85], - ["2011", "Sweden", 33949.61], - ["2011", "Switzerland", 55080.43], - ["2011", "USA", 1730.92], - ["2011", "United Arab Emirates", 1902.28], - ["2011", "United Kingdom", 7511063.74], - ["2011", "Unspecified", 4749.79] - ], - "report_columns": ["SaleYear", "Country", "Revenue"], - "filename": "sales_revenue_by_country_per_year.xlsx" + "final_report": { + "report_data": [ + ["2010", "Australia", 1005.1], + ["2010", "Austria", 257.04], + ["2010", "Bahrain", 205.74], + ["2010", "Belgium", 1809.91], + ["2010", "Channel Islands", 363.53], + ["2010", "Cyprus", 1590.82], + ["2010", "Denmark", 1281.5], + ["2010", "EIRE", 9029.95], + ["2010", "Finland", 892.8], + ["2010", "France", 9575.36], + ["2010", "Germany", 14562.84], + ["2010", "Iceland", 711.79], + ["2010", "Israel", -227.44], + ["2010", "Italy", 794.5], + ["2010", "Japan", 7705.07], + ["2010", "Lithuania", 1661.06], + ["2010", "Netherlands", 8784.48], + ["2010", "Norway", 3787.12], + ["2010", "Poland", 248.16], + ["2010", "Portugal", 2380.12], + ["2010", "Spain", 1843.73], + ["2010", "Sweden", 2646.3], + ["2010", "Switzerland", 1304.92], + ["2010", "United Kingdom", 676742.62], + ["2011", "Australia", 136072.17], + ["2011", "Austria", 9897.28], + ["2011", "Bahrain", 342.66], + ["2011", "Belgium", 39101.05], + ["2011", "Brazil", 1143.6], + ["2011", "Canada", 3666.38], + ["2011", "Channel Islands", 19722.76], + ["2011", "Cyprus", 11355.47], + ["2011", "Czech Republic", 707.72], + ["2011", "Denmark", 17486.64], + ["2011", "EIRE", 254246.87], + ["2011", "European Community", 1291.75], + ["2011", "Finland", 21433.94], + ["2011", "France", 187828.54], + ["2011", "Germany", 207135.37], + ["2011", "Greece", 4710.52], + ["2011", "Hong Kong", 10117.04], + ["2011", "Iceland", 3598.21], + ["2011", "Israel", 8135.26], + ["2011", "Italy", 16096.01], + ["2011", "Japan", 27635.55], + ["2011", "Lebanon", 1693.88], + ["2011", "Malta", 2505.47], + ["2011", "Netherlands", 275877.06], + ["2011", "Norway", 31376.34], + ["2011", "Poland", 6964.98], + ["2011", "Portugal", 26986.9], + ["2011", "RSA", 1002.31], + ["2011", "Saudi Arabia", 131.17], + ["2011", "Singapore", 9120.39], + ["2011", "Spain", 52930.85], + ["2011", "Sweden", 33949.61], + ["2011", "Switzerland", 55080.43], + ["2011", "USA", 1730.92], + ["2011", "United Arab Emirates", 1902.28], + ["2011", "United Kingdom", 7511063.74], + ["2011", "Unspecified", 4749.79] + ], + "report_columns": ["SaleYear", "Country", "Revenue"], + "filename": "sales_revenue_by_country_per_year.xlsx" + }, + "trajectory": { + "actions": [ + "execute_sql_query", + "execute_sql_query", + "execute_sql_query", + "write_report_to_file", + "output_text" + ], + "description": [ + "Check what are the tables that are available in the database", + "Check what are the columns that are available in the sales table", + "Query to retrtieve the total amount of sales grouped by country and year", + "Send the report data to the function that writes the report to disk", + "Output text to the user with the report file as a Gradio hyperlink" + ] + } } }, { "id": "6", "input": "Generate a report on the 5 highest-value customers per year vs. the average customer.", "expected_output": { - "report_data": [ - [2010, 18102, 27834.61, 647.1343389830506], - [2010, 15061, 19950.660000000007, 647.1343389830506], - [2010, 16029, 13112.52, 647.1343389830506], - [2010, 14646, 8591.879999999997, 647.1343389830506], - [2010, 14911, 7737.939999999999, 647.1343389830506], - [2011, 14646, 271614.13999999996, 1975.993842180096], - [2011, 18102, 231822.69000000006, 1975.993842180096], - [2011, 17450, 192521.9500000001, 1975.993842180096], - [2011, 16446, 168472.5, 1975.993842180096], - [2011, 14911, 136087.11999999956, 1975.993842180096] - ], - "report_columns": ["SaleYear", "CustomerID", "TopCustomerValue", "AverageValue"], - "filename": "highest_value_customers_report.xlsx" + "final_report": { + "report_data": [ + [2010, 18102, 27834.61, 585.0253375527616], + [2010, 15061, 19950.660000000007, 585.0253375527616], + [2010, 16029, 13112.52, 585.0253375527616], + [2010, 14646, 8591.879999999997, 585.0253375527616], + [2010, 16210, 7000.639999999999, 585.0253375527616], + [2011, 14646, 270897.13999999984, 1825.03812299787], + [2011, 18102, 228603.88000000006, 1825.03812299787], + [2011, 17450, 185453.33000000013, 1825.03812299787], + [2011, 14911, 125815.48999999973, 1825.03812299787], + [2011, 12415, 123725.44999999987, 1825.03812299787] + ], + "report_columns": ["SaleYear", "CustomerID", "TopCustomerValue", "AverageValue"], + "filename": "highest_value_customers_report.xlsx" + }, + "trajectory": { + "actions": [ + "execute_sql_query", + "execute_sql_query", + "execute_sql_query", + "execute_sql_query", + "write_report_to_file", + "output_text" + ], + "description": [ + "Check what are the tables that are available in the database", + "Check what are the columns that are available in the sales table", + "Query to retrtieve or create a view that containsthe total amount of sales grouped by customer and year", + "Query to retrieve the total amount of sales of the top 5 customers per year, and the average value of all customers", + "Send the report data to the function that writes the report to disk", + "Output text to the user with the report file as a Gradio hyperlink" + ] + } } }, { "id": "7", "input": "Generate a report on the average amount spent by one time buyers for each year vs. the average customer.", "expected_output": { - "report_data": [ - ["2010", 196.46396396396398, 789.2065542676543], - ["2011", 342.7430707154745, 2119.856516843317], - ["Overall", 330.3767486671746, 2229.075676652107] - ], - "report_columns": ["Year", "Average Spent by One-Time Buyers", "Average Spent by All Customers"], - "filename": "average_spending_report.xlsx" + "final_report": { + "report_data": [ + ["2010", 196.46396396396398, 789.2065542676543], + ["2011", 342.7430707154745, 2119.856516843317] + ], + "report_columns": ["Year", "Average Spent by One-Time Buyers", "Average Spent by All Customers"], + "filename": "average_spending_report.xlsx" + }, + "trajectory": { + "actions": [ + "execute_sql_query", + "execute_sql_query", + "execute_sql_query", + "execute_sql_query", + "write_report_to_file", + "output_text" + ], + "description": [ + "Check what are the tables that are available in the database", + "Check what are the columns that are available in the sales table", + "Query to find out the number of invoices and the total spent by customer", + "Query to retrieve the average spent for one time buyers only, grouped by year", + "Send the report data to the function that writes the report to disk", + "Output text to the user with the report file as a Gradio hyperlink" + ] + } } } ] diff --git a/implementations/report_generation/evaluate.py b/implementations/report_generation/evaluate.py index d5b4447d..453c7183 100644 --- a/implementations/report_generation/evaluate.py +++ b/implementations/report_generation/evaluate.py @@ -2,6 +2,7 @@ import asyncio import logging +from typing import Any import agents import click @@ -10,11 +11,20 @@ from langfuse._client.datasets import DatasetItemClient from langfuse.experiment import Evaluation from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from openai.types.responses.response_output_message import ResponseOutputMessage +from openai.types.responses.response_output_refusal import ResponseOutputRefusal +from openai.types.responses.response_output_text import ResponseOutputText from pydantic import BaseModel from tenacity import retry, stop_after_attempt, wait_exponential from implementations.report_generation.data.langfuse_upload import DEFAULT_EVALUATION_DATASET_NAME from implementations.report_generation.main import get_report_generation_agent +from implementations.report_generation.prompts import ( + RESULT_EVALUATOR_INSTRUCTIONS, + RESULT_EVALUATOR_TEMPLATE, + TRAJECTORY_EVALUATOR_INSTRUCTIONS, + TRAJECTORY_EVALUATOR_TEMPLATE, +) load_dotenv(verbose=True) @@ -22,32 +32,15 @@ logger = logging.getLogger(__name__) -EVALUATOR_INSTRUCTIONS = """\ -Evaluate whether the "Proposed Answer" to the given "Question" matches the "Ground Truth".""" - -ADDITONAL_EVALUATOR_INSTRUCTIONS = """\ -Disregard the following aspects when comparing the "Proposed Answer" to the "Ground Truth": -- The order of the items should not matter, unless explicitly specified in the "Question". -- The formatting of the values should not matter, unless explicitly specified in the "Question". -- The column and row names have to be similar but not necessarily exact, unless explicitly specified in the "Question". -- The filename has to be similar by name but not necessarily exact, unless explicitly specified in the "Question". -- The numerical values should be equal to the second decimal place. -""" - -EVALUATOR_TEMPLATE = """\ -# Question - -{question} - -# Ground Truth - -{ground_truth} - -# Proposed Answer - -{proposed_response} - -""" +# Will have the structure: +# { +# "final_report": str | None, +# "trajectory": { +# "actions": list[str], +# "parameters": list[str], +# }, +# } +EvaluationOutput = dict[str, None | Any] class EvaluatorResponse(BaseModel): @@ -78,7 +71,7 @@ async def evaluate(dataset_name: str): name="Evaluate Report Generation Agent", description="Evaluate the Report Generation Agent with data from Langfuse", task=agent_task, - evaluators=[llm_evaluator], + evaluators=[final_result_evaluator, trajectory_evaluator], max_concurrency=1, ) @@ -92,7 +85,7 @@ async def evaluate(dataset_name: str): logger.warning(f"Client manager services not closed successfully: {e}") -async def agent_task(*, item: DatasetItemClient, **kwargs) -> str | None: +async def agent_task(*, item: DatasetItemClient, **kwargs) -> EvaluationOutput: """Run the report generation agent against an item from a Langfuse dataset. Parameters @@ -102,30 +95,121 @@ async def agent_task(*, item: DatasetItemClient, **kwargs) -> str | None: Returns ------- - str | None - The arguments sent by the report generation agent to the write_report_to_file - function. Returns None if the agent did not call the function. + EvaluationOutput + The output of the report generation agent with the values it should + be evaluated against. """ # Define and run the report generation agent report_generation_agent = get_report_generation_agent(enable_trace=True) result = await run_agent_with_retry(report_generation_agent, item.input) - # Extract the report data from the result by returning the - # arguments to the write_report_to_file function call - # Reversing the responses to get the last write_report_to_file call - # in case a failed call has been made first - for raw_response in reversed(result.raw_responses): + # Extract the report data and trajectory from the agent's response + actions = [] + parameters = [] + final_report = None + for raw_response in result.raw_responses: for output in raw_response.output: - if isinstance(output, ResponseFunctionToolCall) and "write_report_to_file" in output.name: - return output.arguments + # The trajectory will be the list of actions and the + # parameters passed to each one of them + if isinstance(output, ResponseFunctionToolCall): + actions.append(output.name) + parameters.append(output.arguments) + + # The final report will be the arguments sent by the + # write_report_to_file function call + # If there is more than one call to the write_report_to_file function, + # the last one will be used because the previous calls were likely + # failed calls + if isinstance(output, ResponseFunctionToolCall) and "write_report_to_file" in output.name: + final_report = output.arguments + + if isinstance(output, ResponseOutputMessage): + for content in output.content: + actions.append(content.type) + if isinstance(content, ResponseOutputText): + parameters.append(content.text) + elif isinstance(content, ResponseOutputRefusal): + parameters.append(content.refusal) + + if final_report is None: + logger.warning("No call to write_report_to_file function found in the agent's response") + + return { + "final_report": final_report, + "trajectory": { + "actions": actions, + "parameters": parameters, + }, + } + + +async def final_result_evaluator( + *, + input: str, + output: EvaluationOutput, + expected_output: EvaluationOutput, + **kwargs, +) -> Evaluation: + # ruff: noqa: A002 + """Evaluate the proposed final answer against the ground truth. - logger.warning("No call to write_report_to_file function found in the agent's response") - return None + Uses LLM-as-a-judge and returns the reasoning behind the answer. + Parameters + ---------- + input : str + The input to the report generation agent. + output : EvaluationOutput + The output of the report generation agent with the values it should be + evaluated against. + expected_output : EvaluationOutput + The evaluation output the report generation agent should have. + kwargs : dict + Additional keyword arguments. -async def llm_evaluator(*, input: str, output: str, expected_output: str, **kwargs) -> Evaluation: + Returns + ------- + Evaluation + The evaluation result, including the reasoning behind the answer. + """ + # Define the evaluator agent + client_manager = AsyncClientManager.get_instance() + evaluator_agent = agents.Agent( + name="Final Result Evaluator Agent", + instructions=RESULT_EVALUATOR_INSTRUCTIONS, + output_type=EvaluatorResponse, + model=agents.OpenAIChatCompletionsModel( + model=client_manager.configs.default_planner_model, + openai_client=client_manager.openai_client, + ), + ) + # Format the input for the evaluator agent + evaluator_input = RESULT_EVALUATOR_TEMPLATE.format( + question=input, + ground_truth=expected_output["final_report"], + proposed_response=output["final_report"], + ) + # Run the evaluator agent with retry + result = await run_agent_with_retry(evaluator_agent, evaluator_input) + evaluation_response = result.final_output_as(EvaluatorResponse) + + # Return the evaluation result + return Evaluation( + name="Final Result", + value=evaluation_response.is_answer_correct, + comment=evaluation_response.explanation, + ) + + +async def trajectory_evaluator( + *, + input: str, + output: EvaluationOutput, + expected_output: EvaluationOutput, + **kwargs, +) -> Evaluation: # ruff: noqa: A002 - """Evaluate the proposed answer against the ground truth. + """Evaluate the agent's trajectory against the ground truth. Uses LLM-as-a-judge and returns the reasoning behind the answer. @@ -133,10 +217,11 @@ async def llm_evaluator(*, input: str, output: str, expected_output: str, **kwar ---------- input : str The input to the report generation agent. - output : str - The output of the report generation agent. - expected_output : str - The expected output of the report generation agent. + output : EvaluationOutput + The output of the report generation agent with the values it should be + evaluated against. + expected_output : EvaluationOutput + The evaluation output the report generation agent should have. kwargs : dict Additional keyword arguments. @@ -148,19 +233,25 @@ async def llm_evaluator(*, input: str, output: str, expected_output: str, **kwar # Define the evaluator agent client_manager = AsyncClientManager.get_instance() evaluator_agent = agents.Agent( - name="Evaluator Agent", - instructions=EVALUATOR_INSTRUCTIONS + ADDITONAL_EVALUATOR_INSTRUCTIONS, + name="Trajectory Evaluator Agent", + instructions=TRAJECTORY_EVALUATOR_INSTRUCTIONS, output_type=EvaluatorResponse, model=agents.OpenAIChatCompletionsModel( model=client_manager.configs.default_planner_model, openai_client=client_manager.openai_client, ), ) + + assert isinstance(expected_output["trajectory"], dict), "Expected trajectory must be a dictionary" + assert isinstance(output["trajectory"], dict), "Actual trajectory must be a dictionary" + # Format the input for the evaluator agent - evaluator_input = EVALUATOR_TEMPLATE.format( + evaluator_input = TRAJECTORY_EVALUATOR_TEMPLATE.format( question=input, - ground_truth=expected_output, - proposed_response=output, + expected_actions=expected_output["trajectory"]["actions"], + expected_descriptions=expected_output["trajectory"]["description"], + actual_actions=output["trajectory"]["actions"], + actual_parameters=output["trajectory"]["parameters"], ) # Run the evaluator agent with retry result = await run_agent_with_retry(evaluator_agent, evaluator_input) @@ -168,7 +259,7 @@ async def llm_evaluator(*, input: str, output: str, expected_output: str, **kwar # Return the evaluation result return Evaluation( - name="LLM-as-a-judge", + name="Trajectory", value=evaluation_response.is_answer_correct, comment=evaluation_response.explanation, ) diff --git a/implementations/report_generation/main.py b/implementations/report_generation/main.py index acd5006b..9e0a6553 100644 --- a/implementations/report_generation/main.py +++ b/implementations/report_generation/main.py @@ -17,6 +17,7 @@ from gradio.components.chatbot import ChatMessage from implementations.report_generation.file_writer import get_reports_output_path, write_report_to_file +from implementations.report_generation.prompts import MAIN_AGENT_INSTRUCTIONS load_dotenv(verbose=True) @@ -24,18 +25,6 @@ logger = logging.getLogger(__name__) -REACT_INSTRUCTIONS = """\ -Perform the task using the SQLite database tool. \ -EACH TIME before invoking the function, you must explain your reasons for doing so. \ -If the SQL query did not return intended results, try again. \ -For best performance, divide complex queries into simpler sub-queries. \ -Do not make up information. \ -When the report is done, use the report file writer tool to write it to a file. \ -Make sure the write_report_to_file tool is called so it generates the report file. \ -At the end, provide the report file as a downloadable hyperlink to the user. \ -Make sure the link can be clicked on by the user. -""" - LANGFUSE_PROJECT_NAME = "Report Generation" @@ -79,17 +68,24 @@ def get_report_generation_agent(enable_trace: bool = True) -> agents.Agent: # Define an agent using the OpenAI Agent SDK return agents.Agent( name="Report Generation Agent", # Agent name for logging and debugging purposes - instructions=REACT_INSTRUCTIONS, # System instructions for the agent + instructions=MAIN_AGENT_INSTRUCTIONS, # System instructions for the agent # Tools available to the agent # We wrap the `search_knowledgebase` method with `function_tool`, which # will construct the tool definition JSON schema by extracting the necessary # information from the method signature and docstring. tools=[ - agents.function_tool(client_manager.sqlite_connection(get_sqlite_db_path()).execute), - agents.function_tool(write_report_to_file), + agents.function_tool( + client_manager.sqlite_connection(get_sqlite_db_path()).execute, + name_override="execute_sql_query", + description_override="Execute a SQL query against the SQLite database.", + ), + agents.function_tool( + write_report_to_file, + description_override="Write the report data to a file.", + ), ], model=agents.OpenAIChatCompletionsModel( - model=client_manager.configs.default_planner_model, + model=client_manager.configs.default_worker_model, openai_client=client_manager.openai_client, ), ) diff --git a/implementations/report_generation/prompts.py b/implementations/report_generation/prompts.py new file mode 100644 index 00000000..4a8ca199 --- /dev/null +++ b/implementations/report_generation/prompts.py @@ -0,0 +1,70 @@ +"""Prompts for the report generation and evaluator agents.""" + +MAIN_AGENT_INSTRUCTIONS = """\ +Perform the task using the SQLite database tool. \ +EACH TIME before invoking the function, you must explain your reasons for doing so. \ +If the SQL query did not return intended results, try again. \ +For best performance, divide complex queries into simpler sub-queries. \ +Do not make up information. \ +When the report is done, use the report file writer tool to write it to a file. \ +Make sure the write_report_to_file tool is called so it generates the report file. \ +At the end, provide the report file as a downloadable hyperlink to the user. \ +Make sure the link can be clicked on by the user. +""" + +TRAJECTORY_EVALUATOR_INSTRUCTIONS = """\ +You are evaluating if an agent has followed the correct trajectory to generate a report.\ +The agent is a Report Generation Agent that uses the SQLite database tool to generate a report\ +and return the report as a downloadable file to the user.\ +You will be presented with the "Question" that has been asked to the agent along with two sets of data:\ +- The "Expected Trajectory" of the agent, which contains:\ + - A list ids for the actions the agent is expected to perform\ + - A list of rough descriptions of what has been passed as parameters to the actions\ +- The "Actual Trajectory" of the agent, which contains:\ + - A list ids for the actions the agent performed\ + - A list of parameters that has been passed to each one of the actions\ +It's OK if the agent makes mistakes and performs additional steps, or if the queries do not exactly match\ +the description, as long as the queries performed end up satisfying the "Question".\ +It is important that the last action to be of type "output_text" and that itproduces a link to the report file. +""" + +TRAJECTORY_EVALUATOR_TEMPLATE = """\ +# Question + +{question} + +# Expected Trajectory + +actions: {expected_actions} +descriptions: {expected_descriptions} + +# Actual Trajectory + +actions: {actual_actions} +parameters: {actual_parameters} +""" + +RESULT_EVALUATOR_INSTRUCTIONS = """\ +Evaluate whether the "Proposed Answer" to the given "Question" matches the "Ground Truth".\ +Disregard the following aspects when comparing the "Proposed Answer" to the "Ground Truth":\ +- The order of the items should not matter, unless explicitly specified in the "Question".\ +- The formatting of the values should not matter, unless explicitly specified in the "Question".\ +- The column and row names have to be similar but not necessarily exact, unless explicitly specified in the "Question".\ +- The filename has to be similar by name but not necessarily exact, unless explicitly specified in the "Question".\ +- The numerical values should be equal to the second decimal place. +""" + +RESULT_EVALUATOR_TEMPLATE = """\ +# Question + +{question} + +# Ground Truth + +{ground_truth} + +# Proposed Answer + +{proposed_response} + +"""