22
33import logging
44import traceback
5- from contextlib import nullcontext
65from datetime import datetime
76from pathlib import Path
87
1312 DefaultTopicId ,
1413 SingleThreadedAgentRuntime ,
1514)
16- from autogen_ext .models .openai import OpenAIChatCompletionClient
1715from langfuse import Langfuse
1816from omegaconf import DictConfig
1917
2018from src .area_generation .messages import Domain
2119from src .area_generation .moderator import AreaModerator
2220from src .area_generation .scientist import AreaScientist
21+ from src .utils .model_client_utils import get_model_client
2322
2423
2524log = logging .getLogger ("agentic_area_gen.generator" )
@@ -36,29 +35,20 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
3635 num_areas = cfg .area_generation .num_areas
3736 areas_tag = f"_{ datetime .now ().strftime ('%Y%m%d_%H%M%S' )} "
3837
39- lf = langfuse_client
40- if lf is None :
41- lf = Langfuse ()
42-
43- with (
44- lf .start_as_current_span (
45- name = f"ace_area_generation:{ domain_name } :{ exp_id } :{ areas_tag } "
46- )
47- if lf
48- else nullcontext () as span
49- ):
38+ with langfuse_client .start_as_current_span (
39+ name = f"ace_area_generation:{ domain_name } :{ exp_id } :{ areas_tag } "
40+ ) as span :
5041 try :
5142 msg = f"Areas will be saved with tag: { areas_tag } "
5243 log .info (msg )
53- if span :
54- span .update (
55- metadata = {
56- "generation_started" : msg ,
57- "areas_tag" : areas_tag ,
58- "domain" : domain_name ,
59- "exp_id" : exp_id ,
60- }
61- )
44+ span .update (
45+ metadata = {
46+ "generation_started" : msg ,
47+ "areas_tag" : areas_tag ,
48+ "domain" : domain_name ,
49+ "exp_id" : exp_id ,
50+ }
51+ )
6252
6353 output_dir = (
6454 Path .home ()
@@ -72,82 +62,79 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
7262
7363 msg = f"Output directory: { output_dir } "
7464 log .info (msg )
75- if span :
76- span .update (
77- metadata = {
78- "output_directory_configured" : msg ,
79- "output_dir" : str (output_dir ),
80- }
81- )
65+ span .update (
66+ metadata = {
67+ "output_directory_configured" : msg ,
68+ "output_dir" : str (output_dir ),
69+ }
70+ )
8271
83- if span :
84- span .update_trace (
85- metadata = {
86- "domain" : domain_name ,
87- "exp_id" : exp_id ,
88- "max_round" : max_round ,
89- "num_areas" : num_areas ,
90- "areas_tag" : areas_tag ,
91- },
92- tags = ["area_generation_process" , exp_id ],
93- )
72+ span .update_trace (
73+ metadata = {
74+ "domain" : domain_name ,
75+ "exp_id" : exp_id ,
76+ "max_round" : max_round ,
77+ "num_areas" : num_areas ,
78+ "areas_tag" : areas_tag ,
79+ },
80+ tags = ["area_generation_process" , exp_id ],
81+ )
9482
9583 runtime = SingleThreadedAgentRuntime ()
9684
9785 await AreaScientist .register (
9886 runtime ,
9987 "AreaScientistA" ,
10088 lambda : AreaScientist (
101- model_client = OpenAIChatCompletionClient (
102- model = cfg .agents .scientist_a .model_name ,
89+ model_client = get_model_client (
90+ model_name = cfg .agents .scientist_a .model_name ,
10391 seed = cfg .agents .scientist_a .seed ,
10492 ),
10593 scientist_id = "A" ,
106- langfuse_client = lf ,
94+ langfuse_client = langfuse_client ,
10795 ),
10896 )
10997
11098 await AreaScientist .register (
11199 runtime ,
112100 "AreaScientistB" ,
113101 lambda : AreaScientist (
114- model_client = OpenAIChatCompletionClient (
115- model = cfg .agents .scientist_b .model_name ,
102+ model_client = get_model_client (
103+ model_name = cfg .agents .scientist_b .model_name ,
116104 seed = cfg .agents .scientist_b .seed ,
117105 ),
118106 scientist_id = "B" ,
119- langfuse_client = lf ,
107+ langfuse_client = langfuse_client ,
120108 ),
121109 )
122110
123111 await AreaModerator .register (
124112 runtime ,
125113 "AreaModerator" ,
126114 lambda : AreaModerator (
127- model_client = OpenAIChatCompletionClient (
128- model = cfg .agents .moderator .model_name ,
115+ model_client = get_model_client (
116+ model_name = cfg .agents .moderator .model_name ,
129117 seed = cfg .agents .moderator .seed ,
130118 ),
131119 num_scientists = 2 ,
132120 num_final_areas = num_areas ,
133121 max_round = max_round ,
134122 output_dir = output_dir ,
135- langfuse_client = lf ,
123+ langfuse_client = langfuse_client ,
136124 ),
137125 )
138126
139127 msg = "All area agents registered successfully"
140128 log .info (msg )
141- if span :
142- span .update (
143- metadata = {
144- "agents_registered" : msg ,
145- "scientists" : ["A" , "B" ],
146- "moderator" : True ,
147- "max_rounds" : max_round ,
148- "expected_areas" : num_areas ,
149- }
150- )
129+ span .update (
130+ metadata = {
131+ "agents_registered" : msg ,
132+ "scientists" : ["A" , "B" ],
133+ "moderator" : True ,
134+ "max_rounds" : max_round ,
135+ "expected_areas" : num_areas ,
136+ }
137+ )
151138
152139 runtime .start ()
153140
@@ -156,35 +143,32 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
156143
157144 msg = f"Domain message published: { domain_name } "
158145 log .info (msg )
159- if span :
160- span .update (
161- metadata = {
162- "domain_published" : msg ,
163- "domain_name" : domain_name ,
164- }
165- )
146+ span .update (
147+ metadata = {
148+ "domain_published" : msg ,
149+ "domain_name" : domain_name ,
150+ }
151+ )
166152
167153 try :
168154 await runtime .stop_when_idle ()
169155
170156 msg = "Runtime stopped - area generation completed"
171157 log .info (msg )
172- if span :
173- span .update (metadata = {"runtime_completed" : msg })
158+ span .update (metadata = {"runtime_completed" : msg })
174159
175160 print (f"Areas generated with tag: { areas_tag } " )
176161 except Exception as e :
177162 msg = f"Error while waiting for runtime to stop: { e } "
178163 log .error (msg )
179- if span :
180- span .update (
181- level = "ERROR" ,
182- status_message = str (e ),
183- metadata = {
184- "runtime_error" : msg ,
185- "error" : str (e ),
186- },
187- )
164+ span .update (
165+ level = "ERROR" ,
166+ status_message = str (e ),
167+ metadata = {
168+ "runtime_error" : msg ,
169+ "error" : str (e ),
170+ },
171+ )
188172 raise
189173
190174 except Exception as e :
@@ -194,17 +178,16 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
194178 log .error (error_msg )
195179 log .error (traceback_msg )
196180
197- if span :
198- span .update (
199- level = "ERROR" ,
200- status_message = str (e ),
201- metadata = {
202- "generation_error" : error_msg ,
203- "error" : str (e ),
204- "traceback" : traceback_msg ,
205- },
206- )
181+ span .update (
182+ level = "ERROR" ,
183+ status_message = str (e ),
184+ metadata = {
185+ "generation_error" : error_msg ,
186+ "error" : str (e ),
187+ "traceback" : traceback_msg ,
188+ },
189+ )
207190
208191 if langfuse_client is None :
209- lf .flush ()
192+ langfuse_client .flush ()
210193 raise
0 commit comments