Skip to content

Commit

Permalink
Add merge sort tool
Browse files Browse the repository at this point in the history
  • Loading branch information
Antti Kaihola committed Jul 4, 2020
1 parent 557e9f1 commit ce65045
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
50 changes: 50 additions & 0 deletions pgtricks/mergesort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from heapq import merge
from tempfile import TemporaryFile
from typing import List, IO, Iterable, Iterator, Optional, cast

import sys


class MergeSort(Iterable[str]):
def __init__(self, directory: str = ".", max_memory: int = 190) -> None:
self._directory = directory
self._max_memory = max_memory
self._partitions: List[IO[str]] = []
self._iterating: Optional[Iterable[str]] = None
self._buffer: List[str] = []
self._memory_counter = 0
self._flush()

def append(self, line: str) -> None:
if self._iterating:
raise ValueError("Can't append lines after starting to sort")
self._memory_counter -= sys.getsizeof(self._buffer)
self._buffer.append(line)
self._memory_counter += sys.getsizeof(self._buffer)
self._memory_counter += sys.getsizeof(line)
if self._memory_counter >= self._max_memory:
self._flush()

def _flush(self) -> None:
if self._buffer:
self._partitions.append(TemporaryFile(mode="w+", dir=self._directory))
self._partitions[-1].writelines(sorted(self._buffer))
self._buffer = []
self._memory_counter = sys.getsizeof(self._buffer)

def __next__(self) -> str:
if not self._iterating:
if self._partitions:
# At least one partition has already been flushed to disk.
# Iterate the merge sort for all partitions.
self._flush()
for partition in self._partitions:
partition.seek(0)
self._iterating = merge(*self._partitions)
else:
# All lines fit in memory. Iterate the list of lines directly.
self._iterating = iter(sorted(self._buffer))
return next(cast(Iterator[str], self._iterating))

def __iter__(self) -> Iterator[str]:
return self
77 changes: 77 additions & 0 deletions pgtricks/tests/test_mergesort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from types import GeneratorType
from typing import cast, Iterable

import pytest

from pgtricks.mergesort import MergeSort


def test_mergesort_append(tmpdir):
m = MergeSort(directory=tmpdir, max_memory=190)
m.append('1\n')
assert m._buffer == ['1\n']
m.append('2\n')
assert m._buffer == []
m.append('3\n')
assert m._buffer == ['3\n']
assert len(m._partitions) == 1
assert m._partitions[0].tell() == 4
m._partitions[0].seek(0)
assert m._partitions[0].read() == '1\n2\n'


def test_mergesort_flush(tmpdir):
m = MergeSort(directory=tmpdir, max_memory=190)
for value in [1, 2, 3]:
m.append(f'{value}\n')
m._flush()
assert len(m._partitions) == 2
assert m._partitions[0].tell() == 4
m._partitions[0].seek(0)
assert m._partitions[0].read() == '1\n2\n'
assert m._partitions[1].tell() == 2
m._partitions[1].seek(0)
assert m._partitions[1].read() == '3\n'


def test_mergesort_iterate_disk(tmpdir):
m = MergeSort(directory=tmpdir, max_memory=190)
for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]:
m.append(f'{value}\n')
assert next(m) == '1\n'
assert isinstance(m._iterating, GeneratorType)
assert next(m) == '1\n'
assert next(m) == '2\n'
assert next(m) == '3\n'
assert next(m) == '3\n'
assert next(m) == '4\n'
assert next(m) == '4\n'
assert next(m) == '5\n'
assert next(m) == '5\n'
assert next(m) == '6\n'
assert next(m) == '8\n'
assert next(m) == '9\n'
with pytest.raises(StopIteration):
next(m)


def test_mergesort_iterate_memory(tmpdir):
m = MergeSort(directory=tmpdir, max_memory=1000000)
for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]:
m.append(f'{value}\n')
assert next(m) == '1\n'
assert not isinstance(m._iterating, GeneratorType)
assert iter(cast(Iterable[str], m._iterating)) is m._iterating
assert next(m) == '1\n'
assert next(m) == '2\n'
assert next(m) == '3\n'
assert next(m) == '3\n'
assert next(m) == '4\n'
assert next(m) == '4\n'
assert next(m) == '5\n'
assert next(m) == '5\n'
assert next(m) == '6\n'
assert next(m) == '8\n'
assert next(m) == '9\n'
with pytest.raises(StopIteration):
next(m)

0 comments on commit ce65045

Please sign in to comment.