Skip to content

Commit aff5961

Browse files
committed
improve grammar correction adding a check for similarity with thefuzz
1 parent b61bc1b commit aff5961

File tree

3 files changed

+147
-28
lines changed

3 files changed

+147
-28
lines changed

chatbot/grammar_correction.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import spacy
44
import random
5+
from thefuzz import fuzz
56

67
class GrammarModel(Gramformer):
78
"""
@@ -39,17 +40,19 @@ def add_correction_to_chat_history(self, chat_history):
3940
last_user_input = chat_history[-1].get('text')
4041
corrected_sentence, correction_message = self.grammar_correction(last_user_input)
4142
error_types = self.get_edits(last_user_input, corrected_sentence)
42-
overlap_ignore_errors = any(item in error_types for item in self.ignore_errors)
43-
44-
if correction_message and (overlap_ignore_errors is False):
43+
relevant_error = any(error not in self.ignore_errors for error in error_types) # check if there is an error in the sentence which is not in the ignore list
44+
token_sort_ratio = fuzz.token_sort_ratio(corrected_sentence, last_user_input) # calculate token similarity (ignoring punctuation and casing)
45+
46+
if correction_message and relevant_error and token_sort_ratio != 100:
4547
chat_history.append(
4648
{
4749
'sender': 'bot',
4850
'text': correction_message,
4951
'correction': True
5052
}
5153
)
52-
return chat_history
54+
55+
return chat_history
5356

5457

5558
def _get_edits(self, input_sentence, corrected_sentence):

chatbot/notebooks/grammar_model_improvements.ipynb

+138-23
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,46 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 344,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
13+
"To disable this warning, you can either:\n",
14+
"\t- Avoid using `tokenizers` before the fork if possible\n",
15+
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
16+
"Collecting thefuzz\n",
17+
" Downloading thefuzz-0.19.0-py2.py3-none-any.whl (17 kB)\n",
18+
"Installing collected packages: thefuzz\n",
19+
"Successfully installed thefuzz-0.19.0\n"
20+
]
21+
}
22+
],
23+
"source": [
24+
"#!pip install thefuzz"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": 345,
630
"metadata": {},
731
"outputs": [],
832
"source": [
933
"from gramformer import Gramformer\n",
1034
"import torch\n",
1135
"import spacy\n",
12-
"import random"
36+
"import random\n",
37+
"from thefuzz import fuzz\n"
1338
]
1439
},
1540
{
1641
"cell_type": "markdown",
1742
"metadata": {},
1843
"source": [
19-
"### Original GrammarModel"
44+
"## Original GrammarModel"
2045
]
2146
},
2247
{
@@ -210,12 +235,12 @@
210235
"cell_type": "markdown",
211236
"metadata": {},
212237
"source": [
213-
"### Remove correction for error types ORTH, OTHER (and PUNCT?)"
238+
"## Improvements to grammar correction"
214239
]
215240
},
216241
{
217242
"cell_type": "code",
218-
"execution_count": 327,
243+
"execution_count": 413,
219244
"metadata": {},
220245
"outputs": [],
221246
"source": [
@@ -256,17 +281,19 @@
256281
" last_user_input = chat_history[-1].get('text')\n",
257282
" corrected_sentence, correction_message = self.grammar_correction(last_user_input)\n",
258283
" error_types = self.get_edits(last_user_input, corrected_sentence)\n",
259-
" overlap_ignore_errors = any(item in error_types for item in self.ignore_errors)\n",
260-
"\n",
261-
" if correction_message and (overlap_ignore_errors is False):\n",
284+
" relevant_error = any(error not in self.ignore_errors for error in error_types) # check if there is an error in the sentence which is not in the ignore list \n",
285+
" token_sort_ratio = fuzz.token_sort_ratio(corrected_sentence, last_user_input) # calculate token similarity (ignoring punctuation and casing)\n",
286+
" print(f\"correction_message: {correction_message}\\nErrors detected: {error_types}\\nPresence of a relevant error: {relevant_error}\\nSimilarity Score: {token_sort_ratio}\") # for debugging only\n",
287+
" \n",
288+
" if correction_message and relevant_error and token_sort_ratio != 100:\n",
262289
" chat_history.append(\n",
263290
" {\n",
264291
" 'sender': 'bot',\n",
265292
" 'text': correction_message,\n",
266-
" 'correction': True,\n",
267-
" 'error_type': error_types\n",
293+
" 'correction': True\n",
268294
" }\n",
269295
" )\n",
296+
" \n",
270297
" return chat_history \n",
271298
"\n",
272299
"\n",
@@ -295,7 +322,7 @@
295322
},
296323
{
297324
"cell_type": "code",
298-
"execution_count": 328,
325+
"execution_count": 414,
299326
"metadata": {},
300327
"outputs": [
301328
{
@@ -310,40 +337,57 @@
310337
"gm2 = GrammarModel2(models=1, use_gpu=False)"
311338
]
312339
},
340+
{
341+
"cell_type": "markdown",
342+
"metadata": {},
343+
"source": [
344+
"### 1. Remove correction when no relevant errors are detected(other than those in self.ignore_errors)"
345+
]
346+
},
313347
{
314348
"cell_type": "code",
315-
"execution_count": 341,
349+
"execution_count": 450,
316350
"metadata": {},
317351
"outputs": [],
318352
"source": [
319-
"chat_history_ex1 = [{'sender': 'User', 'text': 'Hi bot'}]"
353+
"chat_history_ex1 = [{'sender': 'User', 'text': 'where are you goin?'}]"
320354
]
321355
},
322356
{
323357
"cell_type": "code",
324-
"execution_count": 342,
358+
"execution_count": 451,
325359
"metadata": {},
326-
"outputs": [],
360+
"outputs": [
361+
{
362+
"name": "stdout",
363+
"output_type": "stream",
364+
"text": [
365+
"correction_message: I think you meant: \"where are you going?\" \n",
366+
"Errors detected: ['PUNCT']\n",
367+
"Presence of a relevant error: True\n",
368+
"Similarity Score: 97\n"
369+
]
370+
}
371+
],
327372
"source": [
328373
"chat_history = gm2.add_correction_to_chat_history(chat_history_ex1)"
329374
]
330375
},
331376
{
332377
"cell_type": "code",
333-
"execution_count": 343,
378+
"execution_count": 452,
334379
"metadata": {},
335380
"outputs": [
336381
{
337382
"data": {
338383
"text/plain": [
339-
"[{'sender': 'User', 'text': 'Hi bot'},\n",
384+
"[{'sender': 'User', 'text': 'where are you goin?'},\n",
340385
" {'sender': 'bot',\n",
341-
" 'text': 'This would be better said like this: \"Hi booch!\" ',\n",
342-
" 'correction': True,\n",
343-
" 'error_type': ['NOUN']}]"
386+
" 'text': 'I think you meant: \"where are you going?\" ',\n",
387+
" 'correction': True}]"
344388
]
345389
},
346-
"execution_count": 343,
390+
"execution_count": 452,
347391
"metadata": {},
348392
"output_type": "execute_result"
349393
}
@@ -352,12 +396,83 @@
352396
"chat_history"
353397
]
354398
},
399+
{
400+
"cell_type": "markdown",
401+
"metadata": {},
402+
"source": [
403+
"### 2. Remove correction when input and correction are very similar"
404+
]
405+
},
355406
{
356407
"cell_type": "code",
357-
"execution_count": null,
408+
"execution_count": 369,
358409
"metadata": {},
359410
"outputs": [],
360-
"source": []
411+
"source": [
412+
"# Example of similar sentences which should not be corrected\n",
413+
"ex1 = \"Hi bot!\"\n",
414+
"ex2 = \"Hi bot\"\n",
415+
"ex3 = \"Hi bot.\"\n",
416+
"ex4 = \"Hi Bot\"\n",
417+
"ex5 = \"Hi Bot.\"\n",
418+
"ex6 = \"Hi Bot!\"\n",
419+
"ex7 = \"Hi Bot Bot!\" # should lead to lower token sort ratio, but same token set ratio compared to ex1"
420+
]
421+
},
422+
{
423+
"cell_type": "code",
424+
"execution_count": 364,
425+
"metadata": {},
426+
"outputs": [],
427+
"source": [
428+
"# Measure the similarity between 0 and 100 to define a threshold.\n",
429+
"\n",
430+
"def measure_similarity(sentence1, sentence2):\n",
431+
" simple_ratio = fuzz.ratio(sentence1, sentence2)\n",
432+
" print(f\"simple ratio similarity score: {simple_ratio}\")\n",
433+
"\n",
434+
" partial_ratio = fuzz.partial_ratio(sentence1, sentence2) # Return the ratio of the most similar substring.\n",
435+
" print(f\"partial ratio similarity score: {partial_ratio}\")\n",
436+
"\n",
437+
" ratio = fuzz.ratio(sentence1, sentence2)\n",
438+
" print(f\"ratio similarity score: {ratio}\")\n",
439+
"\n",
440+
" token_sort_ratio = fuzz.token_sort_ratio(sentence1, sentence2) # Return a measure of the sequences' similarity sorting the token before comparing. This is what we want to set as threshold.\n",
441+
" print(f\"token sort ratio similarity score: {token_sort_ratio}\")\n",
442+
"\n",
443+
" token_set_ratio = fuzz.token_set_ratio(sentence1, sentence2) # Measures similarity between unique tokens.\n",
444+
" print(f\"token set ratio similarity score: {token_set_ratio}\")"
445+
]
446+
},
447+
{
448+
"cell_type": "code",
449+
"execution_count": 378,
450+
"metadata": {},
451+
"outputs": [
452+
{
453+
"name": "stdout",
454+
"output_type": "stream",
455+
"text": [
456+
"simple ratio similarity score: 67\n",
457+
"partial ratio similarity score: 71\n",
458+
"ratio similarity score: 67\n",
459+
"token sort ratio similarity score: 75\n",
460+
"token set ratio similarity score: 100\n"
461+
]
462+
}
463+
],
464+
"source": [
465+
"measure_similarity(ex1, ex7)"
466+
]
467+
},
468+
{
469+
"cell_type": "code",
470+
"execution_count": 426,
471+
"metadata": {},
472+
"outputs": [],
473+
"source": [
474+
"## Correction accuracy check"
475+
]
361476
},
362477
{
363478
"cell_type": "code",

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ ipykernel~=6.9.2
1313
openai~=0.16.0
1414
waitress~=2.1.1
1515
plotly~=5.6.0
16-
nbformat~=5.2.0
16+
nbformat~=5.2.0
17+
thefuzz~=0.19.0

0 commit comments

Comments
 (0)