Skip to content

Commit b4e6123

Browse files
Python housekeeping (#188)
* Remove `DefaultOpenAIModel` from API. * Format `model.py`. * Include `spotipy` as what's effectively a devDependency for our examples. * Update version to `0.0.2` so we can publish. * Add explicit support for Python 3.12 in the `classifiers` list. * Fixed type errors in the `math` example. * Updates to `agents`, and `math` example to avoid weird `Unknown` issues. * fixup: remove unused import * Add type: ignore for `spotipy` import. * Remove type-checking errors. * Ensure that 'exit' and 'quit' work in process_requests. * Fix up math example, remove redundancy, document issues around JsonProgram in the math example. * Make the program validators/translators non-generic.
1 parent d67dad1 commit b4e6123

File tree

10 files changed

+118
-78
lines changed

10 files changed

+118
-78
lines changed

python/examples/math/demo.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
import asyncio
2+
from collections.abc import Sequence
23
import json
34
import sys
5+
from typing import cast
46
from dotenv import dotenv_values
57
import schema as math
68
from typechat import Failure, create_language_model, process_requests
7-
from program import TypeChatProgramTranslator, TypeChatProgramValidator, JsonProgram, evaluate_json_program # type: ignore
9+
from program import TypeChatProgramTranslator, TypeChatProgramValidator, evaluate_json_program
810

911
vals = dotenv_values()
1012
model = create_language_model(vals)
11-
validator = TypeChatProgramValidator(JsonProgram)
13+
validator = TypeChatProgramValidator()
1214
translator = TypeChatProgramTranslator(model, validator, math.MathAPI)
1315

1416

15-
async def apply_operations(func: str, args: list[int | float]) -> int | float:
17+
async def apply_operations(func: str, args: Sequence[object]) -> int | float:
1618
print(f"{func}({json.dumps(args)}) ")
19+
20+
for arg in args:
21+
if not isinstance(arg, (int, float)):
22+
raise ValueError("All arguments are expected to be numeric.")
23+
24+
args = cast(Sequence[int | float], args)
25+
1726
match func:
1827
case "add":
1928
return args[0] + args[1]
@@ -38,7 +47,7 @@ async def request_handler(message: str):
3847
else:
3948
result = result.value
4049
print(json.dumps(result, indent=2))
41-
math_result = await evaluate_json_program(result, apply_operations) # type: ignore
50+
math_result = await evaluate_json_program(result, apply_operations)
4251
print(f"Math Result: {math_result}")
4352

4453

python/examples/math/program.py

+60-47
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from __future__ import annotations
22
import asyncio
3+
from collections.abc import Sequence
34
import json
4-
from typing import TypedDict, no_type_check
5+
from typing import Any, TypeAlias, TypedDict, cast
56
from typing_extensions import (
67
TypeVar,
78
Callable,
89
Awaitable,
910
Annotated,
1011
NotRequired,
1112
override,
12-
Sequence,
1313
Doc,
14-
Any,
1514
)
1615

1716
from typechat import (
@@ -23,7 +22,6 @@
2322
TypeChatTranslator,
2423
python_type_to_typescript_schema,
2524
)
26-
import collections.abc
2725

2826
T = TypeVar("T", covariant=True)
2927

@@ -70,75 +68,90 @@
7068
},
7169
)
7270

73-
JsonValue = str | int | float | bool | None | dict[str, "Expression"] | list["Expression"]
74-
Expression = JsonValue | FunctionCall | ResultReference # type: ignore
71+
JsonValue: TypeAlias = str | int | float | bool | None | dict[str, "Expression"] | list["Expression"]
72+
Expression: TypeAlias = JsonValue | FunctionCall | ResultReference
7573

7674
JsonProgram = TypedDict("JsonProgram", {"@steps": list[FunctionCall]})
7775

78-
@no_type_check
7976
async def evaluate_json_program(
80-
program: JsonProgram, onCall: Callable[[str, Sequence[Expression]], Awaitable[Expression]]
81-
) -> Expression | Sequence[Expression]:
82-
results: list[Expression] | Expression = []
83-
84-
@no_type_check
85-
async def evaluate_array(array: Sequence[Expression]) -> Sequence[Expression]:
86-
return await asyncio.gather(*[evaluate_call(e) for e in array])
87-
88-
@no_type_check
89-
async def evaluate_object(expr: FunctionCall):
90-
if "@ref" in expr:
91-
index = expr["@ref"]
92-
if index < len(results):
93-
return results[index]
94-
95-
elif "@func" in expr and "@args" in expr:
96-
function_name = expr["@func"]
97-
return await onCall(function_name, await evaluate_array(expr["@args"]))
98-
99-
elif isinstance(expr, collections.abc.Sequence):
100-
return await evaluate_array(expr)
101-
102-
else:
103-
raise ValueError("This condition should never hit")
104-
105-
@no_type_check
106-
async def evaluate_call(expr: FunctionCall) -> Expression | Sequence[Expression]:
107-
if isinstance(expr, int) or isinstance(expr, float) or isinstance(expr, str):
108-
return expr
109-
return await evaluate_object(expr)
77+
program: JsonProgram,
78+
onCall: Callable[[str, Sequence[object]], Awaitable[JsonValue]]
79+
) -> Expression:
80+
results: list[JsonValue] = []
81+
82+
def evaluate_array(array: Sequence[Expression]) -> Awaitable[list[Expression]]:
83+
return asyncio.gather(*[evaluate_expression(e) for e in array])
84+
85+
async def evaluate_expression(expr: Expression) -> JsonValue:
86+
match expr:
87+
case bool() | int() | float() | str() | None:
88+
return expr
89+
90+
case { "@ref": int(index) } if not isinstance(index, bool):
91+
if 0 <= index < len(results):
92+
return results[index]
93+
94+
raise ValueError(f"Index {index} is out of range [0, {len(results)})")
95+
96+
case { "@ref": ref_value }:
97+
raise ValueError(f"'ref' value must be an integer, but was ${ref_value}")
98+
99+
case { "@func": str(function_name) }:
100+
args: list[Expression]
101+
match expr:
102+
case { "@args": None }:
103+
args = []
104+
case { "@args": list() }:
105+
args = cast(list[Expression], expr["@args"]) # TODO
106+
case { "@args": _ }:
107+
raise ValueError("Given an invalid value for '@args'.")
108+
case _:
109+
args = []
110+
111+
return await onCall(function_name, await evaluate_array(args))
112+
113+
case list(array_expression_elements):
114+
return await evaluate_array(array_expression_elements)
115+
116+
case _:
117+
raise ValueError("This condition should never hit")
110118

111119
for step in program["@steps"]:
112-
results.append(await evaluate_call(step))
120+
results.append(await evaluate_expression(step))
113121

114122
if len(results) > 0:
115123
return results[-1]
116124
else:
117125
return None
118126

119127

120-
class TypeChatProgramValidator(TypeChatValidator[T]):
121-
def __init__(self, py_type: type[T]):
122-
# the base class init method creates a typeAdapter for T. This operation fails for the JsonProgram type
123-
super().__init__(py_type=Any) # type: ignore
128+
class TypeChatProgramValidator(TypeChatValidator[JsonProgram]):
129+
def __init__(self):
130+
# TODO: This example should eventually be updated to use Python 3.12 type aliases
131+
# Passing in `JsonProgram` for `py_type` would cause issues because
132+
# Pydantic's `TypeAdapter` ends up trying to eagerly construct an
133+
# anonymous recursive type. Even a NewType does not work here.
134+
# For now, we just pass in `Any` in place of `JsonProgram`.
135+
super().__init__(py_type=cast(type[JsonProgram], Any))
124136

125137
@override
126-
def validate(self, json_text: str) -> Result[T]:
138+
def validate(self, json_text: str) -> Result[JsonProgram]:
127139
# Pydantic is not able to validate JsonProgram instances. It fails with a recursion error.
128-
# For JsonProgram, simply validate that it has a non-zero number of @steps
140+
# For JsonProgram, so we simply validate that it has a non-zero number of `@steps`.
129141
# TODO: extend validations
130142
typed_dict = json.loads(json_text)
131-
if "@steps" in typed_dict and isinstance(typed_dict["@steps"], collections.abc.Sequence):
143+
if "@steps" in typed_dict and isinstance(typed_dict["@steps"], Sequence):
132144
return Success(typed_dict)
133145
else:
134146
return Failure("This is not a valid program. The program must have an array of @steps")
135147

136148

137-
class TypeChatProgramTranslator(TypeChatTranslator[T]):
149+
class TypeChatProgramTranslator(TypeChatTranslator[JsonProgram]):
138150
_api_declaration_str: str
139151

140-
def __init__(self, model: TypeChatModel, validator: TypeChatProgramValidator[T], api_type: type):
152+
def __init__(self, model: TypeChatModel, validator: TypeChatProgramValidator, api_type: type):
141153
super().__init__(model=model, validator=validator, target_type=api_type)
154+
# TODO: the conversion result here has errors!
142155
conversion_result = python_type_to_typescript_schema(api_type)
143156
self._api_declaration_str = conversion_result.typescript_schema_str
144157

python/examples/multiSchema/agents.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from collections.abc import Sequence
12
import os
23
import sys
4+
from typing import cast
35

46
examples_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
57
if examples_path not in sys.path:
@@ -14,8 +16,7 @@
1416
from examples.math.program import (
1517
TypeChatProgramTranslator,
1618
TypeChatProgramValidator,
17-
evaluate_json_program, # type: ignore
18-
JsonProgram,
19+
evaluate_json_program,
1920
)
2021

2122
import examples.music.schema as music_schema
@@ -43,16 +44,23 @@ async def handle_request(self, line: str):
4344

4445

4546
class MathAgent:
46-
_validator: TypeChatProgramValidator[JsonProgram]
47-
_translator: TypeChatProgramTranslator[JsonProgram]
47+
_validator: TypeChatProgramValidator
48+
_translator: TypeChatProgramTranslator
4849

4950
def __init__(self, model: TypeChatModel):
5051
super().__init__()
51-
self._validator = TypeChatProgramValidator(JsonProgram)
52+
self._validator = TypeChatProgramValidator()
5253
self._translator = TypeChatProgramTranslator(model, self._validator, math_schema.MathAPI)
5354

54-
async def _handle_json_program_call(self, func: str, args: list[int | float]) -> int | float:
55+
async def _handle_json_program_call(self, func: str, args: Sequence[object]) -> int | float:
5556
print(f"{func}({json.dumps(args)}) ")
57+
58+
for arg in args:
59+
if not isinstance(arg, (int, float)):
60+
raise ValueError("All arguments are expected to be numeric.")
61+
62+
args = cast(Sequence[int | float], args)
63+
5664
match func:
5765
case "add":
5866
return args[0] + args[1]

python/notebooks/math.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"source": [
3535
"from dotenv import dotenv_values\n",
3636
"from typechat import Failure, create_language_model\n",
37-
"from examples.math.program import TypeChatProgramTranslator, TypeChatProgramValidator, JsonProgram, evaluate_json_program\n",
37+
"from examples.math.program import TypeChatProgramTranslator, TypeChatProgramValidator, evaluate_json_program\n",
3838
"from examples.math import schema as math"
3939
]
4040
},
@@ -46,7 +46,7 @@
4646
"source": [
4747
"vals = dotenv_values()\n",
4848
"model = create_language_model(vals)\n",
49-
"validator = TypeChatProgramValidator(JsonProgram)\n",
49+
"validator = TypeChatProgramValidator()\n",
5050
"translator = TypeChatProgramTranslator(model, validator, math.MathAPI)"
5151
]
5252
},

python/pyproject.toml

+3-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ classifiers = [
1717
"Development Status :: 4 - Beta",
1818
"Programming Language :: Python",
1919
"Programming Language :: Python :: 3.11",
20+
"Programming Language :: Python :: 3.12",
2021
"Programming Language :: Python :: Implementation :: CPython",
2122
"Programming Language :: Python :: Implementation :: PyPy",
2223
]
@@ -41,8 +42,10 @@ dependencies = [
4142
"openai>=1.3.6",
4243
"python-dotenv>=1.0.0",
4344
"pytest",
45+
"spotipy", # for examples
4446
"typing_extensions",
4547
]
48+
4649
[tool.hatch.envs.default.scripts]
4750
test = "pytest {args:tests}"
4851
test-cov = "coverage run -m pytest {args:tests}"
@@ -55,11 +58,6 @@ cov = [
5558
"cov-report",
5659
]
5760

58-
[tool.hatch.envs.examples]
59-
extra-dependencies = [
60-
"spotipy",
61-
]
62-
6361
[[tool.hatch.envs.all.matrix]]
6462
python = ["3.11"]
6563

python/src/typechat/__about__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# SPDX-FileCopyrightText: Microsoft Corporation
22
#
33
# SPDX-License-Identifier: MIT
4-
__version__ = "0.0.1"
4+
__version__ = "0.0.2"

python/src/typechat/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5-
from typechat._internal.model import DefaultOpenAIModel, TypeChatModel, create_language_model
5+
from typechat._internal.model import TypeChatModel, create_language_model
66
from typechat._internal.result import Failure, Result, Success
77
from typechat._internal.translator import TypeChatTranslator
88
from typechat._internal.ts_conversion import python_type_to_typescript_schema
99
from typechat._internal.validator import TypeChatValidator
1010
from typechat._internal.interactive import process_requests
1111

1212
__all__ = [
13-
"DefaultOpenAIModel",
1413
"TypeChatModel",
1514
"TypeChatTranslator",
1615
"TypeChatValidator",

python/src/typechat/_internal/interactive.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ async def process_requests(interactive_prompt: str, input_file_name: str | None,
1313
else:
1414
print(interactive_prompt, end="", flush=True)
1515
for line in sys.stdin:
16-
lower_line = line.lower()
16+
lower_line = line.lower().strip()
1717
if lower_line == "quit" or lower_line == "exit":
1818
break
1919
else:

python/src/typechat/_internal/model.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,23 @@ async def complete(self, input: str) -> Result[str]:
3535
return Failure(str(e))
3636

3737
def create_language_model(vals: dict[str,str|None]) -> TypeChatModel:
38-
model:TypeChatModel
38+
model: TypeChatModel
3939
client: openai.AsyncOpenAI | openai.AsyncAzureOpenAI
40-
40+
4141
if "OPENAI_API_KEY" in vals:
4242
client = openai.AsyncOpenAI(api_key=vals["OPENAI_API_KEY"])
4343
model = DefaultOpenAIModel(model_name=vals.get("OPENAI_MODEL", None) or "gpt-35-turbo", client=client)
4444

4545
elif "AZURE_OPENAI_API_KEY" in vals and "AZURE_OPENAI_ENDPOINT" in vals:
4646
os.environ["OPENAI_API_TYPE"] = "azure"
47-
client=openai.AsyncAzureOpenAI(azure_endpoint=vals.get("AZURE_OPENAI_ENDPOINT",None) or "", api_key=vals["AZURE_OPENAI_API_KEY"],api_version="2023-03-15-preview")
47+
client = openai.AsyncAzureOpenAI(
48+
azure_endpoint=vals.get("AZURE_OPENAI_ENDPOINT", None) or "",
49+
api_key=vals["AZURE_OPENAI_API_KEY"],
50+
api_version="2023-03-15-preview",
51+
)
4852
model = DefaultOpenAIModel(model_name=vals.get("AZURE_OPENAI_MODEL", None) or "gpt-35-turbo", client=client)
49-
53+
5054
else:
5155
raise ValueError("Missing environment variables for Open AI or Azure OpenAI model")
52-
53-
return model
56+
57+
return model

0 commit comments

Comments
 (0)