|
1 | 1 | import functools
|
2 | 2 |
|
3 | 3 | from sqlalchemy import create_engine
|
4 |
| -from sqlalchemy.orm import sessionmaker, close_all_sessions, Session |
| 4 | +from sqlalchemy.orm import sessionmaker, close_all_sessions, Session, \ |
| 5 | + scoped_session |
5 | 6 | from yhttp.core import HTTPStatus
|
6 | 7 |
|
7 | 8 |
|
8 |
| -class DatabaseManager: |
9 |
| - def __init__(self, app, basemodel): |
10 |
| - self.app = app |
| 9 | +class ORM: |
| 10 | + def __init__(self, basemodel, url=None): |
| 11 | + self.url = url |
11 | 12 | self.engine = None
|
12 |
| - self.sessionfactory = sessionmaker() |
13 | 13 | self.basemodel = basemodel
|
| 14 | + self.session = scoped_session(sessionmaker()) |
14 | 15 |
|
15 |
| - def __enter__(self) -> sessionmaker: |
16 |
| - if self.engine is None: |
17 |
| - if 'db' not in self.app.settings: |
18 |
| - raise ValueError( |
19 |
| - 'Please provide db.url configuration entry, for example: ' |
20 |
| - 'postgresql://:@/dbname' |
21 |
| - ) |
22 |
| - |
23 |
| - self.engine = create_engine( |
24 |
| - self.app.settings.db.url, |
25 |
| - isolation_level='REPEATABLE READ' |
26 |
| - ) |
| 16 | + def copy(self, url=None): |
| 17 | + return ORM(self.basemodel, url=url or self.app.settings.db.url) |
27 | 18 |
|
28 |
| - self.sessionfactory.configure(bind=self.engine) |
29 |
| - return self.sessionfactory |
| 19 | + def create_objects(self): |
| 20 | + return self.basemodel.metadata.create_all(self.engine) |
30 | 21 |
|
31 |
| - def __exit__(self, exc_type, exc_value, traceback): |
| 22 | + def connect(self, url=None): |
| 23 | + u = url or self.url |
| 24 | + assert self.engine is None |
| 25 | + assert u is not None |
| 26 | + |
| 27 | + self.engine = create_engine(u, isolation_level='REPEATABLE READ') |
| 28 | + self.session.configure(bind=self.engine) |
| 29 | + |
| 30 | + def disconnect(self): |
32 | 31 | close_all_sessions()
|
33 | 32 | self.engine.dispose()
|
| 33 | + self.engine = None |
34 | 34 |
|
35 |
| - def create_objects(self): |
36 |
| - return self.basemodel.metadata.create_all(self.engine) |
| 35 | + def __enter__(self) -> Session: |
| 36 | + self.connect() |
| 37 | + return self |
| 38 | + |
| 39 | + def __exit__(self, exc_type, exc_value, traceback): |
| 40 | + self.disconnect() |
37 | 41 |
|
38 |
| - def session(self) -> Session: |
39 |
| - return self.sessionfactory.begin() |
| 42 | + |
| 43 | +class ApplicationORM(ORM): |
| 44 | + def __init__(self, basemodel, app): |
| 45 | + self.app = app |
| 46 | + super().__init__(basemodel) |
| 47 | + |
| 48 | + def connect(self, url=None): |
| 49 | + if 'db' not in self.app.settings or 'url' not in self.app.settings.db: |
| 50 | + raise ValueError( |
| 51 | + 'Please provide db.url configuration entry, for example: ' |
| 52 | + 'postgresql://:@/dbname' |
| 53 | + ) |
| 54 | + |
| 55 | + return super().connect(url=url or self.app.settings.db.url) |
40 | 56 |
|
41 | 57 | def __call__(self, handler):
|
42 | 58 | @functools.wraps(handler)
|
43 |
| - def outter(req, *a, **kw): |
44 |
| - with self.session() as session: |
45 |
| - req.dbsession = session |
46 |
| - try: |
47 |
| - return handler(req, *a, **kw) |
48 |
| - except HTTPStatus as ex: |
49 |
| - if ex.keepheaders: |
50 |
| - return ex |
51 |
| - |
52 |
| - raise |
53 |
| - finally: |
54 |
| - del req.dbsession |
| 59 | + def outter(*a, **kw): |
| 60 | + try: |
| 61 | + return handler(*a, **kw) |
| 62 | + except HTTPStatus as ex: |
| 63 | + if ex.keepheaders: |
| 64 | + return ex |
| 65 | + |
| 66 | + raise |
| 67 | + finally: |
| 68 | + self.session.reset() |
55 | 69 |
|
56 | 70 | return outter
|
0 commit comments