diff --git a/data-scripts/count_wikipedia.py b/data-scripts/count_wikipedia.py old mode 100755 new mode 100644 index aacf8d0f..348c1c38 --- a/data-scripts/count_wikipedia.py +++ b/data-scripts/count_wikipedia.py @@ -3,16 +3,18 @@ import sys import os import re -import codecs import operator import datetime -import nltk import warnings +import multiprocessing +import time +import io +import nltk from unidecode import unidecode def usage(): - print ''' + print(''' tokenize a directory of text and count unigrams. usage: @@ -48,14 +50,16 @@ def usage(): Then run: ./WikiExtractor.py -o en_sents --no-templates enwiki-20151002-pages-articles.xml.bz2 -''' % sys.argv[0] +''' % sys.argv[0]) -SENTENCES_PER_BATCH = 500000 # after each batch, delete all counts with count == 1 (hapax legomena) -PRE_SORT_CUTOFF = 300 # before sorting, discard all words with less than this count + +SENTENCES_PER_BATCH = 500000 # after each batch, delete all counts with count == 1 (hapax legomena) +PRE_SORT_CUTOFF = 300 # before sorting, discard all words with less than this count ALL_NON_ALPHA = re.compile(r'^[\W\d]*$', re.UNICODE) SOME_NON_ALPHA = re.compile(r'[\W\d]', re.UNICODE) + class TopTokenCounter(object): def __init__(self): self.count = {} @@ -110,7 +114,7 @@ def batch_prune(self): def pre_sort_prune(self): under_cutoff = set() - for token, count in self.count.iteritems(): + for token, count in self.count.items(): if count < PRE_SORT_CUTOFF: under_cutoff.add(token) for token in under_cutoff: @@ -127,43 +131,81 @@ def get_stats(self): ts = self.get_ts() return "%s keys(count): %d" % (ts, len(self.count)) + def merge(self, other): + self.discarded |= other.discarded + self.legomena ^= other.legomena + for token, num in other.count.items(): + if token in self.count: + self.count[token] += num + else: + self.count[token] = num + + +def count_file(path): + """ + Scan the file at given path, tokenize all lines and return the filled TopTokenCounter + and the number of processed lines. + """ + counter = TopTokenCounter() + lines = 0 + for line in io.open(path, 'r', encoding='utf8'): + with warnings.catch_warnings(): + # unidecode() occasionally (rarely but enough to clog terminal outout) + # complains about surrogate characters in some wikipedia sentences. + # ignore those warnings. + warnings.simplefilter('ignore') + line = unidecode(line) + tokens = nltk.word_tokenize(line) + counter.add_tokens(tokens) + lines += 1 + return counter, lines + + def main(input_dir_str, output_filename): counter = TopTokenCounter() - print counter.get_ts(), 'starting...' + print(counter.get_ts(), 'starting...') + tic = time.time() + pruned_lines = 0 lines = 0 - for root, dirs, files in os.walk(input_dir_str, topdown=True): - if not files: - continue - for fname in files: - path = os.path.join(root, fname) - for line in codecs.open(path, 'r', 'utf8'): - with warnings.catch_warnings(): - # unidecode() occasionally (rarely but enough to clog terminal outout) - # complains about surrogate characters in some wikipedia sentences. - # ignore those warnings. - warnings.simplefilter('ignore') - line = unidecode(line) - tokens = nltk.word_tokenize(line) - counter.add_tokens(tokens) - lines += 1 - if lines % SENTENCES_PER_BATCH == 0: - counter.batch_prune() - print counter.get_stats() - print 'processing: %s' % path - print counter.get_stats() - print 'deleting tokens under cutoff of', PRE_SORT_CUTOFF + files = 0 + process_pool = multiprocessing.Pool() + # Some python iterator magic: Pool.imap() maps the given function over the iterable + # using the process pool. The iterable is produced by creating the full path of every + # file in every directory (thus, the nested generator expression). + for fcounter, l in process_pool.imap( + count_file, (os.path.join(root, fname) + for root, dirs, files in os.walk(input_dir_str, topdown=True) + if files + for fname in files), 4): + lines += l + files += 1 + counter.merge(fcounter) + if (lines - pruned_lines) >= SENTENCES_PER_BATCH: + counter.batch_prune() + pruned_lines = lines + print(counter.get_stats()) + + toc = time.time() + print("Finished reading input data. Read %d files with %d lines in %.2fs." + % (files, lines, toc-tic)) + print(counter.get_stats()) + + print('deleting tokens under cutoff of', PRE_SORT_CUTOFF) counter.pre_sort_prune() - print 'done' - print counter.get_stats() - print counter.get_ts(), 'sorting...' + print('done') + print(counter.get_stats()) + + print(counter.get_ts(), 'sorting...') sorted_pairs = counter.get_sorted_pairs() - print counter.get_ts(), 'done' - print 'writing...' - with codecs.open(output_filename, 'w', 'utf8') as f: + print(counter.get_ts(), 'done') + + print('writing...') + with io.open(output_filename, 'w', encoding='utf8') as f: for token, count in sorted_pairs: f.write('%-18s %d\n' % (token, count)) sys.exit(0) + if __name__ == '__main__': if len(sys.argv) != 3: usage()