diff --git a/presto/scripts/run_benchmark.sh b/presto/scripts/run_benchmark.sh index f54aca7e..88aa622b 100755 --- a/presto/scripts/run_benchmark.sh +++ b/presto/scripts/run_benchmark.sh @@ -45,6 +45,7 @@ OPTIONS: -v, --verbose Print debug logs for worker/engine detection (e.g. node URIs, cluster-tag, GPU model). Use when engine is misdetected or the run fails. + --parallel Run given queries in parallel EXAMPLES: $0 -b tpch -s bench_sf100 @@ -171,6 +172,10 @@ parse_args() { export PRESTO_BENCHMARK_DEBUG=1 shift ;; + --parallel) + PARALLEL=true + shift + ;; *) echo "Error: Unknown argument $1" print_help @@ -247,6 +252,10 @@ if [[ "${SKIP_DROP_CACHE}" == "true" ]]; then PYTEST_ARGS+=("--skip-drop-cache") fi +if [[ "${PARALLEL}" == "true" ]]; then + PYTEST_ARGS+=("--parallel") +fi + source "${SCRIPT_DIR}/../../scripts/py_env_functions.sh" trap delete_python_virtual_env EXIT diff --git a/presto/testing/common/test_utils.py b/presto/testing/common/test_utils.py index 604ac420..7579f519 100644 --- a/presto/testing/common/test_utils.py +++ b/presto/testing/common/test_utils.py @@ -14,7 +14,7 @@ def get_table_external_location(schema_name, table, presto_cursor): - create_table_text = presto_cursor.execute(f"SHOW CREATE TABLE hive.{schema_name}.{table}").fetchone() + create_table_text = presto_cursor.cursor().execute(f"SHOW CREATE TABLE hive.{schema_name}.{table}").fetchone() test_pattern = r"external_location = 'file:/var/lib/presto/data/hive/data/integration_test/(.*)'" user_pattern = r"external_location = 'file:/var/lib/presto/data/hive/data/user_data/(.*)'" assert len(create_table_text) == 1 @@ -46,7 +46,7 @@ def get_scale_factor(request, presto_cursor): if bool(schema_name): # If a schema name is specified, get the scale factor from the metadata file located # where the table are fetching data from. - table = presto_cursor.execute(f"SHOW TABLES in {schema_name}").fetchone()[0] + table = presto_cursor.cursor().execute(f"SHOW TABLES in {schema_name}").fetchone()[0] location = get_table_external_location(schema_name, table, presto_cursor) repository_path = os.path.dirname(location) else: diff --git a/presto/testing/performance_benchmarks/common_fixtures.py b/presto/testing/performance_benchmarks/common_fixtures.py index 0892ea44..d253ef36 100644 --- a/presto/testing/performance_benchmarks/common_fixtures.py +++ b/presto/testing/performance_benchmarks/common_fixtures.py @@ -61,8 +61,9 @@ def presto_cursor(request): port = request.config.getoption("--port") user = request.config.getoption("--user") schema = request.config.getoption("--schema-name") - conn = prestodb.dbapi.connect(host=hostname, port=port, user=user, catalog="hive", schema=schema) - return conn.cursor() + + return prestodb.dbapi.connect(host=hostname, port=port, user=user, catalog="hive", + schema=schema) @pytest.fixture(scope="module") @@ -71,6 +72,7 @@ def benchmark_query(request, presto_cursor, benchmark_queries, benchmark_result_ profile = request.config.getoption("--profile") profile_script_path = request.config.getoption("--profile-script-path") metrics = request.config.getoption("--metrics") + parallel = request.config.getoption("--parallel") benchmark_type = request.node.obj.BENCHMARK_TYPE bench_output_dir = request.config.getoption("--output-dir") hostname = request.config.getoption("--hostname") @@ -93,6 +95,8 @@ def benchmark_query(request, presto_cursor, benchmark_queries, benchmark_result_ failed_queries_dict = benchmark_dict[BenchmarkKeys.FAILED_QUERIES_KEY] assert failed_queries_dict == {} + threads_dict = {} + def benchmark_query_function(query_id): profile_output_file_path = None try: @@ -102,8 +106,9 @@ def benchmark_query_function(query_id): start_profiler(profile_script_path, profile_output_file_path) result = [] for iteration_num in range(iterations): - cursor = presto_cursor.execute( - "--" + str(benchmark_type) + "_" + str(query_id) + "--" + "\n" + benchmark_queries[query_id] + cursor = presto_cursor.cursor().execute( + "--" + str(benchmark_type) + "_" + str(query_id) + "--" + "\n" + + benchmark_queries[query_id] ) result.append(cursor.stats["elapsedTimeMillis"]) @@ -138,5 +143,19 @@ def benchmark_query_function(query_id): finally: if profile and profile_output_file_path is not None: stop_profiler(profile_script_path, profile_output_file_path) - - return benchmark_query_function + + def parallel_benchmark_query_function(query_id): + local_query_id = query_id + threads_dict[local_query_id] = threading.Thread(target=benchmark_query_function, args=(local_query_id,)) + num_queries = len(request.config.getoption("--queries").split(",")) + num_threads = len(threads_dict.keys()) + if (num_queries == num_threads): + for q in threads_dict.keys(): + threads_dict[q].start() + for q in threads_dict.keys(): + threads_dict[q].join() + + if (parallel): + return parallel_benchmark_query_function + else: + return benchmark_query_function diff --git a/presto/testing/performance_benchmarks/conftest.py b/presto/testing/performance_benchmarks/conftest.py index 7989ed24..1c2c98e0 100644 --- a/presto/testing/performance_benchmarks/conftest.py +++ b/presto/testing/performance_benchmarks/conftest.py @@ -44,6 +44,7 @@ def pytest_addoption(parser): parser.addoption("--profile-script-path") parser.addoption("--metrics", action="store_true", default=False) parser.addoption("--skip-drop-cache", action="store_true", default=False) + parser.addoption("--parallel", action="store_true", default=False) def pytest_configure(config):