Skip to content

Commit

Permalink
Fix llama index CI (mlflow#14115)
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 authored Dec 19, 2024
1 parent 103b143 commit 6ad165e
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions tests/llama_index/test_llama_index_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ def test_trace_llm_complete_stream():
assert attr["model_dict"]["model"] == model_name


def _get_llm_input_content_json(content):
if Version(llama_index.core.__version__) >= Version("0.12.5"):
# in llama-index >= 0.12.5, the input content json format is changed to
# {"blocks": {"block_type": "text", "text": <content>} }
return {
"blocks": [
{
"block_type": "text",
"text": content,
}
]
}
return {"content": content}


@pytest.mark.parametrize("is_async", [True, False])
def test_trace_llm_chat(is_async):
llm = OpenAI()
Expand All @@ -139,19 +154,22 @@ def test_trace_llm_chat(is_async):
assert len(spans) == 1
assert spans[0].name == "OpenAI.achat" if is_async else "OpenAI.chat"
assert spans[0].span_type == SpanType.CHAT_MODEL

content_json = _get_llm_input_content_json("Hello")
assert spans[0].inputs == {
"messages": [{"role": "system", "content": "Hello", "additional_kwargs": {}}]
"messages": [{"role": "system", **content_json, "additional_kwargs": {}}]
}
# `additional_kwargs` was broken until 0.1.30 release of llama-index-llms-openai
expected_kwargs = (
{"completion_tokens": 12, "prompt_tokens": 9, "total_tokens": 21}
if llama_oai_version >= Version("0.1.30")
else {}
)
output_content_json = _get_llm_input_content_json('[{"role": "system", "content": "Hello"}]')
assert spans[0].outputs == {
"message": {
"role": "assistant",
"content": '[{"role": "system", "content": "Hello"}]',
**output_content_json,
"additional_kwargs": {},
},
"raw": ANY,
Expand Down Expand Up @@ -197,19 +215,22 @@ def test_trace_llm_chat_stream():
assert len(spans) == 1
assert spans[0].name == "OpenAI.stream_chat"
assert spans[0].span_type == SpanType.CHAT_MODEL

content_json = _get_llm_input_content_json("Hello")
assert spans[0].inputs == {
"messages": [{"role": "system", "content": "Hello", "additional_kwargs": {}}]
"messages": [{"role": "system", **content_json, "additional_kwargs": {}}]
}
# `additional_kwargs` was broken until 0.1.30 release of llama-index-llms-openai
expected_kwargs = (
{"completion_tokens": 12, "prompt_tokens": 9, "total_tokens": 21}
if llama_oai_version >= Version("0.1.30")
else {}
)
output_content_json = _get_llm_input_content_json("Hello world")
assert spans[0].outputs == {
"message": {
"role": "assistant",
"content": "Hello world",
**output_content_json,
"additional_kwargs": {},
},
"raw": ANY,
Expand Down Expand Up @@ -288,7 +309,12 @@ def test_trace_retriever(multi_index, is_async):
assert spans[0].span_type == SpanType.RETRIEVER
assert spans[0].inputs == {"str_or_query_bundle": "apple"}
assert len(spans[0].outputs) == 1
assert spans[0].outputs[0]["page_content"] == retrieved[0].text

if Version(llama_index.core.__version__) >= Version("0.12.5"):
retrieved_text = retrieved[0].node.text
else:
retrieved_text = retrieved[0].text
assert spans[0].outputs[0]["page_content"] == retrieved_text

assert spans[1].name.startswith("VectorIndexRetriever")
assert spans[1].span_type == SpanType.RETRIEVER
Expand Down

0 comments on commit 6ad165e

Please sign in to comment.