diff --git a/evadb/catalog/catalog_manager.py b/evadb/catalog/catalog_manager.py index 70d3e0acff..a788c30400 100644 --- a/evadb/catalog/catalog_manager.py +++ b/evadb/catalog/catalog_manager.py @@ -515,6 +515,20 @@ def get_function_catalog_entry_by_name(self, name: str) -> FunctionCatalogEntry: """ return self._function_service.get_entry_by_name(name) + def get_function_catalog_entries_by_type( + self, type: str + ) -> List[FunctionCatalogEntry]: + """ + Get function information based on type. + + Arguments: + type (str): type of the function + + Returns: + List of FunctionCatalogEntry object + """ + return self._function_service.get_entries_by_type(type) + def delete_function_catalog_entry_by_name(self, function_name: str) -> bool: return self._function_service.delete_entry_by_name(function_name) diff --git a/evadb/catalog/services/function_catalog_service.py b/evadb/catalog/services/function_catalog_service.py index 0c6c272d3c..c1244da1ba 100644 --- a/evadb/catalog/services/function_catalog_service.py +++ b/evadb/catalog/services/function_catalog_service.py @@ -100,6 +100,24 @@ def get_entry_by_name(self, name: str) -> FunctionCatalogEntry: return function_obj.as_dataclass() return None + def get_entries_by_type(self, function_type: str) -> List[FunctionCatalogEntry]: + """returns the function entries that matches the type provided. + Empty list if no such entry found. + + Arguments: + type (str): name to be searched + """ + + entries = ( + self.session.execute( + select(self.model).filter(self.model._type == function_type) + ) + .scalars() + .all() + ) + + return [entry.as_dataclass() for entry in entries] + def get_entry_by_id(self, id: int, return_alchemy=False) -> FunctionCatalogEntry: """return the function entry that matches the id provided. None if no such entry found. diff --git a/evadb/configuration/constants.py b/evadb/configuration/constants.py index 126e6bcfca..c8b84ca933 100644 --- a/evadb/configuration/constants.py +++ b/evadb/configuration/constants.py @@ -38,3 +38,4 @@ DEFAULT_XGBOOST_TASK = "regression" DEFAULT_SKLEARN_TRAIN_MODEL = "rf" SKLEARN_SUPPORTED_MODELS = ["rf", "extra_tree", "kneighbor"] +TRAINING_FRAMEWORKS = ["Sklearn", "Ludwig", "XGBoost", "Forecasting"] diff --git a/evadb/executor/drop_object_executor.py b/evadb/executor/drop_object_executor.py index c4f108052e..d27012328e 100644 --- a/evadb/executor/drop_object_executor.py +++ b/evadb/executor/drop_object_executor.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import shutil import pandas as pd +from evadb.configuration.constants import TRAINING_FRAMEWORKS from evadb.database import EvaDBDatabase from evadb.executor.abstract_executor import AbstractExecutor from evadb.executor.executor_utils import ExecutorError, handle_vector_store_params @@ -24,6 +27,7 @@ from evadb.plan_nodes.drop_object_plan import DropObjectPlan from evadb.storage.storage_engine import StorageEngine from evadb.third_party.vector_stores.utils import VectorStoreFactory +from evadb.utils.generic_utils import string_comparison_case_insensitive from evadb.utils.logging_manager import logger @@ -94,19 +98,58 @@ def _handle_drop_function(self, function_name: str, if_exists: bool): function_entry = self.catalog().get_function_catalog_entry_by_name( function_name ) - for cache in function_entry.dep_caches: - self.catalog().drop_function_cache_catalog_entry(cache) - - # todo also delete the indexes associated with the table - - self.catalog().delete_function_catalog_entry_by_name(function_name) - - return Batch( - pd.DataFrame( - {f"Function {function_name} successfully dropped"}, - index=[0], + # training framework model cleanup on drop function + err_msg = ( + f"Error removing {function_entry.type} model for function {function_name}." + ) + try: + if function_entry.type.lower() in [x.lower() for x in TRAINING_FRAMEWORKS]: + filtered_metadata = list( + filter(lambda x: x.key == "model_path", function_entry.metadata) ) + if len(filtered_metadata) > 0: + model_path = os.path.abspath(filtered_metadata[0].value) + """For 'Forecasting' the entire function catalog of forecasting functions + is checked to see if the model path is shared""" + if string_comparison_case_insensitive( + function_entry.type, "Forecasting" + ): + forecasting_function_entries = ( + self.catalog().get_function_catalog_entries_by_type( + function_entry.type + ) + ) + functions_using_same_model = sum( + 1 + for entry in forecasting_function_entries + if any( + x.key == "model_path" + and os.path.abspath(x.value) == model_path + for x in entry.metadata + ) + ) + if functions_using_same_model == 1: + dir_path = os.path.abspath(os.path.dirname(model_path)) + if os.path.exists(dir_path): + shutil.rmtree(dir_path) + else: + if os.path.exists(model_path): + os.remove(model_path) + + except Exception as e: + raise RuntimeError(f"{err_msg}\n{e}") + + for cache in function_entry.dep_caches: + self.catalog().drop_function_cache_catalog_entry(cache) + + # todo also delete the indexes associated with the table + self.catalog().delete_function_catalog_entry_by_name(function_name) + return Batch( + pd.DataFrame( + {f"Function {function_name} successfully dropped"}, + index=[0], ) + ) def _handle_drop_index(self, index_name: str, if_exists: bool): index_obj = self.catalog().get_index_catalog_entry_by_name(index_name)