diff --git a/packages/pam/handlers/rdp/bridge.go b/packages/pam/handlers/rdp/bridge.go index f582c864..17970ccc 100644 --- a/packages/pam/handlers/rdp/bridge.go +++ b/packages/pam/handlers/rdp/bridge.go @@ -15,3 +15,44 @@ type Bridge struct { handle uint64 cleanup func() } + +// EventType discriminates the variants in Event. +type EventType uint8 + +const ( + EventTypeKeyboard EventType = 1 + EventTypeUnicode EventType = 2 + EventTypeMouse EventType = 3 + EventTypeTargetFrame EventType = 4 +) + +// Action identifies the RDP framing of a TargetFrame event. +type Action uint8 + +const ( + ActionX224 Action = 0 + ActionFastPath Action = 1 +) + +// Fields are reused across variants; switch on Type. +type Event struct { + Type EventType + ElapsedNs uint64 + Scancode uint8 + CodePoint uint16 + X uint16 + Y uint16 + Flags uint32 + WheelDelta int32 + Action Action + Payload []byte +} + +// PollResult discriminates PollEvent outcomes. +type PollResult uint8 + +const ( + PollOK PollResult = 0 + PollTimeout PollResult = 1 + PollEnded PollResult = 2 +) diff --git a/packages/pam/handlers/rdp/bridge_cgo_shared.go b/packages/pam/handlers/rdp/bridge_cgo_shared.go index 9a822e6f..f5181057 100644 --- a/packages/pam/handlers/rdp/bridge_cgo_shared.go +++ b/packages/pam/handlers/rdp/bridge_cgo_shared.go @@ -5,6 +5,7 @@ package rdp /* #cgo CFLAGS: -I${SRCDIR}/native/include +#include #include "rdp_bridge.h" */ import "C" @@ -14,6 +15,8 @@ import ( "errors" "fmt" "net" + "time" + "unsafe" ) func (p *RDPProxy) HandleConnection(ctx context.Context, clientConn net.Conn) error { @@ -36,6 +39,26 @@ func (p *RDPProxy) HandleConnection(ctx context.Context, clientConn net.Conn) er } defer bridge.Close() + // Drain bridge tap events into the session logger. The Rust side closes + // the events channel when the session ends, so the goroutine exits via + // PollEnded without needing an explicit shutdown signal. + drainCtx, cancelDrain := context.WithCancel(ctx) + drainDone := make(chan struct{}) + go func() { + defer close(drainDone) + drainBridgeEvents(drainCtx, bridge, p.config.SessionLogger, p.config.SessionID, p.config.SessionStartedAt) + }() + defer func() { + cancelDrain() + // Wait briefly for the drain loop to exit so a cancelled session + // can't race the Bridge.Close below. PollEvent's timeout caps how + // long this can take. + select { + case <-drainDone: + case <-time.After(2 * pollTimeout): + } + }() + waitErr := make(chan error, 1) go func() { waitErr <- bridge.Wait() }() @@ -89,8 +112,62 @@ func (b *Bridge) Close() error { return nil } -// IsSupported reports whether this build has a real RDP bridge. Used -// by the gateway to decide whether to advertise RDP in the capabilities -// response: a stub-build gateway that advertises support would route -// RDP sessions only to fail them at connect time. +// True when the real bridge is compiled in (vs the stub). func IsSupported() bool { return true } + +// PollEvent drains one tap event with the given timeout. The returned Event +// is only meaningful when result == PollOK. PollEvent is not safe to call +// concurrently for the same Bridge; serialize calls in a single goroutine. +func (b *Bridge) PollEvent(timeout time.Duration) (PollResult, Event, error) { + timeoutMs := timeout.Milliseconds() + if timeoutMs < 0 { + timeoutMs = 0 + } + if timeoutMs > int64(^C.uint32_t(0)) { + timeoutMs = int64(^C.uint32_t(0)) + } + + var raw C.struct_RdpEvent + rc := C.rdp_bridge_poll_event(C.uint64_t(b.handle), &raw, C.uint32_t(timeoutMs)) + + switch rc { + case C.RDP_POLL_OK: + // fall through to event materialization below + case C.RDP_POLL_TIMEOUT: + return PollTimeout, Event{}, nil + case C.RDP_POLL_ENDED: + return PollEnded, Event{}, nil + case C.RDP_POLL_INVALID_HANDLE: + return PollEnded, Event{}, ErrInvalidHandle + default: + return PollEnded, Event{}, fmt.Errorf("rdp bridge: poll returned unexpected status %d", int32(rc)) + } + + ev := Event{ + Type: EventType(uint8(raw.event_type)), + ElapsedNs: uint64(raw.elapsed_ns), + Flags: uint32(raw.flags), + WheelDelta: int32(raw.wheel_delta), + Action: Action(uint8(raw.action)), + } + switch ev.Type { + case EventTypeKeyboard: + ev.Scancode = uint8(raw.value_a) + case EventTypeUnicode: + ev.CodePoint = uint16(raw.value_a) + case EventTypeMouse: + ev.X = uint16(raw.value_a) + ev.Y = uint16(raw.value_b) + case EventTypeTargetFrame: + // Always free the libc-malloc'd buffer Rust handed us, even if + // the copy below is empty -- ownership transfer is unconditional. + if raw.payload_ptr != nil { + defer C.free(unsafe.Pointer(raw.payload_ptr)) + if raw.payload_len > 0 { + ev.Payload = C.GoBytes(unsafe.Pointer(raw.payload_ptr), C.int(raw.payload_len)) + } + } + } + + return PollOK, ev, nil +} diff --git a/packages/pam/handlers/rdp/bridge_cgo_unix.go b/packages/pam/handlers/rdp/bridge_cgo_unix.go index 91b24d38..f940eb6d 100644 --- a/packages/pam/handlers/rdp/bridge_cgo_unix.go +++ b/packages/pam/handlers/rdp/bridge_cgo_unix.go @@ -62,19 +62,8 @@ func startWithDupedFD(dupFd int, targetHost string, targetPort uint16, username, return &Bridge{handle: uint64(handle)}, nil } -// StartWithReadWriter adapts an fd-less Go byte stream (e.g. *tls.Conn -// from the gateway's mTLS-wrapped virtual connection) to the bridge, -// which needs a real file descriptor because the Rust side uses tokio's -// TcpStream::from_raw_fd and does direct async I/O on the socket. -// -// Trick: open a loopback TCP pair. Hand one end's fd to the bridge (it -// thinks it has a real client). Keep the other end in Go and shuttle -// bytes between it and rw with two io.Copy goroutines. -// -// rw (e.g. *tls.Conn) <-io.Copy-> peer <-kernel loopback-> accepted (fd -> Rust bridge) -// -// Cost: two extra in-process copies and a loopback round-trip per byte. -// Negligible vs. the TLS + CredSSP work on either side. +// Adapts an fd-less Go byte stream to the Rust bridge (which needs a real fd +// for tokio's TcpStream::from_raw_fd) by routing through a loopback TCP pair. func StartWithReadWriter(rw io.ReadWriter, targetHost string, targetPort uint16, username, password string) (*Bridge, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/packages/pam/handlers/rdp/bridge_stub.go b/packages/pam/handlers/rdp/bridge_stub.go index 37a3bcdf..0d704908 100644 --- a/packages/pam/handlers/rdp/bridge_stub.go +++ b/packages/pam/handlers/rdp/bridge_stub.go @@ -6,6 +6,7 @@ import ( "context" "io" "net" + "time" ) // Stub implementations for builds without `-tags rdp` or on platforms @@ -29,6 +30,10 @@ func (b *Bridge) Wait() error { return ErrRdpUnavailable } func (b *Bridge) Cancel() error { return ErrRdpUnavailable } func (b *Bridge) Close() error { return ErrRdpUnavailable } +func (b *Bridge) PollEvent(_ time.Duration) (PollResult, Event, error) { + return PollEnded, Event{}, ErrRdpUnavailable +} + // IsSupported reports whether this build has a real RDP bridge. See the // rdp-enabled counterpart in bridge_cgo_shared.go for details. func IsSupported() bool { return false } diff --git a/packages/pam/handlers/rdp/native/Cargo.lock b/packages/pam/handlers/rdp/native/Cargo.lock index 5c04a3e5..c4652505 100644 --- a/packages/pam/handlers/rdp/native/Cargo.lock +++ b/packages/pam/handlers/rdp/native/Cargo.lock @@ -1309,9 +1309,11 @@ dependencies = [ "bytes", "ironrdp-acceptor", "ironrdp-connector", + "ironrdp-core", "ironrdp-pdu", "ironrdp-tls", "ironrdp-tokio", + "libc", "libz-sys", "rcgen", "rustls", diff --git a/packages/pam/handlers/rdp/native/Cargo.toml b/packages/pam/handlers/rdp/native/Cargo.toml index 500a2117..cb53a5d2 100644 --- a/packages/pam/handlers/rdp/native/Cargo.toml +++ b/packages/pam/handlers/rdp/native/Cargo.toml @@ -13,10 +13,12 @@ path = "src/lib.rs" [dependencies] ironrdp-acceptor = "0.8" ironrdp-connector = "0.8" +ironrdp-core = "0.1" ironrdp-tokio = { version = "0.8", features = ["reqwest"] } ironrdp-pdu = "0.7" ironrdp-tls = { version = "0.2", features = ["rustls"] } x509-cert = { version = "0.2", features = ["std"] } +libc = "0.2" tokio = { version = "1", features = ["full"] } tokio-util = "0.7" diff --git a/packages/pam/handlers/rdp/native/include/rdp_bridge.h b/packages/pam/handlers/rdp/native/include/rdp_bridge.h index 83088768..888afbbd 100644 --- a/packages/pam/handlers/rdp/native/include/rdp_bridge.h +++ b/packages/pam/handlers/rdp/native/include/rdp_bridge.h @@ -1,8 +1,5 @@ -/* - * infisical-rdp-bridge C ABI. See ffi.rs for details. Lifecycle: - * start_* -> wait -> free; cancel may be called from any thread. - * start_* transfers ownership of the client fd/socket to the bridge. - */ +/* C ABI; see ffi.rs. Lifecycle: start_* -> wait -> free. start_* takes + * ownership of the client fd/socket. cancel is thread-safe. */ #ifndef INFISICAL_RDP_BRIDGE_H #define INFISICAL_RDP_BRIDGE_H @@ -46,6 +43,35 @@ int32_t rdp_bridge_wait(uint64_t handle); int32_t rdp_bridge_cancel(uint64_t handle); int32_t rdp_bridge_free(uint64_t handle); +/* Poll return codes (distinct number space from the bridge status codes + * above; consumed by rdp_bridge_poll_event only). */ +#define RDP_POLL_OK 0 +#define RDP_POLL_TIMEOUT 1 +#define RDP_POLL_ENDED 2 +#define RDP_POLL_INVALID_HANDLE -1 + +/* Event type discriminator. */ +#define RDP_EVENT_KEYBOARD 1 +#define RDP_EVENT_UNICODE 2 +#define RDP_EVENT_MOUSE 3 +#define RDP_EVENT_TARGET_FRAME 4 + +/* Fields reused across variants; check event_type. For TargetFrame, + * payload_ptr is libc-malloc'd and the Go caller must C.free it. */ +struct RdpEvent { + uint8_t event_type; + uint64_t elapsed_ns; + uint32_t value_a; + uint32_t value_b; + uint32_t flags; + int32_t wheel_delta; + uint8_t action; + uint8_t *payload_ptr; + uint32_t payload_len; +}; + +int32_t rdp_bridge_poll_event(uint64_t handle, struct RdpEvent *out, uint32_t timeout_ms); + #ifdef __cplusplus } #endif diff --git a/packages/pam/handlers/rdp/native/src/bridge.rs b/packages/pam/handlers/rdp/native/src/bridge.rs index cfe5e992..da866ac2 100644 --- a/packages/pam/handlers/rdp/native/src/bridge.rs +++ b/packages/pam/handlers/rdp/native/src/bridge.rs @@ -1,9 +1,10 @@ -//! MITM bridge. Runs acceptor + connector only through CredSSP (to inject -//! credentials), then byte-forwards between the two TLS streams. Letting -//! client and target negotiate MCS/capabilities/share-state directly -//! avoids drift that breaks strict clients (Windows App, mstsc). +//! MITM bridge. Runs acceptor + connector through CredSSP only, then byte- +//! forwards. Letting client/target negotiate MCS directly avoids drift +//! that breaks strict clients (Windows App, mstsc). +use std::borrow::Cow; use std::sync::Arc; +use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use ironrdp_acceptor::{Acceptor, BeginResult}; @@ -11,20 +12,31 @@ use ironrdp_connector::credssp::{CredsspSequence, KerberosConfig}; use ironrdp_connector::sspi::credssp::ClientState; use ironrdp_connector::sspi::generator::GeneratorState; use ironrdp_connector::{encode_x224_packet, ClientConnector, ClientConnectorState}; +use ironrdp_core::ReadCursor; use ironrdp_pdu::gcc::ConferenceCreateRequest; +use ironrdp_pdu::input::fast_path::{FastPathInput, FastPathInputEvent}; use ironrdp_pdu::ironrdp_core::{decode, WriteBuf}; -use ironrdp_pdu::mcs::ConnectInitial; +use ironrdp_pdu::mcs::{ConnectInitial, SendDataRequest}; use ironrdp_pdu::nego::SecurityProtocol; use ironrdp_pdu::rdp::client_info::Credentials as AcceptorCredentials; +use ironrdp_pdu::rdp::headers::{ShareControlHeader, ShareControlPdu}; use ironrdp_pdu::x224::{X224Data, X224}; +use ironrdp_pdu::Action; use ironrdp_tokio::reqwest::ReqwestNetworkClient; use ironrdp_tokio::{FramedWrite, NetworkClient}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; use tokio_util::sync::CancellationToken; -use tracing::info; +use tracing::{debug, info, warn}; +use crate::cap_filter; use crate::config::{connector_config, DEFAULT_HEIGHT, DEFAULT_WIDTH}; +use crate::events::{elapsed_ns_since, EventSender, SessionEvent}; + +/// Cap on c2t PDUs to inspect before giving up on the cap filter. +const CONFIRM_ACTIVE_SCAN_MAX_PDUS: usize = 32; +/// Wall-clock cap on the cap-filter scan window. +const CONFIRM_ACTIVE_SCAN_MAX_DURATION: Duration = Duration::from_secs(5); // The acceptor side of the bridge expects the user to type the target // username with an empty password. The real password is injected by the @@ -42,9 +54,10 @@ pub async fn run_mitm( client_tcp: TcpStream, target: TargetEndpoint, cancel: CancellationToken, + tx: EventSender, ) -> Result<()> { tokio::select! { - result = run_mitm_inner(client_tcp, target) => result, + result = run_mitm_inner(client_tcp, target, tx) => result, _ = cancel.cancelled() => { info!("session canceled by caller"); Ok(()) @@ -52,7 +65,11 @@ pub async fn run_mitm( } } -async fn run_mitm_inner(client_tcp: TcpStream, target: TargetEndpoint) -> Result<()> { +async fn run_mitm_inner( + client_tcp: TcpStream, + target: TargetEndpoint, + tx: EventSender, +) -> Result<()> { // Our tree pulls both ring (direct) and aws-lc-rs (via reqwest); rustls // 0.23 needs an explicit provider when more than one is compiled in. let _ = rustls::crypto::ring::default_provider().install_default(); @@ -66,10 +83,8 @@ async fn run_mitm_inner(client_tcp: TcpStream, target: TargetEndpoint) -> Result let (mut client_stream, client_leftover) = acceptor_output; let (mut target_stream, target_leftover) = connector_output; - // Strip virtual channels (clipboard, drives, audio, USB, etc.) from the - // client's MCS Connect Initial before forwarding. Mouse/keyboard/screen - // ride the implicit MCS I/O channel, not virtual channels, so they're - // unaffected. + // Strip virtual channels (clipboard, drives, audio, USB) from MCS Connect Initial. + // Mouse/keyboard/screen ride the implicit I/O channel and are unaffected. filter_client_mcs_connect_initial(&mut client_stream, &mut target_stream, client_leftover) .await .context("filter client MCS Connect Initial")?; @@ -92,24 +107,334 @@ async fn run_mitm_inner(client_tcp: TcpStream, target: TargetEndpoint) -> Result .await .context("flush target stream before passthrough")?; - // Real RDP clients hard-close TCP without TLS close_notify, which - // rustls surfaces as UnexpectedEof. Treat that as clean shutdown. - match tokio::io::copy_bidirectional(&mut client_stream, &mut target_stream).await { - Ok(_) => info!("session ended cleanly"), - Err(e) if is_unexpected_eof(&e) => info!("session ended (peer hard-closed)"), - Err(e) => return Err(e).context("passthrough copy_bidirectional"), + // PDU-framed bridge with an event tap. read_pdu is pure TPKT/FastPath + // framing (no state machine) so this preserves the "no MCS drift" + // property of the byte-level copy_bidirectional it replaces. + let client_framed = ironrdp_tokio::TokioFramed::new(client_stream); + let target_framed = ironrdp_tokio::TokioFramed::new(target_stream); + bridge_pdus(client_framed, target_framed, tx).await +} + +async fn bridge_pdus( + client_framed: ironrdp_tokio::TokioFramed, + target_framed: ironrdp_tokio::TokioFramed, + tx: EventSender, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, + T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, +{ + let (mut client_read, mut client_write) = ironrdp_tokio::split_tokio_framed(client_framed); + let (mut target_read, mut target_write) = ironrdp_tokio::split_tokio_framed(target_framed); + + let started_at = Instant::now(); + let tx_c2t = tx.clone(); + let tx_t2c = tx; + + let c2t = async move { + let mut cap_filter = CapFilterState::Scanning { + started_at: Instant::now(), + pdus_seen: 0, + info_done: false, + confirm_done: false, + }; + loop { + let (action, frame) = match client_read.read_pdu().await { + Ok(v) => v, + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err::<_, anyhow::Error>(e.into()), + }; + tap_client_to_target(action, &frame, started_at, &tx_c2t); + + let bytes_to_forward: Vec = match cap_filter.consider(action, &frame) { + CapFilterDecision::Forward => frame.to_vec(), + CapFilterDecision::Replace(modified) => modified, + }; + target_write + .write_all(&bytes_to_forward) + .await + .context("write client PDU to target")?; + } + Ok(()) + }; + + let t2c = async move { + loop { + let (action, frame) = match target_read.read_pdu().await { + Ok(v) => v, + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err::<_, anyhow::Error>(e.into()), + }; + tap_target_to_client(action, &frame, started_at, &tx_t2c); + client_write + .write_all(&frame) + .await + .context("write target PDU to client")?; + } + Ok(()) + }; + + match tokio::try_join!(c2t, t2c) { + Ok(_) => { + info!("session ended cleanly"); + Ok(()) + } + Err(e) => Err(e).context("bridge_pdus"), + } +} + +/// One-shot c2t scan that patches Client Info + Client Confirm Active. +enum CapFilterState { + Scanning { + started_at: Instant, + pdus_seen: usize, + info_done: bool, + confirm_done: bool, + }, + Done, +} + +enum CapFilterDecision { + Forward, + Replace(Vec), +} + +impl CapFilterState { + fn consider(&mut self, action: Action, frame: &[u8]) -> CapFilterDecision { + let CapFilterState::Scanning { + started_at, + pdus_seen, + info_done, + confirm_done, + } = self + else { + return CapFilterDecision::Forward; + }; + + if action != Action::X224 { + return CapFilterDecision::Forward; + } + + *pdus_seen += 1; + if *pdus_seen > CONFIRM_ACTIVE_SCAN_MAX_PDUS + || started_at.elapsed() > CONFIRM_ACTIVE_SCAN_MAX_DURATION + { + warn!( + pdus_seen, + info_done = *info_done, + confirm_done = *confirm_done, + "scan window exhausted before both filters fired" + ); + *self = CapFilterState::Done; + return CapFilterDecision::Forward; + } + + // The two filters are disjoint, so a match short-circuits. + if !*info_done { + if let Some(modified) = try_filter_client_info(frame) { + *info_done = true; + let both_done = *info_done && *confirm_done; + if both_done { + *self = CapFilterState::Done; + } + return CapFilterDecision::Replace(modified); + } + } + if !*confirm_done { + if let Some(modified) = try_filter_confirm_active(frame) { + *confirm_done = true; + let both_done = *info_done && *confirm_done; + if both_done { + *self = CapFilterState::Done; + } + return CapFilterDecision::Replace(modified); + } + } + CapFilterDecision::Forward + } +} + +#[derive(Debug, Clone, Copy)] +struct ByteRange { + offset: usize, + len: usize, +} + +impl ByteRange { + fn slice<'a>(&self, frame: &'a [u8]) -> &'a [u8] { + &frame[self.offset..self.offset + self.len] + } + + fn slice_mut<'a>(&self, frame: &'a mut [u8]) -> &'a mut [u8] { + &mut frame[self.offset..self.offset + self.len] } - Ok(()) } -fn is_unexpected_eof(err: &std::io::Error) -> bool { - err.kind() == std::io::ErrorKind::UnexpectedEof +/// Locate `send_data.user_data` inside `frame`. Bails on Cow::Owned. +fn user_data_range_within(frame: &[u8], send_data: &SendDataRequest<'_>) -> Option { + let slice: &[u8] = match &send_data.user_data { + Cow::Borrowed(s) => *s, + Cow::Owned(_) => return None, + }; + let frame_start = frame.as_ptr() as usize; + let slice_start = slice.as_ptr() as usize; + if slice_start < frame_start || slice_start + slice.len() > frame_start + frame.len() { + return None; + } + Some(ByteRange { + offset: slice_start - frame_start, + len: slice.len(), + }) +} + +fn locate_client_info(frame: &[u8]) -> Option { + const SEC_INFO_PKT: u16 = 0x0040; + let send_data = decode::>>(frame).ok()?.0; + let user_data = user_data_range_within(frame, &send_data)?; + if user_data.len < 4 { + return None; + } + let bytes = user_data.slice(frame); + let sec_flags = u16::from_le_bytes([bytes[0], bytes[1]]); + (sec_flags & SEC_INFO_PKT != 0).then_some(user_data) +} + +struct ConfirmActiveLayout { + user_data: ByteRange, + caps_start_in_user_data: usize, +} + +fn locate_confirm_active(frame: &[u8]) -> Option { + let send_data = decode::>>(frame).ok()?.0; + let share_control = decode::(send_data.user_data.as_ref()).ok()?; + if !matches!( + share_control.share_control_pdu, + ShareControlPdu::ClientConfirmActive(_), + ) { + return None; + } + let user_data = user_data_range_within(frame, &send_data)?; + let caps_start_in_user_data = parse_confirm_active_caps_start(user_data.slice(frame))?; + Some(ConfirmActiveLayout { + user_data, + caps_start_in_user_data, + }) +} + +/// MS-RDPBCGR 2.2.1.13.2.1: ShareControlHeader(10) + originatorId(2) + +/// sourceDescLen(2) + combinedLen(2) + sourceDescriptor(var) + numCaps(2) + pad(2) +fn parse_confirm_active_caps_start(user_data: &[u8]) -> Option { + let mut p = 10 + 2; + if user_data.len() < p + 4 { + return None; + } + let source_desc_len = u16::from_le_bytes([user_data[p], user_data[p + 1]]) as usize; + p += 4 + source_desc_len + 4; + (p <= user_data.len()).then_some(p) +} + +fn try_filter_client_info(frame: &[u8]) -> Option> { + let user_data = locate_client_info(frame)?; + let mut out = frame.to_vec(); + if !cap_filter::client_info::clear_compression(user_data.slice_mut(&mut out)) { + return None; + } + debug!("Client Info PDU: cleared INFO_COMPRESSION + CompressionTypeMask"); + Some(out) +} + +fn try_filter_confirm_active(frame: &[u8]) -> Option> { + let layout = locate_confirm_active(frame)?; + let user_data_bytes = layout.user_data.slice(frame); + + let mut order_body_offset_in_frame: Option = None; + let mut codecs_body_offset_in_frame: Option = None; + for cap in cap_filter::walk_caps(user_data_bytes, layout.caps_start_in_user_data) { + let body_offset_in_frame = layout.user_data.offset + cap.body_offset_in_user_data; + match cap.cap_type { + cap_filter::cap_types::ORDER if cap.cap_len >= cap_filter::order_cap::BODY_LEN + 4 => { + order_body_offset_in_frame = Some(body_offset_in_frame); + } + cap_filter::cap_types::BITMAP_CODECS + if cap.cap_len >= cap_filter::bitmap_codecs_cap::MIN_BODY_LEN + 4 => + { + codecs_body_offset_in_frame = Some(body_offset_in_frame); + } + _ => {} + } + } + + // Without Order patched, server emits unrenderable Orders. + let order_offset = order_body_offset_in_frame?; + let mut out = frame.to_vec(); + cap_filter::order_cap::clear_order_support( + &mut out[order_offset..order_offset + cap_filter::order_cap::BODY_LEN], + ); + if let Some(codecs_offset) = codecs_body_offset_in_frame { + cap_filter::bitmap_codecs_cap::clear_codec_count(&mut out[codecs_offset..]); + } + debug!("Confirm Active: cleared Order support + BitmapCodecs count"); + Some(out) } -// Reads the client's MCS Connect Initial PDU, removes any virtual channels -// declared in its Client Network Data block, and forwards the rewritten PDU -// to the target. Any bytes after the PDU (rare; PDUs typically arrive one at -// a time at this stage) are forwarded unchanged. +fn tap_client_to_target(action: Action, frame: &[u8], started_at: Instant, tx: &EventSender) { + if action != Action::FastPath { + return; + } + let input: FastPathInput = match decode_fast_path_input(frame) { + Ok(input) => input, + Err(e) => { + warn!(?e, "failed to decode FastPathInput"); + return; + } + }; + let elapsed_ns = elapsed_ns_since(started_at); + for event in input.input_events() { + let session_event = match *event { + FastPathInputEvent::KeyboardEvent(flags, scancode) => SessionEvent::KeyboardInput { + scancode, + flags, + elapsed_ns, + }, + FastPathInputEvent::UnicodeKeyboardEvent(flags, code_point) => { + SessionEvent::UnicodeInput { + code_point, + flags, + elapsed_ns, + } + } + FastPathInputEvent::MouseEvent(pdu) => SessionEvent::MouseInput { + x: pdu.x_position, + y: pdu.y_position, + flags: pdu.flags, + wheel_delta: pdu.number_of_wheel_rotation_units, + elapsed_ns, + }, + // MouseEventEx, MouseEventRel, QoeEvent, SyncEvent: skip for now; + // uncommon in normal sessions and not needed for replay V1. + _ => continue, + }; + // send error means the receiver was dropped (poll loop exited). + // The bridge keeps forwarding bytes regardless. + let _ = tx.send(session_event); + } +} + +fn tap_target_to_client(action: Action, frame: &[u8], started_at: Instant, tx: &EventSender) { + let _ = tx.send(SessionEvent::TargetFrame { + action, + payload: frame.to_vec(), + elapsed_ns: elapsed_ns_since(started_at), + }); +} + +fn decode_fast_path_input(frame: &[u8]) -> anyhow::Result { + use ironrdp_core::Decode as _; + let mut cursor = ReadCursor::new(frame); + FastPathInput::decode(&mut cursor).map_err(|e| anyhow::anyhow!("decode FastPathInput: {e}")) +} + +// Strips virtual channels from the Client Network Data block of MCS Connect Initial. async fn filter_client_mcs_connect_initial( client_stream: &mut ErasedStream, target_stream: &mut ErasedStream, @@ -409,3 +734,50 @@ pub trait AsyncReadWrite: AsyncRead + AsyncWrite {} impl AsyncReadWrite for T where T: AsyncRead + AsyncWrite {} pub type ErasedStream = Box; + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a synthetic ConfirmActive user_data prefix: + /// ShareControlHeader(10) + originatorId(2) + sourceDescLen(2) + + /// combinedLen(2) + sourceDescriptor(source_desc_len) + numCaps(2) + pad(2) + fn confirm_active_prefix(source_desc_len: usize) -> Vec { + let mut buf = vec![0xAA_u8; 10 + 2]; + buf.extend_from_slice(&(source_desc_len as u16).to_le_bytes()); + buf.extend_from_slice(&0xBBBB_u16.to_le_bytes()); + buf.extend_from_slice(&vec![0xCC; source_desc_len]); + buf.extend_from_slice(&0xDDDD_u16.to_le_bytes()); + buf.extend_from_slice(&0xEEEE_u16.to_le_bytes()); + buf + } + + #[test] + fn caps_start_after_variable_source_descriptor() { + let user_data = confirm_active_prefix(6); + let p = parse_confirm_active_caps_start(&user_data).expect("caps start"); + assert_eq!(p, 12 + 4 + 6 + 4); + assert_eq!(p, user_data.len()); + } + + #[test] + fn caps_start_works_when_source_descriptor_is_empty() { + let user_data = confirm_active_prefix(0); + let p = parse_confirm_active_caps_start(&user_data).expect("caps start"); + assert_eq!(p, 12 + 4 + 0 + 4); + } + + #[test] + fn caps_start_returns_none_when_header_truncated() { + let user_data = vec![0u8; 15]; + assert!(parse_confirm_active_caps_start(&user_data).is_none()); + } + + #[test] + fn caps_start_returns_none_when_source_desc_len_overflows() { + let mut user_data = vec![0u8; 12]; + user_data.extend_from_slice(&9999_u16.to_le_bytes()); + user_data.extend_from_slice(&[0u8; 2]); + assert!(parse_confirm_active_caps_start(&user_data).is_none()); + } +} diff --git a/packages/pam/handlers/rdp/native/src/cap_filter.rs b/packages/pam/handlers/rdp/native/src/cap_filter.rs new file mode 100644 index 00000000..f584e260 --- /dev/null +++ b/packages/pam/handlers/rdp/native/src/cap_filter.rs @@ -0,0 +1,276 @@ +//! Byte-level patches for Confirm Active / Client Info PDUs. +//! IronRDP's typed decode->encode loses unrelated fields, so we mutate in place. + +/// MS-RDPBCGR 2.2.7 +pub mod cap_types { + pub const ORDER: u16 = 0x0003; + pub const BITMAP_CODECS: u16 = 0x001d; +} + +/// MS-RDPBCGR 2.2.7.1.3 +pub mod order_cap { + use std::ops::Range; + + pub const BODY_LEN: usize = 84; + pub const ORDER_SUPPORT: Range = 32..64; + + /// Forces server to fall back to Bitmap updates. + /// orderFlags untouched so NEGOTIATEORDERSUPPORT (mandatory) stays set. + pub fn clear_order_support(body: &mut [u8]) { + body[ORDER_SUPPORT].fill(0); + } +} + +/// MS-RDPBCGR 2.2.7.2.10 +pub mod bitmap_codecs_cap { + pub const CODEC_COUNT_OFFSET: usize = 0; + pub const MIN_BODY_LEN: usize = 1; + + /// Prevents server from picking RFX/NSCodec/AVC. + pub fn clear_codec_count(body: &mut [u8]) { + body[CODEC_COUNT_OFFSET] = 0; + } +} + +/// MS-RDPBCGR 2.2.1.11.1.1, given user_data of an MCS Send Data Request +/// whose security header has SEC_INFO_PKT set. +pub mod client_info { + use std::ops::Range; + + /// 4 bytes security header + 4 bytes CodePage. + pub const FLAGS: Range = 8..12; + pub const INFO_COMPRESSION: u32 = 0x0000_0080; + pub const COMPRESSION_TYPE_MASK: u32 = 0x0000_1E00; + + /// Disables MPPC bulk compression (IronRDP-session can't decompress it). + pub fn clear_compression(user_data: &mut [u8]) -> bool { + if user_data.len() < FLAGS.end { + return false; + } + let bytes: [u8; 4] = match user_data[FLAGS.clone()].try_into() { + Ok(b) => b, + Err(_) => return false, + }; + let flags = u32::from_le_bytes(bytes); + let new_flags = flags & !(INFO_COMPRESSION | COMPRESSION_TYPE_MASK); + if flags == new_flags { + return false; + } + user_data[FLAGS.clone()].copy_from_slice(&new_flags.to_le_bytes()); + true + } +} + +#[derive(Debug, Clone, Copy)] +pub struct WalkedCap { + pub cap_type: u16, + pub cap_len: usize, + pub body_offset_in_user_data: usize, +} + +/// Stops on a malformed cap header. +pub fn walk_caps(user_data: &[u8], caps_start: usize) -> CapIter<'_> { + CapIter { + user_data, + cursor: caps_start, + } +} + +pub struct CapIter<'a> { + user_data: &'a [u8], + cursor: usize, +} + +impl<'a> Iterator for CapIter<'a> { + type Item = WalkedCap; + + fn next(&mut self) -> Option { + if self.cursor + 4 > self.user_data.len() { + return None; + } + let cap_type = + u16::from_le_bytes([self.user_data[self.cursor], self.user_data[self.cursor + 1]]); + let cap_len = u16::from_le_bytes([ + self.user_data[self.cursor + 2], + self.user_data[self.cursor + 3], + ]) as usize; + if cap_len < 4 || self.cursor + cap_len > self.user_data.len() { + return None; + } + let item = WalkedCap { + cap_type, + cap_len, + body_offset_in_user_data: self.cursor + 4, + }; + self.cursor += cap_len; + Some(item) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn order_clear_zeros_only_the_support_array() { + let mut body = vec![0xff_u8; order_cap::BODY_LEN]; + order_cap::clear_order_support(&mut body); + assert_eq!(&body[order_cap::ORDER_SUPPORT], &[0; 32]); + assert_eq!(&body[28..32], &[0xff; 4]); + assert_eq!(&body[64..68], &[0xff; 4]); + } + + #[test] + fn bitmap_codecs_clears_only_first_byte() { + let mut body = vec![0xff_u8; 16]; + bitmap_codecs_cap::clear_codec_count(&mut body); + assert_eq!(body[0], 0); + assert_eq!(&body[1..], &[0xff; 15]); + } + + #[test] + fn client_info_clears_compression_bits() { + let mut user_data = vec![0u8; 12]; + user_data[8..12].copy_from_slice(&0x0000_1E80_u32.to_le_bytes()); + assert!(client_info::clear_compression(&mut user_data)); + let new_flags = u32::from_le_bytes(user_data[8..12].try_into().unwrap()); + assert_eq!(new_flags, 0); + } + + #[test] + fn client_info_noop_when_compression_already_off() { + let mut user_data = vec![0u8; 12]; + user_data[8..12].copy_from_slice(&0x0000_0040_u32.to_le_bytes()); + assert!(!client_info::clear_compression(&mut user_data)); + } + + #[test] + fn client_info_returns_false_when_user_data_too_short() { + let mut user_data = vec![0u8; 11]; + assert!(!client_info::clear_compression(&mut user_data)); + } + + #[test] + fn client_info_preserves_unrelated_flag_bits() { + let mut user_data = vec![0xAB_u8; 12]; + // INFO_COMPRESSION + CompressionTypeMask + INFO_AUTOLOGON(0x0008) + INFO_UNICODE(0x0010) + let original = 0x0000_1E80_u32 | 0x0000_0008 | 0x0000_0010; + user_data[8..12].copy_from_slice(&original.to_le_bytes()); + assert!(client_info::clear_compression(&mut user_data)); + let new_flags = u32::from_le_bytes(user_data[8..12].try_into().unwrap()); + assert_eq!(new_flags, 0x0000_0008 | 0x0000_0010); + assert_eq!(&user_data[..8], &[0xAB; 8]); + } + + #[test] + fn walk_caps_iterates_each_cap() { + let mut user_data = vec![0u8; 8]; + user_data.extend_from_slice(&[0x01, 0x00, 0x08, 0x00, 0xaa, 0xbb, 0xcc, 0xdd]); + user_data.extend_from_slice(&[ + 0x03, 0x00, 0x0c, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, + ]); + let caps: Vec<_> = walk_caps(&user_data, 8).collect(); + assert_eq!(caps.len(), 2); + assert_eq!(caps[0].cap_type, 0x0001); + assert_eq!(caps[0].cap_len, 8); + assert_eq!(caps[0].body_offset_in_user_data, 12); + assert_eq!(caps[1].cap_type, 0x0003); + assert_eq!(caps[1].cap_len, 12); + assert_eq!(caps[1].body_offset_in_user_data, 20); + } + + #[test] + fn walk_caps_stops_on_malformed_header() { + let mut user_data = vec![0u8; 4]; + user_data.extend_from_slice(&[0x01, 0x00, 0x64, 0x00]); + let caps: Vec<_> = walk_caps(&user_data, 4).collect(); + assert_eq!(caps.len(), 0); + } + + #[test] + fn walk_caps_stops_on_cap_len_below_header_size() { + let user_data = vec![0x01, 0x00, 0x02, 0x00]; + let caps: Vec<_> = walk_caps(&user_data, 0).collect(); + assert_eq!(caps.len(), 0); + } + + /// End-to-end byte-preservation contract: walk a synthetic caps array + /// containing Order, BitmapCodecs, and an unrelated cap; patch only + /// the targeted fields; assert every other byte is identical. + #[test] + fn walk_and_patch_preserves_unrelated_bytes() { + let mut buf: Vec = Vec::new(); + + // Cap 1: unrelated cap_type=0x0001, len=8, body filled with 0x77 + buf.extend_from_slice(&[0x01, 0x00, 0x08, 0x00]); + buf.extend_from_slice(&[0x77; 4]); + let unrelated_range = 0..buf.len(); + + // Cap 2: Order (0x0003), full body of 0xFF + 4-byte header + let order_header_offset = buf.len(); + let order_total_len = (order_cap::BODY_LEN + 4) as u16; + buf.extend_from_slice(&[0x03, 0x00]); + buf.extend_from_slice(&order_total_len.to_le_bytes()); + let order_body_offset = buf.len(); + buf.extend_from_slice(&vec![0xFF; order_cap::BODY_LEN]); + + // Cap 3: BitmapCodecs (0x001d), 4-byte header + body of 0xEE + let codecs_header_offset = buf.len(); + let codecs_body_len = 16usize; + buf.extend_from_slice(&[0x1D, 0x00]); + buf.extend_from_slice(&((codecs_body_len + 4) as u16).to_le_bytes()); + let codecs_body_offset = buf.len(); + buf.extend_from_slice(&vec![0xEE; codecs_body_len]); + + // Cap 4: trailing unrelated cap (filter must not stop early or read past it) + let trailing_offset = buf.len(); + buf.extend_from_slice(&[0x02, 0x00, 0x06, 0x00, 0x55, 0x55]); + + let original = buf.clone(); + + let caps: Vec<_> = walk_caps(&buf, 0).collect(); + assert_eq!(caps.len(), 4); + assert_eq!(caps[0].body_offset_in_user_data, order_header_offset - 4); + assert_eq!(caps[1].cap_type, cap_types::ORDER); + assert_eq!(caps[1].body_offset_in_user_data, order_body_offset); + assert_eq!(caps[2].cap_type, cap_types::BITMAP_CODECS); + assert_eq!(caps[2].body_offset_in_user_data, codecs_body_offset); + assert_eq!(caps[3].body_offset_in_user_data, trailing_offset + 4); + + order_cap::clear_order_support( + &mut buf[order_body_offset..order_body_offset + order_cap::BODY_LEN], + ); + bitmap_codecs_cap::clear_codec_count(&mut buf[codecs_body_offset..]); + + // Unrelated cap: byte-identical + assert_eq!(&buf[unrelated_range.clone()], &original[unrelated_range]); + // Order cap: header preserved, only ORDER_SUPPORT range zeroed + assert_eq!( + &buf[order_header_offset..order_body_offset], + &original[order_header_offset..order_body_offset] + ); + let zeroed_start = order_body_offset + order_cap::ORDER_SUPPORT.start; + let zeroed_end = order_body_offset + order_cap::ORDER_SUPPORT.end; + assert_eq!( + &buf[order_body_offset..zeroed_start], + &original[order_body_offset..zeroed_start] + ); + assert_eq!(&buf[zeroed_start..zeroed_end], &[0u8; 32]); + assert_eq!( + &buf[zeroed_end..codecs_header_offset], + &original[zeroed_end..codecs_header_offset] + ); + // BitmapCodecs cap: header preserved, only first body byte zeroed + assert_eq!( + &buf[codecs_header_offset..codecs_body_offset], + &original[codecs_header_offset..codecs_body_offset] + ); + assert_eq!(buf[codecs_body_offset], 0); + assert_eq!( + &buf[codecs_body_offset + 1..trailing_offset], + &original[codecs_body_offset + 1..trailing_offset] + ); + // Trailing cap: byte-identical + assert_eq!(&buf[trailing_offset..], &original[trailing_offset..]); + } +} diff --git a/packages/pam/handlers/rdp/native/src/config.rs b/packages/pam/handlers/rdp/native/src/config.rs index b1f9a77a..f7588e4b 100644 --- a/packages/pam/handlers/rdp/native/src/config.rs +++ b/packages/pam/handlers/rdp/native/src/config.rs @@ -17,10 +17,8 @@ pub fn connector_config(username: String, password: String) -> Config { }, desktop_scale_factor: 0, - // Advertise HYBRID_EX|HYBRID|SSL to match what native clients send. - // Windows App validates the target's echoed clientRequestedProtocols - // against what it sent on the acceptor side; if the sets diverge it - // disconnects right after Connect Response. + // Match native client's HYBRID_EX|HYBRID|SSL set; Windows App validates the + // target echo against what it sent and disconnects on divergence. enable_tls: true, enable_credssp: true, diff --git a/packages/pam/handlers/rdp/native/src/events.rs b/packages/pam/handlers/rdp/native/src/events.rs new file mode 100644 index 00000000..ffb10fd3 --- /dev/null +++ b/packages/pam/handlers/rdp/native/src/events.rs @@ -0,0 +1,46 @@ +//! Bridge tap events. Input is FastPath-decoded c2t; TargetFrame is raw t2c +//! PDU bytes (decoded at replay time in the browser). + +use std::time::Instant; + +use ironrdp_pdu::input::fast_path::KeyboardFlags; +use ironrdp_pdu::input::mouse::PointerFlags; +use ironrdp_pdu::Action; +use tokio::sync::mpsc; + +#[derive(Debug, Clone)] +pub enum SessionEvent { + KeyboardInput { + scancode: u8, + flags: KeyboardFlags, + elapsed_ns: u64, + }, + UnicodeInput { + code_point: u16, + flags: KeyboardFlags, + elapsed_ns: u64, + }, + MouseInput { + x: u16, + y: u16, + flags: PointerFlags, + wheel_delta: i16, + elapsed_ns: u64, + }, + TargetFrame { + action: Action, + payload: Vec, + elapsed_ns: u64, + }, +} + +pub fn elapsed_ns_since(started_at: Instant) -> u64 { + started_at.elapsed().as_nanos() as u64 +} + +pub type EventSender = mpsc::UnboundedSender; +pub type EventReceiver = mpsc::UnboundedReceiver; + +pub fn channel() -> (EventSender, EventReceiver) { + mpsc::unbounded_channel() +} diff --git a/packages/pam/handlers/rdp/native/src/ffi.rs b/packages/pam/handlers/rdp/native/src/ffi.rs index ecef7782..d178bfaa 100644 --- a/packages/pam/handlers/rdp/native/src/ffi.rs +++ b/packages/pam/handlers/rdp/native/src/ffi.rs @@ -1,8 +1,5 @@ -//! C ABI for the bridge. Called from Go via CGo. -//! -//! Each session runs on its own OS thread with a current-thread tokio -//! runtime. `start_*` transfers ownership of the client fd/socket to -//! Rust (Go hands in a dup). Contract: wait, then free. +//! C ABI for the bridge. Each session runs on its own thread with a +//! current-thread tokio runtime. Caller contract: wait, then free. use std::collections::HashMap; use std::ffi::{c_char, CStr}; @@ -10,12 +7,15 @@ use std::net::TcpStream as StdTcpStream; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{LazyLock, Mutex}; use std::thread::JoinHandle; +use std::time::Duration; use tokio::net::TcpStream; +use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use tracing::{error, info}; use crate::bridge::{run_mitm, TargetEndpoint}; +use crate::events::{self, SessionEvent}; pub const RDP_BRIDGE_OK: i32 = 0; pub const RDP_BRIDGE_SESSION_ERROR: i32 = 1; @@ -24,10 +24,146 @@ pub const RDP_BRIDGE_INVALID_HANDLE: i32 = -1; pub const RDP_BRIDGE_BAD_ARG: i32 = -2; pub const RDP_BRIDGE_RUNTIME_ERROR: i32 = -3; +// Distinct number space from the bridge status codes above; consumed by +// a different Go function. +pub const RDP_POLL_OK: i32 = 0; +pub const RDP_POLL_TIMEOUT: i32 = 1; +pub const RDP_POLL_ENDED: i32 = 2; +pub const RDP_POLL_INVALID_HANDLE: i32 = -1; + +#[repr(u8)] +pub enum RdpEventType { + Keyboard = 1, + Unicode = 2, + Mouse = 3, + TargetFrame = 4, +} + +/// Fields are reused across variants; check `event_type` first. +/// For TargetFrame, `payload_ptr` is libc::malloc'd; Go must libc::free it. +#[repr(C)] +pub struct RdpEvent { + pub event_type: u8, + /// Nanoseconds since bridge start. + pub elapsed_ns: u64, + /// Keyboard: scancode. Unicode: code point. Mouse: x. TargetFrame: bytes. + pub value_a: u32, + /// Mouse: y. Others: 0. + pub value_b: u32, + /// Keyboard / Unicode / Mouse flags (raw bits from the RDP layer). + pub flags: u32, + /// Mouse wheel delta (signed). 0 for others. + pub wheel_delta: i32, + /// TargetFrame: 0 = X.224, 1 = FastPath. 0 for others. + pub action: u8, + pub payload_ptr: *mut u8, + pub payload_len: u32, +} + +impl RdpEvent { + const fn zero() -> Self { + Self { + event_type: 0, + elapsed_ns: 0, + value_a: 0, + value_b: 0, + flags: 0, + wheel_delta: 0, + action: 0, + payload_ptr: std::ptr::null_mut(), + payload_len: 0, + } + } + + fn from_session_event(ev: SessionEvent) -> Self { + match ev { + SessionEvent::KeyboardInput { + scancode, + flags, + elapsed_ns, + } => Self { + event_type: RdpEventType::Keyboard as u8, + elapsed_ns, + value_a: scancode.into(), + flags: flags.bits().into(), + ..Self::zero() + }, + SessionEvent::UnicodeInput { + code_point, + flags, + elapsed_ns, + } => Self { + event_type: RdpEventType::Unicode as u8, + elapsed_ns, + value_a: code_point.into(), + flags: flags.bits().into(), + ..Self::zero() + }, + SessionEvent::MouseInput { + x, + y, + flags, + wheel_delta, + elapsed_ns, + } => Self { + event_type: RdpEventType::Mouse as u8, + elapsed_ns, + value_a: x.into(), + value_b: y.into(), + flags: flags.bits().into(), + wheel_delta: wheel_delta.into(), + ..Self::zero() + }, + SessionEvent::TargetFrame { + action, + payload, + elapsed_ns, + } => { + // Copy into a libc::malloc'd buffer the Go caller will free. + // Using libc (not Rust's allocator) lets Go free directly via + // C.free without an extra trip back through the FFI. + let len = payload.len(); + let ptr = if len == 0 { + std::ptr::null_mut() + } else { + unsafe { + let p = libc::malloc(len) as *mut u8; + if p.is_null() { + std::ptr::null_mut() + } else { + std::ptr::copy_nonoverlapping(payload.as_ptr(), p, len); + p + } + } + }; + Self { + event_type: RdpEventType::TargetFrame as u8, + elapsed_ns, + value_a: len as u32, + action: match action { + ironrdp_pdu::Action::X224 => 0, + ironrdp_pdu::Action::FastPath => 1, + }, + payload_ptr: ptr, + payload_len: len as u32, + ..Self::zero() + } + } + } + } +} + struct BridgeEntry { cancel: CancellationToken, // Taken by wait(); None afterward. join: Mutex>>>, + // Receiver side of the bridge's event channel. Polled by Go via + // rdp_bridge_poll_event. Wrapped in Option so the poll loop can take it + // out for the duration of the await without holding the HANDLES lock. + events_rx: Mutex>>, + // Set once the events channel has reported closed; subsequent polls + // short-circuit to RDP_POLL_ENDED. + events_ended: Mutex, } static HANDLES: LazyLock>> = @@ -64,6 +200,8 @@ fn spawn_session( let cancel = CancellationToken::new(); let cancel_for_thread = cancel.clone(); + let (events_tx, events_rx) = events::channel(); + let join = std::thread::Builder::new() .name("rdp-bridge-session".to_owned()) .spawn(move || -> anyhow::Result<()> { @@ -78,13 +216,15 @@ fn spawn_session( username, password, }; - run_mitm(client, endpoint, cancel_for_thread).await + run_mitm(client, endpoint, cancel_for_thread, events_tx).await }) })?; Ok(register(BridgeEntry { cancel, join: Mutex::new(Some(join)), + events_rx: Mutex::new(Some(events_rx)), + events_ended: Mutex::new(false), })) } @@ -227,3 +367,71 @@ pub extern "C" fn rdp_bridge_free(handle: u64) -> i32 { RDP_BRIDGE_INVALID_HANDLE } } + +/// Poll the next event, blocking up to `timeout_ms` ms. On RDP_POLL_OK, +/// caller owns *payload_ptr (must libc::free). +#[no_mangle] +pub unsafe extern "C" fn rdp_bridge_poll_event( + handle: u64, + out: *mut RdpEvent, + timeout_ms: u32, +) -> i32 { + if out.is_null() { + return RDP_POLL_INVALID_HANDLE; + } + + // Avoid holding the HANDLES lock across the await. + let take_result: Result>, i32> = { + let handles = HANDLES.lock().expect("HANDLES poisoned"); + match handles.get(&handle) { + None => Err(RDP_POLL_INVALID_HANDLE), + Some(entry) => { + if *entry.events_ended.lock().expect("events_ended poisoned") { + Err(RDP_POLL_ENDED) + } else { + Ok(entry.events_rx.lock().expect("events_rx poisoned").take()) + } + } + } + }; + let mut rx = match take_result { + Ok(Some(rx)) => rx, + // Concurrent poll on the same handle; callers must serialize. + Ok(None) => return RDP_POLL_INVALID_HANDLE, + Err(code) => return code, + }; + + // Short-lived single-thread runtime just for the timeout. Cheap; the + // bridge thread runs its own runtime. + let result = { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .expect("build poll runtime"); + rt.block_on(async { + tokio::time::timeout(Duration::from_millis(timeout_ms.into()), rx.recv()).await + }) + }; + + let outcome = match result { + Ok(Some(event)) => { + let rdp_event = RdpEvent::from_session_event(event); + unsafe { out.write(rdp_event) }; + RDP_POLL_OK + } + Ok(None) => RDP_POLL_ENDED, // sender side dropped (bridge ended) + Err(_timeout) => RDP_POLL_TIMEOUT, + }; + + // Restore the receiver, or mark ended if the channel reported closed. + let handles = HANDLES.lock().expect("HANDLES poisoned"); + if let Some(entry) = handles.get(&handle) { + if outcome == RDP_POLL_ENDED { + *entry.events_ended.lock().expect("events_ended poisoned") = true; + } else { + *entry.events_rx.lock().expect("events_rx poisoned") = Some(rx); + } + } + + outcome +} diff --git a/packages/pam/handlers/rdp/native/src/lib.rs b/packages/pam/handlers/rdp/native/src/lib.rs index 61c64480..abb6f0bd 100644 --- a/packages/pam/handlers/rdp/native/src/lib.rs +++ b/packages/pam/handlers/rdp/native/src/lib.rs @@ -3,5 +3,7 @@ //! passes bytes through. pub mod bridge; +pub mod cap_filter; pub mod config; +pub mod events; pub mod ffi; diff --git a/packages/pam/handlers/rdp/proxy.go b/packages/pam/handlers/rdp/proxy.go index e113902a..b5220f60 100644 --- a/packages/pam/handlers/rdp/proxy.go +++ b/packages/pam/handlers/rdp/proxy.go @@ -1,6 +1,13 @@ package rdp import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/rs/zerolog/log" + "github.com/Infisical/infisical-merge/packages/pam/session" ) @@ -10,9 +17,11 @@ type RDPProxyConfig struct { InjectUsername string InjectPassword string SessionID string - // Retained for API symmetry with other PAM handlers; not yet written - // through (no RDP session recording in this MVP). - SessionLogger session.SessionLogger + SessionLogger session.SessionLogger + // Session-anchored origin for elapsedNs. The Rust bridge restarts its + // own clock per RDP client connection; we rewrite each event's elapsedNs + // against this anchor so timestamps stay monotonic across reconnects. + SessionStartedAt time.Time } type RDPProxy struct { @@ -22,3 +31,122 @@ type RDPProxy struct { func NewRDPProxy(config RDPProxyConfig) *RDPProxy { return &RDPProxy{config: config} } + +// Wire envelopes carried inside TerminalEvent.Data for ChannelType=RDP. +type rdpTargetFrameEnvelope struct { + Type string `json:"type"` // "target_frame" + Action string `json:"action"` // "x224" | "fastpath" + Payload []byte `json:"payload"` // raw PDU bytes (base64 by Go's json.Marshal) + ElapsedNs uint64 `json:"elapsedNs"` +} + +type rdpKeyboardEnvelope struct { + Type string `json:"type"` // "keyboard" + Scancode uint8 `json:"scancode"` + Flags uint32 `json:"flags"` + ElapsedNs uint64 `json:"elapsedNs"` +} + +type rdpUnicodeEnvelope struct { + Type string `json:"type"` // "unicode" + CodePoint uint16 `json:"codePoint"` + Flags uint32 `json:"flags"` + ElapsedNs uint64 `json:"elapsedNs"` +} + +type rdpMouseEnvelope struct { + Type string `json:"type"` // "mouse" + X uint16 `json:"x"` + Y uint16 `json:"y"` + Flags uint32 `json:"flags"` + WheelDelta int32 `json:"wheelDelta"` + ElapsedNs uint64 `json:"elapsedNs"` +} + +// Bounds bridge poll latency so Cancel ends the drain loop promptly. +const pollTimeout = 250 * time.Millisecond + +var errUnknownRdpEventType = errors.New("rdp: unknown event type") + +// Logger errors are warned but don't stop the drain; dropping one event is +// better than back-pressuring the bridge byte stream. +func drainBridgeEvents(ctx context.Context, b *Bridge, logger session.SessionLogger, sessionID string, sessionStartedAt time.Time) { + if logger == nil { + return + } + for { + if ctx.Err() != nil { + return + } + result, ev, err := b.PollEvent(pollTimeout) + if err != nil { + log.Debug().Err(err).Str("sessionID", sessionID).Msg("rdp event drain stopped") + return + } + switch result { + case PollEnded: + return + case PollTimeout: + continue + case PollOK: + if !sessionStartedAt.IsZero() { + ev.ElapsedNs = uint64(time.Since(sessionStartedAt).Nanoseconds()) + } + data, encErr := encodeRdpEvent(ev) + if encErr != nil { + log.Warn().Err(encErr).Str("sessionID", sessionID).Uint8("type", uint8(ev.Type)).Msg("encode RDP event") + continue + } + te := session.TerminalEvent{ + Timestamp: time.Now(), + EventType: session.TerminalEventRDP, + ChannelType: session.TerminalChannelRDP, + Data: data, + ElapsedTime: float64(ev.ElapsedNs) / 1e9, + } + if logErr := logger.LogTerminalEvent(te); logErr != nil { + log.Warn().Err(logErr).Str("sessionID", sessionID).Msg("log RDP event") + } + } + } +} + +func encodeRdpEvent(ev Event) ([]byte, error) { + switch ev.Type { + case EventTypeTargetFrame: + action := "x224" + if ev.Action == ActionFastPath { + action = "fastpath" + } + return json.Marshal(rdpTargetFrameEnvelope{ + Type: "target_frame", + Action: action, + Payload: ev.Payload, + ElapsedNs: ev.ElapsedNs, + }) + case EventTypeKeyboard: + return json.Marshal(rdpKeyboardEnvelope{ + Type: "keyboard", + Scancode: ev.Scancode, + Flags: ev.Flags, + ElapsedNs: ev.ElapsedNs, + }) + case EventTypeUnicode: + return json.Marshal(rdpUnicodeEnvelope{ + Type: "unicode", + CodePoint: ev.CodePoint, + Flags: ev.Flags, + ElapsedNs: ev.ElapsedNs, + }) + case EventTypeMouse: + return json.Marshal(rdpMouseEnvelope{ + Type: "mouse", + X: ev.X, + Y: ev.Y, + Flags: ev.Flags, + WheelDelta: ev.WheelDelta, + ElapsedNs: ev.ElapsedNs, + }) + } + return nil, errUnknownRdpEventType +} diff --git a/packages/pam/local/rdp-proxy.go b/packages/pam/local/rdp-proxy.go index af3b43ef..68760d25 100644 --- a/packages/pam/local/rdp-proxy.go +++ b/packages/pam/local/rdp-proxy.go @@ -18,22 +18,15 @@ import ( "github.com/rs/zerolog/log" ) -// RDPProxyServer exposes a local loopback TCP listener that tunnels bytes -// to the gateway's RDP MITM bridge via the existing mTLS + SSH relay. The -// user's RDP client connects to the loopback port; the gateway takes care -// of credential injection and forwarding to the Windows target. +// Loopback listener that tunnels RDP client traffic to the gateway's MITM bridge. type RDPProxyServer struct { BaseProxyServer server net.Listener port int - rdpFilePath string // path to the generated .rdp file, if any + rdpFilePath string } -// StartRDPLocalProxy is the CLI entry point for `infisical pam rdp access`. -// It creates a PAM session with the backend, binds a loopback listener, -// writes a .rdp file pointing at that loopback, optionally launches the -// user's default RDP client, and forwards accepted connections to the -// gateway. +// CLI entry point for `infisical pam rdp access`. func StartRDPLocalProxy(accessToken string, accessParams PAMAccessParams, projectID string, durationStr string, port int, noLaunch bool) { log.Info().Msgf("Starting RDP proxy for account: %s", accessParams.GetDisplayName()) log.Info().Msgf("Session duration: %s", durationStr) @@ -164,10 +157,8 @@ func (p *RDPProxyServer) gracefulShutdown() { p.shutdownOnce.Do(func() { log.Info().Msg("Starting graceful shutdown of RDP proxy...") - // Remove the .rdp file first: p.cancel() below unblocks Run(), - // which returns to main, which may exit before the rest of this - // goroutine completes. Do the cleanup that has to happen before - // anything that could let main race ahead. + // p.cancel() below can return main before this goroutine finishes; + // remove the .rdp file before risking that race. if p.rdpFilePath != "" { if err := os.Remove(p.rdpFilePath); err != nil && !os.IsNotExist(err) { log.Debug().Err(err).Str("path", p.rdpFilePath).Msg("Failed to remove .rdp file on exit") @@ -308,15 +299,8 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) { log.Info().Msgf("RDP connection closed for client: %s", clientConn.RemoteAddr().String()) } -// writeRDPFile creates a .rdp file pointing at the local loopback -// listener. Files live under `~/.infisical/rdp/` to match the CLI's -// existing convention for per-user state (alongside the login config -// and update-check cache). Filename includes the session ID so -// concurrent sessions don't collide. The file is removed on graceful -// shutdown (see gracefulShutdown) since the embedded loopback port -// becomes invalid as soon as the CLI exits; reopening the file later -// would just dial a dead port. -// Falls back to the OS temp dir if the home directory can't be resolved. +// Generates a per-session .rdp file under ~/.infisical/rdp/ pointing at +// the loopback listener. Removed on graceful shutdown. func writeRDPFile(listenPort int, sessionID, username string) (string, error) { filename := fmt.Sprintf("infisical-rdp-%s.rdp", sessionID) diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index 0cd6c29e..567e08e2 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -417,13 +417,17 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo if credentials.Port <= 0 || credentials.Port > 65535 { return fmt.Errorf("rdp: target port %d out of range", credentials.Port) } + // Anchor event timestamps to the session-level start so reconnects + // within the same PAM session don't restart elapsedNs from zero. + sessionStartedAt, _ := pamConfig.SessionUploader.GetSessionStartedAt(pamConfig.SessionId) rdpConfig := rdp.RDPProxyConfig{ - TargetHost: credentials.Host, - TargetPort: uint16(credentials.Port), - InjectUsername: credentials.Username, - InjectPassword: credentials.Password, - SessionID: pamConfig.SessionId, - SessionLogger: sessionLogger, + TargetHost: credentials.Host, + TargetPort: uint16(credentials.Port), + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + SessionID: pamConfig.SessionId, + SessionLogger: sessionLogger, + SessionStartedAt: sessionStartedAt, } proxy := rdp.NewRDPProxy(rdpConfig) log.Info(). diff --git a/packages/pam/session/logger.go b/packages/pam/session/logger.go index 77c3c3e3..cfddd621 100644 --- a/packages/pam/session/logger.go +++ b/packages/pam/session/logger.go @@ -31,6 +31,7 @@ type TerminalEventType string const ( TerminalEventInput TerminalEventType = "input" // Data from user to server TerminalEventOutput TerminalEventType = "output" // Data from server to user + TerminalEventRDP TerminalEventType = "rdp" // RDP tap event (see TerminalChannelRDP) ) // TerminalChannelType represents the type of SSH channel @@ -40,6 +41,7 @@ const ( TerminalChannelShell TerminalChannelType = "terminal" // Interactive shell session TerminalChannelExec TerminalChannelType = "exec" // Single command execution TerminalChannelSFTP TerminalChannelType = "sftp" // SFTP file transfer + TerminalChannelRDP TerminalChannelType = "rdp" // RDP frame/input tap; Data carries an RDP-specific JSON envelope ) // TerminalEvent represents a single event in a terminal session @@ -305,7 +307,14 @@ func (sl *EncryptedSessionLogger) LogTerminalEvent(event TerminalEvent) error { if event.ElapsedTime == 0 { event.ElapsedTime = time.Since(sl.sessionStart).Seconds() } - event.Data = sl.applyMasking(event.Data) + // RDP carries a structured JSON envelope (with base64-encoded PDU + // bytes, scancodes, etc.) in Data, not free-form terminal text. + // Masking patterns are SSH-shaped regexes; running them over the + // envelope would corrupt valid recordings whenever a pattern + // happened to match a substring of the JSON or base64. + if event.ChannelType != TerminalChannelRDP { + event.Data = sl.applyMasking(event.Data) + } return json.Marshal(event) }) } diff --git a/packages/pam/session/uploader.go b/packages/pam/session/uploader.go index 6f43781c..5d016f72 100644 --- a/packages/pam/session/uploader.go +++ b/packages/pam/session/uploader.go @@ -273,27 +273,27 @@ func deletePersistedOffset(filename string) { _ = os.Remove(offsetFilePath(filename)) } -// readFromOffset reads length-prefixed encrypted records from filename starting at offset, -// decrypts each, and returns them as a JSON array payload plus the new file offset. -// When maxPayloadBytes > 0, stops accumulating once the next entry would push the serialized JSON array past that limit -// Returns nil payload (and the unchanged offset) if there are no new records. -func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadBytes int) ([]byte, int64, error) { +// Returns (payload JSON array, new offset, last entry's elapsedMs, err). +// lastEntryElapsedMs is 0 if entries lack the field. maxPayloadBytes>0 +// caps the JSON array size; caller loops for the rest. +func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadBytes int) ([]byte, int64, int64, error) { recordingDir := GetSessionRecordingDir() fullPath := filepath.Join(recordingDir, filename) file, err := os.Open(fullPath) if err != nil { - return nil, offset, fmt.Errorf("failed to open session file: %w", err) + return nil, offset, 0, fmt.Errorf("failed to open session file: %w", err) } defer file.Close() if _, err := file.Seek(offset, io.SeekStart); err != nil { - return nil, offset, fmt.Errorf("failed to seek to offset %d: %w", offset, err) + return nil, offset, 0, fmt.Errorf("failed to seek to offset %d: %w", offset, err) } var entries []json.RawMessage newOffset := offset runningSize := 2 // account for JSON array brackets [] + var lastEntryElapsedMs int64 for { lengthBytes := make([]byte, 4) @@ -301,7 +301,7 @@ func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadByte if err == io.EOF || err == io.ErrUnexpectedEOF { break // No more complete records } - return nil, newOffset, fmt.Errorf("failed to read length prefix: %w", err) + return nil, newOffset, 0, fmt.Errorf("failed to read length prefix: %w", err) } length := binary.BigEndian.Uint32(lengthBytes) @@ -312,7 +312,7 @@ func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadByte decryptedData, err := DecryptData(encryptedData, encryptionKey) if err != nil { - return nil, newOffset, fmt.Errorf("failed to decrypt record at offset %d: %w", newOffset, err) + return nil, newOffset, 0, fmt.Errorf("failed to decrypt record at offset %d: %w", newOffset, err) } entrySize := len(decryptedData) @@ -323,21 +323,40 @@ func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadByte break // would exceed budget; caller will loop for the rest } + // Probe the entry's elapsedTime field. Absent on non-terminal events. + var probe struct { + ElapsedTime float64 `json:"elapsedTime"` + } + if jsonErr := json.Unmarshal(decryptedData, &probe); jsonErr == nil && probe.ElapsedTime > 0 { + lastEntryElapsedMs = int64(probe.ElapsedTime * 1000) + } + entries = append(entries, json.RawMessage(decryptedData)) newOffset += int64(4 + length) runningSize += entrySize } if len(entries) == 0 { - return nil, newOffset, nil + return nil, newOffset, 0, nil } payload, err := json.Marshal(entries) if err != nil { - return nil, newOffset, fmt.Errorf("failed to marshal event batch: %w", err) + return nil, newOffset, 0, fmt.Errorf("failed to marshal event batch: %w", err) } - return payload, newOffset, nil + return payload, newOffset, lastEntryElapsedMs, nil +} + +// Stable across gateway restarts and per-connection bridge restarts. +func (su *SessionUploader) GetSessionStartedAt(sessionID string) (time.Time, bool) { + su.activeSessionsMu.RLock() + defer su.activeSessionsMu.RUnlock() + state, ok := su.activeSessions[sessionID] + if !ok { + return time.Time{}, false + } + return state.startedAt, true } // RegisterSession registers a session for incremental batch uploads, resuming from @@ -415,12 +434,8 @@ func (su *SessionUploader) startUploadRoutine() { }() } -// resumeInProgressSessions re-registers non-expired recording files into the upload loop at startup. -// A gateway restart kills all proxy connections, so any file on disk is from a session that is -// already over from the customer's perspective. Re-registering restores offset tracking so the -// ticker-based flush and chunk reconciliation can drive uploads to completion over subsequent ticks. -// Already-expired files are skipped here and handled exclusively by uploadExpiredSessionFiles -// to avoid duplicate back-to-back cleanup attempts on the same file at startup. +// Re-registers non-expired recording files at startup so the flush ticker +// can drain them. Expired files are handled by uploadExpiredSessionFiles. func (su *SessionUploader) resumeInProgressSessions() { allFiles, err := ListSessionFiles() if err != nil { @@ -494,10 +509,7 @@ func (su *SessionUploader) flushActiveSessions() { } } -// flushSession reads new events from the session recording file since the last uploaded offset, -// uploads them as a batch, and advances the offset on success. Returns nil when there is nothing -// to do (session not registered, already in legacy mode, no new events) or when a 404 cleanly -// transitions the session to legacy mode; the caller treats those as success. +// Uploads new events as a batch and advances the offset on success. func (su *SessionUploader) flushSession(sessionID, encryptionKey string) error { su.activeSessionsMu.RLock() state, ok := su.activeSessions[sessionID] @@ -518,7 +530,7 @@ func (su *SessionUploader) flushSession(sessionID, encryptionKey string) error { currentOffset := state.fileOffset for { - payload, newOffset, err := readFromOffset(state.filename, encryptionKey, currentOffset, pamRecordingMaxPlaintextBytes) + payload, newOffset, lastEntryElapsedMs, err := readFromOffset(state.filename, encryptionKey, currentOffset, pamRecordingMaxPlaintextBytes) if err != nil { log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to read session events for chunk upload") break @@ -527,7 +539,12 @@ func (su *SessionUploader) flushSession(sessionID, encryptionKey string) error { break } - endElapsedMs := time.Since(state.startedAt).Milliseconds() + // Prefer the last event's actual elapsedTime; fall back to wallclock for + // non-terminal sessions whose entries lack the field (HTTP, Kubernetes). + endElapsedMs := lastEntryElapsedMs + if endElapsedMs <= startElapsedMs { + endElapsedMs = time.Since(state.startedAt).Milliseconds() + } pc, encErr := su.chunkUploader.EncryptAndQueueChunk(sessionID, payload, startElapsedMs, endElapsedMs) if encErr != nil { @@ -551,7 +568,7 @@ func (su *SessionUploader) flushSession(sessionID, encryptionKey string) error { return nil } - payload, newOffset, err := readFromOffset(state.filename, encryptionKey, state.fileOffset, 0) + payload, newOffset, _, err := readFromOffset(state.filename, encryptionKey, state.fileOffset, 0) if err != nil { log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to read session events for batch upload") return err @@ -700,10 +717,8 @@ func (su *SessionUploader) CleanupPAMSession(sessionID string, reason string) er su.RegisterSession(sessionID) } - // Final flush: upload any remaining events before we delete the file. Any failure on this path - // (key fetch, batch flush, or legacy bulk upload) returns early with the recording file, registry - // entry, and persisted offset intact so uploadExpiredSessionFiles can retry once the file crosses - // ExpiresAt. Deleting on failure would lose unuploaded events unrecoverably. + // On any failure here, return early so uploadExpiredSessionFiles can retry + // past ExpiresAt; deleting the file on failure would lose events. encryptionKey, err := su.credentialsManager.GetPAMSessionEncryptionKey() if err != nil { log.Error().Err(err).Str("sessionId", sessionID).Msg("Could not get encryption key for final flush, keeping recording file for retry") @@ -714,8 +729,7 @@ func (su *SessionUploader) CleanupPAMSession(sessionID string, reason string) er return flushErr } - // If the batch endpoint was not supported (or this session was already in legacy mode), - // fall back to a single bulk upload of the whole file. + // Legacy fallback: single bulk upload of the whole file. su.activeSessionsMu.RLock() state, stateExists := su.activeSessions[sessionID] su.activeSessionsMu.RUnlock()