Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions wren-ai-service/src/pipelines/indexing/db_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import logging
import sys
import uuid
Expand Down Expand Up @@ -29,7 +28,7 @@
@component
class DDLChunker:
@component.output_types(documents=List[Document])
async def run(
def run(
self,
mdl: Dict[str, Any],
column_batch_size: int,
Expand All @@ -48,7 +47,7 @@ def _additional_meta() -> Dict[str, Any]:
},
"content": chunk["payload"],
}
for chunk in await self._get_ddl_commands(
for chunk in self._get_ddl_commands(
**mdl, column_batch_size=column_batch_size
)
]
Expand All @@ -63,7 +62,7 @@ def _additional_meta() -> Dict[str, Any]:
]
}

async def _model_preprocessor(
def _model_preprocessor(
self, models: List[Dict[str, Any]], **kwargs
) -> List[Dict[str, Any]]:
def _column_preprocessor(
Expand All @@ -81,9 +80,9 @@ def _column_preprocessor(
**addition,
}

async def _preprocessor(model: Dict[str, Any], **kwargs) -> Dict[str, Any]:
def _preprocessor(model: Dict[str, Any], **kwargs) -> Dict[str, Any]:
addition = {
key: await helper(model, **kwargs)
key: helper(model, **kwargs)
for key, helper in helper.MODEL_PREPROCESSORS.items()
if helper.condition(model, **kwargs)
}
Expand All @@ -100,11 +99,9 @@ async def _preprocessor(model: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"primaryKey": model.get("primaryKey", ""),
}

tasks = [_preprocessor(model, **kwargs) for model in models]

return await asyncio.gather(*tasks)
return [_preprocessor(model, **kwargs) for model in models]

async def _get_ddl_commands(
def _get_ddl_commands(
self,
models: List[Dict[str, Any]],
relationships: List[Dict[str, Any]],
Expand All @@ -115,7 +112,7 @@ async def _get_ddl_commands(
) -> List[dict]:
return (
self._convert_models_and_relationships(
await self._model_preprocessor(models, **kwargs),
self._model_preprocessor(models, **kwargs),
relationships,
column_batch_size,
)
Expand Down Expand Up @@ -300,13 +297,13 @@ def validate_mdl(mdl_str: str, validator: MDLValidator) -> Dict[str, Any]:


@observe(capture_input=False)
async def chunk(
def chunk(
mdl: Dict[str, Any],
chunker: DDLChunker,
column_batch_size: int,
project_id: Optional[str] = None,
) -> Dict[str, Any]:
return await chunker.run(
return chunker.run(
mdl=mdl,
column_batch_size=column_batch_size,
project_id=project_id,
Expand Down
60 changes: 24 additions & 36 deletions wren-ai-service/tests/pytest/pipelines/indexing/test_db_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@
from src.pipelines.indexing.db_schema import DBSchema, DDLChunker


@pytest.mark.asyncio
async def test_empty_mdl():
def test_empty_mdl():
chunker = DDLChunker()
mdl = {"models": [], "views": [], "relationships": [], "metrics": []}

document = await chunker.run(mdl, column_batch_size=1)
document = chunker.run(mdl, column_batch_size=1)
assert document == {"documents": []}


@pytest.mark.asyncio
async def test_single_model():
def test_single_model():
chunker = DDLChunker()
mdl = {
"models": [
Expand All @@ -35,7 +33,7 @@ async def test_single_model():
"metrics": [],
}

actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 1

document: Document = actual["documents"][0]
Expand All @@ -49,8 +47,7 @@ async def test_single_model():
)


@pytest.mark.asyncio
async def test_multiple_models():
def test_multiple_models():
chunker = DDLChunker()
mdl = {
"models": [
Expand All @@ -74,7 +71,7 @@ async def test_multiple_models():
"metrics": [],
}

actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 2

document_1: Document = actual["documents"][0]
Expand All @@ -98,8 +95,7 @@ async def test_multiple_models():
)


@pytest.mark.asyncio
async def test_column_is_primary_key():
def test_column_is_primary_key():
chunker = DDLChunker()
mdl = {
"models": [
Expand All @@ -119,7 +115,7 @@ async def test_column_is_primary_key():
"metrics": [],
}

actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 2

document_0: Document = actual["documents"][0]
Expand All @@ -140,8 +136,7 @@ async def test_column_is_primary_key():
)


@pytest.mark.asyncio
async def test_column_with_properties():
def test_column_with_properties():
chunker = DDLChunker()
mdl = {
"models": [
Expand All @@ -164,7 +159,7 @@ async def test_column_with_properties():
"metrics": [],
}

actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 2

document_0: Document = actual["documents"][0]
Expand Down Expand Up @@ -195,8 +190,7 @@ async def test_column_with_properties():
)


@pytest.mark.asyncio
async def test_column_with_nested_columns():
def test_column_with_nested_columns():
chunker = DDLChunker()
mdl = {
"models": [
Expand All @@ -221,7 +215,7 @@ async def test_column_with_nested_columns():
"metrics": [],
}

actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 2

document_0: Document = actual["documents"][0]
Expand All @@ -242,8 +236,7 @@ async def test_column_with_nested_columns():
)


@pytest.mark.asyncio
async def test_column_with_calculated_property():
def test_column_with_calculated_property():
chunker = DDLChunker()
mdl = {
"models": [
Expand All @@ -264,7 +257,7 @@ async def test_column_with_calculated_property():
"metrics": [],
}

actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 2

document_0: Document = actual["documents"][0]
Expand All @@ -285,8 +278,7 @@ async def test_column_with_calculated_property():
)


@pytest.mark.asyncio
async def test_column_with_relationship():
def test_column_with_relationship():
chunker = DDLChunker()
mdl = {
"models": [
Expand Down Expand Up @@ -328,7 +320,7 @@ async def test_column_with_relationship():
"metrics": [],
}

actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 6

document_0: Document = actual["documents"][0]
Expand Down Expand Up @@ -381,8 +373,7 @@ async def test_column_with_relationship():
)


@pytest.mark.asyncio
async def test_column_batch_size():
def test_column_batch_size():
chunker = DDLChunker()
mdl = {
"models": [
Expand All @@ -399,7 +390,7 @@ async def test_column_batch_size():
"relationships": [],
"metrics": [],
}
actual = await chunker.run(mdl, column_batch_size=2)
actual = chunker.run(mdl, column_batch_size=2)
assert len(actual["documents"]) == 3

document_0: Document = actual["documents"][0]
Expand Down Expand Up @@ -444,16 +435,15 @@ async def test_column_batch_size():
)


@pytest.mark.asyncio
async def test_view():
def test_view():
chunker = DDLChunker()
mdl = {
"models": [],
"views": [{"name": "view_1", "statement": "SELECT * FROM user"}],
"relationships": [],
"metrics": [],
}
actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 1

document_0: Document = actual["documents"][0]
Expand All @@ -468,8 +458,7 @@ async def test_view():
)


@pytest.mark.asyncio
async def test_view_with_properties():
def test_view_with_properties():
chunker = DDLChunker()
mdl = {
"models": [],
Expand All @@ -483,7 +472,7 @@ async def test_view_with_properties():
"relationships": [],
"metrics": [],
}
actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 1

document_0: Document = actual["documents"][0]
Expand All @@ -498,8 +487,7 @@ async def test_view_with_properties():
)


@pytest.mark.asyncio
async def test_metric():
def test_metric():
chunker = DDLChunker()
mdl = {
"models": [],
Expand All @@ -518,7 +506,7 @@ async def test_metric():
}
],
}
actual = await chunker.run(mdl, column_batch_size=1)
actual = chunker.run(mdl, column_batch_size=1)
assert len(actual["documents"]) == 1

document_0: Document = actual["documents"][0]
Expand Down