diff --git a/be/src/http/action/stream_load.cpp b/be/src/http/action/stream_load.cpp index e86ca32db238e0..0c526908b5a661 100644 --- a/be/src/http/action/stream_load.cpp +++ b/be/src/http/action/stream_load.cpp @@ -157,9 +157,9 @@ void StreamLoadAction::handle(HttpRequest* req) { ctx->load_cost_nanos = MonotonicNanos() - ctx->start_nanos; if (!ctx->status.ok() && !ctx->status.is_publish_timeout()) { - if (ctx->need_rollback) { + if (ctx->need_rollback()) { (void)_exec_env->stream_load_executor()->rollback_txn(ctx); - ctx->need_rollback = false; + ctx->clear_need_rollback(); } if (ctx->body_sink != nullptr) { ctx->body_sink->cancel(ctx->status); @@ -215,7 +215,8 @@ Status StreamLoadAction::_handle_batch_write(starrocks::HttpRequest* http_req, S } int StreamLoadAction::on_header(HttpRequest* req) { - auto* ctx = new StreamLoadContext(_exec_env, &StreamLoadMetrics::instance()->streaming_load_current_processing); + auto* ctx = new StreamLoadContext(_exec_env, _exec_env->load_stream_mgr(), + &StreamLoadMetrics::instance()->streaming_load_current_processing); ctx->ref(); req->set_handler_ctx(ctx); @@ -262,9 +263,9 @@ int StreamLoadAction::on_header(HttpRequest* req) { auto st = _on_header(req, ctx); if (!st.ok()) { ctx->status = st; - if (ctx->need_rollback) { + if (ctx->need_rollback()) { (void)_exec_env->stream_load_executor()->rollback_txn(ctx); - ctx->need_rollback = false; + ctx->clear_need_rollback(); } if (ctx->body_sink != nullptr) { ctx->body_sink->cancel(st); diff --git a/be/src/http/action/transaction_stream_load.cpp b/be/src/http/action/transaction_stream_load.cpp index ee170d276fbfa5..7b4eb37006d3cd 100644 --- a/be/src/http/action/transaction_stream_load.cpp +++ b/be/src/http/action/transaction_stream_load.cpp @@ -99,7 +99,7 @@ static void _send_reply(HttpRequest* req, const std::string& str) { } void TransactionManagerAction::_send_error_reply(HttpRequest* req, const Status& st) { - auto ctx = std::make_unique(_exec_env); + auto ctx = std::make_unique(_exec_env, _exec_env->load_stream_mgr()); ctx->label = req->header(HTTP_LABEL_KEY); auto str = ctx->to_resp_json(req->param(HTTP_TXN_OP_KEY), st); @@ -180,7 +180,7 @@ TransactionStreamLoadAction::TransactionStreamLoadAction(ExecEnv* exec_env) : _e TransactionStreamLoadAction::~TransactionStreamLoadAction() = default; void TransactionStreamLoadAction::_send_error_reply(HttpRequest* req, const Status& st) { - auto ctx = std::make_unique(_exec_env); + auto ctx = std::make_unique(_exec_env, _exec_env->load_stream_mgr()); ctx->label = req->header(HTTP_LABEL_KEY); auto str = ctx->to_resp_json(TXN_LOAD, st); @@ -214,7 +214,7 @@ void TransactionStreamLoadAction::handle(HttpRequest* req) { ctx->last_active_ts = MonotonicNanos(); if (!ctx->status.ok()) { - if (ctx->need_rollback) { + if (ctx->need_rollback()) { (void)_exec_env->transaction_mgr()->_rollback_transaction(ctx); } } @@ -297,7 +297,7 @@ int TransactionStreamLoadAction::on_header(HttpRequest* req) { auto st = _on_header(req, ctx); if (!st.ok()) { ctx->status = st; - if (ctx->need_rollback) { + if (ctx->need_rollback()) { (void)_exec_env->transaction_mgr()->_rollback_transaction(ctx); } auto resp = _exec_env->transaction_mgr()->_build_reply(TXN_LOAD, ctx); diff --git a/be/src/runtime/batch_write/batch_write_mgr.cpp b/be/src/runtime/batch_write/batch_write_mgr.cpp index 8b6f18c147c0b0..34c9f4a291d6b1 100644 --- a/be/src/runtime/batch_write/batch_write_mgr.cpp +++ b/be/src/runtime/batch_write/batch_write_mgr.cpp @@ -138,8 +138,9 @@ StatusOr BatchWriteMgr::create_and_register_pipe( auto pipe = std::make_shared(pipe_name, batch_write_interval_ms, config::merge_commit_stream_load_pipe_block_wait_us, config::merge_commit_stream_load_pipe_max_buffered_bytes); - RETURN_IF_ERROR(exec_env->load_stream_mgr()->put(load_id, pipe)); - StreamLoadContext* ctx = new StreamLoadContext(exec_env, load_id); + auto* load_stream_mgr = exec_env->load_stream_mgr(); + RETURN_IF_ERROR(load_stream_mgr->put(load_id, pipe)); + StreamLoadContext* ctx = new StreamLoadContext(exec_env, load_id, load_stream_mgr); ctx->ref(); ctx->id = load_id; ctx->db = db; @@ -170,7 +171,7 @@ static std::string s_empty; void BatchWriteMgr::receive_stream_load_rpc(ExecEnv* exec_env, brpc::Controller* cntl, const PStreamLoadRequest* request, PStreamLoadResponse* response) { - auto* ctx = new StreamLoadContext(exec_env); + auto* ctx = new StreamLoadContext(exec_env, exec_env->load_stream_mgr()); ctx->ref(); DeferOp defer([&]() { response->set_json_result(ctx->to_json()); diff --git a/be/src/runtime/rejected_record_sync_daemon.cpp b/be/src/runtime/rejected_record_sync_daemon.cpp index b075d21bdcdc25..35a369300c898d 100644 --- a/be/src/runtime/rejected_record_sync_daemon.cpp +++ b/be/src/runtime/rejected_record_sync_daemon.cpp @@ -551,7 +551,7 @@ Status RejectedRecordSyncDaemon::post_to_stream_load(const std::string& payload) // ROOT, skipping password / INSERT-privilege checks. The placeholder // user/passwd fields below are kept syntactically valid; FE thrift // ignores them once the token bypass fires. - StreamLoadContext* ctx = new StreamLoadContext(_env); + StreamLoadContext* ctx = new StreamLoadContext(_env, _env->load_stream_mgr()); ctx->ref(); DeferOp release([&] { if (ctx->unref()) { diff --git a/be/src/runtime/routine_load/routine_load_task_executor.cpp b/be/src/runtime/routine_load/routine_load_task_executor.cpp index 430623b87b5e07..b71c36931110bd 100644 --- a/be/src/runtime/routine_load/routine_load_task_executor.cpp +++ b/be/src/runtime/routine_load/routine_load_task_executor.cpp @@ -84,6 +84,12 @@ std::string build_kafka_source_info(const TKafkaLoadInfo& info) { return std::string(buf.GetString(), buf.GetSize()); } +void set_need_rollback(StreamLoadContext* ctx, ExecEnv* exec_env) { + ctx->set_need_rollback([exec_env](StreamLoadContext* rollback_ctx) { + return exec_env->stream_load_executor()->rollback_txn(rollback_ctx); + }); +} + } // namespace Status RoutineLoadTaskExecutor::init(MetricRegistry* metrics) { @@ -127,7 +133,7 @@ Status RoutineLoadTaskExecutor::get_kafka_partition_meta(const PKafkaMetaProxyRe DCHECK(request.has_kafka_info()); // This context is meaningless, just for unifing the interface - StreamLoadContext ctx(_exec_env); + StreamLoadContext ctx(_exec_env, _exec_env->load_stream_mgr()); ctx.load_type = TLoadType::ROUTINE_LOAD; ctx.load_src_type = TLoadSourceType::KAFKA; ctx.label = "NaN"; @@ -144,7 +150,6 @@ Status RoutineLoadTaskExecutor::get_kafka_partition_meta(const PKafkaMetaProxyRe t_info.__set_properties(properties); ctx.kafka_info = std::make_unique(t_info); - ctx.need_rollback = false; std::shared_ptr consumer; RETURN_IF_ERROR(_data_consumer_pool.get_consumer(&ctx, &consumer)); @@ -170,7 +175,7 @@ Status RoutineLoadTaskExecutor::get_kafka_partition_offset(const PKafkaOffsetPro DCHECK(request.has_kafka_info()); // This context is meaningless, just for unifing the interface - StreamLoadContext ctx(_exec_env); + StreamLoadContext ctx(_exec_env, _exec_env->load_stream_mgr()); ctx.load_type = TLoadType::ROUTINE_LOAD; ctx.load_src_type = TLoadSourceType::KAFKA; ctx.label = "NaN"; @@ -187,7 +192,6 @@ Status RoutineLoadTaskExecutor::get_kafka_partition_offset(const PKafkaOffsetPro t_info.__set_properties(properties); ctx.kafka_info = std::make_unique(t_info); - ctx.need_rollback = false; // convert pb repeated value to vector std::vector partition_ids; @@ -220,7 +224,7 @@ Status RoutineLoadTaskExecutor::get_pulsar_partition_meta(const PPulsarMetaProxy DCHECK(request.has_pulsar_info()); // This context is meaningless, just for unifing the interface - StreamLoadContext ctx(_exec_env); + StreamLoadContext ctx(_exec_env, _exec_env->load_stream_mgr()); ctx.load_type = TLoadType::ROUTINE_LOAD; ctx.load_src_type = TLoadSourceType::PULSAR; ctx.label = "NaN"; @@ -238,7 +242,6 @@ Status RoutineLoadTaskExecutor::get_pulsar_partition_meta(const PPulsarMetaProxy t_info.__set_properties(properties); ctx.pulsar_info = std::make_unique(t_info); - ctx.need_rollback = false; std::shared_ptr consumer; RETURN_IF_ERROR(_data_consumer_pool.get_consumer(&ctx, &consumer)); @@ -257,7 +260,7 @@ Status RoutineLoadTaskExecutor::get_pulsar_partition_backlog(const PPulsarBacklo DCHECK(request.has_pulsar_info()); // This context is meaningless, just for unifing the interface - StreamLoadContext ctx(_exec_env); + StreamLoadContext ctx(_exec_env, _exec_env->load_stream_mgr()); ctx.load_type = TLoadType::ROUTINE_LOAD; ctx.load_src_type = TLoadSourceType::PULSAR; ctx.label = "NaN"; @@ -275,7 +278,6 @@ Status RoutineLoadTaskExecutor::get_pulsar_partition_backlog(const PPulsarBacklo t_info.__set_properties(properties); ctx.pulsar_info = std::make_unique(t_info); - ctx.need_rollback = false; // convert pb repeated value to vector std::vector partitions; @@ -314,7 +316,7 @@ Status RoutineLoadTaskExecutor::submit_task(const TRoutineLoadTask& task) { } // create the context - auto* ctx = new StreamLoadContext(_exec_env); + auto* ctx = new StreamLoadContext(_exec_env, _exec_env->load_stream_mgr()); ctx->load_type = TLoadType::ROUTINE_LOAD; ctx->load_src_type = task.type; ctx->job_id = task.job_id; @@ -348,7 +350,7 @@ Status RoutineLoadTaskExecutor::submit_task(const TRoutineLoadTask& task) { } // the routine load task'txn has alreay began in FE. // so it need to rollback if encounter error. - ctx->need_rollback = true; + set_need_rollback(ctx, _exec_env); if (task.__isset.max_filter_ratio) { ctx->max_filter_ratio = task.max_filter_ratio; } else { @@ -549,9 +551,9 @@ void RoutineLoadTaskExecutor::exec_task(StreamLoadContext* ctx, DataConsumerPool void RoutineLoadTaskExecutor::err_handler(StreamLoadContext* ctx, const Status& st, std::string_view err_msg) { LOG(WARNING) << err_msg << " " << ctx->brief(); ctx->status = st; - if (ctx->need_rollback) { + if (ctx->need_rollback()) { (void)_exec_env->stream_load_executor()->rollback_txn(ctx); - ctx->need_rollback = false; + ctx->clear_need_rollback(); } if (ctx->body_sink != nullptr) { ctx->body_sink->cancel(st); diff --git a/be/src/runtime/stream_load/stream_context_mgr.cpp b/be/src/runtime/stream_load/stream_context_mgr.cpp index f12dd75cc0851d..da8ff20c8bbae7 100644 --- a/be/src/runtime/stream_load/stream_context_mgr.cpp +++ b/be/src/runtime/stream_load/stream_context_mgr.cpp @@ -91,8 +91,9 @@ Status StreamContextMgr::create_channel_context(ExecEnv* exec_env, const std::st int32_t format, StreamLoadContext*& ctx, const TUniqueId& load_id, long txn_id) { auto pipe = std::make_shared(true); - RETURN_IF_ERROR(exec_env->load_stream_mgr()->put(load_id, pipe)); - ctx = new StreamLoadContext(exec_env, load_id); + auto* load_stream_mgr = exec_env->load_stream_mgr(); + RETURN_IF_ERROR(load_stream_mgr->put(load_id, pipe)); + ctx = new StreamLoadContext(exec_env, load_id, load_stream_mgr); if (ctx == nullptr) { return Status::InternalError("allocate stream load context fail"); } @@ -108,7 +109,6 @@ Status StreamContextMgr::create_channel_context(ExecEnv* exec_env, const std::st ctx->start_nanos = UnixSeconds(); ctx->last_active_ts = ctx->start_nanos; - ctx->need_rollback = false; ctx->format = static_cast(format); ctx->body_sink = pipe; diff --git a/be/src/runtime/stream_load/stream_load_context.cpp b/be/src/runtime/stream_load/stream_load_context.cpp index 4c6b15e9dcaa07..096ab35c3480d8 100644 --- a/be/src/runtime/stream_load/stream_load_context.cpp +++ b/be/src/runtime/stream_load/stream_load_context.cpp @@ -36,20 +36,38 @@ #include +#include + +#include "common/logging.h" #include "common/system/master_info.h" #include "compute_env/load/load_stream_mgr.h" -#include "runtime/exec_env.h" -#include "runtime/stream_load/stream_load_executor.h" namespace starrocks { +StreamLoadContext::StreamLoadContext(ExecEnv* exec_env, LoadStreamMgr* load_stream_mgr, IntGauge* running_loads) + : StreamLoadContext(exec_env, UniqueId::gen_uid(), load_stream_mgr, running_loads) {} + +StreamLoadContext::StreamLoadContext(ExecEnv* exec_env, UniqueId id, LoadStreamMgr* load_stream_mgr, + IntGauge* running_loads) + : id(id), _exec_env(exec_env), _load_stream_mgr(load_stream_mgr), _refs(0), _running_loads(running_loads) { + start_nanos = MonotonicNanos(); + if (_running_loads != nullptr) { + _running_loads->increment(1); + } +} + StreamLoadContext::~StreamLoadContext() noexcept { - if (need_rollback) { - (void)_exec_env->stream_load_executor()->rollback_txn(this); - need_rollback = false; + if (_need_rollback) { + DCHECK(_rollback_txn_callback); + if (_rollback_txn_callback) { + (void)_rollback_txn_callback(this); + } + clear_need_rollback(); } - _exec_env->load_stream_mgr()->remove(id); + if (_load_stream_mgr != nullptr) { + _load_stream_mgr->remove(id); + } if (_running_loads != nullptr) { _running_loads->increment(-1); } @@ -335,6 +353,17 @@ bool StreamLoadContext::check_and_set_http_limiter(ConcurrentLimiter* limiter) { return _http_limiter_guard->set_limiter(limiter); } +void StreamLoadContext::set_need_rollback(RollbackTxnCallback callback) { + CHECK(callback); + _need_rollback = true; + _rollback_txn_callback = std::move(callback); +} + +void StreamLoadContext::clear_need_rollback() { + _need_rollback = false; + _rollback_txn_callback = nullptr; +} + void StreamLoadContext::release(StreamLoadContext* context) { if (context != nullptr && context->unref()) { delete context; diff --git a/be/src/runtime/stream_load/stream_load_context.h b/be/src/runtime/stream_load/stream_load_context.h index 771862d2268952..c2a7f9297dc983 100644 --- a/be/src/runtime/stream_load/stream_load_context.h +++ b/be/src/runtime/stream_load/stream_load_context.h @@ -37,6 +37,7 @@ #include #include +#include #include #include @@ -58,6 +59,7 @@ namespace starrocks { +class LoadStreamMgr; class RuntimeProfile; // kafka related info @@ -146,16 +148,12 @@ const std::string DEFAULT_WAREHOUSE = "default_warehouse"; class StreamLoadContext { public: - explicit StreamLoadContext(ExecEnv* exec_env, IntGauge* running_loads = nullptr) - : StreamLoadContext(exec_env, UniqueId::gen_uid(), running_loads) {} - - explicit StreamLoadContext(ExecEnv* exec_env, UniqueId id, IntGauge* running_loads = nullptr) - : id(id), _exec_env(exec_env), _refs(0), _running_loads(running_loads) { - start_nanos = MonotonicNanos(); - if (_running_loads != nullptr) { - _running_loads->increment(1); - } - } + using RollbackTxnCallback = std::function; + + StreamLoadContext(ExecEnv* exec_env, LoadStreamMgr* load_stream_mgr, IntGauge* running_loads = nullptr); + + StreamLoadContext(ExecEnv* exec_env, UniqueId id, LoadStreamMgr* load_stream_mgr, + IntGauge* running_loads = nullptr); ~StreamLoadContext() noexcept; @@ -176,6 +174,10 @@ class StreamLoadContext { bool check_and_set_http_limiter(ConcurrentLimiter* limiter); + void set_need_rollback(RollbackTxnCallback callback); + void clear_need_rollback(); + bool need_rollback() const { return _need_rollback; } + static void release(StreamLoadContext* context); // Returns the Thrift RPC timeout (in milliseconds) shared by the stream-load plan (put) @@ -295,7 +297,6 @@ class StreamLoadContext { std::mutex lock; std::shared_ptr body_sink; - bool need_rollback = false; int64_t txn_id = -1; std::promise promise; @@ -361,8 +362,11 @@ class StreamLoadContext { private: ExecEnv* _exec_env; + LoadStreamMgr* _load_stream_mgr; std::atomic _refs; IntGauge* _running_loads; + bool _need_rollback = false; + RollbackTxnCallback _rollback_txn_callback; }; } // namespace starrocks diff --git a/be/src/runtime/stream_load/stream_load_executor.cpp b/be/src/runtime/stream_load/stream_load_executor.cpp index 4533304a281f04..9223ab83cf4fee 100644 --- a/be/src/runtime/stream_load/stream_load_executor.cpp +++ b/be/src/runtime/stream_load/stream_load_executor.cpp @@ -72,6 +72,7 @@ static StatusOr get_txn_status(const AuthInfo& auth, s std::string_view table, int64_t txn_id); static bool wait_txn_visible_until(const AuthInfo& auth, std::string_view db, std::string_view table, int64_t txn_id, int64_t deadline); +static void set_need_rollback(StreamLoadContext* ctx, ExecEnv* exec_env); Status StreamLoadExecutor::execute_plan_fragment(StreamLoadContext* ctx) { if (process_exit_in_progress()) { @@ -210,7 +211,7 @@ Status StreamLoadExecutor::begin_txn(StreamLoadContext* ctx) { return status; } ctx->txn_id = result.txnId; - ctx->need_rollback = true; + set_need_rollback(ctx, _exec_env); ctx->load_deadline_sec = UnixSeconds() + result.timeout; return Status::OK(); @@ -245,10 +246,10 @@ Status StreamLoadExecutor::commit_txn(StreamLoadContext* ctx) { RETURN_IF_ERROR(commit_txn_internal(request, rpc_timeout_ms, &result)); Status st(result.status); if (st.ok()) { - ctx->need_rollback = false; + ctx->clear_need_rollback(); return st; } else if (st.is_publish_timeout()) { - ctx->need_rollback = false; + ctx->clear_need_rollback(); bool visible = wait_txn_visible_until(ctx->auth, request.db, request.tbl, request.txnId, ctx->load_deadline_sec); return visible ? Status::OK() : st; @@ -258,7 +259,7 @@ Status StreamLoadExecutor::commit_txn(StreamLoadContext* ctx) { std::this_thread::sleep_for(std::chrono::milliseconds(result.retry_interval_ms)); } else if (st.is_time_out()) { if (++retry > 1) { - ctx->need_rollback = true; + set_need_rollback(ctx, _exec_env); return st; } LOG(WARNING) << "commit transaction " << request.txnId << " failed, will retry. errmsg=" << st.message(); @@ -266,7 +267,7 @@ Status StreamLoadExecutor::commit_txn(StreamLoadContext* ctx) { rpc_timeout_ms = (ctx->load_deadline_sec - UnixSeconds()) * 1000; } } else { - ctx->need_rollback = true; + set_need_rollback(ctx, _exec_env); return st; } } @@ -375,7 +376,7 @@ Status StreamLoadExecutor::prepare_txn(StreamLoadContext* ctx) { return status; } // commit success, set need_rollback to false - ctx->need_rollback = false; + ctx->clear_need_rollback(); return Status::OK(); } @@ -517,4 +518,10 @@ bool StreamLoadExecutor::collect_load_stat(StreamLoadContext* ctx, TTxnCommitAtt return false; } +void set_need_rollback(StreamLoadContext* ctx, ExecEnv* exec_env) { + ctx->set_need_rollback([exec_env](StreamLoadContext* rollback_ctx) { + return exec_env->stream_load_executor()->rollback_txn(rollback_ctx); + }); +} + } // namespace starrocks diff --git a/be/src/runtime/stream_load/transaction_mgr.cpp b/be/src/runtime/stream_load/transaction_mgr.cpp index 8eed15ecd88708..ad718cefa0088d 100644 --- a/be/src/runtime/stream_load/transaction_mgr.cpp +++ b/be/src/runtime/stream_load/transaction_mgr.cpp @@ -92,7 +92,7 @@ std::string TransactionMgr::_build_reply(const std::string& txn_op, StreamLoadCo } std::string TransactionMgr::_build_reply(const std::string& label, const std::string& txn_op, const Status& st) { - auto ctx = std::make_unique(_exec_env); + auto ctx = std::make_unique(_exec_env, _exec_env->load_stream_mgr()); ctx->label = label; return ctx->to_resp_json(txn_op, st); } @@ -164,14 +164,14 @@ Status TransactionMgr::begin_transaction(const HttpRequest* req, std::string* re Status st; auto ctx = _exec_env->stream_context_mgr()->get(label); if (ctx == nullptr) { - ctx = new StreamLoadContext(_exec_env, + ctx = new StreamLoadContext(_exec_env, _exec_env->load_stream_mgr(), &StreamLoadMetrics::instance()->transaction_streaming_load_current_processing); ctx->ref(); std::lock_guard l(ctx->lock); st = _begin_transaction(req, ctx); if (!st.ok()) { ctx->status = st; - if (ctx->need_rollback) { + if (ctx->need_rollback()) { (void)_rollback_transaction(ctx); } } @@ -267,7 +267,7 @@ Status TransactionMgr::commit_transaction(const HttpRequest* req, std::string* r if (!st.ok()) { LOG(ERROR) << "Fail to commit txn: " << st << " " << ctx->brief(); ctx->status = st; - if (ctx->need_rollback) { + if (ctx->need_rollback()) { (void)_rollback_transaction(ctx); } } @@ -385,6 +385,7 @@ Status TransactionMgr::_rollback_transaction(StreamLoadContext* ctx) { // 3. rollback transaction by send request to FE RETURN_IF_ERROR(_exec_env->stream_load_executor()->rollback_txn(ctx)); + ctx->clear_need_rollback(); // 4. remove stream load context // By remove context at the end, we can retry when the rollback FE fails diff --git a/be/test/http/stream_load_test.cpp b/be/test/http/stream_load_test.cpp index 4460e9559402db..017115290089ce 100644 --- a/be/test/http/stream_load_test.cpp +++ b/be/test/http/stream_load_test.cpp @@ -332,7 +332,7 @@ TEST_F(StreamLoadActionTest, plan_fail) { TEST_F(StreamLoadActionTest, huge_malloc) { StreamLoadAction action(&_env, _limiter.get()); - auto ctx = new StreamLoadContext(&_env); + auto ctx = new StreamLoadContext(&_env, _env.load_stream_mgr()); ctx->ref(); ctx->body_sink = std::make_shared(); HttpRequest request(_evhttp_req); @@ -543,7 +543,7 @@ TEST_F(StreamLoadActionTest, enable_batch_write_wrong_argument) { TEST_F(StreamLoadActionTest, merge_commit_response) { // success { - StreamLoadContext ctx(&_env); + StreamLoadContext ctx(&_env, _env.load_stream_mgr()); ctx.enable_batch_write = true; ctx.status = Status::OK(); ctx.txn_id = 1; @@ -583,7 +583,7 @@ TEST_F(StreamLoadActionTest, merge_commit_response) { // fail { - StreamLoadContext ctx(&_env); + StreamLoadContext ctx(&_env, _env.load_stream_mgr()); ctx.enable_batch_write = true; ctx.status = Status::InternalError("TestFail"); ctx.txn_id = 2; diff --git a/be/test/http/transaction_stream_load_test.cpp b/be/test/http/transaction_stream_load_test.cpp index 0113ea5e150f2a..7167a735b9894f 100644 --- a/be/test/http/transaction_stream_load_test.cpp +++ b/be/test/http/transaction_stream_load_test.cpp @@ -861,7 +861,7 @@ TEST_F(TransactionStreamLoadActionTest, txn_not_same_load) { TEST_F(TransactionStreamLoadActionTest, huge_malloc) { TransactionStreamLoadAction action(&_env); - auto ctx = new StreamLoadContext(&_env); + auto ctx = new StreamLoadContext(&_env, _env.load_stream_mgr()); ctx->db = "db"; ctx->table = "tbl"; ctx->label = "huge_malloc"; @@ -936,7 +936,7 @@ TEST_F(TransactionStreamLoadActionTest, huge_malloc) { TEST_F(TransactionStreamLoadActionTest, release_resource_for_success_request) { TransactionStreamLoadAction action(&_env); - auto ctx = new StreamLoadContext(&_env); + auto ctx = new StreamLoadContext(&_env, _env.load_stream_mgr()); ctx->ref(); ctx->db = "db"; ctx->table = "tbl"; @@ -998,7 +998,7 @@ TEST_F(TransactionStreamLoadActionTest, release_resource_for_success_request) { TEST_F(TransactionStreamLoadActionTest, release_resource_for_on_header_failure) { TransactionStreamLoadAction action(&_env); - auto ctx = new StreamLoadContext(&_env); + auto ctx = new StreamLoadContext(&_env, _env.load_stream_mgr()); ctx->ref(); ctx->db = "db"; ctx->table = "tbl"; @@ -1055,7 +1055,7 @@ TEST_F(TransactionStreamLoadActionTest, release_resource_for_on_header_failure) TEST_F(TransactionStreamLoadActionTest, on_header_invalid_envelope) { TransactionStreamLoadAction action(&_env); - auto ctx = new StreamLoadContext(&_env); + auto ctx = new StreamLoadContext(&_env, _env.load_stream_mgr()); ctx->ref(); ctx->db = "db"; ctx->table = "tbl"; @@ -1093,7 +1093,7 @@ TEST_F(TransactionStreamLoadActionTest, on_header_invalid_envelope) { TEST_F(TransactionStreamLoadActionTest, release_resource_for_not_handle) { TransactionStreamLoadAction action(&_env); - auto ctx = new StreamLoadContext(&_env); + auto ctx = new StreamLoadContext(&_env, _env.load_stream_mgr()); ctx->ref(); ctx->db = "db"; ctx->table = "tbl"; diff --git a/be/test/runtime/batch_write/batch_write_mgr_test.cpp b/be/test/runtime/batch_write/batch_write_mgr_test.cpp index d48b24bc2a6a99..0a6291cbbb0e74 100644 --- a/be/test/runtime/batch_write/batch_write_mgr_test.cpp +++ b/be/test/runtime/batch_write/batch_write_mgr_test.cpp @@ -63,7 +63,7 @@ class BatchWriteMgrTest : public testing::Test { } StreamLoadContext* build_data_context(const BatchWriteId& batch_write_id, const std::string& data) { - StreamLoadContext* ctx = new StreamLoadContext(_exec_env); + StreamLoadContext* ctx = new StreamLoadContext(_exec_env, _exec_env->load_stream_mgr()); ctx->ref(); ctx->db = batch_write_id.db; ctx->table = batch_write_id.table; diff --git a/be/test/runtime/batch_write/isomorphic_batch_write_test.cpp b/be/test/runtime/batch_write/isomorphic_batch_write_test.cpp index f464640bcfa2c1..09b7cf2ee0f4df 100644 --- a/be/test/runtime/batch_write/isomorphic_batch_write_test.cpp +++ b/be/test/runtime/batch_write/isomorphic_batch_write_test.cpp @@ -65,7 +65,7 @@ class IsomorphicBatchWriteTest : public testing::Test { StreamLoadContext* build_pipe_context(const std::string& label, int64_t txn_id, const BatchWriteId& batch_write_id, std::shared_ptr pipe) { - StreamLoadContext* ctx = new StreamLoadContext(_exec_env); + StreamLoadContext* ctx = new StreamLoadContext(_exec_env, _exec_env->load_stream_mgr()); ctx->ref(); ctx->db = batch_write_id.db; ctx->table = batch_write_id.table; @@ -79,7 +79,7 @@ class IsomorphicBatchWriteTest : public testing::Test { } StreamLoadContext* build_data_context(const BatchWriteId& batch_write_id, const std::string& data) { - StreamLoadContext* ctx = new StreamLoadContext(_exec_env); + StreamLoadContext* ctx = new StreamLoadContext(_exec_env, _exec_env->load_stream_mgr()); ctx->ref(); ctx->db = batch_write_id.db; ctx->table = batch_write_id.table; diff --git a/be/test/runtime/routine_load/data_consumer_test.cpp b/be/test/runtime/routine_load/data_consumer_test.cpp index 2051eefb325f44..a8e2fb4367d3a3 100644 --- a/be/test/runtime/routine_load/data_consumer_test.cpp +++ b/be/test/runtime/routine_load/data_consumer_test.cpp @@ -29,7 +29,8 @@ TEST_F(KafkaDataConsumerTest, test_get_partition_offset_broker_down) { tKafkaLoadInfo.topic = "test_topic"; tKafkaLoadInfo.partition_begin_offset = {{0, 100}}; auto kafka_info = std::make_unique(tKafkaLoadInfo); - StreamLoadContext context(ExecEnv::GetInstance()); + auto* exec_env = ExecEnv::GetInstance(); + StreamLoadContext context(exec_env, exec_env->load_stream_mgr()); context.kafka_info = std::move(kafka_info); KafkaDataConsumer consumer(&context); @@ -52,7 +53,8 @@ TEST_F(KafkaDataConsumerTest, test_get_partition_meta_broker_down) { tKafkaLoadInfo.topic = "test_topic"; tKafkaLoadInfo.partition_begin_offset = {{0, 100}}; auto kafka_info = std::make_unique(tKafkaLoadInfo); - StreamLoadContext context(ExecEnv::GetInstance()); + auto* exec_env = ExecEnv::GetInstance(); + StreamLoadContext context(exec_env, exec_env->load_stream_mgr()); context.kafka_info = std::move(kafka_info); KafkaDataConsumer consumer(&context); diff --git a/be/test/runtime/stream_load/stream_load_context_test.cpp b/be/test/runtime/stream_load/stream_load_context_test.cpp index a00e22de65c1fd..1d11398cf3d13e 100644 --- a/be/test/runtime/stream_load/stream_load_context_test.cpp +++ b/be/test/runtime/stream_load/stream_load_context_test.cpp @@ -21,6 +21,8 @@ #include "base/testutil/assert.h" #include "base/uid_util.h" #include "base/utility/defer_op.h" +#include "compute_env/load/load_stream_mgr.h" +#include "compute_env/load/stream_load_pipe.h" #include "runtime/exec_env.h" #include "runtime/message_body_sink.h" #include "runtime/stream_load/stream_context_mgr.h" @@ -92,7 +94,7 @@ TEST_F(StreamLoadContextTest, calc_put_and_commit_rpc_timeout_ms) { for (const auto& tc : test_cases) { SCOPED_TRACE(tc.description); - StreamLoadContext ctx(_exec_env); + StreamLoadContext ctx(_exec_env, _exec_env->load_stream_mgr()); ctx.timeout_second = tc.timeout_second; int32_t original_timeout_second = ctx.timeout_second; @@ -103,6 +105,60 @@ TEST_F(StreamLoadContextTest, calc_put_and_commit_rpc_timeout_ms) { } } +TEST_F(StreamLoadContextTest, destructor_removes_pipe_from_injected_load_stream_mgr) { + LoadStreamMgr load_stream_mgr; + UniqueId load_id = UniqueId::gen_uid(); + auto pipe = std::make_shared(); + + ASSERT_OK(load_stream_mgr.put(load_id, pipe)); + ASSERT_NE(nullptr, load_stream_mgr.get(load_id)); + + { StreamLoadContext ctx(_exec_env, load_id, &load_stream_mgr); } + + EXPECT_EQ(nullptr, load_stream_mgr.get(load_id)); +} + +TEST_F(StreamLoadContextTest, destructor_runs_rollback_callback_when_needed) { + int rollback_count = 0; + StreamLoadContext* expected_ctx = nullptr; + + { + StreamLoadContext ctx(nullptr, nullptr); + expected_ctx = &ctx; + ctx.set_need_rollback([&](StreamLoadContext* callback_ctx) { + EXPECT_EQ(expected_ctx, callback_ctx); + ++rollback_count; + return Status::OK(); + }); + + EXPECT_TRUE(ctx.need_rollback()); + } + + EXPECT_EQ(1, rollback_count); +} + +TEST_F(StreamLoadContextTest, clear_need_rollback_prevents_destructor_rollback_callback) { + int rollback_count = 0; + + { + StreamLoadContext ctx(nullptr, nullptr); + ctx.set_need_rollback([&](StreamLoadContext* /*callback_ctx*/) { + ++rollback_count; + return Status::OK(); + }); + ctx.clear_need_rollback(); + + EXPECT_FALSE(ctx.need_rollback()); + } + + EXPECT_EQ(0, rollback_count); +} + +TEST_F(StreamLoadContextTest, set_need_rollback_rejects_empty_callback) { + StreamLoadContext ctx(nullptr, nullptr); + ASSERT_DEATH(ctx.set_need_rollback(nullptr), "callback"); +} + TEST_F(StreamLoadContextTest, stream_load_context_handle_cancel_and_close_channel_context) { StreamContextMgr stream_context_mgr; StreamLoadContext* ctx = nullptr;