diff --git a/integration-tests/Cargo.lock b/integration-tests/Cargo.lock index 7bf8a977c..d5901658d 100644 --- a/integration-tests/Cargo.lock +++ b/integration-tests/Cargo.lock @@ -1664,6 +1664,7 @@ dependencies = [ "bitcoin_core_sv2", "clap", "config", + "dashmap", "hex", "hotpath", "serde", @@ -2204,6 +2205,7 @@ dependencies = [ "bitcoin_core_sv2", "clap", "config", + "dashmap", "hex", "hotpath", "serde", diff --git a/miner-apps/Cargo.lock b/miner-apps/Cargo.lock index 2e64eba04..eb9cb6a5e 100644 --- a/miner-apps/Cargo.lock +++ b/miner-apps/Cargo.lock @@ -1524,6 +1524,7 @@ dependencies = [ "bitcoin_core_sv2", "clap", "config", + "dashmap", "hex", "hotpath", "serde", diff --git a/miner-apps/jd-client/Cargo.toml b/miner-apps/jd-client/Cargo.toml index 43085be54..e7918405f 100644 --- a/miner-apps/jd-client/Cargo.toml +++ b/miner-apps/jd-client/Cargo.toml @@ -18,6 +18,7 @@ path = "src/lib/mod.rs" [dependencies] stratum-apps = { path = "../../stratum-apps", features = ["jd_client"] } async-channel = "1.5.1" +dashmap = "6.1.0" serde = { version = "1.0.89", default-features = false, features = ["derive", "alloc"] } tokio = { version = "1.44.1", features = ["full"] } ext-config = { version = "0.14.0", features = ["toml"], package = "config" } diff --git a/miner-apps/jd-client/src/lib/channel_manager/downstream_message_handler.rs b/miner-apps/jd-client/src/lib/channel_manager/downstream_message_handler.rs index d108d0dda..63bc0c2bd 100644 --- a/miner-apps/jd-client/src/lib/channel_manager/downstream_message_handler.rs +++ b/miner-apps/jd-client/src/lib/channel_manager/downstream_message_handler.rs @@ -111,33 +111,35 @@ impl RouteMessageTo<'_> { ) -> Result<(), JDCErrorKind> { match self { RouteMessageTo::Downstream((downstream_id, message)) => { - _ = channel_manager_channel.downstream_sender.send(( - downstream_id, - message.into_static(), - None, - )); + let sender = channel_manager_channel + .downstream_sender + .get(&downstream_id) + .map(|r| r.value().clone()); + if let Some(sender) = sender { + sender.send((message.into_static(), None)).await?; + } } RouteMessageTo::Upstream(message) => { if get_jd_mode() != JdMode::SoloMining { let message_static = message.into_static(); let sv2_frame: Sv2Frame = AnyMessage::Mining(message_static).try_into()?; - _ = channel_manager_channel + channel_manager_channel .upstream_sender .send(sv2_frame) - .await; + .await?; } } RouteMessageTo::JobDeclarator(message) => { - _ = channel_manager_channel + channel_manager_channel .jd_sender .send(message.into_static()) - .await; + .await?; } RouteMessageTo::TemplateProvider(message) => { - _ = channel_manager_channel + channel_manager_channel .tp_sender .send(message.into_static()) - .await; + .await?; } } Ok(()) @@ -460,9 +462,15 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { }) })?; - for messages in messages { - let _ = messages.forward(&self.channel_manager_channel).await; + for message in messages { + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } + Ok(()) } @@ -723,8 +731,13 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { }) })?; - for messages in messages { - let _ = messages.forward(&self.channel_manager_channel).await; + for message in messages { + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } Ok(()) @@ -918,8 +931,13 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { messages }); - for messages in messages { - let _ = messages.forward(&self.channel_manager_channel).await; + for message in messages { + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } Ok(()) @@ -1153,8 +1171,13 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { }) })?; - for messages in messages { - let _ = messages.forward(&self.channel_manager_channel).await; + for message in messages { + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } Ok(()) @@ -1411,8 +1434,13 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { }) })?; - for messages in messages { - _ = messages.forward(&self.channel_manager_channel).await; + for message in messages { + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } Ok(()) diff --git a/miner-apps/jd-client/src/lib/channel_manager/mod.rs b/miner-apps/jd-client/src/lib/channel_manager/mod.rs index 76b7b3e2f..a34f7d48e 100644 --- a/miner-apps/jd-client/src/lib/channel_manager/mod.rs +++ b/miner-apps/jd-client/src/lib/channel_manager/mod.rs @@ -7,8 +7,9 @@ use std::{ }, }; -use async_channel::{Receiver, Sender}; +use async_channel::{unbounded, Receiver, Sender}; use bitcoin_core_sv2::CancellationToken; +use dashmap::DashMap; use stratum_apps::{ coinbase_output_constraints::coinbase_output_constraints_message, custom_mutex::Mutex, @@ -55,7 +56,7 @@ use stratum_apps::{ }, }, }; -use tokio::{net::TcpListener, select, sync::broadcast}; +use tokio::{net::TcpListener, select}; use tracing::{debug, error, info, warn}; use crate::{ @@ -65,8 +66,8 @@ use crate::{ error::{self, JDCError, JDCErrorKind, JDCResult}, status::{handle_error, Status, StatusSender}, utils::{ - AtomicUpstreamState, DownstreamChannelJobId, PendingChannelRequest, SharesOrderedByDiff, - UpstreamState, + AtomicUpstreamState, DownstreamChannelJobId, DownstreamMessage, PendingChannelRequest, + SharesOrderedByDiff, UpstreamState, }, }; pub mod downstream_message_handler; @@ -246,7 +247,7 @@ pub struct ChannelManagerChannel { jd_receiver: Receiver>, tp_sender: Sender>, tp_receiver: Receiver>, - downstream_sender: broadcast::Sender<(DownstreamId, Mining<'static>, Option>)>, + downstream_sender: Arc>>, downstream_receiver: Receiver<(DownstreamId, Mining<'static>, Option>)>, } @@ -281,7 +282,6 @@ impl ChannelManager { jd_receiver: Receiver>, tp_sender: Sender>, tp_receiver: Receiver>, - downstream_sender: broadcast::Sender<(DownstreamId, Mining<'static>, Option>)>, downstream_receiver: Receiver<(DownstreamId, Mining<'static>, Option>)>, coinbase_outputs: Vec, supported_extensions: Vec, @@ -337,7 +337,7 @@ impl ChannelManager { jd_receiver, tp_sender, tp_receiver, - downstream_sender, + downstream_sender: Arc::new(DashMap::new()), downstream_receiver, }; @@ -437,11 +437,6 @@ impl ChannelManager { fallback_coordinator: FallbackCoordinator, status_sender: Sender, channel_manager_sender: Sender<(DownstreamId, Mining<'static>, Option>)>, - channel_manager_receiver: broadcast::Sender<( - DownstreamId, - Mining<'static>, - Option>, - )>, supported_extensions: Vec, required_extensions: Vec, ) -> JDCResult<(), error::ChannelManager> { @@ -493,7 +488,6 @@ impl ChannelManager { let fallback_coordinator_inner = fallback_coordinator.clone(); let status_sender_inner = status_sender.clone(); let channel_manager_sender_inner = channel_manager_sender.clone(); - let channel_manager_receiver_inner = channel_manager_receiver.clone(); let task_manager_inner = task_manager_clone.clone(); let supported_extensions_inner = supported_extensions.clone(); let required_extensions_inner = required_extensions.clone(); @@ -531,12 +525,14 @@ impl ChannelManager { } }; + let (channel_manager_sender_ds, channel_manager_receiver_ds) = unbounded(); + let downstream = Downstream::new( downstream_id, channel_id_factory, group_channel, channel_manager_sender_inner, - channel_manager_receiver_inner, + channel_manager_receiver_ds, noise_stream, cancellation_token_inner.clone(), fallback_coordinator_inner.clone(), @@ -545,6 +541,8 @@ impl ChannelManager { required_extensions_inner, ); + this.channel_manager_channel.downstream_sender.insert(downstream_id, channel_manager_sender_ds); + this.channel_manager_data.super_safe_lock(|data| { data.downstream.insert(downstream_id, downstream.clone()); }); @@ -687,6 +685,9 @@ impl ChannelManager { .vardiff .retain(|key, _| key.downstream_id != downstream_id); }); + self.channel_manager_channel + .downstream_sender + .remove(&downstream_id); Ok(()) } @@ -1187,7 +1188,12 @@ impl ChannelManager { }); for message in messages { - let _ = message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } info!("Vardiff update cycle complete"); diff --git a/miner-apps/jd-client/src/lib/channel_manager/template_message_handler.rs b/miner-apps/jd-client/src/lib/channel_manager/template_message_handler.rs index d48e99a47..64e921bb7 100644 --- a/miner-apps/jd-client/src/lib/channel_manager/template_message_handler.rs +++ b/miner-apps/jd-client/src/lib/channel_manager/template_message_handler.rs @@ -234,7 +234,12 @@ impl HandleTemplateDistributionMessagesFromServerAsync for ChannelManager { } for message in messages { - let _ = message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } Ok(()) @@ -606,7 +611,12 @@ impl HandleTemplateDistributionMessagesFromServerAsync for ChannelManager { } for message in messages { - let _ = message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } Ok(()) diff --git a/miner-apps/jd-client/src/lib/channel_manager/upstream_message_handler.rs b/miner-apps/jd-client/src/lib/channel_manager/upstream_message_handler.rs index 7c2932a4c..5b0494d19 100644 --- a/miner-apps/jd-client/src/lib/channel_manager/upstream_message_handler.rs +++ b/miner-apps/jd-client/src/lib/channel_manager/upstream_message_handler.rs @@ -482,7 +482,12 @@ impl HandleMiningMessagesFromServerAsync for ChannelManager { })?; for message in messages_results.into_iter().flatten() { - let _ = message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } Ok(()) } diff --git a/miner-apps/jd-client/src/lib/downstream/common_message_handler.rs b/miner-apps/jd-client/src/lib/downstream/common_message_handler.rs index 0786c70ef..affafba2f 100644 --- a/miner-apps/jd-client/src/lib/downstream/common_message_handler.rs +++ b/miner-apps/jd-client/src/lib/downstream/common_message_handler.rs @@ -14,7 +14,7 @@ use stratum_apps::{ }, utils::types::Sv2Frame, }; -use tracing::info; +use tracing::{error, info}; #[cfg_attr(not(test), hotpath::measure_all)] impl HandleCommonMessagesFromClientAsync for Downstream { @@ -69,7 +69,12 @@ impl HandleCommonMessagesFromClientAsync for Downstream { let frame: Sv2Frame = AnyMessage::Common(response.into_static().into()) .try_into() .map_err(JDCError::shutdown)?; - _ = self.downstream_channel.downstream_sender.send(frame).await; + if let Err(e) = self.downstream_channel.downstream_sender.send(frame).await { + error!( + "Failed to send SetupConnectionError to downstream {}: {e}", + self.downstream_id + ); + } return Err(JDCError::disconnect( JDCErrorKind::SetupConnectionError, @@ -89,7 +94,12 @@ impl HandleCommonMessagesFromClientAsync for Downstream { let frame: Sv2Frame = AnyMessage::Common(response.into_static().into()) .try_into() .map_err(JDCError::shutdown)?; - _ = self.downstream_channel.downstream_sender.send(frame).await; + if let Err(e) = self.downstream_channel.downstream_sender.send(frame).await { + error!( + "Failed to send SetupConnectionError to downstream {}: {e}", + self.downstream_id + ); + } return Err(JDCError::disconnect( JDCErrorKind::SetupConnectionError, @@ -109,7 +119,16 @@ impl HandleCommonMessagesFromClientAsync for Downstream { .try_into() .map_err(JDCError::shutdown)?; - _ = self.downstream_channel.downstream_sender.send(frame).await; + if let Err(e) = self.downstream_channel.downstream_sender.send(frame).await { + error!( + "Failed to send SetupConnectionSuccess to downstream {}: {e}", + self.downstream_id + ); + return Err(JDCError::disconnect( + JDCErrorKind::ChannelErrorSender, + self.downstream_id, + )); + } Ok(()) } diff --git a/miner-apps/jd-client/src/lib/downstream/extensions_message_handler.rs b/miner-apps/jd-client/src/lib/downstream/extensions_message_handler.rs index 40940acc5..4780b993e 100644 --- a/miner-apps/jd-client/src/lib/downstream/extensions_message_handler.rs +++ b/miner-apps/jd-client/src/lib/downstream/extensions_message_handler.rs @@ -89,7 +89,16 @@ impl HandleExtensionsFromClientAsync for Downstream { let frame: Sv2Frame = AnyMessage::Extensions(error.into()) .try_into() .map_err(JDCError::shutdown)?; - _ = self.downstream_channel.downstream_sender.send(frame).await; + if let Err(e) = self.downstream_channel.downstream_sender.send(frame).await { + error!( + "Failed to send RequestExtensionsError to downstream {}: {e}", + self.downstream_id + ); + return Err(JDCError::disconnect( + JDCErrorKind::ChannelErrorSender, + self.downstream_id, + )); + } // If required extensions are missing, the server SHOULD disconnect the client if !missing_required.is_empty() { @@ -121,7 +130,16 @@ impl HandleExtensionsFromClientAsync for Downstream { let frame: Sv2Frame = AnyMessage::Extensions(success.into()) .try_into() .map_err(JDCError::shutdown)?; - _ = self.downstream_channel.downstream_sender.send(frame).await; + if let Err(e) = self.downstream_channel.downstream_sender.send(frame).await { + error!( + "Failed to send RequestExtensionsSuccess to downstream {}: {e}", + self.downstream_id + ); + return Err(JDCError::disconnect( + JDCErrorKind::ChannelErrorSender, + self.downstream_id, + )); + } info!( "Downstream {}: Stored negotiated extensions: {:?}", diff --git a/miner-apps/jd-client/src/lib/downstream/mod.rs b/miner-apps/jd-client/src/lib/downstream/mod.rs index f1594f56f..58d3e37d7 100644 --- a/miner-apps/jd-client/src/lib/downstream/mod.rs +++ b/miner-apps/jd-client/src/lib/downstream/mod.rs @@ -24,7 +24,6 @@ use stratum_apps::{ }; use bitcoin_core_sv2::CancellationToken; -use tokio::sync::broadcast; use tracing::{debug, error, warn}; use crate::{ @@ -72,7 +71,7 @@ pub struct DownstreamData { #[derive(Clone)] pub struct DownstreamChannel { channel_manager_sender: Sender<(DownstreamId, Mining<'static>, Option>)>, - channel_manager_receiver: broadcast::Sender<(DownstreamId, Mining<'static>, Option>)>, + channel_manager_receiver: Receiver<(Mining<'static>, Option>)>, downstream_sender: Sender, downstream_receiver: Receiver, /// Per-connection cancellation token (child of the global token). @@ -98,11 +97,7 @@ impl Downstream { channel_id_factory: AtomicU32, group_channel: GroupChannel<'static, DefaultJobStore>>, channel_manager_sender: Sender<(DownstreamId, Mining<'static>, Option>)>, - channel_manager_receiver: broadcast::Sender<( - DownstreamId, - Mining<'static>, - Option>, - )>, + channel_manager_receiver: Receiver<(Mining<'static>, Option>)>, noise_stream: NoiseTcpStream, cancellation_token: CancellationToken, fallback_coordinator: FallbackCoordinator, @@ -183,7 +178,6 @@ impl Downstream { return; } - let mut receiver = self.downstream_channel.channel_manager_receiver.subscribe(); task_manager.spawn(async move { let fallback_handler = fallback_coordinator.register(); let fallback_token = fallback_coordinator.token(); @@ -209,7 +203,7 @@ impl Downstream { } } } - res = self_clone_2.handle_channel_manager_message(&mut receiver) => { + res = self_clone_2.handle_channel_manager_message() => { if let Err(e) = res { error!(?e, "Error handling channel manager message for {downstream_id}"); if handle_error(&status_sender, e).await { @@ -248,28 +242,20 @@ impl Downstream { } // Handles messages sent from the channel manager to this downstream. - async fn handle_channel_manager_message( - self, - receiver: &mut broadcast::Receiver<(DownstreamId, Mining<'static>, Option>)>, - ) -> JDCResult<(), error::Downstream> { - let (downstream_id, message, _tlv_fields) = match receiver.recv().await { + async fn handle_channel_manager_message(self) -> JDCResult<(), error::Downstream> { + let (message, _tlv_fields) = match self + .downstream_channel + .channel_manager_receiver + .recv() + .await + { Ok(msg) => msg, Err(e) => { - warn!(?e, "Broadcast receive failed"); - return Err(JDCError::shutdown( - JDCErrorKind::BroadcastChannelErrorReceiver(e), - )); + warn!(?e, "Channel receive failed"); + return Err(JDCError::shutdown(JDCErrorKind::ChannelErrorReceiver(e))); } }; - if downstream_id != self.downstream_id { - debug!( - ?downstream_id, - "Message ignored for non-matching downstream" - ); - return Ok(()); - } - let message = AnyMessage::Mining(message); let sv2_frame: Sv2Frame = message.try_into().map_err(JDCError::shutdown)?; diff --git a/miner-apps/jd-client/src/lib/error.rs b/miner-apps/jd-client/src/lib/error.rs index 0e57874ec..727911d26 100644 --- a/miner-apps/jd-client/src/lib/error.rs +++ b/miner-apps/jd-client/src/lib/error.rs @@ -38,7 +38,7 @@ use stratum_apps::{ RequestId, TemplateId, VardiffKey, }, }; -use tokio::{sync::broadcast, time::error::Elapsed}; +use tokio::time::error::Elapsed; pub type JDCResult = Result>; @@ -164,8 +164,6 @@ pub enum JDCErrorKind { ChannelErrorReceiver(async_channel::RecvError), /// Channel sender error ChannelErrorSender, - /// Broadcast channel receiver error - BroadcastChannelErrorReceiver(broadcast::error::RecvError), /// Network helpers error NetworkHelpersError(network_helpers::Error), /// Unexpected message @@ -269,9 +267,6 @@ impl fmt::Display for JDCErrorKind { ParseInt(ref e) => write!(f, "Bad convert from `String` to `int`: `{e:?}`"), ChannelErrorReceiver(ref e) => write!(f, "Channel receive error: `{e:?}`"), Parser(ref e) => write!(f, "Parser error: `{e:?}`"), - BroadcastChannelErrorReceiver(ref e) => { - write!(f, "Broadcast channel receive error: {e:?}") - } ChannelErrorSender => write!(f, "Sender error"), NetworkHelpersError(ref e) => write!(f, "Network error: {e:?}"), UnexpectedMessage(extension_type, message_type) => { @@ -504,6 +499,12 @@ impl From for JDCErrorKind { } } +impl From> for JDCErrorKind { + fn from(_: async_channel::SendError) -> Self { + JDCErrorKind::ChannelErrorSender + } +} + impl HandlerErrorType for JDCError { fn parse_error(error: ParserError) -> Self { Self { diff --git a/miner-apps/jd-client/src/lib/mod.rs b/miner-apps/jd-client/src/lib/mod.rs index d3eab5f2d..c62b1f588 100644 --- a/miner-apps/jd-client/src/lib/mod.rs +++ b/miner-apps/jd-client/src/lib/mod.rs @@ -16,7 +16,7 @@ use stratum_apps::{ tp_type::TemplateProviderType, utils::types::{Sv2Frame, GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}, }; -use tokio::sync::{broadcast, Notify}; +use tokio::sync::Notify; use tracing::{debug, error, info, warn}; use crate::{ @@ -96,8 +96,6 @@ impl JobDeclaratorClient { let (channel_manager_to_jd_sender, channel_manager_to_jd_receiver) = unbounded(); let (jd_to_channel_manager_sender, jd_to_channel_manager_receiver) = unbounded(); - let (channel_manager_to_downstream_sender, _channel_manager_to_downstream_receiver) = - broadcast::channel(10); let (downstream_to_channel_manager_sender, downstream_to_channel_manager_receiver) = unbounded(); @@ -114,7 +112,6 @@ impl JobDeclaratorClient { jd_to_channel_manager_receiver.clone(), channel_manager_to_tp_sender.clone(), tp_to_channel_manager_receiver.clone(), - channel_manager_to_downstream_sender.clone(), downstream_to_channel_manager_receiver, encoded_outputs.clone(), self.config.supported_extensions().to_vec(), @@ -332,7 +329,6 @@ impl JobDeclaratorClient { fallback_coordinator.clone(), status_sender.clone(), downstream_to_channel_manager_sender.clone(), - channel_manager_to_downstream_sender.clone(), self.config.supported_extensions().to_vec(), self.config.required_extensions().to_vec(), ) @@ -398,8 +394,6 @@ impl JobDeclaratorClient { let (channel_manager_to_jd_sender_new, channel_manager_to_jd_receiver_new) = unbounded(); let (jd_to_channel_manager_sender_new, jd_to_channel_manager_receiver_new) = unbounded(); - let (channel_manager_to_downstream_sender_new, _channel_manager_to_downstream_receiver_new) = - broadcast::channel(10); let (downstream_to_channel_manager_sender_new, downstream_to_channel_manager_receiver_new) = unbounded(); @@ -412,7 +406,6 @@ impl JobDeclaratorClient { jd_to_channel_manager_receiver_new.clone(), channel_manager_to_tp_sender.clone(), tp_to_channel_manager_receiver.clone(), - channel_manager_to_downstream_sender_new.clone(), downstream_to_channel_manager_receiver_new.clone(), encoded_outputs.clone(), self.config.supported_extensions().to_vec(), @@ -538,7 +531,6 @@ impl JobDeclaratorClient { fallback_coordinator.clone(), status_sender.clone(), downstream_to_channel_manager_sender_new.clone(), - channel_manager_to_downstream_sender_new.clone(), self.config.supported_extensions().to_vec(), self.config.required_extensions().to_vec(), ) diff --git a/miner-apps/jd-client/src/lib/utils.rs b/miner-apps/jd-client/src/lib/utils.rs index 90eda4922..dd44340ba 100644 --- a/miner-apps/jd-client/src/lib/utils.rs +++ b/miner-apps/jd-client/src/lib/utils.rs @@ -32,7 +32,7 @@ use stratum_apps::{ CloseChannel, OpenExtendedMiningChannel, OpenStandardMiningChannel, SubmitSharesExtended, }, - parsers_sv2::{JobDeclaration, Mining}, + parsers_sv2::{JobDeclaration, Mining, Tlv}, }, utils::types::{ChannelId, DownstreamId, Hashrate, JobId}, }; @@ -44,6 +44,8 @@ use crate::{ error::JDCErrorKind, }; +pub(crate) type DownstreamMessage = (Mining<'static>, Option>); + /// Represents a single upstream entry (Pool + JDS pair) with raw address strings /// that are resolved via DNS at connection time. #[derive(Debug, Clone)] diff --git a/miner-apps/translator/src/lib/error.rs b/miner-apps/translator/src/lib/error.rs index 8854b2a44..2b003f6a0 100644 --- a/miner-apps/translator/src/lib/error.rs +++ b/miner-apps/translator/src/lib/error.rs @@ -29,7 +29,6 @@ use stratum_apps::{ MessageType, }, }; -use tokio::sync::broadcast; pub type TproxyResult = Result>; @@ -158,10 +157,6 @@ pub enum TproxyErrorKind { ChannelErrorReceiver(async_channel::RecvError), /// Channel sender error ChannelErrorSender, - /// Broadcast channel receiver error - BroadcastChannelErrorReceiver(broadcast::error::RecvError), - /// Tokio channel receiver error - TokioChannelErrorRecv(tokio::sync::broadcast::error::RecvError), /// Error converting SetDifficulty to Message SetDifficultyToMessage(SetDifficulty), /// Received an unexpected message type @@ -223,11 +218,7 @@ impl fmt::Display for TproxyErrorKind { ParseInt(ref e) => write!(f, "Bad convert from `String` to `int`: `{e:?}`"), PoisonLock => write!(f, "Poison Lock error"), ChannelErrorReceiver(ref e) => write!(f, "Channel receive error: `{e:?}`"), - BroadcastChannelErrorReceiver(ref e) => { - write!(f, "Broadcast channel receive error: {e:?}") - } ChannelErrorSender => write!(f, "Sender error"), - TokioChannelErrorRecv(ref e) => write!(f, "Channel receive error: `{e:?}`"), SetDifficultyToMessage(ref e) => { write!(f, "Error converting SetDifficulty to Message: `{e:?}`") } @@ -333,12 +324,6 @@ impl From for TproxyErrorKind { } } -impl From for TproxyErrorKind { - fn from(e: tokio::sync::broadcast::error::RecvError) -> Self { - TproxyErrorKind::TokioChannelErrorRecv(e) - } -} - //*** LOCK ERRORS *** impl From> for TproxyErrorKind { fn from(_e: PoisonError) -> Self { diff --git a/miner-apps/translator/src/lib/sv1/downstream/channel.rs b/miner-apps/translator/src/lib/sv1/downstream/channel.rs index 151fea760..c69305c59 100644 --- a/miner-apps/translator/src/lib/sv1/downstream/channel.rs +++ b/miner-apps/translator/src/lib/sv1/downstream/channel.rs @@ -1,9 +1,5 @@ use async_channel::{Receiver, Sender}; -use stratum_apps::{ - stratum_core::sv1_api::json_rpc, - utils::types::{ChannelId, DownstreamId}, -}; -use tokio::sync::broadcast; +use stratum_apps::{stratum_core::sv1_api::json_rpc, utils::types::DownstreamId}; use tokio_util::sync::CancellationToken; use tracing::debug; @@ -12,8 +8,7 @@ pub struct DownstreamChannelState { pub downstream_sv1_sender: Sender, pub downstream_sv1_receiver: Receiver, pub sv1_server_sender: Sender<(DownstreamId, json_rpc::Message)>, - pub sv1_server_broadcast: - broadcast::Sender<(ChannelId, Option, json_rpc::Message)>, /* channel_id, optional downstream_id, message */ + pub sv1_server_receiver: Receiver, /// Per-connection cancellation token (child of the global token). /// Cancelled when this downstream's task loop exits, causing /// the associated SV1 I/O task to shut down. @@ -26,17 +21,13 @@ impl DownstreamChannelState { downstream_sv1_sender: Sender, downstream_sv1_receiver: Receiver, sv1_server_sender: Sender<(DownstreamId, json_rpc::Message)>, - sv1_server_broadcast: broadcast::Sender<( - ChannelId, - Option, - json_rpc::Message, - )>, + sv1_server_receiver: Receiver, connection_token: CancellationToken, ) -> Self { Self { downstream_sv1_receiver, downstream_sv1_sender, - sv1_server_broadcast, + sv1_server_receiver, sv1_server_sender, connection_token, } diff --git a/miner-apps/translator/src/lib/sv1/downstream/downstream.rs b/miner-apps/translator/src/lib/sv1/downstream/downstream.rs index 560625c57..0e0569c18 100644 --- a/miner-apps/translator/src/lib/sv1/downstream/downstream.rs +++ b/miner-apps/translator/src/lib/sv1/downstream/downstream.rs @@ -2,7 +2,6 @@ use crate::{ error::{self, TproxyError, TproxyErrorKind, TproxyResult}, status::{handle_error, StatusSender}, sv1::downstream::{channel::DownstreamChannelState, data::DownstreamData}, - utils::AGGREGATED_CHANNEL_ID, }; use async_channel::{Receiver, Sender}; use std::{ @@ -23,9 +22,8 @@ use stratum_apps::{ }, }, task_manager::TaskManager, - utils::types::{ChannelId, DownstreamId, Hashrate}, + utils::types::{DownstreamId, Hashrate}, }; -use tokio::sync::broadcast; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, warn}; @@ -62,11 +60,7 @@ impl Downstream { downstream_sv1_sender: Sender, downstream_sv1_receiver: Receiver, sv1_server_sender: Sender<(DownstreamId, json_rpc::Message)>, - sv1_server_broadcast: broadcast::Sender<( - ChannelId, - Option, - json_rpc::Message, - )>, + sv1_server_receiver: Receiver, target: Target, hashrate: Option, connection_token: CancellationToken, @@ -76,7 +70,7 @@ impl Downstream { downstream_sv1_sender, downstream_sv1_receiver, sv1_server_sender, - sv1_server_broadcast, + sv1_server_receiver, connection_token, ); Self { @@ -106,10 +100,6 @@ impl Downstream { status_sender: StatusSender, task_manager: Arc, ) { - let mut sv1_server_receiver = self - .downstream_channel_state - .sv1_server_broadcast - .subscribe(); let downstream_id = self.downstream_id; task_manager.spawn(async move { // we just spawned a new task that's relevant to fallback coordination @@ -141,7 +131,7 @@ impl Downstream { } // Handle server -> downstream message - res = self.handle_sv1_server_message(&mut sv1_server_receiver) => { + res = self.handle_sv1_server_message() => { if let Err(e) = res { error!("Downstream {downstream_id}: error in server message handler: {e:?}"); if handle_error(&status_sender, e).await { @@ -181,25 +171,16 @@ impl Downstream { /// complete /// - On handshake completion: sends cached messages in correct order (set_difficulty first, /// then notify) - pub async fn handle_sv1_server_message( - &self, - sv1_server_receiver: &mut broadcast::Receiver<( - ChannelId, - Option, - json_rpc::Message, - )>, - ) -> TproxyResult<(), error::Downstream> { - match sv1_server_receiver.recv().await { - Ok((channel_id, downstream_id, message)) => { - let my_channel_id = self.downstream_data.super_safe_lock(|d| d.channel_id); - let my_downstream_id = self.downstream_id; + pub async fn handle_sv1_server_message(&self) -> TproxyResult<(), error::Downstream> { + match self + .downstream_channel_state + .sv1_server_receiver + .recv() + .await + { + Ok(message) => { + let downstream_id = self.downstream_id; let handshake_complete = self.sv1_handshake_complete.load(Ordering::SeqCst); - let id_matches = (my_channel_id == Some(channel_id) - || channel_id == AGGREGATED_CHANNEL_ID) - && (downstream_id.is_none() || downstream_id == Some(my_downstream_id)); - if !id_matches { - return Ok(()); // Message not intended for this downstream - } // Check if this is a queued message response let is_queued_sv1_handshake_response = self @@ -267,7 +248,7 @@ impl Downstream { "Down: Failed to send mining.set_difficulty to downstream: {:?}", e ); - TproxyError::disconnect(TproxyErrorKind::ChannelErrorSender, downstream_id.unwrap_or(0)) + TproxyError::disconnect(TproxyErrorKind::ChannelErrorSender, downstream_id) })?; } @@ -279,7 +260,7 @@ impl Downstream { .await .map_err(|e| { error!("Down: Failed to send mining.notify to downstream: {:?}", e); - TproxyError::disconnect(TproxyErrorKind::ChannelErrorSender, downstream_id.unwrap_or(0)) + TproxyError::disconnect(TproxyErrorKind::ChannelErrorSender, downstream_id) })?; } return Ok(()); @@ -297,7 +278,7 @@ impl Downstream { ); TproxyError::disconnect( TproxyErrorKind::ChannelErrorSender, - downstream_id.unwrap_or(0), + downstream_id, ) })?; } @@ -338,7 +319,7 @@ impl Downstream { error!("Down: Failed to send queued message to downstream: {:?}", e); TproxyError::disconnect( TproxyErrorKind::ChannelErrorSender, - downstream_id.unwrap_or(0), + downstream_id, ) })?; } else { @@ -348,12 +329,11 @@ impl Downstream { } } Err(e) => { - let downstream_id = self.downstream_id; error!( "Sv1 message handler error for downstream {}: {:?}", - downstream_id, e + self.downstream_id, e ); - return Err(TproxyError::disconnect(e, downstream_id)); + return Err(TproxyError::shutdown(e)); } } diff --git a/miner-apps/translator/src/lib/sv1/sv1_server/channel.rs b/miner-apps/translator/src/lib/sv1/sv1_server/channel.rs index 035651749..f8471078d 100644 --- a/miner-apps/translator/src/lib/sv1/sv1_server/channel.rs +++ b/miner-apps/translator/src/lib/sv1/sv1_server/channel.rs @@ -1,16 +1,13 @@ use async_channel::{unbounded, Receiver, Sender}; +use dashmap::DashMap; +use std::sync::Arc; use stratum_apps::stratum_core::parsers_sv2::{Mining, Tlv}; -use stratum_apps::{ - stratum_core::sv1_api::json_rpc, - utils::types::{ChannelId, DownstreamId}, -}; -use tokio::sync::broadcast; +use stratum_apps::{stratum_core::sv1_api::json_rpc, utils::types::DownstreamId}; #[derive(Clone)] pub struct Sv1ServerChannelState { - pub sv1_server_to_downstream_sender: - broadcast::Sender<(ChannelId, Option, json_rpc::Message)>, + pub sv1_server_to_downstream_sender: Arc>>, pub downstream_to_sv1_server_sender: Sender<(DownstreamId, json_rpc::Message)>, pub downstream_to_sv1_server_receiver: Receiver<(DownstreamId, json_rpc::Message)>, pub channel_manager_receiver: Receiver<(Mining<'static>, Option>)>, @@ -23,11 +20,10 @@ impl Sv1ServerChannelState { channel_manager_receiver: Receiver<(Mining<'static>, Option>)>, channel_manager_sender: Sender<(Mining<'static>, Option>)>, ) -> Self { - let (sv1_server_to_downstream_sender, _) = broadcast::channel(1000); let (downstream_to_sv1_server_sender, downstream_to_sv1_server_receiver) = unbounded(); Self { - sv1_server_to_downstream_sender, + sv1_server_to_downstream_sender: Arc::new(DashMap::new()), downstream_to_sv1_server_receiver, downstream_to_sv1_server_sender, channel_manager_receiver, @@ -40,7 +36,5 @@ impl Sv1ServerChannelState { self.channel_manager_sender.close(); self.downstream_to_sv1_server_receiver.close(); self.downstream_to_sv1_server_sender.close(); - self.channel_manager_receiver.close(); - self.channel_manager_sender.close(); } } diff --git a/miner-apps/translator/src/lib/sv1/sv1_server/difficulty_manager.rs b/miner-apps/translator/src/lib/sv1/sv1_server/difficulty_manager.rs index db2bf5638..deb59d788 100644 --- a/miner-apps/translator/src/lib/sv1/sv1_server/difficulty_manager.rs +++ b/miner-apps/translator/src/lib/sv1/sv1_server/difficulty_manager.rs @@ -150,24 +150,27 @@ impl Sv1Server { } // Process immediate set_difficulty updates (for new_target >= upstream_target) - for (channel_id, downstream_id, target) in immediate_updates { + for (_channel_id, downstream_id, target) in immediate_updates { // Send set_difficulty message immediately if let Ok(set_difficulty_msg) = build_sv1_set_difficulty_from_sv2_target(target) { - if let Err(e) = self + let ds_id = downstream_id.unwrap_or(0); + if let Some(sender) = self .sv1_server_channel_state .sv1_server_to_downstream_sender - .send((channel_id, downstream_id, set_difficulty_msg)) + .get(&ds_id) + .map(|r| r.value().clone()) { - error!( - "Failed to send immediate SetDifficulty message to downstream {}: {:?}", - downstream_id.unwrap_or(0), - e - ); - } else { - trace!( - "Sent immediate SetDifficulty to downstream {} (new_target >= upstream_target)", - downstream_id.unwrap_or(0) - ); + if let Err(e) = sender.send(set_difficulty_msg).await { + error!( + "Failed to send immediate SetDifficulty message to downstream {}: {:?}", + ds_id, e + ); + } else { + trace!( + "Sent immediate SetDifficulty to downstream {} (new_target >= upstream_target)", + ds_id + ); + } } } } @@ -346,22 +349,22 @@ impl Sv1Server { channel_id ); - let affected = self.downstreams.iter().find(|downstream| { - downstream - .downstream_data - .super_safe_lock(|d| d.channel_id == Some(channel_id)) - }); - - let Some(downstream) = affected else { + let Some(downstream_id_ref) = self.channel_id_to_downstream_id.get(&channel_id) else { warn!("No downstream found for channel {}", channel_id); return; }; + let downstream_id = *downstream_id_ref; + drop(downstream_id_ref); - let downstream_id = downstream.downstream_id; - - downstream.downstream_data.super_safe_lock(|d| { - d.set_upstream_target(new_upstream_target, downstream_id); - }); + { + let Some(downstream) = self.downstreams.get(&downstream_id) else { + warn!("No downstream found for downstream_id {}", downstream_id); + return; + }; + downstream.downstream_data.super_safe_lock(|d| { + d.set_upstream_target(new_upstream_target, downstream_id); + }); + } trace!("Updated upstream target for downstream {}", downstream_id); @@ -426,7 +429,7 @@ impl Sv1Server { .get(&update.downstream_id) .and_then(|ds| ds.downstream_data.super_safe_lock(|d| d.channel_id)); - let Some(channel_id) = channel_id else { + let Some(_channel_id) = channel_id else { trace!( "Skipping SetDifficulty for downstream {}: no channel_id yet", update.downstream_id @@ -446,17 +449,20 @@ impl Sv1Server { } }; - if let Err(e) = self + if let Some(sender) = self .sv1_server_channel_state .sv1_server_to_downstream_sender - .send((channel_id, Some(update.downstream_id), set_difficulty_msg)) + .get(&update.downstream_id) + .map(|r| r.value().clone()) { - error!( - "Failed to send SetDifficulty to downstream {}: {:?}", - update.downstream_id, e - ); - } else { - trace!("Sent SetDifficulty to downstream {}", update.downstream_id); + if let Err(e) = sender.send(set_difficulty_msg).await { + error!( + "Failed to send SetDifficulty to downstream {}: {:?}", + update.downstream_id, e + ); + } else { + trace!("Sent SetDifficulty to downstream {}", update.downstream_id); + } } } } diff --git a/miner-apps/translator/src/lib/sv1/sv1_server/sv1_server.rs b/miner-apps/translator/src/lib/sv1/sv1_server/sv1_server.rs index f1af8cd2e..b02998757 100644 --- a/miner-apps/translator/src/lib/sv1/sv1_server/sv1_server.rs +++ b/miner-apps/translator/src/lib/sv1/sv1_server/sv1_server.rs @@ -12,7 +12,6 @@ use crate::{ use async_channel::{Receiver, Sender}; use dashmap::DashMap; use std::{ - collections::HashMap, net::SocketAddr, sync::{ atomic::{AtomicU32, AtomicUsize, Ordering}, @@ -73,6 +72,7 @@ pub struct Sv1Server { pub(crate) request_id_factory: Arc, pub(crate) downstreams: Arc>, pub(crate) request_id_to_downstream_id: Arc>, + pub(crate) channel_id_to_downstream_id: Arc>, pub(crate) vardiff: Arc>>>, /// HashMap to store the SetNewPrevHash for each channel /// Used in both aggregated and non-aggregated mode @@ -86,6 +86,54 @@ pub struct Sv1Server { #[cfg_attr(not(test), hotpath::measure_all)] impl Sv1Server { + /// Sends a message to downstream(s) for the given channel_id. + /// + /// In aggregated mode the channel manager rewrites the job's channel_id to + /// `AGGREGATED_CHANNEL_ID` before forwarding, which signals a broadcast: send to every + /// connected downstream. + async fn send_to_channel( + &self, + channel_id: ChannelId, + msg: stratum_apps::stratum_core::sv1_api::json_rpc::Message, + ) { + if channel_id == AGGREGATED_CHANNEL_ID { + // Broadcast to every connected downstream. + for (downstream_id, sender) in self + .sv1_server_channel_state + .sv1_server_to_downstream_sender + .iter() + .map(|e| (*e.key(), e.value().clone())) + .collect::>() + { + if let Err(e) = sender.send(msg.clone()).await { + warn!( + "Failed to send notify to downstream {}: channel closed: {}", + downstream_id, e + ); + } + } + } else { + // Non-aggregated: send to the single downstream that owns this channel_id. + if let Some((downstream_id, sender)) = self + .channel_id_to_downstream_id + .get(&channel_id) + .and_then(|downstream_id| { + self.sv1_server_channel_state + .sv1_server_to_downstream_sender + .get(downstream_id.value()) + .map(|s| (*downstream_id, s.value().clone())) + }) + { + if let Err(e) = sender.send(msg).await { + warn!( + "Failed to send notify to downstream {}: channel closed: {}", + downstream_id, e + ); + } + } + } + } + /// Cleans up server state and closes communication channels. pub fn cleanup(&self) { self.prevhashes.clear(); @@ -94,6 +142,7 @@ impl Sv1Server { self.vardiff.clear(); } self.downstreams.clear(); + self.channel_id_to_downstream_id.clear(); self.request_id_to_downstream_id.clear(); self.pending_target_updates .safe_lock(|updates| updates.clear()) @@ -132,6 +181,7 @@ impl Sv1Server { request_id_factory: Arc::new(AtomicU32::new(1)), downstreams: Arc::new(DashMap::new()), request_id_to_downstream_id: Arc::new(DashMap::new()), + channel_id_to_downstream_id: Arc::new(DashMap::new()), vardiff: Arc::new(DashMap::new()), prevhashes: Arc::new(DashMap::new()), pending_target_updates: Arc::new(Mutex::new(Vec::new())), @@ -233,12 +283,14 @@ impl Sv1Server { connection_token.clone(), ).await; let downstream_id = self.downstream_id_factory.fetch_add(1, Ordering::Relaxed); + let (sv1_server_sender, sv1_server_receiver) = async_channel::unbounded(); + self.sv1_server_channel_state.sv1_server_to_downstream_sender.insert(downstream_id, sv1_server_sender); let downstream = Downstream::new( downstream_id, connection.sender().clone(), connection.receiver().clone(), self.sv1_server_channel_state.downstream_to_sv1_server_sender.clone(), - self.sv1_server_channel_state.sv1_server_to_downstream_sender.clone(), + sv1_server_receiver, first_target, Some(self.config.downstream_difficulty_config.min_individual_miner_hashrate), connection_token, @@ -321,83 +373,84 @@ impl Sv1Server { .await .map_err(TproxyError::shutdown)?; - let downstream = self.downstreams.get(&downstream_id); + let Some(downstream) = self + .downstreams + .get(&downstream_id) + .map(|r| r.value().clone()) + else { + return Ok(()); + }; - if let Some(downstream) = downstream { - let channel_id = downstream + let channel_id = downstream + .downstream_data + .super_safe_lock(|data| data.channel_id); + if channel_id.is_none() { + let is_first_message = downstream .downstream_data - .super_safe_lock(|data| data.channel_id); - if channel_id.is_none() { - let is_first_message = downstream - .downstream_data - .super_safe_lock(|d| d.queued_sv1_handshake_messages.is_empty()); - if is_first_message { - self.handle_open_channel_request(downstream_id).await?; - debug!( - "Down: Sent OpenChannel request for downstream {}", - downstream_id - ); - } - debug!("Down: Queuing Sv1 message until channel is established"); - downstream.downstream_data.super_safe_lock(|data| { - data.queued_sv1_handshake_messages - .push(downstream_message.clone()) - }); - return Ok(()); + .super_safe_lock(|d| d.queued_sv1_handshake_messages.is_empty()); + if is_first_message { + self.handle_open_channel_request(downstream_id).await?; + debug!( + "Down: Sent OpenChannel request for downstream {}", + downstream_id + ); } + debug!("Down: Queuing Sv1 message until channel is established"); + downstream.downstream_data.super_safe_lock(|data| { + data.queued_sv1_handshake_messages + .push(downstream_message.clone()) + }); + return Ok(()); + } - let response = self - .clone() - .handle_message(Some(downstream_id), downstream_message.clone()); + let response = self + .clone() + .handle_message(Some(downstream_id), downstream_message.clone()); - match response { - Ok(Some(response_msg)) => { - debug!( - "Down: Sending Sv1 message to downstream: {:?}", - response_msg - ); - downstream - .downstream_channel_state - .downstream_sv1_sender - .send(response_msg.into()) - .await - .map_err(|error| { - error!("Down: Failed to send message to downstream: {error:?}"); - TproxyError::disconnect( - TproxyErrorKind::ChannelErrorSender, - downstream_id, - ) - })?; - - // Check if this was an authorize message and handle sv1 handshake completion - if let json_rpc::Message::StandardRequest(request) = &downstream_message { - if request.method == "mining.authorize" { - info!("Down: Handling mining.authorize after handshake completion"); - if let Err(e) = downstream.handle_sv1_handshake_completion().await { - error!("Down: Failed to handle handshake completion: {:?}", e); - return Err(TproxyError::disconnect(e, downstream_id)); - } + match response { + Ok(Some(response_msg)) => { + debug!( + "Down: Sending Sv1 message to downstream: {:?}", + response_msg + ); + downstream + .downstream_channel_state + .downstream_sv1_sender + .send(response_msg.into()) + .await + .map_err(|error| { + error!("Down: Failed to send message to downstream: {error:?}"); + TproxyError::disconnect(TproxyErrorKind::ChannelErrorSender, downstream_id) + })?; + + // Check if this was an authorize message and handle sv1 handshake completion + if let json_rpc::Message::StandardRequest(request) = &downstream_message { + if request.method == "mining.authorize" { + info!("Down: Handling mining.authorize after handshake completion"); + if let Err(e) = downstream.handle_sv1_handshake_completion().await { + error!("Down: Failed to handle handshake completion: {:?}", e); + return Err(TproxyError::disconnect(e, downstream_id)); } } } - Ok(None) => { - // Message was handled but no response needed - } - Err(e) => { - error!("Down: Error handling downstream message: {:?}", e); - return Err(TproxyError::disconnect(e, downstream_id)); - } } - - // Check if there's a pending share to send to the Sv1Server - let pending_share = downstream - .downstream_data - .super_safe_lock(|d| d.pending_share.take()); - if let Some(share) = pending_share { - self.handle_submit_shares(share).await?; + Ok(None) => { + // Message was handled but no response needed + } + Err(e) => { + error!("Down: Error handling downstream message: {:?}", e); + return Err(TproxyError::disconnect(e, downstream_id)); } } + // Check if there's a pending share to send to the Sv1Server + let pending_share = downstream + .downstream_data + .super_safe_lock(|d| d.pending_share.take()); + if let Some(share) = pending_share { + self.handle_submit_shares(share).await?; + } + Ok(()) } @@ -454,10 +507,18 @@ impl Sv1Server { // Only add TLV fields with user identity in non-aggregated mode let tlv_fields = if is_non_aggregated() { - let user_identity_string = self + let Some(downstream) = self .downstreams .get(&message.downstream_id) - .unwrap() + .map(|r| r.value().clone()) + else { + warn!( + "Downstream {} disconnected before share could be submitted, dropping share", + message.downstream_id + ); + return Ok(()); + }; + let user_identity_string = downstream .downstream_data .super_safe_lock(|d| d.user_identity.clone()); UserIdentity::new(&user_identity_string) @@ -568,13 +629,16 @@ impl Sv1Server { d.set_upstream_target(initial_target, downstream_id); }) .map_err(TproxyError::shutdown)?; + self.channel_id_to_downstream_id + .insert(m.channel_id, downstream_id); // Process all queued messages now that channel is established - if let Ok(queued_messages) = downstream.downstream_data.safe_lock(|d| { + let queued_messages = downstream.downstream_data.super_safe_lock(|d| { let messages = d.queued_sv1_handshake_messages.clone(); d.queued_sv1_handshake_messages.clear(); messages - }) { + }); + { if !queued_messages.is_empty() { info!( "Processing {} queued Sv1 messages for downstream {}", @@ -591,18 +655,19 @@ impl Sv1Server { if let Ok(Some(response_msg)) = self.clone().handle_message(Some(downstream_id), message) { - self.sv1_server_channel_state + if let Some(sender) = self + .sv1_server_channel_state .sv1_server_to_downstream_sender - .send(( - m.channel_id, - Some(downstream_id), - response_msg.into(), - )) - .map_err(|_| { - TproxyError::shutdown( + .get(&downstream_id) + .map(|r| r.value().clone()) + { + sender.send(response_msg.into()).await.map_err(|_| { + TproxyError::disconnect( TproxyErrorKind::ChannelErrorSender, + downstream_id, ) })?; + } } } } @@ -615,10 +680,19 @@ impl Sv1Server { )) })?; // send the set_difficulty message to the downstream - self.sv1_server_channel_state + if let Some(sender) = self + .sv1_server_channel_state .sv1_server_to_downstream_sender - .send((m.channel_id, None, set_difficulty)) - .map_err(|_| TproxyError::shutdown(TproxyErrorKind::ChannelErrorSender))?; + .get(&downstream_id) + .map(|r| r.value().clone()) + { + sender.send(set_difficulty).await.map_err(|_| { + TproxyError::disconnect( + TproxyErrorKind::ChannelErrorSender, + downstream_id, + ) + })?; + } } else { error!("Downstream not found for downstream_id: {}", downstream_id); } @@ -629,7 +703,12 @@ impl Sv1Server { "Received NewExtendedMiningJob for channel id: {}", m.channel_id ); - if let Some(prevhash) = self.prevhashes.get(&m.channel_id) { + // Clone the prevhash immediately so the DashMap guard is not held across .await. + if let Some(prevhash) = self + .prevhashes + .get(&m.channel_id) + .map(|r| r.value().clone()) + { let prevhash = prevhash.as_static(); let clean_jobs = m.job_id == prevhash.job_id; let notify = @@ -644,16 +723,18 @@ impl Sv1Server { AGGREGATED_CHANNEL_ID }; - let mut channel_jobs = self.valid_sv1_jobs.entry(job_channel_id).or_default(); - if clean_jobs { - channel_jobs.clear(); + { + let mut channel_jobs = + self.valid_sv1_jobs.entry(job_channel_id).or_default(); + if clean_jobs { + channel_jobs.clear(); + } + channel_jobs.push(notify_parsed); } - channel_jobs.push(notify_parsed); - let _ = self - .sv1_server_channel_state - .sv1_server_to_downstream_sender - .send((m.channel_id, None, notify.into())); + let notify_msg: stratum_apps::stratum_core::sv1_api::json_rpc::Message = + notify.into(); + self.send_to_channel(m.channel_id, notify_msg).await; } } @@ -702,7 +783,17 @@ impl Sv1Server { downstream_id: DownstreamId, ) -> TproxyResult<(), error::Sv1Server> { let config = &self.config.downstream_difficulty_config; - let downstream = self.downstreams.get(&downstream_id).unwrap(); + let Some(downstream) = self + .downstreams + .get(&downstream_id) + .map(|r| r.value().clone()) + else { + warn!( + "Downstream {} disconnected before channel could be opened, skipping", + downstream_id + ); + return Ok(()); + }; let hashrate = config.min_individual_miner_hashrate as f64; let shares_per_min = config.shares_per_minute as f64; @@ -744,22 +835,6 @@ impl Sv1Server { Ok(()) } - /// Retrieves a downstream connection by ID from the provided map. - /// - /// # Arguments - /// * `downstream_id` - The ID of the downstream connection to find - /// * `downstream` - HashMap containing downstream connections - /// - /// # Returns - /// * `Some(Downstream)` - If a downstream with the given ID exists - /// * `None` - If no downstream with the given ID is found - pub fn get_downstream( - downstream_id: DownstreamId, - downstream: HashMap, - ) -> Option { - downstream.get(&downstream_id).cloned() - } - /// Extracts the downstream ID from a Downstream instance. /// /// # Arguments @@ -787,6 +862,9 @@ impl Sv1Server { // Only remove from vardiff map if vardiff is enabled self.vardiff.remove(&downstream_id); } + self.sv1_server_channel_state + .sv1_server_to_downstream_sender + .remove(&downstream_id); let current_downstream = self.downstreams.remove(&downstream_id); if let Some((downstream_id, downstream)) = current_downstream { @@ -799,6 +877,7 @@ impl Sv1Server { let channel_id = downstream.downstream_data.super_safe_lock(|d| d.channel_id); if let Some(channel_id) = channel_id { + self.channel_id_to_downstream_id.remove(&channel_id); if !self.config.aggregate_channels { info!("Sending CloseChannel message: {channel_id} for downstream: {downstream_id}"); let reason_code = @@ -883,31 +962,38 @@ impl Sv1Server { target: Target, derived_hashrate: Option, ) -> TproxyResult<(), error::Sv1Server> { - for downstream in self.downstreams.iter() { - let downstream_id = downstream.key(); - let downstream = downstream.value(); - let channel_id = downstream.downstream_data.super_safe_lock(|d| { - let channel_id = d.channel_id?; - - d.set_upstream_target(target, *downstream_id); - d.set_pending_target(target, *downstream_id); - - // Update pending hashrate derived from the upstream target - if let Some(hr) = derived_hashrate { - d.set_pending_hashrate(Some(hr as f32), *downstream_id); + let tasks: Vec<(DownstreamId, _)> = self + .downstreams + .iter() + .filter_map(|entry| { + let downstream_id = *entry.key(); + let has_channel = entry.value().downstream_data.super_safe_lock(|d| { + let channel_id = d.channel_id?; + d.set_upstream_target(target, downstream_id); + d.set_pending_target(target, downstream_id); + if let Some(hr) = derived_hashrate { + d.set_pending_hashrate(Some(hr as f32), downstream_id); + } + Some(channel_id) + }); + if has_channel.is_none() { + trace!( + "Skipping downstream {}: no channel_id set (vardiff disabled)", + downstream_id + ); + return None; } - - Some(channel_id) - }); - - let Some(channel_id) = channel_id else { - trace!( - "Skipping downstream {}: no channel_id set (vardiff disabled)", - downstream_id - ); - continue; - }; - + let sender = self + .sv1_server_channel_state + .sv1_server_to_downstream_sender + .get(&downstream_id)? + .value() + .clone(); + Some((downstream_id, sender)) + }) + .collect(); + + for (downstream_id, sender) in tasks { let set_difficulty_msg = match build_sv1_set_difficulty_from_sv2_target(target) { Ok(msg) => msg, Err(e) => { @@ -918,12 +1004,7 @@ impl Sv1Server { return Err(TproxyError::shutdown(e)); } }; - - if let Err(e) = self - .sv1_server_channel_state - .sv1_server_to_downstream_sender - .send((channel_id, Some(*downstream_id), set_difficulty_msg)) - { + if let Err(e) = sender.send(set_difficulty_msg).await { error!( "Failed to send SetDifficulty to downstream {}: {:?}", downstream_id, e @@ -948,13 +1029,7 @@ impl Sv1Server { target: Target, derived_hashrate: Option, ) -> TproxyResult<(), error::Sv1Server> { - let affected = self.downstreams.iter().find(|downstream| { - downstream - .downstream_data - .super_safe_lock(|d| d.channel_id == Some(channel_id)) - }); - - let Some(downstream) = affected else { + let Some(downstream_id_ref) = self.channel_id_to_downstream_id.get(&channel_id) else { warn!( "No downstream found for channel {} when vardiff is disabled", channel_id @@ -976,17 +1051,17 @@ impl Sv1Server { TproxyErrorKind::DownstreamNotFoundWithChannelId(channel_id), )); }; - - let downstream_id = downstream.key(); - let downstream = downstream.value(); - + let downstream_id = *downstream_id_ref; + drop(downstream_id_ref); + let Some(downstream) = self.downstreams.get(&downstream_id) else { + return Ok(()); + }; downstream.downstream_data.super_safe_lock(|d| { - d.set_upstream_target(target, *downstream_id); - d.set_pending_target(target, *downstream_id); - + d.set_upstream_target(target, downstream_id); + d.set_pending_target(target, downstream_id); // Update pending hashrate derived from the upstream target if let Some(hr) = derived_hashrate { - d.set_pending_hashrate(Some(hr as f32), *downstream_id); + d.set_pending_hashrate(Some(hr as f32), downstream_id); } }); @@ -1001,21 +1076,24 @@ impl Sv1Server { } }; - if let Err(e) = self + let sender = self .sv1_server_channel_state .sv1_server_to_downstream_sender - .send((channel_id, Some(*downstream_id), set_difficulty_msg)) - { - error!( - "Failed to send SetDifficulty to downstream {}: {:?}", - downstream_id, e - ); - return Err(TproxyError::shutdown(TproxyErrorKind::ChannelErrorSender)); - } else { - debug!( - "Sent SetDifficulty to downstream {} for channel {} (vardiff disabled)", - downstream_id, channel_id - ); + .get(&downstream_id) + .map(|r| r.value().clone()); + if let Some(sender) = sender { + if let Err(e) = sender.send(set_difficulty_msg).await { + error!( + "Failed to send SetDifficulty to downstream {}: {:?}", + downstream_id, e + ); + return Err(TproxyError::shutdown(TproxyErrorKind::ChannelErrorSender)); + } else { + debug!( + "Sent SetDifficulty to downstream {} for channel {} (vardiff disabled)", + downstream_id, channel_id + ); + } } Ok(()) } @@ -1130,14 +1208,19 @@ impl Sv1Server { downstream_id, notify.job_id, notify.time.0 ); - if let Err(e) = self + let sent = match self .sv1_server_channel_state .sv1_server_to_downstream_sender - .send((channel_id.unwrap_or(0), Some(downstream_id), notify.into())) + .get(&downstream_id) + .map(|r| r.value().clone()) { + Some(sender) => sender.send(notify.into()).await.is_ok(), + None => false, + }; + if !sent { warn!( - "Failed to send keepalive job to downstream {}: {:?}", - downstream_id, e + "Failed to send keepalive job to downstream {}", + downstream_id ); } else if let Some(downstream) = self.downstreams.get(&downstream_id) { downstream.downstream_data.super_safe_lock(|d| { @@ -1224,7 +1307,7 @@ mod tests { use super::*; use crate::config::{DownstreamDifficultyConfig, TranslatorConfig, Upstream}; use async_channel::unbounded; - use std::{collections::HashMap, str::FromStr}; + use std::str::FromStr; use stratum_apps::key_utils::Secp256k1PublicKey; fn create_test_config() -> TranslatorConfig { @@ -1282,15 +1365,6 @@ mod tests { assert!(server.config.downstream_difficulty_config.enable_vardiff); } - #[test] - fn test_get_downstream_basic() { - let downstreams = HashMap::new(); - - // Test non-existing downstream - let not_found = Sv1Server::get_downstream(999, downstreams); - assert!(not_found.is_none()); - } - #[tokio::test] async fn test_send_set_difficulty_to_all_downstreams_empty() { let server = create_test_sv1_server(); diff --git a/pool-apps/Cargo.lock b/pool-apps/Cargo.lock index e3b322016..f15895de2 100644 --- a/pool-apps/Cargo.lock +++ b/pool-apps/Cargo.lock @@ -769,6 +769,20 @@ dependencies = [ "cipher", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "derive_arbitrary" version = "1.4.2" @@ -1890,6 +1904,7 @@ dependencies = [ "bitcoin_core_sv2", "clap", "config", + "dashmap", "hex", "hotpath", "serde", diff --git a/pool-apps/pool/Cargo.toml b/pool-apps/pool/Cargo.toml index 8d4551c17..9b3095c61 100644 --- a/pool-apps/pool/Cargo.toml +++ b/pool-apps/pool/Cargo.toml @@ -27,6 +27,7 @@ clap = { version = "4.5.39", features = ["derive"] } bitcoin_core_sv2 = { path = "../../bitcoin-core-sv2" } hex = "0.4.3" hotpath = "0.9" +dashmap = "6.1.0" [features] default = ["monitoring"] diff --git a/pool-apps/pool/src/lib/channel_manager/mining_message_handler.rs b/pool-apps/pool/src/lib/channel_manager/mining_message_handler.rs index 643716fbc..7814ce62d 100644 --- a/pool-apps/pool/src/lib/channel_manager/mining_message_handler.rs +++ b/pool-apps/pool/src/lib/channel_manager/mining_message_handler.rs @@ -247,7 +247,12 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { })?; for message in messages { - message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + error!("Failed to forward message {e:?}"); + } } Ok(()) @@ -510,9 +515,13 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { })?; for message in messages { - message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + error!("Failed to forward message {e:?}"); + } } - Ok(()) } @@ -672,7 +681,12 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { })?; for message in messages { - message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + error!("Failed to forward message {e:?}"); + } } Ok(()) @@ -864,7 +878,12 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { })?; for message in messages { - message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + error!("Failed to forward message {e:?}"); + } } Ok(()) @@ -995,7 +1014,12 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { })?; for message in messages { - message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + error!("Failed to forward message {e:?}"); + } } Ok(()) @@ -1086,7 +1110,11 @@ impl HandleMiningMessagesFromClientAsync for ChannelManager { }) })?; - message.forward(&self.channel_manager_channel).await; + message + .forward(&self.channel_manager_channel) + .await + .map_err(|e| PoolError::disconnect(e, downstream_id))?; + Ok(()) } } diff --git a/pool-apps/pool/src/lib/channel_manager/mod.rs b/pool-apps/pool/src/lib/channel_manager/mod.rs index 47c7e34f2..334675ade 100644 --- a/pool-apps/pool/src/lib/channel_manager/mod.rs +++ b/pool-apps/pool/src/lib/channel_manager/mod.rs @@ -7,9 +7,10 @@ use std::{ }, }; -use async_channel::{Receiver, Sender}; +use async_channel::{unbounded, Receiver, Sender}; use bitcoin_core_sv2::CancellationToken; use core::sync::atomic::Ordering; +use dashmap::DashMap; use stratum_apps::{ coinbase_output_constraints::coinbase_output_constraints_message, config_helpers::CoinbaseRewardScript, @@ -37,7 +38,7 @@ use stratum_apps::{ task_manager::TaskManager, utils::types::{ChannelId, DownstreamId, SharesPerMinute, VardiffKey}, }; -use tokio::{net::TcpListener, select, sync::broadcast}; +use tokio::{net::TcpListener, select}; use tracing::{debug, error, info, warn}; use crate::{ @@ -45,6 +46,7 @@ use crate::{ downstream::Downstream, error::{self, PoolError, PoolErrorKind, PoolResult}, status::{handle_error, Status, StatusSender}, + utils::DownstreamMessage, }; mod mining_message_handler; @@ -81,7 +83,7 @@ pub struct ChannelManagerData { pub struct ChannelManagerChannel { tp_sender: Sender>, tp_receiver: Receiver>, - downstream_sender: broadcast::Sender<(usize, Mining<'static>, Option>)>, + downstream_sender: Arc>>, downstream_receiver: Receiver<(usize, Mining<'static>, Option>)>, } @@ -110,7 +112,6 @@ impl ChannelManager { config: PoolConfig, tp_sender: Sender>, tp_receiver: Receiver>, - downstream_sender: broadcast::Sender<(DownstreamId, Mining<'static>, Option>)>, downstream_receiver: Receiver<(DownstreamId, Mining<'static>, Option>)>, coinbase_outputs: Vec, ) -> PoolResult { @@ -150,7 +151,7 @@ impl ChannelManager { let channel_manager_channel = ChannelManagerChannel { tp_sender, tp_receiver, - downstream_sender, + downstream_sender: Arc::new(DashMap::new()), downstream_receiver, }; @@ -230,11 +231,6 @@ impl ChannelManager { cancellation_token: CancellationToken, status_sender: Sender, channel_manager_sender: Sender<(DownstreamId, Mining<'static>, Option>)>, - channel_manager_receiver: broadcast::Sender<( - DownstreamId, - Mining<'static>, - Option>, - )>, ) -> PoolResult<(), error::ChannelManager> { // todo: let start_downstream_server accept Arc, instead of clone. let this = Arc::new(self); @@ -287,7 +283,6 @@ impl ChannelManager { let cancellation_token_inner = cancellation_token_clone.clone(); let status_sender_inner = status_sender.clone(); let channel_manager_sender_inner = channel_manager_sender.clone(); - let channel_manager_receiver_inner = channel_manager_receiver.clone(); let task_manager_inner = task_manager_clone.clone(); task_manager_clone.spawn(async move { @@ -323,12 +318,16 @@ impl ChannelManager { } }; + let (channel_manager_sender, channel_manager_receiver) = unbounded(); + + + let downstream = Downstream::new( downstream_id, channel_id_factory, group_channel, channel_manager_sender_inner, - channel_manager_receiver_inner, + channel_manager_receiver, noise_stream, cancellation_token_inner.clone(), task_manager_inner.clone(), @@ -336,6 +335,8 @@ impl ChannelManager { this.required_extensions.clone(), ); + this.channel_manager_channel.downstream_sender.insert(downstream_id, channel_manager_sender); + this.channel_manager_data.super_safe_lock(|data| { data.downstream.insert(downstream_id, downstream.clone()); }); @@ -431,6 +432,9 @@ impl ChannelManager { .vardiff .retain(|key, _| key.downstream_id != downstream_id); }); + self.channel_manager_channel + .downstream_sender + .remove(&downstream_id); Ok(()) } @@ -611,7 +615,12 @@ impl ChannelManager { }); for message in messages { - message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + error!("Failed to forward message {e:?}"); + } } info!("Vardiff update cycle complete"); @@ -646,7 +655,7 @@ impl ChannelManager { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub enum RouteMessageTo<'a> { /// Route to the template provider subsystem. TemplateProvider(TemplateDistribution<'a>), @@ -667,21 +676,27 @@ impl<'a> From<(DownstreamId, Mining<'a>)> for RouteMessageTo<'a> { } impl RouteMessageTo<'_> { - pub async fn forward(self, channel_manager_channel: &ChannelManagerChannel) { + pub async fn forward( + self, + channel_manager_channel: &ChannelManagerChannel, + ) -> Result<(), PoolErrorKind> { match self { RouteMessageTo::Downstream((downstream_id, message)) => { - _ = channel_manager_channel.downstream_sender.send(( - downstream_id, - message.into_static(), - None, - )); + let sender = channel_manager_channel + .downstream_sender + .get(&downstream_id) + .map(|r| r.value().clone()); + if let Some(sender) = sender { + sender.send((message.into_static(), None)).await?; + } } RouteMessageTo::TemplateProvider(message) => { - _ = channel_manager_channel + channel_manager_channel .tp_sender .send(message.into_static()) - .await; + .await?; } } + Ok(()) } } diff --git a/pool-apps/pool/src/lib/channel_manager/template_distribution_message_handler.rs b/pool-apps/pool/src/lib/channel_manager/template_distribution_message_handler.rs index 5024e11cb..8c85a24c7 100644 --- a/pool-apps/pool/src/lib/channel_manager/template_distribution_message_handler.rs +++ b/pool-apps/pool/src/lib/channel_manager/template_distribution_message_handler.rs @@ -128,7 +128,12 @@ impl HandleTemplateDistributionMessagesFromServerAsync for ChannelManager { })?; for message in messages { - message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } Ok(()) @@ -245,7 +250,12 @@ impl HandleTemplateDistributionMessagesFromServerAsync for ChannelManager { })?; for message in messages { - message.forward(&self.channel_manager_channel).await; + // A send can only fail if the receiver side of the channel is closed. + // Since this is an unbounded channel, it cannot fail due to capacity + // limits (which would only apply to bounded channels). + if let Err(e) = message.forward(&self.channel_manager_channel).await { + tracing::error!("Failed to forward message {e:?}"); + } } Ok(()) diff --git a/pool-apps/pool/src/lib/downstream/mod.rs b/pool-apps/pool/src/lib/downstream/mod.rs index 7917f1045..f135ae93d 100644 --- a/pool-apps/pool/src/lib/downstream/mod.rs +++ b/pool-apps/pool/src/lib/downstream/mod.rs @@ -29,7 +29,6 @@ use stratum_apps::{ types::{ChannelId, DownstreamId, Message, Sv2Frame}, }, }; -use tokio::sync::broadcast; use tracing::{debug, error, warn}; use crate::{ @@ -71,7 +70,7 @@ pub struct DownstreamData { #[derive(Clone)] pub struct DownstreamChannel { channel_manager_sender: Sender<(DownstreamId, Mining<'static>, Option>)>, - channel_manager_receiver: broadcast::Sender<(DownstreamId, Mining<'static>, Option>)>, + channel_manager_receiver: Receiver<(Mining<'static>, Option>)>, downstream_sender: Sender, downstream_receiver: Receiver, /// Per-connection cancellation token (child of the global token). @@ -103,11 +102,7 @@ impl Downstream { channel_id_factory: AtomicU32, group_channel: GroupChannel<'static, DefaultJobStore>>, channel_manager_sender: Sender<(DownstreamId, Mining<'static>, Option>)>, - channel_manager_receiver: broadcast::Sender<( - DownstreamId, - Mining<'static>, - Option>, - )>, + channel_manager_receiver: Receiver<(Mining<'static>, Option>)>, noise_stream: NoiseTcpStream, cancellation_token: CancellationToken, task_manager: Arc, @@ -186,7 +181,6 @@ impl Downstream { return; } - let mut receiver = self.downstream_channel.channel_manager_receiver.subscribe(); task_manager.spawn(async move { loop { let mut self_clone_1 = self.clone(); @@ -205,7 +199,7 @@ impl Downstream { } } } - res = self_clone_2.handle_channel_manager_message(&mut receiver) => { + res = self_clone_2.handle_channel_manager_message() => { if let Err(e) = res { error!(?e, "Error handling channel manager message for {downstream_id}"); if handle_error(&status_sender, e).await { @@ -213,7 +207,6 @@ impl Downstream { } } } - } } @@ -255,26 +248,20 @@ impl Downstream { } // Handles messages sent from the channel manager to this downstream. - async fn handle_channel_manager_message( - self, - receiver: &mut broadcast::Receiver<(DownstreamId, Mining<'static>, Option>)>, - ) -> PoolResult<(), error::Downstream> { - let (downstream_id, msg, _tlv_fields) = match receiver.recv().await { + async fn handle_channel_manager_message(self) -> PoolResult<(), error::Downstream> { + let (msg, _tlv_fields) = match self + .downstream_channel + .channel_manager_receiver + .recv() + .await + { Ok(msg) => msg, Err(e) => { warn!(?e, "Broadcast receive failed"); - return Ok(()); + return Err(PoolError::shutdown(PoolErrorKind::ChannelRecv(e))); } }; - if downstream_id != self.downstream_id { - debug!( - ?downstream_id, - "Message ignored for non-matching downstream" - ); - return Ok(()); - } - let message = AnyMessage::Mining(msg); let std_frame: Sv2Frame = message.try_into().map_err(PoolError::shutdown)?; diff --git a/pool-apps/pool/src/lib/mod.rs b/pool-apps/pool/src/lib/mod.rs index 21b2d12f4..9b852e023 100644 --- a/pool-apps/pool/src/lib/mod.rs +++ b/pool-apps/pool/src/lib/mod.rs @@ -13,7 +13,7 @@ use stratum_apps::{ stratum_core::bitcoin::consensus::Encodable, task_manager::TaskManager, tp_type::TemplateProviderType, utils::types::GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS, }; -use tokio::sync::{broadcast, Notify}; +use tokio::sync::Notify; use tracing::{debug, error, info, warn}; use crate::{ @@ -72,8 +72,6 @@ impl PoolSv2 { let (status_sender, status_receiver) = unbounded(); - let (channel_manager_to_downstream_sender, _channel_manager_to_downstream_receiver) = - broadcast::channel(10); let (downstream_to_channel_manager_sender, downstream_to_channel_manager_receiver) = unbounded(); @@ -86,7 +84,6 @@ impl PoolSv2 { self.config.clone(), channel_manager_to_tp_sender.clone(), tp_to_channel_manager_receiver, - channel_manager_to_downstream_sender.clone(), downstream_to_channel_manager_receiver, encoded_outputs.clone(), ) @@ -211,7 +208,6 @@ impl PoolSv2 { cancellation_token.clone(), status_sender, downstream_to_channel_manager_sender, - channel_manager_to_downstream_sender, ) .await?; diff --git a/pool-apps/pool/src/lib/utils.rs b/pool-apps/pool/src/lib/utils.rs index 75d9240f2..f648dfdba 100644 --- a/pool-apps/pool/src/lib/utils.rs +++ b/pool-apps/pool/src/lib/utils.rs @@ -4,12 +4,15 @@ use stratum_apps::{ binary_sv2::Str0255, common_messages_sv2::{Protocol, SetupConnection}, mining_sv2::CloseChannel, + parsers_sv2::{Mining, Tlv}, }, utils::types::ChannelId, }; use crate::error::PoolErrorKind; +pub(crate) type DownstreamMessage = (Mining<'static>, Option>); + /// Constructs a `SetupConnection` message for the mining protocol. #[allow(clippy::result_large_err)] pub fn get_setup_connection_message(