Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions src/key_exchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,14 @@ impl<'a> TryFrom<IncomingPacket<'a>> for EcdhKeyExchangeInit<'a> {
type Error = Error;

fn try_from(packet: IncomingPacket<'a>) -> Result<Self, Error> {
let Decoded {
value: r#type,
next,
} = MessageType::decode(packet.payload)?;
if r#type != MessageType::KeyExchangeEcdhInit {
if packet.message_type != MessageType::KeyExchangeEcdhInit {
return Err(Error::InvalidPacket("unexpected message type"));
}

let Decoded {
value: client_ephemeral_public_key,
next,
} = <&[u8]>::decode(next)?;
} = <&[u8]>::decode(packet.payload)?;

if !next.is_empty() {
debug!(bytes = ?next, "unexpected trailing bytes");
Expand Down Expand Up @@ -303,18 +299,14 @@ impl<'a> TryFrom<IncomingPacket<'a>> for KeyExchangeInit<'a> {
type Error = Error;

fn try_from(packet: IncomingPacket<'a>) -> Result<Self, Self::Error> {
let Decoded {
value: r#type,
next,
} = MessageType::decode(packet.payload)?;
if r#type != MessageType::KeyExchangeInit {
if packet.message_type != MessageType::KeyExchangeInit {
return Err(Error::InvalidPacket("unexpected message type"));
}

let Decoded {
value: cookie,
next,
} = <[u8; 16]>::decode(next)?;
} = <[u8; 16]>::decode(packet.payload)?;

let Decoded {
value: key_exchange_algorithms,
Expand Down Expand Up @@ -426,16 +418,12 @@ impl<'a> TryFrom<IncomingPacket<'a>> for NewKeys {
type Error = Error;

fn try_from(packet: IncomingPacket<'a>) -> Result<Self, Self::Error> {
let Decoded {
value: r#type,
next,
} = MessageType::decode(packet.payload)?;
if r#type != MessageType::NewKeys {
if packet.message_type != MessageType::NewKeys {
return Err(Error::InvalidPacket("unexpected message type"));
}

if !next.is_empty() {
debug!(bytes = ?next, "unexpected trailing bytes");
if !packet.payload.is_empty() {
debug!(bytes = ?packet.payload, "unexpected trailing bytes");
return Err(Error::InvalidPacket("unexpected trailing bytes"));
}

Expand Down
12 changes: 11 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ mod key_exchange;
use key_exchange::KeyExchange;
mod proto;
use proto::{AesCtrWriteKeys, Completion, Decoded, MessageType, ReadState, WriteState};
mod service;
mod userauth;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should be called user_auth.


use crate::{
key_exchange::{EcdhKeyExchangeInit, KeyExchangeInit, NewKeys, RawKeySet},
Expand Down Expand Up @@ -57,7 +59,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
return;
}
};
exchange.prefixed(packet.payload);
exchange.update(&((packet.payload.len() + 1) as u32).to_be_bytes());
exchange.update(&[u8::from(packet.message_type)]);
exchange.update(packet.payload);
let peer_key_exchange_init = match KeyExchangeInit::try_from(packet) {
Ok(key_exchange_init) => key_exchange_init,
Err(error) => {
Expand Down Expand Up @@ -296,6 +300,12 @@ enum Error {
InvalidMac,
#[error("unreachable code: {0}")]
Unreachable(&'static str),
#[error("not ready for new packets")]
#[expect(
unused,
reason = "Use marking from the service trait is failing in the compiler"
)]
NotReady,
}

#[derive(Debug, Error)]
Expand Down
60 changes: 55 additions & 5 deletions src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::future;
use core::iter;
use core::pin::Pin;
use core::task::{ready, Context, Poll};
use std::io;
use std::{borrow::Cow, io};

use aws_lc_rs::{
cipher::{self, StreamingDecryptingKey, StreamingEncryptingKey, UnboundCipherKey},
Expand Down Expand Up @@ -41,7 +41,9 @@ impl ReadState {
) -> Result<IncomingPacket<'a>, Error> {
loop {
match self.poll_packet()? {
Completion::Complete(packet_length) => return self.decode_packet(packet_length),
Completion::Complete((sequence_number, packet_length)) => {
return self.decode_packet(sequence_number, packet_length)
}
Completion::Incomplete(_amount) => {
let _ = self.buffer(stream).await?;
continue;
Expand All @@ -51,7 +53,7 @@ impl ReadState {
}

// This and decode_packet are split because of a borrowck limitation
pub(crate) fn poll_packet(&mut self) -> Result<Completion<PacketLength>, Error> {
pub(crate) fn poll_packet(&mut self) -> Result<Completion<(u32, PacketLength)>, Error> {
// Compact the internal buffer
if self.last_length > 0 {
self.buf.copy_within(self.last_length.., 0);
Expand Down Expand Up @@ -145,13 +147,15 @@ impl ReadState {

// Note: this needs to be done AFTER the IO to ensure
// this async function is cancel-safe
let sequence_number = self.sequence_number;
self.sequence_number = self.sequence_number.wrapping_add(1);
self.last_length = 4 + packet_length.inner as usize + mac_len;
Ok(Completion::Complete(packet_length))
Ok(Completion::Complete((sequence_number, packet_length)))
}

pub(crate) fn decode_packet<'a>(
&'a self,
sequence_number: u32,
packet_length: PacketLength,
) -> Result<IncomingPacket<'a>, Error> {
let Decoded {
Expand All @@ -164,6 +168,17 @@ impl ReadState {
return Err(Error::Incomplete(Some(payload_len - next.len())));
};

let Decoded {
value: message_type,
next: payload,
} = MessageType::decode(payload).map_err(|e| {
if matches!(e, Error::Incomplete(_)) {
Error::InvalidPacket("Packet without message type")
} else {
e
}
})?;

let Some(next) = next.get(payload_len..) else {
return Err(Error::Unreachable(
"unable to extract rest after fixed-length slice",
Expand All @@ -176,7 +191,11 @@ impl ReadState {
)));
};

Ok(IncomingPacket { payload })
Ok(IncomingPacket {
sequence_number,
message_type,
payload,
})
}

pub(crate) async fn buffer<'a>(
Expand Down Expand Up @@ -403,6 +422,10 @@ pub(crate) enum MessageType {
NewKeys,
KeyExchangeEcdhInit,
KeyExchangeEcdhReply,
UserauthRequest,
UserauthFailure,
UserauthSuccess,
UserauthBanner,
Unknown(u8),
}

Expand Down Expand Up @@ -435,6 +458,10 @@ impl From<u8> for MessageType {
21 => Self::NewKeys,
30 => Self::KeyExchangeEcdhInit,
31 => Self::KeyExchangeEcdhReply,
50 => Self::UserauthRequest,
51 => Self::UserauthFailure,
52 => Self::UserauthSuccess,
53 => Self::UserauthBanner,
value => Self::Unknown(value),
}
}
Expand All @@ -453,15 +480,38 @@ impl From<MessageType> for u8 {
MessageType::NewKeys => 21,
MessageType::KeyExchangeEcdhInit => 30,
MessageType::KeyExchangeEcdhReply => 31,
MessageType::UserauthRequest => 50,
MessageType::UserauthFailure => 51,
MessageType::UserauthSuccess => 52,
MessageType::UserauthBanner => 53,
MessageType::Unknown(value) => value,
}
}
}

pub(crate) struct IncomingPacket<'a> {
pub(crate) sequence_number: u32,
pub(crate) message_type: MessageType,
pub(crate) payload: &'a [u8],
}

impl IncomingPacket<'_> {
pub(crate) fn unimplemented(self) -> OutgoingPacket<'static> {
OutgoingPacket {
message_type: MessageType::Unimplemented,
payload: Cow::Owned(self.sequence_number.to_be_bytes().to_vec()),
}
}
}

pub(crate) struct OutgoingPacket<'a> {
#[expect(unused)]
pub(crate) message_type: MessageType,
// FIXME: Figure out a way to encode the payload requiring fewer boxes/copies.
#[expect(unused)]
pub(crate) payload: Cow<'a, [u8]>,
}

/// An encoded outgoing packet including length field and padding, but
/// excluding encryption and MAC
#[must_use]
Expand Down
31 changes: 31 additions & 0 deletions src/service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::{
proto::{IncomingPacket, OutgoingPacket},
Error,
};

// FIXME: move this to the transport connection code once that is implemented.
#[expect(unused)]
pub(crate) enum ConnectionEvent {
Close,
}

#[expect(unused)]
pub(crate) trait Service {
/// Service name used by SshTransportConnection during handshake
const NAME: &'static [u8];

/// Poll for packets to transmit through the transport layer.
///
/// Should be called first of the poll functions.
fn poll_transmit(&mut self) -> Option<OutgoingPacket<'_>>;
/// Poll for connection events that need handling by the
/// transport layer.
///
/// Should be called second of the poll functions. However
/// services should ensure themselves that all outgoing packets
/// are sent before emitting a connectionevent that results in
/// termination of the connection or service.
fn poll_event(&mut self) -> Option<ConnectionEvent>;
/// Handle a packet
fn handle_packet(&mut self, packet: IncomingPacket<'_>) -> Result<(), Error>;
}
Loading