Skip to content

Commit 696da33

Browse files
authored
[Feature-selection] Replace matplotlib with plotly (#815)
1 parent 55a6023 commit 696da33

File tree

5 files changed

+59
-74
lines changed

5 files changed

+59
-74
lines changed

feature_selection/feature_selection.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,15 @@
1313
# limitations under the License.
1414
#
1515
import json
16-
import os
1716

18-
import matplotlib.pyplot as plt
1917
import mlrun
2018
import mlrun.datastore
21-
import mlrun.utils
2219
import mlrun.feature_store as fs
20+
import mlrun.utils
2321
import numpy as np
2422
import pandas as pd
25-
import seaborn as sns
26-
from mlrun.artifacts import PlotArtifact
23+
import plotly.express as px
24+
from mlrun.artifacts import PlotlyArtifact
2725
from mlrun.datastore.targets import ParquetTarget
2826
# MLRun utils
2927
from mlrun.utils.helpers import create_class
@@ -42,15 +40,6 @@
4240
}
4341

4442

45-
def _clear_current_figure():
46-
"""
47-
Clear matplotlib current figure.
48-
"""
49-
plt.cla()
50-
plt.clf()
51-
plt.close()
52-
53-
5443
def show_values_on_bars(axs, h_v="v", space=0.4):
5544
def _show_on_single_plot(ax_):
5645
if h_v == "v":
@@ -74,33 +63,18 @@ def _show_on_single_plot(ax_):
7463

7564

7665
def plot_stat(context, stat_name, stat_df):
77-
_clear_current_figure()
78-
79-
# Add chart
80-
ax = plt.axes()
81-
stat_chart = sns.barplot(
66+
sorted_df = stat_df.sort_values(stat_name)
67+
fig = px.bar(
68+
data_frame=sorted_df,
8269
x=stat_name,
83-
y="index",
84-
data=stat_df.sort_values(stat_name, ascending=False).reset_index(),
85-
ax=ax,
70+
y=sorted_df.index,
71+
title=f"{stat_name} feature scores",
72+
color=stat_name,
8673
)
87-
plt.tight_layout()
88-
89-
for p in stat_chart.patches:
90-
width = p.get_width()
91-
plt.text(
92-
5 + p.get_width(),
93-
p.get_y() + 0.55 * p.get_height(),
94-
"{:1.2f}".format(width),
95-
ha="center",
96-
va="center",
97-
)
98-
9974
context.log_artifact(
100-
PlotArtifact(f"{stat_name}", body=plt.gcf()),
101-
local_path=os.path.join("plots", "feature_selection", f"{stat_name}.html"),
75+
item=PlotlyArtifact(key=stat_name, figure=fig),
76+
local_path=f"{stat_name}.html",
10277
)
103-
_clear_current_figure()
10478

10579

10680
def feature_selection(
@@ -115,7 +89,6 @@ def feature_selection(
11589
sample_ratio: float = None,
11690
output_vector_name: float = None,
11791
ignore_type_errors: bool = False,
118-
is_feature_vector: bool = False,
11992
):
12093
"""
12194
Applies selected feature selection statistical functions or models on our 'df_artifact'.
@@ -138,10 +111,9 @@ def feature_selection(
138111
model name (ex. LinearSVC), formalized json (contains 'CLASS',
139112
'FIT', 'META') or a path to such json file.
140113
:param max_scaled_scores: produce feature scores table scaled with max_scaler.
141-
:param sample_ratio: percentage of the dataset the user whishes to compute the feature selection process on.
114+
:param sample_ratio: percentage of the dataset the user wishes to compute the feature selection process on.
142115
:param output_vector_name: creates a new feature vector containing only the identifies features.
143116
:param ignore_type_errors: skips datatypes that are neither float nor int within the feature vector.
144-
:param is_feature_vector: bool stating if the data is passed as a feature vector.
145117
"""
146118
stat_filters = stat_filters or DEFAULT_STAT_FILTERS
147119
model_filters = model_filters or DEFAULT_MODEL_FILTERS

0 commit comments

Comments
 (0)