Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ def create_dataset(
*,
columns: Sequence[Column],
feature_schema: dict | None = None,
flat_schema: dict | None = None,
query_script: str = "",
create_rows: bool | None = True,
validate_version: bool | None = True,
Expand Down Expand Up @@ -831,15 +832,17 @@ def create_dataset(
)

except DatasetNotFoundError:
"""
schema = {
c.name: c.type.to_dict() for c in columns if isinstance(c.type, SQLType)
}
"""
dataset = self.metastore.create_dataset(
name,
project.id if project else None,
feature_schema=feature_schema,
query_script=query_script,
schema=schema,
schema=flat_schema,
ignore_if_exists=True,
description=description,
attrs=attrs,
Expand Down
9 changes: 9 additions & 0 deletions src/datachain/lib/dc/datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from datachain.query.dataset import DatasetQuery, PartitionByType
from datachain.query.schema import DEFAULT_DELIMITER, Column
from datachain.sql.functions import path as pathfunc
from datachain.sql.types import SQLType
from datachain.utils import batched_it, env2bool, inside_notebook, row_to_nested_dict

from .database import DEFAULT_DATABASE_BATCH_SIZE
Expand Down Expand Up @@ -639,6 +640,13 @@ def save( # type: ignore[override]

# Schema preparation
schema = self.signals_schema.clone_without_sys_signals().serialize()
flat_schema = {
c.name: c.type.to_dict() # type: ignore[union-attr]
for c in self.signals_schema.clone_with_sys_signals().db_signals(
as_columns=True
)
if isinstance(c.type, SQLType) # type: ignore[union-attr]
}

# Handle retry and delta functionality
if not result:
Expand All @@ -654,6 +662,7 @@ def save( # type: ignore[override]
description=description,
attrs=attrs,
feature_schema=schema,
flat_schema=flat_schema,
update_version=update_version,
**kwargs,
)
Expand Down
6 changes: 6 additions & 0 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sqlalchemy import ColumnElement
from typing_extensions import Literal as LiteralEx

from datachain.data_storage.schema import DataTable
from datachain.func import literal
from datachain.func.func import Func
from datachain.lib.convert.python_to_sql import python_to_sql
Expand Down Expand Up @@ -769,6 +770,11 @@ def clone_without_sys_signals(self) -> "SignalSchema":
schema.pop("sys", None)
return SignalSchema(schema)

def clone_with_sys_signals(self) -> "SignalSchema":
sys_cols = {c.name: c.type for c in DataTable.sys_columns()}
sys_schema = SignalSchema.from_column_types(sys_cols)
return sys_schema.merge(self.clone_without_sys_signals(), "")

def merge(
self,
right_schema: "SignalSchema",
Expand Down
2 changes: 2 additions & 0 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,7 @@ def save(
version: str | None = None,
project: Project | None = None,
feature_schema: dict | None = None,
flat_schema: dict | None = None,
dependencies: list[DatasetDependency] | None = None,
description: str | None = None,
attrs: list[str] | None = None,
Expand Down Expand Up @@ -1920,6 +1921,7 @@ def save(
project,
version=version,
feature_schema=feature_schema,
flat_schema=flat_schema,
columns=columns,
description=description,
attrs=attrs,
Expand Down
Loading