diff --git a/docs/guides/customizing_sqlmesh.md b/docs/guides/customizing_sqlmesh.md index 3b95b6ba8..e653b266b 100644 --- a/docs/guides/customizing_sqlmesh.md +++ b/docs/guides/customizing_sqlmesh.md @@ -50,7 +50,7 @@ class CustomLoader(SqlMeshLoader): # Call SqlMeshLoader's normal `_load_models` method to ingest models from file and parse model SQL models = super()._load_models(macros, jinja_macros, gateway, audits, signals) - new_models = {} + new_models: UniqueKeyDict[str, Model] = {} # Loop through the existing model names/objects for model_name, model in models.items(): # Create list of existing and new post-statements @@ -64,11 +64,19 @@ class CustomLoader(SqlMeshLoader): # Create a copy of the model with the `post_statements_` field updated new_models[model_name] = model.copy(update={"post_statements_": new_post_statements}) - return new_models + # Load the new models to ensure that the modified models are correctly serialized + return self._load_models_from_definitions( + models=new_models, + macros=macros, + jinja_macros=jinja_macros, + audits=audits, + signals=signals, + ) + # Pass the CustomLoader class to the SQLMesh configuration object config = Config( # < your configuration parameters here >, loader=CustomLoader, ) -``` \ No newline at end of file +``` diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 3c688b852..a23b8e4d1 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -18,7 +18,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit, load_multiple_audits -from sqlmesh.core.dialect import parse +from sqlmesh.core.dialect import format_model_expressions, parse from sqlmesh.core.environment import EnvironmentStatements from sqlmesh.core.linter.rule import Rule from sqlmesh.core.linter.definition import RuleSet @@ -551,6 +551,57 @@ def _load_python_models( return models + def _load_models_from_definitions( + self, + models: UniqueKeyDict[str, Model], + macros: MacroRegistry, + jinja_macros: JinjaMacroRegistry, + audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], + ) -> UniqueKeyDict[str, Model]: + out: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + + for model_name, model in models.items(): + if model.is_sql: + try: + model_definition = model.render_definition(include_python=False) + formatted_model = format_model_expressions( + model_definition, + self.config.model_defaults.dialect, + ) + expressions = parse(formatted_model) + + reloaded_models = load_sql_based_models( + expressions, + self._get_variables, + defaults=self.config.model_defaults.dict(), + macros=macros, + jinja_macros=jinja_macros, + audit_definitions=audits, + default_audits=self.config.model_defaults.audits, + path=model._path, + module_path=self.config_path, + dialect=self.config.model_defaults.dialect, + time_column_format=self.config.time_column_format, + physical_schema_mapping=self.config.physical_schema_mapping, + project=self.config.project, + default_catalog=self.context.default_catalog, + infer_names=self.config.model_naming.infer_names, + signal_definitions=signals, + ) + except Exception as ex: + raise ConfigError( + f"Failed to reload model definition at '{model._path}'.\n{ex}" + ) + + for new_model in reloaded_models: + out[new_model.fqn] = new_model + + else: # TODO: Reload of seeds, python models + out[model.fqn] = model + + return out + def load_materializations(self) -> None: with sys_path(self.config_path): self._load_materializations()