|
1 | 1 | from __future__ import annotations
|
2 | 2 | import asyncio
|
| 3 | +from collections.abc import Sequence |
3 | 4 | import json
|
4 |
| -from typing import TypedDict, no_type_check |
| 5 | +from typing import Any, TypeAlias, TypedDict, cast |
5 | 6 | from typing_extensions import (
|
6 | 7 | TypeVar,
|
7 | 8 | Callable,
|
8 | 9 | Awaitable,
|
9 | 10 | Annotated,
|
10 | 11 | NotRequired,
|
11 | 12 | override,
|
12 |
| - Sequence, |
13 | 13 | Doc,
|
14 |
| - Any, |
15 | 14 | )
|
16 | 15 |
|
17 | 16 | from typechat import (
|
|
23 | 22 | TypeChatTranslator,
|
24 | 23 | python_type_to_typescript_schema,
|
25 | 24 | )
|
26 |
| -import collections.abc |
27 | 25 |
|
28 | 26 | T = TypeVar("T", covariant=True)
|
29 | 27 |
|
|
70 | 68 | },
|
71 | 69 | )
|
72 | 70 |
|
73 |
| -JsonValue = str | int | float | bool | None | dict[str, "Expression"] | list["Expression"] |
74 |
| -Expression = JsonValue | FunctionCall | ResultReference # type: ignore |
| 71 | +JsonValue: TypeAlias = str | int | float | bool | None | dict[str, "Expression"] | list["Expression"] |
| 72 | +Expression: TypeAlias = JsonValue | FunctionCall | ResultReference |
75 | 73 |
|
76 | 74 | JsonProgram = TypedDict("JsonProgram", {"@steps": list[FunctionCall]})
|
77 | 75 |
|
78 |
| -@no_type_check |
79 | 76 | async def evaluate_json_program(
|
80 |
| - program: JsonProgram, onCall: Callable[[str, Sequence[Expression]], Awaitable[Expression]] |
81 |
| -) -> Expression | Sequence[Expression]: |
82 |
| - results: list[Expression] | Expression = [] |
83 |
| - |
84 |
| - @no_type_check |
85 |
| - async def evaluate_array(array: Sequence[Expression]) -> Sequence[Expression]: |
86 |
| - return await asyncio.gather(*[evaluate_call(e) for e in array]) |
87 |
| - |
88 |
| - @no_type_check |
89 |
| - async def evaluate_object(expr: FunctionCall): |
90 |
| - if "@ref" in expr: |
91 |
| - index = expr["@ref"] |
92 |
| - if index < len(results): |
93 |
| - return results[index] |
94 |
| - |
95 |
| - elif "@func" in expr and "@args" in expr: |
96 |
| - function_name = expr["@func"] |
97 |
| - return await onCall(function_name, await evaluate_array(expr["@args"])) |
98 |
| - |
99 |
| - elif isinstance(expr, collections.abc.Sequence): |
100 |
| - return await evaluate_array(expr) |
101 |
| - |
102 |
| - else: |
103 |
| - raise ValueError("This condition should never hit") |
104 |
| - |
105 |
| - @no_type_check |
106 |
| - async def evaluate_call(expr: FunctionCall) -> Expression | Sequence[Expression]: |
107 |
| - if isinstance(expr, int) or isinstance(expr, float) or isinstance(expr, str): |
108 |
| - return expr |
109 |
| - return await evaluate_object(expr) |
| 77 | + program: JsonProgram, |
| 78 | + onCall: Callable[[str, Sequence[object]], Awaitable[JsonValue]] |
| 79 | +) -> Expression: |
| 80 | + results: list[JsonValue] = [] |
| 81 | + |
| 82 | + def evaluate_array(array: Sequence[Expression]) -> Awaitable[list[Expression]]: |
| 83 | + return asyncio.gather(*[evaluate_expression(e) for e in array]) |
| 84 | + |
| 85 | + async def evaluate_expression(expr: Expression) -> JsonValue: |
| 86 | + match expr: |
| 87 | + case bool() | int() | float() | str() | None: |
| 88 | + return expr |
| 89 | + |
| 90 | + case { "@ref": int(index) } if not isinstance(index, bool): |
| 91 | + if 0 <= index < len(results): |
| 92 | + return results[index] |
| 93 | + |
| 94 | + raise ValueError(f"Index {index} is out of range [0, {len(results)})") |
| 95 | + |
| 96 | + case { "@ref": ref_value }: |
| 97 | + raise ValueError(f"'ref' value must be an integer, but was ${ref_value}") |
| 98 | + |
| 99 | + case { "@func": str(function_name) }: |
| 100 | + args: list[Expression] |
| 101 | + match expr: |
| 102 | + case { "@args": None }: |
| 103 | + args = [] |
| 104 | + case { "@args": list() }: |
| 105 | + args = cast(list[Expression], expr["@args"]) # TODO |
| 106 | + case { "@args": _ }: |
| 107 | + raise ValueError("Given an invalid value for '@args'.") |
| 108 | + case _: |
| 109 | + args = [] |
| 110 | + |
| 111 | + return await onCall(function_name, await evaluate_array(args)) |
| 112 | + |
| 113 | + case list(array_expression_elements): |
| 114 | + return await evaluate_array(array_expression_elements) |
| 115 | + |
| 116 | + case _: |
| 117 | + raise ValueError("This condition should never hit") |
110 | 118 |
|
111 | 119 | for step in program["@steps"]:
|
112 |
| - results.append(await evaluate_call(step)) |
| 120 | + results.append(await evaluate_expression(step)) |
113 | 121 |
|
114 | 122 | if len(results) > 0:
|
115 | 123 | return results[-1]
|
116 | 124 | else:
|
117 | 125 | return None
|
118 | 126 |
|
119 | 127 |
|
120 |
| -class TypeChatProgramValidator(TypeChatValidator[T]): |
121 |
| - def __init__(self, py_type: type[T]): |
122 |
| - # the base class init method creates a typeAdapter for T. This operation fails for the JsonProgram type |
123 |
| - super().__init__(py_type=Any) # type: ignore |
| 128 | +class TypeChatProgramValidator(TypeChatValidator[JsonProgram]): |
| 129 | + def __init__(self): |
| 130 | + # TODO: This example should eventually be updated to use Python 3.12 type aliases |
| 131 | + # Passing in `JsonProgram` for `py_type` would cause issues because |
| 132 | + # Pydantic's `TypeAdapter` ends up trying to eagerly construct an |
| 133 | + # anonymous recursive type. Even a NewType does not work here. |
| 134 | + # For now, we just pass in `Any` in place of `JsonProgram`. |
| 135 | + super().__init__(py_type=cast(type[JsonProgram], Any)) |
124 | 136 |
|
125 | 137 | @override
|
126 |
| - def validate(self, json_text: str) -> Result[T]: |
| 138 | + def validate(self, json_text: str) -> Result[JsonProgram]: |
127 | 139 | # Pydantic is not able to validate JsonProgram instances. It fails with a recursion error.
|
128 |
| - # For JsonProgram, simply validate that it has a non-zero number of @steps |
| 140 | + # For JsonProgram, so we simply validate that it has a non-zero number of `@steps`. |
129 | 141 | # TODO: extend validations
|
130 | 142 | typed_dict = json.loads(json_text)
|
131 |
| - if "@steps" in typed_dict and isinstance(typed_dict["@steps"], collections.abc.Sequence): |
| 143 | + if "@steps" in typed_dict and isinstance(typed_dict["@steps"], Sequence): |
132 | 144 | return Success(typed_dict)
|
133 | 145 | else:
|
134 | 146 | return Failure("This is not a valid program. The program must have an array of @steps")
|
135 | 147 |
|
136 | 148 |
|
137 |
| -class TypeChatProgramTranslator(TypeChatTranslator[T]): |
| 149 | +class TypeChatProgramTranslator(TypeChatTranslator[JsonProgram]): |
138 | 150 | _api_declaration_str: str
|
139 | 151 |
|
140 |
| - def __init__(self, model: TypeChatModel, validator: TypeChatProgramValidator[T], api_type: type): |
| 152 | + def __init__(self, model: TypeChatModel, validator: TypeChatProgramValidator, api_type: type): |
141 | 153 | super().__init__(model=model, validator=validator, target_type=api_type)
|
| 154 | + # TODO: the conversion result here has errors! |
142 | 155 | conversion_result = python_type_to_typescript_schema(api_type)
|
143 | 156 | self._api_declaration_str = conversion_result.typescript_schema_str
|
144 | 157 |
|
|
0 commit comments