diff --git a/src/connection.rs b/src/connection.rs new file mode 100644 index 00000000..5b4c30df --- /dev/null +++ b/src/connection.rs @@ -0,0 +1,583 @@ +use core::sync::atomic::AtomicU64; +use std::{ + borrow::Cow, + collections::{HashMap, VecDeque}, +}; + +use crate::{ + proto::{Decode, Decoded, Encode, IncomingPacket, MessageType, OutgoingPacket}, + service::{ConnectionEvent, Service}, + Error, +}; + +const BUFFER_SIZE: u32 = 1024; + +pub(crate) struct SshConnectionService { + connection_id: ConnectionId, + next_channel_id: u32, + channels: HashMap, + pending_channels: VecDeque, + pending_packets: VecDeque>, +} + +impl Service for SshConnectionService { + const NAME: &'static [u8] = b"ssh-connection"; + + fn poll_transmit(&mut self) -> Option> { + self.pending_packets.pop_front() + } + + fn poll_event(&mut self) -> Option { + // FIXME: Implement clean closing of connections. + None + } + + fn handle_packet(&mut self, packet: IncomingPacket<'_>) -> Result<(), Error> { + match packet.message_type { + MessageType::ChannelOpen => { + let message = ChannelOpen::try_from(packet)?; + self.pending_channels.push_back(PendingChannelData { + channel_type: message.channel_type.to_vec(), + remote_id: message.remote_id, + initial_window_size: message.initial_window_size, + max_packet_size: message.max_packet_size, + type_specific_data: message.type_specific_data.to_vec(), + }); + } + MessageType::ChannelData => { + let message = ChannelData::try_from(packet)?; + if let Some(ChannelState::Active(channel_state)) = + self.channels.get_mut(&message.channel_id) + { + channel_state.stdin.extend_from_slice(message.data); + } + } + MessageType::ChannelWindowAdjust => { + let message = ChannelWindowAdjust::try_from(packet)?; + if let Some(ChannelState::Active(channel_state)) = + self.channels.get_mut(&message.channel_id) + { + channel_state.window_size = channel_state + .window_size + .saturating_add(message.bytes_to_add); + } + } + MessageType::ChannelClose => { + let message = ChannelClose::try_from(packet)?; + if let Some(ChannelState::Active(channel_state)) = + self.channels.get_mut(&message.channel_id) + { + self.pending_packets.push_back( + ChannelClose { + channel_id: channel_state.remote_id, + } + .into_packet(), + ); + } + self.channels.remove(&message.channel_id); + } + MessageType::ChannelRequest => { + let message = ChannelRequest::try_from(packet)?; + if let Some(ChannelState::Active(channel_state)) = + self.channels.get_mut(&message.channel_id) + { + // FIXME: Implement proper handling instead of this shim to make the remote happy + match message.request_type { + b"pty-req" | b"shell" if message.want_reply => { + self.pending_packets.push_back( + ChannelSuccess { + channel_id: channel_state.remote_id, + } + .into_packet(), + ); + } + _ if message.want_reply => { + self.pending_packets.push_back( + ChannelFailure { + channel_id: channel_state.remote_id, + } + .into_packet(), + ); + } + _ => {} + } + } + } + _ => self.pending_packets.push_back(packet.unimplemented()), + } + + Ok(()) + } +} + +impl SshConnectionService { + #[expect(unused)] + pub(crate) fn new() -> Self { + Self { + connection_id: ConnectionId::new(), + next_channel_id: 0, + channels: HashMap::new(), + pending_channels: VecDeque::new(), + pending_packets: VecDeque::new(), + } + } + + #[expect(unused)] + pub(crate) fn poll_pending_channel<'a>(&'a mut self) -> Option> { + self.pending_channels + .pop_front() + .map(|data| PendingChannel { + data, + connection: self, + }) + } + + fn get_next_channel_id(&mut self) -> u32 { + loop { + let id = self.next_channel_id; + self.next_channel_id += 1; + if !self.channels.contains_key(&id) { + return id; + } + } + } +} + +pub(crate) struct PendingChannel<'a> { + data: PendingChannelData, + connection: &'a mut SshConnectionService, +} + +impl PendingChannel<'_> { + #[expect(unused)] + pub(crate) fn channel_type(&self) -> &[u8] { + &self.data.channel_type + } + + #[expect(unused)] + pub(crate) fn type_specific_data(&self) -> &[u8] { + &self.data.type_specific_data + } + + #[expect(unused)] + pub(crate) fn accept(self) -> ChannelId { + let our_id = self.connection.get_next_channel_id(); + self.connection.pending_packets.push_back( + ChannelOpenConfirmation { + remote_id: self.data.remote_id, + our_id, + initial_window_size: BUFFER_SIZE, + maximum_packet_size: BUFFER_SIZE, + } + .into_packet(), + ); + self.connection.channels.insert( + our_id, + ChannelState::Active(ActiveChannelData { + remote_id: self.data.remote_id, + window_size: self.data.initial_window_size, + max_packet_size: self.data.max_packet_size, + stdin: vec![], + }), + ); + + let result = ChannelId { + connection_id: self.connection.connection_id, + our_channel_id: our_id, + }; + //Inhibit drop + core::mem::forget(self); + result + } + + #[expect(unused)] + pub(crate) fn decline(self) { + // Actual logic is in drop of self. + } +} + +impl Drop for PendingChannel<'_> { + fn drop(&mut self) { + self.connection.pending_packets.push_back( + ChannelOpenFailure { + remote_id: self.data.remote_id, + } + .into_packet(), + ); + } +} + +pub(crate) struct Channel<'a> { + id: u32, + connection: &'a mut SshConnectionService, +} + +impl Channel<'_> { + #[expect(unused)] + pub(crate) fn poll_recv(&mut self) -> Option> { + let Some(ChannelState::Active(state)) = self.connection.channels.get_mut(&self.id) else { + unreachable!("Channel struct for non-existing channel"); + }; + + if state.stdin.is_empty() { + None + } else { + let mut result = vec![]; + core::mem::swap(&mut result, &mut state.stdin); + self.connection.pending_packets.push_back( + ChannelWindowAdjust { + channel_id: state.remote_id, + bytes_to_add: state.stdin.len() as u32, + } + .into_packet(), + ); + Some(result) + } + } + + #[expect(unused)] + fn send(&mut self, buf: &[u8]) -> usize { + let Some(ChannelState::Active(state)) = self.connection.channels.get_mut(&self.id) else { + unreachable!("Channel struct for non-existing channel"); + }; + + let output_len = buf + .len() + .min(state.max_packet_size.min(state.window_size) as usize); + if output_len > 0 { + self.connection.pending_packets.push_back( + ChannelData { + channel_id: state.remote_id, + data: &buf[..output_len], + } + .into_packet(), + ); + state.window_size -= output_len as u32; + } + + output_len + } + + #[expect(unused)] + fn close(self) { + let Some(ChannelState::Active(state)) = self.connection.channels.get_mut(&self.id) else { + unreachable!("Channel struct for non-existing channel"); + }; + + self.connection.pending_packets.push_back( + ChannelClose { + channel_id: state.remote_id, + } + .into_packet(), + ); + *self.connection.channels.get_mut(&self.id).unwrap() = ChannelState::Closing; + } +} + +struct PendingChannelData { + channel_type: Vec, + remote_id: u32, + initial_window_size: u32, + max_packet_size: u32, + type_specific_data: Vec, +} + +enum ChannelState { + Closing, + Active(ActiveChannelData), +} + +struct ActiveChannelData { + remote_id: u32, + window_size: u32, + max_packet_size: u32, + stdin: Vec, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub(crate) struct ChannelId { + connection_id: ConnectionId, + our_channel_id: u32, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +struct ConnectionId { + id: u64, +} + +impl ConnectionId { + fn new() -> Self { + static NEXT: AtomicU64 = AtomicU64::new(0); + Self { + id: NEXT.fetch_add(1, core::sync::atomic::Ordering::Relaxed), + } + } +} + +#[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" +)] +struct ChannelOpen<'a> { + channel_type: &'a [u8], + remote_id: u32, + initial_window_size: u32, + max_packet_size: u32, + type_specific_data: &'a [u8], +} + +impl<'a> TryFrom> for ChannelOpen<'a> { + type Error = Error; + + fn try_from(packet: IncomingPacket<'a>) -> Result { + if packet.message_type != MessageType::ChannelOpen { + return Err(Error::InvalidPacket("unexpected message type")); + } + + let Decoded { + value: channel_type, + next, + } = <&[u8]>::decode(packet.payload)?; + let Decoded { + value: remote_id, + next, + } = u32::decode(next)?; + let Decoded { + value: initial_window_size, + next, + } = u32::decode(next)?; + let Decoded { + value: max_packet_size, + next: type_specific_data, + } = u32::decode(next)?; + Ok(ChannelOpen { + channel_type, + remote_id, + initial_window_size, + max_packet_size, + type_specific_data, + }) + } +} + +struct ChannelOpenConfirmation { + remote_id: u32, + our_id: u32, + initial_window_size: u32, + maximum_packet_size: u32, +} + +impl ChannelOpenConfirmation { + fn into_packet(self) -> OutgoingPacket<'static> { + let mut payload = Vec::with_capacity(16); + self.remote_id.encode(&mut payload); + self.our_id.encode(&mut payload); + self.initial_window_size.encode(&mut payload); + self.maximum_packet_size.encode(&mut payload); + OutgoingPacket { + message_type: MessageType::ChannelOpenConfirmation, + payload: Cow::Owned(payload), + } + } +} + +struct ChannelOpenFailure { + remote_id: u32, +} + +impl ChannelOpenFailure { + fn into_packet(self) -> OutgoingPacket<'static> { + let mut payload = Vec::with_capacity(16); + self.remote_id.encode(&mut payload); + 2u32.encode(&mut payload); + b"".encode(&mut payload); + b"".encode(&mut payload); + OutgoingPacket { + message_type: MessageType::ChannelOpenFailure, + payload: Cow::Owned(payload), + } + } +} + +struct ChannelData<'a> { + channel_id: u32, + data: &'a [u8], +} + +impl<'a> TryFrom> for ChannelData<'a> { + type Error = Error; + + fn try_from(packet: IncomingPacket<'a>) -> Result { + if packet.message_type != MessageType::ChannelData { + return Err(Error::InvalidPacket("unexpected message type")); + } + + let Decoded { + value: channel_id, + next, + } = u32::decode(packet.payload)?; + let Decoded { value: data, .. } = <&[u8]>::decode(next)?; + + Ok(ChannelData { channel_id, data }) + } +} + +impl ChannelData<'_> { + fn into_packet(self) -> OutgoingPacket<'static> { + let mut payload = Vec::with_capacity(self.data.len() + 8); + self.channel_id.encode(&mut payload); + self.data.encode(&mut payload); + OutgoingPacket { + message_type: MessageType::ChannelData, + payload: Cow::Owned(payload), + } + } +} + +struct ChannelWindowAdjust { + channel_id: u32, + bytes_to_add: u32, +} + +impl TryFrom> for ChannelWindowAdjust { + type Error = Error; + + fn try_from(packet: IncomingPacket<'_>) -> Result { + if packet.message_type != MessageType::ChannelWindowAdjust { + return Err(Error::InvalidPacket("unexpected message type")); + } + + let Decoded { + value: channel_id, + next, + } = u32::decode(packet.payload)?; + let Decoded { + value: bytes_to_add, + .. + } = u32::decode(next)?; + + Ok(Self { + channel_id, + bytes_to_add, + }) + } +} + +impl ChannelWindowAdjust { + fn into_packet(self) -> OutgoingPacket<'static> { + let mut payload = Vec::with_capacity(8); + self.channel_id.encode(&mut payload); + self.bytes_to_add.encode(&mut payload); + OutgoingPacket { + message_type: MessageType::ChannelWindowAdjust, + payload: Cow::Owned(payload), + } + } +} + +struct ChannelClose { + channel_id: u32, +} + +impl TryFrom> for ChannelClose { + type Error = Error; + + fn try_from(packet: IncomingPacket<'_>) -> Result { + if packet.message_type != MessageType::ChannelClose { + return Err(Error::InvalidPacket("unexpected message type")); + } + + let Decoded { + value: channel_id, .. + } = u32::decode(packet.payload)?; + + Ok(Self { channel_id }) + } +} + +impl ChannelClose { + fn into_packet(self) -> OutgoingPacket<'static> { + let mut payload = Vec::with_capacity(4); + self.channel_id.encode(&mut payload); + OutgoingPacket { + message_type: MessageType::ChannelClose, + payload: Cow::Owned(payload), + } + } +} + +#[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" +)] +struct ChannelRequest<'a> { + channel_id: u32, + request_type: &'a [u8], + want_reply: bool, +} + +impl<'a> TryFrom> for ChannelRequest<'a> { + type Error = Error; + + fn try_from(packet: IncomingPacket<'a>) -> Result { + if packet.message_type != MessageType::ChannelRequest { + return Err(Error::InvalidPacket("unexpected message type")); + } + + let Decoded { + value: channel_id, + next, + } = u32::decode(packet.payload)?; + let Decoded { + value: request_type, + next, + } = <&[u8]>::decode(next)?; + let Decoded { + value: want_reply, .. + } = bool::decode(next)?; + + Ok(ChannelRequest { + channel_id, + request_type, + want_reply, + }) + } +} + +struct ChannelSuccess { + channel_id: u32, +} + +impl ChannelSuccess { + #[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" + )] + fn into_packet(self) -> OutgoingPacket<'static> { + let mut payload = Vec::with_capacity(4); + self.channel_id.encode(&mut payload); + OutgoingPacket { + message_type: MessageType::ChannelSuccess, + payload: Cow::Owned(payload), + } + } +} + +struct ChannelFailure { + channel_id: u32, +} + +impl ChannelFailure { + #[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" + )] + fn into_packet(self) -> OutgoingPacket<'static> { + let mut payload = Vec::with_capacity(4); + self.channel_id.encode(&mut payload); + OutgoingPacket { + message_type: MessageType::ChannelFailure, + payload: Cow::Owned(payload), + } + } +} diff --git a/src/key_exchange.rs b/src/key_exchange.rs index 7fb8cc21..b929c2d9 100644 --- a/src/key_exchange.rs +++ b/src/key_exchange.rs @@ -109,18 +109,14 @@ impl<'a> TryFrom> for EcdhKeyExchangeInit<'a> { type Error = Error; fn try_from(packet: IncomingPacket<'a>) -> Result { - let Decoded { - value: r#type, - next, - } = MessageType::decode(packet.payload)?; - if r#type != MessageType::KeyExchangeEcdhInit { + if packet.message_type != MessageType::KeyExchangeEcdhInit { return Err(Error::InvalidPacket("unexpected message type")); } let Decoded { value: client_ephemeral_public_key, next, - } = <&[u8]>::decode(next)?; + } = <&[u8]>::decode(packet.payload)?; if !next.is_empty() { debug!(bytes = ?next, "unexpected trailing bytes"); @@ -303,18 +299,14 @@ impl<'a> TryFrom> for KeyExchangeInit<'a> { type Error = Error; fn try_from(packet: IncomingPacket<'a>) -> Result { - let Decoded { - value: r#type, - next, - } = MessageType::decode(packet.payload)?; - if r#type != MessageType::KeyExchangeInit { + if packet.message_type != MessageType::KeyExchangeInit { return Err(Error::InvalidPacket("unexpected message type")); } let Decoded { value: cookie, next, - } = <[u8; 16]>::decode(next)?; + } = <[u8; 16]>::decode(packet.payload)?; let Decoded { value: key_exchange_algorithms, @@ -426,16 +418,12 @@ impl<'a> TryFrom> for NewKeys { type Error = Error; fn try_from(packet: IncomingPacket<'a>) -> Result { - let Decoded { - value: r#type, - next, - } = MessageType::decode(packet.payload)?; - if r#type != MessageType::NewKeys { + if packet.message_type != MessageType::NewKeys { return Err(Error::InvalidPacket("unexpected message type")); } - if !next.is_empty() { - debug!(bytes = ?next, "unexpected trailing bytes"); + if !packet.payload.is_empty() { + debug!(bytes = ?packet.payload, "unexpected trailing bytes"); return Err(Error::InvalidPacket("unexpected trailing bytes")); } diff --git a/src/lib.rs b/src/lib.rs index 32f96645..ea1f3f2f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,10 +6,12 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, error, instrument, warn}; +mod connection; mod key_exchange; use key_exchange::KeyExchange; mod proto; use proto::{AesCtrWriteKeys, Completion, Decoded, MessageType, ReadState, WriteState}; +mod service; use crate::{ key_exchange::{EcdhKeyExchangeInit, KeyExchangeInit, NewKeys, RawKeySet}, @@ -57,7 +59,9 @@ impl Connection { return; } }; - exchange.prefixed(packet.payload); + exchange.update(&((packet.payload.len() + 1) as u32).to_be_bytes()); + exchange.update(&[u8::from(packet.message_type)]); + exchange.update(packet.payload); let peer_key_exchange_init = match KeyExchangeInit::try_from(packet) { Ok(key_exchange_init) => key_exchange_init, Err(error) => { diff --git a/src/proto.rs b/src/proto.rs index 410c8173..36cf800d 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -2,7 +2,7 @@ use core::future; use core::iter; use core::pin::Pin; use core::task::{ready, Context, Poll}; -use std::io; +use std::{borrow::Cow, io}; use aws_lc_rs::{ cipher::{self, StreamingDecryptingKey, StreamingEncryptingKey, UnboundCipherKey}, @@ -41,7 +41,9 @@ impl ReadState { ) -> Result, Error> { loop { match self.poll_packet()? { - Completion::Complete(packet_length) => return self.decode_packet(packet_length), + Completion::Complete((sequence_number, packet_length)) => { + return self.decode_packet(sequence_number, packet_length) + } Completion::Incomplete(_amount) => { let _ = self.buffer(stream).await?; continue; @@ -51,7 +53,7 @@ impl ReadState { } // This and decode_packet are split because of a borrowck limitation - pub(crate) fn poll_packet(&mut self) -> Result, Error> { + pub(crate) fn poll_packet(&mut self) -> Result, Error> { // Compact the internal buffer if self.last_length > 0 { self.buf.copy_within(self.last_length.., 0); @@ -145,13 +147,15 @@ impl ReadState { // Note: this needs to be done AFTER the IO to ensure // this async function is cancel-safe + let sequence_number = self.sequence_number; self.sequence_number = self.sequence_number.wrapping_add(1); self.last_length = 4 + packet_length.inner as usize + mac_len; - Ok(Completion::Complete(packet_length)) + Ok(Completion::Complete((sequence_number, packet_length))) } pub(crate) fn decode_packet<'a>( &'a self, + sequence_number: u32, packet_length: PacketLength, ) -> Result, Error> { let Decoded { @@ -164,6 +168,17 @@ impl ReadState { return Err(Error::Incomplete(Some(payload_len - next.len()))); }; + let Decoded { + value: message_type, + next: payload, + } = MessageType::decode(payload).map_err(|e| { + if matches!(e, Error::Incomplete(_)) { + Error::InvalidPacket("Packet without message type") + } else { + e + } + })?; + let Some(next) = next.get(payload_len..) else { return Err(Error::Unreachable( "unable to extract rest after fixed-length slice", @@ -176,7 +191,11 @@ impl ReadState { ))); }; - Ok(IncomingPacket { payload }) + Ok(IncomingPacket { + sequence_number, + message_type, + payload, + }) } pub(crate) async fn buffer<'a>( @@ -403,6 +422,20 @@ pub(crate) enum MessageType { NewKeys, KeyExchangeEcdhInit, KeyExchangeEcdhReply, + GlobalRequest, + RequestSuccess, + RequestFailure, + ChannelOpen, + ChannelOpenConfirmation, + ChannelOpenFailure, + ChannelWindowAdjust, + ChannelData, + ChannelExtendedData, + ChannelEof, + ChannelClose, + ChannelRequest, + ChannelSuccess, + ChannelFailure, Unknown(u8), } @@ -435,6 +468,20 @@ impl From for MessageType { 21 => Self::NewKeys, 30 => Self::KeyExchangeEcdhInit, 31 => Self::KeyExchangeEcdhReply, + 80 => Self::GlobalRequest, + 81 => Self::RequestSuccess, + 82 => Self::RequestFailure, + 90 => Self::ChannelOpen, + 91 => Self::ChannelOpenConfirmation, + 92 => Self::ChannelOpenFailure, + 93 => Self::ChannelWindowAdjust, + 94 => Self::ChannelData, + 95 => Self::ChannelExtendedData, + 96 => Self::ChannelEof, + 97 => Self::ChannelClose, + 98 => Self::ChannelRequest, + 99 => Self::ChannelSuccess, + 100 => Self::ChannelFailure, value => Self::Unknown(value), } } @@ -453,15 +500,49 @@ impl From for u8 { MessageType::NewKeys => 21, MessageType::KeyExchangeEcdhInit => 30, MessageType::KeyExchangeEcdhReply => 31, + MessageType::GlobalRequest => 80, + MessageType::RequestSuccess => 81, + MessageType::RequestFailure => 82, + MessageType::ChannelOpen => 90, + MessageType::ChannelOpenConfirmation => 91, + MessageType::ChannelOpenFailure => 92, + MessageType::ChannelWindowAdjust => 93, + MessageType::ChannelData => 94, + MessageType::ChannelExtendedData => 95, + MessageType::ChannelEof => 96, + MessageType::ChannelClose => 97, + MessageType::ChannelRequest => 98, + MessageType::ChannelSuccess => 99, + MessageType::ChannelFailure => 100, MessageType::Unknown(value) => value, } } } pub(crate) struct IncomingPacket<'a> { + pub(crate) sequence_number: u32, + pub(crate) message_type: MessageType, pub(crate) payload: &'a [u8], } +impl IncomingPacket<'_> { + #[expect(unused)] + pub(crate) fn unimplemented(self) -> OutgoingPacket<'static> { + OutgoingPacket { + message_type: MessageType::Unimplemented, + payload: Cow::Owned(self.sequence_number.to_be_bytes().to_vec()), + } + } +} + +pub(crate) struct OutgoingPacket<'a> { + #[expect(unused)] + pub(crate) message_type: MessageType, + // FIXME: Figure out a way to encode the payload requiring fewer boxes/copies. + #[expect(unused)] + pub(crate) payload: Cow<'a, [u8]>, +} + /// An encoded outgoing packet including length field and padding, but /// excluding encryption and MAC #[must_use] diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 00000000..f2b1fd15 --- /dev/null +++ b/src/service.rs @@ -0,0 +1,31 @@ +use crate::{ + proto::{IncomingPacket, OutgoingPacket}, + Error, +}; + +// FIXME: move this to the transport connection code once that is implemented. +#[expect(unused)] +pub(crate) enum ConnectionEvent { + Close, +} + +#[expect(unused)] +pub(crate) trait Service { + /// Service name used by SshTransportConnection during handshake + const NAME: &'static [u8]; + + /// Poll for packets to transmit through the transport layer. + /// + /// Should be called first of the poll functions. + fn poll_transmit(&mut self) -> Option>; + /// Poll for connection events that need handling by the + /// transport layer. + /// + /// Should be called second of the poll functions. However + /// services should ensure themselves that all outgoing packets + /// are sent before emitting a connectionevent that results in + /// termination of the connection or service. + fn poll_event(&mut self) -> Option; + /// Handle a packet + fn handle_packet(&mut self, packet: IncomingPacket<'_>) -> Result<(), Error>; +}