From 5aaa447dbabf89236177081672603ccae23c4437 Mon Sep 17 00:00:00 2001 From: Pramod Chunduri <43007047+pchunduri6@users.noreply.github.com> Date: Fri, 17 Nov 2023 17:46:01 -0500 Subject: [PATCH] Migrate ChatGPT function to openai v1.0 (#1368) Migrate ChatGPT function to openai v1.0. The test is skipped in circleCI because we must supply the `OPENAI_API_KEY`. The test passes on local machine. - [x] Upgrade ChatGPT function. - [x] Upgrade Dall-e function. - [x] Update unit test cases. - [x] Verify that notebooks work correctly. --- evadb/functions/chatgpt.py | 20 ++++++----- evadb/functions/dalle.py | 19 +++++----- setup.py | 2 +- .../long/functions/test_chatgpt.py | 13 +++---- test/unit_tests/test_dalle.py | 35 ++++++++++++++++--- 5 files changed, 59 insertions(+), 30 deletions(-) diff --git a/evadb/functions/chatgpt.py b/evadb/functions/chatgpt.py index fadc61191..bf0d33868 100644 --- a/evadb/functions/chatgpt.py +++ b/evadb/functions/chatgpt.py @@ -115,19 +115,21 @@ def setup( ) def forward(self, text_df): try_to_import_openai() - import openai + from openai import OpenAI - @retry(tries=6, delay=20) - def completion_with_backoff(**kwargs): - return openai.ChatCompletion.create(**kwargs) - - openai.api_key = self.openai_api_key - if len(openai.api_key) == 0: - openai.api_key = os.environ.get("OPENAI_API_KEY", "") + api_key = self.openai_api_key + if len(self.openai_api_key) == 0: + api_key = os.environ.get("OPENAI_API_KEY", "") assert ( - len(openai.api_key) != 0 + len(api_key) != 0 ), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)" + client = OpenAI(api_key=api_key) + + @retry(tries=6, delay=20) + def completion_with_backoff(**kwargs): + return client.chat.completions.create(**kwargs) + queries = text_df[text_df.columns[0]] content = text_df[text_df.columns[0]] if len(text_df.columns) > 1: diff --git a/evadb/functions/dalle.py b/evadb/functions/dalle.py index 7c1dc39dd..03c2e77f8 100644 --- a/evadb/functions/dalle.py +++ b/evadb/functions/dalle.py @@ -56,24 +56,25 @@ def setup(self, openai_api_key="") -> None: ) def forward(self, text_df): try_to_import_openai() - import openai + from openai import OpenAI - openai.api_key = self.openai_api_key - # If not found, try OS Environment Variable - if len(openai.api_key) == 0: - openai.api_key = os.environ.get("OPENAI_API_KEY", "") + api_key = self.openai_api_key + if len(self.openai_api_key) == 0: + api_key = os.environ.get("OPENAI_API_KEY", "") assert ( - len(openai.api_key) != 0 - ), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)" + len(api_key) != 0 + ), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)" + + client = OpenAI(api_key=api_key) def generate_image(text_df: PandasDataframe): results = [] queries = text_df[text_df.columns[0]] for query in queries: - response = openai.Image.create(prompt=query, n=1, size="1024x1024") + response = client.images.generate(prompt=query, n=1, size="1024x1024") # Download the image from the link - image_response = requests.get(response["data"][0]["url"]) + image_response = requests.get(response.data[0].url) image = Image.open(BytesIO(image_response.content)) # Convert the image to an array format suitable for the DataFrame diff --git a/setup.py b/setup.py index 3334fa836..61dc0b8c6 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ def read(path, encoding="utf-8"): "sentence-transformers", "protobuf", "bs4", - "openai==0.28", # CHATGPT + "openai>=1.0", # CHATGPT "gpt4all", # PRIVATE GPT "sentencepiece", # TRANSFORMERS ] diff --git a/test/integration_tests/long/functions/test_chatgpt.py b/test/integration_tests/long/functions/test_chatgpt.py index b72612d05..3f8cd9a92 100644 --- a/test/integration_tests/long/functions/test_chatgpt.py +++ b/test/integration_tests/long/functions/test_chatgpt.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from test.markers import chatgpt_skip_marker from test.util import get_evadb_for_testing @@ -22,9 +23,8 @@ from evadb.server.command_handler import execute_query_fetch_all -def create_dummy_csv_file(config) -> str: - tmp_dir_from_config = config.get_value("storage", "tmp_dir") - +def create_dummy_csv_file(catalog) -> str: + tmp_dir_from_config = catalog.get_configuration_catalog_value("tmp_dir") df_dict = [ { "prompt": "summarize", @@ -49,17 +49,18 @@ def setUp(self) -> None: );""" execute_query_fetch_all(self.evadb, create_table_query) - self.csv_file_path = create_dummy_csv_file(self.evadb.config) + self.csv_file_path = create_dummy_csv_file(self.evadb.catalog()) csv_query = f"""LOAD CSV '{self.csv_file_path}' INTO MyTextCSV;""" execute_query_fetch_all(self.evadb, csv_query) + os.environ["OPENAI_API_KEY"] = "sk-..." def tearDown(self) -> None: execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS MyTextCSV;") @chatgpt_skip_marker def test_openai_chat_completion_function(self): - function_name = "OpenAIChatCompletion" + function_name = "ChatGPT" execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name} @@ -69,4 +70,4 @@ def test_openai_chat_completion_function(self): gpt_query = f"SELECT {function_name}('summarize', content) FROM MyTextCSV;" output_batch = execute_query_fetch_all(self.evadb, gpt_query) - self.assertEqual(output_batch.columns, ["openaichatcompletion.response"]) + self.assertEqual(output_batch.columns, ["chatgpt.response"]) diff --git a/test/unit_tests/test_dalle.py b/test/unit_tests/test_dalle.py index c434a4db4..373e90e44 100644 --- a/test/unit_tests/test_dalle.py +++ b/test/unit_tests/test_dalle.py @@ -16,13 +16,26 @@ import unittest from io import BytesIO from test.util import get_evadb_for_testing +from typing import List, Optional from unittest.mock import MagicMock, patch -from PIL import Image +from PIL import Image as PILImage +from pydantic import AnyUrl, BaseModel from evadb.server.command_handler import execute_query_fetch_all +class Image(BaseModel): + b64_json: Optional[str] # Replace with the actual type if different + revised_prompt: Optional[str] # Replace with the actual type if different + url: AnyUrl + + +class ImagesResponse(BaseModel): + created: Optional[int] # Replace with the actual type if different + data: List[Image] + + class DallEFunctionTest(unittest.TestCase): def setUp(self) -> None: self.evadb = get_evadb_for_testing() @@ -43,10 +56,10 @@ def tearDown(self) -> None: @patch.dict("os.environ", {"OPENAI_API_KEY": "mocked_openai_key"}) @patch("requests.get") - @patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]}) - def test_dalle_image_generation(self, mock_openai_create, mock_requests_get): + @patch("openai.OpenAI") + def test_dalle_image_generation(self, mock_openai, mock_requests_get): # Generate a 1x1 white pixel PNG image in memory - img = Image.new("RGB", (1, 1), color="white") + img = PILImage.new("RGB", (1, 1), color="white") img_byte_array = BytesIO() img.save(img_byte_array, format="PNG") mock_image_content = img_byte_array.getvalue() @@ -55,6 +68,18 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get): mock_response.content = mock_image_content mock_requests_get.return_value = mock_response + # Set up the mock for OpenAI instance + mock_openai_instance = mock_openai.return_value + mock_openai_instance.images.generate.return_value = ImagesResponse( + data=[ + Image( + b64_json=None, + revised_prompt=None, + url="https://images.openai.com/1234.png", + ) + ] + ) + function_name = "DallE" execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") @@ -67,6 +92,6 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get): gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;" execute_query_fetch_all(self.evadb, gpt_query) - mock_openai_create.assert_called_once_with( + mock_openai_instance.images.generate.assert_called_once_with( prompt="a surreal painting of a cat", n=1, size="1024x1024" )