Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition that prevents full flush in the BulkProcessor #185

Merged
merged 2 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 80 additions & 14 deletions tests/integration-tests/helpers_bulkprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,27 @@ def setup_class(self):

assert self.client.databases().create(self.db_name).is_success()
assert self.client.table().create("Posts").is_success()
assert self.client.table().create("Users").is_success()

# create schema
r = self.client.table().set_schema(
assert self.client.table().set_schema(
"Posts",
{
"columns": [
{"name": "title", "type": "string"},
{"name": "text", "type": "text"},
]
},
)
assert r.is_success()
).is_success()
assert self.client.table().set_schema(
"Users",
{
"columns": [
{"name": "username", "type": "string"},
{"name": "email", "type": "string"},
]
},
).is_success()

def teardown_class(self):
assert self.client.databases().delete(self.db_name).is_success()
Expand All @@ -61,24 +70,81 @@ def _get_record(self) -> dict:
"title": self.fake.company(),
"text": self.fake.text(),
}

def _get_user(self) -> dict:
return {
"username": self.fake.name(),
"email": self.fake.email(),
}

def test_bulk_insert_records(self, record: dict):
pt = 2
bp = BulkProcessor(
self.client,
thread_pool_size=1,
batch_size=5,
)
bp.put_records("Posts", [self._get_record() for x in range(42)])
bp.flush_queue()

r = self.client.data().summarize("Posts", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 42

stats = bp.get_stats()
assert stats["total"] == 42
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["tables"]["Posts"] == 42

def test_flush_queue(self):
assert self.client.sql().query('DELETE FROM "Posts" WHERE 1 = 1').is_success()

bp = BulkProcessor(
self.client,
thread_pool_size=4,
batch_size=50,
flush_interval=1,
processing_timeout=pt,
)
bp.put_records("Posts", [self._get_record() for x in range(10)])
bp.put_records("Posts", [self._get_record() for x in range(1000)])
bp.flush_queue()

r = self.client.data().summarize("Posts", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 1000

stats = bp.get_stats()
assert stats["total"] == 1000
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["tables"]["Posts"] == 1000

def test_multiple_tables(self):
assert self.client.sql().query('DELETE FROM "Posts" WHERE 1 = 1').is_success()

bp = BulkProcessor(
self.client,
thread_pool_size=3,
batch_size=42,
)
for it in range(33):
bp.put_records("Posts", [self._get_record() for x in range(9)])
bp.put_records("Users", [self._get_user() for x in range(7)])
bp.flush_queue()

# wait until indexed :shrug:
time.sleep(pt)
utils.wait_until_records_are_indexed("Posts")
r = self.client.data().summarize("Posts", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 33 * 9

r = self.client.search_and_filter().search_table("Posts", {})
r = self.client.data().summarize("Users", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "records" in r
assert len(r["records"]) > 0
assert len(r["records"]) <= 10
assert "summaries" in r
assert r["summaries"][0]["proof"] == 33 * 7

stats = bp.get_stats()
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["tables"]["Posts"] == 33 * 9
assert stats["tables"]["Users"] == 33 * 7
assert stats["total"] == stats["tables"]["Posts"] + stats["tables"]["Users"]
38 changes: 25 additions & 13 deletions xata/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
BP_DEFAULT_FLUSH_INTERVAL = 5
BP_DEFAULT_PROCESSING_TIMEOUT = 0.025
BP_DEFAULT_THROW_EXCEPTION = False
BP_VERSION = "0.2.1"
BP_VERSION = "0.3.0"
TRX_MAX_OPERATIONS = 1000
TRX_VERSION = "0.1.0"
TRX_BACKOFF = 0.1
Expand Down Expand Up @@ -182,19 +182,23 @@ def get_stats(self):
def flush_queue(self):
"""
Flush all records from the queue.
https://github.com/xataio/xata-py/issues/184
"""
self.logger.debug("flushing queue with %d records .." % (self.records.size()))
self.records.set_flush_interval(0)
self.processing_timeout = 0

# If the queue is not empty wait for one flush interval.
# Purpose is a race condition with self.stats["queue"]
if self.records.size() > 0:
time.sleep(self.flush_interval)
# force flush the records queue and shorten the processing times
self.records.force_queue_flush()
self.processing_timeout = 0.001
wait = 0.005 * len(self.thread_workers)

while self.stats["queue"] > 0:
while self.records.size() > 0:
self.logger.debug("flushing queue with %d records." % self.stats["queue"])
time.sleep(self.processing_timeout / len(self.thread_workers) + 0.01)
time.sleep(wait)

# Last poor mans check if queue is fully flushed
if self.records.size() > 0 or self.stats["queue"] > 0:
self.logger.debug("one more flush interval necessary with queue at %d records." % self.stats["queue"])
time.sleep(wait)

class Records(object):
"""
Expand All @@ -208,14 +212,22 @@ def __init__(self, batch_size: int, flush_interval: int, logger):
"""
self.batch_size = batch_size
self.flush_interval = flush_interval
self.force_flush = False
self.logger = logger

self.store = dict()
self.store_ptr = 0
self.lock = Lock()

def set_flush_interval(self, interval: int):
self.flush_interval = interval
def force_queue_flush(self):
"""
Force next batch to be available
https://github.com/xataio/xata-py/issues/184
"""
with self.lock:
self.force_flush = True
self.flush_interval = 0.001
self.batch_size = 1

def put(self, table_name: str, records: list[dict]):
"""
Expand Down Expand Up @@ -264,8 +276,8 @@ def next_batch(self) -> dict:
self.flush_interval,
)
)
# pop records ?
if len(self.store[table_name]["records"]) >= self.batch_size or flush_needed:
# force flush table, batch size reached or timer exceeded
if self.force_flush or len(self.store[table_name]["records"]) >= self.batch_size or flush_needed:
self.store[table_name]["flushed"] = time.time()
rs = self.store[table_name]["records"][0 : self.batch_size]
del self.store[table_name]["records"][0 : self.batch_size]
Expand Down
Loading