1313# limitations under the License.
1414#
1515import json
16- import os
1716
18- import matplotlib .pyplot as plt
1917import mlrun
2018import mlrun .datastore
21- import mlrun .utils
2219import mlrun .feature_store as fs
20+ import mlrun .utils
2321import numpy as np
2422import pandas as pd
25- import seaborn as sns
26- from mlrun .artifacts import PlotArtifact
23+ import plotly . express as px
24+ from mlrun .artifacts import PlotlyArtifact
2725from mlrun .datastore .targets import ParquetTarget
2826# MLRun utils
2927from mlrun .utils .helpers import create_class
4240}
4341
4442
45- def _clear_current_figure ():
46- """
47- Clear matplotlib current figure.
48- """
49- plt .cla ()
50- plt .clf ()
51- plt .close ()
52-
53-
5443def show_values_on_bars (axs , h_v = "v" , space = 0.4 ):
5544 def _show_on_single_plot (ax_ ):
5645 if h_v == "v" :
@@ -74,33 +63,18 @@ def _show_on_single_plot(ax_):
7463
7564
7665def plot_stat (context , stat_name , stat_df ):
77- _clear_current_figure ()
78-
79- # Add chart
80- ax = plt .axes ()
81- stat_chart = sns .barplot (
66+ sorted_df = stat_df .sort_values (stat_name )
67+ fig = px .bar (
68+ data_frame = sorted_df ,
8269 x = stat_name ,
83- y = " index" ,
84- data = stat_df . sort_values ( stat_name , ascending = False ). reset_index () ,
85- ax = ax ,
70+ y = sorted_df . index ,
71+ title = f" { stat_name } feature scores" ,
72+ color = stat_name ,
8673 )
87- plt .tight_layout ()
88-
89- for p in stat_chart .patches :
90- width = p .get_width ()
91- plt .text (
92- 5 + p .get_width (),
93- p .get_y () + 0.55 * p .get_height (),
94- "{:1.2f}" .format (width ),
95- ha = "center" ,
96- va = "center" ,
97- )
98-
9974 context .log_artifact (
100- PlotArtifact ( f" { stat_name } " , body = plt . gcf () ),
101- local_path = os . path . join ( "plots" , "feature_selection" , f"{ stat_name } .html" ) ,
75+ item = PlotlyArtifact ( key = stat_name , figure = fig ),
76+ local_path = f"{ stat_name } .html" ,
10277 )
103- _clear_current_figure ()
10478
10579
10680def feature_selection (
@@ -115,7 +89,6 @@ def feature_selection(
11589 sample_ratio : float = None ,
11690 output_vector_name : float = None ,
11791 ignore_type_errors : bool = False ,
118- is_feature_vector : bool = False ,
11992):
12093 """
12194 Applies selected feature selection statistical functions or models on our 'df_artifact'.
@@ -138,10 +111,9 @@ def feature_selection(
138111 model name (ex. LinearSVC), formalized json (contains 'CLASS',
139112 'FIT', 'META') or a path to such json file.
140113 :param max_scaled_scores: produce feature scores table scaled with max_scaler.
141- :param sample_ratio: percentage of the dataset the user whishes to compute the feature selection process on.
114+ :param sample_ratio: percentage of the dataset the user wishes to compute the feature selection process on.
142115 :param output_vector_name: creates a new feature vector containing only the identifies features.
143116 :param ignore_type_errors: skips datatypes that are neither float nor int within the feature vector.
144- :param is_feature_vector: bool stating if the data is passed as a feature vector.
145117 """
146118 stat_filters = stat_filters or DEFAULT_STAT_FILTERS
147119 model_filters = model_filters or DEFAULT_MODEL_FILTERS
0 commit comments