diff --git a/src/key_exchange.rs b/src/key_exchange.rs index 7fb8cc21..b929c2d9 100644 --- a/src/key_exchange.rs +++ b/src/key_exchange.rs @@ -109,18 +109,14 @@ impl<'a> TryFrom> for EcdhKeyExchangeInit<'a> { type Error = Error; fn try_from(packet: IncomingPacket<'a>) -> Result { - let Decoded { - value: r#type, - next, - } = MessageType::decode(packet.payload)?; - if r#type != MessageType::KeyExchangeEcdhInit { + if packet.message_type != MessageType::KeyExchangeEcdhInit { return Err(Error::InvalidPacket("unexpected message type")); } let Decoded { value: client_ephemeral_public_key, next, - } = <&[u8]>::decode(next)?; + } = <&[u8]>::decode(packet.payload)?; if !next.is_empty() { debug!(bytes = ?next, "unexpected trailing bytes"); @@ -303,18 +299,14 @@ impl<'a> TryFrom> for KeyExchangeInit<'a> { type Error = Error; fn try_from(packet: IncomingPacket<'a>) -> Result { - let Decoded { - value: r#type, - next, - } = MessageType::decode(packet.payload)?; - if r#type != MessageType::KeyExchangeInit { + if packet.message_type != MessageType::KeyExchangeInit { return Err(Error::InvalidPacket("unexpected message type")); } let Decoded { value: cookie, next, - } = <[u8; 16]>::decode(next)?; + } = <[u8; 16]>::decode(packet.payload)?; let Decoded { value: key_exchange_algorithms, @@ -426,16 +418,12 @@ impl<'a> TryFrom> for NewKeys { type Error = Error; fn try_from(packet: IncomingPacket<'a>) -> Result { - let Decoded { - value: r#type, - next, - } = MessageType::decode(packet.payload)?; - if r#type != MessageType::NewKeys { + if packet.message_type != MessageType::NewKeys { return Err(Error::InvalidPacket("unexpected message type")); } - if !next.is_empty() { - debug!(bytes = ?next, "unexpected trailing bytes"); + if !packet.payload.is_empty() { + debug!(bytes = ?packet.payload, "unexpected trailing bytes"); return Err(Error::InvalidPacket("unexpected trailing bytes")); } diff --git a/src/lib.rs b/src/lib.rs index 32f96645..1fe7546a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,8 @@ mod key_exchange; use key_exchange::KeyExchange; mod proto; use proto::{AesCtrWriteKeys, Completion, Decoded, MessageType, ReadState, WriteState}; +mod service; +mod userauth; use crate::{ key_exchange::{EcdhKeyExchangeInit, KeyExchangeInit, NewKeys, RawKeySet}, @@ -57,7 +59,9 @@ impl Connection { return; } }; - exchange.prefixed(packet.payload); + exchange.update(&((packet.payload.len() + 1) as u32).to_be_bytes()); + exchange.update(&[u8::from(packet.message_type)]); + exchange.update(packet.payload); let peer_key_exchange_init = match KeyExchangeInit::try_from(packet) { Ok(key_exchange_init) => key_exchange_init, Err(error) => { @@ -296,6 +300,12 @@ enum Error { InvalidMac, #[error("unreachable code: {0}")] Unreachable(&'static str), + #[error("not ready for new packets")] + #[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" + )] + NotReady, } #[derive(Debug, Error)] diff --git a/src/proto.rs b/src/proto.rs index 410c8173..04765dbe 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -2,7 +2,7 @@ use core::future; use core::iter; use core::pin::Pin; use core::task::{ready, Context, Poll}; -use std::io; +use std::{borrow::Cow, io}; use aws_lc_rs::{ cipher::{self, StreamingDecryptingKey, StreamingEncryptingKey, UnboundCipherKey}, @@ -41,7 +41,9 @@ impl ReadState { ) -> Result, Error> { loop { match self.poll_packet()? { - Completion::Complete(packet_length) => return self.decode_packet(packet_length), + Completion::Complete((sequence_number, packet_length)) => { + return self.decode_packet(sequence_number, packet_length) + } Completion::Incomplete(_amount) => { let _ = self.buffer(stream).await?; continue; @@ -51,7 +53,7 @@ impl ReadState { } // This and decode_packet are split because of a borrowck limitation - pub(crate) fn poll_packet(&mut self) -> Result, Error> { + pub(crate) fn poll_packet(&mut self) -> Result, Error> { // Compact the internal buffer if self.last_length > 0 { self.buf.copy_within(self.last_length.., 0); @@ -145,13 +147,15 @@ impl ReadState { // Note: this needs to be done AFTER the IO to ensure // this async function is cancel-safe + let sequence_number = self.sequence_number; self.sequence_number = self.sequence_number.wrapping_add(1); self.last_length = 4 + packet_length.inner as usize + mac_len; - Ok(Completion::Complete(packet_length)) + Ok(Completion::Complete((sequence_number, packet_length))) } pub(crate) fn decode_packet<'a>( &'a self, + sequence_number: u32, packet_length: PacketLength, ) -> Result, Error> { let Decoded { @@ -164,6 +168,17 @@ impl ReadState { return Err(Error::Incomplete(Some(payload_len - next.len()))); }; + let Decoded { + value: message_type, + next: payload, + } = MessageType::decode(payload).map_err(|e| { + if matches!(e, Error::Incomplete(_)) { + Error::InvalidPacket("Packet without message type") + } else { + e + } + })?; + let Some(next) = next.get(payload_len..) else { return Err(Error::Unreachable( "unable to extract rest after fixed-length slice", @@ -176,7 +191,11 @@ impl ReadState { ))); }; - Ok(IncomingPacket { payload }) + Ok(IncomingPacket { + sequence_number, + message_type, + payload, + }) } pub(crate) async fn buffer<'a>( @@ -403,6 +422,10 @@ pub(crate) enum MessageType { NewKeys, KeyExchangeEcdhInit, KeyExchangeEcdhReply, + UserauthRequest, + UserauthFailure, + UserauthSuccess, + UserauthBanner, Unknown(u8), } @@ -435,6 +458,10 @@ impl From for MessageType { 21 => Self::NewKeys, 30 => Self::KeyExchangeEcdhInit, 31 => Self::KeyExchangeEcdhReply, + 50 => Self::UserauthRequest, + 51 => Self::UserauthFailure, + 52 => Self::UserauthSuccess, + 53 => Self::UserauthBanner, value => Self::Unknown(value), } } @@ -453,15 +480,38 @@ impl From for u8 { MessageType::NewKeys => 21, MessageType::KeyExchangeEcdhInit => 30, MessageType::KeyExchangeEcdhReply => 31, + MessageType::UserauthRequest => 50, + MessageType::UserauthFailure => 51, + MessageType::UserauthSuccess => 52, + MessageType::UserauthBanner => 53, MessageType::Unknown(value) => value, } } } pub(crate) struct IncomingPacket<'a> { + pub(crate) sequence_number: u32, + pub(crate) message_type: MessageType, pub(crate) payload: &'a [u8], } +impl IncomingPacket<'_> { + pub(crate) fn unimplemented(self) -> OutgoingPacket<'static> { + OutgoingPacket { + message_type: MessageType::Unimplemented, + payload: Cow::Owned(self.sequence_number.to_be_bytes().to_vec()), + } + } +} + +pub(crate) struct OutgoingPacket<'a> { + #[expect(unused)] + pub(crate) message_type: MessageType, + // FIXME: Figure out a way to encode the payload requiring fewer boxes/copies. + #[expect(unused)] + pub(crate) payload: Cow<'a, [u8]>, +} + /// An encoded outgoing packet including length field and padding, but /// excluding encryption and MAC #[must_use] diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 00000000..f2b1fd15 --- /dev/null +++ b/src/service.rs @@ -0,0 +1,31 @@ +use crate::{ + proto::{IncomingPacket, OutgoingPacket}, + Error, +}; + +// FIXME: move this to the transport connection code once that is implemented. +#[expect(unused)] +pub(crate) enum ConnectionEvent { + Close, +} + +#[expect(unused)] +pub(crate) trait Service { + /// Service name used by SshTransportConnection during handshake + const NAME: &'static [u8]; + + /// Poll for packets to transmit through the transport layer. + /// + /// Should be called first of the poll functions. + fn poll_transmit(&mut self) -> Option>; + /// Poll for connection events that need handling by the + /// transport layer. + /// + /// Should be called second of the poll functions. However + /// services should ensure themselves that all outgoing packets + /// are sent before emitting a connectionevent that results in + /// termination of the connection or service. + fn poll_event(&mut self) -> Option; + /// Handle a packet + fn handle_packet(&mut self, packet: IncomingPacket<'_>) -> Result<(), Error>; +} diff --git a/src/userauth.rs b/src/userauth.rs new file mode 100644 index 00000000..fe58fe38 --- /dev/null +++ b/src/userauth.rs @@ -0,0 +1,446 @@ +use std::borrow::Cow; + +use crate::{ + proto::{Decode, Decoded, IncomingPacket, MessageType, OutgoingPacket}, + service::Service, + Error, +}; + +pub(crate) struct SshUserauth { + state: AuthState, +} + +impl SshUserauth { + #[expect(unused)] + pub(crate) fn new() -> Self { + Self { + state: AuthState::WaitingForAuthRequest(WaitingForAuthRequest::new_session()), + } + } + + #[expect(unused)] + pub(crate) fn poll_authrequest(&mut self) -> Option> { + match &self.state { + AuthState::WaitingForAuthDecision(waiting_for_auth_decision) => { + match waiting_for_auth_decision.method() { + AuthenticationMethod::None => Some(AuthRequest::None(NoneAuthRequest { + state: &mut self.state, + })), + AuthenticationMethod::Unknown(cow) => { + unreachable!("Pending auth request with unknown method") + } + } + } + _ => None, + } + } + + #[expect(unused)] + pub(crate) fn poll_complete( + mut self, + ) -> Result<(impl FnOnce(S) -> SshUserauthWrapper, AuthData), Self> { + match self.state.take() { + AuthState::AuthCompleted(auth_completed) => Ok(( + |inner| SshUserauthWrapper { inner }, + AuthData { + username: auth_completed.username, + service: auth_completed.service, + }, + )), + s => { + self.state = s; + Err(self) + } + } + } +} + +/// Main result of authentication +pub(crate) struct AuthData { + username: Vec, + service: Vec, +} + +impl AuthData { + #[expect(unused)] + fn username(&self) -> &[u8] { + &self.username + } + + #[expect(unused)] + fn service(&self) -> &[u8] { + &self.service + } +} + +/// Wrapper needed to handle the authentication messages after completion of authentication +#[expect(unused)] +pub(crate) struct SshUserauthWrapper { + inner: S, +} + +impl Service for SshUserauthWrapper { + const NAME: &'static [u8] = S::NAME; + + fn poll_transmit(&mut self) -> Option> { + self.inner.poll_transmit() + } + + fn poll_event(&mut self) -> Option { + self.inner.poll_event() + } + + fn handle_packet(&mut self, packet: IncomingPacket<'_>) -> Result<(), Error> { + match packet.message_type { + MessageType::UserauthRequest + | MessageType::UserauthSuccess + | MessageType::UserauthFailure + | MessageType::UserauthBanner => { + // Ignore per RFC4252 section 5.3 + Ok(()) + } + _ => self.inner.handle_packet(packet), + } + } +} + +// FIXME: Implement actual proper authentication methods +pub(crate) enum AuthRequest<'a> { + #[expect(unused)] + None(NoneAuthRequest<'a>), +} + +pub(crate) struct NoneAuthRequest<'a> { + state: &'a mut AuthState, +} + +impl NoneAuthRequest<'_> { + #[expect(unused)] + pub(crate) fn username(&self) -> &[u8] { + match &self.state { + AuthState::WaitingForAuthDecision(waiting_for_auth_decision) => { + waiting_for_auth_decision.username() + } + _ => unreachable!("Invalid state for auth request"), + } + } + + #[expect(unused)] + pub(crate) fn service(&self) -> &[u8] { + match &self.state { + AuthState::WaitingForAuthDecision(waiting_for_auth_decision) => { + waiting_for_auth_decision.service() + } + _ => unreachable!("Invalid state for auth request"), + } + } + + #[expect(unused)] + pub(crate) fn accept(self) { + match self.state.take() { + AuthState::WaitingForAuthDecision(waiting_for_auth_decision) => { + *self.state = + AuthState::AuthCompletedWaitingForTransmit(waiting_for_auth_decision.accept()) + } + _ => unreachable!("Invalid state for auth request"), + } + } + + #[expect(unused)] + pub(crate) fn decline(self) { + match self.state.take() { + AuthState::WaitingForAuthDecision(waiting_for_auth_decision) => { + *self.state = AuthState::AuthFailed(waiting_for_auth_decision.decline()) + } + _ => unreachable!("Invalid state for auth request"), + } + } +} + +impl Service for SshUserauth { + const NAME: &'static [u8] = b"ssh-userauth"; + + fn poll_transmit(&mut self) -> Option> { + match self.state.take() { + AuthState::AuthCompletedWaitingForTransmit(auth_completed_waiting_for_transmit) => { + let (new_state, packet) = auth_completed_waiting_for_transmit.advance(); + self.state = AuthState::AuthCompleted(new_state); + Some(packet) + } + AuthState::AuthFailed(auth_failed) => { + let (new_state, packet) = auth_failed.advance(); + self.state = AuthState::WaitingForAuthRequest(new_state); + Some(packet) + } + AuthState::WaitingToSendUnimplemented(waiting_to_send_unimplemented) => { + let (new_state, packet) = waiting_to_send_unimplemented.advance(); + self.state = AuthState::WaitingForAuthRequest(new_state); + Some(packet) + } + AuthState::Poisoned => { + panic!("Poisoned authentication state. Error was non-recoverable.") + } + s => { + self.state = s; + None + } + } + } + + fn poll_event(&mut self) -> Option { + None + } + + fn handle_packet(&mut self, packet: IncomingPacket<'_>) -> Result<(), Error> { + match self.state.take() { + AuthState::WaitingForAuthRequest(waiting_for_auth_request) => { + if packet.message_type == MessageType::UserauthRequest { + let request = UserauthRequest::try_from(packet)?; + self.state = match waiting_for_auth_request.advance(request) { + Ok(waiting_for_auth_decision) => { + AuthState::WaitingForAuthDecision(waiting_for_auth_decision) + } + Err(auth_failed) => AuthState::AuthFailed(auth_failed), + }; + } else { + self.state = AuthState::WaitingToSendUnimplemented( + WaitingToSendUnimplemented::from(packet), + ); + } + } + AuthState::WaitingForAuthDecision(_) + | AuthState::AuthCompletedWaitingForTransmit(_) + | AuthState::AuthCompleted(_) + | AuthState::AuthFailed(_) + | AuthState::WaitingToSendUnimplemented(_) => return Err(Error::NotReady), + AuthState::Poisoned => { + panic!("Poisoned authentication state. Error was non-recoverable.") + } + } + + Ok(()) + } +} + +#[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" +)] +enum AuthState { + WaitingForAuthRequest(WaitingForAuthRequest), + WaitingForAuthDecision(WaitingForAuthDecision), + AuthCompletedWaitingForTransmit(AuthCompletedWaitingForTransmit), + AuthCompleted(AuthCompleted), + AuthFailed(AuthFailed), + WaitingToSendUnimplemented(WaitingToSendUnimplemented), + Poisoned, +} + +impl AuthState { + fn take(&mut self) -> Self { + let mut val = Self::Poisoned; + core::mem::swap(&mut val, self); + val + } +} + +struct WaitingForAuthRequest {} + +impl WaitingForAuthRequest { + fn new_session() -> Self { + Self {} + } + + #[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" + )] + fn advance(self, request: UserauthRequest<'_>) -> Result { + match request.method { + AuthenticationMethod::None => Ok(WaitingForAuthDecision { + username: request.username.to_vec(), + service: request.service.to_vec(), + method: AuthenticationMethod::None, + }), + AuthenticationMethod::Unknown(_) => Err(AuthFailed {}), + } + } +} + +impl WaitingForAuthRequest {} + +struct WaitingForAuthDecision { + username: Vec, + service: Vec, + method: AuthenticationMethod<'static>, +} + +impl WaitingForAuthDecision { + fn username(&self) -> &[u8] { + &self.username + } + + fn service(&self) -> &[u8] { + &self.service + } + + fn method(&self) -> AuthenticationMethod<'_> { + self.method.borrowed() + } + + fn accept(self) -> AuthCompletedWaitingForTransmit { + AuthCompletedWaitingForTransmit { + username: self.username, + service: self.service, + } + } + + fn decline(self) -> AuthFailed { + AuthFailed {} + } +} + +struct AuthCompletedWaitingForTransmit { + username: Vec, + service: Vec, +} + +impl AuthCompletedWaitingForTransmit { + #[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" + )] + fn advance(self) -> (AuthCompleted, OutgoingPacket<'static>) { + ( + AuthCompleted { + username: self.username, + service: self.service, + }, + OutgoingPacket { + message_type: MessageType::UserauthSuccess, + payload: Cow::Borrowed(&[]), + }, + ) + } +} + +struct AuthCompleted { + username: Vec, + service: Vec, +} + +struct AuthFailed {} + +impl AuthFailed { + #[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" + )] + fn advance(self) -> (WaitingForAuthRequest, OutgoingPacket<'static>) { + ( + WaitingForAuthRequest {}, + OutgoingPacket { + message_type: MessageType::UserauthFailure, + payload: Cow::Borrowed(b"\0\0\0\0\0"), + }, + ) + } +} + +struct WaitingToSendUnimplemented { + packet: OutgoingPacket<'static>, +} + +impl From> for WaitingToSendUnimplemented { + fn from(value: IncomingPacket<'_>) -> Self { + Self { + packet: value.unimplemented(), + } + } +} + +impl WaitingToSendUnimplemented { + #[expect( + unused, + reason = "Use marking from the service trait is failing in the compiler" + )] + fn advance(self) -> (WaitingForAuthRequest, OutgoingPacket<'static>) { + (WaitingForAuthRequest {}, self.packet) + } +} + +#[derive(Debug, Clone)] +enum AuthenticationMethod<'a> { + None, + Unknown(Cow<'a, [u8]>), +} + +impl<'b> AuthenticationMethod<'b> { + fn borrowed<'a: 'b>(&'a self) -> AuthenticationMethod<'a> { + match self { + Self::None => Self::None, + Self::Unknown(Cow::Borrowed(b)) => Self::Unknown(Cow::Borrowed(b)), + Self::Unknown(Cow::Owned(b)) => Self::Unknown(Cow::Borrowed(b)), + } + } +} + +impl<'a> Decode<'a> for AuthenticationMethod<'a> { + fn decode(bytes: &'a [u8]) -> Result, Error> { + let Decoded { + value: bytestring, + next, + } = <&[u8]>::decode(bytes)?; + Ok(Decoded { + value: AuthenticationMethod::from(bytestring), + next, + }) + } +} + +impl<'a> From<&'a [u8]> for AuthenticationMethod<'a> { + fn from(value: &'a [u8]) -> Self { + match value { + b"none" => Self::None, + _ => Self::Unknown(Cow::Borrowed(value)), + } + } +} + +struct UserauthRequest<'a> { + username: &'a [u8], + service: &'a [u8], + method: AuthenticationMethod<'a>, + #[expect(unused)] + method_specific_data: &'a [u8], +} + +impl<'a> TryFrom> for UserauthRequest<'a> { + type Error = Error; + + fn try_from(packet: IncomingPacket<'a>) -> Result { + if packet.message_type != MessageType::UserauthRequest { + return Err(Error::InvalidPacket("unexpected message type")); + } + + let Decoded { + value: username, + next, + } = <&[u8]>::decode(packet.payload)?; + let Decoded { + value: service, + next, + } = <&[u8]>::decode(next)?; + let Decoded { + value: method, + next: method_specific_data, + } = AuthenticationMethod::decode(next)?; + + Ok(Self { + username, + service, + method, + method_specific_data, + }) + } +}