Skip to content

Commit

Permalink
fix: training framework dumped model cleanup on drop function, closes #…
Browse files Browse the repository at this point in the history
  • Loading branch information
aayushacharya committed Jan 25, 2024
1 parent 6fcac41 commit 6e19a02
Showing 1 changed file with 54 additions and 11 deletions.
65 changes: 54 additions & 11 deletions evadb/executor/drop_object_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil

import pandas as pd

from evadb.configuration.constants import TRAINING_FRAMEWORKS
from evadb.database import EvaDBDatabase
from evadb.executor.abstract_executor import AbstractExecutor
from evadb.executor.executor_utils import ExecutorError, handle_vector_store_params
Expand All @@ -24,6 +27,7 @@
from evadb.plan_nodes.drop_object_plan import DropObjectPlan
from evadb.storage.storage_engine import StorageEngine
from evadb.third_party.vector_stores.utils import VectorStoreFactory
from evadb.utils.generic_utils import string_comparison_case_insensitive
from evadb.utils.logging_manager import logger


Expand Down Expand Up @@ -94,19 +98,58 @@ def _handle_drop_function(self, function_name: str, if_exists: bool):
function_entry = self.catalog().get_function_catalog_entry_by_name(
function_name
)
for cache in function_entry.dep_caches:
self.catalog().drop_function_cache_catalog_entry(cache)

# todo also delete the indexes associated with the table

self.catalog().delete_function_catalog_entry_by_name(function_name)

return Batch(
pd.DataFrame(
{f"Function {function_name} successfully dropped"},
index=[0],
# training framework model cleanup on drop function
err_msg = (
f"Error removing {function_entry.type} model for function {function_name}."
)
try:
if function_entry.type.lower() in [x.lower() for x in TRAINING_FRAMEWORKS]:
filtered_metadata = list(
filter(lambda x: x.key == "model_path", function_entry.metadata)
)
if len(filtered_metadata) > 0:
model_path = os.path.abspath(filtered_metadata[0].value)
"""For 'Forecasting' the entire function catalog of forecasting functions
is checked to see if the model path is shared"""
if string_comparison_case_insensitive(
function_entry.type, "Forecasting"
):
forecasting_function_entries = (
self.catalog().get_function_catalog_entries_by_type(
function_entry.type
)
)
functions_using_same_model = sum(
1
for entry in forecasting_function_entries
if any(
x.key == "model_path"
and os.path.abspath(x.value) == model_path
for x in entry.metadata
)
)
if functions_using_same_model == 1:
dir_path = os.path.abspath(os.path.dirname(model_path))
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
else:
if os.path.exists(model_path):
os.remove(model_path)

except Exception as e:
raise RuntimeError(f"{err_msg}\n{e}")

for cache in function_entry.dep_caches:
self.catalog().drop_function_cache_catalog_entry(cache)

# todo also delete the indexes associated with the table
self.catalog().delete_function_catalog_entry_by_name(function_name)
return Batch(
pd.DataFrame(
{f"Function {function_name} successfully dropped"},
index=[0],
)
)

def _handle_drop_index(self, index_name: str, if_exists: bool):
index_obj = self.catalog().get_index_catalog_entry_by_name(index_name)
Expand Down

0 comments on commit 6e19a02

Please sign in to comment.