|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import matplotlib.dates as mdates |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +from scipy.stats import f, mannwhitneyu, shapiro, ttest_ind |
| 10 | + |
| 11 | + |
| 12 | +class BucketTest: |
| 13 | + """ BucketTest class computes and renders charts and statistics for bucket testing """ |
| 14 | + |
| 15 | + def __init__(self, df: pd.DataFrame, variable: str, group: str, x_axis='date', custom_title='', |
| 16 | + custom_day_interval=1, custom_ylabel=''): |
| 17 | + """ Create a new bucket test with the given attributes """ |
| 18 | + self.df = df |
| 19 | + self.variable = variable |
| 20 | + self.x_axis = x_axis |
| 21 | + self.group = group |
| 22 | + self.custom_title = custom_title |
| 23 | + self.custom_day_interval = custom_day_interval |
| 24 | + self.custom_ylabel = custom_ylabel |
| 25 | + |
| 26 | + def render(self, figure_size_x=12, figure_size_y=5, line_width=3, title_font_size=16, legend_font_size=14, |
| 27 | + rotation=30): |
| 28 | + """ Render renders the charts representing the bucket test """ |
| 29 | + |
| 30 | + fig, ax = plt.subplots(figsize=(figure_size_x, figure_size_y)) |
| 31 | + for group_value in self.df[self.group].unique(): |
| 32 | + df = self.df[self.df[self.group] == group_value] |
| 33 | + df.set_index(self.x_axis, drop=False, inplace=True) |
| 34 | + ax.plot(df[self.variable], label=group_value, linewidth=line_width) |
| 35 | + |
| 36 | + # Title customization |
| 37 | + if self.custom_title != '': |
| 38 | + plt.title(self.custom_title, fontsize=title_font_size) |
| 39 | + else: |
| 40 | + plt.title('{} per {}'.format(self.variable, self.group), fontsize=title_font_size) |
| 41 | + plt.legend(bbox_to_anchor=(1.3, 0.8), frameon=False, fontsize=legend_font_size) |
| 42 | + |
| 43 | + # Y-label customization |
| 44 | + plt.ylabel(self.custom_ylabel or self.variable) |
| 45 | + |
| 46 | + plt.ylim(0) |
| 47 | + plt.xticks(rotation=rotation) |
| 48 | + self.__set_locator_and_formatter__(ax) |
| 49 | + plt.show() |
| 50 | + plt.savefig(Path('Chart')) |
| 51 | + |
| 52 | + def __set_locator_and_formatter__(self, ax): |
| 53 | + # Major locator customization |
| 54 | + ax.xaxis.set_major_locator(mdates.DayLocator(interval=self.custom_day_interval)) |
| 55 | + ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) |
| 56 | + |
| 57 | + def compute_pvalues(self, alpha=0.01): |
| 58 | + """ ComputePValues computes all pvalues, variance etc. for each combination of categories within |
| 59 | + the bucket test and renders a table containing the results """ |
| 60 | + # Create a list with unique values from a data frame |
| 61 | + values_df_group = self.df[self.group].unique() |
| 62 | + |
| 63 | + # Create variables for group A and B |
| 64 | + group_a = self.df[self.df[self.group] == values_df_group[0]][self.variable] |
| 65 | + group_b = self.df[self.df[self.group] == values_df_group[1]][self.variable] |
| 66 | + |
| 67 | + # normality |
| 68 | + normality_group_a, normality_pvalue_a = shapiro(group_a) |
| 69 | + normality_group_b, normality_pvalue_b = shapiro(group_b) |
| 70 | + print('Shapiro group A p-value: ', normality_pvalue_a) |
| 71 | + print('Shapiro group B p-value: ', normality_pvalue_b) |
| 72 | + |
| 73 | + # variance |
| 74 | + F = np.var(group_a) / np.var(group_b) |
| 75 | + critical_value_group_a = len(group_a) - 1 |
| 76 | + critical_value_group_b = len(group_b) - 1 |
| 77 | + f_pvalue = f.cdf(F, critical_value_group_a, critical_value_group_b) |
| 78 | + print('F test p-value: ', f_pvalue) |
| 79 | + |
| 80 | + if normality_pvalue_a > alpha and normality_pvalue_b > alpha: |
| 81 | + if f_pvalue > alpha: |
| 82 | + # T-test |
| 83 | + ttest_pvalue = ttest_ind(group_a, group_b).pvalue |
| 84 | + print('T-test p-value: ', ttest_pvalue) |
| 85 | + print('Statistical significance: ', ttest_pvalue <= alpha) |
| 86 | + else: |
| 87 | + # Welch's test |
| 88 | + welch_pvalue = ttest_ind(group_a, group_b, equal_var=False).pvalue |
| 89 | + print('Welch p-value: ', welch_pvalue) |
| 90 | + print('Statistical significance: ', welch_pvalue <= alpha) |
| 91 | + else: |
| 92 | + # Mann-Whitney U test |
| 93 | + mannwhitneyu_pvalue = mannwhitneyu(group_a, group_b).pvalue |
| 94 | + print('Mann-Whitney U test: ', mannwhitneyu_pvalue) |
| 95 | + print('Statistical significance: ', mannwhitneyu_pvalue <= alpha) |
0 commit comments