Skip to content

Commit

Permalink
snippets
Browse files Browse the repository at this point in the history
  • Loading branch information
coilysiren committed Oct 23, 2023
1 parent 6ac07f6 commit 8ab234e
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 1 deletion.
146 changes: 146 additions & 0 deletions snippets/python/sql_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@

import dataclasses
import json
import re
import typing

import tokenizer_script


@dataclasses.dataclass(frozen=True)
class SQLState:
state: dict

def read_table_meta(self, table_name: str) -> dict:
return self.state.get(table_name, {}).get("metadata", {})

def read_table_rows(self, table_name: str) -> list[dict]:
return self.state.get(table_name, {}).get("rows", [])

def read_information_schema(self) -> list[dict]:
return [data["metadata"] for data in self.state.values()]

def write_table_meta(self, table_name: str, data: dict):
state = self.state
table = state.get(table_name, {})
metadata = table.get("metadata", {})
metadata.update(data)
table["metadata"] = metadata
state[table_name] = table
return self.__class__(state)

def write_table_rows(self, table_name: str, data: dict):
state = self.state
table = state.get(table_name, {})
rows = table.get("rows", [])
rows.append(data)
table["rows"] = rows
state[table_name] = table
return self.__class__(state)


class SQLType:
@staticmethod
def varchar(data) -> str:
data_str = str(data).strip()
data_str = re.sub(r'^["\']', "", data_str) # leading ' or "
data_str = re.sub(r'["\']$', "", data_str) # trailing ' or "
return data_str

@staticmethod
def int(data) -> int:
return int(data.strip())


class SQLFunctions:
@staticmethod
def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[list, SQLState]:
output: list[dict] = []
table_name = args[0]

# get columns
columns = {}
columns_str = args[1]
if columns_str:
# fmt: off
columns = {
columns_str[i]: columns_str[i + 1]
for i in range(0, len(columns_str), 2)
}
# fmt: on

if not state.read_table_meta(table_name):
state = state.write_table_meta(
table_name,
{
"table_name": table_name,
"table_schema": table_schema,
"colums": columns,
},
)
return (output, state)

@staticmethod
def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]:
output: list[dict] = []
table_name = args[0]
keys = args[1]
values = args[3]
key_value_map = dict(zip(keys, values))

sql_type_map = {
"VARCHAR": SQLType.varchar,
"INT": SQLType.int,
}

data = {}
metadata = state.read_table_meta(table_name)
if metadata:
for key, value in key_value_map.items():
data[key] = sql_type_map[metadata["colums"][key]](value)
state = state.write_table_rows(table_name, data)

return (output, state)

@staticmethod
def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]:
output: list[dict] = []
select_columns = args[0] if isinstance(args[0], list) else [args[0]]
from_value = args[2]

# `information_schema.tables` is a special case
if from_value == "information_schema.tables":
data = state.read_information_schema()
else:
data = state.read_table_rows(from_value)

output = []
for datum in data:
# fmt: off
output.append({
key: datum.get(key)
for key in select_columns
})
# fmt: on

return (output, state)


def run_sql(input_sql: list[str]) -> list[str]:
output = []
state = SQLState(state={})
sql_tokenizer = tokenizer_script.SQLTokenizer(
{
"CREATE TABLE": SQLFunctions.create_table,
"INSERT INTO": SQLFunctions.insert_into,
"SELECT": SQLFunctions.select,
}
)
sql_token_list = sql_tokenizer.tokenize_sql(input_sql)

# iterate over each line of sql
for sql_tokens in sql_token_list:
output, state = sql_tokens.worker_func(state, *sql_tokens.args)

return [json.dumps(output)]

140 changes: 140 additions & 0 deletions snippets/python/tokenizer_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@

import dataclasses
import json
import os
import typing


DEBUG = bool(int(os.getenv("DEBUG", "0")))


@dataclasses.dataclass(frozen=True)
class SQLTokens:
worker_str: str
worker_func: typing.Callable | None
args: list[typing.Any]


@dataclasses.dataclass(frozen=True)
class SQLTokenizer:
sql_function_map: dict[str, typing.Callable | None]

def tokenize_sql(self, sql: list[str]) -> list[SQLTokens]:
# remove comments
sql = [line.strip() for line in sql if not line.startswith("--")]
# re-split on semi-colons, the semi-colons are the true line breaks in SQL
sql = " ".join(sql).split(";")
# remove empty lines
sql = [line.strip() for line in sql if line]

# get worker strings
worker_strs = []
worker_funcs = []
args_strs = []
for line in sql:
this_worker_str = None
# We sort the SQL function map by its key length, longest first.
# This is a low complexity way to ensure that we can match, for example,
# both `SET SESSION AUTHORIZATION` and `SET`.
# fmt: off
sql_function_map_ordered_keys = sorted([
key
for key in self.sql_function_map.keys()
], key=len, reverse=True)
# fmt: on
for key in sql_function_map_ordered_keys:
if line.startswith(key):
this_worker_str = key
worker_strs.append(key)
worker_funcs.append(self.sql_function_map[key])
args_strs.append(line.replace(key, "").strip())
break
if this_worker_str is None:
raise ValueError(f"Unknown worker function: {this_worker_str}")

# tokenize args
args_list: list[list] = []
for i, sentence in enumerate(args_strs):
args_list.append([])
word_start: int | None = 0
inside_list = False
string_start: tuple[int | None, str | None] = (None, None)
for k, letter in enumerate(sentence):
if (string_start[0] is None) and (letter in ["'", '"']):
if DEBUG:
print(f"at letter: {letter}, starting a string")
string_start = (k, letter)
elif (word_start is None) and (letter not in ["(", ")", ",", " "]):
if DEBUG:
print(f"at letter: {letter}, starting a word")
word_start = k
elif (letter == string_start[1]) and (sentence[k - 1] != "\\") and (inside_list):
if DEBUG:
print(f"at letter: {letter}, ending string: {sentence[string_start[0]:k+1]}")
string = sentence[string_start[0] : k + 1]
args_list[i][-1].append(string)
string_start = (None, None)
word_start = None
elif (string_start[0] is not None) and (letter == string_start[1]) and (sentence[k - 1] != "\\"):
if DEBUG:
print(f"at letter: {letter}, ending string: {sentence[string_start[0]:k+1]}")
string = sentence[string_start[0] : k + 1]
args_list[i].append(string)
string_start = (None, None)
word_start = None
elif (word_start is not None) and (letter in [")"]) and (inside_list) and (string_start[0] is None):
if DEBUG:
print(
f"at letter: {letter}, adding word: {sentence[word_start:k]}, to list: {args_list[i][-1]}"
)
word = sentence[word_start:k]
args_list[i][-1].append(word)
word_start = None
inside_list = False
elif (
(word_start is not None) and (letter in [" ", ","]) and (inside_list) and (string_start[0] is None)
):
if DEBUG:
print(
f"at letter: {letter}, adding word: {sentence[word_start:k]}, to list: {args_list[i][-1]}"
)
word = sentence[word_start:k]
args_list[i][-1].append(word)
word_start = None
elif (word_start is not None) and (letter in [" ", ")", ","]) and (string_start[0] is None):
if DEBUG:
print(f"at letter: {letter}, adding word: {sentence[word_start:k]}")
word = sentence[word_start:k]
args_list[i].append(word)
word_start = None
elif (word_start is not None) and (k == len(sentence) - 1):
if DEBUG:
print(f"at letter: {letter}, last word: {sentence[word_start:]}")
word = sentence[word_start:]
args_list[i].append(word)
word_start = None
elif letter == "(":
if DEBUG:
print(f"at letter: {letter}, starting a list")
inside_list = True
args_list[i].append([])
word_start = None
elif (inside_list) and (letter in ")"):
if DEBUG:
print(f"at letter: {letter}, ending list")
inside_list = False
elif word_start is not None:
if DEBUG:
print(f"at letter: {letter}, inside of a word: {sentence[word_start:k]}")
else:
if DEBUG:
print(f"at letter: {letter}")

return [
SQLTokens(
worker_str=worker_str,
worker_func=worker_func,
args=args_list,
)
for worker_str, worker_func, args_list in zip(worker_strs, worker_funcs, args_list)
]
2 changes: 1 addition & 1 deletion snippets/ruby/sql_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def self.select(state, *args)

raise 'FROM not found' if from_index.nil?

select_keys = args[1...from_index].join(' ').split(',').map(&:strip)
select_keys = args[1...from_index].join(' ').split(',').map {|s| s.gsub(/[()]/, '')}.map(&:strip)
from_value = args[from_index + 1]

data = if from_value == 'information_schema.tables'
Expand Down

0 comments on commit 8ab234e

Please sign in to comment.