Skip to content

Commit a88fa09

Browse files
author
Kareem Zidane
authored
Merge pull request #157 from cs50/refactor/scoped-session-fix
v7.0.0
2 parents b6647cf + d0b3f99 commit a88fa09

24 files changed

+995
-852
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
"Topic :: Software Development :: Libraries :: Python Modules"
1111
],
1212
description="CS50 library for Python",
13-
install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor"],
13+
install_requires=["Flask>=1.0", "SQLAlchemy<2", "sqlparse", "termcolor"],
1414
keywords="cs50",
1515
name="cs50",
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="6.0.5"
19+
version="7.0.0"
2020
)

src/cs50/__init__.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,5 @@
1-
import logging
2-
import os
3-
import sys
4-
5-
6-
# Disable cs50 logger by default
7-
logging.getLogger("cs50").disabled = True
8-
9-
# Import cs50_*
10-
from .cs50 import get_char, get_float, get_int, get_string
11-
try:
12-
from .cs50 import get_long
13-
except ImportError:
14-
pass
15-
16-
# Hook into flask importing
17-
from . import flask
18-
19-
# Wrap SQLAlchemy
1+
from .cs50 import get_float, get_int, get_string
202
from .sql import SQL
3+
from ._logger import _setup_logger
4+
5+
_setup_logger()

src/cs50/_engine.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import threading
2+
3+
from ._engine_util import create_engine
4+
5+
6+
thread_local_data = threading.local()
7+
8+
9+
class Engine:
10+
"""Wraps a SQLAlchemy engine.
11+
"""
12+
13+
def __init__(self, url):
14+
self._engine = create_engine(url)
15+
16+
def get_transaction_connection(self):
17+
"""
18+
:returns: A new connection with autocommit disabled (to be used for transactions).
19+
"""
20+
21+
_thread_local_connections()[self] = self._engine.connect().execution_options(
22+
autocommit=False)
23+
return self.get_existing_transaction_connection()
24+
25+
def get_connection(self):
26+
"""
27+
:returns: A new connection with autocommit enabled
28+
"""
29+
30+
return self._engine.connect().execution_options(autocommit=True)
31+
32+
def get_existing_transaction_connection(self):
33+
"""
34+
:returns: The transaction connection bound to this Engine instance, if one exists, or None.
35+
"""
36+
37+
return _thread_local_connections().get(self)
38+
39+
def close_transaction_connection(self):
40+
"""Closes the transaction connection bound to this Engine instance, if one exists and
41+
removes it.
42+
"""
43+
44+
connection = self.get_existing_transaction_connection()
45+
if connection:
46+
connection.close()
47+
del _thread_local_connections()[self]
48+
49+
def is_postgres(self):
50+
return self._engine.dialect.name in {"postgres", "postgresql"}
51+
52+
def __getattr__(self, attr):
53+
return getattr(self._engine, attr)
54+
55+
def _thread_local_connections():
56+
"""
57+
:returns: A thread local dict to keep track of transaction connection. If one does not exist,
58+
creates one.
59+
"""
60+
61+
try:
62+
connections = thread_local_data.connections
63+
except AttributeError:
64+
connections = thread_local_data.connections = {}
65+
66+
return connections

src/cs50/_engine_util.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Utility functions used by _session.py.
2+
"""
3+
4+
import os
5+
import sqlite3
6+
7+
import sqlalchemy
8+
9+
sqlite_url_prefix = "sqlite:///"
10+
11+
12+
def create_engine(url, **kwargs):
13+
"""Creates a new SQLAlchemy engine. If ``url`` is a URL for a SQLite database, makes sure that
14+
the SQLite file exits and enables foreign key constraints.
15+
"""
16+
17+
try:
18+
engine = sqlalchemy.create_engine(url, **kwargs)
19+
except sqlalchemy.exc.ArgumentError:
20+
raise RuntimeError(f"invalid URL: {url}") from None
21+
22+
if _is_sqlite_url(url):
23+
_assert_sqlite_file_exists(url)
24+
sqlalchemy.event.listen(engine, "connect", _enable_sqlite_foreign_key_constraints)
25+
26+
return engine
27+
28+
def _is_sqlite_url(url):
29+
return url.startswith(sqlite_url_prefix)
30+
31+
32+
def _assert_sqlite_file_exists(url):
33+
path = url[len(sqlite_url_prefix):]
34+
if not os.path.exists(path):
35+
raise RuntimeError(f"does not exist: {path}")
36+
if not os.path.isfile(path):
37+
raise RuntimeError(f"not a file: {path}")
38+
39+
40+
def _enable_sqlite_foreign_key_constraints(dbapi_connection, _):
41+
cursor = dbapi_connection.cursor()
42+
cursor.execute("PRAGMA foreign_keys=ON")
43+
cursor.close()

src/cs50/_logger.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Sets up logging for the library.
2+
"""
3+
4+
import logging
5+
import os.path
6+
import re
7+
import sys
8+
import traceback
9+
10+
import termcolor
11+
12+
13+
def green(msg):
14+
return _colored(msg, "green")
15+
16+
17+
def red(msg):
18+
return _colored(msg, "red")
19+
20+
21+
def yellow(msg):
22+
return _colored(msg, "yellow")
23+
24+
25+
def _colored(msg, color):
26+
return termcolor.colored(str(msg), color)
27+
28+
29+
def _setup_logger():
30+
_configure_default_logger()
31+
_patch_root_handler_format_exception()
32+
_configure_cs50_logger()
33+
_patch_excepthook()
34+
35+
36+
def _configure_default_logger():
37+
"""Configures a default handler and formatter to prevent flask and werkzeug from adding theirs.
38+
"""
39+
40+
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
41+
42+
43+
def _patch_root_handler_format_exception():
44+
"""Patches formatException for the root handler to use ``_format_exception``.
45+
"""
46+
47+
try:
48+
formatter = logging.root.handlers[0].formatter
49+
formatter.formatException = lambda exc_info: _format_exception(*exc_info)
50+
except IndexError:
51+
pass
52+
53+
54+
def _configure_cs50_logger():
55+
"""Disables the cs50 logger by default. Disables logging propagation to prevent messages from
56+
being logged more than once. Sets the logging handler and formatter.
57+
"""
58+
59+
_logger = logging.getLogger("cs50")
60+
_logger.disabled = True
61+
_logger.setLevel(logging.DEBUG)
62+
63+
# Log messages once
64+
_logger.propagate = False
65+
66+
handler = logging.StreamHandler()
67+
handler.setLevel(logging.DEBUG)
68+
69+
formatter = logging.Formatter("%(levelname)s: %(message)s")
70+
formatter.formatException = lambda exc_info: _format_exception(*exc_info)
71+
handler.setFormatter(formatter)
72+
_logger.addHandler(handler)
73+
74+
75+
def _patch_excepthook():
76+
sys.excepthook = lambda type_, value, exc_tb: print(
77+
_format_exception(type_, value, exc_tb), file=sys.stderr)
78+
79+
80+
def _format_exception(type_, value, exc_tb):
81+
"""Formats traceback, darkening entries from global site-packages directories and user-specific
82+
site-packages directory.
83+
https://stackoverflow.com/a/46071447/5156190
84+
"""
85+
86+
# Absolute paths to site-packages
87+
packages = tuple(os.path.join(os.path.abspath(p), "") for p in sys.path[1:])
88+
89+
# Highlight lines not referring to files in site-packages
90+
lines = []
91+
for line in traceback.format_exception(type_, value, exc_tb):
92+
matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line)
93+
if matches and matches.group(1).startswith(packages):
94+
lines += line
95+
else:
96+
matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL)
97+
lines.append(matches.group(1) + yellow(matches.group(2)) + matches.group(3))
98+
return "".join(lines).rstrip()

src/cs50/_sql_sanitizer.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import datetime
2+
import re
3+
4+
import sqlalchemy
5+
import sqlparse
6+
7+
8+
class SQLSanitizer:
9+
"""Sanitizes SQL values.
10+
"""
11+
12+
def __init__(self, dialect):
13+
self._dialect = dialect
14+
15+
def escape(self, value):
16+
"""Escapes value using engine's conversion function.
17+
https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
18+
19+
:param value: The value to be sanitized
20+
21+
:returns: The sanitized value
22+
"""
23+
# pylint: disable=too-many-return-statements
24+
if isinstance(value, (list, tuple)):
25+
return self.escape_iterable(value)
26+
27+
if isinstance(value, bool):
28+
return sqlparse.sql.Token(
29+
sqlparse.tokens.Number,
30+
sqlalchemy.types.Boolean().literal_processor(self._dialect)(value))
31+
32+
if isinstance(value, bytes):
33+
if self._dialect.name in {"mysql", "sqlite"}:
34+
# https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
35+
return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'")
36+
if self._dialect.name in {"postgres", "postgresql"}:
37+
# https://dba.stackexchange.com/a/203359
38+
return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'")
39+
40+
raise RuntimeError(f"unsupported value: {value}")
41+
42+
string_processor = sqlalchemy.types.String().literal_processor(self._dialect)
43+
if isinstance(value, datetime.date):
44+
return sqlparse.sql.Token(
45+
sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d")))
46+
47+
if isinstance(value, datetime.datetime):
48+
return sqlparse.sql.Token(
49+
sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d %H:%M:%S")))
50+
51+
if isinstance(value, datetime.time):
52+
return sqlparse.sql.Token(
53+
sqlparse.tokens.String, string_processor(value.strftime("%H:%M:%S")))
54+
55+
if isinstance(value, float):
56+
return sqlparse.sql.Token(
57+
sqlparse.tokens.Number,
58+
sqlalchemy.types.Float().literal_processor(self._dialect)(value))
59+
60+
if isinstance(value, int):
61+
return sqlparse.sql.Token(
62+
sqlparse.tokens.Number,
63+
sqlalchemy.types.Integer().literal_processor(self._dialect)(value))
64+
65+
if isinstance(value, str):
66+
return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value))
67+
68+
if value is None:
69+
return sqlparse.sql.Token(
70+
sqlparse.tokens.Keyword,
71+
sqlalchemy.types.NullType().literal_processor(self._dialect)(value))
72+
73+
raise RuntimeError(f"unsupported value: {value}")
74+
75+
def escape_iterable(self, iterable):
76+
"""Escapes each value in iterable and joins all the escaped values with ", ", formatted for
77+
SQL's ``IN`` operator.
78+
79+
:param: An iterable of values to be escaped
80+
81+
:returns: A comma-separated list of escaped values from ``iterable``
82+
:rtype: :class:`sqlparse.sql.TokenList`
83+
"""
84+
85+
return sqlparse.sql.TokenList(
86+
sqlparse.parse(", ".join([str(self.escape(v)) for v in iterable])))
87+
88+
89+
def escape_verbatim_colon(value):
90+
"""Escapes verbatim colon from a value so as it is not confused with a parameter marker.
91+
"""
92+
93+
# E.g., ':foo, ":foo, :foo will be replaced with
94+
# '\:foo, "\:foo, \:foo respectively
95+
return re.sub(r"(^(?:'|\")|\s+):", r"\1\:", value)

src/cs50/_sql_util.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Utility functions used by sql.py.
2+
"""
3+
4+
import contextlib
5+
import decimal
6+
import warnings
7+
8+
import sqlalchemy
9+
10+
11+
def process_select_result(result):
12+
"""Converts a SQLAlchemy result to a ``list`` of ``dict`` objects, each of which represents a
13+
row in the result set.
14+
15+
:param result: A SQLAlchemy result
16+
:type result: :class:`sqlalchemy.engine.Result`
17+
"""
18+
rows = [dict(row) for row in result.fetchall()]
19+
for row in rows:
20+
for column in row:
21+
# Coerce decimal.Decimal objects to float objects
22+
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
23+
if isinstance(row[column], decimal.Decimal):
24+
row[column] = float(row[column])
25+
26+
# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
27+
elif isinstance(row[column], memoryview):
28+
row[column] = bytes(row[column])
29+
30+
return rows
31+
32+
33+
@contextlib.contextmanager
34+
def raise_errors_for_warnings():
35+
"""Catches warnings and raises errors instead.
36+
"""
37+
38+
with warnings.catch_warnings():
39+
warnings.simplefilter("error")
40+
yield
41+
42+
43+
def postgres_lastval(connection):
44+
"""
45+
:returns: The ID of the last inserted row, if defined in this session, or None
46+
"""
47+
48+
try:
49+
return connection.execute("SELECT LASTVAL()").first()[0]
50+
except sqlalchemy.exc.OperationalError:
51+
return None

0 commit comments

Comments
 (0)