diff --git a/dbldatagen/data_analyzer.py b/dbldatagen/data_analyzer.py index 5aec5245..fa46c944 100644 --- a/dbldatagen/data_analyzer.py +++ b/dbldatagen/data_analyzer.py @@ -5,16 +5,26 @@ """ This module defines the ``DataAnalyzer`` class. -This code is experimental and both APIs and code generated is liable to change in future versions. + .. warning:: + Experimental + + This code is experimental and both APIs and code generated is liable to change in future versions. + """ -from pyspark.sql.types import LongType, FloatType, IntegerType, StringType, DoubleType, BooleanType, ShortType, \ - TimestampType, DateType, DecimalType, ByteType, BinaryType, StructType, ArrayType, DataType +import logging +import pprint +from collections import namedtuple -import pyspark.sql as ssql +import numpy as np import pyspark.sql.functions as F +from pyspark import sql +from pyspark.sql import DataFrame +from pyspark.sql.types import LongType, FloatType, IntegerType, StringType, DoubleType, BooleanType, ShortType, \ + TimestampType, DateType, DecimalType, ByteType, BinaryType, StructType, ArrayType, DataType, MapType -from .utils import strip_margins +from .html_utils import HtmlUtils from .spark_singleton import SparkSingleton +from .utils import strip_margins, json_value_from_path class DataAnalyzer: @@ -23,10 +33,16 @@ class DataAnalyzer: :param df: Spark dataframe to analyze :param sparkSession: Spark session instance to use when performing spark operations + :param maxRows: if specified, determines max number of rows to analyze when `analysisLevel` is "sample" + :param analysisLevel: Determines level of analysis to perform. Options are ["minimal", "sample", "full"]. + Default is "sample - .. warning:: - Experimental + You may increase the categorical values threshold to a higher value using the Spark config option + `dbldatagen.analyzer.categoricalValuesThreshold` in which case, columns with higher numbers of distinct values + will be evaluated to see if they can be represented as a values list. + However the sampling may fail if this level is set too high. + Experimentally, this should be kept below 100 at present. """ _DEFAULT_GENERATED_NAME = "synthetic_data" @@ -43,20 +59,133 @@ class DataAnalyzer: |# Column definitions are stubs only - modify to generate correct data |#""", '|') - def __init__(self, df=None, sparkSession=None): + _INT_32_MAX = 2 ** 16 - 1 + + _MAX_COLUMN_ELEMENT_LENGTH_THRESHOLD = 40 + _CATEGORICAL_VALUE_DEFAULT_THRESHOLD = 50 + + _MAX_VALUES_LINE_LENGTH = 60 + _CODE_GENERATION_INDENT = 4 + _MEASURE_ROUNDING = 4 + + # tuple for column infor + ColInfo = namedtuple("ColInfo", ["name", "dt", "isArrayColumn", "isNumeric"]) + + # tuple for values info + ColumnValuesInfo = namedtuple("ColumnValuesInfo", ["name", "statements", "value_refs"]) + + # options + _ANALYSIS_LEVELS = ["minimal", "sample", "analyze_text", "full"] + _CATEGORICAL_VALUES_THRESHOLD_OPTION = "dbldatagen.analyzer.categoricalValuesThreshold" + _SAMPLE_ROWS_THRESHOLD_OPTION = "dbldatagen.analyzer.sampleRowsThreshold" + _CACHE_SOURCE_OPTION = "dbldatagen.analyzer.cacheSource" + _CACHE_SAMPLE_OPTION = "dbldatagen.analyzer.cacheSample" + + _DEFAULT_SAMPLE_ROWS_THRESHOLD = 10000 + + def __init__(self, df=None, sparkSession=None, maxRows=None, analysisLevel="sample"): """ Constructor: :param df: Dataframe to analyze :param sparkSession: Spark session to use + :param maxRows: if specified, determines max number of rows to analyze. + :param analysisLevel: Determines level of analysis to perform. Options are ["minimal", "sample", "full"]. + Default is "sample + + + + You may increase the categorical values threshold to a higher value, in which case, columns with higher values + of distinct values will be evaluated to see if they can be represented as a values list. + + However the current implementation will flag an error if the number of categorical causes SQL array sizes to be + too large. Experimentally, this should be kept below 100 at present. """ assert df is not None, "dataframe must be supplied" self._df = df + assert analysisLevel in self._ANALYSIS_LEVELS, f"analysisLevel must be one of {self._ANALYSIS_LEVELS}" + assert maxRows is None or maxRows > 0, "maxRows must be greater than 0, if supplied" + if sparkSession is None: sparkSession = SparkSingleton.getLocalInstance() self._sparkSession = sparkSession self._dataSummary = None + self._columnsInfo = None + self._expandedSampleDf = None + + self._valuesCountThreshold = int(self._sparkSession.conf.get(self._CATEGORICAL_VALUES_THRESHOLD_OPTION, + str(self._CATEGORICAL_VALUE_DEFAULT_THRESHOLD))) + + # max rows is supplied parameter or default + self._maxRows = maxRows or self._DEFAULT_SAMPLE_ROWS_THRESHOLD + self._df_sampled_data = None + self._analysisLevel = analysisLevel + self._cacheSource = self._sparkSession.conf.get(self._CACHE_SOURCE_OPTION, "false").lower() == "true" + self._cacheSample = self._sparkSession.conf.get(self._CACHE_SAMPLE_OPTION, "true").lower() == "true" + + @classmethod + def sampleData(cls, df: DataFrame, maxRows: int): + """ + Sample data from a dataframe specifying the max rows to sample + + :param df: The dataframe to sample + :param maxRows: The maximum number of rows to samples + :return: The dataframe with the sampled data + """ + assert df is not None, "dataframe must be supplied" + assert maxRows is not None and isinstance(maxRows, int) and maxRows > 0, "maxRows must be a non-zero integer" + + # use count with limit of maxRows + 1 to determine if the dataframe is larger than maxRows + if df.limit(maxRows + 1).count() <= maxRows: + return df + + # if the dataframe is larger than maxRows, then sample it + # and constrain the output to the limit of maxRows + return df.sample(maxRows / df.count(), seed=42).limit(maxRows) + + @property + def sourceDf(self): + """ Get source dataframe""" + return self._df + + @property + def sampledSourceDf(self): + """ Get source dataframe (capped with maxRows if necessary)""" + if self._df_sampled_data is None: + # by default, use the full source + if self._analysisLevel == "full": + self._df_sampled_data = self._df + else: + self._df_sampled_data = self.sampleData(self._df, self._maxRows) + + if self._cacheSample: + self._df_sampled_data = self._df_sampled_data.cache() + + return self._df_sampled_data + + @property + def expandedSampleDf(self): + """ Get dataframe with array values expanded""" + + if self._expandedSampleDf is None: + df_expandedSample = self.sampledSourceDf + + # expand source dataframe array columns + columns = df_expandedSample.columns + + for column in self.columnsInfo: + if column.isArrayColumn: + df_expandedSample = df_expandedSample.withColumn(column.name, F.explode_outer(F.col(column.name))) + + df_expandedSample = df_expandedSample.select(*columns) + + if self._cacheSample: + df_expandedSample = df_expandedSample.cache() + + self._expandedSampleDf = df_expandedSample + + return self._expandedSampleDf def _displayRow(self, row): """Display details for row""" @@ -72,10 +201,10 @@ def _addMeasureToSummary(self, measureName, summaryExpr="''", fieldExprs=None, d """ Add a measure to the summary dataframe :param measureName: Name of measure - :param summaryExpr: Summary expression - :param fieldExprs: list of field expressions (or generator) + :param summaryExpr: Summary expression - string or sql.Column + :param fieldExprs: list of field expressions (or generator) - either string or sql.Column instances :param dfData: Source data df - data being summarized - :param rowLimit: Number of rows to get for measure + :param rowLimit: Number of rows to get for measure - usually 1 :param dfSummary: Summary df :return: dfSummary with new measure added """ @@ -83,17 +212,178 @@ def _addMeasureToSummary(self, measureName, summaryExpr="''", fieldExprs=None, d assert measureName is not None and len(measureName) > 0, "invalid measure name" # add measure name and measure summary - exprs = [f"'{measureName}' as measure_", f"string({summaryExpr}) as summary_"] + exprs = [F.lit(measureName).astype(StringType()).alias("measure_")] + + if isinstance(summaryExpr, str): + exprs.append(F.expr(summaryExpr).astype(StringType()).alias("summary_")) + else: + assert isinstance(summaryExpr, sql.Column), "summaryExpr must be string or sql.Column" + exprs.append(summaryExpr.astype(StringType()).alias("summary_")) # add measures for fields - exprs.extend(fieldExprs) + for fieldExpr in fieldExprs: + if isinstance(fieldExpr, str): + exprs.append(F.expr(fieldExpr).astype(StringType())) + else: + assert isinstance(fieldExpr, sql.Column), "fieldExpr must be string or sql.Column" + exprs.append(fieldExpr) - if dfSummary is not None: - dfResult = dfSummary.union(dfData.selectExpr(*exprs).limit(rowLimit)) - else: - dfResult = dfData.selectExpr(*exprs).limit(rowLimit) + dfMeasure = dfData.select(*exprs).limit(rowLimit) if rowLimit is not None else dfData.select(*exprs) + + return dfSummary.union(dfMeasure) if dfSummary is not None else dfMeasure + + @staticmethod + def _is_numeric_type(dtype): + """ return true if dtype is numeric, false otherwise""" + if dtype.lower() in ['smallint', 'tinyint', 'double', 'float', 'bigint', 'int']: + return True + elif dtype.lower().startswith("decimal"): + return True + + return False + + @property + def columnsInfo(self): + """ Get extended columns info. + + :return: List of column info tuples (named tuple - ColumnValuesInfo) + + """ + if self._columnsInfo is None: + df_dtypes = self.sampledSourceDf.dtypes + + # compile column information [ (name, datatype, isArrayColumn, isNumeric) ] + columnsInfo = [self.ColInfo(dtype[0], + dtype[1], + 1 if dtype[1].lower().startswith('array') else 0, + 1 if self._is_numeric_type(dtype[1].lower()) else 0) + for dtype in df_dtypes] + + self._columnsInfo = columnsInfo + return self._columnsInfo + + _URL_PREFIX = r"https?:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}" + _IMAGE_EXTS = r"(png)|(jpg)|(tif)" + + _regex_patterns = { + "alpha_upper": r"[A-Z]+", + "alpha_lower": r"[a-z]+", + "digits": r"[0-9]+", + "alphanumeric": r"[a-zA-Z0-9]+", + "identifier": r"[a-zA-Z0-9_]+", + "image_url": _URL_PREFIX + r"\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)" + _IMAGE_EXTS, + "url": _URL_PREFIX + r"\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)", + "email_common": r"([a-zA-Z0-9._%-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,6})*", + "email_uncommon": r"([a-z0-9_\.\+-]+)@([\da-z\.-]+)\.([a-z\.]{2,6})", + "ip_addr": r"[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}", + "free_text": r"(\s*[a-zA-Z0-9]+\s*[\?\.;,]*)*" + } + + def _compute_pattern_match_clauses(self): + """Generate string pattern matching expressions to compute probability of matching particular patterns + """ + stmts = [] + + # for each column + for colInfo in self.columnsInfo: + clauses = [] + if colInfo.dt == "string": + + # compute named struct of measures matching specific regular expressions + for k, v in self._regex_patterns.items(): + clauses.append(F.round(F.expr(f"""count_if(`{colInfo.name}` regexp '^{v}$')"""), 4) + .astype(StringType()).alias(k)) + + stmt = F.to_json(F.struct(*clauses)).alias(colInfo.name) + stmts.append(stmt) + else: + stmts.append(F.first(F.lit('')).alias(colInfo.name)) + result = stmts + return result + + @staticmethod + def _left4k(name): + """Return left 4k characters of string""" + return f"left(string({name}), 4096)" + + _WORD_REGEX = r"\\b\\w+\\b" + _SPACE_REGEX = r"\\s+" + _DIGIT_REGEX = r"\\d" + _PUNCTUATION_REGEX = r"[\\?\\.\\;\\,\\!\\{\\}\\[\\]\\(\\)\\>\\<]" + _AT_REGEX = r"\\@" + _PERIOD_REGEX = r"\\." + _HTTP_REGEX = r"^http[s]?\\:\\/\\/" + _ALPHA_REGEX = r"[a-zA-Z]" + _ALPHA_UPPER_REGEX = r"[A-Z]" + _ALPHA_LOWER_REGEX = r"[a-z]" + _HEX_REGEX = r"[0-9a-fA-F]" + + _MINMAXAVG = "minmaxavg" + _BOOLEAN = "boolean" + + _textFeatures = { + 'print_len': ("length(string($name$))", _MINMAXAVG), + 'word_count': (f"size(regexp_extract_all(left(string($name$), 4096), '{_WORD_REGEX}', 0))", _MINMAXAVG), + 'space_count': (f"size(regexp_extract_all(left(string($name$), 4096), '{_SPACE_REGEX}', 0))", _MINMAXAVG), + 'digit_count': (f"size(regexp_extract_all(left(string($name$), 4096), '{_DIGIT_REGEX}', 0))", _MINMAXAVG), + 'punctuation_count': ( + f"size(regexp_extract_all(left(string($name$), 4096), '{_PUNCTUATION_REGEX}', 0))", _MINMAXAVG), + 'at_count': (f"size(regexp_extract_all(left(string($name$), 4096), '{_AT_REGEX}', 0))", _MINMAXAVG), + 'period_count': (f"size(regexp_extract_all(left(string($name$), 4096), '{_PERIOD_REGEX}', 0))", _MINMAXAVG), + 'http_count': (f"size(regexp_extract_all(left(string($name$), 4096), '{_HTTP_REGEX}', 0))", _MINMAXAVG), + 'alpha_count': (f"size(regexp_extract_all(left(string($name$), 4096), '{_ALPHA_REGEX}', 0))", _MINMAXAVG), + 'alpha_lower_count': ( + f"size(regexp_extract_all(left(string($name$), 4096), '{_ALPHA_LOWER_REGEX}', 0))", _MINMAXAVG), + 'alpha_upper_count': ( + f"size(regexp_extract_all(left(string($name$), 4096), '{_ALPHA_UPPER_REGEX}', 0))", _MINMAXAVG), + 'hex_digit_count': (f"size(regexp_extract_all(left(string($name$), 4096), '{_HEX_REGEX}', 0))", _MINMAXAVG), + } + + def generateTextFeatures(self, sourceDf): + """ Generate text features from source dataframe + + Generates set of text features for each column (analyzing string representation of each column value) + + :param sourceDf: Source datafame + :return: Dataframe of text features + """ + # generate named struct of text features for each column + + # we need to double escape backslashes in regular expressions as they will be lost in string expansion + + # for each column, extract text features from string representation of column value (leftmost 4096 characters) + + fieldTextFeatures = [] + + # add regular text features + for colInfo in self.columnsInfo: + features_clauses = [] + + for k, v in self._textFeatures.items(): + feature_expr, strategy = v + feature_expr = feature_expr.replace("$name$", colInfo.name) # substitute column name - return dfResult + if strategy == self._MINMAXAVG: + features = F.array(F.min(F.expr(feature_expr)), + F.max(F.expr(feature_expr)), + F.avg(F.expr(feature_expr))) + features_clauses.append(features.alias(k)) + elif strategy == self._BOOLEAN: + feature = F.when(F.expr(feature_expr), F.lit(1)).otherwise(F.lit(0)) + features_clauses.append(feature.alias(k)) + + column_text_features = F.to_json(F.struct(*features_clauses)).alias(colInfo.name) + + fieldTextFeatures.append(column_text_features) + + dfTextFeatures = self._addMeasureToSummary( + 'text_features', + fieldExprs=fieldTextFeatures, + dfData=sourceDf, + dfSummary=None, + rowLimit=None) + + return dfTextFeatures def summarizeToDF(self): """ Generate summary analysis of data set as dataframe @@ -105,86 +395,164 @@ def summarizeToDF(self): The output is also used in code generation to generate more accurate code. """ - self._df.cache().createOrReplaceTempView("data_analysis_summary") + # if self._cacheSource: + # self._df.cache().createOrReplaceTempView("data_analysis_summary") + + df_under_analysis = self.sampledSourceDf - total_count = self._df.count() * 1.0 + logger = logging.getLogger(__name__) + logger.info("Analyzing counts") + total_count = df_under_analysis.count() * 1.0 - dtypes = self._df.dtypes + logger.info("Analyzing measures") - # schema information + # feature : schema information, [minimal, sample, complete] dfDataSummary = self._addMeasureToSummary( 'schema', - summaryExpr=f"""to_json(named_struct('column_count', {len(dtypes)}))""", - fieldExprs=[f"'{dtype[1]}' as {dtype[0]}" for dtype in dtypes], - dfData=self._df) + summaryExpr=F.to_json(F.expr(f"""named_struct('column_count', {len(self.columnsInfo)})""")), + fieldExprs=[f"'{colInfo.dt}' as {colInfo.name}" for colInfo in self.columnsInfo], + dfData=self.sourceDf) # count + # feature : count, [minimal, sample, complete] dfDataSummary = self._addMeasureToSummary( 'count', - summaryExpr=f"{total_count}", - fieldExprs=[f"string(count({dtype[0]})) as {dtype[0]}" for dtype in dtypes], - dfData=self._df, + summaryExpr="count(*)", + fieldExprs=[f"string(count({colInfo.name})) as {colInfo.name}" for colInfo in self.columnsInfo], + dfData=self.sourceDf, dfSummary=dfDataSummary) + # feature : probability of nulls, [minimal, sample, complete] dfDataSummary = self._addMeasureToSummary( 'null_probability', - fieldExprs=[f"""string( round( ({total_count} - count({dtype[0]})) /{total_count}, 2)) as {dtype[0]}""" - for dtype in dtypes], - dfData=self._df, + fieldExprs=[ + f"""string(round((count(*) - count({colInfo.name})) /count(*), 5)) as {colInfo.name}""" + for colInfo in self.columnsInfo], + dfData=self.sourceDf, dfSummary=dfDataSummary) - # distinct count + # feature : distinct count, [minimal, sample, complete] dfDataSummary = self._addMeasureToSummary( 'distinct_count', summaryExpr="count(distinct *)", - fieldExprs=[f"string(count(distinct {dtype[0]})) as {dtype[0]}" for dtype in dtypes], - dfData=self._df, + fieldExprs=[f"string(count(distinct {colInfo.name})) as {colInfo.name}" for colInfo in self.columnsInfo], + dfData=self.sourceDf, dfSummary=dfDataSummary) - # min + # feature : item distinct count (i.e distinct count of array items), [minimal, sample, complete] dfDataSummary = self._addMeasureToSummary( - 'min', - fieldExprs=[f"string(min({dtype[0]})) as {dtype[0]}" for dtype in dtypes], - dfData=self._df, + 'item_distinct_count', + summaryExpr="count(distinct *)", + fieldExprs=[f"string(count(distinct {colInfo.name})) as {colInfo.name}" for colInfo in self.columnsInfo], + dfData=self.expandedSampleDf, dfSummary=dfDataSummary) + # feature : item count (i.e count of individual array items), [minimal, sample, complete] dfDataSummary = self._addMeasureToSummary( - 'max', - fieldExprs=[f"string(max({dtype[0]})) as {dtype[0]}" for dtype in dtypes], + 'item_count', + summaryExpr="count(*)", + fieldExprs=[f"string(count({colInfo.name})) as {colInfo.name}" for colInfo in self.columnsInfo], + dfData=self.expandedSampleDf, + dfSummary=dfDataSummary) + + # string characteristics for strings and string representation of other values + # feature : print len max, [minimal, sample, complete] + dfDataSummary = self._addMeasureToSummary( + 'print_len', + fieldExprs=[F.to_json(F.struct(F.expr(f"min(length(string({colInfo.name})))").alias("min"), + F.expr(f"max(length(string({colInfo.name})))").alias("max"), + F.expr(f"avg(length(string({colInfo.name})))").alias("avg"))) + .alias(colInfo.name) + for colInfo in self.columnsInfo], dfData=self._df, dfSummary=dfDataSummary) - descriptionDf = self._df.describe().where("summary in ('mean', 'stddev')") - describeData = descriptionDf.collect() + # string characteristics for strings and string representation of other values + # feature : item print len max, [minimal, sample, complete] + dfDataSummary = self._addMeasureToSummary( + 'item_printlen', + fieldExprs=[F.to_json(F.struct(F.expr(f"min(length(string({colInfo.name})))").alias("min"), + F.expr(f"max(length(string({colInfo.name})))").alias("max"), + F.expr(f"avg(length(string({colInfo.name})))").alias("avg"))) + .alias(colInfo.name) + for colInfo in self.columnsInfo], + dfData=self.expandedSampleDf, + dfSummary=dfDataSummary) - for row in describeData: - measure = row['summary'] + # feature : item print len max, [minimal, sample, complete] + metrics_clause = self._compute_pattern_match_clauses() - values = {k[0]: '' for k in dtypes} + # string metrics + # we'll compute probabilities that string values match specific patterns - this can be subsequently + # used to tailor code generation + dfDataSummary = self._addMeasureToSummary( + 'string_patterns', + fieldExprs=metrics_clause, + dfData=self.expandedSampleDf, + dfSummary=dfDataSummary) - row_key_pairs = row.asDict() - for k1 in row_key_pairs: - values[k1] = str(row[k1]) + # min + dfDataSummary = self._addMeasureToSummary( + 'min', + fieldExprs=[f"string(min({colInfo.name})) as {colInfo.name}" for colInfo in self.columnsInfo], + dfData=df_under_analysis, + dfSummary=dfDataSummary) - dfDataSummary = self._addMeasureToSummary( - measure, - fieldExprs=[f"'{values[dtype[0]]}'" for dtype in dtypes], - dfData=self._df, - dfSummary=dfDataSummary) + dfDataSummary = self._addMeasureToSummary( + 'max', + fieldExprs=[f"string(max({colInfo.name})) as {colInfo.name}" for colInfo in self.columnsInfo], + dfData=df_under_analysis, + dfSummary=dfDataSummary) - # string characteristics for strings and string representation of other values dfDataSummary = self._addMeasureToSummary( - 'print_len_min', - fieldExprs=[f"min(length(string({dtype[0]}))) as {dtype[0]}" for dtype in dtypes], - dfData=self._df, + 'cardinality', + fieldExprs=[f"""to_json(named_struct( + 'min', min(cardinality({colInfo.name})), + 'max', max(cardinality({colInfo.name})))) + as {colInfo.name}""" + if colInfo.isArrayColumn else "min(1)" + for colInfo in self.columnsInfo], + dfData=df_under_analysis, dfSummary=dfDataSummary) dfDataSummary = self._addMeasureToSummary( - 'print_len_max', - fieldExprs=[f"max(length(string({dtype[0]}))) as {dtype[0]}" for dtype in dtypes], - dfData=self._df, + 'array_value_min', + fieldExprs=[f"min(array_min({colInfo.name})) as {colInfo.name}" + if colInfo.isArrayColumn else f"first('') as {colInfo.name}" + for colInfo in self.columnsInfo], + dfData=df_under_analysis, + dfSummary=dfDataSummary) + + dfDataSummary = self._addMeasureToSummary( + 'array_value_max', + fieldExprs=[f"max(array_max({colInfo.name})) as {colInfo.name}" + if colInfo.isArrayColumn else f"first('') as {colInfo.name}" + for colInfo in self.columnsInfo], + dfData=df_under_analysis, + dfSummary=dfDataSummary) + + rounding = self._MEASURE_ROUNDING + + dfDataSummary = self._addMeasureToSummary( + 'stats', + fieldExprs=[F.to_json(F.struct( + F.expr(f"round(skewness({colInfo.name}),{rounding})").alias('skewness'), + F.expr(f"round(kurtosis({colInfo.name}),{rounding})").alias('kurtosis'), + F.expr(f"round(mean({colInfo.name}),{rounding})").alias('mean'), + F.expr(f"round(stddev_pop({colInfo.name}),{rounding})").alias('stddev'))) + .alias(colInfo.name) + if colInfo.isNumeric + else F.expr("null").alias(colInfo.name) + for colInfo in self.columnsInfo], + dfData=df_under_analysis, dfSummary=dfDataSummary) + if self._analysisLevel in ["analyze_text", "full"]: + logger.info("Analyzing summary text features") + dfTextFeaturesSummary = self.generateTextFeatures(self.expandedSampleDf) + + dfDataSummary = dfDataSummary.union(dfTextFeaturesSummary) + return dfDataSummary def summarize(self, suppressOutput=False): @@ -211,7 +579,7 @@ def summarize(self, suppressOutput=False): return summary @classmethod - def _valueFromSummary(cls, dataSummary, colName, measure, defaultValue): + def _valueFromSummary(cls, dataSummary, colName, measure, defaultValue, jsonPath=None): """ Get value from data summary :param dataSummary: Data summary to search, optional @@ -219,6 +587,8 @@ def _valueFromSummary(cls, dataSummary, colName, measure, defaultValue): :param measure: Measure name of measure to get value for :param defaultValue: Default value if any other argument is not specified or value could not be found in data summary + :param jsonPath: if jsonPath is supplied, treat initial result as JSON data and perform lookup according to + the supplied json path. :return: Value from lookup or `defaultValue` if not found """ if dataSummary is not None and colName is not None and measure is not None: @@ -226,13 +596,20 @@ def _valueFromSummary(cls, dataSummary, colName, measure, defaultValue): measureValues = dataSummary[measure] if colName in measureValues: - return measureValues[colName] + result = measureValues[colName] + + if jsonPath is not None: + result = json_value_from_path(jsonPath, result, defaultValue) + + if result is not None: + return result # return default value if value could not be looked up or found return defaultValue @classmethod - def _generatorDefaultAttributesFromType(cls, sqlType, colName=None, dataSummary=None, sourceDf=None): + def _generatorDefaultAttributesFromType(cls, sqlType, colName=None, isArrayElement=False, dataSummary=None, + sourceDf=None, valuesInfo=None): """ Generate default set of attributes for each data type :param sqlType: Instance of `pyspark.sql.types.DataType` @@ -248,31 +625,38 @@ def _generatorDefaultAttributesFromType(cls, sqlType, colName=None, dataSummary= """ assert isinstance(sqlType, DataType) - if sqlType == StringType(): + min_attribute = "min" if not isArrayElement else "array_value_min" + max_attribute = "max" if not isArrayElement else "array_value_max" + + if valuesInfo is not None and \ + colName in valuesInfo and not isinstance(sqlType, (BinaryType, StructType, MapType)): + result = valuesInfo[colName].value_refs + assert result is not None and len(result) > 0 + elif sqlType == StringType(): result = """template=r'\\\\w'""" elif sqlType in [IntegerType(), LongType()]: - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=1000000) + minValue = cls._valueFromSummary(dataSummary, colName, min_attribute, defaultValue=0) + maxValue = cls._valueFromSummary(dataSummary, colName, max_attribute, defaultValue=1000000) result = f"""minValue={minValue}, maxValue={maxValue}""" elif sqlType == ByteType(): - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=127) + minValue = cls._valueFromSummary(dataSummary, colName, min_attribute, defaultValue=0) + maxValue = cls._valueFromSummary(dataSummary, colName, max_attribute, defaultValue=127) result = f"""minValue={minValue}, maxValue={maxValue}""" elif sqlType == ShortType(): - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=32767) + minValue = cls._valueFromSummary(dataSummary, colName, min_attribute, defaultValue=0) + maxValue = cls._valueFromSummary(dataSummary, colName, max_attribute, defaultValue=32767) result = f"""minValue={minValue}, maxValue={maxValue}""" elif sqlType == BooleanType(): result = """expr='id % 2 = 1'""" elif sqlType == DateType(): result = """expr='current_date()'""" elif isinstance(sqlType, DecimalType): - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=1000) + minValue = cls._valueFromSummary(dataSummary, colName, min_attribute, defaultValue=0) + maxValue = cls._valueFromSummary(dataSummary, colName, max_attribute, defaultValue=1000) result = f"""minValue={minValue}, maxValue={maxValue}""" elif sqlType in [FloatType(), DoubleType()]: - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0.0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=1000000.0) + minValue = cls._valueFromSummary(dataSummary, colName, min_attribute, defaultValue=0.0) + maxValue = cls._valueFromSummary(dataSummary, colName, max_attribute, defaultValue=1000000.0) result = f"""minValue={minValue}, maxValue={maxValue}, step=0.1""" elif sqlType == TimestampType(): result = """begin="2020-01-01 01:00:00", end="2020-12-31 23:59:00", interval="1 minute" """ @@ -288,8 +672,99 @@ def _generatorDefaultAttributesFromType(cls, sqlType, colName=None, dataSummary= return result + def _cleanse_name(self, col_name): + """cleanse column name for use in code""" + return col_name.replace(' ', '_') + + def _format_values_list(self, values): + """ Format values """ + pp = pprint.PrettyPrinter(indent=self._CODE_GENERATION_INDENT, + width=self._MAX_VALUES_LINE_LENGTH, + compact=True) + values = pp.pformat(values) + + return values + + def _processCategoricalValuesInfo(self, dataSummary=None, sourceDf=None): + """ Computes values clauses for appropriate columns + + :param dataSummary: Data summary + :param sourceDf: Source data dataframe + :return: Map from column name to ColumnValueInfo tuples + where ColumnValuesInfo = namedtuple("ColumnValuesInfo", ["name", "statements", "weights", "values"]) + """ + assert dataSummary is not None + assert sourceDf is not None + + results = {} + + logger = logging.getLogger(__name__) + logger.info("Performing categorical data analysis") + + for fld in sourceDf.schema.fields: + col_name = fld.name + col_base_type = fld.dataType.elementType if isinstance(fld.dataType, ArrayType) else fld.dataType + col_type = col_base_type.simpleString() + + stmts = [] + value_refs = [] + + # we'll compute values set for elements whose max printable length < MAX_COLUMN_ELEMENT_LENGTH_THRESHOLD + # whose count of distinct elements is < MAX_DISTINCT_THRESHOLD + # and where the type is numeric or string either by itself or in array variant + numDistinct = int(self._valueFromSummary(dataSummary, col_name, "item_distinct_count", + defaultValue=self._INT_32_MAX)) + maxPrintable = int(self._valueFromSummary(dataSummary, col_name, "item_max_printlen", + defaultValue=self._INT_32_MAX)) + + if self._valuesCountThreshold > numDistinct > 1 and \ + maxPrintable < self._MAX_COLUMN_ELEMENT_LENGTH_THRESHOLD and \ + col_type in ["float", "double", "int", "smallint", "bigint", "tinyint", "string"]: + logger.info(f"Retrieving categorical values for column `{col_name}`") + + value_rows = sorted(sourceDf.select(col_name).where(f"{col_name} is not null") + .groupBy(col_name).count().collect(), + key=lambda r1, sk=col_name: r1[sk]) + values = [r[col_name] for r in value_rows] + weights = [r['count'] for r in value_rows] + + # simplify the weights + countNonNull = int(self._valueFromSummary(dataSummary, col_name, "item_count", + defaultValue=sum(weights))) + + weights = ((np.array(weights) / countNonNull) * 100.0).round().astype(np.uint64) + weights = np.maximum(weights, 1) # minumum weight must be 1 + + # divide by GCD to get simplified weights + gcd = np.gcd.reduce(weights) + + weights = (weights / gcd).astype(np.uint64) + + # if all of the weights are within 10% of mean, ignore the weights + avg_weight = np.mean(weights) + weight_threshold = avg_weight * 0.1 + weight_test = np.abs(weights - avg_weight) + if np.all(weight_test < weight_threshold): + weights = None + else: + weights = list(weights) + + safe_name = self._cleanse_name(col_name) + + if weights is not None: + stmts.append(f"{safe_name}_weights = {self._format_values_list(weights)}") + value_refs.append(f"""weights = {safe_name}_weights""") + + stmts.append(f"{safe_name}_values = {self._format_values_list(values)}") + value_refs.append(f"""values={safe_name}_values""") + + results[col_name] = self.ColumnValuesInfo(col_name, stmts, ", ".join(value_refs)) + + return results + @classmethod - def _scriptDataGeneratorCode(cls, schema, dataSummary=None, sourceDf=None, suppressOutput=False, name=None): + def _scriptDataGeneratorCode(cls, schema, dataSummary=None, sourceDf=None, suppressOutput=False, name=None, + valuesInfo=None): """ Generate outline data generator code from an existing dataframe @@ -307,6 +782,7 @@ def _scriptDataGeneratorCode(cls, schema, dataSummary=None, sourceDf=None, suppr :param sourceDf: Source dataframe to retrieve attributes of real data, optional :param suppressOutput: Suppress printing of generated code if True :param name: Optional name for data generator + :param valuesInfo: References and statements for `values` clauses :return: String containing skeleton code """ @@ -320,10 +796,17 @@ def _scriptDataGeneratorCode(cls, schema, dataSummary=None, sourceDf=None, suppr stmts.append(cls._GENERATED_COMMENT) stmts.append("import dbldatagen as dg") - stmts.append("import pyspark.sql.types") stmts.append(cls._GENERATED_FROM_SCHEMA_COMMENT) + if valuesInfo is not None: + for k, v in valuesInfo.items(): + stmts.append("") + stmts.append(f"# values for column `{k}`") + for line in v.statements: + stmts.append(line) + + stmts.append("") stmts.append(strip_margins( f"""generation_spec = ( | dg.DataGenerator(sparkSession=spark, @@ -340,15 +823,34 @@ def _scriptDataGeneratorCode(cls, schema, dataSummary=None, sourceDf=None, suppr if isinstance(fld.dataType, ArrayType): col_type = fld.dataType.elementType.simpleString() - field_attributes = cls._generatorDefaultAttributesFromType(fld.dataType.elementType) # no data look up - array_attributes = """structType='array', numFeatures=(2,6)""" + field_attributes = cls._generatorDefaultAttributesFromType(fld.dataType.elementType, + colName=col_name, + isArrayElement=True, + dataSummary=dataSummary, + sourceDf=sourceDf, + valuesInfo=valuesInfo) + + if dataSummary is not None: + minLength = cls._valueFromSummary(dataSummary, col_name, "cardinality", jsonPath="mint", + defaultValue=2) + maxLength = cls._valueFromSummary(dataSummary, col_name, "cardinality", jsonPath="max", + defaultValue=6) + + if minLength != maxLength: + array_attributes = f"""structType='array', numFeatures=({minLength}, {maxLength})""" + else: + array_attributes = f"""structType='array', numFeatures={minLength}""" + + else: + array_attributes = """structType='array', numFeatures=(2,6)""" name_and_type = f"""'{col_name}', '{col_type}'""" stmts.append(indent + f""".withColumn({name_and_type}, {field_attributes}, {array_attributes})""") else: field_attributes = cls._generatorDefaultAttributesFromType(fld.dataType, colName=col_name, dataSummary=dataSummary, - sourceDf=sourceDf) + sourceDf=sourceDf, + valuesInfo=valuesInfo) stmts.append(indent + f""".withColumn('{col_name}', '{col_type}', {field_attributes})""") stmts.append(indent + ")") @@ -359,7 +861,7 @@ def _scriptDataGeneratorCode(cls, schema, dataSummary=None, sourceDf=None, suppr return "\n".join(stmts) @classmethod - def scriptDataGeneratorFromSchema(cls, schema, suppressOutput=False, name=None): + def scriptDataGeneratorFromSchema(cls, schema, suppressOutput=False, name=None, asHtml=False): """ Generate outline data generator code from an existing dataframe @@ -374,15 +876,19 @@ def scriptDataGeneratorFromSchema(cls, schema, suppressOutput=False, name=None): :param schema: Pyspark schema - i.e manually constructed StructType or return value from `dataframe.schema` :param suppressOutput: Suppress printing of generated code if True + :param asHtml: If True, will generate Html suitable for notebook ``displayHtml``. If true, suppresses output :param name: Optional name for data generator - :return: String containing skeleton code + :return: String containing skeleton code (in Html form if `asHtml` is True) """ - return cls._scriptDataGeneratorCode(schema, - suppressOutput=suppressOutput, - name=name) + generated_code = cls._scriptDataGeneratorCode(schema, suppressOutput=asHtml or suppressOutput, name=name) + + if asHtml: + generated_code = HtmlUtils.formatCodeAsHtml(generated_code) - def scriptDataGeneratorFromData(self, suppressOutput=False, name=None): + return generated_code + + def scriptDataGeneratorFromData(self, suppressOutput=False, name=None, asHtml=False): """ Generate outline data generator code from an existing dataframe @@ -397,22 +903,41 @@ def scriptDataGeneratorFromData(self, suppressOutput=False, name=None): :param suppressOutput: Suppress printing of generated code if True :param name: Optional name for data generator - :return: String containing skeleton code + :param asHtml: If True, will generate Html suitable for notebook ``displayHtml``. If true, suppresses output + :return: String containing skeleton code (in Html form if `asHtml` is True) """ - assert self._df is not None - assert type(self._df) is ssql.DataFrame, "sourceDf must be a valid Pyspark dataframe" + assert self.sampledSourceDf is not None + assert type(self.sampledSourceDf) is sql.DataFrame, "sourceDf must be a valid Pyspark dataframe" if self._dataSummary is None: + logger = logging.getLogger(__name__) + logger.info("Performing data analysis in preparation for code generation") + df_summary = self.summarizeToDF() self._dataSummary = {} - for row in df_summary.collect(): + logger.info("Performing summary analysis ...") + + analysis_measures = df_summary.collect() + + logger.info("Processing summary analysis results") + + for row in analysis_measures: row_key_pairs = row.asDict() self._dataSummary[row['measure_']] = row_key_pairs - return self._scriptDataGeneratorCode(self._df.schema, - suppressOutput=suppressOutput, - name=name, - dataSummary=self._dataSummary, - sourceDf=self._df) + values_info = self._processCategoricalValuesInfo(dataSummary=self._dataSummary, + sourceDf=self.expandedSampleDf) + + generated_code = self._scriptDataGeneratorCode(self.sampledSourceDf.schema, + suppressOutput=asHtml or suppressOutput, + name=name, + dataSummary=self._dataSummary, + sourceDf=self.sampledSourceDf, + valuesInfo=values_info) + + if asHtml: + generated_code = HtmlUtils.formatCodeAsHtml(generated_code) + + return generated_code diff --git a/docs/source/generating_from_existing_data.rst b/docs/source/generating_from_existing_data.rst index 6e56b537..ae994e94 100644 --- a/docs/source/generating_from_existing_data.rst +++ b/docs/source/generating_from_existing_data.rst @@ -112,13 +112,49 @@ For example, the following code will generate synthetic data generation code fro .. code-block:: python - import dbldatagen as dg + import dbldatagen as dg - dfSource = spark.read.format("parquet").load("/tmp/your/source/dataset") + # In a Databricks runtime environment + # The folder `dbfs:/databricks-datasets` contains a variety of sample data sets + dfSource = spark.read.format("parquet").load("dbfs:/databricks-datasets/amazon/test4K/") + + da = dg.DataAnalyzer(dfSource) + + df2 = da.summarizeToDF() + generatedCode = da.scriptDataGeneratorFromData(suppressOutput=True) + + print(generatedCode) - analyzer = dg.DataAnalyzer(sparkSession=spark, df=df_source_data) - generatedCode = analyzer.scriptDataGeneratorFromData() +It is not intended to generate complete code to reproduce the dataset but serves as a starting point +for generating a synthetic data set that mirrors the original source. +This will produce the following generated code. +.. code-block:: python + import dbldatagen as dg + + # Column definitions are stubs only - modify to generate correct data + # + + # values for column `rating` + rating_weights = [10, 5, 8, 18, 59] + rating_values = [1.0, 2.0, 3.0, 4.0, 5.0] + generation_spec = ( + dg.DataGenerator(sparkSession=spark, + name='synthetic_data', + rows=100000, + random=True, + ) + .withColumn('asin', 'string', template=r'\\w') + .withColumn('brand', 'string', template=r'\\w') + .withColumn('helpful', 'bigint', minValue=0, maxValue=417, structType='array', numFeatures=2) + .withColumn('img', 'string', template=r'\\w') + .withColumn('price', 'double', minValue=0.01, maxValue=962.0, step=0.1) + .withColumn('rating', 'double', weights = rating_weights, values=rating_values) + .withColumn('review', 'string', template=r'\\w') + .withColumn('time', 'bigint', minValue=921369600, maxValue=1406073600) + .withColumn('title', 'string', template=r'\\w') + .withColumn('user', 'string', template=r'\\w') + ) diff --git a/makefile b/makefile index e76e0952..c8812db4 100644 --- a/makefile +++ b/makefile @@ -35,7 +35,7 @@ create-github-build-env: install-dev-dependencies: @echo "$(OK_COLOR)=> installing dev environment requirements$(NO_COLOR)" - pip install -r python/dev_require.txt + pip install -v -r python/dev_require.txt clean-dev-env: @echo "$(OK_COLOR)=> Cleaning dev environment$(NO_COLOR)" diff --git a/tests/test_generation_from_data.py b/tests/test_generation_from_data.py index fab15809..c93d26ae 100644 --- a/tests/test_generation_from_data.py +++ b/tests/test_generation_from_data.py @@ -1,13 +1,20 @@ -import logging import ast -import pytest +import logging +import re +from html.parser import HTMLParser -import pyspark.sql as ssql +import pyspark.sql.functions as F +import pytest import dbldatagen as dg -spark = dg.SparkSingleton.getLocalInstance("unit tests") +@pytest.fixture(scope="class") +def spark(): + sparkSession = dg.SparkSingleton.getLocalInstance("unit tests") + sparkSession.conf.set("dbldatagen.data_analysis.checkpoint", "true") + sparkSession.sparkContext.setCheckpointDir( "/tmp/dbldatagen/checkpoint") + return sparkSession @pytest.fixture(scope="class") @@ -18,14 +25,62 @@ def setupLogging(): class TestGenerationFromData: SMALL_ROW_COUNT = 10000 - - @pytest.fixture + MEDIUM_ROW_COUNT = 50000 + + class SimpleValidator(HTMLParser): # pylint: disable=abstract-method + def __init__(self): + super().__init__() + self._errors = [] + self._tags = {} + + def handle_starttag(self, tag, attrs): + if tag in self._tags: + self._tags[tag] += 1 + else: + self._tags[tag] = 1 + + def handle_endtag(self, tag): + if tag in self._tags: + self._tags[tag] -= 1 + else: + self._errors.append(f"end tag {tag} found without start tag") + self._tags[tag] = -1 + + def checkHtml(self, htmlText): + """ + Check if htmlText produces errors + + :param htmlText: html text to parse + :return: Returns the the list of errors + """ + for tag, count in self._tags.items(): + if count > 0: + self._errors.append(f"tag {tag} has {count} additional start tags") + elif count < 0: + self._errors.append(f"tag {tag} has {-count} additional end tags") + return self._errors + + @pytest.fixture(scope="class") def testLogger(self): logger = logging.getLogger(__name__) return logger - @pytest.fixture - def generation_spec(self): + def mk_generation_spec(self, spark, row_count): + + country_codes = [ + "CN", "US", "FR", "CA", "IN", "JM", "IE", "PK", "GB", "IL", "AU", + "SG", "ES", "GE", "MX", "ET", "SA", "LB", "NL", + ] + country_weights = [ + 1300, 365, 67, 38, 1300, 3, 7, 212, 67, 9, 25, 6, 47, 83, + 126, 109, 58, 8, 17, + ] + + eurozone_countries = ["Austria", "Belgium", "Cyprus", "Estonia", "Finland", "France", "Germany", "Greece", + "Ireland", "Italy", "Latvia", "Lithuania", "Luxembourg", "Malta", "Netherlands", + "Portugal", "Slovakia", "Slovenia", "Spain" + ] + spec = ( dg.DataGenerator(sparkSession=spark, name='test_generator', rows=self.SMALL_ROW_COUNT, seedMethod='hash_fieldname') @@ -36,9 +91,9 @@ def generation_spec(self): baseColumn="asin") .withColumn('price', 'double', min=1.0, max=999.0, random=True, step=0.01) .withColumn('rating', 'double', values=[1.0, 2, 0, 3.0, 4.0, 5.0], random=True) - .withColumn('review', 'string', text=dg.ILText((1, 3), (1, 4), (3, 8)), random=True) + .withColumn('review', 'string', text=dg.ILText((1, 3), (1, 4), (3, 8)), random=True, percentNulls=0.1) .withColumn('time', 'bigint', expr="now()", percentNulls=0.1) - .withColumn('title', 'string', template=r"\w|\w \w \w|\w \w \w||\w \w \w \w", random=True) + .withColumn('title', 'string', template=r"\w|\w \w \w|\w \w \w|\w \w \w \w", random=True) .withColumn('user', 'string', expr="hex(abs(hash(id)))") .withColumn("event_ts", "timestamp", begin="2020-01-01 01:00:00", end="2020-12-31 23:59:00", @@ -47,63 +102,143 @@ def generation_spec(self): numColumns=(2, 4), structType="array") .withColumn("tf_flag", "boolean", expr="id % 2 = 1") .withColumn("short_value", "short", max=32767, percentNulls=0.1) + .withColumn("string_values", "string", values=["one", "two", "three"]) + .withColumn("country_codes", "string", values=country_codes, weights=country_weights) + .withColumn("euro_countries", "string", values=eurozone_countries) + .withColumn("int_value", "int", min=100, max=200, percentNulls=0.1) .withColumn("byte_value", "tinyint", max=127) .withColumn("decimal_value", "decimal(10,2)", max=1000000) - .withColumn("decimal_value", "decimal(10,2)", max=1000000) .withColumn("date_value", "date", expr="current_date()", random=True) .withColumn("binary_value", "binary", expr="cast('spark' as binary)", random=True) ) return spec - def test_code_generation1(self, generation_spec, setupLogging): + @pytest.fixture(scope="class") + def small_source_data_df(self, spark): + generation_spec = self.mk_generation_spec(spark, self.SMALL_ROW_COUNT) df_source_data = generation_spec.build() - df_source_data.show() + return df_source_data.checkpoint(eager=True) - analyzer = dg.DataAnalyzer(sparkSession=spark, df=df_source_data) + @pytest.fixture(scope="class") + def medium_source_data_df(self, spark): + generation_spec = self.mk_generation_spec(spark, self.MEDIUM_ROW_COUNT) + df_source_data = generation_spec.build() + return df_source_data.checkpoint(eager=True) + + def test_code_generation1(self, small_source_data_df, setupLogging, spark): + analyzer = dg.DataAnalyzer(sparkSession=spark, df=small_source_data_df) generatedCode = analyzer.scriptDataGeneratorFromData() - for fld in df_source_data.schema: + for fld in small_source_data_df.schema: assert f"withColumn('{fld.name}'" in generatedCode # check generated code for syntax errors ast_tree = ast.parse(generatedCode) assert ast_tree is not None - def test_code_generation_from_schema(self, generation_spec, setupLogging): - df_source_data = generation_spec.build() - generatedCode = dg.DataAnalyzer.scriptDataGeneratorFromSchema(df_source_data.schema) + def test_code_generation_as_html(self, small_source_data_df, setupLogging, spark): + analyzer = dg.DataAnalyzer(sparkSession=spark, df=small_source_data_df) + + generatedCode = analyzer.scriptDataGeneratorFromData(asHtml=True) - for fld in df_source_data.schema: + # note the generated code does not have html tags + validator = self.SimpleValidator() + parsing_errors = validator.checkHtml(generatedCode) + + assert len(parsing_errors) == 0, "Number of errors should be zero" + + print(generatedCode) + + def test_code_generation_from_schema(self, small_source_data_df, setupLogging): + generatedCode = dg.DataAnalyzer.scriptDataGeneratorFromSchema(small_source_data_df.schema) + + for fld in small_source_data_df.schema: assert f"withColumn('{fld.name}'" in generatedCode # check generated code for syntax errors ast_tree = ast.parse(generatedCode) assert ast_tree is not None - def test_summarize(self, testLogger, generation_spec): - testLogger.info("Building test data") + def test_code_generation_as_html_from_schema(self, small_source_data_df, setupLogging): + generatedCode = dg.DataAnalyzer.scriptDataGeneratorFromSchema(small_source_data_df.schema, asHtml=True) - df_source_data = generation_spec.build() + # note the generated code does not have html tags + validator = self.SimpleValidator() + parsing_errors = validator.checkHtml(generatedCode) + + assert len(parsing_errors) == 0, "Number of errors should be zero" + + print(generatedCode) + + def test_summarize(self, testLogger, small_source_data_df, spark): testLogger.info("Creating data analyzer") - analyzer = dg.DataAnalyzer(sparkSession=spark, df=df_source_data) + analyzer = dg.DataAnalyzer(sparkSession=spark, df=small_source_data_df, maxRows=1000) testLogger.info("Summarizing data analyzer results") analyzer.summarize() - def test_summarize_to_df(self, generation_spec, testLogger): - testLogger.info("Building test data") - - df_source_data = generation_spec.build() - + def test_summarize_to_df(self, small_source_data_df, testLogger, spark): testLogger.info("Creating data analyzer") - analyzer = dg.DataAnalyzer(sparkSession=spark, df=df_source_data) + analyzer = dg.DataAnalyzer(sparkSession=spark, df=small_source_data_df, maxRows=1000) testLogger.info("Summarizing data analyzer results") df = analyzer.summarizeToDF() df.show() + + def test_generate_text_features(self, small_source_data_df, testLogger, spark): + testLogger.info("Creating data analyzer") + + analyzer = dg.DataAnalyzer(sparkSession=spark, df=small_source_data_df, maxRows=1000) + + df_text_features = analyzer.generateTextFeatures(small_source_data_df).limit(10) + df_text_features.show() + + # data = df_text_features.selectExpr("get_json_object(asin, '$.print_len') as asin").limit(10).collect() + data = (df_text_features.select(F.get_json_object(F.col("asin"), "$.print_len").alias("asin")) + .limit(10).collect()) + assert data[0]['asin'] is not None + + @pytest.mark.parametrize("sampleString, expectedMatch", + [("0234", "digits"), + ("http://www.yahoo.com", "url"), + ("http://www.yahoo.com/test.png", "image_url"), + ("info+new_account@databrickslabs.com", "email_uncommon"), + ("abcdefg", "alpha_lower"), + ("ABCDEFG", "alpha_upper"), + ("A09", "alphanumeric"), + ("this is a test ", "free_text"), + ("test_function", "identifier"), + ("10.0.0.1", "ip_addr") + ]) + def test_match_patterns(self, sampleString, expectedMatch, small_source_data_df, spark): + analyzer = dg.DataAnalyzer(sparkSession=spark, df=small_source_data_df) + + pattern_match_result = "" + for k, v in analyzer._regex_patterns.items(): + pattern = f"^{v}$" + + if re.match(pattern, sampleString) is not None: + pattern_match_result = k + break + + assert pattern_match_result == expectedMatch, f"expected match to be {expectedMatch}" + + def test_source_data_property(self, small_source_data_df, spark): + analyzer = dg.DataAnalyzer(sparkSession=spark, df=small_source_data_df, maxRows=500) + + count_rows = analyzer.sampledSourceDf.count() + assert abs(count_rows - 500) < 50, "expected count to be close to 500" + + def test_sample_data(self, small_source_data_df, spark): + # create a DataAnalyzer object + analyzer = dg.DataAnalyzer(sparkSession=spark, df=small_source_data_df, maxRows=500) + + # sample the data + df_sample = analyzer.sampleData(small_source_data_df, 100) + assert df_sample.count() <= 100, "expected count to be 100"