-
Notifications
You must be signed in to change notification settings - Fork 60
Removing kernal messaging in aqua #1304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
23a87a9
e7c18d3
ab7e984
a534b62
30f8e47
f45350a
6e46a8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,8 +7,8 @@ | |
|
|
||
| from tornado.web import HTTPError | ||
|
|
||
| from ads.aqua.app import logger | ||
| from ads.aqua.client.client import Client, ExtendedRequestError | ||
| from ads.aqua.client.openai_client import OpenAI | ||
| from ads.aqua.common.decorator import handle_exceptions | ||
| from ads.aqua.common.enums import PredictEndpoints | ||
| from ads.aqua.extension.base_handler import AquaAPIhandler | ||
|
|
@@ -221,11 +221,49 @@ def list_shapes(self): | |
|
|
||
|
|
||
| class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler): | ||
|
|
||
| def _extract_text_from_choice(self, choice): | ||
| # choice may be a dict or an object | ||
| if isinstance(choice, dict): | ||
| # streaming chunk: {"delta": {"content": "..."}} | ||
| delta = choice.get("delta") | ||
| if isinstance(delta, dict): | ||
| return delta.get("content") or delta.get("text") or None | ||
| # non-streaming: {"message": {"content": "..."}} | ||
| msg = choice.get("message") | ||
| if isinstance(msg, dict): | ||
| return msg.get("content") or msg.get("text") | ||
| # fallback top-level fields | ||
| return choice.get("text") or choice.get("content") | ||
| # object-like choice | ||
| delta = getattr(choice, "delta", None) | ||
| if delta is not None: | ||
| return getattr(delta, "content", None) or getattr(delta, "text", None) | ||
| msg = getattr(choice, "message", None) | ||
| if msg is not None: | ||
| if isinstance(msg, str): | ||
| return msg | ||
| return getattr(msg, "content", None) or getattr(msg, "text", None) | ||
| return getattr(choice, "text", None) or getattr(choice, "content", None) | ||
|
|
||
| def _extract_text_from_chunk(self, chunk): | ||
| if chunk : | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add docstrings, use type hinting |
||
| if isinstance(chunk, dict): | ||
| choices = chunk.get("choices") or [] | ||
| if choices: | ||
| return self._extract_text_from_choice(choices[0]) | ||
| # fallback top-level | ||
| return chunk.get("text") or chunk.get("content") | ||
| # object-like chunk | ||
| choices = getattr(chunk, "choices", None) | ||
| if choices: | ||
| return self._extract_text_from_choice(choices[0]) | ||
| return getattr(chunk, "text", None) or getattr(chunk, "content", None) | ||
|
|
||
| def _get_model_deployment_response( | ||
| self, | ||
| model_deployment_id: str, | ||
| payload: dict, | ||
| route_override_header: Optional[str], | ||
| payload: dict | ||
| ): | ||
| """ | ||
| Returns the model deployment inference response in a streaming fashion. | ||
|
|
@@ -272,49 +310,160 @@ def _get_model_deployment_response( | |
| """ | ||
|
|
||
| model_deployment = AquaDeploymentApp().get(model_deployment_id) | ||
| endpoint = model_deployment.endpoint + "/predictWithResponseStream" | ||
| endpoint_type = model_deployment.environment_variables.get( | ||
| "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT | ||
| ) | ||
| aqua_client = Client(endpoint=endpoint) | ||
|
|
||
| if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( | ||
| endpoint_type, | ||
| route_override_header, | ||
| ): | ||
| endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1" | ||
| endpoint_type = payload["endpoint_type"] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's check if |
||
| aqua_client = OpenAI(base_url=endpoint) | ||
|
|
||
| allowed = { | ||
| "max_tokens", | ||
| "temperature", | ||
| "top_p", | ||
| "stop", | ||
| "n", | ||
| "presence_penalty", | ||
| "frequency_penalty", | ||
| "logprobs", | ||
| "user", | ||
| "echo", | ||
| } | ||
| responses_allowed = { | ||
| "temperature", "top_p" | ||
| } | ||
|
|
||
| # normalize and filter | ||
| if payload.get("stop") == []: | ||
| payload["stop"] = None | ||
|
|
||
| encoded_image = "NA" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the key named NA here? shouldn't we check if "encoded_image" is present in the payload or not? |
||
| if encoded_image in payload : | ||
| encoded_image = payload["encoded_image"] | ||
|
|
||
| model = payload.pop("model") | ||
| filtered = {k: v for k, v in payload.items() if k in allowed} | ||
| responses_filtered = {k: v for k, v in payload.items() if k in responses_allowed} | ||
|
|
||
| if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT == endpoint_type and encoded_image == "NA": | ||
| try: | ||
| for chunk in aqua_client.chat( | ||
| messages=payload.pop("messages"), | ||
| payload=payload, | ||
| stream=True, | ||
| ): | ||
| try: | ||
| if "text" in chunk["choices"][0]: | ||
| yield chunk["choices"][0]["text"] | ||
| elif "content" in chunk["choices"][0]["delta"]: | ||
| yield chunk["choices"][0]["delta"]["content"] | ||
| except Exception as e: | ||
| logger.debug( | ||
| f"Exception occurred while parsing streaming response: {e}" | ||
| ) | ||
| api_kwargs = { | ||
| "model": model, | ||
| "messages": [{"role": "user", "content": payload["prompt"]}], | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. payload["prompt"] can error out if prompt key isn't present. We can do some key validation at the beginning of the function for all attributes fetched from payload. |
||
| "stream": True, | ||
| **filtered | ||
| } | ||
| if "chat_template" in payload: | ||
| chat_template = payload.pop("chat_template") | ||
| api_kwargs["extra_body"] = {"chat_template": chat_template} | ||
|
|
||
| stream = aqua_client.chat.completions.create(**api_kwargs) | ||
|
|
||
| for chunk in stream: | ||
| if chunk : | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece : | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
| raise HTTPError(500, str(ex)) | ||
|
|
||
| elif ( | ||
| endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT | ||
| and encoded_image != "NA" | ||
| ): | ||
| file_type = payload.pop("file_type") | ||
| if file_type.startswith("image"): | ||
| api_kwargs = { | ||
| "model": model, | ||
| "messages": [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| {"type": "text", "text": payload["prompt"]}, | ||
| { | ||
| "type": "image_url", | ||
| "image_url": {"url": f"{self.encoded_image}"}, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks incorrect - |
||
| }, | ||
| ], | ||
| } | ||
| ], | ||
| "stream": True, | ||
| **filtered | ||
| } | ||
|
|
||
| # Add chat_template for image-based chat completions | ||
| if "chat_template" in payload: | ||
| chat_template = payload.pop("chat_template") | ||
| api_kwargs["extra_body"] = {"chat_template": chat_template} | ||
|
|
||
| response = aqua_client.chat.completions.create(**api_kwargs) | ||
|
|
||
| elif self.file_type.startswith("audio"): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| api_kwargs = { | ||
| "model": model, | ||
| "messages": [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| {"type": "text", "text": payload["prompt"]}, | ||
| { | ||
| "type": "audio_url", | ||
| "audio_url": {"url": f"{self.encoded_image}"}, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks incorrect - should |
||
| }, | ||
| ], | ||
| } | ||
| ], | ||
| "stream": True, | ||
| **filtered | ||
| } | ||
|
|
||
| # Add chat_template for audio-based chat completions | ||
| if "chat_template" in payload: | ||
| chat_template = payload.pop("chat_template") | ||
| api_kwargs["extra_body"] = {"chat_template": chat_template} | ||
|
|
||
| response = aqua_client.chat.completions.create(**api_kwargs) | ||
| try: | ||
| for chunk in response: | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece: | ||
| print(piece, end="", flush=True) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this |
||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
| raise HTTPError(500, str(ex)) | ||
| elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: | ||
| try: | ||
| for chunk in aqua_client.generate( | ||
| prompt=payload.pop("prompt"), | ||
| payload=payload, | ||
| stream=True, | ||
| for chunk in aqua_client.completions.create( | ||
| prompt=payload["prompt"], stream=True, model=model, **filtered | ||
| ): | ||
| try: | ||
| yield chunk["choices"][0]["text"] | ||
| except Exception as e: | ||
| logger.debug( | ||
| f"Exception occurred while parsing streaming response: {e}" | ||
| ) | ||
| if chunk : | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece : | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
| raise HTTPError(500, str(ex)) | ||
|
|
||
| elif endpoint_type == PredictEndpoints.RESPONSES: | ||
| api_kwargs = { | ||
| "model": model, | ||
| "input": payload["prompt"], | ||
| "stream": True | ||
| } | ||
|
|
||
| if "temperature" in responses_filtered: | ||
| api_kwargs["temperature"] = responses_filtered["temperature"] | ||
| if "top_p" in responses_filtered: | ||
| api_kwargs["top_p"] = responses_filtered["top_p"] | ||
|
|
||
| response = aqua_client.responses.create(**api_kwargs) | ||
| try: | ||
| for chunk in response: | ||
| if chunk : | ||
| piece = self._extract_text_from_chunk(chunk) | ||
| if piece : | ||
| yield piece | ||
| except ExtendedRequestError as ex: | ||
| raise HTTPError(400, str(ex)) | ||
| except Exception as ex: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. handle unknown endpoint type with something like: |
||
|
|
@@ -340,19 +489,20 @@ def post(self, model_deployment_id): | |
| prompt = input_data.get("prompt") | ||
| messages = input_data.get("messages") | ||
|
|
||
|
|
||
| if not prompt and not messages: | ||
| raise HTTPError( | ||
| 400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages") | ||
| ) | ||
| if not input_data.get("model"): | ||
| raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model")) | ||
| route_override_header = self.request.headers.get("route", None) | ||
| self.set_header("Content-Type", "text/event-stream") | ||
| response_gen = self._get_model_deployment_response( | ||
| model_deployment_id, input_data, route_override_header | ||
| model_deployment_id, input_data | ||
| ) | ||
| try: | ||
| for chunk in response_gen: | ||
| print(chunk) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't we |
||
| self.write(chunk) | ||
| self.flush() | ||
| self.finish() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add docstrings, use type hinting