1
+ from functools import partial
2
+ from typing import Dict , List
3
+ from app .models import ChatMessage
4
+ from langgraph .graph import StateGraph , END
5
+ from langchain_openai import ChatOpenAI
6
+ from app .core .graph .members import Leader , LeaderNode , Member , SummariserNode , TeamState , WorkerNode
7
+ from langchain_core .messages import HumanMessage , AIMessage
8
+ from langchain_core .runnables import RunnableLambda
9
+
10
+ model = ChatOpenAI (model = "gpt-3.5-turbo" )
11
+
12
+ # Create the Member/Leader class instance in members
13
+ def format_teams (teams : Dict [str , any ]):
14
+ """Update the team members to use Member/Leader"""
15
+ for team in teams :
16
+ members = teams [team ]["members" ]
17
+ for k ,v in members .items ():
18
+ print (v )
19
+ teams [team ]["members" ][k ] = Leader (** v ) if v ["type" ] == "leader" else Member (** v )
20
+ return teams
21
+
22
+ def router (state : TeamState ):
23
+ return state ["next" ]
24
+
25
+ def enter_chain (state : TeamState , team : Dict [str , str | List [Member | Leader ]]):
26
+ """
27
+ Initialise the sub-graph state.
28
+ This makes it so that the states of each graph don't get intermixed.
29
+ """
30
+ task = state ["task" ]
31
+ team_name = team ["name" ]
32
+ team_members = team ["members" ]
33
+
34
+ results = {
35
+ "messages" : task ,
36
+ "team_name" : team_name ,
37
+ "team_members" : team_members ,
38
+ }
39
+ return results
40
+
41
+ def exit_chain (state : TeamState ):
42
+ """
43
+ Pass the final response back to the top-level graph's state.
44
+ """
45
+ answer = state ["messages" ][- 1 ]
46
+ return {"messages" : [answer ]}
47
+
48
+ def create_graph (teams : Dict [str , Dict [str , str | Dict [str , Member | Leader ]]], leader_name : str ):
49
+ """
50
+ Create the team's graph.
51
+ """
52
+ build = StateGraph (TeamState )
53
+ # Add the start and end node
54
+ build .add_node (leader_name , RunnableLambda (LeaderNode (model ).delegate ))
55
+ build .add_node ("summariser" , RunnableLambda (SummariserNode (model ).summarise ))
56
+
57
+ members = teams [leader_name ]["members" ]
58
+ for name , member in members .items ():
59
+ if isinstance (member , Member ):
60
+ build .add_node (name , RunnableLambda (WorkerNode (model ).work ))
61
+ elif isinstance (member , Leader ):
62
+ subgraph = create_graph (teams , leader_name = name )
63
+ enter = partial (enter_chain , team = teams [name ])
64
+ build .add_node (name , enter | subgraph | exit_chain )
65
+ else :
66
+ continue
67
+ build .add_edge (name , leader_name )
68
+
69
+ conditional_mapping = {v :v for v in members }
70
+ conditional_mapping ["FINISH" ] = "summariser"
71
+ build .add_conditional_edges (leader_name , router , conditional_mapping )
72
+
73
+ build .set_entry_point (leader_name )
74
+ build .set_finish_point ("summariser" )
75
+ graph = build .compile ()
76
+ return graph
77
+
78
+
79
+
80
+ async def generator (teams : dict , team_leader : str , messages : List [ChatMessage ]):
81
+ """Create the graph and strem the response"""
82
+ format_teams (teams )
83
+ root = create_graph (teams , leader_name = team_leader )
84
+ messages = [HumanMessage (message .content ) if message .type == "human" else AIMessage (message .content ) for message in messages ]
85
+
86
+ async for output in root .astream ({
87
+ "messages" : messages ,
88
+ "team_name" : teams [team_leader ]["name" ],
89
+ "team_members" : teams [team_leader ]["members" ]
90
+ }):
91
+ for key , value in output .items ():
92
+ if key != "__end__" :
93
+ response = {key :value }
94
+ formatted_output = f"data: { response } \n \n "
95
+ print (formatted_output )
96
+ yield formatted_output
97
+
98
+ # teams = {
99
+ # "FoodExpertLeader": {
100
+ # "name": "FoodExperts",
101
+ # "members": {
102
+ # "ChineseFoodExpert": {
103
+ # "type": "worker",
104
+ # "name": "ChineseFoodExpert",
105
+ # "backstory": "Studied culinary school in Singapore. Well-verse in hawker to fine-dining experiences. ISFP.",
106
+ # "role": "Provide chinese food suggestions in Singapore",
107
+ # "tools": []
108
+ # },
109
+ # "MalayFoodExpert": {
110
+ # "type": "worker",
111
+ # "name": "MalayFoodExpert",
112
+ # "backstory": "Studied culinary school in Singapore. Well-verse in hawker to fine-dining experiences. INTP.",
113
+ # "role": "Provide malay food suggestions in Singapore",
114
+ # "tools": []
115
+ # },
116
+ # }
117
+ # },
118
+ # "TravelExpertLeader": {
119
+ # "name": "TravelKakis",
120
+ # "members": {
121
+ # "FoodExpertLeader": {
122
+ # "type": "leader",
123
+ # "name": "FoodExpertLeader",
124
+ # "role": "Gather inputs from your team and provide a diverse food suggestions in Singapore.",
125
+ # "tools": []
126
+ # },
127
+ # "HistoryExpert": {
128
+ # "type": "worker",
129
+ # "name": "HistoryExpert",
130
+ # "backstory": "Studied Singapore history. Well-verse in Singapore architecture. INTJ.",
131
+ # "role": "Provide places to sight-see with a history/architecture angle",
132
+ # "tools": ["search"]
133
+ # }
134
+ # }
135
+ # }
136
+ # }
137
+
138
+ # format_teams(teams)
139
+
140
+ # team_leader = "TravelExpertLeader"
141
+
142
+ # root = create_graph(teams, team_leader)
143
+
144
+ # messages = [
145
+ # HumanMessage(f"What is the best food in Singapore")
146
+ # ]
147
+
148
+ # initial_state = {
149
+ # "messages": messages,
150
+ # "team_name": teams[team_leader]["name"],
151
+ # "team_members": teams[team_leader]["members"],
152
+ # }
153
+
154
+ # async def main():
155
+ # async for s in root.astream(initial_state):
156
+ # if "__end__" not in s:
157
+ # print(s)
158
+ # print("----")
159
+
160
+ # import asyncio
161
+
162
+ # asyncio.run(main())
0 commit comments