diff --git a/.github/dependabot.yml b/.github/dependabot.yml index a47db14f3..cc5c9637c 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -21,3 +21,8 @@ updates: directory: "/tools/xdp" schedule: interval: "daily" + + - package-ecosystem: "cargo" + directory: "/dc" + schedule: + interval: "daily" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cd98d0444..66dcd5f8c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,7 +43,7 @@ jobs: # find all child folders in the examples directory # jq -R - raw content is passed in (not json, just strings) # jq -s - slurp the content into an object - # jq '. += ' adds the s2n-quic-xdp crate to the list of crates we build + # jq '. += ' adds the s2n-quic-xdp and s2n-quic-dc crates to the list of crates we build # Many of the xdp crates have much more complex build processes, so we # don't try to build all of them. # jq -c - output the object in (c)ompact mode on a single line, github @@ -63,7 +63,7 @@ jobs: export EXAMPLES=$(find examples/ -maxdepth 1 -mindepth 1 -type d | jq -R | jq -sc) echo "examples=$EXAMPLES" echo "examples=$EXAMPLES" >> $GITHUB_OUTPUT - export CRATES=$(find quic common -name *Cargo.toml | jq -R | jq -s | jq '. += ["tools/xdp/s2n-quic-xdp/Cargo.toml"]' | jq -c) + export CRATES=$(find quic common -name *Cargo.toml | jq -R | jq -s | jq '. += ["tools/xdp/s2n-quic-xdp/Cargo.toml","dc/s2n-quic-dc/Cargo.toml"]' | jq -c) echo "crates=$CRATES" echo "crates=$CRATES" >> $GITHUB_OUTPUT @@ -115,7 +115,7 @@ jobs: # # manual_clamp will panic when min > max # See https://github.com/rust-lang/rust-clippy/pull/10101 - cargo clippy --all-features --all-targets -- -A clippy::derive_partial_eq_without_eq -A clippy::manual_clamp ${{ matrix.args }} + cargo clippy --all-features --all-targets --workspace -- -A clippy::derive_partial_eq_without_eq -A clippy::manual_clamp ${{ matrix.args }} udeps: runs-on: ubuntu-latest @@ -196,11 +196,13 @@ jobs: env: [default] include: - os: windows-latest - # s2n-tls doesn't currently build on windows - exclude: --workspace --exclude s2n-quic-tls + # s2n-tls and s2n-quic-dc don't currently build on windows + exclude: --exclude s2n-quic-tls --exclude s2n-quic-dc - rust: stable os: ubuntu-latest target: aarch64-unknown-linux-gnu + # s2n-quic-dc doesn't currently build on aarch64 + exclude: --exclude s2n-quic-dc - rust: stable os: ubuntu-latest target: i686-unknown-linux-gnu @@ -213,6 +215,8 @@ jobs: os: ubuntu-latest target: native env: S2N_QUIC_PLATFORM_FEATURES_OVERRIDE="" + # s2n-quic-dc requires platform features + exclude: --exclude s2n-quic-dc - rust: stable os: ubuntu-latest target: native @@ -259,11 +263,11 @@ jobs: # Build the tests before running to improve cross compilation speed - name: Run cargo/cross build run: | - ${{ matrix.target != 'native' && 'cross' || 'cargo' }} build --tests ${{ matrix.exclude }} ${{ matrix.target != 'native' && format('--target {0}', matrix.target) || '' }} ${{ matrix.args }} + ${{ matrix.target != 'native' && 'cross' || 'cargo' }} build --tests --workspace ${{ matrix.exclude }} ${{ matrix.target != 'native' && format('--target {0}', matrix.target) || '' }} ${{ matrix.args }} - name: Run cargo/cross test run: | - ${{ matrix.target != 'native' && 'cross' || 'cargo' }} test ${{ matrix.exclude }} ${{ matrix.target != 'native' && format('--target {0}', matrix.target) || '' }} ${{ matrix.args }} + ${{ matrix.target != 'native' && 'cross' || 'cargo' }} test --workspace ${{ matrix.exclude }} ${{ matrix.target != 'native' && format('--target {0}', matrix.target) || '' }} ${{ matrix.args }} miri: # miri needs quite a bit of memory so use a larger instance @@ -593,7 +597,7 @@ jobs: - name: Run cargo insta test run: | - cargo insta test --delete-unreferenced-snapshots + cargo insta test --all --delete-unreferenced-snapshots - name: Check to make sure there are no unused snapshots run: | @@ -618,7 +622,7 @@ jobs: - name: Run cargo build run: | cd examples/echo - cargo build --timings --release + cargo build --timings --release --workspace - uses: aws-actions/configure-aws-credentials@v4.0.2 if: github.event_name == 'push' || github.repository == github.event.pull_request.head.repo.full_name @@ -698,7 +702,7 @@ jobs: - name: Run cargo build working-directory: tools/memory-report - run: cargo build --release + run: cargo build --release --workspace - name: Run server working-directory: tools/memory-report diff --git a/Cargo.toml b/Cargo.toml index aaabc3f13..7cc7f6555 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,11 @@ members = [ "common/s2n-*", "quic/s2n-*", + "dc/s2n-*", +] +default-members = [ + "common/s2n-*", + "quic/s2n-*", ] resolver = "2" # don't include any workspaces outside of the main project diff --git a/dc/s2n-quic-dc/Cargo.toml b/dc/s2n-quic-dc/Cargo.toml new file mode 100644 index 000000000..9a95d1b04 --- /dev/null +++ b/dc/s2n-quic-dc/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "s2n-quic-dc" +version = "0.36.0" +description = "Internal crate used by s2n-quic" +repository = "https://github.com/aws/s2n-quic" +authors = ["AWS s2n"] +edition = "2021" +rust-version = "1.71" +license = "Apache-2.0" +# Exclude corpus files when publishing to crates.io +exclude = ["corpus.tar.gz"] + +[features] +testing = [] + +[dependencies] +atomic-waker = "1" +aws-lc-rs = "1" +bytes = "1" +crossbeam-channel = "0.5" +libc = "0.2" +num-rational = { version = "0.4", default-features = false } +once_cell = "1" +s2n-codec = { version = "=0.36.0", path = "../../common/s2n-codec", default-features = false } +s2n-quic-core = { version = "=0.36.0", path = "../../quic/s2n-quic-core", default-features = false } +s2n-quic-platform = { version = "=0.36.0", path = "../../quic/s2n-quic-platform" } +thiserror = "1" +tokio = { version = "1", features = ["io-util"], optional = true } +tracing = "0.1" +zerocopy = { version = "0.7", features = ["derive"] } + +[dev-dependencies] +bolero = "0.10" +s2n-codec = { path = "../../common/s2n-codec", features = ["testing"] } +s2n-quic-core = { path = "../../quic/s2n-quic-core", features = ["testing"] } +tokio = { version = "1", features = ["io-util"] } diff --git a/dc/s2n-quic-dc/benches/Cargo.toml b/dc/s2n-quic-dc/benches/Cargo.toml new file mode 100644 index 000000000..401518e8e --- /dev/null +++ b/dc/s2n-quic-dc/benches/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "benches" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +aws-lc-rs = "1" +criterion = { version = "0.4", features = ["html_reports"] } +s2n-codec = { path = "../../../common/s2n-codec" } +s2n-quic-dc = { path = "../../s2n-quic-dc", features = ["testing"] } + +[[bench]] +name = "bench" +harness = false + +[workspace] +members = ["."] + +[profile.release] +lto = true +codegen-units = 1 +incremental = false + +[profile.bench] +lto = true +codegen-units = 1 +incremental = false diff --git a/dc/s2n-quic-dc/benches/benches/bench.rs b/dc/s2n-quic-dc/benches/benches/bench.rs new file mode 100644 index 000000000..74ba0a444 --- /dev/null +++ b/dc/s2n-quic-dc/benches/benches/bench.rs @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::{criterion_group, criterion_main}; + +criterion_group!(benches, ::benches::benchmarks); +criterion_main!(benches); diff --git a/dc/s2n-quic-dc/benches/src/crypto.rs b/dc/s2n-quic-dc/benches/src/crypto.rs new file mode 100644 index 000000000..34911befa --- /dev/null +++ b/dc/s2n-quic-dc/benches/src/crypto.rs @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::Criterion; + +pub mod encrypt; +pub mod hkdf; + +pub fn benchmarks(c: &mut Criterion) { + encrypt::benchmarks(c); + hkdf::benchmarks(c); +} diff --git a/dc/s2n-quic-dc/benches/src/crypto/encrypt.rs b/dc/s2n-quic-dc/benches/src/crypto/encrypt.rs new file mode 100644 index 000000000..18baee88f --- /dev/null +++ b/dc/s2n-quic-dc/benches/src/crypto/encrypt.rs @@ -0,0 +1,92 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::{black_box, BenchmarkId, Criterion, Throughput}; + +pub fn benchmarks(c: &mut Criterion) { + let mut group = c.benchmark_group("crypto/encrypt"); + + let headers = [0, 16]; + + let payloads = [1, 100, 1000, 8900]; + + let inline = [ + ("aes_128_gcm", &aws_lc_rs::aead::AES_128_GCM), + ("aes_256_gcm", &aws_lc_rs::aead::AES_256_GCM), + ]; + + for payload_size in payloads { + let payload = black_box(vec![42u8; payload_size]); + for header_size in headers { + let header = black_box(vec![42u8; header_size]); + + group.throughput(Throughput::Elements(1)); + + let input_name = format!("payload={payload_size},header={header_size}"); + + for (name, algo) in inline { + group.bench_with_input( + BenchmarkId::new(format!("{name}_reuse"), &input_name), + &(&header[..], &payload[..]), + |b, (header, payload)| { + let key = black_box(awslc::key(algo)); + let mut payload = black_box(payload.to_vec()); + let mut packet_number = 0u32; + b.iter(move || { + let _ = black_box(awslc::encrypt( + &key, + &mut packet_number, + header, + &mut payload, + )); + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new(format!("{name}_fresh"), &input_name), + &(&header[..], &payload[..]), + |b, (header, payload)| { + let mut payload = black_box(payload.to_vec()); + let mut packet_number = 0u32; + b.iter(move || { + let key = black_box(awslc::key(algo)); + let _ = black_box(awslc::encrypt( + &key, + &mut packet_number, + header, + &mut payload, + )); + }); + }, + ); + } + } + } +} + +mod awslc { + use aws_lc_rs::aead::{Aad, Algorithm, LessSafeKey, Nonce, UnboundKey, NONCE_LEN}; + + #[inline(never)] + pub fn key(algo: &'static Algorithm) -> LessSafeKey { + let max_key = [42u8; 32]; + let key = &max_key[..algo.key_len()]; + let key = UnboundKey::new(algo, key).unwrap(); + LessSafeKey::new(key) + } + + #[inline(never)] + pub fn encrypt(key: &LessSafeKey, packet_number: &mut u32, header: &[u8], payload: &mut [u8]) { + let mut nonce = [0u8; NONCE_LEN]; + nonce[NONCE_LEN - 8..].copy_from_slice(&(*packet_number as u64).to_be_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + let aad = Aad::from(header); + let mut tag = [0u8; 16]; + key.seal_in_place_scatter(nonce, aad, payload, &[][..], &mut tag) + .unwrap(); + + *packet_number += 1; + } +} diff --git a/dc/s2n-quic-dc/benches/src/crypto/hkdf.rs b/dc/s2n-quic-dc/benches/src/crypto/hkdf.rs new file mode 100644 index 000000000..1f2c11b41 --- /dev/null +++ b/dc/s2n-quic-dc/benches/src/crypto/hkdf.rs @@ -0,0 +1,92 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::{black_box, BenchmarkId, Criterion, Throughput}; + +pub fn benchmarks(c: &mut Criterion) { + psk(c); +} + +fn psk(c: &mut Criterion) { + let mut group = c.benchmark_group("crypto/hkdf/psk"); + + group.throughput(Throughput::Elements(1)); + + let prk_lens = [16, 32]; + let key_lens = [16, 32, 64]; + let label_lens = [1, 8, 16, 32, 64]; + let algs = [ + ("sha256", awslc::HKDF_SHA256), + ("sha384", awslc::HKDF_SHA384), + ("sha512", awslc::HKDF_SHA512), + ]; + + for prk_len in prk_lens { + for key_len in key_lens { + for label_len in label_lens { + for (alg_name, alg) in algs { + group.bench_with_input( + BenchmarkId::new( + format!("{alg_name}_reuse"), + format!("prk_len={prk_len},label_len={label_len},out_len={key_len}"), + ), + &key_len, + |b, &key_len| { + let prk = black_box(awslc::prk(&vec![42u8; prk_len], alg)); + let label = black_box(vec![42u8; label_len]); + let mut out = black_box(vec![0u8; key_len]); + b.iter(move || { + let _ = black_box(awslc::derive_psk(&prk, &label, &mut out)); + }); + }, + ); + group.bench_with_input( + BenchmarkId::new( + format!("{alg_name}_fresh"), + format!("prk_len={prk_len},label_len={label_len},out_len={key_len}"), + ), + &key_len, + |b, &key_len| { + let key = black_box(vec![42u8; prk_len]); + let label = black_box(vec![42u8; label_len]); + let mut out = black_box(vec![0u8; key_len]); + b.iter(move || { + let prk = black_box(awslc::prk(&key, alg)); + let _ = black_box(awslc::derive_psk(&prk, &label, &mut out)); + }); + }, + ); + } + } + } + } +} + +mod awslc { + pub use aws_lc_rs::hkdf::*; + + #[inline(never)] + pub fn prk(prk: &[u8], alg: Algorithm) -> Prk { + Prk::new_less_safe(alg, prk) + } + + #[inline(never)] + pub fn derive_psk(prk: &Prk, label: &[u8], out: &mut [u8]) { + let out_len = out.len(); + let out_len = OutLen(out_len); + + prk.expand(&[&label], out_len) + .unwrap() + .fill(&mut out[..out_len.0]) + .unwrap(); + } + + #[derive(Clone, Copy)] + struct OutLen(usize); + + impl KeyType for OutLen { + fn len(&self) -> usize { + self.0 + } + } +} diff --git a/dc/s2n-quic-dc/benches/src/datagram.rs b/dc/s2n-quic-dc/benches/src/datagram.rs new file mode 100644 index 000000000..0ddfaace7 --- /dev/null +++ b/dc/s2n-quic-dc/benches/src/datagram.rs @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::Criterion; + +mod recv; +mod send; + +pub fn benchmarks(c: &mut Criterion) { + send::benches(c); + recv::benches(c); +} diff --git a/dc/s2n-quic-dc/benches/src/datagram/recv.rs b/dc/s2n-quic-dc/benches/src/datagram/recv.rs new file mode 100644 index 000000000..05d0a07a1 --- /dev/null +++ b/dc/s2n-quic-dc/benches/src/datagram/recv.rs @@ -0,0 +1,203 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::{black_box, BenchmarkId, Criterion}; + +const PACKET: [u8; 90] = [ + 64, 0, 42, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 55, 67, 102, 47, 62, 183, 50, 8, 44, + 222, 220, 128, 156, 98, 0, 128, 201, 9, 228, 4, 62, 25, 149, 52, 227, 53, 226, 10, 143, 72, 79, + 180, 16, 46, 173, 156, 16, 215, 240, 248, 7, 147, 159, 101, 36, 161, 156, 117, 188, 75, 88, + 125, 182, 220, 74, 234, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +]; + +macro_rules! impl_recv { + ($name:ident) => { + mod $name { + use s2n_quic_dc::packet::datagram::Tag; + // use s2n_quic_dc::datagram::send::send as send_impl; + // pub use s2n_quic_dc::datagram::send::testing::$name::{state, State}; + + /* + #[inline(never)] + pub fn recv(state: &mut State, mut input: &[u8]) { + let _ = send_impl(state, &mut (), &mut input); + } + */ + #[inline(never)] + #[allow(dead_code)] + pub fn parse( + buffer: &mut [u8], + ) -> Option { + let buffer = s2n_codec::DecoderBufferMut::new(buffer); + let (packet, _buffer) = s2n_quic_dc::packet::datagram::decoder::Packet::decode( + buffer, + Tag::default(), + 16, + ) + .ok()?; + Some(packet) + } + } + }; +} + +impl_recv!(null); +impl_recv!(aes_128_gcm); +impl_recv!(aes_256_gcm); + +#[allow(const_item_mutation)] +pub fn benches(c: &mut Criterion) { + let mut group = c.benchmark_group("datagram/recv"); + + // let input = black_box(&mut [1u8, 2, 3][..]); + + group.bench_with_input(BenchmarkId::new("test", 1), &(), |b, _input| { + b.iter(move || { + let _ = black_box(null::parse(black_box(&mut PACKET[..]))); + }); + }); + + /* + + let headers = [0, 16]; + + let payloads = [ + 1, //1, 100, 1000, 1450, + 8900, + ]; + + let inline = [ + ("aes_128_gcm", &aws_lc_rs::aead::AES_128_GCM), + ("aes_256_gcm", &aws_lc_rs::aead::AES_256_GCM), + ]; + + for payload_size in payloads { + let payload = black_box(vec![42u8; payload_size]); + for header_size in headers { + let header = black_box(vec![42u8; header_size]); + + group.throughput(Throughput::Elements(1)); + + let input_name = format!("payload={payload_size},header={header_size}"); + + macro_rules! bench { + ($name:ident) => {{ + let id = BenchmarkId::new(stringify!($name), &input_name); + + if header_size > 0 { + group.bench_with_input( + id, + &(&header[..], &payload[..]), + |b, (header, payload)| { + let mut state = black_box($name::state(creds(42).next().unwrap())); + b.iter(move || { + let _ = + black_box($name::send_header(&mut state, header, payload)); + }); + }, + ); + } else { + group.bench_with_input(id, &payload[..], |b, payload| { + let mut state = black_box($name::state(creds(42).next().unwrap())); + b.iter(move || { + let _ = black_box($name::send(&mut state, payload)); + }); + }); + } + }}; + } + + // bench!(null); + bench!(aes_128_gcm); + bench!(aes_256_gcm); + + for (name, algo) in inline { + group.bench_with_input( + BenchmarkId::new(format!("{name}_inline"), &input_name), + &(&header[..], &payload[..]), + |b, (header, payload)| { + let key = black_box(inline::key(algo)); + let mut payload = black_box(payload.to_vec()); + let mut packet_number = 0u32; + b.iter(move || { + let _ = black_box(inline::send( + &key, + &mut packet_number, + header, + &mut payload, + )); + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new(format!("{name}_inline_scatter"), &input_name), + &(&header[..], &payload[..]), + |b, (header, payload)| { + let key = black_box(inline::key(algo)); + let mut out = black_box(payload.to_vec()); + out.extend(&[0u8; 16]); + let mut packet_number = 0u32; + b.iter(move || { + let _ = black_box(inline::send_scatter( + &key, + &mut packet_number, + header, + payload, + &mut out, + )); + }); + }, + ); + } + } + } + */ +} + +/* +mod inline { + use aws_lc_rs::aead::{Aad, Algorithm, LessSafeKey, Nonce, UnboundKey, NONCE_LEN}; + + #[inline(never)] + pub fn key(algo: &'static Algorithm) -> LessSafeKey { + let max_key = [42u8; 32]; + let key = &max_key[..algo.key_len()]; + let key = UnboundKey::new(algo, key).unwrap(); + LessSafeKey::new(key) + } + + #[inline(never)] + pub fn send(key: &LessSafeKey, packet_number: &mut u32, header: &[u8], payload: &mut [u8]) { + let mut nonce = [0u8; NONCE_LEN]; + nonce[NONCE_LEN - 8..].copy_from_slice(&(*packet_number as u64).to_be_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + let aad = Aad::from(header); + let mut tag = [0u8; 16]; + key.seal_in_place_scatter(nonce, aad, payload, &[][..], &mut tag) + .unwrap(); + + *packet_number += 1; + } + + #[inline(never)] + pub fn send_scatter( + key: &LessSafeKey, + packet_number: &mut u32, + header: &[u8], + payload: &[u8], + out: &mut [u8], + ) { + let mut nonce = [0u8; NONCE_LEN]; + nonce[NONCE_LEN - 8..].copy_from_slice(&(*packet_number as u64).to_be_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + let aad = Aad::from(header); + key.seal_in_place_scatter(nonce, aad, &mut [][..], payload, out) + .unwrap(); + + *packet_number += 1; + } +} +*/ diff --git a/dc/s2n-quic-dc/benches/src/datagram/send.rs b/dc/s2n-quic-dc/benches/src/datagram/send.rs new file mode 100644 index 000000000..baf781fcd --- /dev/null +++ b/dc/s2n-quic-dc/benches/src/datagram/send.rs @@ -0,0 +1,176 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::{Criterion, Throughput}; +// use s2n_quic_dc::credentials::testing::iter as creds; + +macro_rules! impl_send { + ($name:ident) => { + mod $name { + // use s2n_quic_dc::datagram::send::send as send_impl; + // pub use s2n_quic_dc::datagram::send::testing::$name::{state, State}; + + /* + #[inline(never)] + pub fn send(state: &mut State, mut input: &[u8]) { + let _ = send_impl(state, &mut (), &mut input); + } + + #[inline(never)] + pub fn send_header(state: &mut State, mut header: &[u8], mut input: &[u8]) { + let _ = send_impl(state, &mut header, &mut input); + } + */ + } + }; +} + +impl_send!(null); +impl_send!(aes_128_gcm); +impl_send!(aes_256_gcm); + +pub fn benches(c: &mut Criterion) { + let mut group = c.benchmark_group("datagram/send"); + + let headers = [0, 16]; + + let payloads = [ + 1, //1, 100, 1000, 1450, + 8900, + ]; + + // let inline = [ + // ("aes_128_gcm", &aws_lc_rs::aead::AES_128_GCM), + // ("aes_256_gcm", &aws_lc_rs::aead::AES_256_GCM), + // ]; + + for _payload_size in payloads { + // let payload = black_box(vec![42u8; payload_size]); + for _header_size in headers { + // let header = black_box(vec![42u8; header_size]); + + group.throughput(Throughput::Elements(1)); + + /* + let input_name = format!("payload={payload_size},header={header_size}"); + + macro_rules! bench { + ($name:ident) => {{ + let id = BenchmarkId::new(stringify!($name), &input_name); + + if header_size > 0 { + group.bench_with_input( + id, + &(&header[..], &payload[..]), + |b, (header, payload)| { + let mut state = black_box($name::state(creds(42).next().unwrap())); + b.iter(move || { + let _ = + black_box($name::send_header(&mut state, header, payload)); + }); + }, + ); + } else { + group.bench_with_input(id, &payload[..], |b, payload| { + let mut state = black_box($name::state(creds(42).next().unwrap())); + b.iter(move || { + let _ = black_box($name::send(&mut state, payload)); + }); + }); + } + }}; + } + + bench!(null); + bench!(aes_128_gcm); + bench!(aes_256_gcm); + + for (name, algo) in inline { + group.bench_with_input( + BenchmarkId::new(format!("{name}_inline"), &input_name), + &(&header[..], &payload[..]), + |b, (header, payload)| { + let key = black_box(inline::key(algo)); + let mut payload = black_box(payload.to_vec()); + let mut packet_number = 0u32; + b.iter(move || { + let _ = black_box(inline::send( + &key, + &mut packet_number, + header, + &mut payload, + )); + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new(format!("{name}_inline_scatter"), &input_name), + &(&header[..], &payload[..]), + |b, (header, payload)| { + let key = black_box(inline::key(algo)); + let mut out = black_box(payload.to_vec()); + out.extend(&[0u8; 16]); + let mut packet_number = 0u32; + b.iter(move || { + let _ = black_box(inline::send_scatter( + &key, + &mut packet_number, + header, + payload, + &mut out, + )); + }); + }, + ); + } + */ + } + } +} + +/* mod inline { + use aws_lc_rs::aead::{Aad, Algorithm, LessSafeKey, Nonce, UnboundKey, NONCE_LEN}; + + #[inline(never)] + pub fn key(algo: &'static Algorithm) -> LessSafeKey { + let max_key = [42u8; 32]; + let key = &max_key[..algo.key_len()]; + let key = UnboundKey::new(algo, key).unwrap(); + LessSafeKey::new(key) + } + + #[inline(never)] + pub fn send(key: &LessSafeKey, packet_number: &mut u32, header: &[u8], payload: &mut [u8]) { + let mut nonce = [0u8; NONCE_LEN]; + nonce[NONCE_LEN - 8..].copy_from_slice(&(*packet_number as u64).to_be_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + let aad = Aad::from(header); + let mut tag = [0u8; 16]; + key.seal_in_place_scatter(nonce, aad, payload, &[][..], &mut tag) + .unwrap(); + + *packet_number += 1; + } + + #[inline(never)] + pub fn send_scatter( + key: &LessSafeKey, + packet_number: &mut u32, + header: &[u8], + payload: &[u8], + out: &mut [u8], + ) { + let mut nonce = [0u8; NONCE_LEN]; + nonce[NONCE_LEN - 8..].copy_from_slice(&(*packet_number as u64).to_be_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + let aad = Aad::from(header); + key.seal_in_place_scatter(nonce, aad, &mut [][..], payload, out) + .unwrap(); + + *packet_number += 1; + } +} +*/ diff --git a/dc/s2n-quic-dc/benches/src/lib.rs b/dc/s2n-quic-dc/benches/src/lib.rs new file mode 100644 index 000000000..09ded95a8 --- /dev/null +++ b/dc/s2n-quic-dc/benches/src/lib.rs @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::Criterion; + +pub mod crypto; +pub mod datagram; + +pub fn benchmarks(c: &mut Criterion) { + crypto::benchmarks(c); + datagram::benchmarks(c); +} diff --git a/dc/s2n-quic-dc/src/allocator.rs b/dc/s2n-quic-dc/src/allocator.rs new file mode 100644 index 000000000..6a0299794 --- /dev/null +++ b/dc/s2n-quic-dc/src/allocator.rs @@ -0,0 +1,39 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic_core::inet::{ExplicitCongestionNotification, SocketAddress}; + +pub trait Allocator { + type Segment: Segment; + type Retransmission: Segment; + + fn alloc(&mut self) -> Option; + + fn get<'a>(&'a self, segment: &'a Self::Segment) -> &'a Vec; + fn get_mut<'a>(&'a mut self, segment: &'a Self::Segment) -> &'a mut Vec; + + fn push(&mut self, segment: Self::Segment); + fn push_with_retransmission(&mut self, segment: Self::Segment) -> Self::Retransmission; + fn retransmit(&mut self, segment: Self::Retransmission) -> Self::Segment; + fn retransmit_copy(&mut self, retransmission: &Self::Retransmission) -> Option; + + fn can_push(&self) -> bool; + fn is_empty(&self) -> bool; + fn segment_len(&self) -> Option; + + fn free(&mut self, segment: Self::Segment); + fn free_retransmission(&mut self, segment: Self::Retransmission); + + fn ecn(&self) -> ExplicitCongestionNotification; + fn set_ecn(&mut self, ecn: ExplicitCongestionNotification); + + fn remote_address(&self) -> SocketAddress; + fn set_remote_address(&mut self, addr: SocketAddress); + fn set_remote_port(&mut self, port: u16); + + fn force_clear(&mut self); +} + +pub trait Segment: 'static + Send + core::fmt::Debug { + fn leak(&mut self); +} diff --git a/dc/s2n-quic-dc/src/congestion.rs b/dc/s2n-quic-dc/src/congestion.rs new file mode 100644 index 000000000..84719a2cd --- /dev/null +++ b/dc/s2n-quic-dc/src/congestion.rs @@ -0,0 +1,174 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic_core::{ + event, random, + recovery::{ + bbr::BbrCongestionController, congestion_controller::Publisher, CongestionController, + RttEstimator, + }, + time::{timer, Timestamp}, +}; + +pub type PacketInfo = ::PacketInfo; + +#[derive(Clone, Debug)] +pub struct Controller { + controller: BbrCongestionController, +} + +impl Controller { + #[inline] + pub fn new(mtu: u16) -> Self { + let mut controller = BbrCongestionController::new(mtu); + let publisher = &mut NoopPublisher; + controller.on_mtu_update(mtu, publisher); + Self { controller } + } + + #[inline] + pub fn on_packet_sent( + &mut self, + time_sent: Timestamp, + sent_bytes: u16, + has_more_app_data: bool, + rtt_estimator: &RttEstimator, + ) -> PacketInfo { + let sent_bytes = sent_bytes as usize; + let app_limited = Some(!has_more_app_data); + let publisher = &mut NoopPublisher; + self.controller + .on_packet_sent(time_sent, sent_bytes, app_limited, rtt_estimator, publisher) + } + + #[inline] + pub fn on_packet_ack( + &mut self, + newest_acked_time_sent: Timestamp, + bytes_acked: usize, + newest_acked_packet_info: PacketInfo, + rtt_estimator: &RttEstimator, + random_generator: &mut dyn random::Generator, + ack_receive_time: Timestamp, + ) { + let publisher = &mut NoopPublisher; + self.controller.on_ack( + newest_acked_time_sent, + bytes_acked, + newest_acked_packet_info, + rtt_estimator, + random_generator, + ack_receive_time, + publisher, + ) + } + + #[inline] + pub fn on_explicit_congestion(&mut self, ce_count: u64, now: Timestamp) { + let publisher = &mut NoopPublisher; + self.controller + .on_explicit_congestion(ce_count, now, publisher); + } + + #[inline] + pub fn on_packet_lost( + &mut self, + bytes_lost: u32, + packet_info: PacketInfo, + random_generator: &mut dyn random::Generator, + now: Timestamp, + ) { + // TODO where do these come from? + let persistent_congestion = false; + let new_loss_burst = false; + + let publisher = &mut NoopPublisher; + self.controller.on_packet_lost( + bytes_lost, + packet_info, + persistent_congestion, + new_loss_burst, + random_generator, + now, + publisher, + ); + } + + #[inline] + pub fn is_congestion_limited(&self) -> bool { + self.controller.is_congestion_limited() + } + + #[inline] + pub fn requires_fast_retransmission(&self) -> bool { + self.controller.requires_fast_retransmission() + } + + #[inline] + pub fn congestion_window(&self) -> u32 { + self.controller.congestion_window() + } + + #[inline] + pub fn bytes_in_flight(&self) -> u32 { + self.controller.bytes_in_flight() + } + + #[inline] + pub fn send_quantum(&self) -> usize { + self.controller.send_quantum().unwrap_or(usize::MAX) + } + + #[inline] + pub fn earliest_departure_time(&self) -> Option { + self.controller.earliest_departure_time() + } +} + +impl timer::Provider for Controller { + #[inline] + fn timers(&self, query: &mut Q) -> timer::Result { + if let Some(time) = self.earliest_departure_time() { + let mut timer = timer::Timer::default(); + timer.set(time); + query.on_timer(&timer)?; + } + Ok(()) + } +} + +struct NoopPublisher; + +impl Publisher for NoopPublisher { + #[inline] + fn on_slow_start_exited( + &mut self, + _cause: event::builder::SlowStartExitCause, + _congestion_window: u32, + ) { + // TODO + } + + #[inline] + fn on_delivery_rate_sampled( + &mut self, + _rate_sample: s2n_quic_core::recovery::bandwidth::RateSample, + ) { + // TODO + } + + #[inline] + fn on_pacing_rate_updated( + &mut self, + _pacing_rate: s2n_quic_core::recovery::bandwidth::Bandwidth, + _burst_size: u32, + _pacing_gain: num_rational::Ratio, + ) { + // TODO + } + + #[inline] + fn on_bbr_state_changed(&mut self, _state: event::builder::BbrState) { + // TODO + } +} diff --git a/dc/s2n-quic-dc/src/control.rs b/dc/s2n-quic-dc/src/control.rs new file mode 100644 index 000000000..b32fe7e81 --- /dev/null +++ b/dc/s2n-quic-dc/src/control.rs @@ -0,0 +1,14 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub trait Controller { + /// Returns the source port to which control/reset messages should be sent + fn source_port(&self) -> u16; +} + +impl Controller for u16 { + #[inline] + fn source_port(&self) -> u16 { + *self + } +} diff --git a/dc/s2n-quic-dc/src/credentials.rs b/dc/s2n-quic-dc/src/credentials.rs new file mode 100644 index 000000000..28b3e9f01 --- /dev/null +++ b/dc/s2n-quic-dc/src/credentials.rs @@ -0,0 +1,135 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::{ + fmt, + ops::{Deref, DerefMut}, +}; +use s2n_codec::{ + decoder_invariant, decoder_value, + zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}, + zerocopy_value_codec, Encoder, EncoderValue, +}; +use s2n_quic_core::{assume, varint::VarInt}; + +#[cfg(any(test, feature = "testing"))] +pub mod testing; + +#[derive( + Clone, + Copy, + Default, + PartialEq, + Eq, + Hash, + AsBytes, + FromBytes, + FromZeroes, + Unaligned, + PartialOrd, + Ord, +)] +#[cfg_attr(test, derive(bolero::TypeGenerator))] +#[repr(C)] +pub struct Id([u8; 16]); + +impl fmt::Debug for Id { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + format_args!("{:#01x}", u128::from_be_bytes(self.0)).fmt(f) + } +} + +impl From<[u8; 16]> for Id { + #[inline] + fn from(v: [u8; 16]) -> Self { + Self(v) + } +} + +impl Deref for Id { + type Target = [u8; 16]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Id { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl s2n_quic_core::probe::Arg for Id { + #[inline] + fn into_usdt(self) -> isize { + // we have to truncate the bytes, but 64 bits should be unique enough for these purposes + let slice = &self.0[..core::mem::size_of::()]; + let bytes = slice.try_into().unwrap(); + usize::from_ne_bytes(bytes).into_usdt() + } +} + +zerocopy_value_codec!(Id); + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(test, derive(bolero::TypeGenerator))] +pub struct Credentials { + pub id: Id, + pub generation_id: u32, + pub sequence_id: u16, +} + +const MAX_VALUE: u64 = 1 << (32 + 16); + +decoder_value!( + impl<'a> Credentials { + fn decode(buffer: Buffer) -> Result { + let (id, buffer) = buffer.decode()?; + let (value, buffer) = buffer.decode::()?; + let value = *value; + decoder_invariant!(value <= MAX_VALUE, "invalid range"); + let generation_id = (value >> 16) as u32; + let sequence_id = value as u16; + Ok(( + Self { + id, + generation_id, + sequence_id, + }, + buffer, + )) + } + } +); + +impl EncoderValue for Credentials { + #[inline] + fn encode(&self, encoder: &mut E) { + self.id.encode(encoder); + let generation_id = (self.generation_id as u64) << 16; + let sequence_id = self.sequence_id as u64; + let value = generation_id | sequence_id; + let value = unsafe { + assume!(value <= MAX_VALUE); + VarInt::new_unchecked(value) + }; + value.encode(encoder) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bolero::check; + use s2n_codec::assert_codec_round_trip_value; + + #[test] + fn round_trip_test() { + check!().with_type::().for_each(|v| { + assert_codec_round_trip_value!(Credentials, v); + }) + } +} diff --git a/dc/s2n-quic-dc/src/credentials/testing.rs b/dc/s2n-quic-dc/src/credentials/testing.rs new file mode 100644 index 000000000..c4c6ca413 --- /dev/null +++ b/dc/s2n-quic-dc/src/credentials/testing.rs @@ -0,0 +1,17 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; + +pub fn new(id: u16, generation_id: u32, sequence_id: u16) -> Credentials { + let id = Id((id as u128).to_be_bytes()); + Credentials { + id, + generation_id, + sequence_id, + } +} + +pub fn iter(id: u16) -> impl Iterator { + (0..).map(move |sequence_id| new(id, 0, sequence_id)) +} diff --git a/dc/s2n-quic-dc/src/crypto.rs b/dc/s2n-quic-dc/src/crypto.rs new file mode 100644 index 000000000..08d2c1f4f --- /dev/null +++ b/dc/s2n-quic-dc/src/crypto.rs @@ -0,0 +1,114 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::credentials::Credentials; +pub use bytes::buf::UninitSlice; +use core::fmt; +pub use s2n_quic_core::packet::KeyPhase; + +pub mod awslc; +#[cfg(any(test, feature = "testing"))] +pub mod testing; + +pub mod encrypt { + use super::*; + + pub trait Key { + fn credentials(&self) -> &Credentials; + + fn tag_len(&self) -> usize; + + /// Encrypt a payload + fn encrypt( + &self, + nonce: N, + header: &[u8], + extra_payload: Option<&[u8]>, + payload_and_tag: &mut [u8], + ); + + fn retransmission_tag( + &self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ); + } +} + +pub mod decrypt { + use super::*; + + #[derive(Clone, Copy, Debug)] + pub enum Error { + ReplayPotentiallyDetected, + ReplayDefinitelyDetected, + InvalidTag, + } + + impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::ReplayDefinitelyDetected => "key replay prevented", + Self::ReplayPotentiallyDetected => "key replay potentially detected", + Self::InvalidTag => "invalid tag", + } + .fmt(f) + } + } + + impl std::error::Error for Error {} + + pub type Result = core::result::Result; + + pub trait Key { + fn credentials(&self) -> &Credentials; + + fn tag_len(&self) -> usize; + + /// Decrypt a payload + fn decrypt( + &mut self, + nonce: N, + header: &[u8], + payload_in: &[u8], + tag: &[u8], + payload_out: &mut UninitSlice, + ) -> Result; + + /// Decrypt a payload + fn decrypt_in_place( + &mut self, + nonce: N, + header: &[u8], + payload_and_tag: &mut [u8], + ) -> Result; + + fn retransmission_tag( + &mut self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ); + } +} + +pub trait IntoNonce { + fn into_nonce(self) -> [u8; 12]; +} + +impl IntoNonce for u64 { + #[inline] + fn into_nonce(self) -> [u8; 12] { + let mut nonce = [0u8; 12]; + nonce[4..].copy_from_slice(&self.to_be_bytes()); + nonce + } +} + +impl IntoNonce for [u8; 12] { + #[inline] + fn into_nonce(self) -> [u8; 12] { + self + } +} diff --git a/dc/s2n-quic-dc/src/crypto/awslc.rs b/dc/s2n-quic-dc/src/crypto/awslc.rs new file mode 100644 index 000000000..d267425f7 --- /dev/null +++ b/dc/s2n-quic-dc/src/crypto/awslc.rs @@ -0,0 +1,334 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::IntoNonce; +use crate::credentials::Credentials; +use aws_lc_rs::aead::{Aad, Algorithm, LessSafeKey, Nonce, UnboundKey, NONCE_LEN}; +use s2n_quic_core::assume; + +pub use aws_lc_rs::aead::{AES_128_GCM, AES_256_GCM}; + +const TAG_LEN: usize = 16; + +#[derive(Debug)] +pub struct EncryptKey { + credentials: Credentials, + key: LessSafeKey, + iv: Iv, +} + +impl EncryptKey { + #[inline] + pub fn new( + credentials: Credentials, + key: &[u8], + iv: [u8; NONCE_LEN], + algorithm: &'static Algorithm, + ) -> Self { + let key = UnboundKey::new(algorithm, key).unwrap(); + let key = LessSafeKey::new(key); + Self { + credentials, + key, + iv: Iv(iv), + } + } +} + +impl super::encrypt::Key for &EncryptKey { + #[inline] + fn credentials(&self) -> &Credentials { + &self.credentials + } + + #[inline(always)] + fn tag_len(&self) -> usize { + debug_assert_eq!(TAG_LEN, self.key.algorithm().tag_len()); + TAG_LEN + } + + #[inline] + fn encrypt( + &self, + nonce: N, + header: &[u8], + extra_payload: Option<&[u8]>, + payload_and_tag: &mut [u8], + ) { + let nonce = self.iv.nonce(nonce); + let aad = Aad::from(header); + + let extra_in = extra_payload.unwrap_or(&[][..]); + + unsafe { + assume!(payload_and_tag.len() >= self.tag_len() + extra_in.len()); + } + + let inline_len = payload_and_tag.len() - self.tag_len() - extra_in.len(); + + unsafe { + assume!(payload_and_tag.len() >= inline_len); + } + let (in_out, extra_out_and_tag) = payload_and_tag.split_at_mut(inline_len); + + let result = + self.key + .seal_in_place_scatter(nonce, aad, in_out, extra_in, extra_out_and_tag); + + unsafe { + assume!(result.is_ok()); + } + } + + #[inline] + fn retransmission_tag( + &self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ) { + retransmission_tag( + &self.key, + &self.iv, + original_packet_number, + retransmission_packet_number, + tag_out, + ) + } +} + +impl super::encrypt::Key for EncryptKey { + #[inline] + fn credentials(&self) -> &Credentials { + &self.credentials + } + + #[inline] + fn tag_len(&self) -> usize { + <&Self as super::encrypt::Key>::tag_len(&self) + } + + #[inline] + fn encrypt( + &self, + nonce: N, + header: &[u8], + extra_payload: Option<&[u8]>, + payload_and_tag: &mut [u8], + ) { + <&Self as super::encrypt::Key>::encrypt( + &self, + nonce, + header, + extra_payload, + payload_and_tag, + ) + } + + #[inline] + fn retransmission_tag( + &self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ) { + <&Self as super::encrypt::Key>::retransmission_tag( + &self, + original_packet_number, + retransmission_packet_number, + tag_out, + ) + } +} + +#[derive(Debug)] +pub struct DecryptKey { + credentials: Credentials, + key: LessSafeKey, + iv: Iv, +} + +impl DecryptKey { + #[inline] + pub fn new( + credentials: Credentials, + key: &[u8], + iv: [u8; NONCE_LEN], + algorithm: &'static Algorithm, + ) -> Self { + let key = UnboundKey::new(algorithm, key).unwrap(); + let key = LessSafeKey::new(key); + Self { + credentials, + key, + iv: Iv(iv), + } + } +} + +impl super::decrypt::Key for &DecryptKey { + #[inline] + fn credentials(&self) -> &Credentials { + &self.credentials + } + + #[inline] + fn tag_len(&self) -> usize { + debug_assert_eq!(TAG_LEN, self.key.algorithm().tag_len()); + TAG_LEN + } + + #[inline] + fn decrypt( + &mut self, + nonce: N, + header: &[u8], + payload_in: &[u8], + tag: &[u8], + payload_out: &mut super::UninitSlice, + ) -> super::decrypt::Result { + debug_assert_eq!(payload_in.len(), payload_out.len()); + + let nonce = self.iv.nonce(nonce); + let aad = Aad::from(header); + + let payload_out = unsafe { + // SAFETY: the payload is not read by aws-lc, only written to + let ptr = payload_out.as_mut_ptr(); + let len = payload_out.len(); + core::slice::from_raw_parts_mut(ptr, len) + }; + + self.key + .open_separate_gather(nonce, aad, payload_in, tag, payload_out) + .map_err(|_| super::decrypt::Error::InvalidTag) + } + + #[inline] + fn decrypt_in_place( + &mut self, + nonce: N, + header: &[u8], + payload_and_tag: &mut [u8], + ) -> super::decrypt::Result { + let nonce = self.iv.nonce(nonce); + let aad = Aad::from(header); + + self.key + .open_in_place(nonce, aad, payload_and_tag) + .map_err(|_| super::decrypt::Error::InvalidTag)?; + + Ok(()) + } + + #[inline] + fn retransmission_tag( + &mut self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ) { + retransmission_tag( + &self.key, + &self.iv, + original_packet_number, + retransmission_packet_number, + tag_out, + ) + } +} + +impl super::decrypt::Key for DecryptKey { + fn credentials(&self) -> &Credentials { + &self.credentials + } + + #[inline] + fn tag_len(&self) -> usize { + <&Self as super::decrypt::Key>::tag_len(&self) + } + + #[inline] + fn decrypt( + &mut self, + nonce: N, + header: &[u8], + payload_in: &[u8], + tag: &[u8], + payload_out: &mut bytes::buf::UninitSlice, + ) -> super::decrypt::Result { + <&Self as super::decrypt::Key>::decrypt( + &mut &*self, + nonce, + header, + payload_in, + tag, + payload_out, + ) + } + + #[inline] + fn decrypt_in_place( + &mut self, + nonce: N, + header: &[u8], + payload_and_tag: &mut [u8], + ) -> super::decrypt::Result { + <&Self as super::decrypt::Key>::decrypt_in_place( + &mut &*self, + nonce, + header, + payload_and_tag, + ) + } + + #[inline] + fn retransmission_tag( + &mut self, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], + ) { + <&Self as super::decrypt::Key>::retransmission_tag( + &mut &*self, + original_packet_number, + retransmission_packet_number, + tag_out, + ) + } +} + +#[inline] +fn retransmission_tag( + key: &LessSafeKey, + iv: &Iv, + original_packet_number: u64, + retransmission_packet_number: u64, + tag_out: &mut [u8], +) { + debug_assert_eq!(tag_out.len(), TAG_LEN); + + let nonce = iv.nonce(retransmission_packet_number); + let aad = original_packet_number.to_be_bytes(); + let aad = Aad::from(&aad); + + let tag = key.seal_in_place_separate_tag(nonce, aad, &mut []).unwrap(); + + for (a, b) in tag_out.iter_mut().zip(tag.as_ref()) { + *a ^= b; + } +} + +#[derive(Debug)] +struct Iv([u8; NONCE_LEN]); + +impl Iv { + #[inline] + fn nonce(&self, nonce: N) -> Nonce { + let mut nonce = nonce.into_nonce(); + for (dst, src) in nonce.iter_mut().zip(&self.0) { + *dst ^= src; + } + Nonce::assume_unique_for_key(nonce) + } +} diff --git a/dc/s2n-quic-dc/src/crypto/testing.rs b/dc/s2n-quic-dc/src/crypto/testing.rs new file mode 100644 index 000000000..a37ba9fd0 --- /dev/null +++ b/dc/s2n-quic-dc/src/crypto/testing.rs @@ -0,0 +1,109 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::IntoNonce; +use crate::credentials::Credentials; +use s2n_quic_core::assume; + +#[derive(Clone, Debug)] +pub struct Key { + credentials: Credentials, + tag_len: usize, +} + +impl Key { + #[inline] + pub fn new(credentials: Credentials) -> Self { + Self { + credentials, + tag_len: 16, + } + } +} + +impl super::encrypt::Key for Key { + #[inline] + fn credentials(&self) -> &Credentials { + &self.credentials + } + + #[inline] + fn tag_len(&self) -> usize { + self.tag_len + } + + #[inline] + fn encrypt( + &self, + _nonce: N, + _header: &[u8], + extra_payload: Option<&[u8]>, + payload_and_tag: &mut [u8], + ) { + if let Some(extra_payload) = extra_payload { + let offset = payload_and_tag.len() - self.tag_len() - extra_payload.len(); + let dest = &mut payload_and_tag[offset..]; + unsafe { + assume!(dest.len() == extra_payload.len() + self.tag_len); + } + let (dest, tag) = dest.split_at_mut(extra_payload.len()); + dest.copy_from_slice(extra_payload); + tag.fill(0); + } + } + + #[inline] + fn retransmission_tag( + &self, + _original_packet_number: u64, + _retransmission_packet_number: u64, + _tag_out: &mut [u8], + ) { + // no-op + } +} + +impl super::decrypt::Key for Key { + #[inline] + fn credentials(&self) -> &Credentials { + &self.credentials + } + + #[inline] + fn tag_len(&self) -> usize { + self.tag_len + } + + #[inline] + fn decrypt( + &mut self, + _nonce: N, + _header: &[u8], + payload_in: &[u8], + _tag: &[u8], + payload_out: &mut bytes::buf::UninitSlice, + ) -> Result<(), super::decrypt::Error> { + payload_out.copy_from_slice(payload_in); + Ok(()) + } + + #[inline] + fn decrypt_in_place( + &mut self, + _nonce: N, + _header: &[u8], + _payload_and_tag: &mut [u8], + ) -> Result<(), super::decrypt::Error> { + Ok(()) + } + + #[inline] + fn retransmission_tag( + &mut self, + _original_packet_number: u64, + _retransmission_packet_number: u64, + _tag_out: &mut [u8], + ) { + // no-op + } +} diff --git a/dc/s2n-quic-dc/src/lib.rs b/dc/s2n-quic-dc/src/lib.rs new file mode 100644 index 000000000..18caaacf7 --- /dev/null +++ b/dc/s2n-quic-dc/src/lib.rs @@ -0,0 +1,14 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod allocator; +pub mod congestion; +pub mod control; +pub mod credentials; +pub mod crypto; +pub mod msg; +pub mod packet; +pub mod path; +pub mod pool; +pub mod recovery; +pub mod stream; diff --git a/dc/s2n-quic-dc/src/msg.rs b/dc/s2n-quic-dc/src/msg.rs new file mode 100644 index 000000000..d480a3fbf --- /dev/null +++ b/dc/s2n-quic-dc/src/msg.rs @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod addr; +mod cmsg; +pub mod recv; +pub mod send; diff --git a/dc/s2n-quic-dc/src/msg/addr.rs b/dc/s2n-quic-dc/src/msg/addr.rs new file mode 100644 index 000000000..d003d6850 --- /dev/null +++ b/dc/s2n-quic-dc/src/msg/addr.rs @@ -0,0 +1,164 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::{fmt, mem::size_of}; +use libc::{msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6}; +use s2n_quic_core::{ + assume, + inet::{self, SocketAddress}, +}; + +const SIZE: usize = { + let v4 = size_of::(); + let v6 = size_of::(); + if v4 > v6 { + v4 + } else { + v6 + } +}; + +#[repr(align(8))] +pub struct Addr { + msg_name: [u8; SIZE], + msg_namelen: u8, +} + +impl fmt::Debug for Addr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl fmt::Display for Addr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl Default for Addr { + #[inline] + fn default() -> Self { + Self::new(SocketAddress::default()) + } +} + +impl Addr { + #[inline] + pub fn new(value: SocketAddress) -> Self { + let mut v = Self { + msg_name: Default::default(), + msg_namelen: Default::default(), + }; + v.set(value); + v + } + + #[inline] + pub fn get(&self) -> SocketAddress { + match self.msg_namelen as usize { + size if size == size_of::() => { + let sockaddr: &sockaddr_in = unsafe { &*(self.msg_name.as_ptr() as *const _) }; + let port = sockaddr.sin_port.to_be(); + let addr: inet::IpV4Address = sockaddr.sin_addr.s_addr.to_ne_bytes().into(); + inet::SocketAddressV4::new(addr, port).into() + } + size if size == size_of::() => { + let sockaddr: &sockaddr_in6 = unsafe { &*(self.msg_name.as_ptr() as *const _) }; + let port = sockaddr.sin6_port.to_be(); + let addr: inet::IpV6Address = sockaddr.sin6_addr.s6_addr.into(); + inet::SocketAddressV6::new(addr, port).into() + } + _ => unsafe { + assume!(false, "invalid remote address"); + unreachable!() + }, + } + } + + #[inline] + pub fn set(&mut self, remote_address: SocketAddress) { + match remote_address { + SocketAddress::IpV4(addr) => { + let sockaddr: &mut sockaddr_in = + unsafe { &mut *(self.msg_name.as_mut_ptr() as *mut _) }; + sockaddr.sin_family = AF_INET as _; + sockaddr.sin_port = addr.port().to_be(); + sockaddr.sin_addr.s_addr = u32::from_ne_bytes((*addr.ip()).into()); + self.msg_namelen = size_of::() as _; + } + SocketAddress::IpV6(addr) => { + let sockaddr: &mut sockaddr_in6 = + unsafe { &mut *(self.msg_name.as_mut_ptr() as *mut _) }; + sockaddr.sin6_family = AF_INET6 as _; + sockaddr.sin6_port = addr.port().to_be(); + sockaddr.sin6_addr.s6_addr = (*addr.ip()).into(); + self.msg_namelen = size_of::() as _; + } + } + } + + #[inline] + pub fn set_port(&mut self, port: u16) { + match self.msg_namelen as usize { + size if size == size_of::() => { + let sockaddr: &mut sockaddr_in = + unsafe { &mut *(self.msg_name.as_mut_ptr() as *mut _) }; + sockaddr.sin_port = port.to_be(); + } + size if size == size_of::() => { + let sockaddr: &mut sockaddr_in6 = + unsafe { &mut *(self.msg_name.as_mut_ptr() as *mut _) }; + sockaddr.sin6_port = port.to_be(); + } + _ => unsafe { + assume!(false, "invalid remote address"); + unreachable!() + }, + } + } + + #[inline] + pub fn send_with_msg(&mut self, msg: &mut msghdr) { + msg.msg_name = self.msg_name.as_mut_ptr() as *mut _; + msg.msg_namelen = self.msg_namelen as _; + } + + #[inline] + pub fn recv_with_msg(&mut self, msg: &mut msghdr) { + msg.msg_name = self.msg_name.as_mut_ptr() as *mut _; + // use the max size, in case the length changes + msg.msg_namelen = self.msg_name.len() as _; + } + + #[inline] + pub fn update_with_msg(&mut self, msg: &msghdr) { + debug_assert_eq!(self.msg_name.as_ptr(), msg.msg_name as *const u8); + match msg.msg_namelen as usize { + size if size == size_of::() => { + self.msg_namelen = size as _; + } + size if size == size_of::() => { + self.msg_namelen = size as _; + } + _ => { + unreachable!("invalid remote address") + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bolero::check; + + #[test] + fn set_port_test() { + check!().with_type().cloned().for_each(|(addr, port)| { + let mut addr = Addr::new(addr); + addr.set_port(port); + assert_eq!(addr.get().port(), port); + }); + } +} diff --git a/dc/s2n-quic-dc/src/msg/cmsg.rs b/dc/s2n-quic-dc/src/msg/cmsg.rs new file mode 100644 index 000000000..ee7119edd --- /dev/null +++ b/dc/s2n-quic-dc/src/msg/cmsg.rs @@ -0,0 +1,101 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use libc::msghdr; +use s2n_quic_core::{ensure, inet::ExplicitCongestionNotification}; +use s2n_quic_platform::{features, message::cmsg}; + +pub use cmsg::*; + +pub const ENCODER_LEN: usize = { + // TODO calculate based on platform support + 128 +}; + +pub const DECODER_LEN: usize = { + // TODO calculate based on platform support + 128 +}; + +pub const MAX_GRO_SEGMENTS: usize = features::gro::MAX_SEGMENTS; + +#[derive(Debug, Default)] +pub struct Receiver { + ecn: ExplicitCongestionNotification, + segment_len: u16, +} + +impl Receiver { + #[inline] + pub fn with_msg(&mut self, msg: &msghdr, buffer_len: usize) { + // assume we didn't get a GRO cmsg initially + self.segment_len = buffer_len as _; + + ensure!(!msg.msg_control.is_null()); + ensure!(msg.msg_controllen > 0); + + let iter = unsafe { + // SAFETY: the msghdr controllen should be aligned + cmsg::decode::Iter::from_msghdr(msg) + }; + + for (cmsg, value) in iter { + match (cmsg.cmsg_level, cmsg.cmsg_type) { + (level, ty) if features::tos::is_match(level, ty) => { + if let Some(ecn) = features::tos::decode(value) { + // TODO remove this conversion once we consolidate the s2n-quic-core crates + // convert between the vendored s2n-quic-core types + let ecn = { + let ecn = ecn as u8; + ExplicitCongestionNotification::new(ecn) + }; + self.ecn = ecn; + } else { + continue; + } + } + (level, ty) if features::gso::is_match(level, ty) => { + // ignore GSO settings when reading + continue; + } + (level, ty) if features::gro::is_match(level, ty) => { + if let Some(segment_size) = + unsafe { cmsg::decode::value_from_bytes::(value) } + { + self.segment_len = segment_size as _; + } else { + continue; + } + } + _ => { + continue; + } + } + } + } + + #[inline] + pub fn ecn(&self) -> ExplicitCongestionNotification { + self.ecn + } + + #[inline] + pub fn set_ecn(&mut self, ecn: ExplicitCongestionNotification) { + self.ecn = ecn; + } + + #[inline] + pub fn segment_len(&self) -> u16 { + self.segment_len + } + + #[inline] + pub fn set_segment_len(&mut self, len: u16) { + self.segment_len = len; + } + + #[inline] + pub fn take_segment_len(&mut self) -> u16 { + core::mem::replace(&mut self.segment_len, 0) + } +} diff --git a/dc/s2n-quic-dc/src/msg/recv.rs b/dc/s2n-quic-dc/src/msg/recv.rs new file mode 100644 index 000000000..f778118c3 --- /dev/null +++ b/dc/s2n-quic-dc/src/msg/recv.rs @@ -0,0 +1,196 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{addr::Addr, cmsg}; +use core::fmt; +use libc::{iovec, msghdr, recvmsg}; +use s2n_quic_core::{ + assume, branch, ensure, + inet::{ExplicitCongestionNotification, SocketAddress}, + path::MaxMtu, +}; +use std::{io, os::fd::AsRawFd}; +use tracing::trace; + +pub struct Message { + addr: Addr, + buffer: Vec, + recv: cmsg::Receiver, + payload_len: usize, +} + +impl fmt::Debug for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Message") + .field("addr", &self.addr) + .field("segment_len", &self.recv.segment_len()) + .field("payload_len", &self.payload_len) + .field("ecn", &self.recv.ecn()) + .finish() + } +} + +impl Message { + #[inline] + pub fn new(max_mtu: MaxMtu) -> Self { + let max_mtu: u16 = max_mtu.into(); + let max_mtu = max_mtu as usize; + let buffer_len = cmsg::MAX_GRO_SEGMENTS * max_mtu; + // the recv syscall doesn't return more than this + let buffer_len = buffer_len.min(u16::MAX as _); + let buffer = Vec::with_capacity(buffer_len); + Self { + addr: Addr::default(), + buffer, + recv: Default::default(), + payload_len: 0, + } + } + + #[inline] + pub fn remote_address(&self) -> SocketAddress { + self.addr.get() + } + + #[inline] + pub fn ecn(&self) -> ExplicitCongestionNotification { + self.recv.ecn() + } + + #[inline] + pub fn segments(&mut self) -> impl Iterator { + let payload_len = core::mem::replace(&mut self.payload_len, 0); + Segments { + buffer: &mut self.buffer[..payload_len], + segment_len: self.recv.take_segment_len(), + } + } + + #[inline] + pub fn peek_segments(&mut self) -> impl Iterator { + Segments { + buffer: &mut self.buffer[..self.payload_len], + segment_len: self.recv.segment_len(), + } + } + + #[inline] + pub fn take(&mut self) -> Self { + let capacity = self.buffer.capacity(); + core::mem::replace( + self, + Self { + addr: Addr::default(), + buffer: Vec::with_capacity(capacity), + recv: Default::default(), + payload_len: 0, + }, + ) + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.recv.segment_len() == 0 + } + + #[inline] + pub fn clear(&mut self) { + self.buffer.clear(); + self.payload_len = 0; + self.recv.set_segment_len(0); + } + + #[inline] + pub fn recv(&mut self, s: &S) -> io::Result<()> { + let mut msg = unsafe { core::mem::zeroed::() }; + + let mut iovec = unsafe { core::mem::zeroed::() }; + + iovec.iov_base = self.buffer.as_mut_ptr() as *mut _; + iovec.iov_len = self.buffer.capacity() as _; + + msg.msg_iov = &mut iovec; + msg.msg_iovlen = 1; + + self.addr.recv_with_msg(&mut msg); + + let mut cmsg = cmsg::Storage::<{ cmsg::DECODER_LEN }>::default(); + msg.msg_control = cmsg.as_mut_ptr() as *mut _; + msg.msg_controllen = cmsg.len() as _; + + let flags = Default::default(); + + let result = unsafe { recvmsg(s.as_raw_fd(), &mut msg, flags) }; + + self.addr.update_with_msg(&msg); + + trace!( + src = %self.addr, + cmsg_len = msg.msg_controllen, + result, + ); + + if !branch!(result > 0) { + let error = io::Error::last_os_error(); + + if !matches!( + error.kind(), + io::ErrorKind::WouldBlock | io::ErrorKind::Interrupted + ) { + tracing::error!(?error); + } + + return Err(error); + } + + let len = result as usize; + + unsafe { + assume!(self.buffer.capacity() >= len); + self.buffer.set_len(len); + } + + self.payload_len = len; + self.recv.with_msg(&msg, len); + + Ok(()) + } + + #[inline] + pub fn test_recv( + &mut self, + remote_addr: SocketAddress, + ecn: ExplicitCongestionNotification, + payload: Vec, + ) { + debug_assert!(self.is_empty()); + self.addr.set(remote_addr); + self.recv.set_ecn(ecn); + self.recv + .set_segment_len(payload.len().try_into().expect("payload too large")); + self.payload_len = payload.len(); + self.buffer = payload; + } +} + +pub struct Segments<'a> { + buffer: &'a mut [u8], + segment_len: u16, +} + +impl<'a> Iterator for Segments<'a> { + type Item = &'a mut [u8]; + + #[inline] + fn next(&mut self) -> Option { + let len = self.buffer.len().min(self.segment_len as _); + ensure!(len > 0, None); + let (head, tail) = self.buffer.split_at_mut(len); + let (head, tail) = unsafe { + // SAFETY: we're just extending the lifetime of this split off segment + core::mem::transmute((head, tail)) + }; + self.buffer = tail; + Some(head) + } +} diff --git a/dc/s2n-quic-dc/src/msg/send.rs b/dc/s2n-quic-dc/src/msg/send.rs new file mode 100644 index 000000000..dada718ff --- /dev/null +++ b/dc/s2n-quic-dc/src/msg/send.rs @@ -0,0 +1,615 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{addr::Addr, cmsg}; +use crate::allocator::{self, Allocator}; +use core::{fmt, num::NonZeroU16}; +use libc::{iovec, msghdr, sendmsg}; +use s2n_quic_core::{ + assume, ensure, + inet::{ExplicitCongestionNotification, SocketAddress, Unspecified}, +}; +use s2n_quic_platform::features; +use std::{io, os::fd::AsRawFd}; +use tracing::trace; + +type Idx = u16; +type RetransmissionIdx = NonZeroU16; + +#[cfg(debug_assertions)] +type Instance = u64; +#[cfg(not(debug_assertions))] +type Instance = (); + +#[inline(always)] +fn instance_id() -> Instance { + #[cfg(debug_assertions)] + { + use core::sync::atomic::{AtomicU64, Ordering}; + static INSTANCES: AtomicU64 = AtomicU64::new(0); + INSTANCES.fetch_add(1, Ordering::Relaxed) + } +} + +#[derive(Debug)] +pub struct Segment { + idx: Idx, + instance_id: Instance, +} + +impl Segment { + #[inline(always)] + fn get<'a>(&'a self, buffers: &'a [Vec]) -> &'a Vec { + unsafe { + assume!(buffers.len() > self.idx as usize); + } + &buffers[self.idx as usize] + } + + #[inline(always)] + fn get_mut<'a>(&self, buffers: &'a mut [Vec]) -> &'a mut Vec { + unsafe { + assume!(buffers.len() > self.idx as usize); + } + &mut buffers[self.idx as usize] + } +} + +impl allocator::Segment for Segment { + #[inline] + fn leak(&mut self) { + self.idx = Idx::MAX; + } +} + +#[cfg(debug_assertions)] +impl Drop for Segment { + fn drop(&mut self) { + if self.idx != Idx::MAX && !std::thread::panicking() { + panic!("message segment {} leaked", self.idx); + } + } +} + +#[derive(Debug)] +pub struct Retransmission { + idx: RetransmissionIdx, + instance_id: Instance, +} + +impl allocator::Segment for Retransmission { + #[inline] + fn leak(&mut self) { + self.idx = unsafe { RetransmissionIdx::new_unchecked(Idx::MAX) }; + } +} + +impl Retransmission { + #[inline(always)] + fn idx(&self) -> Idx { + self.idx.get() - 1 + } + + #[inline(always)] + fn get<'a>(&'a self, buffers: &'a [Vec]) -> &'a Vec { + let idx = self.idx() as usize; + unsafe { + assume!(buffers.len() > idx); + } + &buffers[idx] + } + + #[inline] + fn into_segment(mut self) -> Segment { + let idx = core::mem::replace(&mut self.idx, unsafe { + RetransmissionIdx::new_unchecked(Idx::MAX) + }); + let idx = idx.get() - 1; + let instance_id = self.instance_id; + Segment { idx, instance_id } + } + + #[inline] + fn from_segment(mut handle: Segment) -> Self { + let idx = core::mem::replace(&mut handle.idx, Idx::MAX); + let idx = idx.saturating_add(1); + let idx = unsafe { RetransmissionIdx::new_unchecked(idx) }; + let instance_id = handle.instance_id; + Retransmission { idx, instance_id } + } +} + +#[cfg(debug_assertions)] +impl Drop for Retransmission { + fn drop(&mut self) { + if self.idx.get() != Idx::MAX && !std::thread::panicking() { + panic!("message segment {} leaked", self.idx.get()); + } + } +} + +pub struct Message { + addr: Addr, + gso: features::Gso, + segment_len: u16, + total_len: u16, + can_push: bool, + buffers: Vec>, + free: Vec, + pending_free: Vec, + payload: Vec, + ecn: ExplicitCongestionNotification, + instance_id: Instance, + #[cfg(debug_assertions)] + allocated: std::collections::BTreeSet, +} + +impl fmt::Debug for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut d = f.debug_struct("Message"); + + d.field("addr", &self.addr) + .field("segment_len", &self.segment_len) + .field("total_len", &self.total_len) + .field("can_push", &self.can_push) + .field("buffers", &self.buffers.len()) + .field("free", &self.free.len()) + .field("pending_free", &self.pending_free.len()) + .field("segments", &self.payload.len()) + .field("ecn", &self.ecn); + + #[cfg(debug_assertions)] + { + d.field("instance_id", &self.instance_id) + .field("allocated", &self.allocated.len()); + } + + d.finish() + } +} + +unsafe impl Send for Message {} +unsafe impl Sync for Message {} + +impl Message { + #[inline] + pub fn new(remote_address: SocketAddress, gso: features::Gso) -> Self { + let burst_size = 16; + Self { + addr: Addr::new(remote_address), + gso, + segment_len: 0, + total_len: 0, + can_push: true, + buffers: Vec::with_capacity(burst_size), + free: Vec::with_capacity(burst_size), + pending_free: Vec::with_capacity(burst_size), + payload: Vec::with_capacity(burst_size), + ecn: ExplicitCongestionNotification::NotEct, + instance_id: instance_id(), + #[cfg(debug_assertions)] + allocated: Default::default(), + } + } + + #[inline] + fn push_payload(&mut self, segment: &Segment) { + debug_assert!(self.can_push()); + debug_assert_eq!(segment.instance_id, self.instance_id); + + let mut iovec = unsafe { core::mem::zeroed::() }; + let buffer = segment.get_mut(&mut self.buffers); + + debug_assert!(!buffer.is_empty()); + debug_assert!( + buffer.len() <= u16::MAX as usize, + "cannot transmit more than 2^16 bytes in a single packet" + ); + + let iov_base: *mut u8 = buffer.as_mut_ptr(); + iovec.iov_base = iov_base as *mut _; + iovec.iov_len = buffer.len() as _; + + self.total_len += buffer.len() as u16; + + if self.payload.is_empty() { + self.segment_len = buffer.len() as _; + } else { + debug_assert!(buffer.len() <= self.segment_len as usize); + // the caller can only push until the last undersized segment + self.can_push &= buffer.len() == self.segment_len as usize; + } + + self.payload.push(iovec); + + let max_segments = self.gso.max_segments(); + + self.can_push &= self.payload.len() < max_segments; + + // sendmsg has a limitation on the total length of the payload, even with GSO + let next_size = self.total_len as usize + self.segment_len as usize; + let max_size = u16::MAX as usize; + self.can_push &= next_size <= max_size; + } + + #[inline] + pub fn send(&mut self, s: &S) -> io::Result<()> { + use cmsg::Encoder as _; + + let mut msg = unsafe { core::mem::zeroed::() }; + + msg.msg_iov = self.payload.as_mut_ptr(); + msg.msg_iovlen = self.payload.len() as _; + + debug_assert!( + !self.addr.get().ip().is_unspecified(), + "cannot send packet to unspecified address" + ); + debug_assert!( + self.addr.get().port() != 0, + "cannot send packet to unspecified port" + ); + self.addr.send_with_msg(&mut msg); + + let mut cmsg_storage = cmsg::Storage::<{ cmsg::ENCODER_LEN }>::default(); + let mut cmsg = cmsg_storage.encoder(); + if self.ecn != ExplicitCongestionNotification::NotEct { + // TODO enable this once we consolidate s2n-quic-core crates + // let _ = cmsg.encode_ecn(ecn, &addr); + } + + if msg.msg_iovlen > 1 { + let _ = cmsg.encode_gso(self.segment_len); + } + + if !cmsg.is_empty() { + msg.msg_control = cmsg.as_mut_ptr() as *mut _; + msg.msg_controllen = cmsg.len() as _; + } + + let flags = Default::default(); + + let result = unsafe { sendmsg(s.as_raw_fd(), &msg, flags) }; + + trace!( + dest = %self.addr, + segments = self.payload.len(), + segment_len = self.segment_len, + cmsg_len = msg.msg_controllen, + result, + ); + + if result >= 0 { + self.on_transmit(); + Ok(()) + } else { + let error = io::Error::last_os_error(); + self.gso.handle_socket_error(&error); + + if !matches!( + error.kind(), + io::ErrorKind::WouldBlock | io::ErrorKind::Interrupted + ) { + tracing::error!(?error); + + // if we got an unrecoverable error we need to clear the queue to make progress + self.force_clear(); + } + + Err(error) + } + } + + #[inline] + pub fn drain(&mut self) -> Drain { + Drain { + message: self, + index: 0, + } + } + + #[inline] + fn on_transmit(&mut self) { + // reset the current payload + self.payload.clear(); + self.ecn = ExplicitCongestionNotification::NotEct; + self.segment_len = 0; + self.total_len = 0; + self.can_push = true; + + for segment in &self.pending_free { + segment.get_mut(&mut self.buffers).clear(); + #[cfg(debug_assertions)] + assert!(self.allocated.remove(&segment.idx)); + } + + if self.free.is_empty() { + core::mem::swap(&mut self.free, &mut self.pending_free); + } else { + self.free.append(&mut self.pending_free); + } + } +} + +impl Allocator for Message { + type Segment = Segment; + + type Retransmission = Retransmission; + + #[inline] + fn alloc(&mut self) -> Option { + ensure!(self.can_push(), None); + + if let Some(segment) = self.free.pop() { + #[cfg(debug_assertions)] + assert!(self.allocated.insert(segment.idx)); + trace!(operation = "alloc", ?segment); + return Some(segment); + } + + let idx = self.buffers.len().try_into().ok()?; + let instance_id = self.instance_id; + let segment = Segment { idx, instance_id }; + self.buffers.push(vec![]); + + #[cfg(debug_assertions)] + assert!(self.allocated.insert(segment.idx)); + trace!(operation = "alloc", ?segment); + + Some(segment) + } + + #[inline] + fn get<'a>(&'a self, segment: &'a Segment) -> &'a Vec { + debug_assert_eq!(segment.instance_id, self.instance_id); + + #[cfg(debug_assertions)] + assert!(self.allocated.contains(&segment.idx)); + + segment.get(&self.buffers) + } + + #[inline] + fn get_mut(&mut self, segment: &Segment) -> &mut Vec { + debug_assert_eq!(segment.instance_id, self.instance_id); + + #[cfg(debug_assertions)] + assert!(self.allocated.contains(&segment.idx)); + + segment.get_mut(&mut self.buffers) + } + + #[inline] + fn push(&mut self, segment: Segment) { + trace!(operation = "push", ?segment); + self.push_payload(&segment); + + #[cfg(debug_assertions)] + assert!(self.allocated.contains(&segment.idx)); + + self.pending_free.push(segment); + } + + #[inline] + fn push_with_retransmission(&mut self, segment: Segment) -> Retransmission { + trace!(operation = "push_with_retransmission", ?segment); + self.push_payload(&segment); + + #[cfg(debug_assertions)] + assert!(self.allocated.contains(&segment.idx)); + + Retransmission::from_segment(segment) + } + + #[inline] + fn retransmit(&mut self, segment: Retransmission) -> Segment { + debug_assert_eq!(segment.instance_id, self.instance_id); + debug_assert!( + self.payload.is_empty(), + "cannot retransmit with pending payload" + ); + + let segment = segment.into_segment(); + + #[cfg(debug_assertions)] + assert!(self.allocated.contains(&segment.idx)); + + segment + } + + #[inline] + fn retransmit_copy(&mut self, retransmission: &Retransmission) -> Option { + debug_assert_eq!(retransmission.instance_id, self.instance_id); + #[cfg(debug_assertions)] + assert!( + self.allocated.contains(&retransmission.idx()), + "{retransmission:?} {self:?}" + ); + + let segment = self.alloc()?; + + let mut target = core::mem::take(self.get_mut(&segment)); + debug_assert!(target.is_empty()); + + let source = retransmission.get(&self.buffers); + debug_assert!( + !source.is_empty(), + "cannot retransmit empty payload; source: {retransmission:?}, target: {segment:?}" + ); + target.extend_from_slice(source); + + *self.get_mut(&segment) = target; + + Some(segment) + } + + #[inline] + fn can_push(&self) -> bool { + self.can_push + } + + #[inline] + fn is_empty(&self) -> bool { + self.payload.is_empty() + } + + #[inline] + fn segment_len(&self) -> Option { + debug_assert_eq!(self.segment_len == 0, self.is_empty()); + if self.segment_len == 0 { + None + } else { + Some(self.segment_len) + } + } + + #[inline] + fn free(&mut self, segment: Segment) { + debug_assert_eq!(segment.instance_id, self.instance_id); + trace!(operation = "free", ?segment); + + #[cfg(debug_assertions)] + assert!(self.allocated.contains(&segment.idx)); + + // if we haven't actually sent anything then immediately free it + if self.is_empty() { + #[cfg(debug_assertions)] + assert!(self.allocated.remove(&segment.idx)); + + self.free.push(segment); + } else { + self.pending_free.push(segment); + } + } + + #[inline] + fn free_retransmission(&mut self, segment: Retransmission) { + debug_assert_eq!(segment.instance_id, self.instance_id); + debug_assert!( + self.payload.is_empty(), + "cannot free a retransmission with pending payload" + ); + + trace!(operation = "free_retransmission", ?segment); + + let segment = segment.into_segment(); + + let buffer = self.get_mut(&segment); + buffer.clear(); + + #[cfg(debug_assertions)] + assert!(self.allocated.remove(&segment.idx)); + + self.free.push(segment); + } + + #[inline] + fn ecn(&self) -> ExplicitCongestionNotification { + self.ecn + } + + #[inline] + fn set_ecn(&mut self, ecn: ExplicitCongestionNotification) { + self.ecn = ecn; + } + + #[inline] + fn remote_address(&self) -> SocketAddress { + self.addr.get() + } + + #[inline] + fn set_remote_address(&mut self, remote_address: SocketAddress) { + self.addr.set(remote_address); + } + + #[inline] + fn set_remote_port(&mut self, port: u16) { + self.addr.set_port(port); + } + + #[inline] + fn force_clear(&mut self) { + self.on_transmit(); + } +} + +#[cfg(debug_assertions)] +impl Drop for Message { + fn drop(&mut self) { + use allocator::Segment; + for segment in &mut self.free { + segment.leak(); + } + for segment in &mut self.pending_free { + segment.leak(); + } + } +} + +pub struct Drain<'a> { + message: &'a mut Message, + index: usize, +} + +impl<'a> Iterator for Drain<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + let v = self.message.payload.get(self.index)?; + self.index += 1; + let v = unsafe { core::slice::from_raw_parts(v.iov_base as *const u8, v.iov_len) }; + Some(v) + } +} + +impl<'a> Drop for Drain<'a> { + #[inline] + fn drop(&mut self) { + self.message.on_transmit(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn happy_path() { + let socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + + let addr: std::net::SocketAddr = "127.0.0.1:4433".parse().unwrap(); + let mut message = Message::new(addr.into(), Default::default()); + + let handle = message.alloc().unwrap(); + let payload = message.get_mut(&handle); + payload.extend_from_slice(b"hello\n"); + let hello = message.push_with_retransmission(handle); + + let world = if message.gso.max_segments() > 1 { + let handle = message.alloc().unwrap(); + let payload = message.get_mut(&handle); + payload.extend_from_slice(b"world\n"); + let world = message.push_with_retransmission(handle); + Some(world) + } else { + None + }; + + message.send(&socket).unwrap(); + + let world = world.map(|world| message.retransmit(world)); + let hello = message.retransmit(hello); + + if let Some(world) = world { + assert_eq!(message.get(&world), b"world\n"); + message.push(world); + } + + assert_eq!(message.get(&hello), b"hello\n"); + message.push(hello); + + message.send(&socket).unwrap(); + } +} diff --git a/dc/s2n-quic-dc/src/packet.rs b/dc/s2n-quic-dc/src/packet.rs new file mode 100644 index 000000000..4e5e1ad6f --- /dev/null +++ b/dc/s2n-quic-dc/src/packet.rs @@ -0,0 +1,71 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_codec::DecoderBufferMut; +use s2n_quic_core::varint::VarInt; + +pub type PacketNumber = VarInt; +pub type HeaderLen = VarInt; +pub type PayloadLen = VarInt; + +#[macro_use] +pub mod tag; + +pub mod control; +pub mod datagram; +pub mod secret_control; +pub mod stream; + +pub use tag::Tag; + +#[derive(Debug)] +pub enum Packet<'a> { + Stream(stream::decoder::Packet<'a>), + Datagram(datagram::decoder::Packet<'a>), + Control(control::decoder::Packet<'a>), + StaleKey(secret_control::stale_key::Packet<'a>), + ReplayDetected(secret_control::replay_detected::Packet<'a>), + RequestShards(secret_control::request_shards::Packet<'a>), + UnknownPathSecret(secret_control::unknown_path_secret::Packet<'a>), +} + +impl<'a> s2n_codec::DecoderParameterizedValueMut<'a> for Packet<'a> { + type Parameter = usize; + + #[inline] + fn decode_parameterized_mut( + tag_len: Self::Parameter, + decoder: DecoderBufferMut<'a>, + ) -> s2n_codec::DecoderBufferMutResult { + match decoder.peek().decode().map(|(tag, _)| tag)? { + Tag::Control(_) => { + let (packet, decoder) = control::decoder::Packet::decode(decoder, (), tag_len)?; + Ok((Self::Control(packet), decoder)) + } + Tag::Stream(_) => { + let (packet, decoder) = stream::decoder::Packet::decode(decoder, (), tag_len)?; + Ok((Self::Stream(packet), decoder)) + } + Tag::Datagram(_) => { + let (packet, decoder) = datagram::decoder::Packet::decode(decoder, (), tag_len)?; + Ok((Self::Datagram(packet), decoder)) + } + Tag::StaleKey(_) => { + // TODO + todo!() + } + Tag::ReplayDetected(_) => { + // TODO + todo!() + } + Tag::RequestShards(_) => { + // TODO + todo!() + } + Tag::UnknownPathSecret(_) => { + // TODO + todo!() + } + } + } +} diff --git a/dc/s2n-quic-dc/src/packet/control.rs b/dc/s2n-quic-dc/src/packet/control.rs new file mode 100644 index 000000000..a523df593 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/control.rs @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::tag::Common; +use core::fmt; +use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; + +const NONCE_MASK: u64 = 1 << 63; + +pub mod decoder; +pub mod encoder; + +#[derive(Clone, Copy, PartialEq, Eq, AsBytes, FromBytes, FromZeroes, Unaligned)] +#[repr(C)] +pub struct Tag(Common); + +impl_tag_codec!(Tag); + +impl Default for Tag { + #[inline] + fn default() -> Self { + Self(Common(0b0101_0000)) + } +} + +impl fmt::Debug for Tag { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("control::Tag") + .field("is_stream", &self.is_stream()) + .field("has_application_header", &self.has_application_header()) + .finish() + } +} + +impl Tag { + const IS_STREAM_MASK: u8 = 0b0010; + const HAS_APPLICATION_HEADER_MASK: u8 = 0b00_0001; + + pub const MIN: u8 = 0b0101_0000; + pub const MAX: u8 = 0b0101_1111; + + #[inline] + pub fn set_is_stream(&mut self, enabled: bool) { + self.0.set(Self::IS_STREAM_MASK, enabled) + } + + #[inline] + pub fn is_stream(&self) -> bool { + self.0.get(Self::IS_STREAM_MASK) + } + + #[inline] + pub fn set_has_application_header(&mut self, enabled: bool) { + self.0.set(Self::HAS_APPLICATION_HEADER_MASK, enabled) + } + + #[inline] + pub fn has_application_header(&self) -> bool { + self.0.get(Self::HAS_APPLICATION_HEADER_MASK) + } + + #[inline] + fn validate(&self) -> Result<(), s2n_codec::DecoderError> { + let range = Self::MIN..=Self::MAX; + debug_assert!(range.contains(&(self.0).0), "{:?} {:?}", self, range); + s2n_codec::decoder_invariant!(range.contains(&(self.0).0), "invalid control bit pattern"); + Ok(()) + } +} diff --git a/dc/s2n-quic-dc/src/packet/control/decoder.rs b/dc/s2n-quic-dc/src/packet/control/decoder.rs new file mode 100644 index 000000000..b4ca6dd3e --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/control/decoder.rs @@ -0,0 +1,237 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + credentials::Credentials, + packet::{ + control::{self, Tag}, + stream, + }, +}; +use s2n_codec::{ + decoder_invariant, CheckedRange, DecoderBufferMut, DecoderBufferMutResult as R, DecoderError, +}; +use s2n_quic_core::{assume, varint::VarInt}; + +type PacketNumber = VarInt; + +pub trait Validator { + fn validate_tag(&mut self, tag: Tag) -> Result<(), DecoderError>; +} + +impl Validator for () { + #[inline] + fn validate_tag(&mut self, _tag: Tag) -> Result<(), DecoderError> { + Ok(()) + } +} + +impl Validator for Tag { + #[inline] + fn validate_tag(&mut self, actual: Tag) -> Result<(), DecoderError> { + decoder_invariant!(*self == actual, "unexpected packet type"); + Ok(()) + } +} + +impl Validator for (A, B) +where + A: Validator, + B: Validator, +{ + #[inline] + fn validate_tag(&mut self, tag: Tag) -> Result<(), DecoderError> { + self.0.validate_tag(tag)?; + self.1.validate_tag(tag)?; + Ok(()) + } +} + +#[derive(Debug)] +pub struct Packet<'a> { + tag: Tag, + credentials: Credentials, + source_control_port: u16, + stream_id: Option, + packet_number: PacketNumber, + header: &'a mut [u8], + application_header: CheckedRange, + control_data: CheckedRange, + auth_tag: &'a mut [u8], +} + +impl<'a> Packet<'a> { + #[inline] + pub fn tag(&self) -> Tag { + self.tag + } + + #[inline] + pub fn credentials(&self) -> &Credentials { + &self.credentials + } + + #[inline] + pub fn source_control_port(&self) -> u16 { + self.source_control_port + } + + #[inline] + pub fn stream_id(&self) -> Option<&stream::Id> { + self.stream_id.as_ref() + } + + #[inline] + pub fn crypto_nonce(&self) -> u64 { + self.packet_number.as_u64() | control::NONCE_MASK + } + + #[inline] + pub fn packet_number(&self) -> PacketNumber { + self.packet_number + } + + #[inline] + pub fn application_header(&self) -> &[u8] { + self.application_header.get(self.header) + } + + #[inline] + pub fn control_data(&self) -> &[u8] { + self.control_data.get(self.header) + } + + #[inline] + pub fn control_data_mut(&mut self) -> &mut [u8] { + self.control_data.get_mut(self.header) + } + + #[inline] + pub fn header(&self) -> &[u8] { + self.header + } + + #[inline] + pub fn auth_tag(&self) -> &[u8] { + self.auth_tag + } + + #[inline(always)] + pub fn decode( + buffer: DecoderBufferMut, + mut validator: V, + crypto_tag_len: usize, + ) -> R { + let ( + tag, + credentials, + source_control_port, + stream_id, + packet_number, + header_len, + total_header_len, + application_header_len, + control_data_len, + ) = { + let buffer = buffer.peek(); + + unsafe { + assume!( + crypto_tag_len >= 16, + "tag len needs to be at least 16 bytes" + ); + } + + let start_len = buffer.len(); + + let (tag, buffer) = buffer.decode()?; + validator.validate_tag(tag)?; + + let (credentials, buffer) = buffer.decode()?; + + let (source_control_port, buffer) = buffer.decode()?; + + let (stream_id, buffer) = if tag.is_stream() { + let (stream_id, buffer) = buffer.decode()?; + (Some(stream_id), buffer) + } else { + (None, buffer) + }; + + let (packet_number, buffer) = buffer.decode::()?; + let (control_data_len, buffer) = buffer.decode::()?; + + let (application_header_len, buffer) = if tag.has_application_header() { + let (application_header_len, buffer) = buffer.decode::()?; + ((*application_header_len) as usize, buffer) + } else { + (0, buffer) + }; + + let header_len = start_len - buffer.len(); + + let buffer = buffer.skip(application_header_len)?; + let buffer = buffer.skip(*control_data_len as _)?; + + let total_header_len = start_len - buffer.len(); + + let buffer = buffer.skip(crypto_tag_len)?; + + let _ = buffer; + + ( + tag, + credentials, + source_control_port, + stream_id, + packet_number, + header_len, + total_header_len, + application_header_len, + control_data_len, + ) + }; + + unsafe { + assume!(buffer.len() >= total_header_len); + } + let (header, buffer) = buffer.decode_slice(total_header_len)?; + + let (application_header, control_data) = { + let buffer = header.peek(); + unsafe { + assume!(buffer.len() >= header_len); + } + let buffer = buffer.skip(header_len)?; + unsafe { + assume!(buffer.len() >= application_header_len); + } + let (application_header, buffer) = + buffer.skip_into_range(application_header_len, &header)?; + unsafe { + assume!(buffer.len() >= *control_data_len as usize); + } + let (control_data, _) = buffer.skip_into_range(*control_data_len as usize, &header)?; + + (application_header, control_data) + }; + let header = header.into_less_safe_slice(); + + let (auth_tag, buffer) = buffer.decode_slice(crypto_tag_len)?; + let auth_tag = auth_tag.into_less_safe_slice(); + + let packet = Packet { + tag, + credentials, + source_control_port, + stream_id, + packet_number, + header, + application_header, + control_data, + auth_tag, + }; + + Ok((packet, buffer)) + } +} diff --git a/dc/s2n-quic-dc/src/packet/control/encoder.rs b/dc/s2n-quic-dc/src/packet/control/encoder.rs new file mode 100644 index 000000000..3e9c421c4 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/control/encoder.rs @@ -0,0 +1,94 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + crypto::encrypt, + packet::{ + control::{Tag, NONCE_MASK}, + stream, + }, +}; +use s2n_codec::{Encoder, EncoderBuffer, EncoderValue}; +use s2n_quic_core::{assume, buffer, varint::VarInt}; + +#[inline(always)] +#[allow(clippy::too_many_arguments)] +pub fn encode( + mut encoder: EncoderBuffer, + source_control_port: u16, + stream_id: Option, + packet_number: VarInt, + header_len: VarInt, + header: &mut H, + control_data_len: VarInt, + control_data: &CD, + crypto: &C, +) -> usize +where + H: buffer::reader::Storage, + CD: EncoderValue, + C: encrypt::Key, +{ + let mut tag = Tag::default(); + + if stream_id.is_some() { + tag.set_is_stream(true); + } + + if *header_len > 0 { + tag.set_has_application_header(true); + } + + let nonce = *packet_number | NONCE_MASK; + + encoder.encode(&tag); + + // encode the credentials being used + encoder.encode(crypto.credentials()); + encoder.encode(&source_control_port); + + encoder.encode(&stream_id); + + encoder.encode(&packet_number); + + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&control_data_len); + } + + if !header.buffer_is_empty() { + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&header_len); + } + encoder.write_sized(*header_len as usize, |mut dest| { + let _: Result<(), core::convert::Infallible> = header.copy_into(&mut dest); + }); + } + + encoder.encode(control_data); + + let payload_offset = encoder.len(); + + encoder.advance_position(crypto.tag_len()); + + let packet_len = encoder.len(); + + let slice = encoder.as_mut_slice(); + + { + let (header, payload_and_tag) = unsafe { + assume!(slice.len() >= payload_offset); + slice.split_at_mut(payload_offset) + }; + + crypto.encrypt(nonce, header, None, payload_and_tag); + } + + if cfg!(debug_assertions) { + let decoder = s2n_codec::DecoderBufferMut::new(slice); + let _ = super::decoder::Packet::decode(decoder, (), crypto.tag_len()).unwrap(); + } + + packet_len +} diff --git a/dc/s2n-quic-dc/src/packet/datagram.rs b/dc/s2n-quic-dc/src/packet/datagram.rs new file mode 100644 index 000000000..e1f709986 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/datagram.rs @@ -0,0 +1,90 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{tag::Common, HeaderLen, PacketNumber, PayloadLen}; +use core::fmt; +use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; + +pub mod decoder; +pub mod encoder; + +#[derive(Clone, Copy, PartialEq, Eq, AsBytes, FromBytes, FromZeroes, Unaligned)] +#[repr(C)] +pub struct Tag(Common); + +impl_tag_codec!(Tag); + +impl Default for Tag { + #[inline] + fn default() -> Self { + Self(Common(0b0100_0000)) + } +} + +impl fmt::Debug for Tag { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("datagram::Tag") + .field("ack_eliciting", &self.ack_eliciting()) + .field("is_connected", &self.is_connected()) + .field("has_length", &self.has_length()) + .field("has_application_header", &self.has_application_header()) + .finish() + } +} + +impl Tag { + const ACK_ELICITING_MASK: u8 = 0b1000; + const IS_CONNECTED_MASK: u8 = 0b0100; + const HAS_LENGTH_MASK: u8 = 0b0010; + const HAS_APPLICATION_HEADER_MASK: u8 = 0b0001; + + pub const MIN: u8 = 0b0100_0000; + pub const MAX: u8 = 0b0100_1111; + + #[inline] + pub fn set_ack_eliciting(&mut self, enabled: bool) { + self.0.set(Self::ACK_ELICITING_MASK, enabled) + } + + #[inline] + pub fn ack_eliciting(&self) -> bool { + self.0.get(Self::ACK_ELICITING_MASK) + } + + #[inline] + pub fn set_is_connected(&mut self, enabled: bool) { + self.0.set(Self::IS_CONNECTED_MASK, enabled) + } + + #[inline] + pub fn is_connected(&self) -> bool { + self.0.get(Self::IS_CONNECTED_MASK) + } + + #[inline] + pub fn set_has_length(&mut self, enabled: bool) { + self.0.set(Self::HAS_LENGTH_MASK, enabled) + } + + #[inline] + pub fn has_length(&self) -> bool { + self.0.get(Self::HAS_LENGTH_MASK) + } + + #[inline] + pub fn set_has_application_header(&mut self, enabled: bool) { + self.0.set(Self::HAS_APPLICATION_HEADER_MASK, enabled) + } + + #[inline] + pub fn has_application_header(&self) -> bool { + self.0.get(Self::HAS_APPLICATION_HEADER_MASK) + } + + #[inline] + fn validate(&self) -> Result<(), s2n_codec::DecoderError> { + let range = Self::MIN..=Self::MAX; + s2n_codec::decoder_invariant!(range.contains(&(self.0).0), "invalid datagram bit pattern"); + Ok(()) + } +} diff --git a/dc/s2n-quic-dc/src/packet/datagram/decoder.rs b/dc/s2n-quic-dc/src/packet/datagram/decoder.rs new file mode 100644 index 000000000..72c9ed61a --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/datagram/decoder.rs @@ -0,0 +1,280 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{credentials::Credentials, packet::datagram::Tag}; +use s2n_codec::{ + decoder_invariant, CheckedRange, DecoderBufferMut, DecoderBufferMutResult as R, DecoderError, +}; +use s2n_quic_core::{assume, varint::VarInt}; + +pub type PacketNumber = VarInt; + +pub trait Validator { + fn validate_tag(&mut self, tag: Tag) -> Result<(), DecoderError>; +} + +impl Validator for () { + #[inline] + fn validate_tag(&mut self, _tag: Tag) -> Result<(), DecoderError> { + Ok(()) + } +} + +impl Validator for Tag { + #[inline] + fn validate_tag(&mut self, actual: Tag) -> Result<(), DecoderError> { + decoder_invariant!(*self == actual, "unexpected packet type"); + Ok(()) + } +} + +impl Validator for (A, B) +where + A: Validator, + B: Validator, +{ + #[inline] + fn validate_tag(&mut self, tag: Tag) -> Result<(), DecoderError> { + self.0.validate_tag(tag)?; + self.1.validate_tag(tag)?; + Ok(()) + } +} + +pub struct Packet<'a> { + tag: Tag, + credentials: Credentials, + source_control_port: u16, + packet_number: PacketNumber, + next_expected_control_packet: Option, + header: &'a mut [u8], + application_header: CheckedRange, + control_data: CheckedRange, + payload: &'a mut [u8], + auth_tag: &'a mut [u8], +} + +impl<'a> std::fmt::Debug for Packet<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Packet") + .field("tag", &self.tag) + .field("credentials", &self.credentials) + .field("source_control_port", &self.source_control_port) + .field("packet_number", &self.packet_number) + .field( + "next_expected_control_packet", + &self.next_expected_control_packet, + ) + .field("header", &self.header) + .field("application_header", &self.application_header) + .field("control_data", &self.control_data) + .field("payload_len", &self.payload.len()) + .field("auth_tag", &self.auth_tag) + .finish() + } +} + +impl<'a> Packet<'a> { + #[inline] + pub fn tag(&self) -> Tag { + self.tag + } + + #[inline] + pub fn credentials(&self) -> &Credentials { + &self.credentials + } + + #[inline] + pub fn source_control_port(&self) -> u16 { + self.source_control_port + } + + #[inline] + pub fn crypto_nonce(&self) -> u64 { + self.packet_number.as_u64() + } + + #[inline] + pub fn packet_number(&self) -> PacketNumber { + self.packet_number + } + + #[inline] + pub fn next_expected_control_packet(&self) -> Option { + self.next_expected_control_packet + } + + #[inline] + pub fn application_header(&self) -> &[u8] { + self.application_header.get(self.header) + } + + #[inline] + pub fn control_data(&self) -> &[u8] { + self.control_data.get(self.header) + } + + #[inline] + pub fn header(&self) -> &[u8] { + self.header + } + + #[inline] + pub fn payload(&self) -> &[u8] { + self.payload + } + + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + self.payload + } + + #[inline] + pub fn auth_tag(&self) -> &[u8] { + self.auth_tag + } + + #[inline(always)] + pub fn decode( + buffer: DecoderBufferMut, + mut validator: V, + crypto_tag_len: usize, + ) -> R { + let ( + tag, + credentials, + source_control_port, + packet_number, + next_expected_control_packet, + header_len, + total_header_len, + application_header_len, + control_data_len, + payload_len, + ) = { + let buffer = buffer.peek(); + + unsafe { + assume!( + crypto_tag_len >= 16, + "tag len needs to be at least 16 bytes" + ); + } + + let start_len = buffer.len(); + + let (tag, buffer) = buffer.decode()?; + validator.validate_tag(tag)?; + + let (credentials, buffer) = buffer.decode()?; + + let (source_control_port, buffer) = buffer.decode()?; + + let (packet_number, buffer) = if tag.is_connected() || tag.ack_eliciting() { + buffer.decode()? + } else { + (VarInt::ZERO, buffer) + }; + + let (payload_len, buffer) = if tag.has_length() { + let (payload_len, buffer) = buffer.decode::()?; + let payload_len = (*payload_len) as usize; + (payload_len, buffer) + } else { + (0, buffer) + }; + + let (next_expected_control_packet, control_data_len, buffer) = if tag.ack_eliciting() { + let (packet_number, buffer) = buffer.decode::()?; + let (control_data_len, buffer) = buffer.decode::()?; + (Some(packet_number), (*control_data_len) as usize, buffer) + } else { + (None, 0usize, buffer) + }; + + let (application_header_len, buffer) = if tag.has_application_header() { + let (application_header_len, buffer) = buffer.decode::()?; + ((*application_header_len) as usize, buffer) + } else { + (0, buffer) + }; + + let header_len = start_len - buffer.len(); + + let buffer = buffer.skip(application_header_len)?; + + let buffer = buffer.skip(control_data_len)?; + + let total_header_len = start_len - buffer.len(); + + let payload_len = if tag.has_length() { + payload_len + } else { + buffer + .len() + .checked_sub(crypto_tag_len) + .ok_or(DecoderError::UnexpectedEof(crypto_tag_len))? + }; + + ( + tag, + credentials, + source_control_port, + packet_number, + next_expected_control_packet, + header_len, + total_header_len, + application_header_len, + control_data_len, + payload_len, + ) + }; + + unsafe { + assume!(buffer.len() >= total_header_len); + } + let (header, buffer) = buffer.decode_slice(total_header_len)?; + + let (application_header, control_data) = { + let buffer = header.peek(); + unsafe { + assume!(buffer.len() >= header_len); + } + let buffer = buffer.skip(header_len)?; + unsafe { + assume!(buffer.len() >= application_header_len); + } + let (application_header, buffer) = + buffer.skip_into_range(application_header_len, &header)?; + unsafe { + assume!(buffer.len() >= control_data_len); + } + let (control_data, _) = buffer.skip_into_range(control_data_len, &header)?; + + (application_header, control_data) + }; + let header = header.into_less_safe_slice(); + + let (payload, buffer) = buffer.decode_slice(payload_len)?; + let payload = payload.into_less_safe_slice(); + + let (auth_tag, buffer) = buffer.decode_slice(crypto_tag_len)?; + let auth_tag = auth_tag.into_less_safe_slice(); + + let packet = Packet { + tag, + credentials, + source_control_port, + packet_number, + next_expected_control_packet, + header, + application_header, + control_data, + payload, + auth_tag, + }; + + Ok((packet, buffer)) + } +} diff --git a/dc/s2n-quic-dc/src/packet/datagram/encoder.rs b/dc/s2n-quic-dc/src/packet/datagram/encoder.rs new file mode 100644 index 000000000..f0713d7de --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/datagram/encoder.rs @@ -0,0 +1,148 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{credentials, crypto::encrypt, packet::datagram::Tag}; +use s2n_codec::{Encoder, EncoderBuffer, EncoderValue}; +use s2n_quic_core::{assume, buffer}; + +#[inline(always)] +pub fn estimate_len( + _packet_number: super::PacketNumber, + next_expected_control_packet: Option, + app_header_len: super::HeaderLen, + payload_len: super::PayloadLen, + crypto_tag_len: usize, +) -> usize { + let app_header_len_usize = *app_header_len as usize; + let payload_len_usize = *payload_len as usize; + + let mut encoder = s2n_codec::EncoderLenEstimator::new(usize::MAX); + + encoder.encode(&Tag::default()); + // credentials + { + encoder.write_zerocopy::(|_| {}); + encoder.write_repeated(8, 0); + } + encoder.encode(&0u16); // source control port + encoder.write_repeated(8, 0); // packet number + encoder.write_repeated(8, 0); // payload len + + if let Some(_packet_number) = next_expected_control_packet { + encoder.write_repeated(8, 0); // next expected control packet + encoder.write_repeated(8, 0); // control_data_len + } + + if app_header_len_usize > 0 { + encoder.write_repeated(8, 0); // application header len + encoder.write_repeated(app_header_len_usize, 0); // application data + } + + encoder.write_repeated(8, 0); + encoder.write_repeated(payload_len_usize, 0); + + encoder.write_repeated(crypto_tag_len, 0); + + encoder.len() +} + +#[inline(always)] +#[allow(clippy::too_many_arguments)] +pub fn encode( + mut encoder: EncoderBuffer, + tag: Tag, + source_control_port: u16, + packet_number: super::PacketNumber, + next_expected_control_packet: Option, + header_len: super::HeaderLen, + header: &mut H, + control_data: &CD, + payload_len: super::PayloadLen, + payload: &mut P, + crypto: &C, +) -> usize +where + H: buffer::reader::Storage, + P: buffer::reader::Storage, + CD: EncoderValue, + C: encrypt::Key, +{ + debug_assert_eq!(tag.ack_eliciting(), next_expected_control_packet.is_some()); + + let header_len_usize = *header_len as usize; + let payload_len_usize = *payload_len as usize; + let nonce = *packet_number; + + encoder.encode(&tag); + + // encode the credentials being used + encoder.encode(crypto.credentials()); + encoder.encode(&source_control_port); + + if tag.is_connected() || tag.ack_eliciting() { + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&packet_number); + } + } else { + debug_assert_eq!(packet_number, super::PacketNumber::default()); + } + + if tag.has_length() { + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&payload_len); + } + } + + if let Some(packet_number) = next_expected_control_packet { + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&packet_number); + } + // TODO write control data len + } + + if !header.buffer_is_empty() { + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&header_len); + } + encoder.write_sized(header_len_usize, |mut dest| { + let _: Result<(), core::convert::Infallible> = header.copy_into(&mut dest); + }); + } + + if next_expected_control_packet.is_some() { + encoder.encode(control_data); + } + + let payload_offset = encoder.len(); + + let mut last_chunk = buffer::reader::storage::Chunk::empty(); + encoder.write_sized(payload_len_usize, |mut dest| { + let result: Result = + payload.partial_copy_into(&mut dest); + last_chunk = result.expect("copy is infallible"); + }); + + let last_chunk = if last_chunk.is_empty() { + None + } else { + Some(&last_chunk[..]) + }; + + encoder.advance_position(crypto.tag_len()); + + let packet_len = encoder.len(); + + let slice = encoder.as_mut_slice(); + let (header, payload_and_tag) = unsafe { + assume!(slice.len() >= payload_offset); + slice.split_at_mut(payload_offset) + }; + + crypto.encrypt(nonce, header, last_chunk, payload_and_tag); + + packet_len +} diff --git a/dc/s2n-quic-dc/src/packet/reset.rs b/dc/s2n-quic-dc/src/packet/reset.rs new file mode 100644 index 000000000..5d8d63c99 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/reset.rs @@ -0,0 +1,47 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::{ + fmt, + ops::{Deref, DerefMut}, +}; +use s2n_codec::zerocopy_value_codec; +use s2n_quic_core::varint::VarInt; +use zerocopy::{AsBytes, FromBytes, Unaligned}; + +#[derive(Clone, Copy, PartialEq, Eq, AsBytes, FromBytes, Unaligned)] +#[repr(C)] +pub struct Tag(u8); + +zerocopy_value_codec!(Tag); + +impl Default for Tag { + #[inline] + fn default() -> Self { + Self(0b0110_0000) + } +} + +/* +impl fmt::Debug for Tag { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("datagram::Tag") + .field("mode", &self.mode()) + .finish() + } +} + +impl Tag { + #[inline] + pub fn mode(&self) -> Mode { + + } +} + +#[derive(Clone, Copy, Debug)] +pub enum Mode { + Early, + Authenticated, + Stateless, +} +*/ diff --git a/dc/s2n-quic-dc/src/packet/secret_control.rs b/dc/s2n-quic-dc/src/packet/secret_control.rs new file mode 100644 index 000000000..397620dee --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/secret_control.rs @@ -0,0 +1,177 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + credentials, + crypto::{decrypt, encrypt}, +}; +use s2n_codec::{ + decoder_invariant, decoder_value, DecoderBuffer, DecoderBufferMut, + DecoderBufferMutResult as Rm, DecoderBufferResult as R, DecoderError, DecoderValue, Encoder, + EncoderBuffer, EncoderValue, +}; +use s2n_quic_core::varint::VarInt; +use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; + +#[macro_use] +mod decoder; +mod encoder; +mod nonce; + +const UNKNOWN_PATH_SECRET: u8 = 0b0110_0000; +const STALE_KEY: u8 = 0b0110_0001; +const REPLAY_DETECTED: u8 = 0b0110_0010; +const REQUEST_SHARDS: u8 = 0b0110_0011; + +macro_rules! impl_tag { + ($tag:expr) => { + #[derive(Clone, Copy, PartialEq, Eq, AsBytes, FromBytes, FromZeroes, Unaligned)] + #[repr(C)] + pub struct Tag(u8); + + impl core::fmt::Debug for Tag { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + f.debug_struct(concat!(module_path!(), "::Tag")).finish() + } + } + + impl Tag { + pub const VALUE: u8 = $tag; + } + + decoder_value!( + impl<'a> Tag { + fn decode(buffer: Buffer) -> Result { + let (tag, buffer) = buffer.decode()?; + decoder_invariant!(tag == $tag, "invalid tag"); + Ok((Self(tag), buffer)) + } + } + ); + + impl EncoderValue for Tag { + #[inline] + fn encode(&self, e: &mut E) { + self.0.encode(e) + } + } + + impl Default for Tag { + #[inline] + fn default() -> Self { + Self($tag) + } + } + }; +} + +#[cfg(test)] +macro_rules! impl_tests { + ($ty:ident) => { + #[test] + fn round_trip_test() { + use crate::crypto::awslc::{DecryptKey, EncryptKey, AES_128_GCM}; + + let creds = crate::credentials::Credentials { + id: Default::default(), + generation_id: Default::default(), + sequence_id: Default::default(), + }; + let key = &[0u8; 16]; + let iv = [0u8; 12]; + let encrypt = EncryptKey::new(creds, key, iv, &AES_128_GCM); + let decrypt = DecryptKey::new(creds, key, iv, &AES_128_GCM); + + bolero::check!() + .with_type::<$ty>() + .filter(|v| v.validate().is_some()) + .for_each(|value| { + let mut buffer = [0u8; 64]; + let len = { + let encoder = s2n_codec::EncoderBuffer::new(&mut buffer); + value.encode(encoder, (&mut &encrypt)) + }; + + { + use decrypt::Key as _; + let buffer = s2n_codec::DecoderBufferMut::new(&mut buffer[..len]); + let (decoded, _) = Packet::decode(buffer, decrypt.tag_len()).unwrap(); + let decoded = decoded.authenticate(&mut &decrypt).unwrap(); + assert_eq!(value, decoded); + } + + { + use decrypt::Key as _; + let buffer = s2n_codec::DecoderBufferMut::new(&mut buffer[..len]); + let (decoded, _) = crate::packet::secret_control::Packet::decode( + buffer, + decrypt.tag_len(), + ) + .unwrap(); + if let crate::packet::secret_control::Packet::$ty(decoded) = decoded { + let decoded = decoded.authenticate(&mut &decrypt).unwrap(); + assert_eq!(value, decoded); + } else { + panic!("decoded as the wrong packet type"); + } + } + }); + } + }; +} + +pub mod replay_detected; +pub mod request_shards; +pub mod stale_key; +pub mod unknown_path_secret; + +pub use nonce::Nonce; +pub use replay_detected::ReplayDetected; +pub use request_shards::RequestShards; +pub use stale_key::StaleKey; +pub use unknown_path_secret::UnknownPathSecret; + +#[derive(Clone, Copy, Debug)] +pub enum Packet<'a> { + UnknownPathSecret(unknown_path_secret::Packet<'a>), + StaleKey(stale_key::Packet<'a>), + ReplayDetected(replay_detected::Packet<'a>), + RequestShards(request_shards::Packet<'a>), +} + +impl<'a> Packet<'a> { + #[inline] + pub fn decode(buffer: DecoderBufferMut<'a>, crypto_tag_len: usize) -> Rm { + let tag = buffer.peek_byte(0)?; + + Ok(match tag { + UNKNOWN_PATH_SECRET => { + let (packet, buffer) = unknown_path_secret::Packet::decode(buffer)?; + (Self::UnknownPathSecret(packet), buffer) + } + STALE_KEY => { + let (packet, buffer) = stale_key::Packet::decode(buffer, crypto_tag_len)?; + (Self::StaleKey(packet), buffer) + } + REPLAY_DETECTED => { + let (packet, buffer) = replay_detected::Packet::decode(buffer, crypto_tag_len)?; + (Self::ReplayDetected(packet), buffer) + } + REQUEST_SHARDS => { + let (packet, buffer) = request_shards::Packet::decode(buffer, crypto_tag_len)?; + (Self::RequestShards(packet), buffer) + } + _ => return Err(DecoderError::InvariantViolation("invalid tag")), + }) + } + + #[inline] + pub fn credential_id(&self) -> &credentials::Id { + match self { + Self::UnknownPathSecret(p) => p.credential_id(), + Self::StaleKey(p) => p.credential_id(), + Self::ReplayDetected(p) => p.credential_id(), + Self::RequestShards(p) => p.credential_id(), + } + } +} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs b/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs new file mode 100644 index 000000000..86e770df9 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/secret_control/decoder.rs @@ -0,0 +1,91 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_codec::{ + DecoderBuffer, DecoderBufferMut, DecoderBufferMutResult as Rm, DecoderError, DecoderValue, +}; + +macro_rules! impl_packet { + ($name:ident) => { + #[derive(Clone, Copy, Debug)] + pub struct Packet<'a> { + header: &'a [u8], + value: $name, + crypto_tag: &'a [u8], + } + + impl<'a> Packet<'a> { + #[inline] + pub fn decode(buffer: DecoderBufferMut<'a>, crypto_tag_len: usize) -> Rm { + let header_len = decoder::header_len::<$name>(buffer.peek())?; + let ((header, value, crypto_tag), buffer) = + decoder::header(buffer, header_len, crypto_tag_len)?; + let packet = Self { + header, + value, + crypto_tag, + }; + Ok((packet, buffer)) + } + + #[inline] + pub fn credential_id(&self) -> &crate::credentials::Id { + &self.value.credential_id + } + + #[inline] + pub fn authenticate(&self, crypto: &mut C) -> Option<&$name> + where + C: decrypt::Key, + { + let Self { + header, + value, + crypto_tag, + } = self; + + crypto + .decrypt( + value.nonce(), + header, + &[], + crypto_tag, + bytes::buf::UninitSlice::new(&mut []), + ) + .ok()?; + + Some(value) + } + } + }; +} + +#[inline] +pub fn header_len<'a, T>(buffer: DecoderBuffer<'a>) -> Result +where + T: DecoderValue<'a>, +{ + let before_len = buffer.len(); + let (_, buffer) = buffer.decode::()?; + Ok(before_len - buffer.len()) +} + +#[inline] +pub fn header<'a, T>( + buffer: DecoderBufferMut<'a>, + header_len: usize, + crypto_tag_len: usize, +) -> Rm<'a, (&[u8], T, &[u8])> +where + T: DecoderValue<'a>, +{ + let (header, buffer) = buffer.decode_slice(header_len)?; + let header = header.freeze(); + let (value, _) = header.decode::()?; + let header = header.into_less_safe_slice(); + + let (crypto_tag, buffer) = buffer.decode_slice(crypto_tag_len)?; + let crypto_tag = crypto_tag.into_less_safe_slice(); + + Ok(((header, value, crypto_tag), buffer)) +} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/encoder.rs b/dc/s2n-quic-dc/src/packet/secret_control/encoder.rs new file mode 100644 index 000000000..2d66f90f1 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/secret_control/encoder.rs @@ -0,0 +1,29 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::Nonce; +use crate::crypto::encrypt; +use s2n_codec::{Encoder, EncoderBuffer}; +use s2n_quic_core::assume; + +#[inline] +pub fn finish(mut encoder: EncoderBuffer, nonce: Nonce, crypto: &mut C) -> usize +where + C: encrypt::Key, +{ + let header_offset = encoder.len(); + + encoder.advance_position(crypto.tag_len()); + + let packet_len = encoder.len(); + + let slice = encoder.as_mut_slice(); + let (header, payload_and_tag) = unsafe { + assume!(slice.len() >= header_offset); + slice.split_at_mut(header_offset) + }; + + crypto.encrypt(nonce, header, None, payload_and_tag); + + packet_len +} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/nonce.rs b/dc/s2n-quic-dc/src/packet/secret_control/nonce.rs new file mode 100644 index 000000000..f98f79801 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/secret_control/nonce.rs @@ -0,0 +1,95 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{REPLAY_DETECTED, REQUEST_SHARDS, STALE_KEY, UNKNOWN_PATH_SECRET}; +use crate::crypto::IntoNonce; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(test, derive(bolero::TypeGenerator))] +pub enum Nonce { + UnknownPathSecret, + StaleKey { + // This is the minimum key ID the server will accept (at the time of sending). + // + // This is used for cases where the server intentionally drops state in a manner that cuts + // out a chunk of not-yet-used key ID space. + min_key_id: u64, + }, + ReplayDetected { + // This is the key ID we rejected. + // + // The client should enqueue a handshake but it should keep in mind that this might be + // caused by an attacker replaying packets, so maybe impose rate limiting or ignore "really + // old" replay detected packets. + rejected_key_id: u64, + }, + // Note that this is all purely a hint and currently neither clients and servers will ever + // send this. + RequestShards { + // Total number of distinct key spaces we'd like the client to send in. + // + // Clients MAY react to this, if they can, by attempting to assign these shards across the + // sending threads. + receiving_shards: u16, + // How wide each of the shards should be. + // + // For example, if this is u16::MAX and receiving_shards is 2, then the client should + // attempt to start sending keys from two independent ranges (current, current + 1, ...) + // and (current + u16::MAX, current + u16::MAX + 1, ...), and if either range wraps + // continue on the other side. + // + // If the receiving is roughly randomly distributed across threads (but reading from the + // same socket) on the server side, this will lead to a higher probability that two + // receiving threads aren't going to contend on the same area of replay tracking when + // reading consecutive packets. + shard_width: u64, + }, +} + +impl IntoNonce for Nonce { + #[inline] + fn into_nonce(self) -> [u8; 12] { + let mut nonce = [0; 12]; + match self { + Self::UnknownPathSecret => { + nonce[0] = UNKNOWN_PATH_SECRET; + } + Self::StaleKey { min_key_id } => { + nonce[0] = STALE_KEY; + nonce[1..9].copy_from_slice(&min_key_id.to_be_bytes()); + } + Self::ReplayDetected { rejected_key_id } => { + nonce[0] = REPLAY_DETECTED; + nonce[1..9].copy_from_slice(&rejected_key_id.to_be_bytes()); + } + Self::RequestShards { + receiving_shards, + shard_width, + } => { + nonce[0] = REQUEST_SHARDS; + nonce[1..3].copy_from_slice(&receiving_shards.to_be_bytes()); + nonce[3..11].copy_from_slice(&shard_width.to_be_bytes()); + } + } + nonce + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bolero::check; + + /// ensures output nonces are only equal if the messages are equal + #[test] + #[cfg_attr(kani, kani::proof, kani::solver(cadical))] + fn nonce_uniqueness() { + check!().with_type::<(Nonce, Nonce)>().for_each(|(a, b)| { + if a == b { + assert_eq!(a.into_nonce(), b.into_nonce()); + } else { + assert_ne!(a.into_nonce(), b.into_nonce()); + } + }); + } +} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs b/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs new file mode 100644 index 000000000..f7a238b00 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/secret_control/replay_detected.rs @@ -0,0 +1,62 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; + +impl_tag!(REPLAY_DETECTED); +impl_packet!(ReplayDetected); + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(test, derive(bolero::TypeGenerator))] +pub struct ReplayDetected { + pub credential_id: credentials::Id, + pub rejected_key_id: VarInt, +} + +impl ReplayDetected { + #[inline] + pub fn encode(&self, mut encoder: EncoderBuffer, crypto: &mut C) -> usize + where + C: encrypt::Key, + { + encoder.encode(&Tag::default()); + encoder.encode(&self.credential_id); + encoder.encode(&self.rejected_key_id); + + encoder::finish(encoder, self.nonce(), crypto) + } + + #[inline] + pub fn nonce(&self) -> Nonce { + Nonce::ReplayDetected { + rejected_key_id: self.rejected_key_id.into(), + } + } + + #[cfg(test)] + fn validate(&self) -> Option<()> { + Some(()) + } +} + +impl<'a> DecoderValue<'a> for ReplayDetected { + #[inline] + fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { + let (tag, buffer) = buffer.decode::()?; + decoder_invariant!(tag == Tag::default(), "invalid tag"); + let (credential_id, buffer) = buffer.decode()?; + let (rejected_key_id, buffer) = buffer.decode()?; + let value = Self { + credential_id, + rejected_key_id, + }; + Ok((value, buffer)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + impl_tests!(ReplayDetected); +} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/request_additional_generation.rs b/dc/s2n-quic-dc/src/packet/secret_control/request_additional_generation.rs new file mode 100644 index 000000000..f002ed8d9 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/secret_control/request_additional_generation.rs @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; + +impl_tag!(REQUEST_ADDITIONAL_GENERATION); +impl_packet!(RequestAdditionalGeneration); + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(test, derive(bolero::TypeGenerator))] +pub struct RequestAdditionalGeneration { + pub credential_id: credentials::Id, + pub generation_id: u32, +} + +impl RequestAdditionalGeneration { + #[inline] + pub fn encode(&self, mut encoder: EncoderBuffer, crypto: &mut C) -> usize + where + C: encrypt::Key, + { + let generation_id = self.generation_id; + + encoder.encode(&Tag::default()); + encoder.encode(&&self.credential_id[..]); + encoder.encode(&VarInt::from(generation_id)); + + encoder::finish( + encoder, + Nonce::RequestAdditionalGeneration { generation_id }, + crypto, + ) + } + + #[inline] + pub fn nonce(&self) -> Nonce { + Nonce::RequestAdditionalGeneration { + generation_id: self.generation_id, + } + } + + #[cfg(test)] + fn validate(&self) -> Option<()> { + Some(()) + } +} + +impl<'a> DecoderValue<'a> for RequestAdditionalGeneration { + #[inline] + fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { + let (tag, buffer) = buffer.decode::()?; + decoder_invariant!(tag == Tag::default(), "invalid tag"); + let (credential_id, buffer) = buffer.decode()?; + let (generation_id, buffer) = decoder::sized(buffer)?; + let value = Self { + credential_id, + generation_id, + }; + Ok((value, buffer)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + impl_tests!(RequestAdditionalGeneration); +} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/request_shards.rs b/dc/s2n-quic-dc/src/packet/secret_control/request_shards.rs new file mode 100644 index 000000000..042f017d8 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/secret_control/request_shards.rs @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; + +impl_tag!(REQUEST_SHARDS); +impl_packet!(RequestShards); + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(test, derive(bolero::TypeGenerator))] +pub struct RequestShards { + pub credential_id: credentials::Id, + pub receiving_shards: u16, + pub shard_width: u64, +} + +impl RequestShards { + #[inline] + pub fn encode(&self, mut encoder: EncoderBuffer, crypto: &mut C) -> usize + where + C: encrypt::Key, + { + encoder.encode(&Tag::default()); + encoder.encode(&self.credential_id); + encoder.encode(&VarInt::from(self.receiving_shards)); + encoder.encode(&self.shard_width); + + encoder::finish(encoder, self.nonce(), crypto) + } + + #[inline] + pub fn nonce(&self) -> Nonce { + Nonce::RequestShards { + receiving_shards: self.receiving_shards, + shard_width: self.shard_width, + } + } + + #[cfg(test)] + fn validate(&self) -> Option<()> { + Some(()) + } +} + +impl<'a> DecoderValue<'a> for RequestShards { + #[inline] + fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { + let (tag, buffer) = buffer.decode::()?; + decoder_invariant!(tag == Tag::default(), "invalid tag"); + let (credential_id, buffer) = buffer.decode()?; + let (receiving_shards, buffer) = buffer.decode::()?; + let (shard_width, buffer) = buffer.decode()?; + let value = Self { + credential_id, + receiving_shards: receiving_shards + .try_into() + .map_err(|_| DecoderError::InvariantViolation("receiving_shards too big"))?, + shard_width, + }; + Ok((value, buffer)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + impl_tests!(RequestShards); +} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs b/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs new file mode 100644 index 000000000..12fcf059d --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/secret_control/stale_key.rs @@ -0,0 +1,62 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; + +impl_tag!(STALE_KEY); +impl_packet!(StaleKey); + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(test, derive(bolero::TypeGenerator))] +pub struct StaleKey { + pub credential_id: credentials::Id, + pub min_key_id: VarInt, +} + +impl StaleKey { + #[inline] + pub fn encode(&self, mut encoder: EncoderBuffer, crypto: &mut C) -> usize + where + C: encrypt::Key, + { + encoder.encode(&Tag::default()); + encoder.encode(&self.credential_id); + encoder.encode(&self.min_key_id); + + encoder::finish(encoder, self.nonce(), crypto) + } + + #[inline] + pub fn nonce(&self) -> Nonce { + Nonce::StaleKey { + min_key_id: self.min_key_id.into(), + } + } + + #[cfg(test)] + fn validate(&self) -> Option<()> { + Some(()) + } +} + +impl<'a> DecoderValue<'a> for StaleKey { + #[inline] + fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { + let (tag, buffer) = buffer.decode::()?; + decoder_invariant!(tag == Tag::default(), "invalid tag"); + let (credential_id, buffer) = buffer.decode()?; + let (min_key_id, buffer) = buffer.decode()?; + let value = Self { + credential_id, + min_key_id, + }; + Ok((value, buffer)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + impl_tests!(StaleKey); +} diff --git a/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs b/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs new file mode 100644 index 000000000..eadcf9525 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/secret_control/unknown_path_secret.rs @@ -0,0 +1,112 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use core::mem::size_of; + +impl_tag!(UNKNOWN_PATH_SECRET); + +const STATELESS_RESET_LEN: usize = 16; + +#[derive(Clone, Copy, Debug)] +pub struct Packet<'a> { + #[allow(dead_code)] + header: &'a [u8], + value: UnknownPathSecret, + crypto_tag: &'a [u8], +} + +impl<'a> Packet<'a> { + #[inline] + pub fn decode(buffer: DecoderBufferMut<'a>) -> Rm { + let header_len = decoder::header_len::(buffer.peek())?; + let ((header, value, crypto_tag), buffer) = + decoder::header(buffer, header_len, STATELESS_RESET_LEN)?; + let packet = Self { + header, + value, + crypto_tag, + }; + Ok((packet, buffer)) + } + + #[inline] + pub fn credential_id(&self) -> &crate::credentials::Id { + &self.value.credential_id + } + + #[inline] + pub fn authenticate( + &self, + stateless_reset: &[u8; STATELESS_RESET_LEN], + ) -> Option<&UnknownPathSecret> { + aws_lc_rs::constant_time::verify_slices_are_equal(self.crypto_tag, stateless_reset).ok()?; + Some(&self.value) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(test, derive(bolero::TypeGenerator))] +pub struct UnknownPathSecret { + pub credential_id: credentials::Id, +} + +impl UnknownPathSecret { + pub const PACKET_SIZE: usize = + size_of::() + size_of::() + STATELESS_RESET_LEN; + + #[inline] + pub fn encode( + &self, + mut encoder: EncoderBuffer, + stateless_reset_tag: &[u8; STATELESS_RESET_LEN], + ) -> usize { + let before = encoder.len(); + encoder.encode(&Tag::default()); + encoder.encode(&&self.credential_id[..]); + encoder.encode(&&stateless_reset_tag[..]); + let after = encoder.len(); + after - before + } + + #[inline] + pub fn nonce(&self) -> Nonce { + Nonce::UnknownPathSecret + } +} + +impl<'a> DecoderValue<'a> for UnknownPathSecret { + #[inline] + fn decode(buffer: DecoderBuffer<'a>) -> R<'a, Self> { + let (tag, buffer) = buffer.decode::()?; + decoder_invariant!(tag == Tag::default(), "invalid tag"); + let (credential_id, buffer) = buffer.decode()?; + let value = Self { credential_id }; + Ok((value, buffer)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trip_test() { + bolero::check!() + .with_type::<(UnknownPathSecret, [u8; 16])>() + .for_each(|(value, stateless_reset)| { + let mut buffer = [0u8; 64]; + let len = { + let encoder = s2n_codec::EncoderBuffer::new(&mut buffer); + value.encode(encoder, stateless_reset) + }; + + { + let buffer = s2n_codec::DecoderBufferMut::new(&mut buffer[..len]); + let (decoded, _) = Packet::decode(buffer).unwrap(); + let decoded = decoded.authenticate(stateless_reset).unwrap(); + assert_eq!(value, decoded); + } + }); + } +} diff --git a/dc/s2n-quic-dc/src/packet/stream.rs b/dc/s2n-quic-dc/src/packet/stream.rs new file mode 100644 index 000000000..fcaa5908a --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/stream.rs @@ -0,0 +1,92 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::tag::Common; +use core::fmt; +use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; + +pub mod decoder; +pub mod encoder; +mod id; + +pub use id::Id; + +#[derive(Clone, Copy, PartialEq, Eq, AsBytes, FromBytes, FromZeroes, Unaligned)] +#[repr(C)] +pub struct Tag(Common); + +impl_tag_codec!(Tag); + +impl Default for Tag { + #[inline] + fn default() -> Self { + Self(Common(0b0000_0000)) + } +} + +impl fmt::Debug for Tag { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("stream::Tag") + .field("has_control_data", &self.has_control_data()) + .field("has_final_offset", &self.has_final_offset()) + .field("has_application_header", &self.has_application_header()) + .finish() + } +} + +impl Tag { + const HAS_SOURCE_STREAM_PORT: u8 = 0b01_0000; + const HAS_CONTROL_DATA_MASK: u8 = 0b00_0100; + const HAS_FINAL_OFFSET_MASK: u8 = 0b00_0010; + const HAS_APPLICATION_HEADER_MASK: u8 = 0b00_0001; + + pub const MIN: u8 = 0b0000_0000; + pub const MAX: u8 = 0b0011_1111; + + #[inline] + pub fn set_has_source_stream_port(&mut self, enabled: bool) { + self.0.set(Self::HAS_SOURCE_STREAM_PORT, enabled) + } + + #[inline] + pub fn has_source_stream_port(&self) -> bool { + self.0.get(Self::HAS_SOURCE_STREAM_PORT) + } + + #[inline] + pub fn set_has_control_data(&mut self, enabled: bool) { + self.0.set(Self::HAS_CONTROL_DATA_MASK, enabled) + } + + #[inline] + pub fn has_control_data(&self) -> bool { + self.0.get(Self::HAS_CONTROL_DATA_MASK) + } + + #[inline] + pub fn set_has_final_offset(&mut self, enabled: bool) { + self.0.set(Self::HAS_FINAL_OFFSET_MASK, enabled) + } + + #[inline] + pub fn has_final_offset(&self) -> bool { + self.0.get(Self::HAS_FINAL_OFFSET_MASK) + } + + #[inline] + pub fn set_has_application_header(&mut self, enabled: bool) { + self.0.set(Self::HAS_APPLICATION_HEADER_MASK, enabled) + } + + #[inline] + pub fn has_application_header(&self) -> bool { + self.0.get(Self::HAS_APPLICATION_HEADER_MASK) + } + + #[inline] + fn validate(&self) -> Result<(), s2n_codec::DecoderError> { + let range = Self::MIN..=Self::MAX; + s2n_codec::decoder_invariant!(range.contains(&(self.0).0), "invalid stream bit pattern"); + Ok(()) + } +} diff --git a/dc/s2n-quic-dc/src/packet/stream/decoder.rs b/dc/s2n-quic-dc/src/packet/stream/decoder.rs new file mode 100644 index 000000000..4b605217d --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/stream/decoder.rs @@ -0,0 +1,578 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + credentials::Credentials, + crypto, + packet::stream::{self, Tag}, +}; +use s2n_codec::{ + decoder_invariant, u24, CheckedRange, DecoderBufferMut, DecoderBufferMutResult as R, + DecoderError, +}; +use s2n_quic_core::{assume, varint::VarInt}; + +type PacketNumber = VarInt; + +pub trait Validator { + fn validate_tag(&mut self, tag: Tag) -> Result<(), DecoderError>; +} + +impl Validator for () { + #[inline] + fn validate_tag(&mut self, _tag: Tag) -> Result<(), DecoderError> { + Ok(()) + } +} + +impl Validator for Tag { + #[inline] + fn validate_tag(&mut self, actual: Tag) -> Result<(), DecoderError> { + decoder_invariant!(*self == actual, "unexpected packet type"); + Ok(()) + } +} + +impl Validator for (A, B) +where + A: Validator, + B: Validator, +{ + #[inline] + fn validate_tag(&mut self, tag: Tag) -> Result<(), DecoderError> { + self.0.validate_tag(tag)?; + self.1.validate_tag(tag)?; + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Owned { + pub tag: Tag, + pub credentials: Credentials, + pub source_control_port: u16, + pub source_stream_port: Option, + pub stream_id: stream::Id, + pub original_packet_number: PacketNumber, + pub packet_number: PacketNumber, + pub retransmission_packet_number_offset: u8, + pub next_expected_control_packet: PacketNumber, + pub stream_offset: VarInt, + pub final_offset: Option, + pub application_header: Vec, + pub control_data: Vec, + pub payload: Vec, + pub auth_tag: Vec, +} + +impl<'a> From> for Owned { + fn from(packet: Packet<'a>) -> Self { + let application_header = packet.application_header().to_vec(); + let control_data = packet.control_data().to_vec(); + + Self { + tag: packet.tag, + credentials: packet.credentials, + source_control_port: packet.source_control_port, + source_stream_port: packet.source_stream_port, + stream_id: packet.stream_id, + original_packet_number: packet.original_packet_number, + packet_number: packet.packet_number, + retransmission_packet_number_offset: packet.retransmission_packet_number_offset, + next_expected_control_packet: packet.next_expected_control_packet, + stream_offset: packet.stream_offset, + final_offset: packet.final_offset, + application_header, + control_data, + payload: packet.payload.to_vec(), + auth_tag: packet.auth_tag.to_vec(), + } + } +} + +#[derive(Debug)] +pub struct Packet<'a> { + tag: Tag, + credentials: Credentials, + source_control_port: u16, + source_stream_port: Option, + stream_id: stream::Id, + original_packet_number: PacketNumber, + packet_number: PacketNumber, + retransmission_packet_number_offset: u8, + next_expected_control_packet: PacketNumber, + stream_offset: VarInt, + final_offset: Option, + header: &'a mut [u8], + application_header: CheckedRange, + control_data: CheckedRange, + payload: &'a mut [u8], + auth_tag: &'a mut [u8], +} + +impl<'a> Packet<'a> { + #[inline] + pub fn tag(&self) -> Tag { + self.tag + } + + #[inline] + pub fn credentials(&self) -> &Credentials { + &self.credentials + } + + #[inline] + pub fn source_control_port(&self) -> u16 { + self.source_control_port + } + + #[inline] + pub fn source_stream_port(&self) -> Option { + self.source_stream_port + } + + #[inline] + pub fn stream_id(&self) -> &stream::Id { + &self.stream_id + } + + #[inline] + pub fn packet_number(&self) -> PacketNumber { + self.packet_number + } + + #[inline] + pub fn is_retransmission(&self) -> bool { + self.packet_number != self.original_packet_number + } + + #[inline] + pub fn next_expected_control_packet(&self) -> PacketNumber { + self.next_expected_control_packet + } + + #[inline] + pub fn stream_offset(&self) -> VarInt { + self.stream_offset + } + + #[inline] + pub fn final_offset(&self) -> Option { + self.final_offset + } + + #[inline] + pub fn is_fin(&self) -> bool { + self.final_offset() + .and_then(|offset| offset.checked_sub(self.stream_offset)) + .and_then(|offset| { + let len = VarInt::try_from(self.payload.len()).ok()?; + offset.checked_sub(len) + }) + .map_or(false, |v| *v == 0) + } + + #[inline] + pub fn application_header(&self) -> &[u8] { + self.application_header.get(self.header) + } + + #[inline] + pub fn control_data(&self) -> &[u8] { + self.control_data.get(self.header) + } + + #[inline] + pub fn header(&self) -> &[u8] { + self.header + } + + #[inline] + pub fn payload(&self) -> &[u8] { + self.payload + } + + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + self.payload + } + + #[inline] + pub fn decrypt( + &mut self, + d: &mut D, + payload_out: &mut crypto::UninitSlice, + ) -> Result<(), crypto::decrypt::Error> + where + D: crypto::decrypt::Key, + { + self.remove_retransmit(d); + + let nonce = self.original_packet_number.as_u64(); + let header = &self.header; + let payload = &self.payload; + let auth_tag = &self.auth_tag; + + d.decrypt(nonce, header, payload, auth_tag, payload_out)?; + + Ok(()) + } + + #[inline] + pub fn decrypt_in_place(&mut self, d: &mut D) -> Result<(), crypto::decrypt::Error> + where + D: crypto::decrypt::Key, + { + self.remove_retransmit(d); + + let nonce = self.original_packet_number.as_u64(); + let header = &self.header; + + let payload_len = self.payload.len(); + let payload_ptr = self.payload.as_mut_ptr(); + let tag_len = self.auth_tag.len(); + let tag_ptr = self.auth_tag.as_mut_ptr(); + let payload_and_tag = unsafe { + debug_assert_eq!(payload_ptr.add(payload_len), tag_ptr); + + core::slice::from_raw_parts_mut(payload_ptr, payload_len + tag_len) + }; + + d.decrypt_in_place(nonce, header, payload_and_tag)?; + + Ok(()) + } + + #[inline] + fn remove_retransmit(&mut self, d: &mut D) + where + D: crypto::decrypt::Key, + { + let original_packet_number = self.original_packet_number.as_u64(); + let retransmission_packet_number = self.packet_number.as_u64(); + + if original_packet_number != retransmission_packet_number { + d.retransmission_tag( + original_packet_number, + retransmission_packet_number, + self.auth_tag, + ); + let offset = self.retransmission_packet_number_offset as usize; + let range = offset..offset + 3; + self.header[range].copy_from_slice(&[0; 3]); + } + } + + #[inline] + #[cfg(debug_assertions)] + pub fn retransmit( + buffer: DecoderBufferMut, + retransmission_packet_number: VarInt, + key: &mut K, + ) -> Result<(), DecoderError> + where + K: crypto::encrypt::Key, + { + let buffer = buffer.into_less_safe_slice(); + + let mut before = Self::snapshot(buffer, key.tag_len()); + // the auth tag will have changed so clear it + before.auth_tag.clear(); + + Self::retransmit_impl( + DecoderBufferMut::new(buffer), + retransmission_packet_number, + key, + )?; + + let mut after = Self::snapshot(buffer, key.tag_len()); + assert_eq!(after.packet_number, retransmission_packet_number); + after.packet_number = before.packet_number; + // the auth tag will have changed so clear it + after.auth_tag.clear(); + + assert_eq!(before, after); + + Ok(()) + } + + #[inline] + #[cfg(not(debug_assertions))] + pub fn retransmit( + buffer: DecoderBufferMut, + retransmission_packet_number: VarInt, + key: &mut K, + ) -> Result<(), DecoderError> + where + K: crypto::encrypt::Key, + { + Self::retransmit_impl(buffer, retransmission_packet_number, key) + } + + #[inline] + #[cfg(debug_assertions)] + fn snapshot(buffer: &mut [u8], crypto_tag_len: usize) -> Owned { + let buffer = DecoderBufferMut::new(buffer); + let (packet, _buffer) = Self::decode(buffer, (), crypto_tag_len).unwrap(); + packet.into() + } + + #[inline(always)] + fn retransmit_impl( + buffer: DecoderBufferMut, + retransmission_packet_number: VarInt, + key: &mut K, + ) -> Result<(), DecoderError> + where + K: crypto::encrypt::Key, + { + unsafe { + assume!(key.tag_len() >= 16, "tag len needs to be at least 16 bytes"); + } + + let (tag, buffer) = buffer.decode::()?; + + let (credentials, buffer) = buffer.decode::()?; + + debug_assert_eq!(&credentials, key.credentials()); + + let (_source_control_port, buffer) = buffer.decode::()?; + + let (_source_stream_port, buffer) = if tag.has_source_stream_port() { + let (port, buffer) = buffer.decode::()?; + (Some(port), buffer) + } else { + (None, buffer) + }; + + let (stream_id, buffer) = buffer.decode::()?; + + decoder_invariant!( + stream_id.is_reliable, + "only reliable streams can be retransmitted" + ); + + let (original_packet_number, buffer) = buffer.decode::()?; + let (retransmission_packet_number_buffer, buffer) = buffer.decode_slice(3)?; + let retransmission_packet_number_buffer = + retransmission_packet_number_buffer.into_less_safe_slice(); + let retransmission_packet_number_buffer: &mut [u8; 3] = + retransmission_packet_number_buffer.try_into().unwrap(); + + let (_next_expected_control_packet, buffer) = buffer.decode::()?; + let (_stream_offset, buffer) = buffer.decode::()?; + + let auth_tag_offset = buffer + .len() + .checked_sub(key.tag_len()) + .ok_or(DecoderError::InvariantViolation("missing auth tag"))?; + let buffer = buffer.skip(auth_tag_offset)?; + let auth_tag = buffer.into_less_safe_slice(); + + let relative = retransmission_packet_number + .checked_sub(original_packet_number) + .ok_or(DecoderError::InvariantViolation( + "invalid retransmission packet number", + ))?; + + let original_packet_number = original_packet_number.as_u64(); + + let relative: u24 = relative + .as_u64() + .try_into() + .map_err(|_| DecoderError::InvariantViolation("packet is too old"))?; + + // undo the previous retransmission if needed + let prev_value = u24::from_be_bytes(*retransmission_packet_number_buffer); + if prev_value != u24::ZERO { + let retransmission_packet_number = original_packet_number + *prev_value as u64; + key.retransmission_tag( + original_packet_number, + retransmission_packet_number, + auth_tag, + ); + } + + retransmission_packet_number_buffer.copy_from_slice(&relative.to_be_bytes()); + + key.retransmission_tag( + original_packet_number, + retransmission_packet_number.as_u64(), + auth_tag, + ); + + Ok(()) + } + + #[inline(always)] + pub fn decode( + buffer: DecoderBufferMut, + mut validator: V, + crypto_tag_len: usize, + ) -> R { + let ( + tag, + credentials, + source_control_port, + source_stream_port, + stream_id, + original_packet_number, + packet_number, + retransmission_packet_number_offset, + next_expected_control_packet, + stream_offset, + final_offset, + header_len, + total_header_len, + application_header_len, + control_data_len, + payload_len, + ) = { + let buffer = buffer.peek(); + + unsafe { + assume!( + crypto_tag_len >= 16, + "tag len needs to be at least 16 bytes" + ); + } + + let start_len = buffer.len(); + + let (tag, buffer) = buffer.decode()?; + validator.validate_tag(tag)?; + + let (credentials, buffer) = buffer.decode()?; + + let (source_control_port, buffer) = buffer.decode()?; + + let (source_stream_port, buffer) = if tag.has_source_stream_port() { + let (port, buffer) = buffer.decode()?; + (Some(port), buffer) + } else { + (None, buffer) + }; + + let (stream_id, buffer) = buffer.decode::()?; + + let (original_packet_number, buffer) = buffer.decode::()?; + + let retransmission_packet_number_offset = (start_len - buffer.len()) as u8; + let (packet_number, buffer) = if stream_id.is_reliable { + let (rel, buffer) = buffer.decode::()?; + let rel = VarInt::from_u32(*rel); + let pn = original_packet_number.checked_add(rel).ok_or( + DecoderError::InvariantViolation("retransmission packet number overflow"), + )?; + (pn, buffer) + } else { + (original_packet_number, buffer) + }; + + let (next_expected_control_packet, buffer) = buffer.decode()?; + let (stream_offset, buffer) = buffer.decode()?; + let (final_offset, buffer) = if tag.has_final_offset() { + let (final_offset, buffer) = buffer.decode()?; + (Some(final_offset), buffer) + } else { + (None, buffer) + }; + let (control_data_len, buffer) = if tag.has_control_data() { + buffer.decode()? + } else { + (VarInt::ZERO, buffer) + }; + let (payload_len, buffer) = buffer.decode::()?; + + let (application_header_len, buffer) = if tag.has_application_header() { + let (application_header_len, buffer) = buffer.decode::()?; + ((*application_header_len) as usize, buffer) + } else { + (0, buffer) + }; + + let header_len = start_len - buffer.len(); + + let buffer = buffer.skip(application_header_len)?; + let buffer = buffer.skip(*control_data_len as _)?; + + let total_header_len = start_len - buffer.len(); + + let buffer = buffer.skip(*payload_len as _)?; + let buffer = buffer.skip(crypto_tag_len)?; + + let _ = buffer; + + ( + tag, + credentials, + source_control_port, + source_stream_port, + stream_id, + original_packet_number, + packet_number, + retransmission_packet_number_offset, + next_expected_control_packet, + stream_offset, + final_offset, + header_len, + total_header_len, + application_header_len, + control_data_len, + payload_len, + ) + }; + + unsafe { + assume!(buffer.len() >= total_header_len); + } + let (header, buffer) = buffer.decode_slice(total_header_len)?; + + let (application_header, control_data) = { + let buffer = header.peek(); + unsafe { + assume!(buffer.len() >= header_len); + } + let buffer = buffer.skip(header_len)?; + unsafe { + assume!(buffer.len() >= application_header_len); + } + let (application_header, buffer) = + buffer.skip_into_range(application_header_len, &header)?; + unsafe { + assume!(buffer.len() >= *control_data_len as usize); + } + let (control_data, _) = buffer.skip_into_range(*control_data_len as usize, &header)?; + + (application_header, control_data) + }; + let header = header.into_less_safe_slice(); + + let (payload, buffer) = buffer.decode_slice(*payload_len as usize)?; + let payload = payload.into_less_safe_slice(); + + let (auth_tag, buffer) = buffer.decode_slice(crypto_tag_len)?; + let auth_tag = auth_tag.into_less_safe_slice(); + + let packet = Packet { + tag, + credentials, + source_control_port, + source_stream_port, + stream_id, + original_packet_number, + packet_number, + retransmission_packet_number_offset, + next_expected_control_packet, + stream_offset, + final_offset, + header, + application_header, + control_data, + payload, + auth_tag, + }; + + Ok((packet, buffer)) + } +} diff --git a/dc/s2n-quic-dc/src/packet/stream/encoder.rs b/dc/s2n-quic-dc/src/packet/stream/encoder.rs new file mode 100644 index 000000000..a98bd2d56 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/stream/encoder.rs @@ -0,0 +1,173 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + crypto::encrypt, + packet::stream::{self, Tag}, +}; +use s2n_codec::{u24, Encoder, EncoderBuffer, EncoderValue}; +use s2n_quic_core::{ + assume, + buffer::{self, reader::storage::Infallible as _}, + varint::VarInt, +}; + +// TODO make sure this is accurate +pub const MAX_RETRANSMISSION_HEADER_LEN: usize = MAX_HEADER_LEN + (24 / 8); +pub const MAX_HEADER_LEN: usize = 50; + +#[inline(always)] +#[allow(clippy::too_many_arguments)] +pub fn encode( + mut encoder: EncoderBuffer, + source_control_port: u16, + source_stream_port: Option, + stream_id: stream::Id, + packet_number: VarInt, + next_expected_control_packet: VarInt, + header_len: VarInt, + header: &mut H, + control_data_len: VarInt, + control_data: &CD, + payload: &mut P, + crypto: &C, +) -> usize +where + H: buffer::reader::Storage, + P: buffer::Reader, + CD: EncoderValue, + C: encrypt::Key, +{ + let stream_offset = payload.current_offset(); + let final_offset = payload.final_offset(); + + let mut tag = Tag::default(); + + if *control_data_len > 0 { + tag.set_has_control_data(true); + } + + if final_offset.is_some() { + tag.set_has_final_offset(true); + } + + if *header_len > 0 { + tag.set_has_application_header(true); + } + + if source_stream_port.is_some() { + tag.set_has_source_stream_port(true); + } + + let nonce = *packet_number; + + encoder.encode(&tag); + + // encode the credentials being used + encoder.encode(crypto.credentials()); + encoder.encode(&source_control_port); + encoder.encode(&source_stream_port); + + encoder.encode(&stream_id); + + encoder.encode(&packet_number); + if stream_id.is_reliable { + encoder.encode(&u24::default()); + } + encoder.encode(&next_expected_control_packet); + encoder.encode(&stream_offset); + + if let Some(final_offset) = final_offset { + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&final_offset); + } + } + + if *control_data_len > 0 { + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&control_data_len); + } + } + + let payload_len = { + // TODO compute payload len for the given encoder + let buffered_len = payload.buffered_len(); + + let remaining_payload_capacity = encoder + .remaining_capacity() + .saturating_sub(header_len.encoding_size()) + .saturating_sub(*header_len as usize) + .saturating_sub(*control_data_len as usize) + .saturating_sub(crypto.tag_len()); + + // TODO figure out encoding size for the capacity + let remaining_payload_capacity = remaining_payload_capacity.saturating_sub(1); + + let payload_len = buffered_len.min(remaining_payload_capacity); + + unsafe { + assume!(VarInt::try_from(payload_len).is_ok()); + VarInt::try_from(payload_len).unwrap() + } + }; + + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&payload_len); + } + + if !header.buffer_is_empty() { + unsafe { + assume!(encoder.remaining_capacity() >= 8); + encoder.encode(&header_len); + } + encoder.write_sized(*header_len as usize, |mut dest| { + header.infallible_copy_into(&mut dest); + }); + } + + if *control_data_len > 0 { + encoder.encode(control_data); + } + + let payload_offset = encoder.len(); + + let mut last_chunk = Default::default(); + encoder.write_sized(*payload_len as usize, |mut dest| { + // the payload result is infallible + last_chunk = payload.infallible_partial_copy_into(&mut dest); + }); + + let last_chunk = if last_chunk.is_empty() { + None + } else { + Some(&*last_chunk) + }; + + encoder.advance_position(crypto.tag_len()); + + let packet_len = encoder.len(); + + let slice = encoder.as_mut_slice(); + + { + let (header, payload_and_tag) = unsafe { + assume!(slice.len() >= payload_offset); + slice.split_at_mut(payload_offset) + }; + + crypto.encrypt(nonce, header, last_chunk, payload_and_tag); + } + + if cfg!(debug_assertions) { + let decoder = s2n_codec::DecoderBufferMut::new(slice); + let (packet, remaining) = + super::decoder::Packet::decode(decoder, (), crypto.tag_len()).unwrap(); + assert!(remaining.is_empty()); + assert_eq!(packet.packet_number(), packet_number); + } + + packet_len +} diff --git a/dc/s2n-quic-dc/src/packet/stream/id.rs b/dc/s2n-quic-dc/src/packet/stream/id.rs new file mode 100644 index 000000000..5746b9d9d --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/stream/id.rs @@ -0,0 +1,166 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::fmt; +use s2n_codec::{decoder_invariant, decoder_value, Encoder, EncoderValue}; +use s2n_quic_core::{assume, ensure, probe, varint::VarInt}; + +#[derive(Clone, Copy, Default, PartialEq, Eq)] +pub struct Id { + pub generation_id: u32, + pub sequence_id: u16, + pub is_reliable: bool, + pub is_bidirectional: bool, +} + +impl fmt::Debug for Id { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if f.alternate() { + f.debug_struct("stream::Id") + .field("generation_id", &self.generation_id) + .field("sequence_id", &self.sequence_id) + .field("is_reliable", &self.is_reliable) + .field("is_bidirectional", &self.is_bidirectional) + .finish() + } else { + self.into_varint().as_u64().fmt(f) + } + } +} + +impl fmt::Display for Id { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.into_varint().fmt(f) + } +} + +impl probe::Arg for Id { + #[inline] + fn into_usdt(self) -> isize { + self.into_varint().into_usdt() + } +} + +impl Id { + #[inline] + pub fn bidirectional(mut self) -> Self { + self.is_bidirectional = true; + self + } + + #[inline] + pub fn reliable(mut self) -> Self { + self.is_reliable = true; + self + } + + #[inline] + pub fn next(&self) -> Option { + let mut generation_id = self.generation_id; + let (sequence_id, overflowed) = self.sequence_id.overflowing_add(1); + if overflowed { + generation_id = generation_id.checked_add(1)?; + } + Some(Self { + generation_id, + sequence_id, + is_reliable: self.is_reliable, + is_bidirectional: self.is_bidirectional, + }) + } + + #[inline] + pub fn iter(&self) -> impl Iterator { + let mut next = Some(*self); + core::iter::from_fn(move || { + let current = next; + next = next.and_then(|v| v.next()); + current + }) + } + + #[inline] + pub fn into_varint(self) -> VarInt { + let generation_id = (self.generation_id as u64) << 18; + let sequence_id = (self.sequence_id as u64) << 2; + let is_reliable = if self.is_reliable { + IS_RELIABLE_MASK + } else { + 0b00 + }; + let is_bidirectional = if self.is_bidirectional { + IS_BIDIRECTIONAL_MASK + } else { + 0b00 + }; + let value = generation_id | sequence_id | is_reliable | is_bidirectional; + unsafe { + assume!(value <= MAX_ID_VALUE); + VarInt::new_unchecked(value) + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct TryFromIntError(()); + +impl fmt::Display for TryFromIntError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "could not convert the provided u64 to a stream ID") + } +} + +impl std::error::Error for TryFromIntError {} + +impl TryFrom for Id { + type Error = TryFromIntError; + + #[inline] + fn try_from(value: u64) -> Result { + ensure!(value <= (1 << (32 + 16)), Err(TryFromIntError(()))); + let generation_id = (value >> 16) as u32; + let sequence_id = value as u16; + Ok(Self { + generation_id, + sequence_id, + is_reliable: false, + is_bidirectional: false, + }) + } +} + +const MAX_ID_VALUE: u64 = 1 << (32 + 16 + 1 + 1); +const IS_RELIABLE_MASK: u64 = 0b10; +const IS_BIDIRECTIONAL_MASK: u64 = 0b01; + +decoder_value!( + impl<'a> Id { + fn decode(buffer: Buffer) -> Result { + let (value, buffer) = buffer.decode::()?; + let value = *value; + decoder_invariant!(value <= MAX_ID_VALUE, "invalid range"); + let generation_id = (value >> 18) as u32; + let sequence_id = (value >> 2) as u16; + let is_reliable = value & IS_RELIABLE_MASK == IS_RELIABLE_MASK; + let is_bidirectional = value & IS_BIDIRECTIONAL_MASK == IS_BIDIRECTIONAL_MASK; + Ok(( + Self { + generation_id, + sequence_id, + is_reliable, + is_bidirectional, + }, + buffer, + )) + } + } +); + +impl EncoderValue for Id { + #[inline] + fn encode(&self, encoder: &mut E) { + self.into_varint().encode(encoder) + } +} diff --git a/dc/s2n-quic-dc/src/packet/tag.rs b/dc/s2n-quic-dc/src/packet/tag.rs new file mode 100644 index 000000000..d03fc9316 --- /dev/null +++ b/dc/s2n-quic-dc/src/packet/tag.rs @@ -0,0 +1,106 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_codec::{decoder_invariant, decoder_value}; +use zerocopy::{AsBytes, FromBytes, FromZeroes, Unaligned}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromBytes, FromZeroes, Unaligned)] +#[repr(C)] +pub(super) struct Common(pub(super) u8); + +impl Common { + #[inline] + pub(super) fn set(&mut self, mask: u8, enabled: bool) { + self.0 = self.0 & !mask | if enabled { mask } else { 0 } + } + + #[inline] + pub(super) fn get(&self, mask: u8) -> bool { + self.0 & mask != 0 + } + + #[inline] + fn validate(&self) -> Result<(), s2n_codec::DecoderError> { + decoder_invariant!(self.0 & 0b1000_0000 == 0, "only short packets are used"); + Ok(()) + } +} + +#[derive(Clone, Copy, Debug)] +pub enum Tag { + Stream(super::stream::Tag), + Datagram(super::datagram::Tag), + Control(super::control::Tag), + StaleKey(super::secret_control::stale_key::Tag), + ReplayDetected(super::secret_control::replay_detected::Tag), + RequestShards(super::secret_control::request_shards::Tag), + UnknownPathSecret(super::secret_control::unknown_path_secret::Tag), +} + +decoder_value!( + impl<'a> Tag { + fn decode(buffer: Buffer) -> Result { + match buffer.peek_byte(0)? { + super::stream::Tag::MIN..=super::stream::Tag::MAX => { + let (tag, buffer) = buffer.decode()?; + Ok((Self::Stream(tag), buffer)) + } + super::datagram::Tag::MIN..=super::datagram::Tag::MAX => { + let (tag, buffer) = buffer.decode()?; + Ok((Self::Datagram(tag), buffer)) + } + super::control::Tag::MIN..=super::control::Tag::MAX => { + let (tag, buffer) = buffer.decode()?; + Ok((Self::Control(tag), buffer)) + } + super::secret_control::stale_key::Tag::VALUE => { + let (tag, buffer) = buffer.decode()?; + Ok((Self::StaleKey(tag), buffer)) + } + super::secret_control::replay_detected::Tag::VALUE => { + let (tag, buffer) = buffer.decode()?; + Ok((Self::ReplayDetected(tag), buffer)) + } + super::secret_control::request_shards::Tag::VALUE => { + let (tag, buffer) = buffer.decode()?; + Ok((Self::RequestShards(tag), buffer)) + } + super::secret_control::unknown_path_secret::Tag::VALUE => { + let (tag, buffer) = buffer.decode()?; + Ok((Self::UnknownPathSecret(tag), buffer)) + } + // reserve this range for other packet types + 0b0110_0000..=0b0111_1111 => Err(s2n_codec::DecoderError::InvariantViolation( + "unexpected packet tag", + )), + 0b1000_0000..=0b1111_1111 => Err(s2n_codec::DecoderError::InvariantViolation( + "only short packets are accepted", + )), + } + } + } +); + +macro_rules! impl_tag_codec { + ($ty:ty) => { + impl s2n_codec::EncoderValue for $ty { + #[inline] + fn encode(&self, encoder: &mut E) { + self.0.encode(encoder); + } + } + + s2n_codec::decoder_value!( + impl<'a> $ty { + fn decode(buffer: Buffer) -> Result { + let (byte, buffer) = buffer.decode()?; + let v = Self(byte); + v.validate()?; + Ok((v, buffer)) + } + } + ); + }; +} + +impl_tag_codec!(Common); diff --git a/dc/s2n-quic-dc/src/path.rs b/dc/s2n-quic-dc/src/path.rs new file mode 100644 index 000000000..e86b636c8 --- /dev/null +++ b/dc/s2n-quic-dc/src/path.rs @@ -0,0 +1,65 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic_core::{ + path::{Handle, MaxMtu, Tuple}, + varint::VarInt, +}; + +#[cfg(any(test, feature = "testing"))] +pub mod testing; + +pub trait Controller { + type Handle: Handle; + + fn handle(&self) -> &Self::Handle; +} + +impl Controller for Tuple { + type Handle = Self; + + #[inline] + fn handle(&self) -> &Self::Handle { + self + } +} + +#[derive(Clone, Copy, Debug)] +pub struct Parameters { + pub max_mtu: MaxMtu, + pub remote_max_data: VarInt, + pub local_max_data: VarInt, +} + +impl Default for Parameters { + fn default() -> Self { + static DEFAULT_MAX_DATA: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { + std::env::var("DC_QUIC_DEFAULT_MAX_DATA") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(1u32 << 25) + .into() + }); + + static DEFAULT_MTU: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { + let mtu = if cfg!(target_os = "linux") { + 8940 + } else { + 1450 + }; + + std::env::var("DC_QUIC_DEFAULT_MTU") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(mtu) + .try_into() + .unwrap() + }); + + Self { + max_mtu: *DEFAULT_MTU, + remote_max_data: *DEFAULT_MAX_DATA, + local_max_data: *DEFAULT_MAX_DATA, + } + } +} diff --git a/dc/s2n-quic-dc/src/path/testing.rs b/dc/s2n-quic-dc/src/path/testing.rs new file mode 100644 index 000000000..22ffd42d9 --- /dev/null +++ b/dc/s2n-quic-dc/src/path/testing.rs @@ -0,0 +1,47 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic_core::inet::SocketAddressV4; + +pub use s2n_quic_core::path::Tuple as Handle; + +#[derive(Clone, Debug)] +pub struct Controller { + handle: Handle, +} + +impl Controller { + #[inline] + pub fn server() -> Self { + let local_address = SocketAddressV4::new([127, 0, 0, 1], 4433); + let remote_address = SocketAddressV4::new([127, 0, 0, 2], 4433); + + let local_address = local_address.into(); + let remote_address = remote_address.into(); + + let handle = Handle { + local_address, + remote_address, + }; + Self { handle } + } + + #[inline] + pub fn client() -> Self { + let mut v = Self::server(); + let remote = v.handle.local_address.0; + let local = v.handle.remote_address.0; + v.handle.remote_address = remote.into(); + v.handle.local_address = local.into(); + v + } +} + +impl super::Controller for Controller { + type Handle = Handle; + + #[inline] + fn handle(&self) -> &Self::Handle { + &self.handle + } +} diff --git a/dc/s2n-quic-dc/src/pool.rs b/dc/s2n-quic-dc/src/pool.rs new file mode 100644 index 000000000..9395fe213 --- /dev/null +++ b/dc/s2n-quic-dc/src/pool.rs @@ -0,0 +1,147 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::ops; +use crossbeam_channel as mpmc; +use std::{ + mem::ManuallyDrop, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; +use tracing::{info, trace}; + +pub struct Pool { + release: mpmc::Sender, + acquire: mpmc::Receiver, + stats: Option>, +} + +impl Clone for Pool { + #[inline] + fn clone(&self) -> Self { + Self { + release: self.release.clone().clone(), + acquire: self.acquire.clone(), + stats: self.stats.clone(), + } + } +} + +impl Default for Pool { + #[inline] + fn default() -> Self { + Self::new(2000) + } +} + +impl Pool { + #[inline] + pub fn new(max_entries: usize) -> Self { + let (release, acquire) = mpmc::bounded(max_entries); + let mut pool = Self { + release, + acquire, + stats: None, + }; + + if std::env::var("DC_QUIC_POOL_METRICS").is_ok() { + let stats = Arc::new(Stats::default()); + pool.stats = Some(stats.clone()); + std::thread::spawn(move || loop { + std::thread::sleep(core::time::Duration::from_secs(1)); + let hits = stats.hits.load(Ordering::Relaxed); + let misses = stats.misses.load(Ordering::Relaxed); + let errors = stats.errors.load(Ordering::Relaxed); + let hit_ratio = hits as f64 / (hits + misses) as f64 * 100.0; + info!(hits, misses, errors, hit_ratio); + }); + } + + pool + } + + #[inline] + pub fn get(&self) -> Option> { + let entry = self.acquire.try_recv().ok()?; + let entry = Entry::new(entry, self.release.clone()); + Some(entry) + } + + #[inline] + pub fn get_or_init(&self, f: F) -> Result, E> + where + F: FnOnce() -> Result, + { + if let Some(entry) = self.get() { + if let Some(stats) = self.stats.as_ref() { + stats.hits.fetch_add(1, Ordering::Relaxed); + } + trace!("hit"); + Ok(entry) + } else { + let entry = f(); + + if entry.is_err() { + if let Some(stats) = self.stats.as_ref() { + stats.errors.fetch_add(1, Ordering::Relaxed); + } + } + + let entry = entry?; + + let entry = Entry::new(entry, self.release.clone()); + if let Some(stats) = self.stats.as_ref() { + stats.misses.fetch_add(1, Ordering::Relaxed); + } + trace!("miss"); + Ok(entry) + } + } +} + +#[derive(Default)] +struct Stats { + hits: AtomicUsize, + misses: AtomicUsize, + errors: AtomicUsize, +} + +pub struct Entry { + entry: ManuallyDrop, + pool: mpmc::Sender, +} + +impl Entry { + #[inline] + fn new(entry: T, pool: mpmc::Sender) -> Self { + let entry = ManuallyDrop::new(entry); + Self { entry, pool } + } +} + +impl ops::Deref for Entry { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.entry + } +} + +impl ops::DerefMut for Entry { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.entry + } +} + +impl Drop for Entry { + #[inline] + fn drop(&mut self) { + let socket = unsafe { ManuallyDrop::take(&mut self.entry) }; + trace!("release"); + let _ = self.pool.try_send(socket); + } +} diff --git a/dc/s2n-quic-dc/src/recovery.rs b/dc/s2n-quic-dc/src/recovery.rs new file mode 100644 index 000000000..a1de7e5aa --- /dev/null +++ b/dc/s2n-quic-dc/src/recovery.rs @@ -0,0 +1,8 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub use s2n_quic_core::recovery::RttEstimator; + +pub fn rtt_estimator() -> RttEstimator { + RttEstimator::new(core::time::Duration::from_millis(10)) +} diff --git a/dc/s2n-quic-dc/src/stream.rs b/dc/s2n-quic-dc/src/stream.rs new file mode 100644 index 000000000..6b22fea49 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream.rs @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::time::Duration; + +pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(10); + +pub mod packet_map; +pub mod packet_number; +pub mod processing; +pub mod recv; +pub mod send; diff --git a/dc/s2n-quic-dc/src/stream/packet_map.rs b/dc/s2n-quic-dc/src/stream/packet_map.rs new file mode 100644 index 000000000..45ae65466 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/packet_map.rs @@ -0,0 +1,14 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic_core::{inet::ExplicitCongestionNotification, time::Timestamp}; + +pub type Map = s2n_quic_core::packet::number::Map>; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct SentPacketInfo { + pub data: Data, + pub time_sent: Timestamp, + pub ecn: ExplicitCongestionNotification, + pub cc_info: crate::congestion::PacketInfo, +} diff --git a/dc/s2n-quic-dc/src/stream/packet_number.rs b/dc/s2n-quic-dc/src/stream/packet_number.rs new file mode 100644 index 000000000..885c2e7d6 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/packet_number.rs @@ -0,0 +1,30 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::sync::atomic::{AtomicU64, Ordering}; +use s2n_quic_core::varint::VarInt; + +#[derive(Clone, Copy, Debug)] +pub struct ExhaustionError; + +#[derive(Debug, Default)] +pub struct Counter(AtomicU64); + +impl Counter { + #[inline] + pub fn reset(&self) { + self.0.store(0, Ordering::Relaxed) + } + + #[inline] + pub fn next(&self) -> Result { + // https://marabos.nl/atomics/memory-ordering.html#relaxed + // > While atomic operations using relaxed memory ordering do not + // > provide any happens-before relationship, they do guarantee a total + // > modification order of each individual atomic variable. This means + // > that all modifications of the same atomic variable happen in an + // > order that is the same from the perspective of every single thread. + let pn = self.0.fetch_add(1, Ordering::Relaxed); + VarInt::new(pn).map_err(|_| ExhaustionError) + } +} diff --git a/dc/s2n-quic-dc/src/stream/processing.rs b/dc/s2n-quic-dc/src/stream/processing.rs new file mode 100644 index 000000000..ff308a05b --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/processing.rs @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::crypto::decrypt; + +#[derive(Clone, Copy, Debug, thiserror::Error)] +pub enum Error { + #[error("packet could not be decrypted")] + Decrypt, + #[error("packet has already been processed")] + Duplicate, + #[error("the crypto key has been replayed and is invalid")] + KeyReplayPrevented, + #[error("the crypto key has been potentially replayed and is invalid")] + KeyReplayPotentiallyPrevented, +} + +impl From for Error { + fn from(value: decrypt::Error) -> Self { + match value { + decrypt::Error::ReplayDefinitelyDetected => Self::KeyReplayPrevented, + decrypt::Error::ReplayPotentiallyDetected => Self::KeyReplayPotentiallyPrevented, + decrypt::Error::InvalidTag => Self::Decrypt, + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/recv.rs b/dc/s2n-quic-dc/src/stream/recv.rs new file mode 100644 index 000000000..93af7f4ab --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/recv.rs @@ -0,0 +1,18 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::packet::stream::decoder::Packet; +use s2n_quic_core::packet::number::{PacketNumberSpace, SlidingWindow, SlidingWindowError}; + +#[derive(Debug, Default)] +pub struct StreamFilter { + window: SlidingWindow, +} + +impl StreamFilter { + #[inline] + pub fn on_packet(&mut self, packet: &Packet) -> Result<(), SlidingWindowError> { + let packet_number = PacketNumberSpace::Initial.new_packet_number(packet.packet_number()); + self.window.insert(packet_number) + } +} diff --git a/dc/s2n-quic-dc/src/stream/send.rs b/dc/s2n-quic-dc/src/stream/send.rs new file mode 100644 index 000000000..c014fe768 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send.rs @@ -0,0 +1,16 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::too_many_arguments)] + +pub mod application; +pub mod error; +pub mod filter; +pub mod flow; +pub mod path; +pub mod probes; +pub mod transmission; +pub mod worker; + +#[cfg(test)] +mod tests; diff --git a/dc/s2n-quic-dc/src/stream/send/application.rs b/dc/s2n-quic-dc/src/stream/send/application.rs new file mode 100644 index 000000000..dab90a4bb --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/application.rs @@ -0,0 +1,119 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + crypto::encrypt, + packet::stream::{self, encoder}, + stream::{ + packet_number, + send::{error::Error, flow, path}, + }, +}; +use s2n_codec::EncoderBuffer; +use s2n_quic_core::{ + buffer::{self, reader::Storage as _, Reader as _}, + ensure, + inet::ExplicitCongestionNotification, + time::Clock, + varint::VarInt, +}; + +pub trait Message { + fn max_segments(&self) -> usize; + fn set_ecn(&mut self, ecn: ExplicitCongestionNotification); + fn push usize>( + &mut self, + clock: &Clk, + is_reliable: bool, + buffer_len: usize, + p: P, + ); +} + +pub struct State { + pub source_control_port: u16, + pub stream_id: stream::Id, +} + +impl State { + #[inline] + pub fn transmit( + &self, + credits: flow::Credits, + path: &path::Info, + storage: &mut I, + packet_number: &packet_number::Counter, + encrypt_key: &E, + clock: &Clk, + message: &mut M, + ) -> Result<(), Error> + where + E: encrypt::Key, + I: buffer::reader::Storage, + Clk: Clock, + M: Message, + { + ensure!(credits.len > 0 || credits.is_fin, Ok(())); + + let mut reader = buffer::reader::Incremental::new(credits.offset); + let mut reader = reader.with_storage(storage, credits.is_fin)?; + debug_assert!( + reader.buffered_len() >= credits.len, + "attempted to acquire more credits than what is buffered" + ); + let mut reader = reader.with_read_limit(credits.len); + + let stream_id = *self.stream_id(); + let max_header_len = self.max_header_len(); + + // TODO set destination address with the current value + + message.set_ecn(path.ecn); + + loop { + let packet_number = packet_number.next()?; + + let buffer_len = { + let estimated_len = reader.buffered_len() + max_header_len; + (path.mtu as usize).min(estimated_len) + }; + + message.push(clock, stream_id.is_reliable, buffer_len, |buffer| { + let encoder = EncoderBuffer::new(buffer); + encoder::encode( + encoder, + self.source_control_port, + None, + stream_id, + packet_number, + path.next_expected_control_packet, + VarInt::ZERO, + &mut &[][..], + VarInt::ZERO, + &(), + &mut reader, + encrypt_key, + ) + }); + + // bail if we've transmitted everything + ensure!(!reader.buffer_is_empty(), break); + } + + Ok(()) + } + + #[inline] + fn stream_id(&self) -> &stream::Id { + &self.stream_id + } + + #[inline] + pub fn max_header_len(&self) -> usize { + if self.stream_id().is_reliable { + encoder::MAX_RETRANSMISSION_HEADER_LEN + } else { + encoder::MAX_HEADER_LEN + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/error.rs b/dc/s2n-quic-dc/src/stream/send/error.rs new file mode 100644 index 000000000..f1bcfe67c --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/error.rs @@ -0,0 +1,72 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::stream::packet_number; +use s2n_quic_core::{buffer, varint::VarInt}; + +#[derive(Clone, Copy, Debug, thiserror::Error)] +pub enum Error { + #[error("payload provided is too large and exceeded the maximum offset")] + PayloadTooLarge, + #[error("the provided packet buffer is too small for the minimum packet size")] + PacketBufferTooSmall, + #[error("the number of packets able to be sent on the sender has been exceeded")] + PacketNumberExhaustion, + #[error("retransmission not possible")] + RetransmissionFailure, + #[error("stream has been finished")] + StreamFinished, + #[error("the final size of the stream has changed")] + FinalSizeChanged, + #[error("the sender idle timer expired")] + IdleTimeout, + #[error("the stream was reset by the peer with code {code}")] + TransportError { code: VarInt }, + #[error("the stream was closed with application code {error}")] + ApplicationError { + error: s2n_quic_core::application::Error, + }, + #[error("an invalid frame was received: {decoder}")] + FrameError { decoder: s2n_codec::DecoderError }, + #[error("the stream experienced an unrecoverable error")] + FatalError, +} + +impl From for std::io::Error { + #[inline] + fn from(error: Error) -> Self { + use std::io::ErrorKind; + let kind = match error { + Error::PayloadTooLarge => ErrorKind::BrokenPipe, + Error::PacketBufferTooSmall => ErrorKind::InvalidInput, + Error::PacketNumberExhaustion => ErrorKind::BrokenPipe, + Error::RetransmissionFailure => ErrorKind::BrokenPipe, + Error::StreamFinished => ErrorKind::UnexpectedEof, + Error::FinalSizeChanged => ErrorKind::InvalidInput, + Error::IdleTimeout => ErrorKind::TimedOut, + Error::ApplicationError { .. } => ErrorKind::ConnectionReset, + Error::TransportError { .. } => ErrorKind::ConnectionAborted, + Error::FrameError { .. } => ErrorKind::InvalidData, + Error::FatalError => ErrorKind::BrokenPipe, + }; + Self::new(kind, error) + } +} + +impl From for Error { + #[inline] + fn from(_error: packet_number::ExhaustionError) -> Self { + Self::PacketNumberExhaustion + } +} + +impl From> for Error { + #[inline] + fn from(error: buffer::Error) -> Self { + match error { + buffer::Error::OutOfRange => Self::PayloadTooLarge, + buffer::Error::InvalidFin => Self::FinalSizeChanged, + buffer::Error::ReaderError(_) => unreachable!(), + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/filter.rs b/dc/s2n-quic-dc/src/stream/send/filter.rs new file mode 100644 index 000000000..857e93d67 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/filter.rs @@ -0,0 +1,18 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::packet::control::decoder::Packet; +use s2n_quic_core::packet::number::{PacketNumberSpace, SlidingWindow, SlidingWindowError}; + +#[derive(Debug, Default)] +pub struct Filter { + window: SlidingWindow, +} + +impl Filter { + #[inline] + pub fn on_packet(&mut self, packet: &Packet) -> Result<(), SlidingWindowError> { + let packet_number = PacketNumberSpace::Initial.new_packet_number(packet.packet_number()); + self.window.insert(packet_number) + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/flow.rs b/dc/s2n-quic-dc/src/stream/send/flow.rs new file mode 100644 index 000000000..c5fc80267 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/flow.rs @@ -0,0 +1,53 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic_core::varint::VarInt; + +pub mod blocking; +pub mod non_blocking; + +/// Flow credits acquired by an application request +#[derive(Debug)] +pub struct Credits { + /// The offset at which to write the stream bytes + pub offset: VarInt, + /// The number of bytes which an application must write after acquisition + pub len: usize, + /// Indicates if the stream is being finalized + pub is_fin: bool, +} + +/// An application request for flow credits +#[derive(Clone, Copy, Debug)] +pub struct Request { + /// The number of bytes in the application buffer + pub len: usize, + /// Indicates if the request is finalizing a stream + pub is_fin: bool, +} + +impl Request { + /// Clamps the request with the given number of credits + #[inline] + pub fn clamp(&mut self, credits: u64) { + let len = self.len.min(credits.min(u16::MAX as u64) as usize); + + // if we didn't acquire the entire len, then clear the `is_fin` flag + if self.len != len { + self.is_fin = false; + } + + // update the len based on the provided credits + self.len = len; + } + + /// Constructs a response with the acquired offset + #[inline] + pub fn response(self, offset: VarInt) -> Credits { + Credits { + offset, + len: self.len, + is_fin: self.is_fin, + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs b/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs new file mode 100644 index 000000000..6f58d1507 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/flow/blocking.rs @@ -0,0 +1,230 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::Credits; +use crate::stream::send::{error::Error, flow}; +use s2n_quic_core::{ensure, varint::VarInt}; +use std::sync::{Condvar, Mutex}; + +pub struct State { + state: Mutex, + notify: Condvar, +} + +impl State { + #[inline] + pub fn new(initial_flow_offset: VarInt) -> Self { + Self { + state: Mutex::new(Inner { + stream_offset: VarInt::ZERO, + flow_offset: initial_flow_offset, + is_finished: false, + }), + notify: Condvar::new(), + } + } +} + +struct Inner { + /// Monotonic offset which tracks where the application is currently writing + stream_offset: VarInt, + /// Monotonic offset which indicates the maximum offset the application can write to + flow_offset: VarInt, + /// Indicates that the stream has been finalized + is_finished: bool, +} + +impl State { + /// Called by the background worker to release flow credits + /// + /// Callers MUST ensure the provided offset is monotonic. + #[inline] + pub fn release(&self, flow_offset: VarInt) -> Result<(), Error> { + let mut guard = self.state.lock().map_err(|_| Error::FatalError)?; + + // only notify subscribers if we actually increment the offset + debug_assert!( + guard.flow_offset < flow_offset, + "flow offsets MUST be monotonic" + ); + ensure!(guard.flow_offset < flow_offset, Ok(())); + + guard.flow_offset = flow_offset; + drop(guard); + + self.notify.notify_all(); + + Ok(()) + } + + /// Called by the application to acquire flow credits + #[inline] + pub fn acquire(&self, mut request: flow::Request) -> Result { + let mut guard = self.state.lock().map_err(|_| Error::FatalError)?; + + loop { + ensure!(!guard.is_finished, Err(Error::FinalSizeChanged)); + + // TODO check for an error + + let current_offset = guard.stream_offset; + let flow_offset = guard.flow_offset; + + debug_assert!( + current_offset <= flow_offset, + "current_offset={current_offset} should be <= flow_offset={flow_offset}" + ); + + let Some(flow_credits) = flow_offset + .as_u64() + .checked_sub(current_offset.as_u64()) + .filter(|v| { + // if we're finishing the stream and don't have any buffered data, then we + // don't need any flow control + if request.len == 0 && request.is_fin { + true + } else { + *v > 0 + } + }) + else { + guard = self.notify.wait(guard).map_err(|_| Error::FatalError)?; + continue; + }; + + // clamp the request to the flow credits we have + request.clamp(flow_credits); + + // update the stream offset with the given request + guard.stream_offset = current_offset + .checked_add_usize(request.len) + .ok_or(Error::PayloadTooLarge)?; + + // update the finished status + guard.is_finished |= request.is_fin; + + // drop the lock before notifying all of the waiting Condvar + drop(guard); + + // notify the other handles when we finish + if request.is_fin { + self.notify.notify_all(); + } + + // the offset was correctly updated so return our acquired credits + let credits = request.response(current_offset); + + return Ok(credits); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::stream::send::path; + use std::{ + sync::atomic::{AtomicU64, Ordering}, + thread, + }; + + #[test] + fn concurrent_flow() { + let mut initial_offset = VarInt::from_u8(255); + let expected_len = VarInt::from_u16(u16::MAX); + let state = State::new(initial_offset); + let path_info = path::Info { + mtu: 1500, + send_quantum: 10, + ecn: Default::default(), + next_expected_control_packet: Default::default(), + }; + let total = AtomicU64::new(0); + let workers = 5; + let worker_counts = Vec::from_iter((0..workers).map(|_| AtomicU64::new(0))); + + thread::scope(|s| { + let total = &total; + let path_info = &path_info; + let state = &state; + + for (idx, count) in worker_counts.iter().enumerate() { + s.spawn(move || { + thread::sleep(core::time::Duration::from_millis(10)); + + let mut buffer_len = 1; + let mut is_fin = false; + let max_segments = 10; + let max_header_len = 50; + let mut max_offset = VarInt::ZERO; + + loop { + let mut request = flow::Request { + len: buffer_len, + is_fin, + }; + request.clamp(path_info.max_flow_credits(max_header_len, max_segments)); + + let Ok(credits) = state.acquire(request) else { + break; + }; + + eprintln!( + "thread={idx} offset={}..{}", + credits.offset, + credits.offset + credits.len + ); + buffer_len += 1; + buffer_len = buffer_len.min( + expected_len + .as_u64() + .saturating_sub(credits.offset.as_u64()) + .saturating_sub(credits.len as u64) + as usize, + ); + + assert!(max_offset <= credits.offset); + max_offset = credits.offset; + + if buffer_len == 0 { + is_fin = true; + } + total.fetch_add(credits.len as _, Ordering::Relaxed); + count.fetch_add(credits.len as _, Ordering::Relaxed); + } + }); + } + + s.spawn(|| { + let mut credits = 10; + while initial_offset < expected_len { + thread::sleep(core::time::Duration::from_millis(1)); + initial_offset = (initial_offset + credits).min(expected_len); + credits += 1; + let _ = state.release(initial_offset); + } + }); + }); + + assert_eq!(total.load(Ordering::Relaxed), expected_len.as_u64()); + let mut at_least_one_write = true; + for (idx, count) in worker_counts.into_iter().enumerate() { + let count = count.load(Ordering::Relaxed); + eprintln!("thread={idx}, count={}", count); + if count == 0 { + at_least_one_write = false; + } + } + + let _ = at_least_one_write; + + // TODO the Mutex mechanism doesn't fairly distribute between workers so don't make this + // assertion until we can do something more reliable + /* + assert!( + at_least_one_write, + "all workers need to write at least one byte" + ); + */ + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs b/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs new file mode 100644 index 000000000..22f706a89 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/flow/non_blocking.rs @@ -0,0 +1,248 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::Credits; +use crate::stream::send::{error::Error, flow}; +use atomic_waker::AtomicWaker; +use core::{ + sync::atomic::{AtomicU64, Ordering}, + task::{Context, Poll}, +}; +use s2n_quic_core::{ensure, varint::VarInt}; + +const ERROR_MASK: u64 = 1 << 63; +const FINISHED_MASK: u64 = 1 << 62; + +pub struct State { + /// Monotonic offset which tracks where the application is currently writing + stream_offset: AtomicU64, + /// Monotonic offset which indicates the maximum offset the application can write to + flow_offset: AtomicU64, + /// Notifies an application of newly-available flow credits + poll_waker: AtomicWaker, + // TODO add a list for the `acquire` future wakers +} + +impl State { + #[inline] + pub fn new(initial_flow_offset: VarInt) -> Self { + Self { + stream_offset: AtomicU64::new(0), + flow_offset: AtomicU64::new(initial_flow_offset.as_u64()), + poll_waker: AtomicWaker::new(), + } + } +} + +impl State { + /// Called by the background worker to release flow credits + /// + /// Callers MUST ensure the provided offset is monotonic. + #[inline] + pub fn release(&self, flow_offset: VarInt) { + self.flow_offset + .store(flow_offset.as_u64(), Ordering::Release); + self.poll_waker.wake(); + } + + /// Called by the application to acquire flow credits + #[inline] + pub async fn acquire(&self, request: flow::Request) -> Result { + core::future::poll_fn(|cx| self.poll_acquire(cx, request)).await + } + + /// Called by the application to acquire flow credits + #[inline] + pub fn poll_acquire( + &self, + cx: &mut Context, + mut request: flow::Request, + ) -> Poll> { + let mut current_offset = self.acquire_offset()?; + + let mut stored_waker = false; + + loop { + let flow_offset = self.flow_offset.load(Ordering::Acquire); + + let Some(flow_credits) = flow_offset + .checked_sub(current_offset.as_u64()) + .filter(|v| { + // if we're finishing the stream and don't have any buffered data, then we + // don't need any flow control + if request.len == 0 && request.is_fin { + true + } else { + *v > 0 + } + }) + else { + // if we already stored a waker and didn't get more credits then yield the task + ensure!(!stored_waker, Poll::Pending); + stored_waker = true; + + self.poll_waker.register(cx.waker()); + + // make one last effort to acquire some flow credits before going to sleep + current_offset = self.acquire_offset()?; + + continue; + }; + + // clamp the request to the flow credits we have + request.clamp(flow_credits); + + let mut new_offset = current_offset + .as_u64() + .checked_add(request.len as u64) + .ok_or(Error::PayloadTooLarge)?; + + // record that we've sent the final offset + if request.is_fin { + new_offset |= FINISHED_MASK; + } + + let result = self.stream_offset.compare_exchange( + current_offset.as_u64(), + new_offset, + Ordering::Release, // TODO is this the correct ordering? + Ordering::Acquire, + ); + + match result { + Ok(_) => { + // the offset was correctly updated so return our acquired credits + let credits = request.response(current_offset); + return Poll::Ready(Ok(credits)); + } + Err(updated_offset) => { + // the offset was updated from underneath us so try again + current_offset = Self::process_offset(updated_offset)?; + // clear the fact that we stored the waker, since we need to do a full sync + // to get the correct state + stored_waker = false; + continue; + } + } + } + } + + #[inline] + fn acquire_offset(&self) -> Result { + Self::process_offset(self.stream_offset.load(Ordering::Acquire)) + } + + #[inline] + fn process_offset(offset: u64) -> Result { + if offset & ERROR_MASK == ERROR_MASK { + // TODO actually load the error value for the stream + return Err(Error::TransportError { code: VarInt::MAX }); + } + + if offset & FINISHED_MASK == FINISHED_MASK { + return Err(Error::FinalSizeChanged); + } + + Ok(unsafe { VarInt::new_unchecked(offset) }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::stream::send::path; + use std::sync::Arc; + + #[tokio::test] + async fn concurrent_flow() { + let mut initial_offset = VarInt::from_u8(255); + let expected_len = VarInt::from_u16(u16::MAX); + let state = Arc::new(State::new(initial_offset)); + let path_info = path::Info { + mtu: 1500, + send_quantum: 10, + ecn: Default::default(), + next_expected_control_packet: Default::default(), + }; + let total = Arc::new(AtomicU64::new(0)); + // TODO support more than one Waker via intrusive list or something + let workers = 1; + let worker_counts = Vec::from_iter((0..workers).map(|_| Arc::new(AtomicU64::new(0)))); + + let mut tasks = tokio::task::JoinSet::new(); + + for (idx, count) in worker_counts.iter().cloned().enumerate() { + let total = total.clone(); + let state = state.clone(); + tasks.spawn(async move { + tokio::time::sleep(core::time::Duration::from_millis(10)).await; + + let mut buffer_len = 1; + let mut is_fin = false; + let max_segments = 10; + let max_header_len = 50; + let mut max_offset = VarInt::ZERO; + + loop { + let mut request = flow::Request { + len: buffer_len, + is_fin, + }; + request.clamp(path_info.max_flow_credits(max_header_len, max_segments)); + + let Ok(credits) = state.acquire(request).await else { + break; + }; + + println!( + "thread={idx} offset={}..{}", + credits.offset, + credits.offset + credits.len + ); + buffer_len += 1; + buffer_len = buffer_len.min( + expected_len + .as_u64() + .saturating_sub(credits.offset.as_u64()) + .saturating_sub(credits.len as u64) as usize, + ); + assert!(max_offset <= credits.offset); + max_offset = credits.offset; + if buffer_len == 0 { + is_fin = true; + } + total.fetch_add(credits.len as _, Ordering::Relaxed); + count.fetch_add(credits.len as _, Ordering::Relaxed); + } + }); + } + + tasks.spawn(async move { + let mut credits = 10; + while initial_offset < expected_len { + tokio::time::sleep(core::time::Duration::from_millis(1)).await; + initial_offset = (initial_offset + credits).min(expected_len); + credits += 1; + state.release(initial_offset); + } + }); + + // make sure all of the tasks complete + while tasks.join_next().await.is_some() {} + + assert_eq!(total.load(Ordering::Relaxed), expected_len.as_u64()); + let mut at_least_one_write = true; + for (idx, count) in worker_counts.into_iter().enumerate() { + let count = count.load(Ordering::Relaxed); + println!("thread={idx}, count={}", count); + if count == 0 { + at_least_one_write = false; + } + } + + assert!( + at_least_one_write, + "all workers need to write at least one byte" + ); + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/path.rs b/dc/s2n-quic-dc/src/stream/send/path.rs new file mode 100644 index 000000000..870aa97f6 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/path.rs @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::sync::atomic::{AtomicU64, Ordering}; +use s2n_quic_core::{inet::ExplicitCongestionNotification, varint::VarInt}; + +/// Contains the current state of a transmission path +pub struct State { + info: AtomicU64, + next_expected_control_packet: AtomicU64, +} + +impl State { + /// Loads a relaxed view of the current path state + #[inline] + pub fn load(&self) -> Info { + // use relaxed since it's ok to be slightly out of sync with the current MTU/send_quantum + let mut data = self.info.load(Ordering::Relaxed); + + let mtu = data as u16; + data >>= 16; + + let send_quantum = data as u8; + data >>= 8; + + let ecn = data as u8; + let ecn = ExplicitCongestionNotification::new(ecn); + data >>= 8; + + // TODO can we store pacing rate in the remaining bits? + + debug_assert_eq!(data, 0, "unexpected extra data"); + + let next_expected_control_packet = + self.next_expected_control_packet.load(Ordering::Relaxed); + let next_expected_control_packet = + VarInt::new(next_expected_control_packet).unwrap_or(VarInt::MAX); + + Info { + mtu, + send_quantum, + ecn, + next_expected_control_packet, + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct Info { + pub mtu: u16, + pub send_quantum: u8, + pub ecn: ExplicitCongestionNotification, + pub next_expected_control_packet: VarInt, +} + +impl Info { + /// Returns the maximum number of flow credits for the current path info + #[inline] + pub fn max_flow_credits(&self, max_header_len: usize, max_segments: usize) -> u64 { + // trim off the headers since those don't count for flow control + let max_payload_size_per_segment = self.mtu as usize - max_header_len; + // clamp the number of segments we can transmit in a single burst + let max_segments = max_segments.min(self.send_quantum as usize); + + let max_payload_size = max_payload_size_per_segment * max_segments; + + max_payload_size as u64 + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/probes.rs b/dc/s2n-quic-dc/src/stream/send/probes.rs new file mode 100644 index 000000000..e74d81470 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/probes.rs @@ -0,0 +1,113 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::too_many_arguments)] + +use crate::{credentials, packet::stream}; +use core::time::Duration; +use s2n_quic_core::{packet::number::PacketNumber, probe, varint::VarInt}; + +probe::define!( + extern "probe" { + /// Called when a control packet is received + #[link_name = s2n_quic_dc__stream__send__control_packet] + pub fn on_control_packet( + credential_id: credentials::Id, + stream_id: stream::Id, + packet_number: VarInt, + control_data_len: usize, + ); + + /// Called when a control packet is decrypted + #[link_name = s2n_quic_dc__stream__send__control_packet_decrypted] + pub fn on_control_packet_decrypted( + credential_id: credentials::Id, + stream_id: stream::Id, + packet_number: VarInt, + control_data_len: usize, + valid: bool, + ); + + /// Called when a control packet was dropped due to being a duplicate + #[link_name = s2n_quic_dc__stream__send__control_packet_decrypted] + pub fn on_control_packet_duplicate( + credential_id: credentials::Id, + stream_id: stream::Id, + packet_number: VarInt, + control_data_len: usize, + ); + + /// Called when a packet was ACK'd + #[link_name = s2n_quic_dc__stream__send__packet_ack] + pub fn on_packet_ack( + credential_id: credentials::Id, + stream_id: stream::Id, + packet_number: u64, + packet_len: u16, + stream_offset: VarInt, + payload_len: u16, + lifetime: Duration, + ); + + /// Called when a packet was lost + #[link_name = s2n_quic_dc__stream__send__packet_lost] + pub fn on_packet_lost( + credential_id: credentials::Id, + stream_id: stream::Id, + packet_number: u64, + packet_len: u16, + stream_offset: VarInt, + payload_len: u16, + lifetime: Duration, + needs_retransmission: bool, + ); + + /// Called when a packet was ACK'd + #[link_name = s2n_quic_dc__stream__send__pto_backoff_reset] + pub fn on_pto_backoff_reset( + credential_id: credentials::Id, + stream_id: stream::Id, + previous_value: u32, + ); + + /// Called when the PTO timer is armed + #[link_name = s2n_quic_dc__stream__send__pto_armed] + pub fn on_pto_armed( + credential_id: credentials::Id, + stream_id: stream::Id, + pto_period: Duration, + pto_backoff: u32, + ); + + /// Called when a range of stream bytes are transmitted + #[link_name = s2n_quic_dc__stream__send__transmit_stream] + pub fn on_transmit_stream( + credential_id: credentials::Id, + stream_id: stream::Id, + packet_number: PacketNumber, + stream_offset: VarInt, + payload_len: u16, + is_retransmission: bool, + ); + + /// Called when a range of stream bytes are transmitted as a probe + #[link_name = s2n_quic_dc__stream__send__transmit_probe] + pub fn on_transmit_probe( + credential_id: credentials::Id, + stream_id: stream::Id, + packet_number: PacketNumber, + stream_offset: VarInt, + payload_len: u16, + is_retransmission: bool, + ); + + /// Called when a control packet is received + #[link_name = s2n_quic_dc__stream__send__close] + pub fn on_close( + credential_id: credentials::Id, + stream_id: stream::Id, + packet_number: VarInt, + error_code: VarInt, + ); + } +); diff --git a/dc/s2n-quic-dc/src/stream/send/tests.rs b/dc/s2n-quic-dc/src/stream/send/tests.rs new file mode 100644 index 000000000..cf1406c94 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/tests.rs @@ -0,0 +1,2 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 diff --git a/dc/s2n-quic-dc/src/stream/send/transmission.rs b/dc/s2n-quic-dc/src/stream/send/transmission.rs new file mode 100644 index 000000000..74e2ed034 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/transmission.rs @@ -0,0 +1,20 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum Type { + Probe, + Stream, +} + +impl Type { + #[inline] + pub fn is_probe(self) -> bool { + matches!(self, Self::Probe) + } + + #[inline] + pub fn is_stream(self) -> bool { + matches!(self, Self::Stream) + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/worker.rs b/dc/s2n-quic-dc/src/stream/send/worker.rs new file mode 100644 index 000000000..634237a55 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/worker.rs @@ -0,0 +1,1186 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + allocator::{Allocator, Segment}, + congestion, + crypto::{decrypt, encrypt, UninitSlice}, + packet::{ + self, + stream::{self, decoder, encoder}, + }, + path::Parameters, + stream::{ + packet_map, packet_number, processing, + send::{ + error::Error, + filter::Filter, + probes, + transmission::Type as TransmissionType, + worker::{self, retransmission::Segment as Retransmission}, + }, + }, +}; +use s2n_codec::{DecoderBufferMut, EncoderBuffer}; +use s2n_quic_core::{ + branch, ensure, + frame::{self, FrameMut}, + inet::ExplicitCongestionNotification, + interval_set::IntervalSet, + packet::number::PacketNumberSpace, + path::{ecn, INITIAL_PTO_BACKOFF}, + random, + recovery::{Pto, RttEstimator}, + stream::state, + time::{timer, Clock, Timer, Timestamp}, + varint::VarInt, +}; +use std::collections::BinaryHeap; +use tracing::{debug, trace}; + +mod checker; +mod probe; +pub mod retransmission; +pub mod transmission; + +type PacketMap = packet_map::Map>; + +#[derive(Debug)] +pub struct State { + pub stream_id: stream::Id, + rtt_estimator: RttEstimator, + pub sent_packets: PacketMap, + pub state: state::Sender, + control_filter: Filter, + pub retransmissions: BinaryHeap>, + next_expected_control_packet: VarInt, + pub cca: congestion::Controller, + ecn: ecn::Controller, + pub pto: Pto, + pto_backoff: u32, + idle_timer: Timer, + pub error: Option, + unacked_ranges: IntervalSet, + max_sent_offset: VarInt, + pub max_data: VarInt, + checker: checker::Checker, +} + +impl State { + #[inline] + pub fn new(stream_id: stream::Id, params: &Parameters) -> Self { + let mtu = params.max_mtu; + let initial_max_data = params.remote_max_data; + + // initialize the pending data left to send + let mut unacked_ranges = IntervalSet::new(); + unacked_ranges.insert(VarInt::ZERO..=VarInt::MAX).unwrap(); + + let cca = congestion::Controller::new(mtu.into()); + let max_sent_offset = VarInt::ZERO; + + let mut checker = checker::Checker::default(); + checker.on_max_data(initial_max_data); + + Self { + stream_id, + next_expected_control_packet: VarInt::ZERO, + rtt_estimator: crate::recovery::rtt_estimator(), + cca, + sent_packets: Default::default(), + control_filter: Default::default(), + ecn: ecn::Controller::default(), + state: Default::default(), + retransmissions: Default::default(), + pto: Pto::default(), + pto_backoff: INITIAL_PTO_BACKOFF, + idle_timer: Default::default(), + error: None, + unacked_ranges, + max_sent_offset, + max_data: initial_max_data, + checker, + } + } + + /// Returns the current flow offset + #[inline] + pub fn flow_offset(&self) -> VarInt { + let extra_window = self + .cca + .congestion_window() + .saturating_sub(self.cca.bytes_in_flight()); + self.max_data + .min(self.max_sent_offset + extra_window as usize) + } + + /// Called by the worker when it receives a control packet from the peer + #[inline] + pub fn on_control_packet( + &mut self, + decrypt_key: &mut D, + ecn: ExplicitCongestionNotification, + packet: &mut packet::control::decoder::Packet, + random: &mut dyn random::Generator, + clock: &Clk, + message: &mut A, + ) -> Result<(), processing::Error> + where + D: decrypt::Key, + Clk: Clock, + A: Allocator, + { + match self.on_control_packet_impl(decrypt_key, ecn, packet, random, clock, message) { + Ok(None) => {} + Ok(Some(error)) => return Err(error), + Err(error) => { + self.on_error(error, message); + } + } + + self.invariants(); + + Ok(()) + } + + #[inline(always)] + fn on_control_packet_impl( + &mut self, + decrypt_key: &mut D, + _ecn: ExplicitCongestionNotification, + packet: &mut packet::control::decoder::Packet, + random: &mut dyn random::Generator, + clock: &Clk, + message: &mut A, + ) -> Result, Error> + where + D: decrypt::Key, + Clk: Clock, + A: Allocator, + { + probes::on_control_packet( + decrypt_key.credentials().id, + self.stream_id, + packet.packet_number(), + packet.control_data().len(), + ); + + // only process the packet after we know it's authentic + let res = decrypt_key.decrypt( + packet.crypto_nonce(), + packet.header(), + &[], + packet.auth_tag(), + UninitSlice::new(&mut []), + ); + + probes::on_control_packet_decrypted( + decrypt_key.credentials().id, + self.stream_id, + packet.packet_number(), + packet.control_data().len(), + res.is_ok(), + ); + + // drop the packet if it failed to authenticate + if let Err(err) = res { + return Ok(Some(err.into())); + } + + // check if we've already seen the packet + ensure!( + self.control_filter.on_packet(packet).is_ok(), + return { + probes::on_control_packet_duplicate( + decrypt_key.credentials().id, + self.stream_id, + packet.packet_number(), + packet.control_data().len(), + ); + // drop the packet if we've already seen it + Ok(Some(processing::Error::Duplicate)) + } + ); + + let packet_number = packet.packet_number(); + + // raise our next expected control packet + { + let pn = packet_number.saturating_add(VarInt::from_u8(1)); + let pn = self.next_expected_control_packet.max(pn); + self.next_expected_control_packet = pn; + } + + let mut newly_acked = false; + + { + let mut decoder = DecoderBufferMut::new(packet.control_data_mut()); + while !decoder.is_empty() { + let (frame, remaining) = decoder + .decode::() + .map_err(|decoder| Error::FrameError { decoder })?; + decoder = remaining; + + trace!(?frame); + + match frame { + FrameMut::Padding(_) => { + continue; + } + FrameMut::Ping(_) => { + // no need to do anything special here + } + FrameMut::Ack(ack) => { + self.on_frame_ack( + decrypt_key, + &ack, + random, + clock, + message, + &mut newly_acked, + )?; + } + FrameMut::MaxData(frame) => { + if self.max_data < frame.maximum_data { + self.max_data = frame.maximum_data; + self.checker.on_max_data(frame.maximum_data); + } + } + FrameMut::ConnectionClose(close) => { + debug!(connection_close = ?close, state = ?self.state); + + probes::on_close( + decrypt_key.credentials().id, + self.stream_id, + packet_number, + close.error_code, + ); + + // if there was no error and we transmitted everything then just shut the + // stream down + if close.error_code == VarInt::ZERO + && close.frame_type.is_some() + && self.state.on_recv_all_acks().is_ok() + { + self.clean_up(message); + // transmit one more PTO packet so we can ACK the peer's + // CONNECTION_CLOSE frame and they can shutdown quickly. Otherwise, + // they'll need to hang around to respond to potential loss. + self.pto.force_transmit(); + return Ok(None); + } + + // no need to transmit a reset back to the peer - just close it + let _ = self.state.on_send_reset(); + let _ = self.state.on_recv_reset_ack(); + let error = if close.frame_type.is_some() { + Error::TransportError { + code: close.error_code, + } + } else { + Error::ApplicationError { + error: close.error_code.into(), + } + }; + return Err(error); + } + _ => continue, + } + } + } + + if newly_acked { + if self.pto_backoff != INITIAL_PTO_BACKOFF { + probes::on_pto_backoff_reset( + decrypt_key.credentials().id, + self.stream_id, + self.pto_backoff, + ); + } + + self.pto_backoff = INITIAL_PTO_BACKOFF; + } + + trace!( + retransmissions = self.retransmissions.len(), + packets_in_flight = self.sent_packets.iter().count(), + ); + + // try to transition to the final state if we've sent all of the data + if self.unacked_ranges.is_empty() && self.state.on_recv_all_acks().is_ok() { + self.clean_up(message); + // transmit one more PTO packet so we can ACK the peer's + // CONNECTION_CLOSE frame and they can shutdown quickly. Otherwise, + // they'll need to hang around to respond to potential loss. + self.pto.force_transmit(); + } + + // make sure we have all of the pending packets we need to finish the transmission + if !self.state.is_terminal() { + // TODO pass `unacked_ranges` + self.checker + .check_pending_packets(&self.sent_packets, &self.retransmissions); + } + + // re-arm the idle timer as long as we're still sending data + if self.state.is_ready() || self.state.is_sending() || self.state.is_data_sent() { + self.arm_idle_timer(clock); + } + + Ok(None) + } + + #[inline] + fn on_frame_ack( + &mut self, + decrypt_key: &mut D, + ack: &frame::Ack, + random: &mut dyn random::Generator, + clock: &Clk, + message: &mut A, + newly_acked: &mut bool, + ) -> Result<(), Error> + where + D: decrypt::Key, + Ack: frame::ack::AckRanges, + Clk: Clock, + A: Allocator, + { + // TODO get all of this information + // self.ecn.validate( + // newly_acked_ecn_counts, + // sent_packet_ecn_counts, + // baseline_ecn_counts, + // ack_frame_ecn_counts, + // now, + // rtt, + // path, + // publisher, + // ); + + let ack_time = clock.get_time(); + + let mut max = None; + let mut cca_args = None; + let mut bytes_acked = 0; + + for range in ack.ack_ranges() { + max = max.max(Some(*range.end())); + let pmin = PacketNumberSpace::Initial.new_packet_number(*range.start()); + let pmax = PacketNumberSpace::Initial.new_packet_number(*range.end()); + let range = s2n_quic_core::packet::number::PacketNumberRange::new(pmin, pmax); + for (num, packet) in self.sent_packets.remove_range(range) { + packet.data.on_ack(num); + + if packet.data.included_fin { + let _ = self + .unacked_ranges + .remove(packet.data.stream_offset..=VarInt::MAX); + } else { + let _ = self.unacked_ranges.remove(packet.data.range()); + } + + self.checker + .on_ack(packet.data.stream_offset, packet.data.payload_len); + + self.ecn.on_packet_ack(packet.time_sent, packet.ecn); + bytes_acked += packet.data.cca_len() as usize; + + // record the most recent packet + if cca_args + .as_ref() + .map_or(true, |prev: &(Timestamp, _)| prev.0 < packet.time_sent) + { + cca_args = Some((packet.time_sent, packet.cc_info)); + } + + // free the retransmission segment + if let Some(segment) = packet.data.retransmission { + message.free_retransmission(segment); + } + + probes::on_packet_ack( + decrypt_key.credentials().id, + self.stream_id, + num.as_u64(), + packet.data.packet_len, + packet.data.stream_offset, + packet.data.payload_len, + ack_time.saturating_duration_since(packet.time_sent), + ); + + *newly_acked |= true; + } + } + + if let Some((time_sent, cc_info)) = cca_args { + self.cca.on_packet_ack( + time_sent, + bytes_acked, + cc_info, + &self.rtt_estimator, + random, + ack_time, + ); + } + + let mut is_unrecoverable = false; + + if let Some(lost_max) = max.and_then(|min| min.checked_sub(VarInt::from_u8(2))) { + let lost_min = PacketNumberSpace::Initial.new_packet_number(VarInt::ZERO); + let lost_max = PacketNumberSpace::Initial.new_packet_number(lost_max); + let range = s2n_quic_core::packet::number::PacketNumberRange::new(lost_min, lost_max); + for (num, packet) in self.sent_packets.remove_range(range) { + packet.data.on_loss(num); + + // TODO create a path and publisher + // self.ecn.on_packet_loss(packet.time_sent, packet.ecn, now, path, publisher); + + self.cca.on_packet_lost( + packet.data.cca_len() as _, + packet.cc_info, + random, + ack_time, + ); + + probes::on_packet_lost( + decrypt_key.credentials().id, + self.stream_id, + num.as_u64(), + packet.data.packet_len, + packet.data.stream_offset, + packet.data.payload_len, + ack_time.saturating_duration_since(packet.time_sent), + packet.data.retransmission.is_some(), + ); + + if let Some(segment) = packet.data.retransmission { + let segment = message.retransmit(segment); + let retransmission = Retransmission { + segment, + stream_offset: packet.data.stream_offset, + payload_len: packet.data.payload_len, + ty: TransmissionType::Stream, + included_fin: packet.data.included_fin, + }; + self.retransmissions.push(retransmission); + } else { + // we can only recover reliable streams + is_unrecoverable |= packet.data.payload_len > 0 && !self.stream_id.is_reliable; + } + } + } + + ensure!(!is_unrecoverable, Err(Error::RetransmissionFailure)); + + self.invariants(); + + Ok(()) + } + + /// Called by the worker thread when it becomes aware of the application having transmitted a + /// segment + #[inline] + pub fn on_transmit_segment( + &mut self, + packet_number: VarInt, + time_sent: Timestamp, + transmission: transmission::Info, + ecn: ExplicitCongestionNotification, + mut has_more_app_data: bool, + ) { + has_more_app_data |= !self.retransmissions.is_empty(); + let cc_info = self.cca.on_packet_sent( + time_sent, + transmission.cca_len(), + has_more_app_data, + &self.rtt_estimator, + ); + + // update the max offset that we've transmitted + self.max_sent_offset = self.max_sent_offset.max(transmission.end_offset()); + + // try to transition to start sending + let _ = self.state.on_send_stream(); + if transmission.included_fin { + // if the transmission included the final offset, then transition to that state + let _ = self.state.on_send_fin(); + } + + let info = packet_map::SentPacketInfo { + data: transmission, + time_sent, + ecn, + cc_info, + }; + + let packet_number = PacketNumberSpace::Initial.new_packet_number(packet_number); + self.sent_packets.insert(packet_number, info); + + self.invariants(); + } + + #[inline] + pub fn arm_pto_timer(&mut self, encrypt_key: &mut E, clock: &Clk) + where + E: encrypt::Key, + Clk: Clock, + { + let pto_backoff = self.pto_backoff; + let pto_period = self + .rtt_estimator + .pto_period(pto_backoff, PacketNumberSpace::Initial); + self.pto.update(clock.get_time(), pto_period); + + probes::on_pto_armed( + encrypt_key.credentials().id, + self.stream_id, + pto_period, + pto_backoff, + ); + } + + #[inline] + pub fn on_transmit( + &mut self, + packet_number: &packet_number::Counter, + encrypt_key: &mut E, + source_control_port: u16, + source_stream_port: Option, + clock: &Clk, + message: &mut A, + send_quantum: &mut usize, + mtu: u16, + ) -> Result<(), Error> + where + E: encrypt::Key, + Clk: Clock, + A: Allocator, + { + if let Err(error) = self.on_transmit_recovery_impl( + packet_number, + encrypt_key, + source_control_port, + source_stream_port, + clock, + message, + send_quantum, + mtu, + ) { + self.on_error(error, message); + return Err(error); + } + + Ok(()) + } + + #[inline] + fn on_transmit_recovery_impl( + &mut self, + packet_number: &packet_number::Counter, + encrypt_key: &mut E, + source_control_port: u16, + source_stream_port: Option, + clock: &Clk, + message: &mut A, + send_quantum: &mut usize, + mtu: u16, + ) -> Result<(), Error> + where + E: encrypt::Key, + Clk: Clock, + A: Allocator, + { + // try using a retransmission as a probe + self.on_transmit_retransmission_probe(message)?; + + self.on_transmit_retransmissions( + packet_number, + encrypt_key, + clock, + message, + send_quantum, + mtu, + )?; + + self.on_transmit_probe( + packet_number, + encrypt_key, + source_control_port, + source_stream_port, + clock, + message, + send_quantum, + mtu, + )?; + + Ok(()) + } + + #[inline] + fn on_transmit_retransmission_probe(&mut self, message: &mut A) -> Result<(), Error> + where + A: Allocator, + { + // We'll only have retransmissions if we're reliable + ensure!(self.stream_id.is_reliable, Ok(())); + + let mut transmissions = self.pto.transmissions() as usize; + ensure!(transmissions > 0, Ok(())); + + // Only push a new probe if we don't have existing retransmissions. + // + // The retransmissions structure uses a BinaryHeap, which prioritizes the smallest stream + // offsets, in order to more quickly unblock the peer. If we keep using retransmissions as + // probes, then it can cause issues where we don't make progress and keep sending the same + // segments. + ensure!(self.retransmissions.is_empty(), Ok(())); + + transmissions = transmissions.saturating_sub(self.retransmissions.len()); + ensure!(transmissions > 0, Ok(())); + + let pending = self + .sent_packets + .iter() + .filter(|(_, packet)| packet.data.retransmission.is_some()) + .take(transmissions); + + for (_pn, packet) in pending { + if let Some(retransmission) = packet.data.retransmission.as_ref() { + let Some(segment) = message.retransmit_copy(retransmission) else { + break; + }; + let retransmission = Retransmission { + segment, + ty: TransmissionType::Probe, + stream_offset: packet.data.stream_offset, + payload_len: packet.data.payload_len, + included_fin: packet.data.included_fin, + }; + self.retransmissions.push(retransmission); + } + } + + Ok(()) + } + + #[inline] + fn on_transmit_retransmissions( + &mut self, + packet_number: &packet_number::Counter, + encrypt_key: &mut E, + clock: &Clk, + message: &mut A, + send_quantum: &mut usize, + mtu: u16, + ) -> Result<(), Error> + where + E: encrypt::Key, + Clk: Clock, + A: Allocator, + { + while let Some(retransmission) = self.retransmissions.peek() { + ensure!(message.can_push(), break); + if retransmission.ty.is_probe() { + if self.pto.transmissions() == 0 { + let retrans = self + .retransmissions + .pop() + .expect("retransmission should be available"); + message.free(retrans.segment); + continue; + } + } else { + ensure!(!self.cca.is_congestion_limited(), break); + } + ensure!(*send_quantum >= mtu as usize, break); + + let segment_len = message.segment_len(); + let buffer = message.get_mut(retransmission); + + debug_assert!(!buffer.is_empty(), "empty retransmission buffer submitted"); + + // make sure we have enough space in the current buffer for the payload + ensure!( + segment_len.map_or(true, |s| s as usize >= buffer.len()), + break + ); + + let packet_number = match packet_number.next() { + Ok(pn) => pn, + Err(err) if message.is_empty() => return Err(err.into()), + // if we've sent something wait until `on_transmit` gets called again to return an + // error + Err(_) => break, + }; + + { + let buffer = DecoderBufferMut::new(buffer); + match decoder::Packet::retransmit(buffer, packet_number, encrypt_key) { + Ok(info) => info, + Err(err) => { + let retransmission = self + .retransmissions + .pop() + .expect("retransmission should be available"); + message.free(retransmission.segment); + debug_assert!(false, "{err:?}"); + return Err(Error::RetransmissionFailure); + } + } + }; + + let time_sent = clock.get_time(); + *send_quantum = send_quantum.saturating_sub(buffer.len()); + let packet_len = buffer.len() as u16; + + if branch!(message.is_empty()) { + let ecn = self + .ecn + .ecn(s2n_quic_core::transmission::Mode::Normal, time_sent); + message.set_ecn(ecn); + } + + { + let info = self + .retransmissions + .pop() + .expect("retransmission should be available"); + let stream_offset = info.stream_offset; + let payload_len = info.payload_len; + let ty = info.ty; + let included_fin = info.included_fin; + + let retransmission = if ty.is_stream() && self.stream_id.is_reliable { + let segment = message.push_with_retransmission(info.segment); + Some(segment) + } else { + message.push(info.segment); + None + }; + + let transmission = transmission::Info { + packet_len, + stream_offset, + payload_len, + included_fin, + retransmission, + }; + + self.on_transmit_segment( + packet_number, + time_sent, + transmission, + message.ecn(), + false, + ); + + if self.pto.transmissions() > 0 && ty.is_probe() { + self.pto.on_transmit_once(); + } + + #[cfg(debug_assertions)] + self.on_transmit_offset( + encrypt_key, + packet_number, + stream_offset, + included_fin, + payload_len, + ty, + true, + clock, + ); + } + } + + Ok(()) + } + + #[inline] + pub fn on_transmit_probe( + &mut self, + packet_number: &packet_number::Counter, + encrypt_key: &mut E, + source_control_port: u16, + source_stream_port: Option, + clock: &Clk, + message: &mut A, + send_quantum: &mut usize, + mtu: u16, + ) -> Result<(), Error> + where + E: encrypt::Key, + Clk: Clock, + A: Allocator, + { + while self.pto.transmissions() > 0 { + ensure!(message.can_push(), break); + // don't write a packet unless the segment len is the MTU + if let Some(segment_len) = message.segment_len() { + ensure!(segment_len == mtu, Ok(())); + } + + let mut payload = worker::probe::Probe { + offset: self.max_sent_offset, + final_offset: None, + }; + + let packet_len = self.on_transmit_data_unchecked( + packet_number, + encrypt_key, + source_control_port, + source_stream_port, + &mut payload, + clock, + message, + mtu, + TransmissionType::Probe, + )?; + + ensure!(packet_len > 0, break); + + *send_quantum -= packet_len as usize; + + self.pto.on_transmit_once(); + } + + Ok(()) + } + + #[inline] + pub fn on_transmit_data_unchecked( + &mut self, + packet_number: &packet_number::Counter, + encrypt_key: &mut E, + source_control_port: u16, + source_stream_port: Option, + cleartext_payload: &mut I, + clock: &Clk, + message: &mut A, + mtu: u16, + ty: TransmissionType, + ) -> Result + where + E: encrypt::Key, + I: s2n_quic_core::buffer::Reader, + Clk: Clock, + A: Allocator, + { + // try to allocate a segment in the current buffer + ensure!(let Some(segment) = message.alloc(), Ok(0)); + + // try to get the next packet number + ensure!( + let Ok(packet_number) = packet_number.next(), + return { + message.free(segment); + ensure!(!message.is_empty(), Err(Error::PacketNumberExhaustion)); + Ok(0) + } + ); + + let buffer = message.get_mut(&segment); + + { + let mtu = mtu as usize; + + // grow the buffer if needed + if branch!(buffer.capacity() < mtu) { + // We don't use `resize` here, since that will require initializing the bytes, + // which can add up quickly. This is OK, though, since we're just writing into the + // buffer and not actually reading anything. + buffer.reserve(mtu - buffer.len()); + } + + unsafe { + debug_assert!(buffer.capacity() >= mtu); + buffer.set_len(mtu as _); + } + } + + self.checker.check_payload(cleartext_payload); + + let stream_offset = cleartext_payload.current_offset(); + let encoder = EncoderBuffer::new(buffer); + let packet_len = encoder::encode( + encoder, + source_control_port, + source_stream_port, + self.stream_id, + packet_number, + self.next_expected_control_packet, + VarInt::ZERO, + &mut &[][..], + VarInt::ZERO, + &(), + cleartext_payload, + encrypt_key, + ); + + // no need to keep going if the output is empty + ensure!( + packet_len > 0, + return { + message.free(segment); + Ok(0) + } + ); + + let payload_len = (cleartext_payload.current_offset() - stream_offset) + .try_into() + .unwrap(); + + let included_fin = cleartext_payload.final_offset().map_or(false, |fin| { + stream_offset.as_u64() + payload_len as u64 == fin.as_u64() + }); + + buffer.truncate(packet_len); + + debug_assert!( + packet_len < 1 << 16, + "cannot write larger packets than 2^16" + ); + let packet_len = packet_len as u16; + + let time_sent = clock.get_time(); + + // get the current ECN marking for this batch on the first transmission + if branch!(message.is_empty()) { + let ecn = self + .ecn + .ecn(s2n_quic_core::transmission::Mode::Normal, time_sent); + message.set_ecn(ecn); + } + + { + let has_more_app_data = branch!(cleartext_payload.buffered_len() > 0); + + let retransmission = if ty.is_stream() && self.stream_id.is_reliable { + let segment = message.push_with_retransmission(segment); + Some(segment) + } else { + message.push(segment); + None + }; + + let transmission = transmission::Info { + packet_len, + stream_offset, + payload_len, + included_fin, + retransmission, + }; + + self.on_transmit_segment( + packet_number, + time_sent, + transmission, + message.ecn(), + has_more_app_data, + ); + } + + Ok(packet_len) + } + + #[inline] + pub fn on_timeout( + &mut self, + _encrypt_key: &mut E, + clock: &Clk, + message: &mut A, + ) -> Result<(), Error> + where + E: encrypt::Key, + Clk: Clock, + A: Allocator, + { + if self.state.is_ready() { + self.arm_idle_timer(clock); + } else if branch!(self.idle_timer.poll_expiration(clock.get_time()).is_ready()) { + self.on_idle_timeout(message); + } + + let packets_in_flight = !self.sent_packets.is_empty(); + if branch!(self + .pto + .on_timeout(packets_in_flight, clock.get_time()) + .is_ready()) + { + // TODO where does this come from + let max_pto_backoff = 1024; + self.pto_backoff = self.pto_backoff.saturating_mul(2).min(max_pto_backoff); + } + + Ok(()) + } + + #[inline] + fn arm_idle_timer(&mut self, clock: &impl Clock) { + // TODO make this configurable + let idle_timeout = crate::stream::DEFAULT_IDLE_TIMEOUT; + self.idle_timer.set(clock.get_time() + idle_timeout); + } + + #[inline] + fn on_idle_timeout(&mut self, message: &mut A) + where + A: Allocator, + { + // we don't want to transmit anything so enter a terminal state + let mut did_transition = false; + did_transition |= self.state.on_send_reset().is_ok(); + did_transition |= self.state.on_recv_reset_ack().is_ok(); + if did_transition { + self.on_error(Error::IdleTimeout, message); + } + } + + #[inline] + pub fn check_error(&self) -> Result<(), Error> { + if let Some(err) = self.error { + Err(err) + } else { + Ok(()) + } + } + + #[inline] + fn on_error(&mut self, error: Error, message: &mut A) + where + A: Allocator, + { + ensure!(self.error.is_none()); + self.error = Some(error); + let _ = self.state.on_queue_reset(); + + self.clean_up(message); + } + + #[inline] + fn clean_up(&mut self, message: &mut A) + where + A: Allocator, + { + // force clear message so we don't get panics + message.force_clear(); + + for retransmission in self.retransmissions.drain() { + message.free(retransmission.segment); + } + let min = PacketNumberSpace::Initial.new_packet_number(VarInt::ZERO); + let max = PacketNumberSpace::Initial.new_packet_number(VarInt::MAX); + let range = s2n_quic_core::packet::number::PacketNumberRange::new(min, max); + for (_pn, info) in self.sent_packets.remove_range(range) { + if let Some(segment) = info.data.retransmission { + message.free_retransmission(segment); + } + } + + self.idle_timer.cancel(); + self.pto.cancel(); + self.unacked_ranges.clear(); + + self.invariants(); + } + + #[inline(always)] + #[cfg_attr(not(debug_assertions), allow(dead_code))] + fn on_transmit_offset( + &mut self, + encrypt_key: &mut impl encrypt::Key, + packet_number: VarInt, + stream_offset: VarInt, + included_fin: bool, + payload_len: u16, + transmission_type: TransmissionType, + is_retransmission: bool, + _clock: &impl Clock, + ) { + let packet_number = s2n_quic_core::packet::number::PacketNumberSpace::Initial + .new_packet_number(packet_number); + let is_probe = matches!(transmission_type, TransmissionType::Probe); + if is_probe { + probes::on_transmit_probe( + encrypt_key.credentials().id, + self.stream_id, + packet_number, + stream_offset, + payload_len, + is_retransmission, + ); + } else { + probes::on_transmit_stream( + encrypt_key.credentials().id, + self.stream_id, + packet_number, + stream_offset, + payload_len, + is_retransmission, + ); + } + self.checker.on_stream_transmission( + stream_offset, + payload_len, + is_retransmission, + is_probe, + ); + trace!( + stream_id = ?self.stream_id, + stream_offset = stream_offset.as_u64(), + payload_len, + included_fin, + is_retransmission, + is_probe = transmission_type.is_probe(), + ); + } + + #[cfg(debug_assertions)] + #[inline] + fn invariants(&self) { + // TODO + } + + #[cfg(not(debug_assertions))] + #[inline(always)] + fn invariants(&self) {} +} + +impl timer::Provider for State { + #[inline] + fn timers(&self, query: &mut Q) -> timer::Result { + // if we're in a terminal state then no timers are needed + ensure!(!self.state.is_terminal(), Ok(())); + + if branch!(matches!(self.state, state::Sender::Send)) { + let mut can_transmit = !self.cca.is_congestion_limited(); + can_transmit |= self.cca.requires_fast_retransmission(); + can_transmit &= self.max_sent_offset < self.max_data; + if can_transmit { + self.cca.timers(query)?; + } + } + self.pto.timers(query)?; + self.idle_timer.timers(query)?; + Ok(()) + } +} + +#[cfg(debug_assertions)] +impl Drop for State { + #[inline] + fn drop(&mut self) { + // ignore any checks for leaking segments since we're cleaning everything up + for mut retransmission in self.retransmissions.drain() { + retransmission.segment.leak(); + } + let min = PacketNumberSpace::Initial.new_packet_number(VarInt::ZERO); + let max = PacketNumberSpace::Initial.new_packet_number(VarInt::MAX); + let range = s2n_quic_core::packet::number::PacketNumberRange::new(min, max); + for (_pn, info) in self.sent_packets.remove_range(range) { + if let Some(mut segment) = info.data.retransmission { + segment.leak(); + } + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/worker/checker.rs b/dc/s2n-quic-dc/src/stream/send/worker/checker.rs new file mode 100644 index 000000000..1db1f5453 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/worker/checker.rs @@ -0,0 +1,158 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![cfg_attr(not(debug_assertions), allow(dead_code, unused_imports))] + +use s2n_quic_core::{buffer::Reader, interval_set::IntervalSet, varint::VarInt}; + +#[cfg(debug_assertions)] +macro_rules! run { + ($($tt:tt)*) => { + $($tt)* + } +} + +#[cfg(not(debug_assertions))] +macro_rules! run { + ($($tt:tt)*) => {}; +} + +#[cfg(debug_assertions)] +#[derive(Clone, Debug, Default)] +pub struct Checker { + acked_ranges: IntervalSet, + largest_transmitted_offset: VarInt, + max_data: VarInt, + highest_seen_offset: Option, + final_offset: Option, +} + +#[cfg(not(debug_assertions))] +#[derive(Clone, Debug, Default)] +pub struct Checker {} + +#[allow(unused_variables)] +impl Checker { + #[inline(always)] + pub fn check_payload(&mut self, payload: &impl Reader) { + run!({ + if let Some(final_offset) = payload.final_offset() { + self.on_final_offset(final_offset); + } + self.on_stream_offset( + payload.current_offset(), + payload.buffered_len().min(u16::MAX as _) as _, + ); + }); + } + + #[inline(always)] + pub fn on_ack(&mut self, offset: VarInt, payload_len: u16) { + run!(if payload_len > 0 { + self.acked_ranges + .insert(offset..offset + VarInt::from_u16(payload_len)) + .unwrap(); + }); + } + + #[inline(always)] + pub fn on_max_data(&mut self, max_data: VarInt) { + run!({ + self.max_data = self.max_data.max(max_data); + }); + } + + #[inline(always)] + pub fn check_pending_packets( + &self, + packets: &super::PacketMap, + retransmissions: &super::BinaryHeap>, + ) { + run!({ + let largest_transmitted_offset = self.largest_transmitted_offset; + if largest_transmitted_offset == 0u64 { + return; + } + + let mut missing = IntervalSet::new(); + missing + .insert(VarInt::ZERO..largest_transmitted_offset) + .unwrap(); + // remove all of the ranges we've acked + missing.difference(&self.acked_ranges).unwrap(); + + for (_pn, packet) in packets.iter() { + let offset = packet.data.stream_offset; + let payload_len = packet.data.payload_len; + if payload_len > 0 { + missing + .remove(offset..offset + VarInt::from_u16(payload_len)) + .unwrap(); + } + } + + for packet in retransmissions.iter() { + let offset = packet.stream_offset; + let payload_len = packet.payload_len; + if payload_len > 0 { + missing + .remove(offset..offset + VarInt::from_u16(payload_len)) + .unwrap(); + } + } + + assert!( + missing.is_empty(), + "missing ranges for retransmission {missing:?}" + ); + }); + } + + #[inline(always)] + pub fn on_stream_transmission( + &mut self, + offset: VarInt, + payload_len: u16, + is_retransmission: bool, + is_probe: bool, + ) { + run!({ + self.on_stream_offset(offset, payload_len); + + if !is_retransmission && !is_probe { + assert_eq!(self.largest_transmitted_offset, offset); + } + + let end_offset = offset + VarInt::from_u16(payload_len); + self.largest_transmitted_offset = self.largest_transmitted_offset.max(end_offset); + + assert!(self.largest_transmitted_offset <= self.max_data); + }); + } + + #[inline(always)] + pub fn on_stream_offset(&mut self, offset: VarInt, payload_len: u16) { + run!({ + if let Some(final_offset) = self.final_offset { + assert!(offset <= final_offset); + } + + match self.highest_seen_offset.as_mut() { + Some(prev) => *prev = (*prev).max(offset), + None => self.highest_seen_offset = Some(offset), + } + }); + } + + #[inline(always)] + fn on_final_offset(&mut self, final_offset: VarInt) { + run!({ + self.on_stream_offset(final_offset, 0); + + match self.final_offset { + Some(prev) => assert_eq!(prev, final_offset), + None => self.final_offset = Some(final_offset), + } + }); + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/worker/probe.rs b/dc/s2n-quic-dc/src/stream/send/worker/probe.rs new file mode 100644 index 000000000..0b70c4ab5 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/worker/probe.rs @@ -0,0 +1,50 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic_core::{buffer, varint::VarInt}; + +#[derive(Clone, Copy, Debug)] +pub struct Probe { + pub offset: VarInt, + pub final_offset: Option, +} + +impl buffer::reader::Storage for Probe { + type Error = core::convert::Infallible; + + #[inline] + fn buffered_len(&self) -> usize { + 0 + } + + #[inline] + fn read_chunk( + &mut self, + _watermark: usize, + ) -> Result, Self::Error> { + Ok(Default::default()) + } + + #[inline] + fn partial_copy_into( + &mut self, + _dest: &mut Dest, + ) -> Result, Self::Error> + where + Dest: buffer::writer::Storage + ?Sized, + { + Ok(Default::default()) + } +} + +impl buffer::Reader for Probe { + #[inline] + fn current_offset(&self) -> VarInt { + self.offset + } + + #[inline] + fn final_offset(&self) -> Option { + self.final_offset + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/worker/retransmission.rs b/dc/s2n-quic-dc/src/stream/send/worker/retransmission.rs new file mode 100644 index 000000000..20d99d278 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/worker/retransmission.rs @@ -0,0 +1,51 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{allocator, stream::send::transmission}; +use core::cmp::Ordering; +use s2n_quic_core::varint::VarInt; + +#[derive(Debug)] +pub struct Segment { + pub segment: S, + pub ty: transmission::Type, + pub stream_offset: VarInt, + pub payload_len: u16, + pub included_fin: bool, +} + +impl PartialEq for Segment { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl Eq for Segment {} + +impl PartialOrd for Segment { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Segment { + #[inline] + fn cmp(&self, rhs: &Self) -> Ordering { + self.ty + .cmp(&rhs.ty) + .then(self.stream_offset.cmp(&rhs.stream_offset)) + .then(self.payload_len.cmp(&rhs.payload_len)) + .reverse() + } +} + +impl core::ops::Deref for Segment { + type Target = S; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.segment + } +} diff --git a/dc/s2n-quic-dc/src/stream/send/worker/transmission.rs b/dc/s2n-quic-dc/src/stream/send/worker/transmission.rs new file mode 100644 index 000000000..99a6d334c --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/send/worker/transmission.rs @@ -0,0 +1,46 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::allocator::Segment; +use s2n_quic_core::{packet::number::PacketNumber, varint::VarInt}; +use tracing::trace; + +#[derive(Debug)] +pub struct Info { + pub packet_len: u16, + pub retransmission: Option, + pub stream_offset: VarInt, + pub payload_len: u16, + pub included_fin: bool, +} + +impl Info { + #[inline] + pub fn cca_len(&self) -> u16 { + if self.payload_len == 0 { + self.packet_len + } else { + self.payload_len + } + } + + #[inline] + pub fn range(&self) -> core::ops::Range { + self.stream_offset..self.end_offset() + } + + #[inline] + pub fn end_offset(&self) -> VarInt { + self.stream_offset + VarInt::from_u16(self.payload_len) + } + + #[inline(always)] + pub fn on_ack(&self, packet_number: PacketNumber) { + trace!(event = "ack", ?packet_number, range = ?self.range()); + } + + #[inline(always)] + pub fn on_loss(&self, packet_number: PacketNumber) { + trace!(event = "lost", ?packet_number, range = ?self.range()); + } +} diff --git a/quic/s2n-quic-qns/etc/Dockerfile b/quic/s2n-quic-qns/etc/Dockerfile index d86a20c5f..45917b330 100644 --- a/quic/s2n-quic-qns/etc/Dockerfile +++ b/quic/s2n-quic-qns/etc/Dockerfile @@ -26,8 +26,9 @@ FROM rust-base AS sources COPY Cargo.toml /app COPY common /app/common COPY quic /app/quic -# Don't include testing crates +# Don't include testing crates or s2n-quic-dc RUN set -eux; \ + sed -i '/dc/d' Cargo.toml; \ sed -i '/xdp/d' quic/s2n-quic-platform/Cargo.toml; \ sed -i '/xdp/d' quic/s2n-quic-qns/Cargo.toml; \ sed -i '/xdp/d' quic/s2n-quic/Cargo.toml; \