Skip to content

Commit

Permalink
Migrate ChatGPT function to openai v1.0 (#1368)
Browse files Browse the repository at this point in the history
Migrate ChatGPT function to openai v1.0.

The test is skipped in circleCI because we must supply the
`OPENAI_API_KEY`. The test passes on local machine.

- [x] Upgrade ChatGPT function.
- [x] Upgrade Dall-e function.
- [x] Update unit test cases.
- [x] Verify that notebooks work correctly.
  • Loading branch information
pchunduri6 committed Nov 17, 2023
1 parent 0c25a44 commit 5aaa447
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 30 deletions.
20 changes: 11 additions & 9 deletions evadb/functions/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,21 @@ def setup(
)
def forward(self, text_df):
try_to_import_openai()
import openai
from openai import OpenAI

@retry(tries=6, delay=20)
def completion_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs)

openai.api_key = self.openai_api_key
if len(openai.api_key) == 0:
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
api_key = self.openai_api_key
if len(self.openai_api_key) == 0:
api_key = os.environ.get("OPENAI_API_KEY", "")
assert (
len(openai.api_key) != 0
len(api_key) != 0
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"

client = OpenAI(api_key=api_key)

@retry(tries=6, delay=20)
def completion_with_backoff(**kwargs):
return client.chat.completions.create(**kwargs)

queries = text_df[text_df.columns[0]]
content = text_df[text_df.columns[0]]
if len(text_df.columns) > 1:
Expand Down
19 changes: 10 additions & 9 deletions evadb/functions/dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,25 @@ def setup(self, openai_api_key="") -> None:
)
def forward(self, text_df):
try_to_import_openai()
import openai
from openai import OpenAI

openai.api_key = self.openai_api_key
# If not found, try OS Environment Variable
if len(openai.api_key) == 0:
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
api_key = self.openai_api_key
if len(self.openai_api_key) == 0:
api_key = os.environ.get("OPENAI_API_KEY", "")
assert (
len(openai.api_key) != 0
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"
len(api_key) != 0
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"

client = OpenAI(api_key=api_key)

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
response = openai.Image.create(prompt=query, n=1, size="1024x1024")
response = client.images.generate(prompt=query, n=1, size="1024x1024")

# Download the image from the link
image_response = requests.get(response["data"][0]["url"])
image_response = requests.get(response.data[0].url)
image = Image.open(BytesIO(image_response.content))

# Convert the image to an array format suitable for the DataFrame
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def read(path, encoding="utf-8"):
"sentence-transformers",
"protobuf",
"bs4",
"openai==0.28", # CHATGPT
"openai>=1.0", # CHATGPT
"gpt4all", # PRIVATE GPT
"sentencepiece", # TRANSFORMERS
]
Expand Down
13 changes: 7 additions & 6 deletions test/integration_tests/long/functions/test_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest
from test.markers import chatgpt_skip_marker
from test.util import get_evadb_for_testing
Expand All @@ -22,9 +23,8 @@
from evadb.server.command_handler import execute_query_fetch_all


def create_dummy_csv_file(config) -> str:
tmp_dir_from_config = config.get_value("storage", "tmp_dir")

def create_dummy_csv_file(catalog) -> str:
tmp_dir_from_config = catalog.get_configuration_catalog_value("tmp_dir")
df_dict = [
{
"prompt": "summarize",
Expand All @@ -49,17 +49,18 @@ def setUp(self) -> None:
);"""
execute_query_fetch_all(self.evadb, create_table_query)

self.csv_file_path = create_dummy_csv_file(self.evadb.config)
self.csv_file_path = create_dummy_csv_file(self.evadb.catalog())

csv_query = f"""LOAD CSV '{self.csv_file_path}' INTO MyTextCSV;"""
execute_query_fetch_all(self.evadb, csv_query)
os.environ["OPENAI_API_KEY"] = "sk-..."

def tearDown(self) -> None:
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS MyTextCSV;")

@chatgpt_skip_marker
def test_openai_chat_completion_function(self):
function_name = "OpenAIChatCompletion"
function_name = "ChatGPT"
execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")

create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name}
Expand All @@ -69,4 +70,4 @@ def test_openai_chat_completion_function(self):

gpt_query = f"SELECT {function_name}('summarize', content) FROM MyTextCSV;"
output_batch = execute_query_fetch_all(self.evadb, gpt_query)
self.assertEqual(output_batch.columns, ["openaichatcompletion.response"])
self.assertEqual(output_batch.columns, ["chatgpt.response"])
35 changes: 30 additions & 5 deletions test/unit_tests/test_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,26 @@
import unittest
from io import BytesIO
from test.util import get_evadb_for_testing
from typing import List, Optional
from unittest.mock import MagicMock, patch

from PIL import Image
from PIL import Image as PILImage
from pydantic import AnyUrl, BaseModel

from evadb.server.command_handler import execute_query_fetch_all


class Image(BaseModel):
b64_json: Optional[str] # Replace with the actual type if different
revised_prompt: Optional[str] # Replace with the actual type if different
url: AnyUrl


class ImagesResponse(BaseModel):
created: Optional[int] # Replace with the actual type if different
data: List[Image]


class DallEFunctionTest(unittest.TestCase):
def setUp(self) -> None:
self.evadb = get_evadb_for_testing()
Expand All @@ -43,10 +56,10 @@ def tearDown(self) -> None:

@patch.dict("os.environ", {"OPENAI_API_KEY": "mocked_openai_key"})
@patch("requests.get")
@patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]})
def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
@patch("openai.OpenAI")
def test_dalle_image_generation(self, mock_openai, mock_requests_get):
# Generate a 1x1 white pixel PNG image in memory
img = Image.new("RGB", (1, 1), color="white")
img = PILImage.new("RGB", (1, 1), color="white")
img_byte_array = BytesIO()
img.save(img_byte_array, format="PNG")
mock_image_content = img_byte_array.getvalue()
Expand All @@ -55,6 +68,18 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
mock_response.content = mock_image_content
mock_requests_get.return_value = mock_response

# Set up the mock for OpenAI instance
mock_openai_instance = mock_openai.return_value
mock_openai_instance.images.generate.return_value = ImagesResponse(
data=[
Image(
b64_json=None,
revised_prompt=None,
url="https://images.openai.com/1234.png",
)
]
)

function_name = "DallE"

execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")
Expand All @@ -67,6 +92,6 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
execute_query_fetch_all(self.evadb, gpt_query)

mock_openai_create.assert_called_once_with(
mock_openai_instance.images.generate.assert_called_once_with(
prompt="a surreal painting of a cat", n=1, size="1024x1024"
)

0 comments on commit 5aaa447

Please sign in to comment.