Skip to content

Commit 5c06978

Browse files
authored
Merge pull request #172 from up-n-atom/type-checking
Respect pep8 and revert 659c8f4
2 parents 3ccbd99 + 53cf4d2 commit 5c06978

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@
1818
package_dir={"": "src"},
1919
packages=["cs50"],
2020
url="https://github.com/cs50/python-cs50",
21-
version="9.2.4"
21+
version="9.2.5"
2222
)

src/cs50/cs50.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def get_string(prompt):
135135
as line endings. If user inputs only a line ending, returns "", not None.
136136
Returns None upon error or no input whatsoever (i.e., just EOF).
137137
"""
138-
if type(prompt) is not str:
138+
if not isinstance(prompt, str):
139139
raise TypeError("prompt must be of type str")
140140
try:
141141
return input(prompt)

src/cs50/sql.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def connect(dbapi_connection, connection_record):
8181

8282
# Enable foreign key constraints
8383
try:
84-
if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite
84+
if isinstance(dbapi_connection, sqlite3.Connection): # If back end is sqlite
8585
cursor = dbapi_connection.cursor()
8686
cursor.execute("PRAGMA foreign_keys=ON")
8787
cursor.close()
@@ -350,11 +350,11 @@ def teardown_appcontext(exception):
350350

351351
# Coerce decimal.Decimal objects to float objects
352352
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
353-
if type(row[column]) is decimal.Decimal:
353+
if isinstance(row[column], decimal.Decimal):
354354
row[column] = float(row[column])
355355

356356
# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
357-
elif type(row[column]) is memoryview:
357+
elif isinstance(row[column], memoryview):
358358
row[column] = bytes(row[column])
359359

360360
# Rows to be returned
@@ -432,52 +432,52 @@ def __escape(value):
432432
import sqlalchemy
433433

434434
# bool
435-
if type(value) is bool:
435+
if isinstance(value, bool):
436436
return sqlparse.sql.Token(
437437
sqlparse.tokens.Number,
438438
sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value))
439439

440440
# bytes
441-
elif type(value) is bytes:
441+
elif isinstance(value, bytes):
442442
if self._engine.url.get_backend_name() in ["mysql", "sqlite"]:
443443
return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
444444
elif self._engine.url.get_backend_name() == "postgresql":
445445
return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359
446446
else:
447447
raise RuntimeError("unsupported value: {}".format(value))
448448

449-
# datetime.date
450-
elif type(value) is datetime.date:
449+
# datetime.datetime
450+
elif isinstance(value, datetime.datetime):
451451
return sqlparse.sql.Token(
452452
sqlparse.tokens.String,
453-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d")))
453+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
454454

455-
# datetime.datetime
456-
elif type(value) is datetime.datetime:
455+
# datetime.date
456+
elif isinstance(value, datetime.date):
457457
return sqlparse.sql.Token(
458458
sqlparse.tokens.String,
459-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
459+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d")))
460460

461461
# datetime.time
462-
elif type(value) is datetime.time:
462+
elif isinstance(value, datetime.time):
463463
return sqlparse.sql.Token(
464464
sqlparse.tokens.String,
465465
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S")))
466466

467467
# float
468-
elif type(value) is float:
468+
elif isinstance(value, float):
469469
return sqlparse.sql.Token(
470470
sqlparse.tokens.Number,
471471
sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value))
472472

473473
# int
474-
elif type(value) is int:
474+
elif isinstance(value, int):
475475
return sqlparse.sql.Token(
476476
sqlparse.tokens.Number,
477477
sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value))
478478

479479
# str
480-
elif type(value) is str:
480+
elif isinstance(value, str):
481481
return sqlparse.sql.Token(
482482
sqlparse.tokens.String,
483483
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value))
@@ -493,7 +493,7 @@ def __escape(value):
493493
raise RuntimeError("unsupported value: {}".format(value))
494494

495495
# Escape value(s), separating with commas as needed
496-
if type(value) in [list, tuple]:
496+
if isinstance(value, (list, tuple)):
497497
return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value])))
498498
else:
499499
return __escape(value)

0 commit comments

Comments
 (0)