10
10
from pathlib import Path
11
11
from typing import AsyncIterator , Awaitable , Callable , cast
12
12
13
- import aiosqlite
14
13
import anyio
15
14
from anyio import TASK_STATUS_IGNORED , Event , Lock , create_task_group
16
15
from anyio .abc import TaskGroup , TaskStatus
17
16
from pycrdt import Doc
17
+ from sqlite_anyio import Connection , connect
18
18
19
19
from .yutils import Decoder , get_new_path , write_var_uint
20
20
@@ -83,11 +83,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
83
83
if self ._task_group is not None :
84
84
raise RuntimeError ("YStore already running" )
85
85
86
- self .started .set ()
87
- self ._starting = False
88
- task_status .started ()
86
+ async with create_task_group () as self ._task_group :
87
+ self .started .set ()
88
+ self ._starting = False
89
+ task_status .started ()
89
90
90
- def stop (self ) -> None :
91
+ async def stop (self ) -> None :
91
92
"""Stop the store."""
92
93
if self ._task_group is None :
93
94
raise RuntimeError ("YStore not running" )
@@ -300,6 +301,7 @@ class MySQLiteYStore(SQLiteYStore):
300
301
path : str
301
302
lock : Lock
302
303
db_initialized : Event
304
+ _db : Connection
303
305
304
306
def __init__ (
305
307
self ,
@@ -340,43 +342,54 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
340
342
self ._starting = False
341
343
task_status .started ()
342
344
345
+ async def stop (self ) -> None :
346
+ """Stop the store."""
347
+ if self .db_initialized .is_set ():
348
+ await self ._db .close ()
349
+ await super ().stop ()
350
+
343
351
async def _init_db (self ):
344
352
create_db = False
345
353
move_db = False
346
354
if not await anyio .Path (self .db_path ).exists ():
347
355
create_db = True
348
356
else :
349
357
async with self .lock :
350
- async with aiosqlite .connect (self .db_path ) as db :
351
- cursor = await db .execute (
352
- "SELECT count(name) FROM sqlite_master "
353
- "WHERE type='table' and name='yupdates'"
354
- )
355
- table_exists = (await cursor .fetchone ())[0 ]
356
- if table_exists :
357
- cursor = await db .execute ("pragma user_version" )
358
- version = (await cursor .fetchone ())[0 ]
359
- if version != self .version :
360
- move_db = True
361
- create_db = True
362
- else :
358
+ db = await connect (self .db_path )
359
+ cursor = await db .cursor ()
360
+ await cursor .execute (
361
+ "SELECT count(name) FROM sqlite_master "
362
+ "WHERE type='table' and name='yupdates'"
363
+ )
364
+ table_exists = (await cursor .fetchone ())[0 ]
365
+ if table_exists :
366
+ await cursor .execute ("pragma user_version" )
367
+ version = (await cursor .fetchone ())[0 ]
368
+ if version != self .version :
369
+ move_db = True
363
370
create_db = True
371
+ else :
372
+ create_db = True
373
+ await db .close ()
364
374
if move_db :
365
375
new_path = await get_new_path (self .db_path )
366
376
self .log .warning ("YStore version mismatch, moving %s to %s" , self .db_path , new_path )
367
377
await anyio .Path (self .db_path ).rename (new_path )
368
378
if create_db :
369
379
async with self .lock :
370
- async with aiosqlite .connect (self .db_path ) as db :
371
- await db .execute (
372
- "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
373
- "metadata BLOB, timestamp REAL NOT NULL)"
374
- )
375
- await db .execute (
376
- "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
377
- )
378
- await db .execute (f"PRAGMA user_version = { self .version } " )
379
- await db .commit ()
380
+ db = await connect (self .db_path )
381
+ cursor = await db .cursor ()
382
+ await cursor .execute (
383
+ "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
384
+ "metadata BLOB, timestamp REAL NOT NULL)"
385
+ )
386
+ await cursor .execute (
387
+ "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
388
+ )
389
+ await cursor .execute (f"PRAGMA user_version = { self .version } " )
390
+ await db .commit ()
391
+ await db .close ()
392
+ self ._db = await connect (self .db_path )
380
393
self .db_initialized .set ()
381
394
382
395
async def read (self ) -> AsyncIterator [tuple [bytes , bytes , float ]]:
@@ -388,17 +401,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
388
401
await self .db_initialized .wait ()
389
402
try :
390
403
async with self .lock :
391
- async with aiosqlite . connect ( self .db_path ) as db :
392
- async with db .execute (
393
- "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?" ,
394
- (self .path ,),
395
- ) as cursor :
396
- found = False
397
- async for update , metadata , timestamp in cursor :
398
- found = True
399
- yield update , metadata , timestamp
400
- if not found :
401
- raise YDocNotFound
404
+ cursor = await self ._db . cursor ()
405
+ await cursor .execute (
406
+ "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?" ,
407
+ (self .path ,),
408
+ )
409
+ found = False
410
+ for update , metadata , timestamp in await cursor . fetchall () :
411
+ found = True
412
+ yield update , metadata , timestamp
413
+ if not found :
414
+ raise YDocNotFound
402
415
except Exception :
403
416
raise YDocNotFound
404
417
@@ -410,38 +423,35 @@ async def write(self, data: bytes) -> None:
410
423
"""
411
424
await self .db_initialized .wait ()
412
425
async with self .lock :
413
- async with aiosqlite .connect (self .db_path ) as db :
414
- # first, determine time elapsed since last update
415
- cursor = await db .execute (
416
- "SELECT timestamp FROM yupdates WHERE path = ? "
417
- "ORDER BY timestamp DESC LIMIT 1" ,
418
- (self .path ,),
419
- )
420
- row = await cursor .fetchone ()
421
- diff = (time .time () - row [0 ]) if row else 0
422
-
423
- if self .document_ttl is not None and diff > self .document_ttl :
424
- # squash updates
425
- ydoc = Doc ()
426
- async with db .execute (
427
- "SELECT yupdate FROM yupdates WHERE path = ?" , (self .path ,)
428
- ) as cursor :
429
- async for (update ,) in cursor :
430
- ydoc .apply_update (update )
431
- # delete history
432
- await db .execute ("DELETE FROM yupdates WHERE path = ?" , (self .path ,))
433
- # insert squashed updates
434
- squashed_update = ydoc .get_update ()
435
- metadata = await self .get_metadata ()
436
- await db .execute (
437
- "INSERT INTO yupdates VALUES (?, ?, ?, ?)" ,
438
- (self .path , squashed_update , metadata , time .time ()),
439
- )
440
-
441
- # finally, write this update to the DB
426
+ # first, determine time elapsed since last update
427
+ cursor = await self ._db .cursor ()
428
+ await cursor .execute (
429
+ "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1" ,
430
+ (self .path ,),
431
+ )
432
+ row = await cursor .fetchone ()
433
+ diff = (time .time () - row [0 ]) if row else 0
434
+
435
+ if self .document_ttl is not None and diff > self .document_ttl :
436
+ # squash updates
437
+ ydoc = Doc ()
438
+ await cursor .execute ("SELECT yupdate FROM yupdates WHERE path = ?" , (self .path ,))
439
+ for (update ,) in await cursor .fetchall ():
440
+ ydoc .apply_update (update )
441
+ # delete history
442
+ await cursor .execute ("DELETE FROM yupdates WHERE path = ?" , (self .path ,))
443
+ # insert squashed updates
444
+ squashed_update = ydoc .get_update ()
442
445
metadata = await self .get_metadata ()
443
- await db .execute (
446
+ await cursor .execute (
444
447
"INSERT INTO yupdates VALUES (?, ?, ?, ?)" ,
445
- (self .path , data , metadata , time .time ()),
448
+ (self .path , squashed_update , metadata , time .time ()),
446
449
)
447
- await db .commit ()
450
+
451
+ # finally, write this update to the DB
452
+ metadata = await self .get_metadata ()
453
+ await cursor .execute (
454
+ "INSERT INTO yupdates VALUES (?, ?, ?, ?)" ,
455
+ (self .path , data , metadata , time .time ()),
456
+ )
457
+ await self ._db .commit ()
0 commit comments