diff --git a/src/python/sql_test.py b/src/python/sql_test.py index 74be988..0e6df4f 100644 --- a/src/python/sql_test.py +++ b/src/python/sql_test.py @@ -21,8 +21,8 @@ class SQLState: def read_table_meta(self, table_name: str) -> dict: return self.state.get(table_name, {}).get("metadata", {}) - def read_table_rows(self, table_name: str) -> dict: - return self.state.get(table_name, {}).get("rows", {}) + def read_table_rows(self, table_name: str) -> list[dict]: + return self.state.get(table_name, {}).get("rows", []) def read_information_schema(self) -> list[dict]: return [data["metadata"] for data in self.state.values()] @@ -104,6 +104,8 @@ def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]: for i, arg in enumerate(args): if arg == "VALUES": values_index = i + if values_index is None: + raise ValueError("VALUES not found") keys = " ".join(args[3:values_index]).replace("(", "").replace(")", "").split(",") keys = [key.strip() for key in keys] @@ -130,6 +132,8 @@ def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]: from_index = i if arg == "WHERE": where_index = i + if from_index is None: + raise ValueError("FROM not found") # get select keys by getting the slice of args before FROM select_keys = " ".join(args[1:from_index]).split(",")