Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add faster version of KL-Sum using numpy #200

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove print statement
mamei16 committed Dec 2, 2023
commit 3e73148e4831ed699a8ea76948b4694d19424b7e
7 changes: 3 additions & 4 deletions sumy/summarizers/fast_kl.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ class KLSummarizer(AbstractSummarizer):
KL Divergence.
Source: http://www.aclweb.org/anthology/N09-1041
"""
MISSING_WORD_VAL = 42. # placeholder value used for missing words in document
MISSING_WORD_VAL = 42.0 # placeholder value used for missing words in document
stop_words = frozenset()

def __call__(self, document, sentences_count):
@@ -89,7 +89,7 @@ def _joint_freq(wc1, wc2, total_len):

@staticmethod
def _kl_divergence(summary_freq, doc_freq, doc_missing_word_mask):
summary_freq = np.where((summary_freq != 0.) & doc_missing_word_mask, summary_freq, doc_freq)
summary_freq = np.where((summary_freq != 0.0) & doc_missing_word_mask, summary_freq, doc_freq)
return (doc_freq * np.log(doc_freq / summary_freq)).sum()

@staticmethod
@@ -114,7 +114,7 @@ def _compute_ratings(self, sentences):

# Keep track of number of words in summary and word frequency
summary_word_list_len = 0
summary_word_freq = np.repeat(0., len(vocabulary))
summary_word_freq = np.repeat(0.0, len(vocabulary))

# make it a list so that it can be modified
sentences_list = list(sentences)
@@ -159,5 +159,4 @@ def _compute_ratings(self, sentences):
# value is the iteration in which it was removed multiplied by -1 so that
# the first sentences removed (the most important) have highest values
ratings[best_sentence] = -1 * len(ratings)
print(f"Num interations: {iterations}")
return ratings