diff --git a/README.md b/README.md index 903b224..8a42ced 100644 --- a/README.md +++ b/README.md @@ -204,7 +204,23 @@ HUGGINGFACEHUB_API_TOKEN="MY_HF_API_KEY" You can also pass `huggingfacehub_api_token` as a named parameter. +#### AWS Bedrock +Create your access keys in security credentials of your user in AWS. +Then write in the files ```~/.aws/config``` and ````~/.aws/credentials```` for Linux and MacOS or ````%USERPROFILE%\.aws\config```` and ````%USERPROFILE%\.aws\credentials```` for Windows: + +In credentials: +```shell +[default] +aws_access_key_id = +aws_secret_access_key = +``` + +In config: +```shell +[default] +region = +``` ### ⚙️ Install locally @@ -238,6 +254,7 @@ The current available tasks in Promptmeteo are: | `CodeGenerator` | Code generation | | `ApiGenerator` | API REST generation | | `ApiFormatter` | API REST correction | +| `Summarizer` | Text summarization | ### ✅ Available Model @@ -259,3 +276,5 @@ The current available `model_name` and `language` values are: | google | text-bison@001 | en | | google | text-bison-32k | es | | google | text-bison-32k | en | +| bedrock | anthropic.claude-v2 | en | +| bedrock | anthropic.claude-v2 | es | diff --git a/examples/05_test_openai_classification_prompts.ipynb b/examples/05_test_openai_classification_prompts.ipynb new file mode 100644 index 0000000..fb85aad --- /dev/null +++ b/examples/05_test_openai_classification_prompts.ipynb @@ -0,0 +1,1718 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "73d808a5-a4dd-4ec5-b4bc-438f1dd5596f", + "metadata": {}, + "source": [ + "# OpenAI test classification prompts." + ] + }, + { + "cell_type": "markdown", + "id": "7ffd3703-20b9-46b6-8ad3-502ec27d2e49", + "metadata": {}, + "source": [ + "In this notebook we are going to use the **Amazon Review Dataset** to test Promptmeteo in the sentiment analysis task" + ] + }, + { + "cell_type": "markdown", + "id": "dad1d69f-59bf-4f34-9000-46e5851d99f6", + "metadata": {}, + "source": [ + "## 1. Data Preparation - EN - Build sentiment dataset." + ] + }, + { + "cell_type": "markdown", + "id": "7993f034-cdd6-4424-b6d1-5917b0638583", + "metadata": {}, + "source": [ + "The dataset contains reviews from Amazon in English collected between November 1, 2015 and November 1, 2019. Each record in the dataset contains the review text, the review title, the star rating, an anonymized reviewer ID, an anonymized product ID and the coarse-grained product category (e.g. ‘books’, ‘appliances’, etc.). The corpus is balanced across stars, so each star rating constitutes 20% of the reviews in each language." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d2c9f56c-4717-42a8-b9d4-cfcc20215f99", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 2)
REVIEWTARGET
strstr
"I reuse my Nes…"positive"
"Fits great kin…"positive"
"Movie freezes …"negative"
"This is my thi…"positive"
"For the money,…"neutral"
" + ], + "text/plain": [ + "shape: (5, 2)\n", + "┌───────────────────────────────────┬──────────┐\n", + "│ REVIEW ┆ TARGET │\n", + "│ --- ┆ --- │\n", + "│ str ┆ str │\n", + "╞═══════════════════════════════════╪══════════╡\n", + "│ I reuse my Nespresso capsules an… ┆ positive │\n", + "│ Fits great kinda expensive but i… ┆ positive │\n", + "│ Movie freezes up. Can't watch it… ┆ negative │\n", + "│ This is my third G-shock as I ha… ┆ positive │\n", + "│ For the money, it's a good buy..… ┆ neutral │\n", + "└───────────────────────────────────┴──────────┘" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import polars as pl\n", + "import sys; sys.path.append('..')\n", + "\n", + "data = pl.read_parquet('../data/amazon_reviews_en/amazon_reviews_multi-test.parquet')\n", + "sql = pl.SQLContext()\n", + "sql.register('data', data)\n", + "\n", + "sentiment_data = sql.execute(\"\"\"\n", + " SELECT\n", + " review_body as REVIEW,\n", + " CASE\n", + " WHEN stars=1 THEN 'negative'\n", + " WHEN stars=3 THEN 'neutral'\n", + " WHEN stars=5 THEN 'positive'\n", + " ELSE null\n", + " END AS TARGET,\n", + " FROM data\n", + " WHERE stars!=2 AND stars!=4;\n", + " \"\"\").collect().sample(fraction=1.0, shuffle=True, seed=0)\n", + "\n", + "train_reviews = sentiment_data.head(100).select('REVIEW').to_series().to_list()\n", + "train_targets = sentiment_data.head(100).select('TARGET').to_series().to_list()\n", + "\n", + "test_reviews = sentiment_data.tail(100).select('REVIEW').to_series().to_list()\n", + "test_targets = sentiment_data.tail(100).select('TARGET').to_series().to_list()\n", + "\n", + "sentiment_data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4baccad9-20fd-48a6-adb4-d75d2335a3ac", + "metadata": {}, + "outputs": [], + "source": [ + "token = 'sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'" + ] + }, + { + "cell_type": "markdown", + "id": "34b211cc-0fcf-47a9-ad80-edd70577648b", + "metadata": {}, + "source": [ + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "6f43c63d-b523-418b-bee8-6faf65186e61", + "metadata": {}, + "source": [ + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "37078106-46f9-4f2b-8719-199fef2aad51", + "metadata": {}, + "source": [ + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "51e6344e-549f-4c33-8ba9-6256e47c2c09", + "metadata": {}, + "source": [ + "## 2. EN - Sin entrenamiento" + ] + }, + { + "cell_type": "markdown", + "id": "9153b637-b41b-42a0-8cb5-2f7d273ee842", + "metadata": {}, + "source": [ + "### Prueba 1" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "98ff2491-70ae-48b4-a887-e1f81c33ef8a", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"I need you to help me with a text classification task.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"The texts you will be processing are from the {__DOMAIN__} domain.\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"I want you to classify the texts into one of the following categories:\n", + " {__LABELS__}.\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"Please provide a step-by-step argument for your answer, explain why you\n", + " believe your final choice is justified, and make sure to conclude your\n", + " explanation with the name of the class you have selected as the correct\n", + " one, in lowercase and without punctuation.\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"In your response, include only the name of the class as a single word, in\n", + " lowercase, without punctuation, and without adding any other statements or\n", + " words.\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b491bbff-68ae-4265-a08a-56108ed94fdb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'en',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " prompt_domain = 'product reviews',\n", + " prompt_labels = ['positive','negative','neutral'],\n", + " selector_k = 0,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "pred_targets = [pred if len(pred)==1 else [''] for pred in pred_targets]\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "76722bfe-f1b9-47fe-9f87-2bc9f97062ca", + "metadata": {}, + "source": [ + "### Prueba 2" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "349e834b-a836-4093-b705-210ccc618c90", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"I need you to help me with a text classification task.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"The texts you will be processing are from the {__DOMAIN__} domain.\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"I want you to classify the texts into one of the following categories:\n", + " {__LABELS__}.\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"Think step by step you answer.\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"In your response, include only the name of the class predicted.\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "988735fe-a05b-4c01-a4a4-4c262f28add1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'en',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " prompt_domain = 'product reviews',\n", + " prompt_labels = ['positive','negative','neutral'],\n", + " selector_k = 0,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "c703c30f-d696-4379-84a3-c536a6d7f565", + "metadata": {}, + "source": [ + "## 3. EN - Con entrenamiento" + ] + }, + { + "cell_type": "markdown", + "id": "71ce3032-eaa0-4ba9-b7ff-2f455f57a66f", + "metadata": {}, + "source": [ + "### Prueba 1" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "ef488d68-63b3-4f15-8e4c-d56633ddb147", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"I need you to help me with a text classification task.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"The texts you will be processing are from the {__DOMAIN__} domain.\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"I want you to classify the texts into one of the following categories:\n", + " {__LABELS__}.\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"Please provide a step-by-step argument for your answer, explain why you\n", + " believe your final choice is justified, and make sure to conclude your\n", + " explanation with the name of the class you have selected as the correct\n", + " one, in lowercase and without punctuation.\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"In your response, include only the name of the class as a single word, in\n", + " lowercase, without punctuation, and without adding any other statements or\n", + " words.\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "172c20d7-cecb-44ef-9fff-82d075b1d53f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'en',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " selector_k = 10,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "model.train(\n", + " examples = train_reviews,\n", + " annotations = train_targets\n", + ")\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "5e9681b2-c628-4a3f-968d-c5cbecc298ec", + "metadata": {}, + "source": [ + "### Prueba 2" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "9def17bf-b9fc-4afd-af9a-fdeef0c974aa", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"I need you to help me with a text classification task.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"The texts you will be processing are from the {__DOMAIN__} domain.\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"I want you to classify the texts into one of the following categories:\n", + " {__LABELS__}.\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"Think step by step you answer.\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"In your response, include only the name of the class predicted.\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "0c198bff-2112-4141-a528-3d97c9d16a33", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'en',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " prompt_domain = 'product reviews',\n", + " prompt_labels = ['positive','negative','neutral'],\n", + " selector_k = 20,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "model.train(\n", + " examples = train_reviews,\n", + " annotations = train_targets,\n", + ")\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "1d69e21c-d1d9-4897-a6df-ec5ef43fe4d0", + "metadata": {}, + "source": [ + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd55f1c6-ac6e-447f-b291-1d0ac50e9048", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "0d24dafd-c16a-40f6-8537-69ff06a1b34d", + "metadata": {}, + "source": [ + "### Prueba 3" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "f0bab4a5-da98-4362-9f52-c3f5ca534ccd", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"I need you to help me with a text classification task.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"\"\n", + "\n", + "PROMPT_LABELS:\n", + " \"\"\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"\"\n", + "\n", + "ANSWER_FORMAT:\n", + " \"\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "c0cf795a-4563-4f6b-a970-20c47eca9870", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'en',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " selector_k = 20,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "model.train(\n", + " examples = train_reviews,\n", + " annotations = train_targets,\n", + ")\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "ea111fd4-6640-4631-b6ed-897fa95d10e2", + "metadata": {}, + "source": [ + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "b9993ba0-7b7b-44d5-b89b-d399ce4d9649", + "metadata": {}, + "source": [ + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "194182f1-0296-4d51-8d5a-aa736d722406", + "metadata": {}, + "source": [ + "## 4. Data Preparation - SP - Build sentiment dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "0b3ff622-fdb4-49da-ace2-c9e2a1cffcfb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 2)
REVIEWTARGET
strstr
"El filtro de d…"positive"
"Un poquito esc…"positive"
"Para qué decir…"negative"
"Mi hija esta e…"positive"
"Se podría mejo…"neutral"
" + ], + "text/plain": [ + "shape: (5, 2)\n", + "┌───────────────────────────────────┬──────────┐\n", + "│ REVIEW ┆ TARGET │\n", + "│ --- ┆ --- │\n", + "│ str ┆ str │\n", + "╞═══════════════════════════════════╪══════════╡\n", + "│ El filtro de de aire es como si … ┆ positive │\n", + "│ Un poquito escaso pero funciona … ┆ positive │\n", + "│ Para qué decir más... ┆ negative │\n", + "│ Mi hija esta en un campamento y … ┆ positive │\n", + "│ Se podría mejorar el ajuste del … ┆ neutral │\n", + "└───────────────────────────────────┴──────────┘" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import polars as pl\n", + "\n", + "data = pl.read_parquet('../data/amazon_reviews_sp/amazon_reviews_multi-test.parquet')\n", + "data.head()\n", + "\n", + "sql = pl.SQLContext()\n", + "sql.register('data', data)\n", + "\n", + "sentiment_data = sql.execute(\"\"\"\n", + " SELECT\n", + " review_body as REVIEW,\n", + " CASE\n", + " WHEN stars=1 THEN 'negative'\n", + " WHEN stars=3 THEN 'neutral'\n", + " WHEN stars=5 THEN 'positive'\n", + " ELSE null\n", + " END AS TARGET,\n", + " FROM data\n", + " WHERE stars!=2 AND stars!=4;\n", + " \"\"\").collect().sample(fraction=1.0, shuffle=True, seed=0)\n", + "/\n", + "train_reviews = sentiment_data.head(100).select('REVIEW').to_series().to_list()\n", + "train_targets = sentiment_data.head(100).select('TARGET').to_series().to_list()\n", + "\n", + "test_reviews = sentiment_data.tail(100).select('REVIEW').to_series().to_list()\n", + "test_targets = sentiment_data.tail(100).select('TARGET').to_series().to_list()\n", + "\n", + "sentiment_data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "cdfdb171-0391-4e3e-beca-072d56267948", + "metadata": {}, + "source": [ + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "d4d2b731-7d7c-4217-9a51-47e6689bbe7f", + "metadata": {}, + "source": [ + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "de2d1b15-01fd-4c98-8c92-29256d242cac", + "metadata": {}, + "source": [ + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "8d322711-2b28-442e-8c19-108c3a083a50", + "metadata": {}, + "source": [ + "## 5 SP - Sin Entrenamiento" + ] + }, + { + "cell_type": "markdown", + "id": "0cf9f12a-1a9b-4732-8d2c-59c9ed52f8a2", + "metadata": {}, + "source": [ + "### Prueba 1" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "0abdc935-3926-461c-9e45-ab90374a5fdb", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"Necesito que me ayudes en una tarea de clasificación de texto.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"Los textos que vas procesar del ambito de {__DOMAIN__}.\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"Quiero que me clasifiques los textos una de las siguientes categorías:\n", + " {__LABELS__}.\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"Por favor argumenta tu respuesta paso a paso, explica por qué crees que\n", + " está justificada tu elección final, y asegúrate de que acabas tu\n", + " explicación con el nombre de la clase que has escogido como la\n", + " correcta, en minúscula y sin puntuación.\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"En tu respuesta incluye sólo el nombre de la clase, como una única\n", + " palabra, en minúscula, sin puntuación, y sin añadir ninguna otra\n", + " afirmación o palabra.\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "9d88b784-746a-45d3-83b6-67c1f09a878a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'es',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " prompt_domain = 'opiniones de productos',\n", + " prompt_labels = ['positiva','negativa','neutral'],\n", + " selector_k = 0,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "pred_targets = [pred if len(pred)==1 else [''] for pred in pred_targets]\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "e55bc005-7447-48f1-967a-6f1ae76f37bb", + "metadata": {}, + "source": [ + "### Prueba 2" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "9ae93a75-bddc-410b-bc29-4019a928ddae", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"Necesito que me ayudes en una tarea de clasificación de texto.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"Los textos que vas procesar del ambito de {__DOMAIN__}.\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"Quiero que me clasifiques los textos una de las siguientes categorías:\n", + " {__LABELS__}.\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"Argumenta tu respuesta paso a paso.\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"En tu respuesta incluye sólo el nombre de la clase, como una única\n", + " respuesta\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "20ee64ed-cfb9-4fd8-ae5f-9bdf1fddf5a3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'es',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " prompt_domain = 'opiniones de productos',\n", + " prompt_labels = ['positiva','negativa','neutra'],\n", + " selector_k = 0,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "cf2744f0-ea90-4873-92bc-8d44c52621e1", + "metadata": {}, + "source": [ + "### Prueba 3" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "db7ab35e-0714-4308-b24d-69f08ca13900", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"Necesito que me ayudes en una tarea de clasificación de texto.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"Quiero que me clasifiques los textos una de las siguientes categorías:\n", + " {__LABELS__}.\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "32f5bd08-c05a-4f33-8564-35e0eba3dfb5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'es',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " prompt_domain = 'opiniones de productos',\n", + " prompt_labels = ['positiva','negativa','neutra'],\n", + " selector_k = 0,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "pred_targets = [pred if len(pred)==1 else [''] for pred in pred_targets]\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "504d176d-7402-44ab-8a64-acfcfc6c6459", + "metadata": {}, + "source": [ + "## ES - Con entrenamiento" + ] + }, + { + "cell_type": "markdown", + "id": "f62c3cd5-0724-46ac-99af-9e10de663976", + "metadata": {}, + "source": [ + "### Prueba 1" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "b4f26ab8-ab4f-4631-bcf9-ea48fa670171", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"Necesito que me ayudes en una tarea de clasificación de texto.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"Los textos que vas procesar del ambito de {__DOMAIN__}.\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"Quiero que me clasifiques los textos una de las siguientes categorías:\n", + " {__LABELS__}.\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"Por favor argumenta tu respuesta paso a paso, explica por qué crees que\n", + " está justificada tu elección final, y asegúrate de que acabas tu\n", + " explicación con el nombre de la clase que has escogido como la\n", + " correcta, en minúscula y sin puntuación.\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"En tu respuesta incluye sólo el nombre de la clase, como una única\n", + " palabra, en minúscula, sin puntuación, y sin añadir ninguna otra\n", + " afirmación o palabra.\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "1495f569-7656-4567-be30-75b7aedb401e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'es',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " prompt_domain = 'opiniones de productos',\n", + " prompt_labels = ['positiva','negativa','neutra'],\n", + " selector_k = 10,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "model.train(\n", + " examples = train_reviews,\n", + " annotations = train_targets,\n", + ")\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "pred_targets = [pred if len(pred)==1 else [''] for pred in pred_targets]\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "8006de83-a826-478c-bee7-6bf38d0fa23a", + "metadata": {}, + "source": [ + "### Prueba 2" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "e2ed16b3-fe75-48e0-9ee5-801aefa95143", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"Necesito que me ayudes en una tarea de clasificación de texto.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"Los textos que vas procesar del ambito de {__DOMAIN__}.\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"Quiero que me clasifiques los textos una de las siguientes categorías:\n", + " {__LABELS__}.\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"Argumenta tu respuesta paso a paso.\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"En tu respuesta incluye sólo el nombre de la clase, como una única\n", + " respuesta\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "80360efc-af76-4d15-b935-586a79e33abc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'es',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " prompt_domain = 'opiniones de productos',\n", + " prompt_labels = ['positiva','negativa','neutra'],\n", + " selector_k = 10,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "model.train(\n", + " examples = train_reviews,\n", + " annotations = train_targets,\n", + ")\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "pred_targets = [pred if len(pred)==1 else [''] for pred in pred_targets]\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "f50d8cd3-af0f-416a-8f9d-a6d57cd86ac6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['positiva'],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " ['neutra'],\n", + " [''],\n", + " ['neutra'],\n", + " ['neutra'],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " ['neutra']]" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred_targets" + ] + }, + { + "cell_type": "markdown", + "id": "25261a04-f35a-4eef-8d57-1d5019bcd424", + "metadata": {}, + "source": [ + "### Prueba 3" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "74a71c53-6983-49d6-aac6-f2d0a2e88e2f", + "metadata": {}, + "outputs": [], + "source": [ + "prompt='''\n", + "TEMPLATE:\n", + " \"Necesito que me ayudes en una tarea de clasificación de texto.\n", + " {__PROMPT_DOMAIN__}\n", + " {__PROMPT_LABELS__}\n", + "\n", + " {__CHAIN_THOUGHT__}\n", + " {__ANSWER_FORMAT__}\"\n", + "\n", + "\n", + "PROMPT_DOMAIN:\n", + " \"\"\n", + "\n", + "\n", + "PROMPT_LABELS:\n", + " \"\"\n", + "\n", + "\n", + "PROMPT_DETAIL:\n", + " \"\"\n", + "\n", + "\n", + "CHAIN_THOUGHT:\n", + " \"\"\n", + "\n", + "\n", + "ANSWER_FORMAT:\n", + " \"\"\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "ea907475-c58a-4b77-904e-0a4f30d78ed4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "from sklearn.metrics import confusion_matrix\n", + "from promptmeteo import DocumentClassifier\n", + "\n", + "model = DocumentClassifier(\n", + " language = 'es',\n", + " model_name = 'text-davinci-003',\n", + " model_provider_name = 'openai',\n", + " model_provider_token = token,\n", + " prompt_domain = 'opiniones de productos',\n", + " prompt_labels = ['positiva','negativa','neutra'],\n", + " selector_k = 20,\n", + ")\n", + "\n", + "model.task.prompt.read_prompt(prompt)\n", + "\n", + "model.train(\n", + " examples = train_reviews,\n", + " annotations = train_targets,\n", + ")\n", + "\n", + "pred_targets = model.predict(test_reviews)\n", + "pred_targets = [pred if len(pred)==1 else [''] for pred in pred_targets]\n", + "\n", + "sns.heatmap(\n", + " confusion_matrix(test_targets, pred_targets),\n", + " annot=True,\n", + " cmap='Blues')" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "100497a5-bffa-47c8-905a-b9569cce41bb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " ['neutra'],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " ['neutra'],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " ['neutra'],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " [''],\n", + " ['neutra'],\n", + " [''],\n", + " ['neutra']]" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred_targets" + ] + }, + { + "cell_type": "markdown", + "id": "572a6af8-112e-43d3-be26-79b100facd50", + "metadata": {}, + "source": [ + "## Conclusiones" + ] + }, + { + "cell_type": "markdown", + "id": "f2070ce4-4ede-44be-9764-d3186d3e62dd", + "metadata": {}, + "source": [ + "* Parece que con el modelo Flan-t5-small, el mejor resultado se obtiene añadiendo más ejemplot y quitando la instrucción del prompt\n", + "\n", + "* Parece que hay mucho errores asociados con que la respuesta tenga un espacio antes de laa respuesta" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/promptmeteo/__init__.py b/promptmeteo/__init__.py index f7ee081..8b9c6a7 100644 --- a/promptmeteo/__init__.py +++ b/promptmeteo/__init__.py @@ -4,3 +4,4 @@ from .document_classifier import DocumentClassifier from .api_generator import APIGenerator from .api_formatter import APIFormatter +from .summarizer import Summarizer \ No newline at end of file diff --git a/promptmeteo/api_formatter.py b/promptmeteo/api_formatter.py index b6d7c15..c55ce86 100644 --- a/promptmeteo/api_formatter.py +++ b/promptmeteo/api_formatter.py @@ -382,4 +382,4 @@ def replace_values(orig_dict, replace_dict): sort_keys=False, ) - return api + return api \ No newline at end of file diff --git a/promptmeteo/models/__init__.py b/promptmeteo/models/__init__.py index 27db4c7..ec3fd32 100644 --- a/promptmeteo/models/__init__.py +++ b/promptmeteo/models/__init__.py @@ -29,6 +29,7 @@ from .hf_hub_api import HFHubApiLLM from .hf_pipeline import HFPipelineLLM from .google_vertexai import GoogleVertexAILLM +from .bedrock import BedrockLLM class ModelProvider(str, Enum): @@ -42,6 +43,7 @@ class ModelProvider(str, Enum): PROVIDER_2: str = "hf_hub_api" PROVIDER_3: str = "hf_pipeline" PROVIDER_4: str = "google-vertexai" + PROVIDER_5: str = "bedrock" class ModelFactory: @@ -57,6 +59,7 @@ class ModelFactory: ModelProvider.PROVIDER_2: HFHubApiLLM, ModelProvider.PROVIDER_3: HFPipelineLLM, ModelProvider.PROVIDER_3: GoogleVertexAILLM, + ModelProvider.PROVIDER_5: BedrockLLM } @classmethod @@ -87,6 +90,9 @@ def factory_method( elif model_provider_name == ModelProvider.PROVIDER_4.value: model_cls = GoogleVertexAILLM + + elif model_provider_name == ModelProvider.PROVIDER_5.value: + model_cls = BedrockLLM else: raise ValueError( diff --git a/promptmeteo/models/bedrock.py b/promptmeteo/models/bedrock.py new file mode 100644 index 0000000..6536d13 --- /dev/null +++ b/promptmeteo/models/bedrock.py @@ -0,0 +1,127 @@ +#!/usr/bin/python3 + +# Copyright (c) 2023 Paradigma Digital S.L. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from enum import Enum +from typing import Dict +from typing import Optional +import os +import boto3 +from langchain.llms.bedrock import Bedrock +from langchain.embeddings import HuggingFaceEmbeddings + +from .base import BaseModel + + +class ModelTypes(str, Enum): + + """ + Enum of available model types. + """ + + AnthropicClaudeV2: str = "anthropic.claude-v2" + + @classmethod + def has_value( + cls, + value: str, + ) -> bool: + """ + Checks if the value is in the enum or not. + """ + + return value in cls._value2member_map_ + + +class ModelEnum(Enum): + + """ + Model Parameters. + """ + + class AnthropicClaudeV2: + + """ + Default parameters for Anthropic Claude V2 + """ + + client = Bedrock + embedding = HuggingFaceEmbeddings + model_task: str = "text2text-generation" + params: dict = { + 'max_tokens_to_sample': 2048, + 'temperature': 0.3, + 'top_k': 250, + 'top_p': 0.999, + 'stop_sequences': ['Human:'] + } + + +class BedrockLLM(BaseModel): + + """ + Bedrock LLM model. + """ + + def __init__( + self, + model_name: Optional[str] = "", + model_params: Optional[Dict] = None, + model_provider_token: Optional[str] = "", + **kwargs + ) -> None: + """ + Make predictions using a model from OpenAI. + It will use the os environment called OPENAI_ORGANIZATION for instance the LLM + """ + + if not ModelTypes.has_value(model_name): + raise ValueError( + f"`model_name`={model_name} not in supported model names: " + f"{[i.name for i in ModelTypes]}" + ) + self.boto3_bedrock = boto3.client('bedrock-runtime', **kwargs) + super(BedrockLLM, self).__init__() + + # Model name + model = ModelTypes(model_name).name + + # Model parameters + if not model_params: + model_params = ( + ModelEnum[model].value.params + if not model_params + else model_params + ) + self.model_params = model_params + + # Model + self._llm = ModelEnum[model].value.client( + model_id=model_name, + model_kwargs=self.model_params, + client = self.boto3_bedrock + ) + + embedding_name = "sentence-transformers/all-MiniLM-L6-v2" + if os.path.exists("/home/models/all-MiniLM-L6-v2"): + embedding_name = "/home/models/all-MiniLM-L6-v2" + + self._embeddings = HuggingFaceEmbeddings(model_name=embedding_name) diff --git a/promptmeteo/parsers/__init__.py b/promptmeteo/parsers/__init__.py index 9021f1e..000af85 100644 --- a/promptmeteo/parsers/__init__.py +++ b/promptmeteo/parsers/__init__.py @@ -27,6 +27,7 @@ from .base import BaseParser from .dummy_parser import DummyParser from .classification_parser import ClassificationParser +from .json_parser import JSONParser class ParserTypes(str, Enum): @@ -41,6 +42,8 @@ class ParserTypes(str, Enum): PARSER_4: str = "code-generation" PARSER_5: str = "api-generation" PARSER_6: str = "api-correction" + PARSER_7: str = "json-info-extraction" + PARSER_8: str = "summarization" class ParserFactory: @@ -77,6 +80,12 @@ def factory_method( elif task_type == ParserTypes.PARSER_6.value: parser_cls = ApiParser + + elif task_type == ParserTypes.PARSER_7.value: + parser_cls = JSONParser + + elif task_type == ParserTypes.PARSER_8.value: + parser_cls = DummyParser else: raise ValueError( diff --git a/promptmeteo/parsers/json_parser.py b/promptmeteo/parsers/json_parser.py new file mode 100644 index 0000000..c7f3ca9 --- /dev/null +++ b/promptmeteo/parsers/json_parser.py @@ -0,0 +1,67 @@ +#!/usr/bin/python3 + +# Copyright (c) 2023 Paradigma Digital S.L. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from typing import List +import re +from .base import BaseParser +import regex +import json + + +class JSONParser(BaseParser): + + """ + Parser for potential JSON outputs + """ + + def run( + self, + text: str, + ) -> List[str]: + """ + Given a response string from an LLM, returns the response expected for + the task. + """ + + try: + json_output = self._preprocess(text) + json_obtanaied = json.loads(json_output) + return json_output + except: + return "" + + + def _preprocess( + self, + text: str, + ) -> str: + """ + Preprocess output string before parsing result to solve common mistakes + such as end-of-line presence and beginning and finishing with empty + space. + """ + pattern = regex.compile(r'\{(?:[^{}]|(?R))*\}') + str_json = pattern.findall(text)[0] + + str_json = str_json.replace("'",'"') + + return str_json diff --git a/promptmeteo/prompts/anthropic.claude-v2_en_classification.prompt b/promptmeteo/prompts/anthropic.claude-v2_en_classification.prompt new file mode 100644 index 0000000..11bbfe9 --- /dev/null +++ b/promptmeteo/prompts/anthropic.claude-v2_en_classification.prompt @@ -0,0 +1,61 @@ +# Copyright (c) 2023 Paradigma Digital S.L. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +TEMPLATE: + "I need you to help me with a text classification task. + {__PROMPT_DOMAIN__} + {__PROMPT_LABELS__} + + {__CHAIN_THOUGHT__} + {__ANSWER_FORMAT__} + {__SHOT_EXAMPLES__} + {__PROMPT_SAMPLE__}" + + +PROMPT_DOMAIN: + "The texts you will be processing are from the {__DOMAIN__} domain." + + +PROMPT_LABELS: + "I want you to classify the texts into one of the following categories: + {__LABELS__}." + + +PROMPT_DETAIL: + "" + +SHOT_EXAMPLES: + "Examples:\n\n{__EXAMPLES__}" + +PROMPT_SAMPLE: + "\n\n{__SAMPLE__}\n" + +CHAIN_THOUGHT: + "Please provide a step-by-step argument for your answer, explain why you + believe your final choice is justified, and make sure to conclude your + explanation with the name of the class you have selected as the correct + one, in lowercase and without punctuation." + + +ANSWER_FORMAT: + "In your response, include only the name of the class as a single word, in + lowercase, without punctuation, and without adding any other statements or + words." diff --git a/promptmeteo/prompts/anthropic.claude-v2_en_summarization.prompt b/promptmeteo/prompts/anthropic.claude-v2_en_summarization.prompt new file mode 100644 index 0000000..5399512 --- /dev/null +++ b/promptmeteo/prompts/anthropic.claude-v2_en_summarization.prompt @@ -0,0 +1,51 @@ +# Copyright (c) 2023 Paradigma Digital S.L. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +TEMPLATE: + "Human: {__PROMPT_DOMAIN__} + {__PROMPT_SAMPLE__} + {__PROMPT_DETAIL__} + {__CHAIN_THOUGHT__} + {__ANSWER_FORMAT__}" + +PROMPT_SAMPLE: + " + \n\"{__SAMPLE__}\"\n + " + +PROMPT_DOMAIN: + "{__DOMAIN__}" + +PROMPT_DETAIL: + " + Based on the text segment, please build a precise summary and do not invent information. + + Assistant: Here is the summary: + " + +SHOT_EXAMPLES: + "" + +CHAIN_THOUGHT: + "" + +ANSWER_FORMAT: + "" diff --git a/promptmeteo/prompts/anthropic.claude-v2_es_classification.prompt b/promptmeteo/prompts/anthropic.claude-v2_es_classification.prompt new file mode 100644 index 0000000..65df7e4 --- /dev/null +++ b/promptmeteo/prompts/anthropic.claude-v2_es_classification.prompt @@ -0,0 +1,57 @@ +# Copyright (c) 2023 Paradigma Digital S.L. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +TEMPLATE: + "Necesito que me ayudes en una tarea de clasificación de texto. + {__PROMPT_DOMAIN__} + {__PROMPT_LABELS__} + + {__CHAIN_THOUGHT__} + {__ANSWER_FORMAT__} + {__SHOT_EXAMPLES__} + {__PROMPT_SAMPLE__}" + + +PROMPT_DOMAIN: + "Los textos que vas procesar del ámbito de {__DOMAIN__}." + + +PROMPT_LABELS: + "Quiero que me clasifiques los textos una de las siguientes categorías: + {__LABELS__}." + + +PROMPT_DETAIL: + "" + +SHOT_EXAMPLES: + "Ejemplos:\n\n{__EXAMPLES__}" + +PROMPT_SAMPLE: + "\n\n{__SAMPLE__}\n" + +CHAIN_THOUGHT: + "" + + +ANSWER_FORMAT: + "En tu respuesta incluye sólo el nombre de la clase, como una única + palabra." diff --git a/promptmeteo/prompts/anthropic.claude-v2_es_summarization.prompt b/promptmeteo/prompts/anthropic.claude-v2_es_summarization.prompt new file mode 100644 index 0000000..d124fe0 --- /dev/null +++ b/promptmeteo/prompts/anthropic.claude-v2_es_summarization.prompt @@ -0,0 +1,51 @@ +# Copyright (c) 2023 Paradigma Digital S.L. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +TEMPLATE: + "Human: {__PROMPT_DOMAIN__} + {__PROMPT_SAMPLE__} + {__PROMPT_DETAIL__} + {__CHAIN_THOUGHT__} + {__ANSWER_FORMAT__}" + +PROMPT_SAMPLE: + " + \n\"{__SAMPLE__}\"\n + " + +PROMPT_DOMAIN: + "{__DOMAIN__}" + +PROMPT_DETAIL: + " + Basado en el segmento de texto, por favor genera un resumen preciso y no invente información. + + Assistant: A continuación muestro el resumen: + " + +SHOT_EXAMPLES: + "" + +CHAIN_THOUGHT: + "" + +ANSWER_FORMAT: + "" diff --git a/promptmeteo/prompts/base.py b/promptmeteo/prompts/base.py index 44c1fd3..2f7f155 100644 --- a/promptmeteo/prompts/base.py +++ b/promptmeteo/prompts/base.py @@ -193,11 +193,8 @@ def run( ) # Domain - prompt_variables["__PROMPT_DOMAIN__"] = ( - self.PROMPT_DOMAIN.format(__DOMAIN__=self._prompt_domain) - if self._prompt_domain - else "" - ) + prompt_variables["__PROMPT_DOMAIN__"] = self.PROMPT_DOMAIN.format(__DOMAIN__=self._prompt_domain) + # Detail prompt_detail = ( @@ -205,11 +202,8 @@ def run( if isinstance(self._prompt_detail, list) else self._prompt_detail ) - prompt_variables["__PROMPT_DETAIL__"] = ( - self.PROMPT_DETAIL.format(__DETAIL__=prompt_detail) - if self._prompt_detail - else "" - ) + prompt_variables["__PROMPT_DETAIL__"] = self.PROMPT_DETAIL.format(__DETAIL__=prompt_detail) + return PromptTemplate.from_template( PromptTemplate.from_template(self.TEMPLATE).format( diff --git a/promptmeteo/prompts/fake-static_es_classification.prompt b/promptmeteo/prompts/fake-static_es_classification.prompt index cf59492..2df720b 100644 --- a/promptmeteo/prompts/fake-static_es_classification.prompt +++ b/promptmeteo/prompts/fake-static_es_classification.prompt @@ -31,7 +31,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/fake-static_es_ner.prompt b/promptmeteo/prompts/fake-static_es_ner.prompt index 16c3eea..c2c11bd 100644 --- a/promptmeteo/prompts/fake-static_es_ner.prompt +++ b/promptmeteo/prompts/fake-static_es_ner.prompt @@ -29,7 +29,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/google-flan-t5-small_es_classification.prompt b/promptmeteo/prompts/google-flan-t5-small_es_classification.prompt index 42eb520..65df7e4 100644 --- a/promptmeteo/prompts/google-flan-t5-small_es_classification.prompt +++ b/promptmeteo/prompts/google-flan-t5-small_es_classification.prompt @@ -31,7 +31,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/google-flan-t5-small_es_ner.prompt b/promptmeteo/prompts/google-flan-t5-small_es_ner.prompt index a5a08af..8dbe0fd 100644 --- a/promptmeteo/prompts/google-flan-t5-small_es_ner.prompt +++ b/promptmeteo/prompts/google-flan-t5-small_es_ner.prompt @@ -29,7 +29,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/google-flan-t5-xxl_es_classification.prompt b/promptmeteo/prompts/google-flan-t5-xxl_es_classification.prompt index c2f0018..00360c0 100644 --- a/promptmeteo/prompts/google-flan-t5-xxl_es_classification.prompt +++ b/promptmeteo/prompts/google-flan-t5-xxl_es_classification.prompt @@ -31,7 +31,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/google-flan-t5-xxl_es_ner.prompt b/promptmeteo/prompts/google-flan-t5-xxl_es_ner.prompt index 16c3eea..c2c11bd 100644 --- a/promptmeteo/prompts/google-flan-t5-xxl_es_ner.prompt +++ b/promptmeteo/prompts/google-flan-t5-xxl_es_ner.prompt @@ -29,7 +29,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/gpt-3.5-turbo-16k_es_summarization.prompt b/promptmeteo/prompts/gpt-3.5-turbo-16k_es_summarization.prompt new file mode 100644 index 0000000..b0c7a9d --- /dev/null +++ b/promptmeteo/prompts/gpt-3.5-turbo-16k_es_summarization.prompt @@ -0,0 +1,49 @@ +# Copyright (c) 2023 Paradigma Digital S.L. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +TEMPLATE: + "{__PROMPT_DOMAIN__} + {__PROMPT_SAMPLE__} + {__PROMPT_DETAIL__} + {__CHAIN_THOUGHT__} + {__ANSWER_FORMAT__}" + +PROMPT_SAMPLE: + " + \n\"{__SAMPLE__}\"\n + " + +PROMPT_DOMAIN: + "{__DOMAIN__}" + +PROMPT_DETAIL: + " + Basado en el segmento de texto, por favor genera un resumen preciso y no invente información. + " + +SHOT_EXAMPLES: + "" + +CHAIN_THOUGHT: + "" + +ANSWER_FORMAT: + "" diff --git a/promptmeteo/prompts/gpt-3.5-turbo-instruct_es_classification.prompt b/promptmeteo/prompts/gpt-3.5-turbo-instruct_es_classification.prompt index c2f0018..00360c0 100644 --- a/promptmeteo/prompts/gpt-3.5-turbo-instruct_es_classification.prompt +++ b/promptmeteo/prompts/gpt-3.5-turbo-instruct_es_classification.prompt @@ -31,7 +31,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/text-bison-32k_es_classification.prompt b/promptmeteo/prompts/text-bison-32k_es_classification.prompt index c2f0018..00360c0 100644 --- a/promptmeteo/prompts/text-bison-32k_es_classification.prompt +++ b/promptmeteo/prompts/text-bison-32k_es_classification.prompt @@ -31,7 +31,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/text-bison-32k_es_ner.prompt b/promptmeteo/prompts/text-bison-32k_es_ner.prompt index 16c3eea..c2c11bd 100644 --- a/promptmeteo/prompts/text-bison-32k_es_ner.prompt +++ b/promptmeteo/prompts/text-bison-32k_es_ner.prompt @@ -29,7 +29,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/text-bison@001_es_classification.prompt b/promptmeteo/prompts/text-bison@001_es_classification.prompt index 1680e8b..874738e 100644 --- a/promptmeteo/prompts/text-bison@001_es_classification.prompt +++ b/promptmeteo/prompts/text-bison@001_es_classification.prompt @@ -30,7 +30,7 @@ TEMPLATE: {__PROMPT_SAMPLE__}" PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/text-bison@001_es_ner.prompt b/promptmeteo/prompts/text-bison@001_es_ner.prompt index 16c3eea..c2c11bd 100644 --- a/promptmeteo/prompts/text-bison@001_es_ner.prompt +++ b/promptmeteo/prompts/text-bison@001_es_ner.prompt @@ -29,7 +29,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/text-bison_es_classification.prompt b/promptmeteo/prompts/text-bison_es_classification.prompt index cf59492..2df720b 100644 --- a/promptmeteo/prompts/text-bison_es_classification.prompt +++ b/promptmeteo/prompts/text-bison_es_classification.prompt @@ -31,7 +31,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/text-bison_es_ner.prompt b/promptmeteo/prompts/text-bison_es_ner.prompt index 16c3eea..c2c11bd 100644 --- a/promptmeteo/prompts/text-bison_es_ner.prompt +++ b/promptmeteo/prompts/text-bison_es_ner.prompt @@ -29,7 +29,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/prompts/text-davinci-003_es_ner.prompt b/promptmeteo/prompts/text-davinci-003_es_ner.prompt index 16c3eea..c2c11bd 100644 --- a/promptmeteo/prompts/text-davinci-003_es_ner.prompt +++ b/promptmeteo/prompts/text-davinci-003_es_ner.prompt @@ -29,7 +29,7 @@ TEMPLATE: PROMPT_DOMAIN: - "Los textos que vas procesar del ambito de {__DOMAIN__}." + "Los textos que vas procesar del ámbito de {__DOMAIN__}." PROMPT_LABELS: diff --git a/promptmeteo/summarizer.py b/promptmeteo/summarizer.py new file mode 100644 index 0000000..3423b7a --- /dev/null +++ b/promptmeteo/summarizer.py @@ -0,0 +1,159 @@ +#%% +#!/usr/bin/python3 +# Copyright (c) 2023 Paradigma Digital S.L. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +import re +import tarfile +import tempfile +import json +import os + +import yaml +from copy import deepcopy +from typing import List + +try: + from typing import Self +except ImportError: + from typing_extensions import Self +from langchain.prompts import PromptTemplate + +from .base import BaseUnsupervised +from .tasks import TaskTypes, TaskBuilder +from .tools import add_docstring_from +from .validations import version_validation + + +class Summarizer(BaseUnsupervised): + + """ + Class for text summarization + """ + TASK_TYPE = TaskTypes.SUMMARIZATION.value + + @add_docstring_from(BaseUnsupervised.__init__) + def __init__( + self, + **kwargs, + ) -> None: + """ + Example + ------- + + >>> from promptmeteo import Summarizer + + >>> model = Summarizer( + >>> language="es", + >>> prompt_domain="A partir del siguiente texto:", + >>> model_name="anthropic.claude-v2", + >>> model_provider_name = "bedrock" + >>> ) + >>> model.predict([text]) + """ + super(Summarizer, self).__init__(**kwargs) + + + + + + @add_docstring_from(BaseUnsupervised.train) + def train( + self, + ) -> Self: + """ + Train the APIFormatter to extract entities anda parameteres. + + Parameters + ---------- + + api_codes : List[str] + + + Returns + ------- + + self + + """ + super(Summarizer, self).train(examples=[""]) + + return self + + @classmethod + @add_docstring_from(BaseUnsupervised.load_model) + def load_model( + cls, + model_path: str, + ) -> Self: + """ + Loads a model artifact to make new predictions. + + Parameters + ---------- + + model_path : str + + + Returns + ------- + + self : Promptmeteo + + """ + + model_dir = os.path.dirname(model_path) + model_name = os.path.basename(model_path) + + if not model_name.endswith(".meteo"): + raise ValueError( + f"{cls.__name__} error in `load_model()`. " + f'model_path="{model_path}" has a bad model name extension. ' + f"Model name must end with `.meteo` (i.e. `./model.meteo`)" + ) + + if not os.path.exists(model_path): + raise ValueError( + f"{cls.__name__} error in `load_model()`. " + f"directory {model_dir} does not exists." + ) + + with tempfile.TemporaryDirectory() as tmp: + with tarfile.open(model_path, "r:gz") as tar: + tar.extractall(tmp) + + init_tmp_path = os.path.join( + tmp, f"{os.path.splitext(model_name)[0]}.init" + ) + + with open(init_tmp_path) as f_init: + params = json.load(f_init) + + self = cls(**params) + + self.builder.build_selector_by_load( + model_path=os.path.join(tmp, model_name), + selector_type=self.SELECTOR_TYPE, + selector_k=self._selector_k, + selector_algorithm=self._selector_algorithm, + ) + + self._is_trained = True + + return self diff --git a/promptmeteo/tasks/task_builder.py b/promptmeteo/tasks/task_builder.py index 5f77944..36cec4f 100644 --- a/promptmeteo/tasks/task_builder.py +++ b/promptmeteo/tasks/task_builder.py @@ -50,6 +50,7 @@ class TaskTypes(str, Enum): CODE_GENERATION: str = "code-generation" API_GENERATION: str = "api-generation" API_CORRECTION: str = "api-correction" + SUMMARIZATION: str = "summarization" class TaskBuilder: diff --git a/pyproject.toml b/pyproject.toml index 13d397d..fe5ee8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "pydantic == 1.10.11", "faiss-cpu == 1.7.4", "tiktoken==0.4.0", + "boto3==1.34.23" ] [tool.setuptools_scm] diff --git a/tests/test_models.py b/tests/test_models.py index fa3cbc7..11c35fd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -57,6 +57,34 @@ def test_model_openai(self): ) assert error.value.args[0] == invalid_provider + + def test_model_bedrock(self): + from promptmeteo.models.bedrock import BedrockLLM + from promptmeteo.models.bedrock import ModelTypes + + for model_name in ModelTypes: + BedrockLLM( + model_name=model_name.value, + model_params={}, + model_provider_token="TEST_TOKEN", + region_name="us-east-1" + ) + + with pytest.raises(ValueError) as error: + BedrockLLM( + model_name="WRONG_NAME", + model_params={}, + model_provider_token="TEST_TOKEN", + region_name="us-east-1" + ) + + invalid_provider = ( + "`model_name`=WRONG_NAME not in supported model names: " + f"{[i.name for i in ModelTypes]}" + ) + assert error.value.args[0] == invalid_provider + + def test_model_fakellm(self): from promptmeteo.models.fake_llm import ModelTypes from promptmeteo.models.fake_llm import FakeLLM diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 8a4ceef..2a9885e 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -4,6 +4,7 @@ from promptmeteo.parsers import ParserFactory from promptmeteo.parsers.dummy_parser import DummyParser from promptmeteo.parsers.classification_parser import ClassificationParser +from promptmeteo.parsers.json_parser import JSONParser class Testparsers: @@ -37,3 +38,31 @@ def test_classification_parser(self): assert ["true"] == parser.run("True") assert ["true"] == parser.run("blabla, true, blabla") assert ["true", "false"] == parser.run("true, false") + + def test_json_parser(self): + parser = JSONParser(prompt_labels=["true","false"]) + + wrong_json = """{ + "item1":[1,2,3,4], + 'item2':{ + "item2.1":"test item", + 'item2.2":["i211","i212",'i213'] + }, + "item3":"this is the third item" + }""" + + correct_json = """{ + "item1":[1,2,3,4], + "item2":{ + "item2.1":"test item", + "item2.2":["i211","i212","i213"] + }, + "item3":"this is the third item" + }""" + + assert correct_json == parser.run(wrong_json) + assert parser.run(correct_json) == parser.run(wrong_json) + import json + json.loads(parser.run(wrong_json)) + assert "" == parser.run("This is not a json format") + assert "" == parser.run("{This is not a complete json") diff --git a/tests/tools/dictionary_checker.py b/tests/tools/dictionary_checker.py index e2fe023..977e8e8 100644 --- a/tests/tools/dictionary_checker.py +++ b/tests/tools/dictionary_checker.py @@ -9,7 +9,7 @@ class DictionaryChecker: ADDED_WORDS = { "en": {"openapi": 1, "api": 1, "schema": 1, "schemas": 1}, - "es": {"openapi": 1, "api": 1, "sample":1, "examples":1}, + "es": {"openapi": 1, "api": 1, "sample":1, "examples":1, "generes":1, "human":1, "assistant":1}, } def __init__(self, language: str):