Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:Add support zip, tar.gz and git repo uploading as files #144

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 152 additions & 57 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from agency_swarm.tools import BaseTool, FileSearch, CodeInterpreter
from agency_swarm.user import User
from agency_swarm.util.files import determine_file_type
from agency_swarm.util.helpers import extract_zip, extract_tar, git_clone
from agency_swarm.util.helpers.file_upload_helpers import is_file_extension_supported
from agency_swarm.util.shared_state import SharedState
from openai.types.beta.threads.runs.tool_call import ToolCall, FunctionToolCall, CodeInterpreterToolCall, FileSearchToolCall

from openai.types.beta.threads.runs.tool_call import ToolCall, FunctionToolCall, CodeInterpreterToolCall, \
FileSearchToolCall

from agency_swarm.util.streaming import AgencyEventHandler

Expand Down Expand Up @@ -138,12 +140,12 @@ def get_completion(self, message: str,
Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread.
"""
res = self.main_thread.get_completion(message=message,
message_files=message_files,
attachments=attachments,
recipient_agent=recipient_agent,
additional_instructions=additional_instructions,
tool_choice=tool_choice,
yield_messages=yield_messages)
message_files=message_files,
attachments=attachments,
recipient_agent=recipient_agent,
additional_instructions=additional_instructions,
tool_choice=tool_choice,
yield_messages=yield_messages)

if not yield_messages:
while True:
Expand All @@ -154,7 +156,6 @@ def get_completion(self, message: str,

return res


def get_completion_stream(self,
message: str,
event_handler: type(AgencyEventHandler),
Expand Down Expand Up @@ -183,13 +184,13 @@ def get_completion_stream(self,
raise Exception("Event handler must not be an instance.")

res = self.main_thread.get_completion_stream(message=message,
message_files=message_files,
event_handler=event_handler,
attachments=attachments,
recipient_agent=recipient_agent,
additional_instructions=additional_instructions,
tool_choice=tool_choice
)
message_files=message_files,
event_handler=event_handler,
attachments=attachments,
recipient_agent=recipient_agent,
additional_instructions=additional_instructions,
tool_choice=tool_choice
)

while True:
try:
Expand Down Expand Up @@ -230,8 +231,16 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs):
images = []
message_file_names = None
uploading_files = False
cloning_files = False
recipient_agents = [agent.name for agent in self.main_recipients]
recipient_agent = self.main_recipients[0]
recipient_agent: Agent = self.main_recipients[0]

if isinstance(recipient_agent.files_folder, list):
to_folder = recipient_agent.files_folder[0]
else:
to_folder = recipient_agent.files_folder

os.makedirs(to_folder, exist_ok=True)

with gr.Blocks(js=js) as demo:
chatbot_queue = queue.Queue()
Expand All @@ -240,15 +249,76 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs):
with gr.Column(scale=9):
dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agents,
value=recipient_agent.name)
msg = gr.Textbox(label="Your Message", lines=4)
msg = gr.Textbox(label="Your Message", lines=10)
with gr.Column(scale=1):
file_upload = gr.Files(label="OpenAI Files", type="filepath")
repo_clone = gr.Textbox(label="OpenAI GIT URL", lines=1)
button = gr.Button(value="Send", variant="primary")

def handle_dropdown_change(selected_option):
nonlocal recipient_agent
recipient_agent = self._get_agent_by_name(selected_option)

def handle_file_clone(repo_url):
nonlocal attachments
nonlocal message_file_names
nonlocal cloning_files
nonlocal images
cloning_files = True
attachments = []
message_file_names = []

try:
extracted_files = []
for file_path in git_clone(repo_url, to_folder):
if file_path.endswith('.zip'):
extracted_files.extend(extract_zip(file_path, to_folder))
elif file_path.endswith('.tar.gz'):
extracted_files.extend(extract_tar(file_path, to_folder))
else:
extracted_files.append(file_path)

print(f"Found {', '.join(extracted_files)}")

for file in extracted_files:
if is_file_extension_supported(file):
file_type = determine_file_type(file)
purpose = "assistants" if file_type != "vision" else "vision"
tools = [{
"type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [
{"type": "file_search"}]

with open(file, 'rb') as f:
try:
# Upload the file to OpenAI
uploaded_file = self.main_thread.client.files.create(
file=f,
purpose=purpose
)

if file_type == "vision":
images.append({
"type": "image_file",
"image_file": {"file_id": uploaded_file.id}
})
else:
attachments.append({
"file_id": uploaded_file.id,
"tools": tools
})

message_file_names.append(uploaded_file.filename)
print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}")
except Exception as e:
print(f"Uploading error: {e}")
return attachments
except Exception as e:
print(f"Error: {e}")
finally:
cloning_files = False
cloning_files = False
return "No files uploaded"

def handle_file_upload(file_list):
nonlocal attachments
nonlocal message_file_names
Expand All @@ -259,47 +329,65 @@ def handle_file_upload(file_list):
message_file_names = []
if file_list:
try:
extracted_files = []
for file_obj in file_list:
file_type = determine_file_type(file_obj.name)
purpose = "assistants" if file_type != "vision" else "vision"
tools = [{"type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [{"type": "file_search"}]

with open(file_obj.name, 'rb') as f:
# Upload the file to OpenAI
file = self.main_thread.client.files.create(
file=f,
purpose=purpose
)

if file_type == "vision":
images.append({
"type": "image_file",
"image_file": {"file_id": file.id}
})
else:
attachments.append({
"file_id": file.id,
"tools": tools
})
file_path = file_obj.name

message_file_names.append(file.filename)
print(f"Uploaded file ID: {file.id}")
if file_path.endswith('.zip'):
extracted_files.extend(extract_zip(file_path, to_folder))
elif file_path.endswith('.tar.gz'):
extracted_files.extend(extract_tar(file_path, to_folder))
else:
extracted_files.append(file_path)

print(f"Found {', '.join(extracted_files)}")

for file in extracted_files:
if is_file_extension_supported(file):
file_type = determine_file_type(file)
purpose = "assistants" if file_type != "vision" else "vision"
tools = [{
"type": "code_interpreter"}] if file_type == "assistants.code_interpreter" else [
{"type": "file_search"}]

with open(file, 'rb') as f:
try:
# Upload the file to OpenAI
uploaded_file = self.main_thread.client.files.create(
file=f,
purpose=purpose
)

if file_type == "vision":
images.append({
"type": "image_file",
"image_file": {"file_id": uploaded_file.id}
})
else:
attachments.append({
"file_id": uploaded_file.id,
"tools": tools
})

message_file_names.append(uploaded_file.filename)
print(f"Uploaded file ID: {uploaded_file.id}: {uploaded_file.filename}")
except Exception as e:
print(f"Uploading error: {e}")
return attachments
except Exception as e:
print(f"Error: {e}")
return str(e)
finally:
uploading_files = False

uploading_files = False
return "No files uploaded"

def user(user_message, history):
if not user_message.strip():
return user_message, history

nonlocal message_file_names
nonlocal uploading_files
nonlocal cloning_files
nonlocal images
nonlocal attachments
nonlocal recipient_agent
Expand All @@ -312,13 +400,15 @@ def check_and_add_tools_in_attachments(attachments, recipient_agent):
if not any(isinstance(t, FileSearch) for t in recipient_agent.tools):
# Add FileSearch tool if it does not exist
recipient_agent.tools.append(FileSearch)
recipient_agent.client.beta.assistants.update(recipient_agent.id, tools=recipient_agent.get_oai_tools())
recipient_agent.client.beta.assistants.update(recipient_agent.id,
tools=recipient_agent.get_oai_tools())
print("Added FileSearch tool to recipient agent to analyze the file.")
elif tool["type"] == "code_interpreter":
if not any(isinstance(t, CodeInterpreter) for t in recipient_agent.tools):
# Add CodeInterpreter tool if it does not exist
recipient_agent.tools.append(CodeInterpreter)
recipient_agent.client.beta.assistants.update(recipient_agent.id, tools=recipient_agent.get_oai_tools())
recipient_agent.client.beta.assistants.update(recipient_agent.id,
tools=recipient_agent.get_oai_tools())
print("Added CodeInterpreter tool to recipient agent to analyze the file.")
return None

Expand Down Expand Up @@ -361,7 +451,6 @@ def on_message_created(self, message: Message) -> None:
if content.type == "text":
full_content += content.text.value + "\n"


self.message_output = MessageOutput("text", self.agent_name, self.recipient_agent_name,
full_content)

Expand All @@ -381,7 +470,7 @@ def on_tool_call_created(self, tool_call: ToolCall):
if isinstance(tool_call, dict):
if "type" not in tool_call:
tool_call["type"] = "function"

if tool_call["type"] == "function":
tool_call = FunctionToolCall(**tool_call)
elif tool_call["type"] == "code_interpreter":
Expand All @@ -403,7 +492,7 @@ def on_tool_call_done(self, snapshot: ToolCall):
if isinstance(snapshot, dict):
if "type" not in snapshot:
snapshot["type"] = "function"

if snapshot["type"] == "function":
snapshot = FunctionToolCall(**snapshot)
elif snapshot["type"] == "code_interpreter":
Expand All @@ -412,7 +501,7 @@ def on_tool_call_done(self, snapshot: ToolCall):
snapshot = FileSearchToolCall(**snapshot)
else:
raise ValueError("Invalid tool call type: " + snapshot["type"])

self.message_output = None

# TODO: add support for code interpreter and retrieval tools
Expand Down Expand Up @@ -470,15 +559,21 @@ def bot(original_message, history):
nonlocal recipient_agent
nonlocal images
nonlocal uploading_files
nonlocal cloning_files

if uploading_files:
history.append([None, "Uploading files... Please wait."])
yield "", history
return "", history

if cloning_files:
history.append([None, "Cloning files... Please wait."])
yield "", history
return "", history

print("Message files: ", attachments)
print("Images: ", images)

if images and len(images) > 0:
original_message = [
{
Expand All @@ -488,7 +583,6 @@ def bot(original_message, history):
*images
]


completion_thread = threading.Thread(target=self.get_completion_stream, args=(
original_message, GradioEventHandler, [], recipient_agent, "", attachments, None))
completion_thread.start()
Expand All @@ -497,7 +591,7 @@ def bot(original_message, history):
message_file_names = []
images = []
uploading_files = False

cloning_files = False
new_message = True
while True:
try:
Expand Down Expand Up @@ -530,6 +624,7 @@ def bot(original_message, history):
)
dropdown.change(handle_dropdown_change, dropdown)
file_upload.change(handle_file_upload, file_upload)
repo_clone.change(handle_file_clone, repo_clone)
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [msg, chatbot], [msg, chatbot]
)
Expand Down Expand Up @@ -605,7 +700,7 @@ def on_tool_call_created(self, tool_call):
if isinstance(tool_call, dict):
if "type" not in tool_call:
tool_call["type"] = "function"

if tool_call["type"] == "function":
tool_call = FunctionToolCall(**tool_call)
elif tool_call["type"] == "code_interpreter":
Expand All @@ -626,7 +721,7 @@ def on_tool_call_delta(self, delta, snapshot):
if isinstance(snapshot, dict):
if "type" not in snapshot:
snapshot["type"] = "function"

if snapshot["type"] == "function":
snapshot = FunctionToolCall(**snapshot)
elif snapshot["type"] == "code_interpreter":
Expand All @@ -635,7 +730,7 @@ def on_tool_call_delta(self, delta, snapshot):
snapshot = FileSearchToolCall(**snapshot)
else:
raise ValueError("Invalid tool call type: " + snapshot["type"])

self.message_output.cprint_update(str(snapshot.function))

@override
Expand Down Expand Up @@ -762,7 +857,7 @@ def _init_agents(self):
agent.max_completion_tokens = self.max_completion_tokens
if self.truncation_strategy is not None and agent.truncation_strategy is None:
agent.truncation_strategy = self.truncation_strategy

if not agent.shared_state:
agent.shared_state = self.shared_state

Expand Down
Loading