From 563a4837f278168dab141d6a8eb3449b6af9b99e Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 18 Mar 2025 19:05:54 -0700 Subject: [PATCH 01/11] Annotations prototype, refs #716 --- llm/__init__.py | 4 + llm/cli.py | 438 +++++++++++++++++---------- llm/default_plugins/openai_models.py | 66 +++- llm/migrations.py | 15 + llm/models.py | 101 +++++- 5 files changed, 463 insertions(+), 161 deletions(-) diff --git a/llm/__init__.py b/llm/__init__.py index ba7d3a12..cc409ba8 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -4,11 +4,13 @@ NeedsKeyException, ) from .models import ( + Annotation, AsyncConversation, AsyncKeyModel, AsyncModel, AsyncResponse, Attachment, + Chunk, Conversation, EmbeddingModel, EmbeddingModelWithAliases, @@ -31,10 +33,12 @@ import struct __all__ = [ + "Annotation", "AsyncConversation", "AsyncKeyModel", "AsyncResponse", "Attachment", + "Chunk", "Collection", "Conversation", "get_async_model", diff --git a/llm/cli.py b/llm/cli.py index fb04dcd4..5f6cf9e0 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -11,6 +11,7 @@ AsyncConversation, AsyncKeyModel, AsyncResponse, + Chunk, Collection, Conversation, Response, @@ -1092,40 +1093,48 @@ def logs_list( sql_format["extra_where"] = where_ + " and ".join(where_bits) final_sql = sql.format(**sql_format) - rows = list( - db.query( - final_sql, - { - "model": model_id, - "query": query, - "conversation_id": conversation_id, - "schema_id": schema_id, - "id_gt": id_gt, - "id_gte": id_gte, - }, - ) - ) - # Reverse the order - we do this because we 'order by id desc limit 3' to get the - # 3 most recent results, but we still want to display them in chronological order - # ... except for searches where we don't do this + # Fetch the rows from the database + query_params = { + "model": model_id, + "query": query, + "conversation_id": conversation_id, + "schema_id": schema_id, + "id_gt": id_gt, + "id_gte": id_gte, + } + + # Instead of processing the rows directly, use Response.from_row() + responses = [] + for row in db.query(final_sql, query_params): + try: + response_obj = Response.from_row(db, row) + responses.append(response_obj) + except Exception as e: + click.echo( + f"Warning: Error processing row {row.get('id')}: {str(e)}", err=True + ) + raise + + # Reverse the order if not a search query to display in chronological order if not query and not data: - rows.reverse() + responses.reverse() # Fetch any attachments - ids = [row["id"] for row in rows] + ids = [response.id for response in responses] attachments = list(db.query(ATTACHMENTS_SQL.format(",".join("?" * len(ids))), ids)) attachments_by_id = {} for attachment in attachments: attachments_by_id.setdefault(attachment["response_id"], []).append(attachment) + # Process the data options if data or data_array or data_key or data_ids: # Special case for --data to output valid JSON to_output = [] - for row in rows: - response = row["response"] or "" + for response_obj in responses: + response_text = response_obj.text() try: - decoded = json.loads(response) + decoded = json.loads(response_text) new_items = [] if ( isinstance(decoded, dict) @@ -1138,160 +1147,275 @@ def logs_list( new_items.append(decoded) if data_ids: for item in new_items: - item[find_unused_key(item, "response_id")] = row["id"] - item[find_unused_key(item, "conversation_id")] = row["id"] + item[find_unused_key(item, "response_id")] = response_obj.id + item[find_unused_key(item, "conversation_id")] = ( + response_obj.conversation.id + if response_obj.conversation + else None + ) to_output.extend(new_items) except ValueError: pass click.echo(output_rows_as_json(to_output, not data_array)) return - for row in rows: - if truncate: - row["prompt"] = _truncate_string(row["prompt"]) - row["response"] = _truncate_string(row["response"]) - # Either decode or remove all JSON keys - keys = list(row.keys()) - for key in keys: - if key.endswith("_json") and row[key] is not None: - if truncate: - del row[key] - else: - row[key] = json.loads(row[key]) + # Handle extraction options + if extract or extract_last: + # Extract and return first code block + for response_obj in responses: + output = extract_fenced_code_block(response_obj.text(), last=extract_last) + if output is not None: + click.echo(output) + return + return + + # Handle response-only output + if response: + # Just output the last response + if responses: + click.echo(responses[-1].text()) + return - output = None + # Process JSON output if json_output: - # Output as JSON if requested - for row in rows: - row["attachments"] = [ + output_list = [] + for response_obj in responses: + response_dict = { + "id": response_obj.id, + "model": response_obj.model.model_id if response_obj.model else None, + "prompt": response_obj.prompt.prompt, + "system": response_obj.prompt.system, + "response": response_obj.text(), + "conversation_id": ( + response_obj.conversation.id if response_obj.conversation else None + ), + "datetime_utc": response_obj.datetime_utc(), + "input_tokens": response_obj.input_tokens, + "output_tokens": response_obj.output_tokens, + "token_details": response_obj.token_details, + } + + # Add prompt_json and response_json if available + if hasattr(response_obj, "_prompt_json") and response_obj._prompt_json: + response_dict["prompt_json"] = response_obj._prompt_json + if hasattr(response_obj, "response_json") and response_obj.response_json: + response_dict["response_json"] = response_obj.response_json + + # Add conversation name and model if conversation exists + if response_obj.conversation: + response_dict["conversation_name"] = response_obj.conversation.name + response_dict["conversation_model"] = ( + response_obj.conversation.model.model_id + if response_obj.conversation.model + else None + ) + + # Add attachments + response_dict["attachments"] = [ {k: v for k, v in attachment.items() if k != "response_id"} - for attachment in attachments_by_id.get(row["id"], []) + for attachment in attachments_by_id.get(response_obj.id, []) ] - output = json.dumps(list(rows), indent=2) - elif extract or extract_last: - # Extract and return first code block - for row in rows: - output = extract_fenced_code_block(row["response"], last=extract_last) - if output is not None: - break - elif response: - # Just output the last response - if rows: - output = rows[-1]["response"] - if output is not None: - click.echo(output) - else: - # Output neatly formatted human-readable logs - current_system = None - should_show_conversation = True - for row in rows: - if short: - system = _truncate_string(row["system"], 120, end=True) - prompt = _truncate_string(row["prompt"], 120, end=True) - cid = row["conversation_id"] - attachments = attachments_by_id.get(row["id"]) - obj = { - "model": row["model"], - "datetime": row["datetime_utc"].split(".")[0], - "conversation": cid, + output_list.append(response_dict) + + click.echo(json.dumps(output_list, indent=2)) + return + + # Handle the regular output format + current_system = None + should_show_conversation = True + for response_obj in responses: + if short: + system = ( + _truncate_string(response_obj.prompt.system, 120, end=True) + if response_obj.prompt.system + else None + ) + prompt = ( + _truncate_string(response_obj.prompt.prompt, 120, end=True) + if response_obj.prompt.prompt + else None + ) + cid = response_obj.conversation.id if response_obj.conversation else None + response_attachments = attachments_by_id.get(response_obj.id, []) + + obj = { + "model": response_obj.model.model_id if response_obj.model else None, + "datetime": response_obj.datetime_utc().split(".")[0], + "conversation": cid, + } + if system: + obj["system"] = system + if prompt: + obj["prompt"] = prompt + if response_attachments: + items = [] + for attachment in response_attachments: + details = {"type": attachment["type"]} + if attachment.get("path"): + details["path"] = attachment["path"] + if attachment.get("url"): + details["url"] = attachment["url"] + items.append(details) + obj["attachments"] = items + if usage and (response_obj.input_tokens or response_obj.output_tokens): + usage_details = { + "input": response_obj.input_tokens, + "output": response_obj.output_tokens, } - if system: - obj["system"] = system - if prompt: - obj["prompt"] = prompt - if attachments: - items = [] - for attachment in attachments: - details = {"type": attachment["type"]} - if attachment.get("path"): - details["path"] = attachment["path"] - if attachment.get("url"): - details["url"] = attachment["url"] - items.append(details) - obj["attachments"] = items - if usage and (row["input_tokens"] or row["output_tokens"]): - usage_details = { - "input": row["input_tokens"], - "output": row["output_tokens"], - } - if row["token_details"]: - usage_details["details"] = json.loads(row["token_details"]) - obj["usage"] = usage_details - click.echo(yaml.dump([obj], sort_keys=False).strip()) - continue + if response_obj.token_details: + usage_details["details"] = response_obj.token_details + obj["usage"] = usage_details + click.echo(yaml.dump([obj], sort_keys=False).strip()) + continue + + # Full output format + click.echo( + "# {}{}\n{}".format( + response_obj.datetime_utc().split(".")[0], + ( + " conversation: {} id: {}".format( + ( + response_obj.conversation.id + if response_obj.conversation + else None + ), + response_obj.id, + ) + if should_show_conversation + else "" + ), + ( + "\nModel: **{}**\n".format( + response_obj.model.model_id if response_obj.model else None + ) + if should_show_conversation + else "" + ), + ) + ) + # In conversation log mode only show it for the first one + if conversation_id: + should_show_conversation = False + click.echo("## Prompt\n\n{}".format(response_obj.prompt.prompt or "-- none --")) + if response_obj.prompt.system != current_system: + if response_obj.prompt.system is not None: + click.echo("\n## System\n\n{}".format(response_obj.prompt.system)) + current_system = response_obj.prompt.system + + # Handle schema if present + if response_obj.prompt.schema: click.echo( - "# {}{}\n{}".format( - row["datetime_utc"].split(".")[0], - ( - " conversation: {} id: {}".format( - row["conversation_id"], row["id"] - ) - if should_show_conversation - else "" - ), - ( - "\nModel: **{}**\n".format(row["model"]) - if should_show_conversation - else "" - ), + "\n## Schema\n\n```json\n{}\n```".format( + json.dumps(response_obj.prompt.schema, indent=2) ) ) - # In conversation log mode only show it for the first one - if conversation_id: - should_show_conversation = False - click.echo("## Prompt\n\n{}".format(row["prompt"] or "-- none --")) - if row["system"] != current_system: - if row["system"] is not None: - click.echo("\n## System\n\n{}".format(row["system"])) - current_system = row["system"] - if row["schema_json"]: - click.echo( - "\n## Schema\n\n```json\n{}\n```".format( - json.dumps(row["schema_json"], indent=2) - ) - ) - attachments = attachments_by_id.get(row["id"]) - if attachments: - click.echo("\n### Attachments\n") - for i, attachment in enumerate(attachments, 1): - if attachment["path"]: - path = attachment["path"] - click.echo( - "{}. **{}**: `{}`".format(i, attachment["type"], path) - ) - elif attachment["url"]: - click.echo( - "{}. **{}**: {}".format( - i, attachment["type"], attachment["url"] - ) + + # Handle attachments + response_attachments = attachments_by_id.get(response_obj.id, []) + if response_attachments: + click.echo("\n### Attachments\n") + for i, attachment in enumerate(response_attachments, 1): + if attachment["path"]: + path = attachment["path"] + click.echo("{}. **{}**: `{}`".format(i, attachment["type"], path)) + elif attachment["url"]: + click.echo( + "{}. **{}**: {}".format( + i, attachment["type"], attachment["url"] ) - elif attachment["content_length"]: - click.echo( - "{}. **{}**: `<{} bytes>`".format( - i, - attachment["type"], - f"{attachment['content_length']:,}", - ) + ) + elif attachment["content_length"]: + click.echo( + "{}. **{}**: `<{} bytes>`".format( + i, + attachment["type"], + f"{attachment['content_length']:,}", ) + ) - # If a schema was provided and the row is valid JSON, pretty print and syntax highlight it - response = row["response"] - if row["schema_json"]: - try: - parsed = json.loads(response) - response = "```json\n{}\n```".format(json.dumps(parsed, indent=2)) - except ValueError: - pass - click.echo("\n## Response\n\n{}\n".format(response)) - if usage: - token_usage = token_usage_string( - row["input_tokens"], - row["output_tokens"], - json.loads(row["token_details"]) if row["token_details"] else None, - ) - if token_usage: - click.echo("## Token usage:\n\n{}\n".format(token_usage)) + # Handle response output + response_text = response_obj.text() + # If a schema was provided and the row is valid JSON, pretty print and syntax highlight it + if response_obj.prompt.schema: + try: + parsed = json.loads(response_text) + response_text = "```json\n{}\n```".format(json.dumps(parsed, indent=2)) + except ValueError: + pass + + if response_obj.annotations: + response_text = format_chunks(response_obj.chunks()) + + click.echo("\n## Response\n\n{}\n".format(response_text)) + + # Handle token usage + if usage and (response_obj.input_tokens or response_obj.output_tokens): + token_usage = token_usage_string( + response_obj.input_tokens, + response_obj.output_tokens, + response_obj.token_details, + ) + if token_usage: + click.echo("## Token usage:\n\n{}\n".format(token_usage)) + + +def format_chunks(chunks: Iterable[Chunk]) -> str: + """ + Format a list of Chunk objects into a structured text with annotations. + + Args: + chunks: A list of Chunk objects containing text and metadata + + Returns: + A formatted string with annotations listed at the end + """ + result = "" + annotations = [] + annotation_index = 1 + + # First pass to collect all text and mark positions for annotations + combined_text = "" + annotation_positions = [] + + for chunk in chunks: + # Add the chunk text to the combined text + start_pos = len(combined_text) + combined_text += chunk.text + end_pos = len(combined_text) + + # If chunk has data, record its position and data + if chunk.data: + annotation_positions.append( + { + "start": start_pos, + "end": end_pos, + "data": chunk.data, + "index": annotation_index, + } + ) + annotation_index += 1 + + # Second pass to build the result with annotation markers + result = combined_text + + # Sort annotations in reverse order to avoid messing up indices when inserting + for pos in sorted(annotation_positions, key=lambda x: x["end"], reverse=True): + # Store the annotation data + annotation_data = textwrap.indent(json.dumps(pos["data"], indent=2), " ") + annotations.append(f" [{pos['index']}]\n{annotation_data}") + + # Insert the annotation marker at the end of the chunk + annotation_marker = f" 「{result[pos['start']:pos['end']]}」[{pos['index']}]" + result = result[: pos["end"]] + annotation_marker + result[pos["end"] :] + + annotations.reverse() + + # Add annotations section if there are any + if annotations: + result += "\n\n### Annotations\n\n" + "\n\n".join(annotations) + + return result @cli.group( diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 9bc37e68..136e63c7 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -116,6 +116,13 @@ def register_models(register): aliases=("3.5-instruct", "chatgpt-instruct"), ) + # Search models + for model_id in ("gpt-4o-search-preview", "gpt-4o-mini-search-preview"): + register( + Chat(model_id, search_preview=True), + AsyncChat(model_id, search_preview=True), + ) + # Load extra models extra_path = llm.user_dir() / "extra-openai-models.yaml" if not extra_path.exists(): @@ -351,7 +358,7 @@ def validate_logit_bias(cls, logit_bias): return validated_logit_bias -class ReasoningEffortEnum(str, Enum): +class LowMediumHighEnum(str, Enum): low = "low" medium = "medium" high = "high" @@ -362,7 +369,7 @@ class OptionsForReasoning(SharedOptions): description="Output a valid JSON object {...}. Prompt must mention JSON.", default=None, ) - reasoning_effort: Optional[ReasoningEffortEnum] = Field( + reasoning_effort: Optional[LowMediumHighEnum] = Field( description=( "Constraints effort on reasoning for reasoning models. Currently supported " "values are low, medium, and high. Reducing reasoning effort can result in " @@ -372,6 +379,15 @@ class OptionsForReasoning(SharedOptions): ) +class OptionsForSearchPreview(SharedOptions): + search_context_size: Optional[LowMediumHighEnum] = Field( + description=( + "How much context is retrieved from the web to help the tool formulate a response" + ), + default=None, + ) + + def _attachment(attachment): url = attachment.url base64_content = "" @@ -418,6 +434,7 @@ def __init__( reasoning=False, supports_schema=False, allows_system_prompt=True, + search_preview=False, ): self.model_id = model_id self.key = key @@ -431,12 +448,16 @@ def __init__( self.can_stream = can_stream self.vision = vision self.allows_system_prompt = allows_system_prompt + self.search_preview = search_preview self.attachment_types = set() if reasoning: self.Options = OptionsForReasoning + if search_preview: + self.Options = OptionsForSearchPreview + if vision: self.attachment_types.update( { @@ -511,6 +532,22 @@ def set_usage(self, response, usage): input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage) ) + def set_annotations(self, response, annotations: list): + # Annotation(type='url_citation', url_citation=AnnotationURLCitation( + # end_index=358, start_index=284, title='...', url='https://...')) + to_add = [] + for annotation in annotations: + if annotation["type"] == "url_citation": + data = annotation["url_citation"] + start_index = data.pop("start_index") + end_index = data.pop("end_index") + to_add.append( + llm.Annotation( + start_index=start_index, end_index=end_index, data=data + ) + ) + response.add_annotations(to_add) + def get_client(self, key, *, async_=False): kwargs = {} if self.api_base: @@ -550,6 +587,13 @@ def build_kwargs(self, prompt, stream): } if stream: kwargs["stream_options"] = {"include_usage": True} + if self.search_preview: + kwargs["web_search_options"] = {} + if prompt.options.search_context_size: + kwargs.pop("search_context_size", None) + kwargs["web_search_options"][ + "search_context_size" + ] = prompt.options.search_context_size return kwargs @@ -571,6 +615,7 @@ def execute(self, prompt, stream, response, conversation=None, key=None): kwargs = self.build_kwargs(prompt, stream) client = self.get_client(key) usage = None + annotations = [] if stream: completion = client.chat.completions.create( model=self.model_name or self.model_id, @@ -581,6 +626,10 @@ def execute(self, prompt, stream, response, conversation=None, key=None): chunks = [] for chunk in completion: chunks.append(chunk) + try: + annotations.extend(chunk.choices[0].delta.annotations) + except (AttributeError, IndexError): + pass if chunk.usage: usage = chunk.usage.model_dump() try: @@ -589,7 +638,11 @@ def execute(self, prompt, stream, response, conversation=None, key=None): content = None if content is not None: yield content - response.response_json = remove_dict_none_values(combine_chunks(chunks)) + final_json = remove_dict_none_values(combine_chunks(chunks)) + if annotations: + final_json["annotations"] = annotations + self.set_annotations(response, annotations) + response.response_json = final_json else: completion = client.chat.completions.create( model=self.model_name or self.model_id, @@ -600,6 +653,13 @@ def execute(self, prompt, stream, response, conversation=None, key=None): usage = completion.usage.model_dump() response.response_json = remove_dict_none_values(completion.model_dump()) yield completion.choices[0].message.content + try: + if completion.choices[0].message.annotations: + self.set_annotations( + response, completion.choices[0].message.annotations + ) + except AttributeError: + pass self.set_usage(response, usage) response._prompt_json = redact_data({"messages": messages}) diff --git a/llm/migrations.py b/llm/migrations.py index 0b93188f..82791687 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -255,3 +255,18 @@ def m014_schemas(db): db["responses"].enable_fts( ["prompt", "response"], create_triggers=True, replace=True ) + + +@migration +def m015_response_annotations(db): + db["response_annotations"].create( + { + "id": int, + "response_id": str, + "start_index": int, + "end_index": int, + "data": str, + }, + pk="id", + foreign_keys=(("response_id", "responses", "id"),), + ) diff --git a/llm/models.py b/llm/models.py index a008a3e4..4be36b60 100644 --- a/llm/models.py +++ b/llm/models.py @@ -234,11 +234,36 @@ def __repr__(self): return f"<{self.__class__.__name__}: {self.id} - {count} response{s}" +class Annotation(BaseModel): + start_index: int + end_index: int + data: dict + + @classmethod + def from_row(cls, row): + return cls( + start_index=row["start_index"], + end_index=row["end_index"], + data=json.loads(row["data"]), + ) + + +class Chunk(BaseModel): + text: str + data: Optional[dict] + start_index: int + end_index: int + + def __str__(self): + return self.text + + class _BaseResponse: """Base response class shared between sync and async responses""" prompt: "Prompt" stream: bool + _annotations: Optional[List[Annotation]] = None conversation: Optional["_BaseConversation"] = None _key: Optional[str] = None @@ -255,6 +280,7 @@ def __init__( self.model = model self.stream = stream self._key = key + self._annotations = [] self._chunks: List[str] = [] self._done = False self.response_json = None @@ -282,6 +308,13 @@ def set_usage( self.output_tokens = output self.token_details = details + def add_annotations(self, annotations: List[Annotation]): + self._annotations.extend(annotations) + + @property + def annotations(self): + return self._annotations or [] + @classmethod def from_row(cls, db, row, _async=False): from llm import get_model, get_async_model @@ -293,7 +326,8 @@ def from_row(cls, db, row, _async=False): # Schema schema = None - if row["schema_id"]: + schema_id = row.get("schema_id") + if schema_id: schema = json.loads(db["schemas"].get(row["schema_id"])["content"]) response = cls( @@ -326,8 +360,61 @@ def from_row(cls, db, row, _async=False): [row["id"]], ) ] + # Annotations + response._annotations = [ + Annotation.from_row(arow) + for arow in db.query( + """ + select id, start_index, end_index, data + from response_annotations + where response_id = ? + order by start_index + """, + [row["id"]], + ) + ] return response + def chunks(self) -> Iterator[str]: + return self.chunks_from_text(self.text()) + + # iterates over chunks of text, so an iterator of Chunk + def chunks_from_text(self, text) -> Iterator[Chunk]: + annotations = sorted(self.annotations, key=lambda a: a.start_index) + + current_index = 0 + + for annotation in annotations: + # If there's a gap before this annotation, yield a gap chunk + if current_index < annotation.start_index: + gap_text = text[current_index : annotation.start_index] + yield Chunk( + text=gap_text, + data=None, + start_index=current_index, + end_index=annotation.start_index, + ) + + # Yield the chunk for this annotation + chunk_text = text[annotation.start_index : annotation.end_index] + yield Chunk( + text=chunk_text, + data=annotation.data, + start_index=annotation.start_index, + end_index=annotation.end_index, + ) + + current_index = annotation.end_index + + # If there's text after the last annotation, yield a final gap chunk + if current_index < len(text): + yield Chunk( + text=text[current_index:], + data=None, + start_index=current_index, + end_index=len(text), + ) + def token_usage(self) -> str: return token_usage_string( self.input_tokens, self.output_tokens, self.token_details @@ -377,6 +464,18 @@ def log_to_db(self, db): "schema_id": schema_id, } db["responses"].insert(response) + + if self.annotations: + db["response_annotations"].insert_all( + { + "response_id": response_id, + "start_index": annotation.start_index, + "end_index": annotation.end_index, + "data": json.dumps(annotation.data), + } + for annotation in self.annotations + ) + # Persist any attachments - loop through with index for index, attachment in enumerate(self.prompt.attachments): attachment_id = attachment.id() From 90634b98c51676256b289cf99326f32cf76b6c04 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 18 Mar 2025 19:26:52 -0700 Subject: [PATCH 02/11] Rename chunk.data to chunk.annotation Refs https://github.com/simonw/llm/issues/716#issuecomment-2735161066 --- llm/cli.py | 5 ++--- llm/models.py | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index 5f6cf9e0..025649bb 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1384,13 +1384,12 @@ def format_chunks(chunks: Iterable[Chunk]) -> str: combined_text += chunk.text end_pos = len(combined_text) - # If chunk has data, record its position and data - if chunk.data: + if chunk.annotation: annotation_positions.append( { "start": start_pos, "end": end_pos, - "data": chunk.data, + "data": chunk.annotation, "index": annotation_index, } ) diff --git a/llm/models.py b/llm/models.py index 4be36b60..b7795722 100644 --- a/llm/models.py +++ b/llm/models.py @@ -250,7 +250,7 @@ def from_row(cls, row): class Chunk(BaseModel): text: str - data: Optional[dict] + annotation: Optional[dict] start_index: int end_index: int @@ -390,7 +390,7 @@ def chunks_from_text(self, text) -> Iterator[Chunk]: gap_text = text[current_index : annotation.start_index] yield Chunk( text=gap_text, - data=None, + annotation=None, start_index=current_index, end_index=annotation.start_index, ) @@ -399,7 +399,7 @@ def chunks_from_text(self, text) -> Iterator[Chunk]: chunk_text = text[annotation.start_index : annotation.end_index] yield Chunk( text=chunk_text, - data=annotation.data, + annotation=annotation.data, start_index=annotation.start_index, end_index=annotation.end_index, ) @@ -410,7 +410,7 @@ def chunks_from_text(self, text) -> Iterator[Chunk]: if current_index < len(text): yield Chunk( text=text[current_index:], - data=None, + annotation=None, start_index=current_index, end_index=len(text), ) From 4d8aacf220002fc499006e425c19ef43fae10636 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 19 Mar 2025 21:43:17 -0700 Subject: [PATCH 03/11] Model feature list for advanced plugins documentation !stable-docs --- docs/plugins/advanced-model-plugins.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/plugins/advanced-model-plugins.md b/docs/plugins/advanced-model-plugins.md index 0db0dbdb..dd9e5665 100644 --- a/docs/plugins/advanced-model-plugins.md +++ b/docs/plugins/advanced-model-plugins.md @@ -1,9 +1,15 @@ (advanced-model-plugins)= # Advanced model plugins -The {ref}`model plugin tutorial ` covers the basics of developing a plugin that adds support for a new model. +The {ref}`model plugin tutorial ` covers the basics of developing a plugin that adds support for a new model. This document covers more advanced topics. -This document covers more advanced topics. +Features to consider for your model plugin include: + +- {ref}`Accepting API keys ` using the standard mechanism that incorporates `llm keys set`, environment variables and support for passing an explicit key to the model. +- Including support for {ref}`Async models ` that can be used with Python's `asyncio` library. +- Support for {ref}`structured output ` using JSON schemas. +- Handling {ref}`attachments ` (images, audio and more) for multi-modal models. +- Tracking {ref}`token usage ` for models that charge by the token. (advanced-model-plugins-api-keys)= From 43ccbb7f92828550e48373e4c3840c26e111d144 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 19 Mar 2025 22:00:27 -0700 Subject: [PATCH 04/11] Advanced plugin docs for supporting annotations Refs #716 - describes a yield llm.Chunk() mechanism that does not yet exist. --- docs/openai-models.md | 2 ++ docs/plugins/advanced-model-plugins.md | 50 ++++++++++++++++++++++++++ docs/usage.md | 28 +++++++++++++++ 3 files changed, 80 insertions(+) diff --git a/docs/openai-models.md b/docs/openai-models.md index 4125033f..4850bfe5 100644 --- a/docs/openai-models.md +++ b/docs/openai-models.md @@ -55,6 +55,8 @@ OpenAI Chat: o1-preview OpenAI Chat: o1-mini OpenAI Chat: o3-mini OpenAI Completion: gpt-3.5-turbo-instruct (aliases: 3.5-instruct, chatgpt-instruct) +OpenAI Chat: gpt-4o-search-preview +OpenAI Chat: gpt-4o-mini-search-preview ``` diff --git a/docs/plugins/advanced-model-plugins.md b/docs/plugins/advanced-model-plugins.md index dd9e5665..820f4ea9 100644 --- a/docs/plugins/advanced-model-plugins.md +++ b/docs/plugins/advanced-model-plugins.md @@ -9,6 +9,7 @@ Features to consider for your model plugin include: - Including support for {ref}`Async models ` that can be used with Python's `asyncio` library. - Support for {ref}`structured output ` using JSON schemas. - Handling {ref}`attachments ` (images, audio and more) for multi-modal models. +- Supporting {ref}`annotations ` for models that return different types of text, or objects that should be attached to sections of the response. - Tracking {ref}`token usage ` for models that charge by the token. (advanced-model-plugins-api-keys)= @@ -243,3 +244,52 @@ This example logs 15 input tokens, 340 output tokens and notes that 37 tokens we ```python response.set_usage(input=15, output=340, details={"cached": 37}) ``` + +(advanced-model-plugins-annotations)= + +## Models that return annotations + +Some models may return additional structured data to accompany their text output. LLM calls these **annotations**. Common use-cases for these include: + +- Reasoning models that return a portion of text representing "thinking" tokens prior to the main response. +- Models that return structured citation information attached to portions of the text. +- Similarly, some search models return references to search reults used to generate the response. + +Model plugins can return these annotations directly from their `execute()` method. This method usually yields a series of strings - to attach a citation to one of these strings, return a `Chunk` object instead: + +```python +from llm import Chunk + +... + # Inside the execute() method: + yield llm.Chunk( + text="This has an annotation", + annotation={ + "title": "Document title", + "url": "https://example.com/document", + } + ) +``` +The `annotation=` must be a dictionary but can take any shape. LLM will automatically record the annotation with the start and end index of the generated text that it is attached to. + +Some annotations may need to be attached to a point in the document without a separate end index. In this case the `text=` parameter should be set to `None`. + +Models may exist that do not return their annotations as part of the general stream but instead produce them at the end of the response, specifying start and end indexes to show which parts of the text they should be attached to. This is often the case for non-streaming APIs. + +For these cases the `response.add_annotations()` method should be used at the end of the `.execute()` method: + +```python +response.add_annotations([ + llm.Annotation( + start_index=0, + end_index=10, + data={ + "title": "Document title", + "url": "https://example.com/document" + } + ) +]) +``` +The method accepts a list of `llm.Annotation` objects, each with a `start_index=`, `end_index=` and `data=` dictionary describing the annotation. + +For annotations that are attached to a point rather than a range the `start_index=` and `end_index=` should be the same integer value. \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md index 16fe0e19..d418bfb7 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -742,6 +742,34 @@ OpenAI Completion: gpt-3.5-turbo-instruct (aliases: 3.5-instruct, chatgpt-instru Include the log probabilities of most likely N per token Features: - streaming +OpenAI Chat: gpt-4o-search-preview + Options: + temperature: float + max_tokens: int + top_p: float + frequency_penalty: float + presence_penalty: float + stop: str + logit_bias: dict, str + seed: int + search_context_size: str + Features: + - streaming + - async +OpenAI Chat: gpt-4o-mini-search-preview + Options: + temperature: float + max_tokens: int + top_p: float + frequency_penalty: float + presence_penalty: float + stop: str + logit_bias: dict, str + seed: int + search_context_size: str + Features: + - streaming + - async ``` From 236c808666260d0497cb43bb59e2c601e70d9887 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 21 Mar 2025 15:05:53 -0700 Subject: [PATCH 05/11] Update .excute() signature to allow str or Chunk, refs #716 --- llm/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llm/models.py b/llm/models.py index b7795722..1a395999 100644 --- a/llm/models.py +++ b/llm/models.py @@ -858,7 +858,7 @@ def execute( stream: bool, response: Response, conversation: Optional[Conversation], - ) -> Iterator[str]: + ) -> Iterator[Union[str, Chunk]]: pass @@ -871,7 +871,7 @@ def execute( response: Response, conversation: Optional[Conversation], key: Optional[str], - ) -> Iterator[str]: + ) -> Iterator[Union[str, Chunk]]: pass @@ -914,7 +914,7 @@ async def execute( stream: bool, response: AsyncResponse, conversation: Optional[AsyncConversation], - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[Union[str, Chunk], None]: yield "" @@ -927,7 +927,7 @@ async def execute( response: AsyncResponse, conversation: Optional[AsyncConversation], key: Optional[str], - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[Union[str, Chunk], None]: yield "" From 2ce2510945382af203745d28a01e6c04359b029a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 22 Mar 2025 22:01:08 -0700 Subject: [PATCH 06/11] Various mypy fixes relating to Union[Chunk, str] - refs #716 --- llm/models.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/llm/models.py b/llm/models.py index 1a395999..10617297 100644 --- a/llm/models.py +++ b/llm/models.py @@ -263,7 +263,7 @@ class _BaseResponse: prompt: "Prompt" stream: bool - _annotations: Optional[List[Annotation]] = None + _annotations: List[Annotation] = field(default_factory=list) conversation: Optional["_BaseConversation"] = None _key: Optional[str] = None @@ -280,8 +280,8 @@ def __init__( self.model = model self.stream = stream self._key = key - self._annotations = [] - self._chunks: List[str] = [] + self._annotations: List[Annotation] = [] + self._chunks: List[Union[Chunk, str]] = [] self._done = False self.response_json = None self.conversation = conversation @@ -375,9 +375,6 @@ def from_row(cls, db, row, _async=False): ] return response - def chunks(self) -> Iterator[str]: - return self.chunks_from_text(self.text()) - # iterates over chunks of text, so an iterator of Chunk def chunks_from_text(self, text) -> Iterator[Chunk]: annotations = sorted(self.annotations, key=lambda a: a.start_index) @@ -502,6 +499,9 @@ class Response(_BaseResponse): model: "Model" conversation: Optional["Conversation"] = None + def chunks(self) -> Iterator[Chunk]: + return self.chunks_from_text(self.text()) + def on_done(self, callback): if not self._done: self.done_callbacks.append(callback) @@ -521,7 +521,7 @@ def _force(self): def text(self) -> str: self._force() - return "".join(self._chunks) + return "".join(map(str, self._chunks)) def text_or_raise(self) -> str: return self.text() @@ -546,7 +546,7 @@ def usage(self) -> Usage: details=self.token_details, ) - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[Union[Chunk, str]]: self._start = time.monotonic() self._start_utcnow = datetime.datetime.now(datetime.timezone.utc) if self._done: @@ -592,6 +592,9 @@ class AsyncResponse(_BaseResponse): model: "AsyncModel" conversation: Optional["AsyncConversation"] = None + async def chunks(self) -> Iterator[Chunk]: + return self.chunks_from_text(await self.text()) + @classmethod def from_row(cls, db, row, _async=False): return super().from_row(db, row, _async=True) @@ -617,7 +620,7 @@ def __aiter__(self): self._start_utcnow = datetime.datetime.now(datetime.timezone.utc) return self - async def __anext__(self) -> str: + async def __anext__(self) -> Union[Chunk, str]: if self._done: if not self._chunks: raise StopAsyncIteration @@ -666,11 +669,11 @@ async def _force(self): def text_or_raise(self) -> str: if not self._done: raise ValueError("Response not yet awaited") - return "".join(self._chunks) + return "".join(map(str, self._chunks)) async def text(self) -> str: await self._force() - return "".join(self._chunks) + return "".join(map(str, self._chunks)) async def json(self) -> Optional[Dict[str, Any]]: await self._force() From 42491b6db086a20429fbe6f9cb3dc9f0cd7d8bcb Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 22 Mar 2025 22:34:14 -0700 Subject: [PATCH 07/11] Added llm.examples with example plugins Refs https://github.com/simonw/llm/pull/847#issuecomment-2746026240 --- llm/examples.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 llm/examples.py diff --git a/llm/examples.py b/llm/examples.py new file mode 100644 index 00000000..a2f01ebc --- /dev/null +++ b/llm/examples.py @@ -0,0 +1,61 @@ +import llm +import random +from typing import AsyncGenerator, Union + + +def build_markov_table(text): + words = text.split() + transitions = {} + # Loop through all but the last word + for i in range(len(words) - 1): + word = words[i] + next_word = words[i + 1] + transitions.setdefault(word, []).append(next_word) + return transitions + + +def generate(transitions, length, start_word=None): + all_words = list(transitions.keys()) + next_word = start_word or random.choice(all_words) + for i in range(length): + yield next_word + options = transitions.get(next_word) or all_words + next_word = random.choice(options) + + +class Markov(llm.Model): + model_id = "markov" + + def execute(self, prompt, stream, response, conversation): + text = prompt.prompt + transitions = build_markov_table(text) + for word in generate(transitions, 20): + yield word + " " + + +class AnnotationsModel(llm.Model): + model_id = "annotations" + can_stream = True + + def execute(self, prompt, stream, response, conversation): + yield "Here is text before the annotation. " + yield llm.Chunk( + text="This is the annotated text. ", + annotation={"title": "Annotation Title", "content": "Annotation Content"}, + ) + yield "Here is text after the annotation." + + +class AnnotationsModelAsync(llm.AsyncModel): + model_id = "annotations" + can_stream = True + + async def execute( + self, prompt, stream, response, conversation=None + ) -> AsyncGenerator[Union[llm.Chunk, str], None]: + yield "Here is text before the annotation. " + yield llm.Chunk( + text="This is the annotated text. ", + annotation={"title": "Annotation Title", "content": "Annotation Content"}, + ) + yield "Here is text after the annotation." From d7cd630ac55c84b824534f5663eb956910c45e25 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 22 Mar 2025 22:35:56 -0700 Subject: [PATCH 08/11] Now tracking start/end index for llm.Chunk from .execute() Also printing those out in streaming mode for non-async models, as a debug thing --- docs/plugins/advanced-model-plugins.md | 4 +-- llm/cli.py | 3 ++ llm/models.py | 41 +++++++++++++++++--------- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/docs/plugins/advanced-model-plugins.md b/docs/plugins/advanced-model-plugins.md index 820f4ea9..328cb1e1 100644 --- a/docs/plugins/advanced-model-plugins.md +++ b/docs/plugins/advanced-model-plugins.md @@ -59,7 +59,7 @@ class MyAsyncModel(llm.AsyncModel): async def execute( self, prompt, stream, response, conversation=None - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[Union[llm.Chunk, str], None]: if stream: completion = await client.chat.completions.create( model=self.model_id, @@ -83,7 +83,7 @@ class MyAsyncModel(llm.AsyncKeyModel): ... async def execute( self, prompt, stream, response, conversation=None, key=None - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[Union[llm.Chunk, str], None]: ``` diff --git a/llm/cli.py b/llm/cli.py index 91955a07..2c39b624 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -10,6 +10,7 @@ AsyncConversation, AsyncKeyModel, AsyncResponse, + Chunk, Collection, Conversation, Response, @@ -561,6 +562,8 @@ async def inner(): ) if should_stream: for chunk in response: + if isinstance(chunk, Chunk) and chunk.annotation: + print(chunk.annotation) print(chunk, end="") sys.stdout.flush() print("") diff --git a/llm/models.py b/llm/models.py index 10617297..206ab6d9 100644 --- a/llm/models.py +++ b/llm/models.py @@ -28,7 +28,7 @@ ) from abc import ABC, abstractmethod import json -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from ulid import ULID CONVERSATION_NAME_LENGTH = 32 @@ -250,9 +250,9 @@ def from_row(cls, row): class Chunk(BaseModel): text: str - annotation: Optional[dict] - start_index: int - end_index: int + annotation: Dict[str, Any] = Field(default_factory=dict) + start_index: Optional[int] = None + end_index: Optional[int] = None def __str__(self): return self.text @@ -387,7 +387,7 @@ def chunks_from_text(self, text) -> Iterator[Chunk]: gap_text = text[current_index : annotation.start_index] yield Chunk( text=gap_text, - annotation=None, + annotation={}, start_index=current_index, end_index=annotation.start_index, ) @@ -407,7 +407,7 @@ def chunks_from_text(self, text) -> Iterator[Chunk]: if current_index < len(text): yield Chunk( text=text[current_index:], - annotation=None, + annotation={}, start_index=current_index, end_index=len(text), ) @@ -554,26 +554,32 @@ def __iter__(self) -> Iterator[Union[Chunk, str]]: return if isinstance(self.model, Model): - for chunk in self.model.execute( + chunk_iter = self.model.execute( self.prompt, stream=self.stream, response=self, conversation=self.conversation, - ): - yield chunk - self._chunks.append(chunk) + ) elif isinstance(self.model, KeyModel): - for chunk in self.model.execute( + chunk_iter = self.model.execute( self.prompt, stream=self.stream, response=self, conversation=self.conversation, key=self.model.get_key(self._key), - ): - yield chunk - self._chunks.append(chunk) + ) else: raise Exception("self.model must be a Model or KeyModel") + index = 0 + for chunk in chunk_iter: + if isinstance(chunk, Chunk): + chunk.start_index = index + index += len(chunk.text) + chunk.end_index = index + else: + index += len(chunk) + yield chunk + self._chunks.append(chunk) if self.conversation: self.conversation.responses.append(self) @@ -618,6 +624,7 @@ async def _on_done(self): def __aiter__(self): self._start = time.monotonic() self._start_utcnow = datetime.datetime.now(datetime.timezone.utc) + self._generator_index = 0 return self async def __anext__(self) -> Union[Chunk, str]: @@ -650,6 +657,12 @@ async def __anext__(self) -> Union[Chunk, str]: try: chunk = await self._generator.__anext__() + if isinstance(chunk, Chunk): + chunk.start_index = self._generator_index + self._generator_index += len(chunk.text) + chunk.end_index = self._generator_index + else: + self._generator_index += len(chunk) self._chunks.append(chunk) return chunk except StopAsyncIteration: From bc9ba5b4e5af3c931b94e64550317ff7fe25ce9f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 25 Mar 2025 19:28:10 -0700 Subject: [PATCH 09/11] Extra OpenAI docs including mention of PDFs, closes #834 --- docs/openai-models.md | 9 +++++++++ docs/usage.md | 1 + 2 files changed, 10 insertions(+) diff --git a/docs/openai-models.md b/docs/openai-models.md index e45cbbb9..731b0ad3 100644 --- a/docs/openai-models.md +++ b/docs/openai-models.md @@ -66,6 +66,15 @@ See [the OpenAI models documentation](https://platform.openai.com/docs/models) f [o1-pro](https://platform.openai.com/docs/models/o1-pro) is not available through the Chat Completions API used by LLM's default OpenAI plugin. You can install the new [llm-openai-plugin](https://github.com/simonw/llm-openai-plugin) plugin to access that model. +## Model features + +The following features work with OpenAI models: + +- {ref}`System prompts ` can be used to provide instructions that have a higher weight than the prompt itself. +- {ref}`Attachments `. Many OpenAI models support image inputs - check which ones using `llm models --options`. Any model that accepts images can also accept PDFs. +- {ref}`Schemas ` can be used to influence the JSON structure of the model output. +- {ref}`Model options ` can be used to set parameters like `temperature`. Use `llm models --options` for a full list of supported options. + (openai-models-embedding)= ## OpenAI embedding models diff --git a/docs/usage.md b/docs/usage.md index 26754daf..08703227 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -45,6 +45,7 @@ Will run a prompt of: ``` For models that support them, {ref}`system prompts ` are a better tool for this kind of prompting. +(usage-model-options)= ### Model options Some models support options. You can pass these using `-o/--option name value` - for example, to set the temperature to 1.5 run this: From 2dfd5e167cab3c5111986ec7f7f588da2be22893 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 25 Mar 2025 19:29:58 -0700 Subject: [PATCH 10/11] Link to two more blog entries !stable-docs --- docs/index.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/index.md b/docs/index.md index 26165749..59c043cd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,12 +17,10 @@ Here's a [YouTube video demo](https://www.youtube.com/watch?v=QUXQNi6jQ30) and [ Background on this project: - [llm, ttok and strip-tags—CLI tools for working with ChatGPT and other LLMs](https://simonwillison.net/2023/May/18/cli-tools-for-llms/) - [The LLM CLI tool now supports self-hosted language models via plugins](https://simonwillison.net/2023/Jul/12/llm/) -- [Accessing Llama 2 from the command-line with the llm-replicate plugin](https://simonwillison.net/2023/Jul/18/accessing-llama-2/) -- [Run Llama 2 on your own Mac using LLM and Homebrew](https://simonwillison.net/2023/Aug/1/llama-2-mac/) -- [Catching up on the weird world of LLMs](https://simonwillison.net/2023/Aug/3/weird-world-of-llms/) - [LLM now provides tools for working with embeddings](https://simonwillison.net/2023/Sep/4/llm-embeddings/) - [Build an image search engine with llm-clip, chat with models with llm chat](https://simonwillison.net/2023/Sep/12/llm-clip-and-chat/) -- [Many options for running Mistral models in your terminal using LLM](https://simonwillison.net/2023/Dec/18/mistral/) +- [You can now run prompts against images, audio and video in your terminal using LLM](https://simonwillison.net/2024/Oct/29/llm-multi-modal/) +- [Structured data extraction from unstructured content using LLM schemas](https://simonwillison.net/2025/Feb/28/llm-schemas/) For more check out [the llm tag](https://simonwillison.net/tags/llm/) on my blog. From 2217945d345e67160409ceb9927a87f53cd47ffe Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 27 Mar 2025 20:36:42 -0700 Subject: [PATCH 11/11] Allow -t to take a URL to a template, closes #856 --- docs/templates.md | 5 +++++ llm/cli.py | 36 +++++++++++++++++++++++------------- tests/conftest.py | 25 +++++++++++++++++++++++++ tests/test_templates.py | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 13 deletions(-) diff --git a/docs/templates.md b/docs/templates.md index 2d6cd43f..db20a45c 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -59,6 +59,11 @@ This can be combined with the `-m` option to specify a different model: curl -s https://llm.datasette.io/en/latest/ | \ llm -t summarize -m gpt-3.5-turbo-16k ``` +Templates can also be specified as full URLs to YAML files: +```bash +llm -t https://raw.githubusercontent.com/simonw/llm-templates/refs/heads/main/python-app.yaml \ + 'Python app to pick a random line from a file' +``` (prompt-templates-list)= diff --git a/llm/cli.py b/llm/cli.py index 2c39b624..80b22476 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -2527,7 +2527,28 @@ def logs_db_path(): return user_dir() / "logs.db" +def _parse_yaml_template(name, content): + try: + loaded = yaml.safe_load(content) + except yaml.YAMLError as ex: + raise click.ClickException("Invalid YAML: {}".format(str(ex))) + if isinstance(loaded, str): + return Template(name=name, prompt=loaded) + loaded["name"] = name + try: + return Template(**loaded) + except pydantic.ValidationError as ex: + msg = "A validation error occurred:\n" + msg += render_errors(ex.errors()) + raise click.ClickException(msg) + + def load_template(name): + if name.startswith("https://") or name.startswith("http://"): + response = httpx.get(name) + response.raise_for_status() + return _parse_yaml_template(name, response.text) + if ":" in name: prefix, rest = name.split(":", 1) loaders = get_template_loaders() @@ -2544,19 +2565,8 @@ def load_template(name): path = template_dir() / f"{name}.yaml" if not path.exists(): raise click.ClickException(f"Invalid template: {name}") - try: - loaded = yaml.safe_load(path.read_text()) - except yaml.YAMLError as ex: - raise click.ClickException("Invalid YAML: {}".format(str(ex))) - if isinstance(loaded, str): - return Template(name=name, prompt=loaded) - loaded["name"] = name - try: - return Template(**loaded) - except pydantic.ValidationError as ex: - msg = "A validation error occurred:\n" - msg += render_errors(ex.errors()) - raise click.ClickException(msg) + content = path.read_text() + return _parse_yaml_template(name, content) def get_history(chat_id): diff --git a/tests/conftest.py b/tests/conftest.py index 4469f3aa..cd053dac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -82,6 +82,15 @@ def execute(self, prompt, stream, response, conversation): ) +class EchoModel(llm.Model): + model_id = "echo" + + def execute(self, prompt, stream, response, conversation): + yield "system:\n{}\n\nprompt:\n{}".format( + prompt.system or "", prompt.prompt or "" + ) + + class MockKeyModel(llm.KeyModel): model_id = "mock_key" needs_key = "mock" @@ -207,6 +216,22 @@ def register_models(self, register): pm.unregister(name="undo-mock-models-plugin") +@pytest.fixture(autouse=True) +def register_echo_model(): + class EchoModelPlugin: + __name__ = "EchoModelPlugin" + + @llm.hookimpl + def register_models(self, register): + register(EchoModel()) + + pm.register(EchoModelPlugin(), name="undo-EchoModelPlugin") + try: + yield + finally: + pm.unregister(name="undo-EchoModelPlugin") + + @pytest.fixture def mocked_openai_chat(httpx_mock): httpx_mock.add_response( diff --git a/tests/test_templates.py b/tests/test_templates.py index 33163dbb..a9000467 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -294,3 +294,35 @@ def test_execute_prompt_with_a_template( assert result.exit_code == 1 assert result.output.strip() == expected_error mocked_openai_chat.reset() + + +@pytest.mark.parametrize( + "template,expected", + ( + ("system: system\nprompt: prompt", "system:\nsystem\n\nprompt:\nprompt"), + ( + "prompt: |\n This is\n ```\n code to extract\n ```", + "system:\n\n\nprompt:\nThis is\n```\ncode to extract\n```", + ), + # Now try that with extract: true + ( + "extract: true\nprompt: |\n This is\n ```\n code to extract\n ```", + "code to extract", + ), + ), +) +def test_execute_prompt_from_template_url(httpx_mock, template, expected): + httpx_mock.add_response( + url="https://example.com/prompt.yaml", + method="GET", + text=template, + status_code=200, + ) + runner = CliRunner() + result = runner.invoke( + cli, + ["-t", "https://example.com/prompt.yaml", "-m", "echo"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert result.output.strip() == expected.strip()