diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index ec8b26b..23633aa 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -15,7 +15,7 @@ jobs: python-version: '3.11' - name: Install dependencies run: | - python -m pip install ".[all]"" + python -m pip install ".[all,dev]" - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/README.md b/README.md index 314dec1..fc32703 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,26 @@ -# mlex_utils +# Utils for MLExchange Platform + +**mlex_utils** is a utility package designed to support the MLExchange platform by providing convenient tools and extensions for Dash Plotly and Prefect. Currently, this package focuses on facilitating job launches and monitoring workflows with Prefect, making it easier to manage and track your machine learning tasks. As the platform evolves, mlex_utils will continue to expand, incorporating additional utilities to enhance the MLExchange experience and streamline data workflows. + +## Features +- Utilities for integrating Dash Plotly components to orchestrate ML jobs using Dash Bootstrap Components and Dash Mantime Components +- Prefect job management tools for launching, scheduling, and monitoring ML jobs. + +## Installation + +To install `mlex_utils`, you can create a new Python environment and install all the dependencies: + +``` +conda create -n new_env python==3.11 +conda activate new_env +pip install . +``` + +Alternatively, you can choose to install a set of utils according to your use case. For example, to install the Prefect-related dependencies and utils: + +``` +pip install ".[prefect]" +``` ## Copyright MLExchange Copyright (c) 2024, The Regents of the University of California, diff --git a/examples/assets/mlex.ico b/examples/assets/mlex.ico new file mode 100644 index 0000000..ec85074 Binary files /dev/null and b/examples/assets/mlex.ico differ diff --git a/examples/assets/mlex.png b/examples/assets/mlex.png new file mode 100644 index 0000000..9cf706d Binary files /dev/null and b/examples/assets/mlex.png differ diff --git a/examples/assets/models_dbc.json b/examples/assets/models_dbc.json new file mode 100644 index 0000000..aed7e91 --- /dev/null +++ b/examples/assets/models_dbc.json @@ -0,0 +1,1126 @@ +{ + "contents": [ + { + "model_name": "MSDNet", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "MSDNets in DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "layer_width", + "title": "Layer Width", + "param_key": "layer_width", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "num_layers", + "title": "# Layers", + "param_key": "num_layers", + "value": 3, + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "custom_dilation", + "title": "Custom Dilation", + "param_key": "custom_dilation", + "value": false, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "max_dilation", + "title": "Maximum Dilation", + "param_key": "max_dilation", + "value": 5, + "comp_group": "train_model" + }, + { + "type": "str", + "name": "dilation_array", + "title": "Dilation Array", + "param_key": "dilation_array", + "value": "[1, 2, 4]", + "placeholder": "e.g. [1, 2, 4]", + "debounce": 1000, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "# Epochs", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "marks": { + "1": "1", + "100": "100", + "1000": "1000" + }, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "options": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "CrossEntropyLoss", + "options": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1]", + "placeholder": "e.g [0.1, 0.4, 0.5]", + "debounce": 1000, + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "Sigmoid", + "options": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_window", + "title": "Qlty Window", + "param_key": "qlty_window", + "value": 64, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_step", + "title": "Qlty Step", + "param_key": "qlty_step", + "value": 32, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_border", + "title": "Qlty Border", + "param_key": "qlty_border", + "value": 8, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 1, + "step": 0.05, + "value": 0.2, + "marks": { + "0": "0%", + "0.5": "50%", + "1": "100%" + }, + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "value": true, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Batch Size Training", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": { + "16": "16", + "32": "32", + "64": "64", + "128": "128" + }, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Batch Size Validation", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": { + "16": "16", + "32": "32", + "64": "64", + "128": "128" + }, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Batch Size Inference", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": { + "16": "16", + "32": "32", + "64": "64", + "128": "128" + }, + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + }, + { + "model_name": "TUNet", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "TUNet in DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "# Epochs", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "marks": { + "1": "1", + "100": "100", + "1000": "1000" + }, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "options": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "CrossEntropyLoss", + "options": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1]", + "placeholder": "e.g [0.1, 0.4, 0.5]", + "debounce": 1000, + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "Sigmoid", + "options": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_window", + "title": "Qlty Window", + "param_key": "qlty_window", + "value": 64, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_step", + "title": "Qlty Step", + "param_key": "qlty_step", + "value": 32, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_border", + "title": "Qlty Border", + "param_key": "qlty_border", + "value": 8, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 1, + "step": 0.05, + "value": 0.2, + "marks": { + "0": "0%", + "0.5": "50%", + "1": "100%" + }, + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "value": true, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": { + "16": "16", + "32": "32", + "64": "64", + "128": "128" + }, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": { + "16": "16", + "32": "32", + "64": "64", + "128": "128" + }, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": { + "16": "16", + "32": "32", + "64": "64", + "128": "128" + }, + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + }, + { + "model_name": "TUNet3+", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "TUNet3+ DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "carryover_channels", + "title": "Carryover Channels", + "param_key": "carryover_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "# Epochs", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "marks": { + "1": "1", + "100": "100", + "1000": "1000" + }, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "options": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "CrossEntropyLoss", + "options": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1]", + "placeholder": "e.g [0.1, 0.4, 0.5]", + "debounce": 1000, + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "Sigmoid", + "options": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_window", + "title": "Qlty Window", + "param_key": "qlty_window", + "value": 64, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_step", + "title": "Qlty Step", + "param_key": "qlty_step", + "value": 32, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_border", + "title": "Qlty Border", + "param_key": "qlty_border", + "value": 8, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 1, + "step": 0.05, + "value": 0.2, + "marks": { + "0": "0%", + "0.5": "50%", + "1": "100%" + }, + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "value": true, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": { + "16": "16", + "32": "32", + "64": "64", + "128": "128" + }, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": { + "16": "16", + "32": "32", + "64": "64", + "128": "128" + }, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": { + "16": "16", + "32": "32", + "64": "64", + "128": "128" + }, + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + } + ] +} diff --git a/examples/assets/models_dmc.json b/examples/assets/models_dmc.json new file mode 100644 index 0000000..7d7e158 --- /dev/null +++ b/examples/assets/models_dmc.json @@ -0,0 +1,1298 @@ +{ + "contents": [ + { + "model_name": "MSDNet", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "MSDNets in DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "layer_width", + "title": "Layer Width", + "param_key": "layer_width", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "num_layers", + "title": "# Layers", + "param_key": "num_layers", + "value": 3, + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "custom_dilation", + "title": "Custom Dilation", + "param_key": "custom_dilation", + "checked": false, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "max_dilation", + "title": "Maximum Dilation", + "param_key": "max_dilation", + "value": 5, + "comp_group": "train_model" + }, + { + "type": "str", + "name": "dilation_array", + "title": "Dilation Array", + "param_key": "dilation_array", + "value": "[1, 2, 4]", + "placeholder": "e.g. [1, 2, 4]", + "error": "Provide a list of ints for dilation", + "debounce": 1000, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "# Epochs", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "data": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "CrossEntropyLoss", + "data": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1]", + "placeholder": "e.g [0.1, 0.4, 0.5]", + "error": "Provide a list with a float for each class", + "debounce": 1000, + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "Sigmoid", + "data": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_window", + "title": "Qlty Window", + "param_key": "qlty_window", + "value": 64, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_step", + "title": "Qlty Step", + "param_key": "qlty_step", + "value": 32, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_border", + "title": "Qlty Border", + "param_key": "qlty_border", + "value": 8, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 1, + "step": 0.05, + "value": 0.2, + "precision": 2, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 0.5, + "label": "50%" + }, + { + "value": 1, + "label": "100%" + } + ], + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "checked": true, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Batch Size Training", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Batch Size Validation", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Batch Size Inference", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + }, + { + "model_name": "TUNet", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "TUNet in DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "# Epochs", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "data": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "CrossEntropyLoss", + "data": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1]", + "placeholder": "e.g [0.1, 0.4, 0.5]", + "error": "Provide a list with a float for each class", + "debounce": 1000, + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "Sigmoid", + "data": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_window", + "title": "Qlty Window", + "param_key": "qlty_window", + "value": 64, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_step", + "title": "Qlty Step", + "param_key": "qlty_step", + "value": 32, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_border", + "title": "Qlty Border", + "param_key": "qlty_border", + "value": 8, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 1, + "step": 0.05, + "value": 0.2, + "precision": 2, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 0.5, + "label": "50%" + }, + { + "value": 1, + "label": "100%" + } + ], + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "checked": true, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + }, + { + "model_name": "TUNet3+", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "TUNet3+ DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "carryover_channels", + "title": "Carryover Channels", + "param_key": "carryover_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "# Epochs", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "data": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "CrossEntropyLoss", + "data": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1]", + "placeholder": "e.g [0.1, 0.4, 0.5]", + "error": "Provide a list with a float for each class", + "debounce": 1000, + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "Sigmoid", + "data": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_window", + "title": "Qlty Window", + "param_key": "qlty_window", + "value": 64, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_step", + "title": "Qlty Step", + "param_key": "qlty_step", + "value": 32, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "qlty_border", + "title": "Qlty Border", + "param_key": "qlty_border", + "value": 8, + "min": 0, + "max": 4096, + "step": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 1, + "step": 0.05, + "value": 0.2, + "precision": 2, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 0.5, + "label": "50%" + }, + { + "value": 1, + "label": "100%" + } + ], + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "checked": true, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + } + ] +} diff --git a/examples/dbc_example.py b/examples/dbc_example.py new file mode 100644 index 0000000..84be3a3 --- /dev/null +++ b/examples/dbc_example.py @@ -0,0 +1,130 @@ +import uuid + +import dash_bootstrap_components as dbc +from dash import ALL, MATCH, Dash, Input, Output, callback, html +from models_utils import Models + +from mlex_utils.dash_utils.components_bootstrap.component_utils import header +from mlex_utils.dash_utils.mlex_components import MLExComponents + + +def get_control_panel(job_manager): + control_panel = dbc.Accordion( + [ + dbc.AccordionItem( + job_manager, + title="Model Configuration", + ), + ], + style={"position": "sticky", "top": "10%", "width": "100%", "padding": "1px"}, + ) + return control_panel + + +# Get models +models = Models(modelfile_path="./examples/assets/models_dbc.json") + +# Get MLExchange dash components +mlex_components = MLExComponents("dbc") +job_manager = mlex_components.get_job_manager(model_list=models.modelname_list) + +app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) +app.title = "Utils Example" +app._favicon = "mlex.ico" + +app_header = header( + "MLExchange | Utils Example", + "https://mlexchange.als.lbl.gov", + "https://mlexchange.als.lbl.gov/docs", + app.get_asset_url("mlex.png"), +) + +app.layout = html.Div( + [ + app_header, + dbc.Container( + children=[ + dbc.Row( + [ + dbc.Col( + get_control_panel(job_manager), + style={ + "display": "flex", + "margin-top": "1em", + "max-width": "450px", + }, + ), + dbc.Col( + dbc.Card( + children=[ + dbc.CardHeader("Model Parameters"), + dbc.CardBody( + children=[ + html.Div( + id="model-params-out", + ) + ] + ), + ], + style={"margin-top": "1em"}, + ), + ), + ] + ), + ], + fluid=True, + ), + ], +) + + +@callback( + Output( + { + "component": "DbcJobManagerAIO", + "subcomponent": "model-parameters", + "aio_id": MATCH, + }, + "children", + ), + Input( + { + "component": "DbcJobManagerAIO", + "subcomponent": "model-list", + "aio_id": MATCH, + }, + "value", + ), +) +def update_model_parameters(model_name): + model = models[model_name] + if model["gui_parameters"]: + item_list = mlex_components.get_parameter_items( + _id={"type": str(uuid.uuid4())}, json_blob=model["gui_parameters"] + ) + return item_list + else: + return html.Div("Model has no parameters") + + +@callback( + Output("model-params-out", "children"), + Input( + { + "component": "DbcJobManagerAIO", + "subcomponent": "model-parameters", + "aio_id": ALL, + }, + "children", + ), + prevent_initial_call=True, +) +def update_model_parameters_output(model_parameter_container): + model_parameters, parameter_errors = mlex_components.get_parameters_values( + model_parameter_container[0] + ) + return str(model_parameters) + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/examples/dmc_example.py b/examples/dmc_example.py new file mode 100644 index 0000000..0abd800 --- /dev/null +++ b/examples/dmc_example.py @@ -0,0 +1,252 @@ +import uuid + +import dash_mantine_components as dmc +from dash import ALL, MATCH, Dash, Input, Output, callback, html +from dash_iconify import DashIconify +from models_utils import Models + +from mlex_utils.dash_utils.components_mantime.component_utils import ( + DmcControlItem as ControlItem, +) +from mlex_utils.dash_utils.components_mantime.component_utils import ( + _accordion_item, + _tooltip, + drawer_section, +) +from mlex_utils.dash_utils.mlex_components import MLExComponents + + +def layout(job_manager): + """ + Returns the layout for the control panel in the app UI + """ + return drawer_section( + "MLExchange Utils Example with DMC", + dmc.Stack( + style={"width": "400px"}, + children=[ + dmc.AccordionMultiple( + id="control-accordion", + value=["data-select", "image-transformations", "run-model"], + children=[ + _accordion_item( + "Data selection", + "majesticons:data-line", + "data-select", + id="data-selection-controls", + children=[ + dmc.Space(h=5), + ControlItem( + "Dataset", + "image-selector", + dmc.Grid( + [ + dmc.Select( + id="project-name-src", + data=[], + placeholder="Select an image to view...", + ), + dmc.ActionIcon( + _tooltip( + "Refresh dataset", + children=[ + DashIconify( + icon="mdi:refresh-circle", + width=20, + ), + ], + ), + size="xs", + variant="subtle", + id="refresh-tiled", + n_clicks=0, + style={"margin": "auto"}, + ), + ], + style={"margin": "0px"}, + ), + ), + dmc.Space(h=10), + ], + ), + _accordion_item( + "Image transformations", + "fluent-mdl2:image-pixel", + "image-transformations", + id="image-transformation-controls", + children=html.Div( + [ + dmc.Space(h=5), + ControlItem( + "Brightness", + "bightness-text", + [ + dmc.Grid( + [ + dmc.Slider( + id={ + "type": "slider", + "index": "brightness", + }, + value=100, + min=0, + max=255, + step=1, + color="gray", + size="sm", + style={"width": "225px"}, + ), + dmc.ActionIcon( + _tooltip( + "Reset brightness", + children=[ + DashIconify( + icon="fluent:arrow-reset-32-regular", + width=15, + ), + ], + ), + size="xs", + variant="subtle", + id={ + "type": "reset", + "index": "brightness", + }, + n_clicks=0, + style={"margin": "auto"}, + ), + ], + style={"margin": "0px"}, + ) + ], + ), + dmc.Space(h=20), + ControlItem( + "Contrast", + "contrast-text", + dmc.Grid( + [ + dmc.Slider( + id={ + "type": "slider", + "index": "contrast", + }, + value=100, + min=0, + max=255, + step=1, + color="gray", + size="sm", + style={"width": "225px"}, + ), + dmc.ActionIcon( + _tooltip( + "Reset contrast", + children=[ + DashIconify( + icon="fluent:arrow-reset-32-regular", + width=15, + ), + ], + ), + size="xs", + variant="subtle", + id={ + "type": "reset", + "index": "contrast", + }, + n_clicks=0, + style={"margin": "auto"}, + ), + ], + style={"margin": "0px"}, + ), + ), + dmc.Space(h=10), + ] + ), + ), + _accordion_item( + "Model configuration", + "carbon:ibm-watson-machine-learning", + "run-model", + id="model-configuration", + children=job_manager, + ), + ], + ), + ], + ), + ) + + +# Get models +models = Models(modelfile_path="./examples/assets/models_dmc.json") + +# Get MLExchange dash components +mlex_components = MLExComponents("dmc") +job_manager = mlex_components.get_job_manager(model_list=models.modelname_list) + +app = Dash(__name__) +app.layout = dmc.MantineProvider( + theme={"colorScheme": "light"}, + children=[ + layout(job_manager), + html.Div( + id="model-params-out", + style={"margin-left": "450px"}, + ), + ], +) + + +@callback( + Output( + { + "component": "DmcJobManagerAIO", + "subcomponent": "model-parameters", + "aio_id": MATCH, + }, + "children", + ), + Input( + { + "component": "DmcJobManagerAIO", + "subcomponent": "model-list", + "aio_id": MATCH, + }, + "value", + ), +) +def update_model_parameters(model_name): + model = models[model_name] + if model["gui_parameters"]: + item_list = mlex_components.get_parameter_items( + _id={"type": str(uuid.uuid4())}, json_blob=model["gui_parameters"] + ) + return item_list + else: + return html.Div("Model has no parameters") + + +@callback( + Output("model-params-out", "children"), + Input( + { + "component": "DmcJobManagerAIO", + "subcomponent": "model-parameters", + "aio_id": ALL, + }, + "children", + ), + prevent_initial_call=True, +) +def update_model_parameters_output(model_parameter_container): + model_parameters, parameter_errors = mlex_components.get_parameters_values( + model_parameter_container[0] + ) + return str(model_parameters) + + +if __name__ == "__main__": + app.run_server(debug=True, port=8051) diff --git a/examples/models_utils.py b/examples/models_utils.py new file mode 100644 index 0000000..69ed192 --- /dev/null +++ b/examples/models_utils.py @@ -0,0 +1,20 @@ +import json + + +class Models: + def __init__(self, modelfile_path="./examples/assets/models.json"): + self.path = modelfile_path + f = open(self.path) + + contents = json.load(f)["contents"] + self.modelname_list = [content["model_name"] for content in contents] + self.models = {} + + for i, n in enumerate(self.modelname_list): + self.models[n] = contents[i] + + def __getitem__(self, key): + try: + return self.models[key] + except KeyError: + raise KeyError(f"A model with name {key} does not exist.") diff --git a/mlex_utils/dash_utils/__init__.py b/mlex_utils/dash_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlex_utils/dash_utils/callbacks/manage_jobs.py b/mlex_utils/dash_utils/callbacks/manage_jobs.py new file mode 100644 index 0000000..5406b4c --- /dev/null +++ b/mlex_utils/dash_utils/callbacks/manage_jobs.py @@ -0,0 +1,96 @@ +from dash import html, no_update + +from mlex_utils.prefect_utils.core import ( + cancel_flow_run, + delete_flow_run, + get_flow_run_logs, + get_flow_run_name, + get_flow_run_state, + query_flow_runs, +) + +DEV_JOBS = [ + {"label": "❌ DLSIA ABC 03/11/2024 15:38PM", "value": "uid0001"}, + {"label": "🕑 DLSIA XYC 03/11/2024 14:21PM", "value": "uid0002"}, + {"label": "✅ DLSIA CBA 03/11/2024 10:02AM", "value": "uid0003"}, +] + + +def _check_job(prefect_tags, mode): + if mode == "dev": + data = DEV_JOBS + else: + data = query_flow_runs(tags=prefect_tags) + return data + + +def _check_train_job(prefect_tags, mode): + """ + This callback populates the train job selector dropdown with job names and ids from Prefect. + The list of jobs is filtered by the selected project in the project selector dropdown. + This callback displays the current status of the job as part of the job name in the dropdown. + In "dev" mode, the dropdown is populated with the sample data above. + """ + return _check_job(prefect_tags + ["train"], mode) + + +def _check_dependent_job(job_id, project_name, prefect_tags, mode): + if mode == "dev": + return DEV_JOBS, no_update + else: + if job_id is not None and get_flow_run_state(job_id) == "COMPLETED": + job_name = get_flow_run_name(job_id) + if job_name is not None: + data = query_flow_runs( + flow_run_name=job_name, + tags=prefect_tags.append(project_name), + ) + selected_value = None if len(data) == 0 else no_update + return data, selected_value + return [], None + + +def _check_inference_job(train_job_id, project_name, prefect_tags, mode): + """ + This callback populates the inference job selector dropdown with job names and ids from Prefect. + The list of jobs is filtered by the selected train job in the train job selector dropdown. + The selected value is set to None if the list of jobs is empty. + This callback displays the current status of the job as part of the job name in the dropdown. + In "dev" mode, the dropdown is populated with the sample data above. + """ + return _check_dependent_job( + train_job_id, project_name, prefect_tags + ["inference"], mode + ) + + +def _get_job_logs(job_id, mode): + """ + This callback retrieves the logs of the selected job from Prefect. + The logs are displayed in the logs textarea. + In "dev" mode, the logs are retrieved from the sample data above. + """ + if mode == "dev": + return "Sample logs" + else: + logs = get_flow_run_logs(job_id) + return [item for pair in zip(logs, [html.Br()] * len(logs)) for item in pair][ + :-1 + ] + + +def _cancel_job(job_id, mode): + """ + This callback cancels the selected job in Prefect. + """ + if mode != "dev": + cancel_flow_run(job_id) + pass + + +def _delete_job(job_id, mode): + """ + This callback deletes the selected job in Prefect. + """ + if mode != "dev": + delete_flow_run(job_id) + pass diff --git a/mlex_utils/dash_utils/components_bootstrap/__init__.py b/mlex_utils/dash_utils/components_bootstrap/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlex_utils/dash_utils/components_bootstrap/advanced_options.py b/mlex_utils/dash_utils/components_bootstrap/advanced_options.py new file mode 100644 index 0000000..f0a81b3 --- /dev/null +++ b/mlex_utils/dash_utils/components_bootstrap/advanced_options.py @@ -0,0 +1,225 @@ +import uuid + +import dash_bootstrap_components as dbc +from dash import MATCH, Input, Output, State, callback, dcc + + +class DbcAdvancedOptionsAIO(dbc.Modal): + + class ids: + + advanced_options_modal = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + "aio_id": aio_id, + } + + cancel_button = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "cancel-button", + "aio_id": aio_id, + } + + delete_button = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "delete-button", + "aio_id": aio_id, + } + + logs_area = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "logs-area", + "aio_id": aio_id, + } + + job_id = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "job-id", + "aio_id": aio_id, + } + + warning_delete_modal = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-delete-modal", + "aio_id": aio_id, + } + + warning_cancel_modal = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-cancel-modal", + "aio_id": aio_id, + } + + warning_confirm_delete = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-delete", + "aio_id": aio_id, + } + + warning_confirm_cancel = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-cancel", + "aio_id": aio_id, + } + + warning_undo_delete = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-undo-delete", + "aio_id": aio_id, + } + + warning_undo_cancel = lambda aio_id: { # noqa: E731 + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-undo-cancel", + "aio_id": aio_id, + } + + ids = ids + + def __init__( + self, + cancel_button_props=None, + delete_button_props=None, + logs_area_props=None, + aio_id=None, + ): + """ + JobExecutionAIO is an All-in-One component that is composed + of a parent `html.Div` with a button to cancel, delete, and advance a job. + - `cancel_button_props` - A dictionary of properties passed into the Button component for the cancel button. + - `delete_button_props` - A dictionary of properties passed into the Button component for the delete button. + - `logs_area_props` - A dictionary of properties passed into the Textarea component for the logs area. + - `aio_id` - The All-in-One component ID used to generate the markdown and dropdown components's dictionary IDs. + """ + if aio_id is None: + aio_id = str(uuid.uuid4()) + + cancel_button_props, delete_button_props = self._update_props( + cancel_button_props, delete_button_props + ) + + super().__init__( + id=self.ids.advanced_options_modal(aio_id), + children=[ + dbc.ModalHeader("Advanced Options"), + dbc.ModalBody(id=self.ids.logs_area(aio_id)), + dbc.ModalFooter( + dbc.Accordion( + dbc.AccordionItem( + title="Danger Zone", + children=dbc.Row( + [ + dbc.Col( + dbc.Button( + "Cancel Job", + id=self.ids.cancel_button(aio_id), + **cancel_button_props, + ), + ), + dbc.Col( + dbc.Button( + "Delete Job", + id=self.ids.delete_button(aio_id), + **delete_button_props, + ), + ), + ], + ), + ), + start_collapsed=True, + flush=True, + style={"width": "100%", "--bs-accordion-active-bg": "#ffb3b3"}, + ), + ), + dcc.Store(id=self.ids.job_id(aio_id), data=None), + dbc.Modal( + id=self.ids.warning_cancel_modal(aio_id), + children=[ + dbc.ModalHeader("Warning"), + dbc.ModalBody("Are you sure you want to cancel this job?"), + dbc.ModalFooter( + [ + dbc.Button( + "YES", + id=self.ids.warning_confirm_cancel(aio_id), + color="danger", + className="ml-auto", + ), + dbc.Button( + "NO", + id=self.ids.warning_undo_cancel(aio_id), + className="ml-auto", + ), + ] + ), + ], + ), + dbc.Modal( + id=self.ids.warning_delete_modal(aio_id), + children=[ + dbc.ModalHeader("Warning"), + dbc.ModalBody("Are you sure you want to delete this job?"), + dbc.ModalFooter( + [ + dbc.Button( + "YES", + id=self.ids.warning_confirm_delete(aio_id), + color="danger", + className="ml-auto", + ), + dbc.Button( + "NO", + id=self.ids.warning_undo_delete(aio_id), + className="ml-auto", + ), + ] + ), + ], + ), + ], + scrollable=True, + size="lg", + ) + + def _update_props(self, cancel_button_props, delete_button_props): + cancel_button_props = cancel_button_props.copy() if cancel_button_props else {} + delete_button_props = delete_button_props.copy() if delete_button_props else {} + + cancel_button_props = self._update_button_props( + cancel_button_props, "danger", {"width": "100%"} + ) + delete_button_props = self._update_button_props( + delete_button_props, "danger", {"width": "100%"} + ) + return cancel_button_props, delete_button_props + + def _update_button_props(self, button_props, color, style): + button_props["color"] = color + button_props["style"] = style + return button_props + + @staticmethod + @callback( + Output(ids.warning_cancel_modal(MATCH), "is_open"), + Input(ids.cancel_button(MATCH), "n_clicks"), + Input(ids.warning_undo_cancel(MATCH), "n_clicks"), + State(ids.warning_cancel_modal(MATCH), "is_open"), + prevent_initial_call=True, + ) + def toggle_warning_cancel_modal( + cancel_button_n_clicks, undo_cancel_button_n_clicks, is_open + ): + return not is_open + + @staticmethod + @callback( + Output(ids.warning_delete_modal(MATCH), "is_open"), + Input(ids.delete_button(MATCH), "n_clicks"), + Input(ids.warning_undo_delete(MATCH), "n_clicks"), + State(ids.warning_delete_modal(MATCH), "is_open"), + prevent_initial_call=True, + ) + def toggle_warning_delete_modal( + delete_button_n_clicks, undo_delete_button_n_clicks, is_open + ): + return not is_open diff --git a/mlex_utils/dash_utils/components_bootstrap/component_utils.py b/mlex_utils/dash_utils/components_bootstrap/component_utils.py new file mode 100644 index 0000000..d8274c4 --- /dev/null +++ b/mlex_utils/dash_utils/components_bootstrap/component_utils.py @@ -0,0 +1,133 @@ +import dash_bootstrap_components as dbc +from dash import html +from dash_iconify import DashIconify + + +class DbcControlItem(dbc.Row): + """ + Customized layout for a control item + """ + + def __init__(self, title, title_id, item, style={"width": "100%", "margin": "0px"}): + super(DbcControlItem, self).__init__( + children=[ + dbc.Label( + title, + id=title_id, + size="sm", + style={ + "width": "100%", + "align-content": "center", + "paddingRight": "5px", + "text-align": "right", + }, + ), + html.Div(item, style={"width": "265px"}), + ], + style=style, + className="g-0", + ) + + +def header(app_title, github_url, help_url, app_logo="assets/mlex.png"): + """ + This header will exist at the top of the webpage rather than browser tab + Args: + app_title: Title of dash app + github_url: URL to github repo + help_url: URL to help page + app_logo: URL to app logo + """ + header = dbc.Navbar( + dbc.Container( + [ + dbc.Row( + [ + dbc.Col( + html.Img( + id="logo", + src=app_logo, + height="60px", + ), + md="auto", + ), + dbc.Col( + [ + html.Div( + [ + html.H3( + app_title, + style={ + "color": "white", + "padding": "0px", + "margin": "0px", + }, + ), + ], + id="app-title", + ) + ], + md=True, + align="center", + ), + ], + align="center", + ), + dbc.Row( + [ + dbc.Col( + [ + dbc.NavbarToggler(id="navbar-toggler"), + dbc.Collapse( + dbc.Nav( + [ + dbc.NavItem( + dbc.NavLink( + DashIconify( + icon="mdi:github", + width=50, + ), + style={ + "margin-right": "1rem", + "color": "#00313C", + "background-color": "white", + }, + href=github_url, + active=True, + ) + ), + dbc.NavItem( + dbc.NavLink( + DashIconify( + icon="mdi:help", + width=50, + ), + style={ + "color": "#00313C", + "background-color": "white", + }, + href=help_url, + active=True, + ) + ), + ], + navbar=True, + pills=True, + ), + id="navbar-collapse", + navbar=True, + ), + ], + md=2, + ), + ], + align="center", + ), + ], + fluid=True, + ), + dark=True, + color="dark", + sticky="top", + ) + return header diff --git a/mlex_utils/dash_utils/components_bootstrap/job_manager.py b/mlex_utils/dash_utils/components_bootstrap/job_manager.py new file mode 100644 index 0000000..f0d9dc5 --- /dev/null +++ b/mlex_utils/dash_utils/components_bootstrap/job_manager.py @@ -0,0 +1,497 @@ +import uuid + +import dash_bootstrap_components as dbc +from dash import ( + MATCH, + Input, + Output, + State, + callback, + callback_context, + dcc, + html, + no_update, +) +from dash_iconify import DashIconify + +from mlex_utils.dash_utils.callbacks.manage_jobs import ( + _cancel_job, + _check_inference_job, + _check_train_job, + _delete_job, + _get_job_logs, +) +from mlex_utils.dash_utils.components_bootstrap.advanced_options import ( + DbcAdvancedOptionsAIO, +) +from mlex_utils.dash_utils.components_bootstrap.component_utils import DbcControlItem + + +class DbcJobManagerAIO(html.Div): + + class ids: + + job_name_title = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "job-name-title", + "aio_id": aio_id, + } + + job_name = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "job-name", + "aio_id": aio_id, + } + + train_button = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "train-button", + "aio_id": aio_id, + } + + train_dropdown_title = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "train-dropdown-title", + "aio_id": aio_id, + } + + train_dropdown = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "train-dropdown", + "aio_id": aio_id, + } + + advanced_options_modal_train = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "advanced-options-modal-train", + "aio_id": aio_id, + } + + inference_button = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "inference-button", + "aio_id": aio_id, + } + + inference_dropdown_title = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "inference-dropdown-title", + "aio_id": aio_id, + } + + inference_dropdown = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "inference-dropdown", + "aio_id": aio_id, + } + + advanced_options_modal_inference = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "advanced-options-modal-inference", + "aio_id": aio_id, + } + + check_job = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "check-job", + "aio_id": aio_id, + } + + project_name_id = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "project-name-id", + "aio_id": aio_id, + } + + notifications_container = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "notifications-container", + "aio_id": aio_id, + } + + model_parameters = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "model-parameters", + "aio_id": aio_id, + } + + model_list_title = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "model-list-title", + "aio_id": aio_id, + } + + model_list = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "model-list", + "aio_id": aio_id, + } + + show_training_stats = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "show-training-stats", + "aio_id": aio_id, + } + + show_training_stats_title = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "show-training-stats-title", + "aio_id": aio_id, + } + + ids = ids + + def __init__( + self, + model_list=["Test Model"], + prefect_tags=[], + mode="dev", + train_button_props=None, + inference_button_props=None, + show_training_stats_button_props=None, + modal_props=None, + aio_id=None, + ): + """ + DbcJobManagerAIO is an All-in-One component that is composed + of a parent `html.Div` with a button to train and infer a model. + - `model_list` - A list of models + - `prefect_tags` - A list of tags used to filter Prefect flow runs. + - `mode` - The mode of the component. If "dev", the component will display sample data. + - `train_button_props` - A dictionary of properties passed into the Button component for the train button. + - `inference_button_props` - A dictionary of properties passed into the Button component for the inference button. + - `show_training_stats_button_props` - A dictionary of properties passed into the Button component for the + show training stats button. + - `modal_props` - A dictionary of properties passed into the Modal component for the advanced options modal. + - `aio_id` - The All-in-One component ID used to generate the markdown and dropdown components's dictionary IDs. + """ + if aio_id is None: + aio_id = str(uuid.uuid4()) + + if train_button_props is None: + train_button_props = {"color": "primary", "style": {"width": "100%"}} + if inference_button_props is None: + inference_button_props = {"color": "primary", "style": {"width": "100%"}} + if show_training_stats_button_props is None: + show_training_stats_button_props = { + "disabled": True, + "color": "secondary", + "style": {"width": "100%"}, + } + if modal_props is None: + modal_props = {"style": {}} + + self._aio_id = aio_id + self._prefect_tags = prefect_tags + self._mode = mode + + super().__init__( + [ + DbcControlItem( + "Algorithm", + self.ids.model_list_title(aio_id), + dbc.Select( + id=self.ids.model_list(aio_id), + options=[ + {"label": model, "value": model} for model in model_list + ], + value=(model_list[0] if model_list[0] else None), + ), + ), + html.Div(id=self.ids.model_parameters(aio_id)), + html.P(), + DbcControlItem( + "Name", + self.ids.job_name_title(aio_id), + dbc.Input( + id=self.ids.job_name(aio_id), + type="text", + placeholder="Name your job...", + style={"width": "100%"}, + ), + ), + html.Div(style={"height": "10px"}), + dbc.Button( + "Train", id=self.ids.train_button(aio_id), **train_button_props + ), + html.Div(style={"height": "10px"}), + DbcControlItem( + "Trained Jobs", + self.ids.train_dropdown_title(aio_id), + [ + dbc.Row( + [ + dbc.Col( + dbc.Select( + id=self.ids.train_dropdown(aio_id), + ), + width=10, + ), + dbc.Col( + dbc.Button( + DashIconify( + icon="mdi:settings", + style={"padding": "0px"}, + ), + id=self.ids.advanced_options_modal_train( + aio_id + ), + color="secondary", + style={"height": "36px", "line-height": "1"}, + ), + width=2, + ), + ], + className="g-1", + ), + ], + ), + html.Div(style={"height": "10px"}), + DbcControlItem( + "", + self.ids.show_training_stats_title(aio_id), + dbc.Button( + "Show Training Stats", + id=self.ids.show_training_stats(aio_id), + **show_training_stats_button_props, + ), + ), + html.Div(style={"height": "10px"}), + dbc.Button( + "Inference", + id=self.ids.inference_button(aio_id), + **inference_button_props, + ), + html.Div(style={"height": "10px"}), + DbcControlItem( + "Inference Jobs", + self.ids.inference_dropdown_title(aio_id), + [ + dbc.Row( + [ + dbc.Col( + dbc.Select( + id=self.ids.inference_dropdown(aio_id), + ), + width=10, + ), + dbc.Col( + dbc.Button( + DashIconify(icon="mdi:settings"), + id=self.ids.advanced_options_modal_inference( + aio_id + ), + color="secondary", + style={"height": "36px", "line-height": "1"}, + ), + width=1, + ), + ], + className="g-1", + ), + ], + ), + html.Div(style={"height": "10px"}), + DbcAdvancedOptionsAIO(aio_id=aio_id), + html.Div(id=self.ids.notifications_container(aio_id)), + dcc.Interval( + id=self.ids.check_job(aio_id), + interval=5000, + ), + dcc.Store( + id=self.ids.project_name_id(aio_id), + data="", + ), + ] + ) + + self.register_callbacks() + + @staticmethod + @callback( + Output( + { + "aio_id": MATCH, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + ), + Output( + { + "aio_id": MATCH, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + Input(ids.advanced_options_modal_train(MATCH), "n_clicks"), + Input(ids.advanced_options_modal_inference(MATCH), "n_clicks"), + State( + { + "aio_id": MATCH, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + ), + State(ids.train_dropdown(MATCH), "value"), + State(ids.inference_dropdown(MATCH), "value"), + prevent_initial_call=True, + ) + def toggle_modal(n1, n2, is_open, train_job_id, inference_job_id): + button_id = callback_context.triggered[0]["prop_id"].split(".")[0] + if "train" in button_id: + job_id = train_job_id + else: + job_id = inference_job_id + return not is_open, job_id + + @staticmethod + @callback( + Output(ids.advanced_options_modal_train(MATCH), "disabled"), + Input(ids.train_dropdown(MATCH), "value"), + ) + def disable_advanced_train_options(train_job_id): + if train_job_id is not None: + return False + return True + + @staticmethod + @callback( + Output(ids.advanced_options_modal_inference(MATCH), "disabled"), + Input(ids.inference_dropdown(MATCH), "value"), + ) + def disable_advanced_inference_options(inference_job_id): + if inference_job_id is not None: + return False + return True + + def register_callbacks(self): + + @callback( + Output(self.ids.train_dropdown(self._aio_id), "options"), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + ) + def check_train_job(n_intervals): + return _check_train_job(self._prefect_tags, self._mode) + + @callback( + Output(self.ids.inference_dropdown(self._aio_id), "options"), + Output(self.ids.inference_dropdown(self._aio_id), "value"), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + Input(self.ids.train_dropdown(self._aio_id), "value"), + State(self.ids.project_name_id(self._aio_id), "data"), + prevent_initial_call=True, + ) + def check_inferece_job(n_intervals, train_job_id, project_name): + return _check_inference_job( + train_job_id, project_name, self._prefect_tags, self._mode + ) + + @callback( + Output( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + allow_duplicate=True, + ), + Input( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-cancel", + }, + "n_clicks", + ), + State( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + prevent_initial_call=True, + ) + def cancel_job(n_clicks, job_id): + _cancel_job(job_id, self._mode) + return False + + @callback( + Output( + self.ids.train_dropdown(self._aio_id), "value", allow_duplicate=True + ), + Output( + self.ids.inference_dropdown(self._aio_id), "value", allow_duplicate=True + ), + Output( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + allow_duplicate=True, + ), + Input( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-delete", + }, + "n_clicks", + ), + State( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + State(self.ids.train_dropdown(self._aio_id), "value"), + State(self.ids.inference_dropdown(self._aio_id), "value"), + prevent_initial_call=True, + ) + def delete_job(n_clicks, job_id, train_job_id, inference_job_id): + _delete_job(job_id, self._mode) + if job_id == train_job_id: + return None, no_update, False + return no_update, None, False + + @callback( + Output( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "logs-area", + }, + "children", + ), + Input( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + ), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + State( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + prevent_initial_call=True, + ) + def get_logs(is_open, n_intervals, job_id): + if job_id is None: + return "No logs available" + return _get_job_logs(job_id, self._mode) diff --git a/mlex_utils/dash_utils/components_bootstrap/job_manager_minimal.py b/mlex_utils/dash_utils/components_bootstrap/job_manager_minimal.py new file mode 100644 index 0000000..b215beb --- /dev/null +++ b/mlex_utils/dash_utils/components_bootstrap/job_manager_minimal.py @@ -0,0 +1,375 @@ +import uuid + +import dash_bootstrap_components as dbc +from dash import MATCH, Input, Output, State, callback, dcc, html +from dash_iconify import DashIconify + +from mlex_utils.dash_utils.callbacks.manage_jobs import ( + _cancel_job, + _check_dependent_job, + _check_job, + _delete_job, + _get_job_logs, +) +from mlex_utils.dash_utils.components_bootstrap.advanced_options import ( + DbcAdvancedOptionsAIO, +) +from mlex_utils.dash_utils.components_bootstrap.component_utils import DbcControlItem + + +class DbcJobManagerMinimalAIO(html.Div): + + class ids: + + job_name_title = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "job-name-title", + "aio_id": aio_id, + } + + job_name = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "job-name", + "aio_id": aio_id, + } + + run_button = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "run-button", + "aio_id": aio_id, + } + + run_dropdown_title = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "run-dropdown-title", + "aio_id": aio_id, + } + + run_dropdown = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "run-dropdown", + "aio_id": aio_id, + } + + advanced_options_modal_run = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "advanced-options-modal-run", + "aio_id": aio_id, + } + + check_job = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "check-job", + "aio_id": aio_id, + } + + project_name_id = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "project-name-id", + "aio_id": aio_id, + } + + notifications_container = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "notifications-container", + "aio_id": aio_id, + } + + model_parameters = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "model-parameters", + "aio_id": aio_id, + } + + model_list_title = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "model-list-title", + "aio_id": aio_id, + } + + model_list = lambda aio_id: { # noqa: E731 + "component": "DbcJobManagerAIO", + "subcomponent": "model-list", + "aio_id": aio_id, + } + + ids = ids + + def __init__( + self, + model_list=["Test Model"], + prefect_tags=[], + mode="dev", + run_button_props=None, + modal_props=None, + aio_id=None, + dependency_id=None, + ): + """ + DbcJobManagerAIO is an All-in-One component that is composed + of a parent `html.Div` with a button to run and infer a model. + - `model_list` - A list of models + - `prefect_tags` - A list of tags used to filter Prefect flow runs. + - `mode` - The mode of the component. If "dev", the component will display sample data. + - `run_button_props` - A dictionary of properties passed into the Button component for the run button. + - `modal_props` - A dictionary of properties passed into the Modal component for the advanced options modal. + - `aio_id` - The All-in-One component ID used to generate the markdown and dropdown components's dictionary IDs. + - `dependency_id` - Check list of jobs that are dependent on the completion of the job of this component id + (dropdown). + """ + if aio_id is None: + aio_id = str(uuid.uuid4()) + + if run_button_props is None: + run_button_props = {"color": "primary", "style": {"width": "100%"}} + if modal_props is None: + modal_props = {"style": {}} + + self._aio_id = aio_id + self._prefect_tags = prefect_tags + self._mode = mode + self._dependency_id = dependency_id + + super().__init__( + [ + DbcControlItem( + "Algorithm", + self.ids.model_list_title(aio_id), + dbc.Select( + id=self.ids.model_list(aio_id), + options=[ + {"label": model, "value": model} for model in model_list + ], + value=(model_list[0] if model_list[0] else None), + ), + ), + html.Div(id=self.ids.model_parameters(aio_id)), + html.P(), + DbcControlItem( + "Name", + self.ids.job_name_title(aio_id), + dbc.Input( + id=self.ids.job_name(aio_id), + type="text", + placeholder="Name your job...", + style={"width": "100%"}, + ), + ), + html.Div(style={"height": "10px"}), + dbc.Button("Run", id=self.ids.run_button(aio_id), **run_button_props), + html.Div(style={"height": "10px"}), + DbcControlItem( + "Jobs", + self.ids.run_dropdown_title(aio_id), + [ + dbc.Row( + [ + dbc.Col( + dbc.Select( + id=self.ids.run_dropdown(aio_id), + ), + width=10, + ), + dbc.Col( + dbc.Button( + DashIconify( + icon="mdi:settings", + style={"padding": "0px"}, + ), + id=self.ids.advanced_options_modal_run(aio_id), + color="secondary", + style={"height": "36px", "line-height": "1"}, + ), + width=2, + ), + ], + className="g-1", + ), + ], + ), + html.Div(style={"height": "10px"}), + DbcAdvancedOptionsAIO(aio_id=aio_id), + html.Div(id=self.ids.notifications_container(aio_id)), + dcc.Interval( + id=self.ids.check_job(aio_id), + interval=5000, + ), + dcc.Store( + id=self.ids.project_name_id(aio_id), + data="", + ), + ] + ) + + self.register_callbacks() + + @staticmethod + @callback( + Output( + { + "aio_id": MATCH, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + ), + Output( + { + "aio_id": MATCH, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + Input(ids.advanced_options_modal_run(MATCH), "n_clicks"), + State( + { + "aio_id": MATCH, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + ), + State(ids.run_dropdown(MATCH), "value"), + prevent_initial_call=True, + ) + def toggle_modal(n1, is_open, run_job_id): + return not is_open, run_job_id + + @staticmethod + @callback( + Output(ids.advanced_options_modal_run(MATCH), "disabled"), + Input(ids.run_dropdown(MATCH), "value"), + ) + def disable_advanced_run_options(run_job_id): + if run_job_id is not None: + return False + return True + + def register_callbacks(self): + + @callback( + Output( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + allow_duplicate=True, + ), + Input( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-cancel", + }, + "n_clicks", + ), + State( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + prevent_initial_call=True, + ) + def cancel_job(n_clicks, job_id): + _cancel_job(job_id, self._mode) + return False + + @callback( + Output(self.ids.run_dropdown(self._aio_id), "value", allow_duplicate=True), + Output( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + allow_duplicate=True, + ), + Input( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-delete", + }, + "n_clicks", + ), + State( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + State(self.ids.run_dropdown(self._aio_id), "value"), + prevent_initial_call=True, + ) + def delete_job(n_clicks, job_id, run_job_id): + _delete_job(job_id, self._mode) + return None, False + + @callback( + Output( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "logs-area", + }, + "children", + ), + Input( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "is_open", + ), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + State( + { + "aio_id": self._aio_id, + "component": "DbcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + prevent_initial_call=True, + ) + def get_logs(is_open, n_intervals, job_id): + if job_id is None: + return "No logs available" + return _get_job_logs(job_id, self._mode) + + if self._dependency_id is not None: + + @callback( + Output( + self.ids.run_dropdown(self._aio_id), "options", allow_duplicate=True + ), + Output(self.ids.run_dropdown(self._aio_id), "value"), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + Input(self._dependency_id, "value"), + State(self.ids.project_name_id(self._aio_id), "data"), + prevent_initial_call=True, + ) + def check_dependent_job(n_intervals, dependent_job_id, project_name): + jobs = _check_dependent_job( + dependent_job_id, project_name, self._prefect_tags, self._mode + ) + return jobs + + else: + + @callback( + Output(self.ids.run_dropdown(self._aio_id), "options"), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + ) + def check_run_job(n_intervals): + return _check_job(self._prefect_tags, self._mode) diff --git a/mlex_utils/dash_utils/components_bootstrap/parameter_items.py b/mlex_utils/dash_utils/components_bootstrap/parameter_items.py new file mode 100644 index 0000000..7d63600 --- /dev/null +++ b/mlex_utils/dash_utils/components_bootstrap/parameter_items.py @@ -0,0 +1,225 @@ +import dash_bootstrap_components as dbc +from dash import dcc + +from mlex_utils.dash_utils.components_bootstrap.component_utils import DbcControlItem + + +class DbcSimpleItem(DbcControlItem): + def __init__( + self, + name, + base_id, + item_type, + title=None, + param_key=None, + visible=True, + debounce=True, + **kwargs, + ): + + if param_key is None: + param_key = name + + self.input = dbc.Input( + type=item_type, + debounce=debounce, + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DbcSimpleItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DbcNumberItem(DbcSimpleItem): + def __init__(self, *args, **kwargs): + super(DbcNumberItem, self).__init__(*args, item_type="number", **kwargs) + + +class DbcStrItem(DbcSimpleItem): + def __init__(self, *args, **kwargs): + super(DbcStrItem, self).__init__(*args, item_type="text", **kwargs) + + +class DbcSliderItem(DbcControlItem): + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): + + if param_key is None: + param_key = name + + self.input = dcc.Slider( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + tooltip={"placement": "bottom", "always_visible": True}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DbcSliderItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DbcDropdownItem(DbcControlItem): + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): + + if param_key is None: + param_key = name + + self.input = dbc.Select( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DbcDropdownItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DbcRadioItem(DbcControlItem): + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): + + if param_key is None: + param_key = name + + self.input = dbc.RadioItems( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DbcRadioItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DbcBoolItem(DbcControlItem): + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): + + if param_key is None: + param_key = name + + self.input = dbc.Switch( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + label=title, + label_style={"margin": "0px 0px 0px 0px"}, + # input_style={"height": "36px"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DbcBoolItem, self).__init__( + title="", # title is already in the switch + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DbcParameterItems(dbc.Form): + type_map = { + "float": DbcNumberItem, + "int": DbcNumberItem, + "str": DbcStrItem, + "slider": DbcSliderItem, + "dropdown": DbcDropdownItem, + "radio": DbcRadioItem, + "bool": DbcBoolItem, + } + + def __init__(self, _id, json_blob, values=None): + super(DbcParameterItems, self).__init__(id=_id, children=[]) + self._json_blob = json_blob + self.children = self.build_children(values=values) + + def _determine_type(self, parameter_dict): + if "type" in parameter_dict: + if parameter_dict["type"] in self.type_map: + return parameter_dict["type"] + elif parameter_dict["type"].__name__ in self.type_map: + return parameter_dict["type"].__name__ + elif type(parameter_dict["value"]) in self.type_map: + return type(parameter_dict["value"]) + raise TypeError( + f"No item type could be determined for this parameter: {parameter_dict}" + ) + + def build_children(self, values=None): + children = [] + for json_record in self._json_blob: + # Build a parameter dict from self.json_blob + type = json_record.get("type", self._determine_type(json_record)) + json_record = json_record.copy() + if values and json_record["name"] in values: + json_record["value"] = values[json_record["name"]] + json_record.pop("type", None) + # TODO: Use comp_group to fix training parameters and enable/disable parameters that + # do not fall into the scope of the job (training, inference, etc.) + if "comp_group" in json_record: + json_record.pop("comp_group", None) + item = self.type_map[type](**json_record, base_id=self.id) + children.append(item) + + return children diff --git a/mlex_utils/dash_utils/components_mantime/__init__.py b/mlex_utils/dash_utils/components_mantime/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlex_utils/dash_utils/components_mantime/advanced_options.py b/mlex_utils/dash_utils/components_mantime/advanced_options.py new file mode 100644 index 0000000..5e229b0 --- /dev/null +++ b/mlex_utils/dash_utils/components_mantime/advanced_options.py @@ -0,0 +1,279 @@ +import uuid + +import dash_mantine_components as dmc +from dash import MATCH, Input, Output, State, callback, dcc +from dash_iconify import DashIconify + + +class DmcAdvancedOptionsAIO(dmc.Modal): + + class ids: + + advanced_options_modal = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + "aio_id": aio_id, + } + + cancel_button = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "cancel-button", + "aio_id": aio_id, + } + + delete_button = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "delete-button", + "aio_id": aio_id, + } + + logs_area = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "logs-area", + "aio_id": aio_id, + } + + job_id = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "job-id", + "aio_id": aio_id, + } + + warning_delete_modal = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-delete-modal", + "aio_id": aio_id, + } + + warning_cancel_modal = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-cancel-modal", + "aio_id": aio_id, + } + + warning_confirm_delete = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-delete", + "aio_id": aio_id, + } + + warning_confirm_cancel = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-cancel", + "aio_id": aio_id, + } + + warning_undo_delete = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-undo-delete", + "aio_id": aio_id, + } + + warning_undo_cancel = lambda aio_id: { # noqa: E731 + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-undo-cancel", + "aio_id": aio_id, + } + + ids = ids + + def __init__( + self, + cancel_button_props=None, + delete_button_props=None, + logs_area_props=None, + aio_id=None, + ): + """ + JobExecutionAIO is an All-in-One component that is composed + of a parent `html.Div` with a button to cancel, delete, and advance a job. + - `cancel_button_props` - A dictionary of properties passed into the Button component for the cancel button. + - `delete_button_props` - A dictionary of properties passed into the Button component for the delete button. + - `logs_area_props` - A dictionary of properties passed into the Textarea component for the logs area. + - `aio_id` - The All-in-One component ID used to generate the markdown and dropdown components's dictionary IDs. + """ + if aio_id is None: + aio_id = str(uuid.uuid4()) + + cancel_button_props = self._update_button_props(cancel_button_props) + delete_button_props = self._update_button_props(delete_button_props) + + super().__init__( + title="Advanced Options", + id=self.ids.advanced_options_modal(aio_id), + opened=False, + children=[ + dmc.ScrollArea( + children=dmc.Paper( + [ + dmc.Text( + "These are the logs...", + id=self.ids.logs_area(aio_id), + ), + ], + style={"width": "100%", "height": 200, "margin-bottom": "10px"}, + ), + ), + dmc.Accordion( + children=[ + dmc.AccordionItem( + [ + dmc.AccordionControl( + "Danger Zone", + icon=DashIconify( + icon="mdi:alert-circle", + # color=dmc.DEFAULT_THEME["colors"]["red"][6], + width=20, + ), + ), + dmc.AccordionPanel( + [ + dmc.Grid( + [ + dmc.Col( + dmc.Button( + "Cancel Job", + id=self.ids.cancel_button( + aio_id + ), + **cancel_button_props, + ), + span=6, + ), + dmc.Col( + dmc.Button( + "Delete Job", + id=self.ids.delete_button( + aio_id + ), + **delete_button_props, + ), + span=6, + ), + ] + ), + ], + ), + ], + value="danger_zone", + ), + ], + ), + dcc.Store(id=self.ids.job_id(aio_id), data=None), + dmc.Modal( + title="Warning", + id=self.ids.warning_cancel_modal(aio_id), + opened=False, + children=[ + dmc.Text("Are you sure you want to cancel this job?"), + dmc.Space(h=25), + dmc.Grid( + [ + dmc.Col( + dmc.Button( + "YES", + id=self.ids.warning_confirm_cancel(aio_id), + color="red", + style={"width": "100%", "margin": "5px"}, + ), + span=6, + ), + dmc.Col( + dmc.Button( + "NO", + id=self.ids.warning_undo_cancel(aio_id), + style={"width": "100%", "margin": "5px"}, + ), + span=6, + ), + ] + ), + ], + ), + dmc.Modal( + title="Warning", + id=self.ids.warning_delete_modal(aio_id), + opened=False, + children=[ + dmc.Text("Are you sure you want to delete this job?"), + dmc.Space(h=25), + dmc.Grid( + [ + dmc.Col( + dmc.Button( + "YES", + id=self.ids.warning_confirm_delete(aio_id), + variant="light", + color="red", + style={"width": "100%", "margin": "5px"}, + ), + span=6, + ), + dmc.Col( + dmc.Button( + "NO", + id=self.ids.warning_undo_delete(aio_id), + variant="light", + style={"width": "100%", "margin": "5px"}, + ), + span=6, + ), + ] + ), + ], + ), + ], + ) + + def _update_button_props( + self, + button_props, + variant="light", + color="red", + style={"width": "100%", "margin": "5px"}, + ): + button_props = button_props.copy() if button_props else {} + button_props["variant"] = ( + variant if "variant" not in button_props else button_props["variant"] + ) + button_props["color"] = ( + color if "color" not in button_props else button_props["color"] + ) + button_props["style"] = ( + style if "style" not in button_props else button_props["style"] + ) + return button_props + + @staticmethod + @callback( + Output(ids.warning_cancel_modal(MATCH), "opened"), + Input(ids.cancel_button(MATCH), "n_clicks"), + Input(ids.warning_undo_cancel(MATCH), "n_clicks"), + Input(ids.warning_confirm_cancel(MATCH), "n_clicks"), + State(ids.warning_cancel_modal(MATCH), "opened"), + prevent_initial_call=True, + ) + def toggle_warning_cancel_modal( + cancel_button_n_clicks, + undo_cancel_button_n_clicks, + confirm_cancel_button_n_clicks, + is_open, + ): + return not is_open + + @staticmethod + @callback( + Output(ids.warning_delete_modal(MATCH), "opened"), + Input(ids.delete_button(MATCH), "n_clicks"), + Input(ids.warning_undo_delete(MATCH), "n_clicks"), + Input(ids.warning_confirm_delete(MATCH), "n_clicks"), + State(ids.warning_delete_modal(MATCH), "opened"), + prevent_initial_call=True, + ) + def toggle_warning_delete_modal( + delete_button_n_clicks, + undo_delete_button_n_clicks, + confirm_delete_n_clicks, + is_open, + ): + return not is_open diff --git a/mlex_utils/dash_utils/components_mantime/component_utils.py b/mlex_utils/dash_utils/components_mantime/component_utils.py new file mode 100644 index 0000000..6298a39 --- /dev/null +++ b/mlex_utils/dash_utils/components_mantime/component_utils.py @@ -0,0 +1,109 @@ +import dash_mantine_components as dmc +from dash import html +from dash_iconify import DashIconify + + +class DmcControlItem(dmc.Grid): + """ + Customized layout for a control item + """ + + def __init__(self, title, title_id, item, style={}): + super(DmcControlItem, self).__init__( + children=[ + dmc.Text( + title, + id=title_id, + size="sm", + style={"width": "100px", "margin": "auto", "paddingRight": "5px"}, + align="right", + ), + html.Div(item, style={"width": "265px", "margin": "auto"}), + ], + style=style, + ) + + +def _accordion_item(title, icon, value, children, id): + """ + Returns a customized layout for an accordion item + """ + panel = dmc.AccordionPanel(children=children, id=id) + return dmc.AccordionItem( + [ + dmc.AccordionControl( + title, + icon=DashIconify( + icon=icon, + color="#00313C", + width=20, + ), + ), + panel, + ], + value=value, + ) + + +def _tooltip(text, children): + """ + Returns a customized layout for a tooltip + """ + return dmc.Tooltip( + label=text, withArrow=True, position="top", color="#464646", children=children + ) + + +def drawer_section(title, children): + """ + This components creates an affix button that opens a drawer with the given children. + Drawer is set to have height and width of fit-content, meaning it won't be full height. + """ + return html.Div( + [ + dmc.Affix( + dmc.Button( + [ + DashIconify( + icon="circum:settings", + height=25, + style={"cursor": "pointer"}, + ), + dmc.Text("Controls", size="sm"), + ], + id="drawer-controls-open-button", + size="lg", + radius="sm", + compact=True, + variant="outline", + color="gray", + ), + position={"left": "25px", "top": "25px"}, + ), + dmc.Drawer( + title=dmc.Text(title, weight=700), + id="drawer-controls", + padding="md", + transition="fade", + transitionDuration=500, + shadow="md", + withOverlay=False, + position="left", + zIndex=10000, + styles={ + "drawer": { + "width": "fit-content", + "height": "fit-content", + "max-height": "100%", + "overflow-y": "auto", + "margin": "0px", + }, + "root": { + "opacity": "0.95", + }, + }, + children=children, + opened=True, + ), + ] + ) diff --git a/mlex_utils/dash_utils/components_mantime/job_manager.py b/mlex_utils/dash_utils/components_mantime/job_manager.py new file mode 100644 index 0000000..51bd511 --- /dev/null +++ b/mlex_utils/dash_utils/components_mantime/job_manager.py @@ -0,0 +1,521 @@ +import uuid + +import dash_mantine_components as dmc +from dash import ( + MATCH, + Input, + Output, + State, + callback, + callback_context, + dcc, + html, + no_update, +) +from dash_iconify import DashIconify + +from mlex_utils.dash_utils.callbacks.manage_jobs import ( + _cancel_job, + _check_inference_job, + _check_train_job, + _delete_job, + _get_job_logs, +) +from mlex_utils.dash_utils.components_mantime.advanced_options import ( + DmcAdvancedOptionsAIO, +) +from mlex_utils.dash_utils.components_mantime.component_utils import ( + DmcControlItem, + _tooltip, +) + + +class DmcJobManagerAIO(html.Div): + + class ids: + + job_name_title = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "job-name-title", + "aio_id": aio_id, + } + + job_name = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "job-name", + "aio_id": aio_id, + } + + train_button = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "train-button", + "aio_id": aio_id, + } + + train_dropdown_title = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "train-dropdown-title", + "aio_id": aio_id, + } + + train_dropdown = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "train-dropdown", + "aio_id": aio_id, + } + + advanced_options_modal_train = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "advanced-options-modal-train", + "aio_id": aio_id, + } + + training_stats_title = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "training-stats-title", + "aio_id": aio_id, + } + + training_stats = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "training-stats", + "aio_id": aio_id, + } + + inference_button = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "inference-button", + "aio_id": aio_id, + } + + inference_dropdown_title = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "inference-dropdown-title", + "aio_id": aio_id, + } + + inference_dropdown = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "inference-dropdown", + "aio_id": aio_id, + } + + advanced_options_modal_inference = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "advanced-options-modal-inference", + "aio_id": aio_id, + } + + advanced_options_modal = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "advanced-options-modal", + "aio_id": aio_id, + } + + check_job = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "check-job", + "aio_id": aio_id, + } + + project_name_id = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "project-name-id", + "aio_id": aio_id, + } + + notifications_container = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "notifications-container", + "aio_id": aio_id, + } + + model_parameters = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "model-parameters", + "aio_id": aio_id, + } + + model_list_title = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "model-list-title", + "aio_id": aio_id, + } + + model_list = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "model-list", + "aio_id": aio_id, + } + + show_training_stats = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "show-training-stats", + "aio_id": aio_id, + } + + show_training_stats_title = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "show-training-stats-title", + "aio_id": aio_id, + } + + ids = ids + + def __init__( + self, + model_list=["Test Model"], + prefect_tags=[], + mode="dev", + train_button_props=None, + inference_button_props=None, + show_training_stats_button_props=None, + modal_props=None, + aio_id=None, + ): + """ + DmcJobManagerAIO is an All-in-One component that is composed + of a parent `html.Div` with a button to train and infer a model. + - `model_list` - A list of models + - `prefect_tags` - A list of tags used to filter Prefect flow runs. + - `mode` - The mode of the component. If "dev", the component will display sample data. + - `train_button_props` - A dictionary of properties passed into the Button component for the train button. + - `inference_button_props` - A dictionary of properties passed into the Button component for the inference button. + - `show_training_stats_button_props` - A dictionary of properties passed into the Button component for the + show training stats button. + - `modal_props` - A dictionary of properties passed into the Modal component for the advanced options modal. + - `aio_id` - The All-in-One component ID used to generate the markdown and dropdown components's dictionary IDs. + """ + if aio_id is None: + aio_id = str(uuid.uuid4()) + + if train_button_props is None: + train_button_props = { + "variant": "light", + "style": {"width": "100%", "margin": "5px"}, + } + if inference_button_props is None: + inference_button_props = { + "variant": "light", + "style": {"width": "100%", "margin": "5px"}, + } + if show_training_stats_button_props is None: + show_training_stats_button_props = { + "size": "sm", + "radius": "lg", + "color": "gray", + "disabled": True, + "style": {"width": "100%"}, + } + if modal_props is None: + modal_props = {"style": {"margin": "10px 10px 10px 250px"}} + + self._aio_id = aio_id + self._prefect_tags = prefect_tags + self._mode = mode + + super().__init__( + [ + DmcControlItem( + "Algorithm", + self.ids.model_list_title(aio_id), + dmc.Select( + id=self.ids.model_list(aio_id), + data=model_list, + value=(model_list[0] if model_list[0] else None), + ), + ), + dmc.Space(h=15), + html.Div(id=self.ids.model_parameters(aio_id)), + dmc.Space(h=25), + DmcControlItem( + "Name", + self.ids.job_name_title(aio_id), + dmc.TextInput( + placeholder="Name your job...", + id=self.ids.job_name(aio_id), + ), + ), + dmc.Space(h=10), + dmc.Button( + "Train", id=self.ids.train_button(aio_id), **train_button_props + ), + dmc.Space(h=10), + DmcControlItem( + "Trained Jobs", + self.ids.train_dropdown_title(aio_id), + dmc.Grid( + [ + dmc.Select( + placeholder="Select a job...", + id=self.ids.train_dropdown(aio_id), + ), + dmc.ActionIcon( + _tooltip( + "Advanced Options", + children=[ + DashIconify( + icon="mdi:settings-applications", + width=30, + ), + ], + ), + size="xs", + variant="subtle", + id=self.ids.advanced_options_modal_train(aio_id), + n_clicks=0, + style={"margin": "auto"}, + ), + ], + style={"margin": "0px"}, + ), + ), + dmc.Space(h=25), + DmcControlItem( + "", + self.ids.show_training_stats_title(aio_id), + dmc.Button( + "Show Training Stats", + id=self.ids.show_training_stats(aio_id), + **show_training_stats_button_props, + ), + ), + dmc.Space(h=10), + dmc.Button( + "Inference", + id=self.ids.inference_button(aio_id), + **inference_button_props, + ), + dmc.Space(h=10), + DmcControlItem( + "Inference Jobs", + self.ids.inference_dropdown_title(aio_id), + dmc.Grid( + [ + dmc.Select( + placeholder="Select a job...", + id=self.ids.inference_dropdown(aio_id), + ), + dmc.ActionIcon( + _tooltip( + "Advanced Options", + children=[ + DashIconify( + icon="mdi:settings-applications", + width=30, + ), + ], + ), + size="xs", + variant="subtle", + id=self.ids.advanced_options_modal_inference(aio_id), + n_clicks=0, + style={"margin": "auto"}, + ), + ], + style={"margin": "0px"}, + ), + ), + DmcAdvancedOptionsAIO(aio_id=aio_id), + html.Div(id=self.ids.notifications_container(aio_id)), + dcc.Interval( + id=self.ids.check_job(aio_id), + interval=5000, + ), + dcc.Store( + id=self.ids.project_name_id(aio_id), + data="", + ), + ] + ) + + self.register_callbacks() + + @staticmethod + @callback( + Output( + { + "aio_id": MATCH, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + ), + Output( + { + "aio_id": MATCH, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + Input(ids.advanced_options_modal_train(MATCH), "n_clicks"), + Input(ids.advanced_options_modal_inference(MATCH), "n_clicks"), + State( + { + "aio_id": MATCH, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + ), + State(ids.train_dropdown(MATCH), "value"), + State(ids.inference_dropdown(MATCH), "value"), + prevent_initial_call=True, + ) + def toggle_modal(n1, n2, is_open, train_job_id, inference_job_id): + button_id = callback_context.triggered[0]["prop_id"].split(".")[0] + if "train" in button_id: + job_id = train_job_id + else: + job_id = inference_job_id + return not is_open, job_id + + @staticmethod + @callback( + Output(ids.advanced_options_modal_train(MATCH), "disabled"), + Input(ids.train_dropdown(MATCH), "value"), + ) + def disable_advanced_train_options(train_job_id): + if train_job_id is not None: + return False + return True + + @staticmethod + @callback( + Output(ids.advanced_options_modal_inference(MATCH), "disabled"), + Input(ids.inference_dropdown(MATCH), "value"), + ) + def disable_advanced_inference_options(inference_job_id): + if inference_job_id is not None: + return False + return True + + def register_callbacks(self): + + @callback( + Output(self.ids.train_dropdown(self._aio_id), "data"), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + ) + def check_train_job(n_intervals): + return _check_train_job(self._prefect_tags, self._mode) + + @callback( + Output(self.ids.inference_dropdown(self._aio_id), "data"), + Output(self.ids.inference_dropdown(self._aio_id), "value"), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + Input(self.ids.train_dropdown(self._aio_id), "value"), + State(self.ids.project_name_id(self._aio_id), "data"), + prevent_initial_call=True, + ) + def check_inference_job(n_intervals, train_job_id, project_name): + return _check_inference_job( + train_job_id, project_name, self._prefect_tags, self._mode + ) + + @callback( + Output( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + allow_duplicate=True, + ), + Input( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-cancel", + }, + "n_clicks", + ), + State( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + prevent_initial_call=True, + ) + def cancel_job(n_clicks, job_id): + _cancel_job(job_id, self._mode) + return False + + @callback( + Output( + self.ids.train_dropdown(self._aio_id), "value", allow_duplicate=True + ), + Output( + self.ids.inference_dropdown(self._aio_id), "value", allow_duplicate=True + ), + Output( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + allow_duplicate=True, + ), + Input( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-delete", + }, + "n_clicks", + ), + State( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + State(self.ids.train_dropdown(self._aio_id), "value"), + State(self.ids.inference_dropdown(self._aio_id), "value"), + prevent_initial_call=True, + ) + def delete_job(n_clicks, job_id, train_job_id, inference_job_id): + _delete_job(job_id, self._mode) + if job_id == train_job_id: + return None, no_update, False + return no_update, None, False + + @callback( + Output( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "logs-area", + }, + "children", + ), + Input( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + ), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + State( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + prevent_initial_call=True, + ) + def get_logs(is_open, n_intervals, job_id): + if job_id is None: + return "No logs available" + return _get_job_logs(job_id, self._mode) diff --git a/mlex_utils/dash_utils/components_mantime/job_manager_minimal.py b/mlex_utils/dash_utils/components_mantime/job_manager_minimal.py new file mode 100644 index 0000000..dab789b --- /dev/null +++ b/mlex_utils/dash_utils/components_mantime/job_manager_minimal.py @@ -0,0 +1,381 @@ +import uuid + +import dash_mantine_components as dmc +from dash import MATCH, Input, Output, State, callback, dcc, html +from dash_iconify import DashIconify + +from mlex_utils.dash_utils.callbacks.manage_jobs import ( + _cancel_job, + _check_inference_job, + _check_train_job, + _delete_job, + _get_job_logs, +) +from mlex_utils.dash_utils.components_mantime.advanced_options import ( + DmcAdvancedOptionsAIO, +) +from mlex_utils.dash_utils.components_mantime.component_utils import ( + DmcControlItem, + _tooltip, +) + + +class DmcJobManagerMinimalAIO(html.Div): + + class ids: + + job_name_title = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "job-name-title", + "aio_id": aio_id, + } + + job_name = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "job-name", + "aio_id": aio_id, + } + + run_button = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "run-button", + "aio_id": aio_id, + } + + run_dropdown_title = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "run-dropdown-title", + "aio_id": aio_id, + } + + run_dropdown = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "run-dropdown", + "aio_id": aio_id, + } + + advanced_options_modal_run = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "advanced-options-modal-run", + "aio_id": aio_id, + } + + advanced_options_modal = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "advanced-options-modal", + "aio_id": aio_id, + } + + check_job = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "check-job", + "aio_id": aio_id, + } + + project_name_id = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "project-name-id", + "aio_id": aio_id, + } + + notifications_container = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "notifications-container", + "aio_id": aio_id, + } + + model_parameters = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "model-parameters", + "aio_id": aio_id, + } + + model_list_title = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "model-list-title", + "aio_id": aio_id, + } + + model_list = lambda aio_id: { # noqa: E731 + "component": "DmcJobManagerAIO", + "subcomponent": "model-list", + "aio_id": aio_id, + } + + ids = ids + + def __init__( + self, + model_list=["Test Model"], + prefect_tags=[], + mode="dev", + run_button_props=None, + modal_props=None, + aio_id=None, + dependency=None, + ): + """ + DmcJobManagerAIO is an All-in-One component that is composed + of a parent `html.Div` with a button to run and infer a model. + - `model_list` - A list of models + - `prefect_tags` - A list of tags used to filter Prefect flow runs. + - `mode` - The mode of the component. If "dev", the component will display sample data. + - `run_button_props` - A dictionary of properties passed into the Button component for the run button. + - `modal_props` - A dictionary of properties passed into the Modal component for the advanced options modal. + - `aio_id` - The All-in-One component ID used to generate the markdown and dropdown components's dictionary IDs. + - `dependency` - List of jobs is dependent on the completion of the value of this component (dropdown). + """ + if aio_id is None: + aio_id = str(uuid.uuid4()) + + if run_button_props is None: + run_button_props = { + "variant": "light", + "style": {"width": "100%", "margin": "5px"}, + } + + if modal_props is None: + modal_props = {"style": {"margin": "10px 10px 10px 250px"}} + + self._aio_id = aio_id + self._prefect_tags = prefect_tags + self._mode = mode + self._dependency = dependency + + super().__init__( + [ + DmcControlItem( + "Algorithm", + self.ids.model_list_title(aio_id), + dmc.Select( + id=self.ids.model_list(aio_id), + data=model_list, + value=(model_list[0] if model_list[0] else None), + ), + ), + dmc.Space(h=15), + html.Div(id=self.ids.model_parameters(aio_id)), + dmc.Space(h=25), + DmcControlItem( + "Name", + self.ids.job_name_title(aio_id), + dmc.TextInput( + placeholder="Name your job...", + id=self.ids.job_name(aio_id), + ), + ), + dmc.Space(h=10), + dmc.Button("Run", id=self.ids.run_button(aio_id), **run_button_props), + dmc.Space(h=10), + DmcControlItem( + "Jobs", + self.ids.run_dropdown_title(aio_id), + dmc.Grid( + [ + dmc.Select( + placeholder="Select a job...", + id=self.ids.run_dropdown(aio_id), + ), + dmc.ActionIcon( + _tooltip( + "Advanced Options", + children=[ + DashIconify( + icon="mdi:settings-applications", + width=30, + ), + ], + ), + size="xs", + variant="subtle", + id=self.ids.advanced_options_modal_run(aio_id), + n_clicks=0, + style={"margin": "auto"}, + ), + ], + style={"margin": "0px"}, + ), + ), + dmc.Space(h=25), + DmcAdvancedOptionsAIO(aio_id=aio_id), + html.Div(id=self.ids.notifications_container(aio_id)), + dcc.Interval( + id=self.ids.check_job(aio_id), + interval=5000, + ), + dcc.Store( + id=self.ids.project_name_id(aio_id), + data="", + ), + ] + ) + + self.register_callbacks() + + @staticmethod + @callback( + Output( + { + "aio_id": MATCH, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + ), + Output( + { + "aio_id": MATCH, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + Input(ids.advanced_options_modal_run(MATCH), "n_clicks"), + State( + { + "aio_id": MATCH, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + ), + State(ids.run_dropdown(MATCH), "value"), + prevent_initial_call=True, + ) + def toggle_modal(n1, is_open, run_job_id): + return not is_open, run_job_id + + @staticmethod + @callback( + Output(ids.advanced_options_modal_run(MATCH), "disabled"), + Input(ids.run_dropdown(MATCH), "value"), + ) + def disable_advanced_run_options(run_job_id): + if run_job_id is not None: + return False + return True + + def register_callbacks(self): + + @callback( + Output(self.ids.run_dropdown(self._aio_id), "data"), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + ) + def check_run_job(n_intervals): + return _check_train_job(self._prefect_tags, self._mode) + + @callback( + Output( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + allow_duplicate=True, + ), + Input( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-cancel", + }, + "n_clicks", + ), + State( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + prevent_initial_call=True, + ) + def cancel_job(n_clicks, job_id): + _cancel_job(job_id, self._mode) + return False + + @callback( + Output(self.ids.run_dropdown(self._aio_id), "value", allow_duplicate=True), + Output( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + allow_duplicate=True, + ), + Input( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "warning-confirm-delete", + }, + "n_clicks", + ), + State( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + State(self.ids.run_dropdown(self._aio_id), "value"), + prevent_initial_call=True, + ) + def delete_job(n_clicks, job_id, run_job_id): + _delete_job(job_id, self._mode) + return None, False + + @callback( + Output( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "logs-area", + }, + "children", + ), + Input( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "advanced-options-modal", + }, + "opened", + ), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + State( + { + "aio_id": self._aio_id, + "component": "DmcAdvancedOptionsAIO", + "subcomponent": "job-id", + }, + "data", + ), + prevent_initial_call=True, + ) + def get_logs(is_open, n_intervals, job_id): + if job_id is None: + return "No logs available" + return _get_job_logs(job_id, self._mode) + + if self._dependency: + + @callback( + Output( + self.ids.run_dropdown(self._aio_id), "data", allow_duplicate=True + ), + Output(self.ids.run_dropdown(self._aio_id), "value"), + Input(self.ids.check_job(self._aio_id), "n_intervals"), + Input(self._dependency, "value"), + State(self.ids.project_name_id(self._aio_id), "data"), + prevent_initial_call=True, + ) + def check_dependant_job(n_intervals, dependant_job_id, project_name): + return _check_inference_job( + dependant_job_id, project_name, self._prefect_tags, self._mode + ) diff --git a/mlex_utils/dash_utils/components_mantime/parameter_items.py b/mlex_utils/dash_utils/components_mantime/parameter_items.py new file mode 100644 index 0000000..aee82cf --- /dev/null +++ b/mlex_utils/dash_utils/components_mantime/parameter_items.py @@ -0,0 +1,260 @@ +import dash_bootstrap_components as dbc + +# TODO: Tentatively remove dbc.Form for mantime components +import dash_mantine_components as dmc + +from mlex_utils.dash_utils.components_mantime.component_utils import DmcControlItem + + +class DmcNumberItem(DmcControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.NumberInput( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DmcNumberItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DmcStrItem(DmcControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.TextInput( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DmcStrItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DmcSliderItem(DmcControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.Slider( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + labelAlwaysOn=False, + color="gray", + size="sm", + **kwargs, + ) + + style = {"padding": "15px 0px 15px 0px"} + if not visible: + style["display"] = "none" + + super(DmcSliderItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DmcDropdownItem(DmcControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.Select( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DmcDropdownItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DmcRadioItem(DmcControlItem): + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): + if param_key is None: + param_key = name + + options = [ + dmc.Radio(option["label"], value=option["value"]) + for option in kwargs["options"] + ] + kwargs.pop("options", None) + self.input = dmc.RadioGroup( + options, + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DmcRadioItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DmcBoolItem(DmcControlItem): + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): + if param_key is None: + param_key = name + + self.input = dmc.Switch( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + label=title, + size="sm", + radius="lg", + color="gray", + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DmcBoolItem, self).__init__( + title="", # title is already in the switch + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DmcParameterItems(dbc.Form): + type_map = { + "float": DmcNumberItem, + "int": DmcNumberItem, + "str": DmcStrItem, + "slider": DmcSliderItem, + "dropdown": DmcDropdownItem, + "radio": DmcRadioItem, + "bool": DmcBoolItem, + } + + def __init__(self, _id, json_blob, values=None): + super(DmcParameterItems, self).__init__(id=_id, children=[]) + self._json_blob = json_blob + self.children = self.build_children(values=values) + + def _determine_type(self, parameter_dict): + if "type" in parameter_dict: + if parameter_dict["type"] in self.type_map: + return parameter_dict["type"] + elif parameter_dict["type"].__name__ in self.type_map: + return parameter_dict["type"].__name__ + elif type(parameter_dict["value"]) in self.type_map: + return type(parameter_dict["value"]) + raise TypeError( + f"No item type could be determined for this parameter: {parameter_dict}" + ) + + def build_children(self, values=None): + children = [] + for json_record in self._json_blob: + # Build a parameter dict from self.json_blob + type = json_record.get("type", self._determine_type(json_record)) + json_record = json_record.copy() + if values and json_record["name"] in values: + json_record["value"] = values[json_record["name"]] + json_record.pop("type", None) + # TODO: Use comp_group to fix training parameters and enable/disable parameters that + # do not fall into the scope of the job (training, inference, etc.) + if "comp_group" in json_record: + json_record.pop("comp_group", None) + item = self.type_map[type](**json_record, base_id=self.id) + children.append(item) + + return children diff --git a/mlex_utils/dash_utils/mlex_components.py b/mlex_utils/dash_utils/mlex_components.py new file mode 100644 index 0000000..17657d5 --- /dev/null +++ b/mlex_utils/dash_utils/mlex_components.py @@ -0,0 +1,108 @@ +import importlib + +UI_STYLE_IMPLEMENTATIONS = { + "dbc": { + "job_manager": "mlex_utils.dash_utils.components_bootstrap.job_manager.DbcJobManagerAIO", + "job_manager_minimal": ( + "mlex_utils.dash_utils.components_bootstrap." + "job_manager_minimal.DbcJobManagerMinimalAIO" + ), + "parameter_items": ( + "mlex_utils.dash_utils.components_bootstrap.parameter_items." + "DbcParameterItems" + ), + }, + "dmc": { + "job_manager": "mlex_utils.dash_utils.components_mantime.job_manager.DmcJobManagerAIO", + "job_manager_minimal": ( + "mlex_utils.dash_utils.components_mantime." + "job_manager_minimal.DmcJobManagerMinimalAIO" + ), + "parameter_items": ( + "mlex_utils.dash_utils.components_mantime.parameter_items." + "DmcParameterItems" + ), + }, +} + + +def import_from_path(import_path: str): + """Dynamically import a class/function given its full import path.""" + module_path, attr_name = import_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, attr_name) + + +class MLExComponents: + ALLOWED_UI_STYLES = {"dbc", "dmc"} + + def __init__(self, ui_style): + if ui_style not in self.ALLOWED_UI_STYLES: + raise ValueError( + f"ui_style must be one of {self.ALLOWED_UI_STYLES}, got {ui_style}" + ) + self.ui_style = ui_style + + def get_job_manager(self, **kwargs): + path = UI_STYLE_IMPLEMENTATIONS[self.ui_style]["job_manager"] + cls = import_from_path(path) + return cls(**kwargs) + + def get_job_manager_minimal(self, **kwargs): + path = UI_STYLE_IMPLEMENTATIONS[self.ui_style]["job_manager_minimal"] + cls = import_from_path(path) + return cls(**kwargs) + + def get_parameter_items(self, **kwargs): + path = UI_STYLE_IMPLEMENTATIONS[self.ui_style]["parameter_items"] + cls = import_from_path(path) + return cls(**kwargs) + + @staticmethod + def get_parameters_values(parameters): + """ + Extracts parameters from the children component of a ParameterItems component, + if there are any errors in the input, it will return an error status + """ + errors = False + input_params = {} + for param in parameters["props"]["children"]: + # param["props"]["children"][0] is the label + # param["props"]["children"][1] is the input + parameter_container = param["props"]["children"][1] + # The actual parameter item is the first and only child of the parameter container + parameter_item = parameter_container["props"]["children"]["props"] + key = parameter_item["id"]["param_key"] + if "value" in parameter_item: + value = parameter_item["value"] + elif "checked" in parameter_item: + value = parameter_item["checked"] + if "error" in parameter_item: + if parameter_item["error"] is not False: + errors = True + input_params[key] = value + return input_params, errors + + # TODO: Consider changing the background of the components to indicate the change + @staticmethod + def update_parameters_values(current_parameters, new_values): + """ + Updates the current parameters with the new values + """ + parameters_children = current_parameters["props"].get("children", []) + + for param in parameters_children: + # param["props"]["children"][1] is the container for the input + # The actual input props are at ["props"]["children"]["props"] + input_props = param["props"]["children"][1]["props"]["children"]["props"] + key = input_props["id"]["param_key"] + + if key in new_values: + value = new_values[key] + # Update "value" if present, otherwise "checked" + if "value" in input_props: + input_props["value"] = value + elif "checked" in input_props: + input_props["checked"] = bool(value) + + return current_parameters diff --git a/mlex_utils/prefect_utils/core.py b/mlex_utils/prefect_utils/core.py index 49ba41c..6a53410 100644 --- a/mlex_utils/prefect_utils/core.py +++ b/mlex_utils/prefect_utils/core.py @@ -7,7 +7,11 @@ FlowRunFilterName, FlowRunFilterParentFlowRunId, FlowRunFilterTags, + LogFilter, + LogFilterFlowRunId, ) +from prefect.client.schemas.objects import State, StateType +from prefect.client.schemas.sorting import LogSort async def _schedule( @@ -45,19 +49,61 @@ def schedule_prefect_flow( return flow_run_id -async def _get_name(flow_run_id): +async def _delete( + flow_run_id: str, +): + async with get_client() as client: + await client.delete_flow_run(flow_run_id) + + +def delete_flow_run(flow_run_id: str): + asyncio.run(_delete(flow_run_id)) + + +async def _set_state( + flow_run_id: str, + state: StateType, + force: bool = False, +): + async with get_client() as client: + await client.set_flow_run_state(flow_run_id, state, force=force) + + +async def _get_flow_run_state(flow_run_id): async with get_client() as client: flow_run = await client.read_flow_run(flow_run_id) - if flow_run.state.is_final(): + return flow_run.state + + +def get_flow_run_state(flow_run_id): + flow_run_state = asyncio.run(_get_flow_run_state(flow_run_id)) + return flow_run_state.type + + +def cancel_flow_run(flow_run_id: str): + flow_run_state = asyncio.run(_get_flow_run_state(flow_run_id)) + if not flow_run_state.is_final(): + asyncio.run(_set_state(flow_run_id, State(type=StateType.CANCELLED))) + pass + + +async def _get_name(flow_run_id, is_completed): + async with get_client() as client: + flow_run = await client.read_flow_run(flow_run_id) + if flow_run and not is_completed: + return flow_run.name + elif flow_run and flow_run.state.is_final(): if flow_run.state.is_completed(): return flow_run.name return None -# TODO: Get flow_run_id when the flow has not completed as well -def get_flow_run_name(flow_run_id): - """Retrieves the name of the flow with the given id.""" - return asyncio.run(_get_name(flow_run_id)) +def get_flow_run_name(flow_run_id, is_completed=False): + """ + Retrieves the name of the flow with the given id. + If is_completed is True, it will return the name of the flow only if it is completed. + """ + return asyncio.run(_get_name(flow_run_id, is_completed)) async def _flow_run_query( @@ -80,8 +126,28 @@ async def _flow_run_query( return flow_runs -# TODO: Change name to query_flow_runs because it takes both names and tags -def get_flow_runs_by_name(flow_run_name=None, tags=None): +async def _read_flow_run(flow_run_id): + async with get_client() as client: + flow_run = await client.read_flow_run(flow_run_id) + return flow_run + + +async def _read_flow_run_logs(flow_run_id, limit=200, offset=0): + async with get_client() as client: + flow_run_logs = await client.read_logs( + log_filter=LogFilter( + flow_run_id=LogFilterFlowRunId( + any_=[flow_run_id], + ), + ), + limit=limit, + offset=offset, + sort=LogSort.TIMESTAMP_ASC, + ) + return flow_run_logs + + +def query_flow_runs(flow_run_name=None, tags=None): flow_runs_by_name = [] flow_runs = asyncio.run(_flow_run_query(tags, flow_run_name=flow_run_name)) for flow_run in flow_runs: @@ -105,3 +171,13 @@ def get_children_flow_run_ids(parent_flow_run_id, sort="START_TIME_ASC"): str(children_flow_run.id) for children_flow_run in children_flow_runs ] return children_flow_run_ids + + +def get_flow_run_logs(flow_run_id): + flow_run_logs = asyncio.run(_read_flow_run_logs(flow_run_id)) + return [log.message for log in flow_run_logs] + + +def get_flow_run_parameters(flow_run_id): + flow_run = asyncio.run(_read_flow_run(flow_run_id)) + return flow_run.parameters diff --git a/mlex_utils/test/test_dash.py b/mlex_utils/test/test_dash.py new file mode 100644 index 0000000..08fb635 --- /dev/null +++ b/mlex_utils/test/test_dash.py @@ -0,0 +1,193 @@ +import uuid +from contextvars import copy_context + +import pytest +from dash._callback_context import context_value +from dash._utils import AttributeDict + +from mlex_utils.dash_utils.components_bootstrap.advanced_options import ( + DbcAdvancedOptionsAIO, +) +from mlex_utils.dash_utils.components_mantime.advanced_options import ( + DmcAdvancedOptionsAIO, +) +from mlex_utils.dash_utils.mlex_components import MLExComponents + +model_parameters = [ + { + "type": "float", + "name": "float_param", + "title": "Float Parameter", + "param_key": "float_param", + "value": 1, + "comp_group": "group_1", + }, + { + "type": "int", + "name": "int_param", + "title": "Integer Parameter", + "param_key": "int_param", + "value": 1, + "comp_group": "group_1", + }, + { + "type": "str", + "name": "str_param", + "title": "String Parameter", + "param_key": "str_param", + "value": "test", + "comp_group": "group_1", + }, + { + "type": "slider", + "name": "slider", + "title": "Slider", + "param_key": "slider", + "min": 1, + "max": 1000, + "value": 30, + "comp_group": "group_1", + }, + { + "type": "dropdown", + "name": "dropdown", + "title": "Dropdown", + "param_key": "dropdown", + "comp_group": "group_1", + }, + { + "type": "radio", + "name": "radio", + "title": "Radio", + "param_key": "radio", + "options": [ + {"label": "Option 1", "value": 1}, + {"label": "Option 2", "value": 2}, + ], + "comp_group": "group_1", + }, + { + "type": "bool", + "name": "bool", + "title": "Bool", + "param_key": "bool", + "comp_group": "group_1", + }, +] + +new_values = { + "float_param": 2.0, + "int_param": 2, + "str_param": "test2", + "slider": 50, + "dropdown": "option_1", + "radio": 2, + "bool": True, +} + + +def serialize_dash_components(obj): + if hasattr(obj, "to_plotly_json"): + # Serialize the Dash component + serialized_obj = obj.to_plotly_json() + # Recursively process the serialized object's props + if isinstance(serialized_obj, dict) and "props" in serialized_obj: + serialized_obj["props"] = serialize_dash_components(serialized_obj["props"]) + return serialized_obj + elif isinstance(obj, dict): + # Recursively process dictionary items + return {key: serialize_dash_components(value) for key, value in obj.items()} + elif isinstance(obj, (list, tuple)): + # Recursively process list or tuple items + return [serialize_dash_components(item) for item in obj] + else: + # Return the object as is (e.g., strings, numbers) + return obj + + +@pytest.mark.parametrize("component_type", ["dbc", "dmc"]) +def test_get_job_manager(component_type): + mlex_components = MLExComponents(component_type) + job_manager = mlex_components.get_job_manager() + assert job_manager is not None + + +@pytest.mark.parametrize("component_type", ["dbc", "dmc"]) +def test_get_job_manager_minimal(component_type): + mlex_components = MLExComponents(component_type) + job_manager = mlex_components.get_job_manager_minimal() + assert job_manager is not None + + +@pytest.mark.parametrize("component_type", ["dbc", "dmc"]) +def test_advanced_options_modal(component_type): + mlex_components = MLExComponents(component_type) + job_manager = mlex_components.get_job_manager() + + def run_callback(n1, n2, is_open, train_job_id, inference_job_id, prop_id): + context_value.set(AttributeDict(**{"triggered_inputs": [{"prop_id": prop_id}]})) + return job_manager.toggle_modal(n1, n2, is_open, train_job_id, inference_job_id) + + ctx = copy_context() + + # Open advanced options modal with train job id + output = ctx.run( + run_callback, 1, 0, False, "uid0001", None, "train-button.n_clicks" + ) + assert output[0] is True and output[1] == "uid0001" + + # Close advanced options modal + output = ctx.run(run_callback, 0, 1, True, None, None, "train-button.n_clicks") + assert output[0] is False and output[1] is None + + # Open advanced options modal with no inference job id + output = ctx.run(run_callback, 0, 1, False, None, None, "inference-button.n_clicks") + assert output[0] is True and output[1] is None + + # Disable train advanced options modal button + assert not job_manager.disable_advanced_train_options("uid0001") + + # Disable inference advanced options modal button + assert job_manager.disable_advanced_inference_options(None) + + +@pytest.mark.parametrize("component_type", ["dbc", "dmc"]) +def test_get_parameters(component_type): + mlex_components = MLExComponents(component_type) + parameters = mlex_components.get_parameter_items( + _id={"type": str(uuid.uuid4())}, json_blob=model_parameters + ) + assert parameters is not None + parameters = serialize_dash_components(parameters) + parameters_dict, params_errors = mlex_components.get_parameters_values(parameters) + assert isinstance(parameters_dict, dict) and params_errors is False + + +@pytest.mark.parametrize("component_type", ["dbc", "dmc"]) +def test_get_and_update_parameters(component_type): + mlex_components = MLExComponents(component_type) + parameters = mlex_components.get_parameter_items( + _id={"type": str(uuid.uuid4())}, json_blob=model_parameters + ) + assert parameters is not None + parameters = serialize_dash_components(parameters) + parameters_dict, params_errors = mlex_components.get_parameters_values(parameters) + assert isinstance(parameters_dict, dict) and params_errors is False + + new_parameters = mlex_components.update_parameters_values(parameters, new_values) + assert new_parameters is not None + new_parameters = serialize_dash_components(new_parameters) + new_parameters_dict, new_params_errors = mlex_components.get_parameters_values( + new_parameters + ) + assert isinstance(new_parameters_dict, dict) and new_params_errors is False + + +def test_toggle_warnings_dbc(): + assert not DbcAdvancedOptionsAIO.toggle_warning_cancel_modal(0, 0, True) + assert DbcAdvancedOptionsAIO.toggle_warning_delete_modal(0, 0, False) + + +def test_toggle_warnings_dmc(): + assert not DmcAdvancedOptionsAIO.toggle_warning_cancel_modal(0, 0, 0, True) + assert DmcAdvancedOptionsAIO.toggle_warning_delete_modal(0, 0, 1, False) diff --git a/mlex_utils/test/test_prefect.py b/mlex_utils/test/test_prefect.py index e85fcc9..d8c57b1 100644 --- a/mlex_utils/test/test_prefect.py +++ b/mlex_utils/test/test_prefect.py @@ -2,14 +2,20 @@ import uuid from prefect import context, flow, get_client +from prefect.client.schemas.objects import StateType from prefect.deployments import Deployment from prefect.engine import create_then_begin_flow_run from prefect.testing.utilities import prefect_test_harness from mlex_utils.prefect_utils.core import ( + cancel_flow_run, + delete_flow_run, get_children_flow_run_ids, + get_flow_run_logs, get_flow_run_name, - get_flow_runs_by_name, + get_flow_run_parameters, + get_flow_run_state, + query_flow_runs, schedule_prefect_flow, ) @@ -35,14 +41,6 @@ def parent_flow(model_name): return parent_flow_run_id -deployment = Deployment.build_from_flow( - flow=parent_flow, - name="test_deployment", - version="1", - tags=["Test tag"], -) - - async def run_flow(): async with get_client() as client: flow_run_id = await create_then_begin_flow_run( @@ -58,6 +56,12 @@ async def run_flow(): def test_schedule_prefect_flows(): with prefect_test_harness(): + deployment = Deployment.build_from_flow( + flow=parent_flow, + name="test_deployment", + version="1", + tags=["Test tag"], + ) # Add deployment deployment.apply() @@ -70,14 +74,14 @@ def test_schedule_prefect_flows(): assert isinstance(flow_run_id, uuid.UUID) -def test_monitor_prefect_flows(): +def test_monitor_prefect_flow_runs(): with prefect_test_harness(): # Run flow flow_run_id = asyncio.run(run_flow()) assert isinstance(flow_run_id, str) # Get flow runs by name - flow_runs = get_flow_runs_by_name() + flow_runs = query_flow_runs() assert len(flow_runs) == 3 # Get flow run name @@ -87,3 +91,75 @@ def test_monitor_prefect_flows(): # Get children flow run ids children_flow_run_ids = get_children_flow_run_ids(flow_run_id) assert len(children_flow_run_ids) == 2 + + +def test_delete_prefect_flow_runs(): + with prefect_test_harness(): + # Run flow + flow_run_id = asyncio.run(run_flow()) + assert isinstance(flow_run_id, str) + + # Get flow runs by name + flow_runs = query_flow_runs() + assert len(flow_runs) == 3 + + # Delete flow run + delete_flow_run(flow_run_id) + + # Get flow runs by name + flow_runs = query_flow_runs() + assert len(flow_runs) < 3 + + +def test_cancel_prefect_flow_runs(): + with prefect_test_harness(): + deployment = Deployment.build_from_flow( + flow=parent_flow, + name="test_deployment", + version="1", + tags=["Test tag"], + ) + # Add deployment + deployment.apply() + + # Schedule parent flow + flow_run_id = schedule_prefect_flow( + deployment_name="Parent Flow/test_deployment", + parameters={"model_name": "model_name"}, + flow_run_name="flow_run_name", + ) + + # Change flow run state + flow_run_state = get_flow_run_state(flow_run_id) + assert flow_run_state not in [StateType.COMPLETED, StateType.FAILED] + + # Cancel flow run + cancel_flow_run(flow_run_id) + + # Check status in name + flow_run_name = get_flow_run_name(flow_run_id) + flow_run_label = query_flow_runs(flow_run_name)[0]["label"] + assert flow_run_label[0] == "🚫" + + +def test_get_flow_run_logs(): + with prefect_test_harness(): + # Run flow + flow_run_id = asyncio.run(run_flow()) + assert isinstance(flow_run_id, str) + + # Get flow run logs + flow_run_logs = get_flow_run_logs(flow_run_id) + assert len(flow_run_logs) > 0 + assert isinstance(flow_run_logs[0], str) + + +def test_get_flow_run_parameters(): + with prefect_test_harness(): + # Run flow + flow_run_id = asyncio.run(run_flow()) + assert isinstance(flow_run_id, str) + + # Get flow run logs + flow_run_parameters = get_flow_run_parameters(flow_run_id) + assert isinstance(flow_run_parameters, dict) diff --git a/pyproject.toml b/pyproject.toml index ba42586..0dd6a5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ +[tool.isort] +profile = "black" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" -[tool.hatch.build.targets.wheel] - packages = ["mlex_utils/**/*"] - [tool.hatch.metadata] allow-direct-references = true @@ -22,14 +22,32 @@ dependencies = [] [project.optional-dependencies] all = [ "prefect==2.14.21", + "dash==2.9.3", + "dash-bootstrap-components==1.6.0", + "dash-mantine-components==0.12.1", + "dash-core-components==2.0.0", + "dash-html-components==2.0.0", + "dash-iconify==0.1.2", + "griffe >= 0.49.0, <1.0.0", ] prefect = [ "prefect==2.14.21", + "griffe >= 0.49.0, <1.0.0", +] + +dash = [ + "dash==2.9.3", + "dash-bootstrap-components==1.6.0", + "dash-mantine-components==0.12.1", + "dash-core-components==2.0.0", + "dash-html-components==2.0.0", + "dash-iconify==0.1.2", ] dev = [ "black==24.2.0", + "dash[testing]==2.9.3", "flake8==7.0.0", "isort==5.13.2", "pre-commit==3.6.2",