diff --git a/Makefile b/Makefile index df228d3..ce0c41c 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ run_all: run_train run_pred run_evaluate # default: # cat tests/lifecycle/test_output.txt -test_mlflow_config: - @pytest \ - tests/lifecycle/test_mlflow.py::TestMlflow::test_mlflow_experiment_is_not_null \ - tests/lifecycle/test_mlflow.py::TestMlflow::test_mlflow_model_name_is_not_null +# test_mlflow_config: +# @pytest \ +# tests/lifecycle/test_mlflow.py::TestMlflow::test_mlflow_experiment_is_not_null \ +# tests/lifecycle/test_mlflow.py::TestMlflow::test_mlflow_model_name_is_not_null diff --git a/dfake/interface/main.py b/dfake/interface/main.py index 7d9a83e..856779e 100644 --- a/dfake/interface/main.py +++ b/dfake/interface/main.py @@ -11,13 +11,14 @@ from dfake.params import * from dfake.dl_logic.model import initialize_model, compile_model, train_model, evaluate_model -from dfake.dl_logic.registry import load_model, save_model, save_results +from dfake.dl_logic.registry import load_model, save_model, save_results, mlflow_run from google.cloud import storage from PIL import Image +@mlflow_run def train(learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE, patience=PATIENCE @@ -31,14 +32,8 @@ def train(learning_rate=LEARNING_RATE, print("\n⭐️ Use case: train") print("\nLoading preprocessed validation data...") - - client = storage.Client() - bucket = client.bucket(BUCKET_NAME) - train_data_dir = bucket.blob(f"data/{DATA_SIZE}/train") - val_data_dir = bucket.blob(f"data/{DATA_SIZE}/valid") - #Lightweight dataset - # train_data_dir = Path(LOCAL_DATA_PATH).joinpath(f"{DATA_SIZE}", "train") - # val_data_dir = Path(LOCAL_DATA_PATH).joinpath(f"{DATA_SIZE}", "valid") + train_data_dir = Path(LOCAL_DATA_PATH).joinpath(f"{DATA_SIZE}", "train") + val_data_dir = Path(LOCAL_DATA_PATH).joinpath(f"{DATA_SIZE}", "valid") #Load data @@ -100,6 +95,7 @@ def train(learning_rate=LEARNING_RATE, return val_accuracy, val_recall, val_precision +@mlflow_run def evaluate(): """ Evaluate the performance of the model on test data @@ -110,11 +106,7 @@ def evaluate(): model = load_model() assert model is not None - client = storage.Client() - bucket = client.bucket(BUCKET_NAME) - test_data_dir = bucket.blob(f"data/{DATA_SIZE}/test") - - # test_data_dir = Path(LOCAL_DATA_PATH).joinpath(f"{DATA_SIZE}", "test") + test_data_dir = Path(LOCAL_DATA_PATH).joinpath(f"{DATA_SIZE}", "test") test_ds = image_dataset_from_directory( diff --git a/setup.py b/setup.py index c0643ce..6380860 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ requirements = [x.strip() for x in content if "git+" not in x] setup(name='dfake_models', - version="0.0.1", + version="0.0.2", description="D-fake models", license="MIT", author="Le Wagon",