Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ads/aqua/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class PredictEndpoints(ExtendedEnum):
CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
TEXT_COMPLETIONS_ENDPOINT = "/v1/completions"
EMBEDDING_ENDPOINT = "/v1/embedding"
RESPONSES = "/v1/responses"


class Tags(ExtendedEnum):
Expand Down
228 changes: 189 additions & 39 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from tornado.web import HTTPError

from ads.aqua.app import logger
from ads.aqua.client.client import Client, ExtendedRequestError
from ads.aqua.client.openai_client import OpenAI
from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.enums import PredictEndpoints
from ads.aqua.extension.base_handler import AquaAPIhandler
Expand Down Expand Up @@ -221,11 +221,49 @@ def list_shapes(self):


class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):

def _extract_text_from_choice(self, choice):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add docstrings, use type hinting

# choice may be a dict or an object
if isinstance(choice, dict):
# streaming chunk: {"delta": {"content": "..."}}
delta = choice.get("delta")
if isinstance(delta, dict):
return delta.get("content") or delta.get("text") or None
# non-streaming: {"message": {"content": "..."}}
msg = choice.get("message")
if isinstance(msg, dict):
return msg.get("content") or msg.get("text")
# fallback top-level fields
return choice.get("text") or choice.get("content")
# object-like choice
delta = getattr(choice, "delta", None)
if delta is not None:
return getattr(delta, "content", None) or getattr(delta, "text", None)
msg = getattr(choice, "message", None)
if msg is not None:
if isinstance(msg, str):
return msg
return getattr(msg, "content", None) or getattr(msg, "text", None)
return getattr(choice, "text", None) or getattr(choice, "content", None)

def _extract_text_from_chunk(self, chunk):
if chunk :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add docstrings, use type hinting

if isinstance(chunk, dict):
choices = chunk.get("choices") or []
if choices:
return self._extract_text_from_choice(choices[0])
# fallback top-level
return chunk.get("text") or chunk.get("content")
# object-like chunk
choices = getattr(chunk, "choices", None)
if choices:
return self._extract_text_from_choice(choices[0])
return getattr(chunk, "text", None) or getattr(chunk, "content", None)

def _get_model_deployment_response(
self,
model_deployment_id: str,
payload: dict,
route_override_header: Optional[str],
payload: dict
):
"""
Returns the model deployment inference response in a streaming fashion.
Expand Down Expand Up @@ -272,49 +310,160 @@ def _get_model_deployment_response(
"""

model_deployment = AquaDeploymentApp().get(model_deployment_id)
endpoint = model_deployment.endpoint + "/predictWithResponseStream"
endpoint_type = model_deployment.environment_variables.get(
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
)
aqua_client = Client(endpoint=endpoint)

if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
endpoint_type,
route_override_header,
):
endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1"
endpoint_type = payload["endpoint_type"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's check if endoint_type key is present in the payload - add key error validation.

aqua_client = OpenAI(base_url=endpoint)

allowed = {
"max_tokens",
"temperature",
"top_p",
"stop",
"n",
"presence_penalty",
"frequency_penalty",
"logprobs",
"user",
"echo",
}
responses_allowed = {
"temperature", "top_p"
}

# normalize and filter
if payload.get("stop") == []:
payload["stop"] = None

encoded_image = "NA"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the key named NA here? shouldn't we check if "encoded_image" is present in the payload or not?

if encoded_image in payload :
encoded_image = payload["encoded_image"]

model = payload.pop("model")
filtered = {k: v for k, v in payload.items() if k in allowed}
responses_filtered = {k: v for k, v in payload.items() if k in responses_allowed}

if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT == endpoint_type and encoded_image == "NA":
try:
for chunk in aqua_client.chat(
messages=payload.pop("messages"),
payload=payload,
stream=True,
):
try:
if "text" in chunk["choices"][0]:
yield chunk["choices"][0]["text"]
elif "content" in chunk["choices"][0]["delta"]:
yield chunk["choices"][0]["delta"]["content"]
except Exception as e:
logger.debug(
f"Exception occurred while parsing streaming response: {e}"
)
api_kwargs = {
"model": model,
"messages": [{"role": "user", "content": payload["prompt"]}],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

payload["prompt"] can error out if prompt key isn't present. We can do some key validation at the beginning of the function for all attributes fetched from payload.

"stream": True,
**filtered
}
if "chat_template" in payload:
chat_template = payload.pop("chat_template")
api_kwargs["extra_body"] = {"chat_template": chat_template}

stream = aqua_client.chat.completions.create(**api_kwargs)

for chunk in stream:
if chunk :
piece = self._extract_text_from_chunk(chunk)
if piece :
yield piece
except ExtendedRequestError as ex:
raise HTTPError(400, str(ex))
except Exception as ex:
raise HTTPError(500, str(ex))

elif (
endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT
and encoded_image != "NA"
):
file_type = payload.pop("file_type")
if file_type.startswith("image"):
api_kwargs = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": payload["prompt"]},
{
"type": "image_url",
"image_url": {"url": f"{self.encoded_image}"},
Copy link
Member

@VipulMascarenhas VipulMascarenhas Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks incorrect - self does not have encoded_image attribute

},
],
}
],
"stream": True,
**filtered
}

# Add chat_template for image-based chat completions
if "chat_template" in payload:
chat_template = payload.pop("chat_template")
api_kwargs["extra_body"] = {"chat_template": chat_template}

response = aqua_client.chat.completions.create(**api_kwargs)

elif self.file_type.startswith("audio"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self does not have file_type attribute, shouldn't this be just file_type?

api_kwargs = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": payload["prompt"]},
{
"type": "audio_url",
"audio_url": {"url": f"{self.encoded_image}"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks incorrect - should encoded_image be passed when file type is audio?

},
],
}
],
"stream": True,
**filtered
}

# Add chat_template for audio-based chat completions
if "chat_template" in payload:
chat_template = payload.pop("chat_template")
api_kwargs["extra_body"] = {"chat_template": chat_template}

response = aqua_client.chat.completions.create(**api_kwargs)
try:
for chunk in response:
piece = self._extract_text_from_chunk(chunk)
if piece:
print(piece, end="", flush=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this yield instead of printing the chunk?

except ExtendedRequestError as ex:
raise HTTPError(400, str(ex))
except Exception as ex:
raise HTTPError(500, str(ex))
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
try:
for chunk in aqua_client.generate(
prompt=payload.pop("prompt"),
payload=payload,
stream=True,
for chunk in aqua_client.completions.create(
prompt=payload["prompt"], stream=True, model=model, **filtered
):
try:
yield chunk["choices"][0]["text"]
except Exception as e:
logger.debug(
f"Exception occurred while parsing streaming response: {e}"
)
if chunk :
piece = self._extract_text_from_chunk(chunk)
if piece :
yield piece
except ExtendedRequestError as ex:
raise HTTPError(400, str(ex))
except Exception as ex:
raise HTTPError(500, str(ex))

elif endpoint_type == PredictEndpoints.RESPONSES:
api_kwargs = {
"model": model,
"input": payload["prompt"],
"stream": True
}

if "temperature" in responses_filtered:
api_kwargs["temperature"] = responses_filtered["temperature"]
if "top_p" in responses_filtered:
api_kwargs["top_p"] = responses_filtered["top_p"]

response = aqua_client.responses.create(**api_kwargs)
try:
for chunk in response:
if chunk :
piece = self._extract_text_from_chunk(chunk)
if piece :
yield piece
except ExtendedRequestError as ex:
raise HTTPError(400, str(ex))
except Exception as ex:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle unknown endpoint type with something like:

else:
    raise HTTPError(400, f"Unsupported endpoint_type: {endpoint_type}")

Expand All @@ -340,19 +489,20 @@ def post(self, model_deployment_id):
prompt = input_data.get("prompt")
messages = input_data.get("messages")


if not prompt and not messages:
raise HTTPError(
400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages")
)
if not input_data.get("model"):
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
route_override_header = self.request.headers.get("route", None)
self.set_header("Content-Type", "text/event-stream")
response_gen = self._get_model_deployment_response(
model_deployment_id, input_data, route_override_header
model_deployment_id, input_data
)
try:
for chunk in response_gen:
print(chunk)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we yield the chunk in this case?

self.write(chunk)
self.flush()
self.finish()
Expand Down
3 changes: 1 addition & 2 deletions tests/unitary/with_extras/aqua/test_deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,7 @@ def test_post(self, mock_get_model_deployment_response):

mock_get_model_deployment_response.assert_called_with(
"mock-deployment-id",
{"prompt": "Hello", "model": "some-model"},
"test-route",
{"prompt": "Hello", "model": "some-model"}
)
self.handler.write.assert_any_call("chunk1")
self.handler.write.assert_any_call("chunk2")
Expand Down
Loading