diff --git a/dbldatagen/data_generator.py b/dbldatagen/data_generator.py index 9f99e903..36ce8be1 100644 --- a/dbldatagen/data_generator.py +++ b/dbldatagen/data_generator.py @@ -22,8 +22,18 @@ from .spark_singleton import SparkSingleton from .utils import ensure, topologicalSort, DataGenError, deprecated, split_list_matching_condition +START_TIMESTAMP_OPTION = "startTimestamp" +ROWS_PER_SECOND_OPTION = "rowsPerSecond" +AGE_LIMIT_OPTION = "ageLimit" +NUM_PARTITIONS_OPTION = "numPartitions" +ROWS_PER_BATCH_OPTION = "rowsPerBatch" +STREAMING_SOURCE_OPTION = "streamingSource" + _OLD_MIN_OPTION = 'min' _OLD_MAX_OPTION = 'max' +RATE_SOURCE = "rate" +RATE_PER_MICRO_BATCH_SOURCE = "rate-micro-batch" +SPARK_RATE_MICROBATCH_VERSION = "3.2.1" _STREAMING_TIMESTAMP_COLUMN = "_source_timestamp" @@ -1058,32 +1068,124 @@ def _getBaseDataFrame(self, startId=0, streaming=False, options=None): df1 = df1.withColumnRenamed(SPARK_RANGE_COLUMN, self._seedColumnName) else: - status = ( - f"Generating streaming data frame with ids from {startId} to {end_id} with {id_partitions} partitions") - self.logger.info(status) - self.executionHistory.append(status) + df1 = self._getStreamingBaseDataFrame(startId, options) + + return df1 - df1 = (self.sparkSession.readStream - .format("rate")) - if options is not None: - if "rowsPerSecond" not in options: - options['rowsPerSecond'] = 1 - if "numPartitions" not in options: - options['numPartitions'] = id_partitions + def _getStreamingSource(self, options=None, spark_version=None): + """ get streaming source from options - for k, v in options.items(): - df1 = df1.option(k, v) - df1 = (df1.load() - .withColumnRenamed("value", self._seedColumnName) - ) + :param options: dictionary of options + :returns: streaming source if present in options (popping option from options), or default if not present + Default streaming source is computed based on whether we are running on Spark version 3.2.1 or later + + if using spark version 3.2.1 or later - `rate-micro-batch` is used as source, otherwise `rate` is used as source + """ + streaming_source = None + if options is not None: + if STREAMING_SOURCE_OPTION in options: + streaming_source = options[STREAMING_SOURCE_OPTION] + assert streaming_source in [RATE_SOURCE, RATE_PER_MICRO_BATCH_SOURCE], \ + f"Invalid streaming source - only ['{RATE_SOURCE}', ['{RATE_PER_MICRO_BATCH_SOURCE}'] supported" + + if spark_version is None: + spark_version = self.sparkSession.version + + if streaming_source is None: + # if using Spark 3.2.1, then default should be RATE_PER_MICRO_BATCH_SOURCE + if spark_version >= SPARK_RATE_MICROBATCH_VERSION: + streaming_source = RATE_PER_MICRO_BATCH_SOURCE else: - df1 = (df1.option("rowsPerSecond", 1) - .option("numPartitions", id_partitions) - .load() - .withColumnRenamed("value", self._seedColumnName) - ) + streaming_source = RATE_SOURCE + + return streaming_source + + def _getCurrentSparkTimestamp(self, asLong=False): + """ get current spark timestamp + + :param asLong: if True, returns current spark timestamp as long, string otherwise + """ + if asLong: + return (self.sparkSession.sql(f"select cast(now() as long) as start_timestamp") + .collect()[0]['start_timestamp']) + else: + return (self.sparkSession.sql(f"select cast(now() as string) as start_timestamp") + .collect()[0]['start_timestamp']) + + def _prepareStreamingOptions(self, options=None, spark_version=None): + default_streaming_partitions = (self.partitions if self.partitions is not None + else self.sparkSession.sparkContext.defaultParallelism) + + streaming_source = self._getStreamingSource(options, spark_version) + + if options is None: + new_options = ({ROWS_PER_SECOND_OPTION: default_streaming_partitions} if streaming_source == RATE_SOURCE + else {ROWS_PER_BATCH_OPTION: default_streaming_partitions}) + else: + new_options = options.copy() + + if NUM_PARTITIONS_OPTION in new_options: + streaming_partitions = new_options[NUM_PARTITIONS_OPTION] + else: + streaming_partitions = default_streaming_partitions + new_options[NUM_PARTITIONS_OPTION] = streaming_partitions + + if streaming_source == RATE_PER_MICRO_BATCH_SOURCE: + if START_TIMESTAMP_OPTION not in new_options: + new_options[START_TIMESTAMP_OPTION] = self._getCurrentSparkTimestamp(asLong=True) + + if ROWS_PER_BATCH_OPTION not in new_options: + # generate one row per partition + new_options[ROWS_PER_BATCH_OPTION] = streaming_partitions + + elif streaming_source == RATE_SOURCE: + if ROWS_PER_SECOND_OPTION not in new_options: + new_options[ROWS_PER_SECOND_OPTION] = streaming_partitions + else: + assert streaming_source in [RATE_SOURCE, RATE_PER_MICRO_BATCH_SOURCE], \ + f"Invalid streaming source - only ['{RATE_SOURCE}', ['{RATE_PER_MICRO_BATCH_SOURCE}'] supported" + + return streaming_source, new_options + + def _getStreamingBaseDataFrame(self, startId=0, options=None): + """Generate base streaming data frame""" + end_id = self._rowCount + startId + + # determine streaming source + streaming_source, options = self._prepareStreamingOptions(options) + partitions = options[NUM_PARTITIONS_OPTION] + + if streaming_source == RATE_SOURCE: + status = f"Generating streaming data with rate source with {partitions} partitions" + else: + status = f"Generating streaming data with rate-micro-batch source with {partitions} partitions" + + self.logger.info(status) + self.executionHistory.append(status) + + age_limit_interval = None + + if STREAMING_SOURCE_OPTION in options: + options.pop(STREAMING_SOURCE_OPTION) + + if AGE_LIMIT_OPTION in options: + age_limit_interval = options.pop("ageLimit") + assert age_limit_interval is not None and float(age_limit_interval) > 0.0, "invalid age limit" + + assert AGE_LIMIT_OPTION not in options + assert STREAMING_SOURCE_OPTION not in options + + df1 = self.sparkSession.readStream.format(streaming_source) + + for k, v in options.items(): + df1 = df1.option(k, v) + + df1 = df1.load().withColumnRenamed("value", ColumnGenerationSpec.SEED_COLUMN) + if age_limit_interval is not None: + df1 = df1.where(f"""abs(cast(now() as double) - cast(`timestamp` as double )) + < cast({age_limit_interval} as double)""") return df1 def _computeColumnBuildOrder(self): diff --git a/makefile b/makefile index e76e0952..e9cec3dc 100644 --- a/makefile +++ b/makefile @@ -29,6 +29,10 @@ create-dev-env: @echo "$(OK_COLOR)=> making conda dev environment$(NO_COLOR)" conda create -n $(ENV_NAME) python=3.8.10 +create-dev-env-321: + @echo "$(OK_COLOR)=> making conda dev environment for Spark 3.2.1$(NO_COLOR)" + conda create -n $(ENV_NAME) python=3.8.10 + create-github-build-env: @echo "$(OK_COLOR)=> making conda dev environment$(NO_COLOR)" conda create -n pip_$(ENV_NAME) python=3.8 @@ -37,6 +41,10 @@ install-dev-dependencies: @echo "$(OK_COLOR)=> installing dev environment requirements$(NO_COLOR)" pip install -r python/dev_require.txt +install-dev-dependencies321: + @echo "$(OK_COLOR)=> installing dev environment requirements for Spark 3.2.1$(NO_COLOR)" + pip install -r python/dev_require_321.txt + clean-dev-env: @echo "$(OK_COLOR)=> Cleaning dev environment$(NO_COLOR)" @echo "Current version: $(CURRENT_VERSION)" diff --git a/python/dev_require_321.txt b/python/dev_require_321.txt new file mode 100644 index 00000000..56846a62 --- /dev/null +++ b/python/dev_require_321.txt @@ -0,0 +1,33 @@ +# The following packages are used in building the test data generator framework. +# All packages used are already installed in the Databricks runtime environment for version 6.5 or later +numpy==1.20.1 +pandas==1.2.4 +pickleshare==0.7.5 +py4j==0.10.9.3 +pyarrow==4.0.0 +pyspark==3.2.1 +python-dateutil==2.8.1 +six==1.15.0 + +# The following packages are required for development only +wheel==0.36.2 +setuptools==52.0.0 +bumpversion +pytest +pytest-cov +pytest-timeout +rstcheck +prospector + +# The following packages are only required for building documentation and are not required at runtime +sphinx==5.0.0 +sphinx_rtd_theme +nbsphinx +numpydoc==0.8 +pypandoc +ipython==7.16.3 +recommonmark +sphinx-markdown-builder +rst2pdf==0.98 +Jinja2 < 3.1 + diff --git a/tests/test_streaming.py b/tests/test_streaming.py index cf10273a..cb9f5218 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -10,18 +10,382 @@ spark = dg.SparkSingleton.getLocalInstance("streaming tests") -class TestStreaming(): +class TestStreaming: row_count = 100000 column_count = 10 - time_to_run = 10 + time_to_run = 8 rows_per_second = 5000 + def getTestDataSpec(self): + testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", + rows=self.row_count, + partitions=spark.sparkContext.defaultParallelism, + seedMethod='hash_fieldname') + .withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=self.column_count) + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) + return testDataSpec + + def test_get_current_spark_timestamp(self): + testDataSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1", + rows=self.row_count, + partitions=spark.sparkContext.defaultParallelism, + seedMethod='hash_fieldname') + ts = testDataSpec._getCurrentSparkTimestamp(asLong=False) + + assert type(ts) is str + assert ts is not None and len(ts.strip()) > 0 + print(ts) + + def test_get_current_spark_timestamp2(self): + testDataSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1", + rows=self.row_count, + partitions=spark.sparkContext.defaultParallelism, + seedMethod='hash_fieldname') + ts = testDataSpec._getCurrentSparkTimestamp(asLong=True) + + assert(type(ts) is int) + print(ts) + + def test_get_current_spark_version(self): + assert spark.version > "3.0.0" + assert spark.version <= "6.0.0" + + @pytest.mark.parametrize("options_supplied,expected,spark_version_override", + [(None, "rate" if spark.version < "3.2.1" else "rate-micro-batch", None), + (None, "rate", "3.0.0"), + (None, "rate-micro-batch", "3.2.1"), + ({'streamingSource': 'rate'}, 'rate', None), + ({'streamingSource': 'rate-micro-batch'}, 'rate-micro-batch', None)]) + def test_streaming_source_options(self, options_supplied, expected, spark_version_override): + print("options", options_supplied) + testDataSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1", + rows=self.row_count, + partitions=spark.sparkContext.defaultParallelism, + seedMethod='hash_fieldname') + + result = testDataSpec._getStreamingSource(options_supplied, spark_version_override) + print("Options:", options_supplied, "retval:", result) + + assert result == expected + + @pytest.mark.parametrize("options_supplied,source_expected,options_expected,spark_version_override", + [(None, "rate" if spark.version < "3.2.1" else "rate-micro-batch", + {'numPartitions': spark.sparkContext.defaultParallelism, + 'rowsPerBatch': spark.sparkContext.defaultParallelism, + 'startTimestamp': "*"} if spark.version >= "3.2.1" + else {'numPartitions': spark.sparkContext.defaultParallelism, + 'rowsPerSecond': spark.sparkContext.defaultParallelism}, None), + + (None, "rate", {'numPartitions': spark.sparkContext.defaultParallelism, + 'rowsPerSecond': spark.sparkContext.defaultParallelism}, "3.0.0"), + + (None, "rate-micro-batch", + {'numPartitions': spark.sparkContext.defaultParallelism, + 'rowsPerBatch': spark.sparkContext.defaultParallelism, + 'startTimestamp': "*"}, "3.2.1"), + + ({'streamingSource': 'rate'}, 'rate', + {'numPartitions': spark.sparkContext.defaultParallelism, + 'streamingSource': 'rate', + 'rowsPerSecond': spark.sparkContext.defaultParallelism}, None), + + ({'streamingSource': 'rate', 'rowsPerSecond': 5000}, 'rate', + {'numPartitions': spark.sparkContext.defaultParallelism, + 'streamingSource': 'rate', + 'rowsPerSecond': 5000}, None), + + ({'streamingSource': 'rate', 'numPartitions': 10}, 'rate', + {'numPartitions': 10, 'rowsPerSecond': 10, 'streamingSource': 'rate'}, None), + + ({'streamingSource': 'rate', 'numPartitions': 10, 'rowsPerSecond': 5000}, 'rate', + {'numPartitions': 10, 'rowsPerSecond': 5000, 'streamingSource': 'rate'}, None), + + ({'streamingSource': 'rate-micro-batch'}, 'rate-micro-batch', + {'streamingSource': 'rate-micro-batch', + 'numPartitions': spark.sparkContext.defaultParallelism, + 'startTimestamp': '*', + 'rowsPerBatch': spark.sparkContext.defaultParallelism}, None), + + ({'streamingSource': 'rate-micro-batch', 'numPartitions':20}, 'rate-micro-batch', + {'streamingSource': 'rate-micro-batch', + 'numPartitions': 20, + 'startTimestamp': '*', + 'rowsPerBatch': 20}, None), + + ({'streamingSource': 'rate-micro-batch', 'numPartitions': 20, 'rowsPerBatch': 4300}, + 'rate-micro-batch', + {'streamingSource': 'rate-micro-batch', + 'numPartitions': 20, + 'startTimestamp': '*', + 'rowsPerBatch': 4300}, None), + ]) + def test_prepare_options(self, options_supplied, source_expected, options_expected, spark_version_override): + testDataSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1", + rows=self.row_count, + partitions=spark.sparkContext.defaultParallelism, + seedMethod='hash_fieldname') + + streaming_source, new_options = testDataSpec._prepareStreamingOptions(options_supplied, spark_version_override) + print("Options supplied:", options_supplied, "streamingSource:", streaming_source) + + assert streaming_source == source_expected, "unexpected streaming source" + + if streaming_source == "rate-micro-batch": + assert "startTimestamp" in new_options + assert "startTimestamp" in options_expected + if options_expected["startTimestamp"] == "*": + options_expected.pop("startTimestamp") + new_options.pop("startTimestamp") + + print("options expected:", options_expected) + + assert new_options == options_expected, "unexpected options" + @pytest.fixture - def getStreamingDirs(self): + def getBaseDir(self, request): time_now = int(round(time.time() * 1000)) + base_dir = f"/tmp/testdatagenerator_{request.node.originalname}_{time_now}" + yield base_dir + print("cleaning base dir") + shutil.rmtree(base_dir) + + @pytest.fixture + def getCheckpoint(self, getBaseDir, request): + checkpoint_dir = os.path.join(getBaseDir, "checkpoint1") + os.makedirs(checkpoint_dir) + + yield checkpoint_dir + print("cleaning checkpoint dir") + + @pytest.fixture + def getDataDir(self, getBaseDir, request): + data_dir = os.path.join(getBaseDir, "data1") + os.makedirs(data_dir) + + yield data_dir + print("cleaning data dir") + + + + def test_fixture1(self, getCheckpoint, getDataDir): + print(getCheckpoint) + print(getDataDir) + + def test_streaming_basic_rate(self, getDataDir, getCheckpoint): + test_dir = getDataDir + checkpoint_dir = getCheckpoint + + try: + + testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", + rows=self.row_count, + partitions=spark.sparkContext.defaultParallelism, + seedMethod='hash_fieldname') + .withIdOutput()) + + dfTestData = testDataSpec.build(withStreaming=True, + options={'rowsPerSecond': self.rows_per_second, + 'ageLimit': 1, + 'streamingSource': 'rate'}) + + (dfTestData.writeStream + .option("checkpointLocation", checkpoint_dir) + .outputMode("append") + .format("parquet") + .start(test_dir) + ) + + start_time = time.time() + time.sleep(self.time_to_run) + + # note stopping the stream may produce exceptions - these can be ignored + recent_progress = [] + for x in spark.streams.active: + recent_progress.append(x.recentProgress) + print(x) + x.stop() + + end_time = time.time() + + # read newly written data + df2 = spark.read.format("parquet").load(test_dir) + + new_data_rows = df2.count() + + print("read {} rows from newly written data".format(new_data_rows)) + finally: + pass + + print("*** Done ***") + + print("elapsed time (seconds)", end_time - start_time) + + # check that we have at least one second of data + assert new_data_rows > self.rows_per_second + + def test_streaming_basic_rate_micro_batch(self, getDataDir, getCheckpoint): + test_dir = getDataDir + checkpoint_dir = getCheckpoint + + try: + + testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", + rows=self.row_count, + partitions=spark.sparkContext.defaultParallelism, + seedMethod='hash_fieldname') + .withIdOutput() + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + ) + + dfTestData = testDataSpec.build(withStreaming=True, + options={'rowsPerBatch': 1000, + 'streamingSource': 'rate-micro-batch', + 'startTimestamp': 0}) + + (dfTestData.writeStream + .option("checkpointLocation", checkpoint_dir) + .outputMode("append") + .format("parquet") + .start(test_dir) + ) + + start_time = time.time() + time.sleep(self.time_to_run) + + # note stopping the stream may produce exceptions - these can be ignored + recent_progress = [] + for x in spark.streams.active: + recent_progress.append(x.recentProgress) + print(x) + x.stop() + + end_time = time.time() + + # read newly written data + df2 = spark.read.format("parquet").load(test_dir) + + new_data_rows = df2.count() + + print("read {} rows from newly written data".format(new_data_rows)) + finally: + pass + + print("*** Done ***") + + print("elapsed time (seconds)", end_time - start_time) + + # check that we have at least one second of data + assert new_data_rows > self.rows_per_second + + + def test_streaming_rate_source(self): + print(spark.version) + test_dir, checkpoint_dir, base_dir = self.getDataAndCheckpoint("test1") + + new_data_rows = 0 + + self.makeDataAndCheckpointDirs(test_dir, checkpoint_dir) + + try: + + testDataSpec = self.getTestDataSpec() + + dfTestData = testDataSpec.build(withStreaming=True, + options={'rowsPerSecond': self.rows_per_second, + 'ageLimit': 1, + 'streamingSource': 'rate'}) + + start_time = time.time() + time.sleep(self.time_to_run) + + # note stopping the stream may produce exceptions - these can be ignored + recent_progress = [] + for x in spark.streams.active: + recent_progress.append(x.recentProgress) + print(x) + x.stop() + + end_time = time.time() + + # read newly written data + df2 = spark.read.format("parquet").load(test_dir) + + new_data_rows = df2.count() + + print("read {} rows from newly written data".format(new_data_rows)) + finally: + shutil.rmtree(base_dir) + + print("*** Done ***") + + print("elapsed time (seconds)", end_time - start_time) + + # check that we have at least one second of data + self.assertGreater(new_data_rows, self.rows_per_second) + + + def test_streaming(self): + print(spark.version) + test_dir, checkpoint_dir, base_dir = self.getDataAndCheckpoint("test1") + + new_data_rows = 0 + + self.makeDataAndCheckpointDirs(test_dir, checkpoint_dir) + + try: + + testDataSpec = self.getTestDataSpec() + + dfTestData = testDataSpec.build(withStreaming=True, + options={'rowsPerSecond': self.rows_per_second, + 'ageLimit': 1}) + + start_time = time.time() + time.sleep(self.time_to_run) + + # note stopping the stream may produce exceptions - these can be ignored + recent_progress = [] + for x in spark.streams.active: + recent_progress.append(x.recentProgress) + print(x) + x.stop() + + end_time = time.time() + + # read newly written data + df2 = spark.read.format("parquet").load(test_dir) + + new_data_rows = df2.count() + + print("read {} rows from newly written data".format(new_data_rows)) + finally: + shutil.rmtree(base_dir) + + print("*** Done ***") + + print("elapsed time (seconds)", end_time - start_time) + + # check that we have at least one second of data + self.assertGreater(new_data_rows, self.rows_per_second) + + def test_streaming_with_age_limit(self): + print(spark.version) + + time_now = int(round(time.time() * 1000)) + base_dir = f"/tmp/testdatagenerator_{time_now}" print("test dir created") data_dir = os.path.join(base_dir, "data") + checkpoint_dir = os.path.join(base_dir, "checkpoint") os.makedirs(data_dir) os.makedirs(checkpoint_dir) @@ -41,61 +405,51 @@ def test_streaming(self, getStreamingDirs, seedColumnName): partitions=4, seedMethod='hash_fieldname', seedColumnName=seedColumnName)) else: testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, - partitions=4, seedMethod='hash_fieldname')) + partitions=4, seedMethod='hash_fieldname') + .withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=self.column_count) + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - testDataSpec = (testDataSpec - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=self.column_count) - .withColumn("code1", IntegerType(), minValue=100, maxValue=200) - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) - ) + dfTestData = testDataSpec.build(withStreaming=True, + options={'rowsPerSecond': self.rows_per_second, + 'ageLimit': 1}) - dfTestData = testDataSpec.build(withStreaming=True, - options={'rowsPerSecond': self.rows_per_second}) + (dfTestData + .writeStream + .format("parquet") + .outputMode("append") + .option("path", test_dir) + .option("checkpointLocation", checkpoint_dir) + .start()) - # check that seed column is in schema - fields = [c.name for c in dfTestData.schema.fields] + start_time = time.time() + time.sleep(self.time_to_run) - if seedColumnName is not None: - assert seedColumnName in fields - assert "id" not in fields if seedColumnName != "id" else True + # note stopping the stream may produce exceptions - these can be ignored + recent_progress = [] + for x in spark.streams.active: + recent_progress.append(x.recentProgress) + print(x) + x.stop() - sq = (dfTestData - .writeStream - .format("parquet") - .outputMode("append") - .option("path", test_dir) - .option("checkpointLocation", checkpoint_dir) - .start()) + end_time = time.time() - # loop until we get one seconds worth of data - start_time = time.time() - elapsed_time = 0 - rows_retrieved = 0 - time_limit = 10.0 - - while elapsed_time < time_limit and rows_retrieved <= self.rows_per_second: - time.sleep(1) - - elapsed_time = time.time() - start_time - - try: - df2 = spark.read.format("parquet").load(test_dir) - rows_retrieved = df2.count() + # read newly written data + df2 = spark.read.format("parquet").load(test_dir) - # ignore file or metadata not found issues arising from read before stream has written first batch - except Exception as exc: # pylint: disable=broad-exception-caught - print("Exception:", exc) + new_data_rows = df2.count() - if sq.isActive: - sq.stop() + print("read {} rows from newly written data".format(new_data_rows)) + finally: + shutil.rmtree(base_dir) - end_time = time.time() print("*** Done ***") print(f"read {rows_retrieved} rows from newly written data") @@ -177,4 +531,6 @@ def test_streaming_trigger_once(self, getStreamingDirs, seedColumnName): print("elapsed time (seconds)", end_time - start_time) # check that we have at least one second of data - assert rows_retrieved >= self.rows_per_second + self.assertGreater(new_data_rows, int(self.rows_per_second / 4)) + +