diff --git a/dbldatagen/data_generator.py b/dbldatagen/data_generator.py
index 9f99e903..36ce8be1 100644
--- a/dbldatagen/data_generator.py
+++ b/dbldatagen/data_generator.py
@@ -22,8 +22,18 @@
 from .spark_singleton import SparkSingleton
 from .utils import ensure, topologicalSort, DataGenError, deprecated, split_list_matching_condition
 
+START_TIMESTAMP_OPTION = "startTimestamp"
+ROWS_PER_SECOND_OPTION = "rowsPerSecond"
+AGE_LIMIT_OPTION = "ageLimit"
+NUM_PARTITIONS_OPTION = "numPartitions"
+ROWS_PER_BATCH_OPTION = "rowsPerBatch"
+STREAMING_SOURCE_OPTION = "streamingSource"
+
 _OLD_MIN_OPTION = 'min'
 _OLD_MAX_OPTION = 'max'
+RATE_SOURCE = "rate"
+RATE_PER_MICRO_BATCH_SOURCE = "rate-micro-batch"
+SPARK_RATE_MICROBATCH_VERSION = "3.2.1"
 
 _STREAMING_TIMESTAMP_COLUMN = "_source_timestamp"
 
@@ -1058,32 +1068,124 @@ def _getBaseDataFrame(self, startId=0, streaming=False, options=None):
                 df1 = df1.withColumnRenamed(SPARK_RANGE_COLUMN, self._seedColumnName)
 
         else:
-            status = (
-                f"Generating streaming data frame with ids from {startId} to {end_id} with {id_partitions} partitions")
-            self.logger.info(status)
-            self.executionHistory.append(status)
+            df1 = self._getStreamingBaseDataFrame(startId, options)
+
+        return df1
 
-            df1 = (self.sparkSession.readStream
-                   .format("rate"))
-            if options is not None:
-                if "rowsPerSecond" not in options:
-                    options['rowsPerSecond'] = 1
-                if "numPartitions" not in options:
-                    options['numPartitions'] = id_partitions
+    def _getStreamingSource(self, options=None, spark_version=None):
+        """ get streaming source from options
 
-                for k, v in options.items():
-                    df1 = df1.option(k, v)
-                df1 = (df1.load()
-                       .withColumnRenamed("value", self._seedColumnName)
-                       )
+        :param options: dictionary of options
+        :returns: streaming source if present in options (popping option from options), or default if not present
 
+        Default streaming source is computed based on whether we are running on Spark version 3.2.1 or later
+
+        if using spark version 3.2.1 or later - `rate-micro-batch` is used as source, otherwise `rate` is used as source
+        """
+        streaming_source = None
+        if options is not None:
+            if STREAMING_SOURCE_OPTION in options:
+                streaming_source = options[STREAMING_SOURCE_OPTION]
+                assert streaming_source in [RATE_SOURCE, RATE_PER_MICRO_BATCH_SOURCE], \
+                    f"Invalid streaming source - only ['{RATE_SOURCE}', ['{RATE_PER_MICRO_BATCH_SOURCE}'] supported"
+
+        if spark_version is None:
+            spark_version = self.sparkSession.version
+
+        if streaming_source is None:
+            # if using Spark 3.2.1, then default should be RATE_PER_MICRO_BATCH_SOURCE
+            if spark_version >= SPARK_RATE_MICROBATCH_VERSION:
+                streaming_source = RATE_PER_MICRO_BATCH_SOURCE
             else:
-                df1 = (df1.option("rowsPerSecond", 1)
-                       .option("numPartitions", id_partitions)
-                       .load()
-                       .withColumnRenamed("value", self._seedColumnName)
-                       )
+                streaming_source = RATE_SOURCE
+
+        return streaming_source
+
+    def _getCurrentSparkTimestamp(self, asLong=False):
+        """ get current spark timestamp
+
+        :param asLong: if True, returns current spark timestamp as long, string otherwise
+        """
+        if asLong:
+            return (self.sparkSession.sql(f"select cast(now() as long) as start_timestamp")
+                    .collect()[0]['start_timestamp'])
+        else:
+            return (self.sparkSession.sql(f"select cast(now() as string) as start_timestamp")
+                    .collect()[0]['start_timestamp'])
+
+    def _prepareStreamingOptions(self, options=None, spark_version=None):
+        default_streaming_partitions = (self.partitions if self.partitions is not None
+                                        else self.sparkSession.sparkContext.defaultParallelism)
+
+        streaming_source = self._getStreamingSource(options, spark_version)
+
+        if options is None:
+            new_options = ({ROWS_PER_SECOND_OPTION: default_streaming_partitions} if streaming_source == RATE_SOURCE
+                           else {ROWS_PER_BATCH_OPTION: default_streaming_partitions})
+        else:
+            new_options = options.copy()
+
+        if NUM_PARTITIONS_OPTION in new_options:
+            streaming_partitions = new_options[NUM_PARTITIONS_OPTION]
+        else:
+            streaming_partitions = default_streaming_partitions
+            new_options[NUM_PARTITIONS_OPTION] = streaming_partitions
+
+        if streaming_source == RATE_PER_MICRO_BATCH_SOURCE:
+            if START_TIMESTAMP_OPTION not in new_options:
+                new_options[START_TIMESTAMP_OPTION] = self._getCurrentSparkTimestamp(asLong=True)
+
+            if ROWS_PER_BATCH_OPTION not in new_options:
+                # generate one row per partition
+                new_options[ROWS_PER_BATCH_OPTION] = streaming_partitions
+
+        elif streaming_source == RATE_SOURCE:
+            if ROWS_PER_SECOND_OPTION not in new_options:
+                new_options[ROWS_PER_SECOND_OPTION] = streaming_partitions
+        else:
+            assert streaming_source in [RATE_SOURCE, RATE_PER_MICRO_BATCH_SOURCE], \
+                f"Invalid streaming source - only ['{RATE_SOURCE}', ['{RATE_PER_MICRO_BATCH_SOURCE}'] supported"
+
+        return streaming_source, new_options
+
+    def _getStreamingBaseDataFrame(self, startId=0, options=None):
+        """Generate base streaming data frame"""
+        end_id = self._rowCount + startId
+
+        # determine streaming source
+        streaming_source, options = self._prepareStreamingOptions(options)
+        partitions = options[NUM_PARTITIONS_OPTION]
+
+        if streaming_source == RATE_SOURCE:
+            status = f"Generating streaming data with rate source with {partitions} partitions"
+        else:
+            status = f"Generating streaming data with rate-micro-batch source with {partitions} partitions"
+
+        self.logger.info(status)
+        self.executionHistory.append(status)
+
+        age_limit_interval = None
+
+        if STREAMING_SOURCE_OPTION in options:
+            options.pop(STREAMING_SOURCE_OPTION)
+
+        if AGE_LIMIT_OPTION in options:
+            age_limit_interval = options.pop("ageLimit")
+            assert age_limit_interval is not None and float(age_limit_interval) > 0.0, "invalid age limit"
+
+        assert AGE_LIMIT_OPTION not in options
+        assert STREAMING_SOURCE_OPTION not in options
+
+        df1 = self.sparkSession.readStream.format(streaming_source)
+
+        for k, v in options.items():
+            df1 = df1.option(k, v)
+
+        df1 = df1.load().withColumnRenamed("value", ColumnGenerationSpec.SEED_COLUMN)
 
+        if age_limit_interval is not None:
+            df1 = df1.where(f"""abs(cast(now() as double) - cast(`timestamp` as double )) 
+                                  < cast({age_limit_interval} as double)""")
         return df1
 
     def _computeColumnBuildOrder(self):
diff --git a/makefile b/makefile
index e76e0952..e9cec3dc 100644
--- a/makefile
+++ b/makefile
@@ -29,6 +29,10 @@ create-dev-env:
 	@echo "$(OK_COLOR)=> making conda dev environment$(NO_COLOR)"
 	conda create -n $(ENV_NAME) python=3.8.10
 
+create-dev-env-321:
+	@echo "$(OK_COLOR)=> making conda dev environment for Spark 3.2.1$(NO_COLOR)"
+	conda create -n $(ENV_NAME) python=3.8.10
+
 create-github-build-env:
 	@echo "$(OK_COLOR)=> making conda dev environment$(NO_COLOR)"
 	conda create -n pip_$(ENV_NAME) python=3.8
@@ -37,6 +41,10 @@ install-dev-dependencies:
 	@echo "$(OK_COLOR)=> installing dev environment requirements$(NO_COLOR)"
 	pip install -r python/dev_require.txt
 
+install-dev-dependencies321:
+	@echo "$(OK_COLOR)=> installing dev environment requirements for Spark 3.2.1$(NO_COLOR)"
+	pip install -r python/dev_require_321.txt
+
 clean-dev-env:
 	@echo "$(OK_COLOR)=> Cleaning dev environment$(NO_COLOR)"
 	@echo "Current version: $(CURRENT_VERSION)"
diff --git a/python/dev_require_321.txt b/python/dev_require_321.txt
new file mode 100644
index 00000000..56846a62
--- /dev/null
+++ b/python/dev_require_321.txt
@@ -0,0 +1,33 @@
+# The following packages are used in building the test data generator framework.
+# All packages used are already installed in the Databricks runtime environment for version 6.5 or later
+numpy==1.20.1
+pandas==1.2.4
+pickleshare==0.7.5
+py4j==0.10.9.3
+pyarrow==4.0.0
+pyspark==3.2.1
+python-dateutil==2.8.1
+six==1.15.0
+
+# The following packages are required for development only
+wheel==0.36.2
+setuptools==52.0.0
+bumpversion
+pytest
+pytest-cov
+pytest-timeout
+rstcheck
+prospector
+
+# The following packages are only required for building documentation and are not required at runtime
+sphinx==5.0.0
+sphinx_rtd_theme
+nbsphinx
+numpydoc==0.8
+pypandoc
+ipython==7.16.3
+recommonmark
+sphinx-markdown-builder
+rst2pdf==0.98
+Jinja2 < 3.1
+
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
index cf10273a..cb9f5218 100644
--- a/tests/test_streaming.py
+++ b/tests/test_streaming.py
@@ -10,18 +10,382 @@
 spark = dg.SparkSingleton.getLocalInstance("streaming tests")
 
 
-class TestStreaming():
+class TestStreaming:
     row_count = 100000
     column_count = 10
-    time_to_run = 10
+    time_to_run = 8
     rows_per_second = 5000
 
+    def getTestDataSpec(self):
+        testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1",
+                                         rows=self.row_count,
+                                         partitions=spark.sparkContext.defaultParallelism,
+                                         seedMethod='hash_fieldname')
+                        .withIdOutput()
+                        .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)",
+                                    numColumns=self.column_count)
+                        .withColumn("code1", IntegerType(), minValue=100, maxValue=200)
+                        .withColumn("code2", IntegerType(), minValue=0, maxValue=10)
+                        .withColumn("code3", StringType(), values=['a', 'b', 'c'])
+                        .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True)
+                        .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1])
+                        )
+        return testDataSpec
+
+    def test_get_current_spark_timestamp(self):
+        testDataSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1",
+                                        rows=self.row_count,
+                                        partitions=spark.sparkContext.defaultParallelism,
+                                        seedMethod='hash_fieldname')
+        ts = testDataSpec._getCurrentSparkTimestamp(asLong=False)
+
+        assert type(ts) is str
+        assert ts is not None and len(ts.strip()) > 0
+        print(ts)
+
+    def test_get_current_spark_timestamp2(self):
+        testDataSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1",
+                                        rows=self.row_count,
+                                        partitions=spark.sparkContext.defaultParallelism,
+                                        seedMethod='hash_fieldname')
+        ts = testDataSpec._getCurrentSparkTimestamp(asLong=True)
+
+        assert(type(ts) is int)
+        print(ts)
+
+    def test_get_current_spark_version(self):
+        assert spark.version > "3.0.0"
+        assert spark.version <= "6.0.0"
+
+    @pytest.mark.parametrize("options_supplied,expected,spark_version_override",
+                             [(None, "rate" if spark.version < "3.2.1" else "rate-micro-batch", None),
+                              (None, "rate", "3.0.0"),
+                              (None, "rate-micro-batch", "3.2.1"),
+                              ({'streamingSource': 'rate'}, 'rate', None),
+                              ({'streamingSource': 'rate-micro-batch'}, 'rate-micro-batch', None)])
+    def test_streaming_source_options(self, options_supplied, expected, spark_version_override):
+        print("options", options_supplied)
+        testDataSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1",
+                                        rows=self.row_count,
+                                        partitions=spark.sparkContext.defaultParallelism,
+                                        seedMethod='hash_fieldname')
+
+        result = testDataSpec._getStreamingSource(options_supplied, spark_version_override)
+        print("Options:", options_supplied, "retval:", result)
+
+        assert result == expected
+
+    @pytest.mark.parametrize("options_supplied,source_expected,options_expected,spark_version_override",
+                             [(None, "rate" if spark.version < "3.2.1" else "rate-micro-batch",
+                               {'numPartitions': spark.sparkContext.defaultParallelism,
+                                'rowsPerBatch': spark.sparkContext.defaultParallelism,
+                                'startTimestamp': "*"} if spark.version >= "3.2.1"
+                               else {'numPartitions': spark.sparkContext.defaultParallelism,
+                                     'rowsPerSecond': spark.sparkContext.defaultParallelism}, None),
+
+                              (None, "rate", {'numPartitions': spark.sparkContext.defaultParallelism,
+                                              'rowsPerSecond': spark.sparkContext.defaultParallelism}, "3.0.0"),
+
+                              (None, "rate-micro-batch",
+                               {'numPartitions': spark.sparkContext.defaultParallelism,
+                                'rowsPerBatch': spark.sparkContext.defaultParallelism,
+                                'startTimestamp': "*"}, "3.2.1"),
+
+                              ({'streamingSource': 'rate'}, 'rate',
+                               {'numPartitions': spark.sparkContext.defaultParallelism,
+                                'streamingSource': 'rate',
+                                'rowsPerSecond': spark.sparkContext.defaultParallelism}, None),
+
+                              ({'streamingSource': 'rate', 'rowsPerSecond': 5000}, 'rate',
+                               {'numPartitions': spark.sparkContext.defaultParallelism,
+                                'streamingSource': 'rate',
+                                'rowsPerSecond': 5000}, None),
+
+                              ({'streamingSource': 'rate', 'numPartitions': 10}, 'rate',
+                               {'numPartitions': 10, 'rowsPerSecond': 10, 'streamingSource': 'rate'}, None),
+
+                              ({'streamingSource': 'rate', 'numPartitions': 10, 'rowsPerSecond': 5000}, 'rate',
+                               {'numPartitions': 10, 'rowsPerSecond': 5000, 'streamingSource': 'rate'}, None),
+
+                              ({'streamingSource': 'rate-micro-batch'}, 'rate-micro-batch',
+                               {'streamingSource': 'rate-micro-batch',
+                                'numPartitions': spark.sparkContext.defaultParallelism,
+                                'startTimestamp': '*',
+                                'rowsPerBatch': spark.sparkContext.defaultParallelism}, None),
+
+                              ({'streamingSource': 'rate-micro-batch', 'numPartitions':20}, 'rate-micro-batch',
+                               {'streamingSource': 'rate-micro-batch',
+                                'numPartitions': 20,
+                                'startTimestamp': '*',
+                                'rowsPerBatch': 20}, None),
+
+                              ({'streamingSource': 'rate-micro-batch', 'numPartitions': 20, 'rowsPerBatch': 4300},
+                               'rate-micro-batch',
+                               {'streamingSource': 'rate-micro-batch',
+                                'numPartitions': 20,
+                                'startTimestamp': '*',
+                                'rowsPerBatch': 4300}, None),
+                              ])
+    def test_prepare_options(self, options_supplied, source_expected, options_expected, spark_version_override):
+        testDataSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1",
+                                        rows=self.row_count,
+                                        partitions=spark.sparkContext.defaultParallelism,
+                                        seedMethod='hash_fieldname')
+
+        streaming_source, new_options = testDataSpec._prepareStreamingOptions(options_supplied, spark_version_override)
+        print("Options supplied:", options_supplied, "streamingSource:", streaming_source)
+
+        assert streaming_source == source_expected, "unexpected streaming source"
+
+        if streaming_source == "rate-micro-batch":
+            assert "startTimestamp" in new_options
+            assert "startTimestamp" in options_expected
+            if options_expected["startTimestamp"] == "*":
+                options_expected.pop("startTimestamp")
+                new_options.pop("startTimestamp")
+
+        print("options expected:", options_expected)
+
+        assert  new_options == options_expected, "unexpected options"
+
     @pytest.fixture
-    def getStreamingDirs(self):
+    def getBaseDir(self, request):
         time_now = int(round(time.time() * 1000))
+        base_dir = f"/tmp/testdatagenerator_{request.node.originalname}_{time_now}"
+        yield base_dir
+        print("cleaning base dir")
+        shutil.rmtree(base_dir)
+
+    @pytest.fixture
+    def getCheckpoint(self, getBaseDir, request):
+        checkpoint_dir = os.path.join(getBaseDir, "checkpoint1")
+        os.makedirs(checkpoint_dir)
+
+        yield checkpoint_dir
+        print("cleaning checkpoint dir")
+
+    @pytest.fixture
+    def getDataDir(self, getBaseDir, request):
+        data_dir = os.path.join(getBaseDir, "data1")
+        os.makedirs(data_dir)
+
+        yield data_dir
+        print("cleaning data dir")
+
+
+
+    def test_fixture1(self, getCheckpoint, getDataDir):
+        print(getCheckpoint)
+        print(getDataDir)
+
+    def test_streaming_basic_rate(self, getDataDir, getCheckpoint):
+        test_dir = getDataDir
+        checkpoint_dir = getCheckpoint
+
+        try:
+
+            testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1",
+                                         rows=self.row_count,
+                                         partitions=spark.sparkContext.defaultParallelism,
+                                         seedMethod='hash_fieldname')
+                            .withIdOutput())
+
+            dfTestData = testDataSpec.build(withStreaming=True,
+                                            options={'rowsPerSecond': self.rows_per_second,
+                                                     'ageLimit': 1,
+                                                     'streamingSource': 'rate'})
+
+            (dfTestData.writeStream
+                       .option("checkpointLocation", checkpoint_dir)
+                       .outputMode("append")
+                       .format("parquet")
+                       .start(test_dir)
+             )
+
+            start_time = time.time()
+            time.sleep(self.time_to_run)
+
+            # note stopping the stream may produce exceptions - these can be ignored
+            recent_progress = []
+            for x in spark.streams.active:
+                recent_progress.append(x.recentProgress)
+                print(x)
+                x.stop()
+
+            end_time = time.time()
+
+            # read newly written data
+            df2 = spark.read.format("parquet").load(test_dir)
+
+            new_data_rows = df2.count()
+
+            print("read {} rows from newly written data".format(new_data_rows))
+        finally:
+            pass
+
+        print("*** Done ***")
+
+        print("elapsed time (seconds)", end_time - start_time)
+
+        # check that we have at least one second of data
+        assert new_data_rows >  self.rows_per_second
+
+    def test_streaming_basic_rate_micro_batch(self, getDataDir, getCheckpoint):
+        test_dir = getDataDir
+        checkpoint_dir = getCheckpoint
+
+        try:
+
+            testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1",
+                                         rows=self.row_count,
+                                         partitions=spark.sparkContext.defaultParallelism,
+                                         seedMethod='hash_fieldname')
+                            .withIdOutput()
+                            .withColumn("code1", IntegerType(), minValue=100, maxValue=200)
+                            .withColumn("code2", IntegerType(), minValue=0, maxValue=10)
+                            .withColumn("code3", StringType(), values=['a', 'b', 'c'])
+                            )
+
+            dfTestData = testDataSpec.build(withStreaming=True,
+                                            options={'rowsPerBatch': 1000,
+                                                     'streamingSource': 'rate-micro-batch',
+                                                     'startTimestamp': 0})
+
+            (dfTestData.writeStream
+                       .option("checkpointLocation", checkpoint_dir)
+                       .outputMode("append")
+                       .format("parquet")
+                       .start(test_dir)
+             )
+
+            start_time = time.time()
+            time.sleep(self.time_to_run)
+
+            # note stopping the stream may produce exceptions - these can be ignored
+            recent_progress = []
+            for x in spark.streams.active:
+                recent_progress.append(x.recentProgress)
+                print(x)
+                x.stop()
+
+            end_time = time.time()
+
+            # read newly written data
+            df2 = spark.read.format("parquet").load(test_dir)
+
+            new_data_rows = df2.count()
+
+            print("read {} rows from newly written data".format(new_data_rows))
+        finally:
+            pass
+
+        print("*** Done ***")
+
+        print("elapsed time (seconds)", end_time - start_time)
+
+        # check that we have at least one second of data
+        assert new_data_rows >  self.rows_per_second
+
+
+    def test_streaming_rate_source(self):
+        print(spark.version)
+        test_dir, checkpoint_dir, base_dir = self.getDataAndCheckpoint("test1")
+
+        new_data_rows = 0
+
+        self.makeDataAndCheckpointDirs(test_dir, checkpoint_dir)
+
+        try:
+
+            testDataSpec = self.getTestDataSpec()
+
+            dfTestData = testDataSpec.build(withStreaming=True,
+                                            options={'rowsPerSecond': self.rows_per_second,
+                                                     'ageLimit': 1,
+                                                     'streamingSource': 'rate'})
+
+            start_time = time.time()
+            time.sleep(self.time_to_run)
+
+            # note stopping the stream may produce exceptions - these can be ignored
+            recent_progress = []
+            for x in spark.streams.active:
+                recent_progress.append(x.recentProgress)
+                print(x)
+                x.stop()
+
+            end_time = time.time()
+
+            # read newly written data
+            df2 = spark.read.format("parquet").load(test_dir)
+
+            new_data_rows = df2.count()
+
+            print("read {} rows from newly written data".format(new_data_rows))
+        finally:
+            shutil.rmtree(base_dir)
+
+        print("*** Done ***")
+
+        print("elapsed time (seconds)", end_time - start_time)
+
+        # check that we have at least one second of data
+        self.assertGreater(new_data_rows, self.rows_per_second)
+
+
+    def test_streaming(self):
+        print(spark.version)
+        test_dir, checkpoint_dir, base_dir = self.getDataAndCheckpoint("test1")
+
+        new_data_rows = 0
+
+        self.makeDataAndCheckpointDirs(test_dir, checkpoint_dir)
+
+        try:
+
+            testDataSpec = self.getTestDataSpec()
+
+            dfTestData = testDataSpec.build(withStreaming=True,
+                                            options={'rowsPerSecond': self.rows_per_second,
+                                                     'ageLimit': 1})
+
+            start_time = time.time()
+            time.sleep(self.time_to_run)
+
+            # note stopping the stream may produce exceptions - these can be ignored
+            recent_progress = []
+            for x in spark.streams.active:
+                recent_progress.append(x.recentProgress)
+                print(x)
+                x.stop()
+
+            end_time = time.time()
+
+            # read newly written data
+            df2 = spark.read.format("parquet").load(test_dir)
+
+            new_data_rows = df2.count()
+
+            print("read {} rows from newly written data".format(new_data_rows))
+        finally:
+            shutil.rmtree(base_dir)
+
+        print("*** Done ***")
+
+        print("elapsed time (seconds)", end_time - start_time)
+
+        # check that we have at least one second of data
+        self.assertGreater(new_data_rows, self.rows_per_second)
+
+    def test_streaming_with_age_limit(self):
+        print(spark.version)
+
+        time_now = int(round(time.time() * 1000))
+
         base_dir = f"/tmp/testdatagenerator_{time_now}"
         print("test dir created")
         data_dir = os.path.join(base_dir, "data")
+
         checkpoint_dir = os.path.join(base_dir, "checkpoint")
         os.makedirs(data_dir)
         os.makedirs(checkpoint_dir)
@@ -41,61 +405,51 @@ def test_streaming(self, getStreamingDirs, seedColumnName):
                                              partitions=4, seedMethod='hash_fieldname', seedColumnName=seedColumnName))
         else:
             testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count,
-                                             partitions=4, seedMethod='hash_fieldname'))
+                                             partitions=4, seedMethod='hash_fieldname')
+                            .withIdOutput()
+                            .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)",
+                                        numColumns=self.column_count)
+                            .withColumn("code1", IntegerType(), minValue=100, maxValue=200)
+                            .withColumn("code2", IntegerType(), minValue=0, maxValue=10)
+                            .withColumn("code3", StringType(), values=['a', 'b', 'c'])
+                            .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True)
+                            .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1])
 
-        testDataSpec = (testDataSpec
-                        .withIdOutput()
-                        .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)",
-                                    numColumns=self.column_count)
-                        .withColumn("code1", IntegerType(), minValue=100, maxValue=200)
-                        .withColumn("code2", IntegerType(), minValue=0, maxValue=10)
-                        .withColumn("code3", StringType(), values=['a', 'b', 'c'])
-                        .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True)
-                        .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1])
+                            )
 
-                        )
+            dfTestData = testDataSpec.build(withStreaming=True,
+                                            options={'rowsPerSecond': self.rows_per_second,
+                                                     'ageLimit': 1})
 
-        dfTestData = testDataSpec.build(withStreaming=True,
-                                        options={'rowsPerSecond': self.rows_per_second})
+            (dfTestData
+             .writeStream
+             .format("parquet")
+             .outputMode("append")
+             .option("path", test_dir)
+             .option("checkpointLocation", checkpoint_dir)
+             .start())
 
-        # check that seed column is in schema
-        fields = [c.name for c in dfTestData.schema.fields]
+            start_time = time.time()
+            time.sleep(self.time_to_run)
 
-        if seedColumnName is not None:
-            assert seedColumnName in fields
-            assert "id" not in fields if seedColumnName != "id" else True
+            # note stopping the stream may produce exceptions - these can be ignored
+            recent_progress = []
+            for x in spark.streams.active:
+                recent_progress.append(x.recentProgress)
+                print(x)
+                x.stop()
 
-        sq = (dfTestData
-              .writeStream
-              .format("parquet")
-              .outputMode("append")
-              .option("path", test_dir)
-              .option("checkpointLocation", checkpoint_dir)
-              .start())
+            end_time = time.time()
 
-        # loop until we get one seconds worth of data
-        start_time = time.time()
-        elapsed_time = 0
-        rows_retrieved = 0
-        time_limit = 10.0
-
-        while elapsed_time < time_limit and rows_retrieved <= self.rows_per_second:
-            time.sleep(1)
-
-            elapsed_time = time.time() - start_time
-
-            try:
-                df2 = spark.read.format("parquet").load(test_dir)
-                rows_retrieved = df2.count()
+            # read newly written data
+            df2 = spark.read.format("parquet").load(test_dir)
 
-            # ignore file or metadata not found issues arising from read before stream has written first batch
-            except Exception as exc:  # pylint: disable=broad-exception-caught
-                print("Exception:", exc)
+            new_data_rows = df2.count()
 
-        if sq.isActive:
-            sq.stop()
+            print("read {} rows from newly written data".format(new_data_rows))
+        finally:
+            shutil.rmtree(base_dir)
 
-        end_time = time.time()
 
         print("*** Done ***")
         print(f"read {rows_retrieved} rows from newly written data")
@@ -177,4 +531,6 @@ def test_streaming_trigger_once(self, getStreamingDirs, seedColumnName):
         print("elapsed time (seconds)", end_time - start_time)
 
         # check that we have at least one second of data
-        assert rows_retrieved >= self.rows_per_second
+        self.assertGreater(new_data_rows, int(self.rows_per_second / 4))
+
+