|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 |
| - "execution_count": 1, |
| 5 | + "execution_count": 344, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [ |
| 8 | + { |
| 9 | + "name": "stdout", |
| 10 | + "output_type": "stream", |
| 11 | + "text": [ |
| 12 | + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", |
| 13 | + "To disable this warning, you can either:\n", |
| 14 | + "\t- Avoid using `tokenizers` before the fork if possible\n", |
| 15 | + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", |
| 16 | + "Collecting thefuzz\n", |
| 17 | + " Downloading thefuzz-0.19.0-py2.py3-none-any.whl (17 kB)\n", |
| 18 | + "Installing collected packages: thefuzz\n", |
| 19 | + "Successfully installed thefuzz-0.19.0\n" |
| 20 | + ] |
| 21 | + } |
| 22 | + ], |
| 23 | + "source": [ |
| 24 | + "#!pip install thefuzz" |
| 25 | + ] |
| 26 | + }, |
| 27 | + { |
| 28 | + "cell_type": "code", |
| 29 | + "execution_count": 345, |
6 | 30 | "metadata": {},
|
7 | 31 | "outputs": [],
|
8 | 32 | "source": [
|
9 | 33 | "from gramformer import Gramformer\n",
|
10 | 34 | "import torch\n",
|
11 | 35 | "import spacy\n",
|
12 |
| - "import random" |
| 36 | + "import random\n", |
| 37 | + "from thefuzz import fuzz\n" |
13 | 38 | ]
|
14 | 39 | },
|
15 | 40 | {
|
16 | 41 | "cell_type": "markdown",
|
17 | 42 | "metadata": {},
|
18 | 43 | "source": [
|
19 |
| - "### Original GrammarModel" |
| 44 | + "## Original GrammarModel" |
20 | 45 | ]
|
21 | 46 | },
|
22 | 47 | {
|
|
210 | 235 | "cell_type": "markdown",
|
211 | 236 | "metadata": {},
|
212 | 237 | "source": [
|
213 |
| - "### Remove correction for error types ORTH, OTHER (and PUNCT?)" |
| 238 | + "## Improvements to grammar correction" |
214 | 239 | ]
|
215 | 240 | },
|
216 | 241 | {
|
217 | 242 | "cell_type": "code",
|
218 |
| - "execution_count": 327, |
| 243 | + "execution_count": 413, |
219 | 244 | "metadata": {},
|
220 | 245 | "outputs": [],
|
221 | 246 | "source": [
|
|
256 | 281 | " last_user_input = chat_history[-1].get('text')\n",
|
257 | 282 | " corrected_sentence, correction_message = self.grammar_correction(last_user_input)\n",
|
258 | 283 | " error_types = self.get_edits(last_user_input, corrected_sentence)\n",
|
259 |
| - " overlap_ignore_errors = any(item in error_types for item in self.ignore_errors)\n", |
260 |
| - "\n", |
261 |
| - " if correction_message and (overlap_ignore_errors is False):\n", |
| 284 | + " relevant_error = any(error not in self.ignore_errors for error in error_types) # check if there is an error in the sentence which is not in the ignore list \n", |
| 285 | + " token_sort_ratio = fuzz.token_sort_ratio(corrected_sentence, last_user_input) # calculate token similarity (ignoring punctuation and casing)\n", |
| 286 | + " print(f\"correction_message: {correction_message}\\nErrors detected: {error_types}\\nPresence of a relevant error: {relevant_error}\\nSimilarity Score: {token_sort_ratio}\") # for debugging only\n", |
| 287 | + " \n", |
| 288 | + " if correction_message and relevant_error and token_sort_ratio != 100:\n", |
262 | 289 | " chat_history.append(\n",
|
263 | 290 | " {\n",
|
264 | 291 | " 'sender': 'bot',\n",
|
265 | 292 | " 'text': correction_message,\n",
|
266 |
| - " 'correction': True,\n", |
267 |
| - " 'error_type': error_types\n", |
| 293 | + " 'correction': True\n", |
268 | 294 | " }\n",
|
269 | 295 | " )\n",
|
| 296 | + " \n", |
270 | 297 | " return chat_history \n",
|
271 | 298 | "\n",
|
272 | 299 | "\n",
|
|
295 | 322 | },
|
296 | 323 | {
|
297 | 324 | "cell_type": "code",
|
298 |
| - "execution_count": 328, |
| 325 | + "execution_count": 414, |
299 | 326 | "metadata": {},
|
300 | 327 | "outputs": [
|
301 | 328 | {
|
|
310 | 337 | "gm2 = GrammarModel2(models=1, use_gpu=False)"
|
311 | 338 | ]
|
312 | 339 | },
|
| 340 | + { |
| 341 | + "cell_type": "markdown", |
| 342 | + "metadata": {}, |
| 343 | + "source": [ |
| 344 | + "### 1. Remove correction when no relevant errors are detected(other than those in self.ignore_errors)" |
| 345 | + ] |
| 346 | + }, |
313 | 347 | {
|
314 | 348 | "cell_type": "code",
|
315 |
| - "execution_count": 341, |
| 349 | + "execution_count": 450, |
316 | 350 | "metadata": {},
|
317 | 351 | "outputs": [],
|
318 | 352 | "source": [
|
319 |
| - "chat_history_ex1 = [{'sender': 'User', 'text': 'Hi bot'}]" |
| 353 | + "chat_history_ex1 = [{'sender': 'User', 'text': 'where are you goin?'}]" |
320 | 354 | ]
|
321 | 355 | },
|
322 | 356 | {
|
323 | 357 | "cell_type": "code",
|
324 |
| - "execution_count": 342, |
| 358 | + "execution_count": 451, |
325 | 359 | "metadata": {},
|
326 |
| - "outputs": [], |
| 360 | + "outputs": [ |
| 361 | + { |
| 362 | + "name": "stdout", |
| 363 | + "output_type": "stream", |
| 364 | + "text": [ |
| 365 | + "correction_message: I think you meant: \"where are you going?\" \n", |
| 366 | + "Errors detected: ['PUNCT']\n", |
| 367 | + "Presence of a relevant error: True\n", |
| 368 | + "Similarity Score: 97\n" |
| 369 | + ] |
| 370 | + } |
| 371 | + ], |
327 | 372 | "source": [
|
328 | 373 | "chat_history = gm2.add_correction_to_chat_history(chat_history_ex1)"
|
329 | 374 | ]
|
330 | 375 | },
|
331 | 376 | {
|
332 | 377 | "cell_type": "code",
|
333 |
| - "execution_count": 343, |
| 378 | + "execution_count": 452, |
334 | 379 | "metadata": {},
|
335 | 380 | "outputs": [
|
336 | 381 | {
|
337 | 382 | "data": {
|
338 | 383 | "text/plain": [
|
339 |
| - "[{'sender': 'User', 'text': 'Hi bot'},\n", |
| 384 | + "[{'sender': 'User', 'text': 'where are you goin?'},\n", |
340 | 385 | " {'sender': 'bot',\n",
|
341 |
| - " 'text': 'This would be better said like this: \"Hi booch!\" ',\n", |
342 |
| - " 'correction': True,\n", |
343 |
| - " 'error_type': ['NOUN']}]" |
| 386 | + " 'text': 'I think you meant: \"where are you going?\" ',\n", |
| 387 | + " 'correction': True}]" |
344 | 388 | ]
|
345 | 389 | },
|
346 |
| - "execution_count": 343, |
| 390 | + "execution_count": 452, |
347 | 391 | "metadata": {},
|
348 | 392 | "output_type": "execute_result"
|
349 | 393 | }
|
|
352 | 396 | "chat_history"
|
353 | 397 | ]
|
354 | 398 | },
|
| 399 | + { |
| 400 | + "cell_type": "markdown", |
| 401 | + "metadata": {}, |
| 402 | + "source": [ |
| 403 | + "### 2. Remove correction when input and correction are very similar" |
| 404 | + ] |
| 405 | + }, |
355 | 406 | {
|
356 | 407 | "cell_type": "code",
|
357 |
| - "execution_count": null, |
| 408 | + "execution_count": 369, |
358 | 409 | "metadata": {},
|
359 | 410 | "outputs": [],
|
360 |
| - "source": [] |
| 411 | + "source": [ |
| 412 | + "# Example of similar sentences which should not be corrected\n", |
| 413 | + "ex1 = \"Hi bot!\"\n", |
| 414 | + "ex2 = \"Hi bot\"\n", |
| 415 | + "ex3 = \"Hi bot.\"\n", |
| 416 | + "ex4 = \"Hi Bot\"\n", |
| 417 | + "ex5 = \"Hi Bot.\"\n", |
| 418 | + "ex6 = \"Hi Bot!\"\n", |
| 419 | + "ex7 = \"Hi Bot Bot!\" # should lead to lower token sort ratio, but same token set ratio compared to ex1" |
| 420 | + ] |
| 421 | + }, |
| 422 | + { |
| 423 | + "cell_type": "code", |
| 424 | + "execution_count": 364, |
| 425 | + "metadata": {}, |
| 426 | + "outputs": [], |
| 427 | + "source": [ |
| 428 | + "# Measure the similarity between 0 and 100 to define a threshold.\n", |
| 429 | + "\n", |
| 430 | + "def measure_similarity(sentence1, sentence2):\n", |
| 431 | + " simple_ratio = fuzz.ratio(sentence1, sentence2)\n", |
| 432 | + " print(f\"simple ratio similarity score: {simple_ratio}\")\n", |
| 433 | + "\n", |
| 434 | + " partial_ratio = fuzz.partial_ratio(sentence1, sentence2) # Return the ratio of the most similar substring.\n", |
| 435 | + " print(f\"partial ratio similarity score: {partial_ratio}\")\n", |
| 436 | + "\n", |
| 437 | + " ratio = fuzz.ratio(sentence1, sentence2)\n", |
| 438 | + " print(f\"ratio similarity score: {ratio}\")\n", |
| 439 | + "\n", |
| 440 | + " token_sort_ratio = fuzz.token_sort_ratio(sentence1, sentence2) # Return a measure of the sequences' similarity sorting the token before comparing. This is what we want to set as threshold.\n", |
| 441 | + " print(f\"token sort ratio similarity score: {token_sort_ratio}\")\n", |
| 442 | + "\n", |
| 443 | + " token_set_ratio = fuzz.token_set_ratio(sentence1, sentence2) # Measures similarity between unique tokens.\n", |
| 444 | + " print(f\"token set ratio similarity score: {token_set_ratio}\")" |
| 445 | + ] |
| 446 | + }, |
| 447 | + { |
| 448 | + "cell_type": "code", |
| 449 | + "execution_count": 378, |
| 450 | + "metadata": {}, |
| 451 | + "outputs": [ |
| 452 | + { |
| 453 | + "name": "stdout", |
| 454 | + "output_type": "stream", |
| 455 | + "text": [ |
| 456 | + "simple ratio similarity score: 67\n", |
| 457 | + "partial ratio similarity score: 71\n", |
| 458 | + "ratio similarity score: 67\n", |
| 459 | + "token sort ratio similarity score: 75\n", |
| 460 | + "token set ratio similarity score: 100\n" |
| 461 | + ] |
| 462 | + } |
| 463 | + ], |
| 464 | + "source": [ |
| 465 | + "measure_similarity(ex1, ex7)" |
| 466 | + ] |
| 467 | + }, |
| 468 | + { |
| 469 | + "cell_type": "code", |
| 470 | + "execution_count": 426, |
| 471 | + "metadata": {}, |
| 472 | + "outputs": [], |
| 473 | + "source": [ |
| 474 | + "## Correction accuracy check" |
| 475 | + ] |
361 | 476 | },
|
362 | 477 | {
|
363 | 478 | "cell_type": "code",
|
|
0 commit comments