Skip to content

Commit 40d2d73

Browse files
authored
Merge pull request #17 from developmentseed/log
Log all the steps into aimstack
2 parents 973d57c + 1d5aee3 commit 40d2d73

File tree

6 files changed

+31
-11
lines changed

6 files changed

+31
-11
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,7 @@ cython_debug/
160160
#.idea/
161161

162162
# streamlit cache
163-
cache/
163+
cache/
164+
165+
# AIM experiment runs
166+
.aim/

agents/l4m_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def base_agent(
1818
llm=llm,
1919
tools=tools,
2020
agent=agent_type,
21-
max_iterations=2,
21+
max_iterations=5,
2222
early_stopping_method="generate",
2323
verbose=True,
2424
# memory=memory,

app.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import streamlit as st
66
from streamlit_folium import folium_static
77

8+
import langchain
89
from langchain.agents import AgentType
910
from langchain.chat_models import ChatOpenAI
1011
from langchain.tools import Tool, DuckDuckGoSearchRun
1112
from langchain.callbacks import (
1213
StreamlitCallbackHandler,
14+
AimCallbackHandler,
1315
get_openai_callback,
1416
)
1517

@@ -21,6 +23,9 @@
2123
from tools.stac.search import STACSearchTool
2224
from agents.l4m_agent import base_agent
2325

26+
# DEBUG
27+
langchain.debug = True
28+
2429

2530
@st.cache_resource(ttl="1h")
2631
def get_agent(
@@ -60,8 +65,6 @@ def get_agent(
6065

6166

6267
def run_query(agent, query):
63-
st_callback = StreamlitCallbackHandler(st.container())
64-
response = agent.run(query, callbacks=[st_callback])
6568
return response
6669

6770

@@ -117,7 +120,9 @@ def plot_vector(df):
117120
st.session_state.total_cost = 0
118121

119122
with st.sidebar:
120-
openai_api_key = st.text_input("OpenAI API Key", type="password")
123+
openai_api_key = os.getenv("OPENAI_API_KEY")
124+
if not openai_api_key:
125+
openai_api_key = st.text_input("OpenAI API Key", type="password")
121126

122127
st.subheader("OpenAI Usage")
123128
total_tokens = st.empty()
@@ -147,9 +152,18 @@ def plot_vector(df):
147152
st.info("Please add your OpenAI API key to continue.")
148153
st.stop()
149154

155+
aim_callback = AimCallbackHandler(
156+
repo=".",
157+
experiment_name="LLLLLM: Base Agent v0.1",
158+
)
159+
150160
agent = get_agent(openai_api_key)
161+
151162
with get_openai_callback() as cb:
152-
response = run_query(agent, prompt)
163+
st_callback = StreamlitCallbackHandler(st.container())
164+
response = agent.run(prompt, callbacks=[st_callback, aim_callback])
165+
166+
aim_callback.flush_tracker(langchain_asset=agent, reset=False, finish=True)
153167

154168
# Log OpenAI stats
155169
# print(f"Model name: {response.llm_output.get('model_name', '')}")

environment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ dependencies:
1818
- streamlit==1.24.1
1919
- streamlit-folium==0.12.0
2020
- watchdog==3.0.0
21+
- aim==3.17.5

tools/geopy/geocode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class GeopyGeocodeTool(BaseTool):
2121
def _run(self, place: str) -> tuple:
2222
locator = Nominatim(user_agent="geocode")
2323
location = locator.geocode(place)
24+
if location is None:
25+
return ("geocode", "Not a recognised address in Nomatim.")
2426
return ("geocode", (location.latitude, location.longitude))
2527

2628
def _arun(self, place: str):

tools/osmnx/geometry.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77

88

99
class PlaceWithTags(BaseModel):
10-
"Name of a place and tags in OSM."
10+
"Name of a place on the map and tags in OSM."
1111

12-
place: str = Field(..., description="name of a place")
13-
tags: Dict[str, str] = Field(..., description="open street maps tags")
12+
place: str = Field(..., description="name of a place on the map.")
13+
tags: Dict[str, str] = Field(..., description="open street maps tags.")
1414

1515

1616
class OSMnxGeometryTool(BaseTool):
1717
"""Tool to query geometries from Open Street Map (OSM)."""
1818

1919
name: str = "geometry"
2020
args_schema: Type[BaseModel] = PlaceWithTags
21-
description: str = "Use this tool to get geometry of different features of a place like building footprints, parks, lakes, hospitals, schools etc. \
22-
Pass the name of the place & relevant tags of OSM as args."
21+
description: str = "Use this tool to get geometry of different features of the place like building footprints, parks, lakes, hospitals, schools etc. \
22+
Pass the name of the place & tags of OSM as args."
2323
return_direct = True
2424

2525
def _run(self, place: str, tags: Dict[str, str]) -> gpd.GeoDataFrame:

0 commit comments

Comments
 (0)