-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6ac07f6
commit 8ab234e
Showing
3 changed files
with
287 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
|
||
import dataclasses | ||
import json | ||
import re | ||
import typing | ||
|
||
import tokenizer_script | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class SQLState: | ||
state: dict | ||
|
||
def read_table_meta(self, table_name: str) -> dict: | ||
return self.state.get(table_name, {}).get("metadata", {}) | ||
|
||
def read_table_rows(self, table_name: str) -> list[dict]: | ||
return self.state.get(table_name, {}).get("rows", []) | ||
|
||
def read_information_schema(self) -> list[dict]: | ||
return [data["metadata"] for data in self.state.values()] | ||
|
||
def write_table_meta(self, table_name: str, data: dict): | ||
state = self.state | ||
table = state.get(table_name, {}) | ||
metadata = table.get("metadata", {}) | ||
metadata.update(data) | ||
table["metadata"] = metadata | ||
state[table_name] = table | ||
return self.__class__(state) | ||
|
||
def write_table_rows(self, table_name: str, data: dict): | ||
state = self.state | ||
table = state.get(table_name, {}) | ||
rows = table.get("rows", []) | ||
rows.append(data) | ||
table["rows"] = rows | ||
state[table_name] = table | ||
return self.__class__(state) | ||
|
||
|
||
class SQLType: | ||
@staticmethod | ||
def varchar(data) -> str: | ||
data_str = str(data).strip() | ||
data_str = re.sub(r'^["\']', "", data_str) # leading ' or " | ||
data_str = re.sub(r'["\']$', "", data_str) # trailing ' or " | ||
return data_str | ||
|
||
@staticmethod | ||
def int(data) -> int: | ||
return int(data.strip()) | ||
|
||
|
||
class SQLFunctions: | ||
@staticmethod | ||
def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[list, SQLState]: | ||
output: list[dict] = [] | ||
table_name = args[0] | ||
|
||
# get columns | ||
columns = {} | ||
columns_str = args[1] | ||
if columns_str: | ||
# fmt: off | ||
columns = { | ||
columns_str[i]: columns_str[i + 1] | ||
for i in range(0, len(columns_str), 2) | ||
} | ||
# fmt: on | ||
|
||
if not state.read_table_meta(table_name): | ||
state = state.write_table_meta( | ||
table_name, | ||
{ | ||
"table_name": table_name, | ||
"table_schema": table_schema, | ||
"colums": columns, | ||
}, | ||
) | ||
return (output, state) | ||
|
||
@staticmethod | ||
def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]: | ||
output: list[dict] = [] | ||
table_name = args[0] | ||
keys = args[1] | ||
values = args[3] | ||
key_value_map = dict(zip(keys, values)) | ||
|
||
sql_type_map = { | ||
"VARCHAR": SQLType.varchar, | ||
"INT": SQLType.int, | ||
} | ||
|
||
data = {} | ||
metadata = state.read_table_meta(table_name) | ||
if metadata: | ||
for key, value in key_value_map.items(): | ||
data[key] = sql_type_map[metadata["colums"][key]](value) | ||
state = state.write_table_rows(table_name, data) | ||
|
||
return (output, state) | ||
|
||
@staticmethod | ||
def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]: | ||
output: list[dict] = [] | ||
select_columns = args[0] if isinstance(args[0], list) else [args[0]] | ||
from_value = args[2] | ||
|
||
# `information_schema.tables` is a special case | ||
if from_value == "information_schema.tables": | ||
data = state.read_information_schema() | ||
else: | ||
data = state.read_table_rows(from_value) | ||
|
||
output = [] | ||
for datum in data: | ||
# fmt: off | ||
output.append({ | ||
key: datum.get(key) | ||
for key in select_columns | ||
}) | ||
# fmt: on | ||
|
||
return (output, state) | ||
|
||
|
||
def run_sql(input_sql: list[str]) -> list[str]: | ||
output = [] | ||
state = SQLState(state={}) | ||
sql_tokenizer = tokenizer_script.SQLTokenizer( | ||
{ | ||
"CREATE TABLE": SQLFunctions.create_table, | ||
"INSERT INTO": SQLFunctions.insert_into, | ||
"SELECT": SQLFunctions.select, | ||
} | ||
) | ||
sql_token_list = sql_tokenizer.tokenize_sql(input_sql) | ||
|
||
# iterate over each line of sql | ||
for sql_tokens in sql_token_list: | ||
output, state = sql_tokens.worker_func(state, *sql_tokens.args) | ||
|
||
return [json.dumps(output)] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
|
||
import dataclasses | ||
import json | ||
import os | ||
import typing | ||
|
||
|
||
DEBUG = bool(int(os.getenv("DEBUG", "0"))) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class SQLTokens: | ||
worker_str: str | ||
worker_func: typing.Callable | None | ||
args: list[typing.Any] | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class SQLTokenizer: | ||
sql_function_map: dict[str, typing.Callable | None] | ||
|
||
def tokenize_sql(self, sql: list[str]) -> list[SQLTokens]: | ||
# remove comments | ||
sql = [line.strip() for line in sql if not line.startswith("--")] | ||
# re-split on semi-colons, the semi-colons are the true line breaks in SQL | ||
sql = " ".join(sql).split(";") | ||
# remove empty lines | ||
sql = [line.strip() for line in sql if line] | ||
|
||
# get worker strings | ||
worker_strs = [] | ||
worker_funcs = [] | ||
args_strs = [] | ||
for line in sql: | ||
this_worker_str = None | ||
# We sort the SQL function map by its key length, longest first. | ||
# This is a low complexity way to ensure that we can match, for example, | ||
# both `SET SESSION AUTHORIZATION` and `SET`. | ||
# fmt: off | ||
sql_function_map_ordered_keys = sorted([ | ||
key | ||
for key in self.sql_function_map.keys() | ||
], key=len, reverse=True) | ||
# fmt: on | ||
for key in sql_function_map_ordered_keys: | ||
if line.startswith(key): | ||
this_worker_str = key | ||
worker_strs.append(key) | ||
worker_funcs.append(self.sql_function_map[key]) | ||
args_strs.append(line.replace(key, "").strip()) | ||
break | ||
if this_worker_str is None: | ||
raise ValueError(f"Unknown worker function: {this_worker_str}") | ||
|
||
# tokenize args | ||
args_list: list[list] = [] | ||
for i, sentence in enumerate(args_strs): | ||
args_list.append([]) | ||
word_start: int | None = 0 | ||
inside_list = False | ||
string_start: tuple[int | None, str | None] = (None, None) | ||
for k, letter in enumerate(sentence): | ||
if (string_start[0] is None) and (letter in ["'", '"']): | ||
if DEBUG: | ||
print(f"at letter: {letter}, starting a string") | ||
string_start = (k, letter) | ||
elif (word_start is None) and (letter not in ["(", ")", ",", " "]): | ||
if DEBUG: | ||
print(f"at letter: {letter}, starting a word") | ||
word_start = k | ||
elif (letter == string_start[1]) and (sentence[k - 1] != "\\") and (inside_list): | ||
if DEBUG: | ||
print(f"at letter: {letter}, ending string: {sentence[string_start[0]:k+1]}") | ||
string = sentence[string_start[0] : k + 1] | ||
args_list[i][-1].append(string) | ||
string_start = (None, None) | ||
word_start = None | ||
elif (string_start[0] is not None) and (letter == string_start[1]) and (sentence[k - 1] != "\\"): | ||
if DEBUG: | ||
print(f"at letter: {letter}, ending string: {sentence[string_start[0]:k+1]}") | ||
string = sentence[string_start[0] : k + 1] | ||
args_list[i].append(string) | ||
string_start = (None, None) | ||
word_start = None | ||
elif (word_start is not None) and (letter in [")"]) and (inside_list) and (string_start[0] is None): | ||
if DEBUG: | ||
print( | ||
f"at letter: {letter}, adding word: {sentence[word_start:k]}, to list: {args_list[i][-1]}" | ||
) | ||
word = sentence[word_start:k] | ||
args_list[i][-1].append(word) | ||
word_start = None | ||
inside_list = False | ||
elif ( | ||
(word_start is not None) and (letter in [" ", ","]) and (inside_list) and (string_start[0] is None) | ||
): | ||
if DEBUG: | ||
print( | ||
f"at letter: {letter}, adding word: {sentence[word_start:k]}, to list: {args_list[i][-1]}" | ||
) | ||
word = sentence[word_start:k] | ||
args_list[i][-1].append(word) | ||
word_start = None | ||
elif (word_start is not None) and (letter in [" ", ")", ","]) and (string_start[0] is None): | ||
if DEBUG: | ||
print(f"at letter: {letter}, adding word: {sentence[word_start:k]}") | ||
word = sentence[word_start:k] | ||
args_list[i].append(word) | ||
word_start = None | ||
elif (word_start is not None) and (k == len(sentence) - 1): | ||
if DEBUG: | ||
print(f"at letter: {letter}, last word: {sentence[word_start:]}") | ||
word = sentence[word_start:] | ||
args_list[i].append(word) | ||
word_start = None | ||
elif letter == "(": | ||
if DEBUG: | ||
print(f"at letter: {letter}, starting a list") | ||
inside_list = True | ||
args_list[i].append([]) | ||
word_start = None | ||
elif (inside_list) and (letter in ")"): | ||
if DEBUG: | ||
print(f"at letter: {letter}, ending list") | ||
inside_list = False | ||
elif word_start is not None: | ||
if DEBUG: | ||
print(f"at letter: {letter}, inside of a word: {sentence[word_start:k]}") | ||
else: | ||
if DEBUG: | ||
print(f"at letter: {letter}") | ||
|
||
return [ | ||
SQLTokens( | ||
worker_str=worker_str, | ||
worker_func=worker_func, | ||
args=args_list, | ||
) | ||
for worker_str, worker_func, args_list in zip(worker_strs, worker_funcs, args_list) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters