From bdbe926968c20d80660d7ec03ee0a1b4c32f52dd Mon Sep 17 00:00:00 2001 From: Lynn Date: Sun, 22 Oct 2023 03:32:53 -0700 Subject: [PATCH] typing --- src/python/sql_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/python/sql_test.py b/src/python/sql_test.py index 74be988..0e6df4f 100644 --- a/src/python/sql_test.py +++ b/src/python/sql_test.py @@ -21,8 +21,8 @@ class SQLState: def read_table_meta(self, table_name: str) -> dict: return self.state.get(table_name, {}).get("metadata", {}) - def read_table_rows(self, table_name: str) -> dict: - return self.state.get(table_name, {}).get("rows", {}) + def read_table_rows(self, table_name: str) -> list[dict]: + return self.state.get(table_name, {}).get("rows", []) def read_information_schema(self) -> list[dict]: return [data["metadata"] for data in self.state.values()] @@ -104,6 +104,8 @@ def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]: for i, arg in enumerate(args): if arg == "VALUES": values_index = i + if values_index is None: + raise ValueError("VALUES not found") keys = " ".join(args[3:values_index]).replace("(", "").replace(")", "").split(",") keys = [key.strip() for key in keys] @@ -130,6 +132,8 @@ def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]: from_index = i if arg == "WHERE": where_index = i + if from_index is None: + raise ValueError("FROM not found") # get select keys by getting the slice of args before FROM select_keys = " ".join(args[1:from_index]).split(",")