Skip to content

Commit 2f34c46

Browse files
authored
Merge pull request #38 from VectorInstitute/fix-anthropic-client
Fixed model client (Anthropic support and retry logic), resume logic for cap generation
2 parents 0de51d5 + 435fbb0 commit 2f34c46

File tree

11 files changed

+963
-815
lines changed

11 files changed

+963
-815
lines changed

poetry.lock

Lines changed: 208 additions & 100 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ dependencies = [
1818
"langchain>=0.3.19",
1919
"matplotlib>=3.10.0",
2020
"omegaconf>=2.3.0",
21-
"openai>=1.92.0",
22-
"pyautogen>=0.2.22",
21+
"openai>=1.102.0",
22+
"ag2>=0.3.2",
23+
"autogen-ext[openai,anthropic]>=0.7.4",
24+
"anthropic>=0.64.0",
2325
"ratelimit>=2.2.1",
2426
"torchvision (>=0.21.0,<0.22.0)",
2527
"torchaudio (>=2.6.0,<3.0.0)",

src/agentic_capability_generator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
def main(cfg: DictConfig) -> None:
3030
"""Run the multi-agent debate-based capability generation system."""
3131
areas_tag = cfg.pipeline_tags.areas_tag
32+
resume_tag = getattr(cfg.pipeline_tags, "resume_capabilities_tag", None)
3233
domain_name = cfg.global_cfg.domain
3334
exp_id = cfg.exp_cfg.exp_id
3435
num_capabilities_per_area = cfg.capability_generation.num_capabilities_per_area
@@ -68,18 +69,26 @@ def main(cfg: DictConfig) -> None:
6869
)
6970
return
7071

72+
if resume_tag:
73+
msg = f"Resuming capability generation from tag: {resume_tag}"
74+
log.info(msg)
75+
span.update(
76+
metadata={"resume_tag_found": msg, "resume_tag": resume_tag}
77+
)
78+
7179
span.update_trace(
7280
metadata={
7381
"domain": domain_name,
7482
"exp_id": exp_id,
7583
"areas_tag": areas_tag,
84+
"resume_tag": resume_tag,
7685
"num_capabilities_per_area": num_capabilities_per_area,
7786
"config": config_yaml,
7887
},
7988
tags=["agentic_capability_generation", exp_id],
8089
)
8190

82-
asyncio.run(generate_capabilities(cfg, areas_tag, lf))
91+
asyncio.run(generate_capabilities(cfg, areas_tag, lf, resume_tag))
8392

8493
msg = (
8594
"Multi-agent debate-based capability generation completed successfully"

src/area_generation/generator.py

Lines changed: 71 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import logging
44
import traceback
5-
from contextlib import nullcontext
65
from datetime import datetime
76
from pathlib import Path
87

@@ -13,13 +12,13 @@
1312
DefaultTopicId,
1413
SingleThreadedAgentRuntime,
1514
)
16-
from autogen_ext.models.openai import OpenAIChatCompletionClient
1715
from langfuse import Langfuse
1816
from omegaconf import DictConfig
1917

2018
from src.area_generation.messages import Domain
2119
from src.area_generation.moderator import AreaModerator
2220
from src.area_generation.scientist import AreaScientist
21+
from src.utils.model_client_utils import get_model_client
2322

2423

2524
log = logging.getLogger("agentic_area_gen.generator")
@@ -36,29 +35,20 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
3635
num_areas = cfg.area_generation.num_areas
3736
areas_tag = f"_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
3837

39-
lf = langfuse_client
40-
if lf is None:
41-
lf = Langfuse()
42-
43-
with (
44-
lf.start_as_current_span(
45-
name=f"ace_area_generation:{domain_name}:{exp_id}:{areas_tag}"
46-
)
47-
if lf
48-
else nullcontext() as span
49-
):
38+
with langfuse_client.start_as_current_span(
39+
name=f"ace_area_generation:{domain_name}:{exp_id}:{areas_tag}"
40+
) as span:
5041
try:
5142
msg = f"Areas will be saved with tag: {areas_tag}"
5243
log.info(msg)
53-
if span:
54-
span.update(
55-
metadata={
56-
"generation_started": msg,
57-
"areas_tag": areas_tag,
58-
"domain": domain_name,
59-
"exp_id": exp_id,
60-
}
61-
)
44+
span.update(
45+
metadata={
46+
"generation_started": msg,
47+
"areas_tag": areas_tag,
48+
"domain": domain_name,
49+
"exp_id": exp_id,
50+
}
51+
)
6252

6353
output_dir = (
6454
Path.home()
@@ -72,82 +62,79 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
7262

7363
msg = f"Output directory: {output_dir}"
7464
log.info(msg)
75-
if span:
76-
span.update(
77-
metadata={
78-
"output_directory_configured": msg,
79-
"output_dir": str(output_dir),
80-
}
81-
)
65+
span.update(
66+
metadata={
67+
"output_directory_configured": msg,
68+
"output_dir": str(output_dir),
69+
}
70+
)
8271

83-
if span:
84-
span.update_trace(
85-
metadata={
86-
"domain": domain_name,
87-
"exp_id": exp_id,
88-
"max_round": max_round,
89-
"num_areas": num_areas,
90-
"areas_tag": areas_tag,
91-
},
92-
tags=["area_generation_process", exp_id],
93-
)
72+
span.update_trace(
73+
metadata={
74+
"domain": domain_name,
75+
"exp_id": exp_id,
76+
"max_round": max_round,
77+
"num_areas": num_areas,
78+
"areas_tag": areas_tag,
79+
},
80+
tags=["area_generation_process", exp_id],
81+
)
9482

9583
runtime = SingleThreadedAgentRuntime()
9684

9785
await AreaScientist.register(
9886
runtime,
9987
"AreaScientistA",
10088
lambda: AreaScientist(
101-
model_client=OpenAIChatCompletionClient(
102-
model=cfg.agents.scientist_a.model_name,
89+
model_client=get_model_client(
90+
model_name=cfg.agents.scientist_a.model_name,
10391
seed=cfg.agents.scientist_a.seed,
10492
),
10593
scientist_id="A",
106-
langfuse_client=lf,
94+
langfuse_client=langfuse_client,
10795
),
10896
)
10997

11098
await AreaScientist.register(
11199
runtime,
112100
"AreaScientistB",
113101
lambda: AreaScientist(
114-
model_client=OpenAIChatCompletionClient(
115-
model=cfg.agents.scientist_b.model_name,
102+
model_client=get_model_client(
103+
model_name=cfg.agents.scientist_b.model_name,
116104
seed=cfg.agents.scientist_b.seed,
117105
),
118106
scientist_id="B",
119-
langfuse_client=lf,
107+
langfuse_client=langfuse_client,
120108
),
121109
)
122110

123111
await AreaModerator.register(
124112
runtime,
125113
"AreaModerator",
126114
lambda: AreaModerator(
127-
model_client=OpenAIChatCompletionClient(
128-
model=cfg.agents.moderator.model_name,
115+
model_client=get_model_client(
116+
model_name=cfg.agents.moderator.model_name,
129117
seed=cfg.agents.moderator.seed,
130118
),
131119
num_scientists=2,
132120
num_final_areas=num_areas,
133121
max_round=max_round,
134122
output_dir=output_dir,
135-
langfuse_client=lf,
123+
langfuse_client=langfuse_client,
136124
),
137125
)
138126

139127
msg = "All area agents registered successfully"
140128
log.info(msg)
141-
if span:
142-
span.update(
143-
metadata={
144-
"agents_registered": msg,
145-
"scientists": ["A", "B"],
146-
"moderator": True,
147-
"max_rounds": max_round,
148-
"expected_areas": num_areas,
149-
}
150-
)
129+
span.update(
130+
metadata={
131+
"agents_registered": msg,
132+
"scientists": ["A", "B"],
133+
"moderator": True,
134+
"max_rounds": max_round,
135+
"expected_areas": num_areas,
136+
}
137+
)
151138

152139
runtime.start()
153140

@@ -156,35 +143,32 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
156143

157144
msg = f"Domain message published: {domain_name}"
158145
log.info(msg)
159-
if span:
160-
span.update(
161-
metadata={
162-
"domain_published": msg,
163-
"domain_name": domain_name,
164-
}
165-
)
146+
span.update(
147+
metadata={
148+
"domain_published": msg,
149+
"domain_name": domain_name,
150+
}
151+
)
166152

167153
try:
168154
await runtime.stop_when_idle()
169155

170156
msg = "Runtime stopped - area generation completed"
171157
log.info(msg)
172-
if span:
173-
span.update(metadata={"runtime_completed": msg})
158+
span.update(metadata={"runtime_completed": msg})
174159

175160
print(f"Areas generated with tag: {areas_tag}")
176161
except Exception as e:
177162
msg = f"Error while waiting for runtime to stop: {e}"
178163
log.error(msg)
179-
if span:
180-
span.update(
181-
level="ERROR",
182-
status_message=str(e),
183-
metadata={
184-
"runtime_error": msg,
185-
"error": str(e),
186-
},
187-
)
164+
span.update(
165+
level="ERROR",
166+
status_message=str(e),
167+
metadata={
168+
"runtime_error": msg,
169+
"error": str(e),
170+
},
171+
)
188172
raise
189173

190174
except Exception as e:
@@ -194,17 +178,16 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
194178
log.error(error_msg)
195179
log.error(traceback_msg)
196180

197-
if span:
198-
span.update(
199-
level="ERROR",
200-
status_message=str(e),
201-
metadata={
202-
"generation_error": error_msg,
203-
"error": str(e),
204-
"traceback": traceback_msg,
205-
},
206-
)
181+
span.update(
182+
level="ERROR",
183+
status_message=str(e),
184+
metadata={
185+
"generation_error": error_msg,
186+
"error": str(e),
187+
"traceback": traceback_msg,
188+
},
189+
)
207190

208191
if langfuse_client is None:
209-
lf.flush()
192+
langfuse_client.flush()
210193
raise

0 commit comments

Comments
 (0)