From 274f2dc7994dccf9d8db5bfc454f0f4c3b837443 Mon Sep 17 00:00:00 2001 From: Praneeth Date: Tue, 13 Aug 2019 18:29:17 +0530 Subject: [PATCH] add segment_long --- deepsegment/deepsegment.py | 57 +++++++++++++++++++++++++++++++++++--- deepsegment/test.py | 5 ---- 2 files changed, 53 insertions(+), 9 deletions(-) delete mode 100644 deepsegment/test.py diff --git a/deepsegment/deepsegment.py b/deepsegment/deepsegment.py index 795f6ba..44d9c90 100644 --- a/deepsegment/deepsegment.py +++ b/deepsegment/deepsegment.py @@ -4,6 +4,8 @@ import pydload import pickle import os +import logging +import time model_links = { 'en': { @@ -14,6 +16,17 @@ } +def chunk(l, n): + """Yield successive n-sized chunks from l.""" + chunked_l = [] + for i in range(0, len(l), n): + chunked_l.append(l[i:i + n]) + + if not chunked_l: + chunked_l = [l] + + return chunked_l + class DeepSegment(): seqtag_model = None data_converter = None @@ -54,25 +67,61 @@ def segment(self, sents): if not DeepSegment.seqtag_model: print('Please load the model first') + string_output = False if not isinstance(sents, list): + logging.warn("Batch input strings for faster inference.") + string_output = True sents = [sents] sents = [sent.strip().split() for sent in sents] + + max_len = len(max(sents, key=len)) + if max_len >= 40: + logging.warn("Consider using segment_long for longer sentences.") + encoded_sents = DeepSegment.data_converter.transform(sents) all_tags = DeepSegment.seqtag_model.predict(encoded_sents) all_tags = [np.argmax(_tags, axis=1).tolist() for _tags in all_tags] - segmented_sentences = [] - for sent, tags in zip(sents, all_tags): + segmented_sentences = [[] for _ in sents] + for sent_index, (sent, tags) in enumerate(zip(sents, all_tags)): segmented_sent = [] for i, (word, tag) in enumerate(zip(sent, tags)): if tag == 2 and i > 0 and segmented_sent: segmented_sent = ' '.join(segmented_sent) - segmented_sentences.append(segmented_sent) + segmented_sentences[sent_index].append(segmented_sent) segmented_sent = [] segmented_sent.append(word) if segmented_sent: - segmented_sentences.append(' '.join(segmented_sent)) + segmented_sentences[sent_index].append(' '.join(segmented_sent)) + if string_output: + return segmented_sentences[0] + return segmented_sentences + + def segment_long(self, sent, n_window=None): + if not n_window: + logging.warn("Using default n_window=10. Set this parameter based on your data.") + n_window = 10 + + if isinstance(sent, list): + logging.error("segment_long doesn't support batching as of now. Batching will be added in a future release.") + return None + + segmented = [] + sent = sent.split() + prefix = [] + while sent: + current_n_window = n_window - len(prefix) + if current_n_window == 0: + current_n_window = n_window + + window = prefix + sent[:current_n_window] + sent = sent[current_n_window:] + segmented_window = self.segment([' '.join(window)])[0] + segmented += segmented_window[:-1] + prefix = segmented_window[-1].split() + + return segmented \ No newline at end of file diff --git a/deepsegment/test.py b/deepsegment/test.py deleted file mode 100644 index 41ca559..0000000 --- a/deepsegment/test.py +++ /dev/null @@ -1,5 +0,0 @@ -from deepsegment import DeepSegment - -m = DeepSegment() - -print(m.segment('I am batman who are you'))