diff --git a/crates/winscard/src/lib.rs b/crates/winscard/src/lib.rs index dbe5f931..36348937 100644 --- a/crates/winscard/src/lib.rs +++ b/crates/winscard/src/lib.rs @@ -184,6 +184,13 @@ impl From for Error { } } +#[cfg(feature = "std")] +impl From for Error { + fn from(value: std::string::FromUtf16Error) -> Self { + Error::new(ErrorKind::InvalidParameter, value.to_string()) + } +} + #[cfg(feature = "std")] impl From for Error { fn from(value: std::ffi::NulError) -> Self { diff --git a/ffi/src/sspi/credentials_attributes.rs b/ffi/src/sspi/credentials_attributes.rs index 5ea6c154..7e90de32 100644 --- a/ffi/src/sspi/credentials_attributes.rs +++ b/ffi/src/sspi/credentials_attributes.rs @@ -108,9 +108,10 @@ pub unsafe fn extract_kdc_proxy_settings(p_buffer: NonNull) -> Result()) - }); + }) + .map_err(Error::from)?; let client_tls_cred = if *client_tls_cred_offset != 0 && *client_tls_cred_length != 0 { // SAFETY: @@ -129,7 +130,10 @@ pub unsafe fn extract_kdc_proxy_settings(p_buffer: NonNull) -> Result Result(); - // SAFETY: - // - `p_buffer` is guaranteed to be non-null due to the prior check. - // - The memory region `p_buffer` contains a valid null-terminator at the end of string. - // - The memory region `p_buffer` points to is valid for reads of bytes up to and including null-terminator. - let kdc_url = unsafe { c_w_str_to_string((*cred_attr).kdc_url.cast_const()) }; + let kdc_url = try_execute!( + // SAFETY: + // - `p_buffer` is guaranteed to be non-null due to the prior check. + // - The memory region `p_buffer` contains a valid null-terminator at the end of string. + // - The memory region `p_buffer` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string((*cred_attr).kdc_url.cast_const()) }.map_err(Error::from) + ); credentials_handle.attributes.kdc_url = Some(kdc_url); 0 @@ -1545,32 +1553,42 @@ pub unsafe extern "system" fn ChangeAccountPasswordW( check_null!(psz_new_password); check_null!(p_output); - // SAFETY: - // - `psz_package_name` is guaranteed to be non-null due to the prior check. - // - The memory region `psz_package_name` contains a valid null-terminator at the end of string. - // - The memory region `psz_package_name` points to is valid for reads of bytes up to and including null-terminator. - let mut security_package_name = unsafe { c_w_str_to_string(psz_package_name) }; + let mut security_package_name = try_execute!( + // SAFETY: + // - `psz_package_name` is guaranteed to be non-null due to the prior check. + // - The memory region `psz_package_name` contains a valid null-terminator at the end of string. + // - The memory region `psz_package_name` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(psz_package_name) }.map_err(Error::from) + ); - // SAFETY: - // - `psz_domain_name` is guaranteed to be non-null due to the prior check. - // - The memory region `psz_domain_name` contains a valid null-terminator at the end of string. - // - The memory region `psz_domain_name` points to is valid for reads of bytes up to and including null-terminator. - let mut domain = unsafe { c_w_str_to_string(psz_domain_name) }; - // SAFETY: - // - `psz_account_name` is guaranteed to be non-null due to the prior check. - // - The memory region `psz_account_name` contains a valid null-terminator at the end of string. - // - The memory region `psz_account_name` points to is valid for reads of bytes up to and including null-terminator. - let mut username = unsafe { c_w_str_to_string(psz_account_name) }; - // SAFETY: - // - `psz_old_password` is guaranteed to be non-null due to the prior check. - // - The memory region `psz_old_password` contains a valid null-terminator at the end of string. - // - The memory region `psz_old_password` points to is valid for reads of bytes up to and including null-terminator. - let mut password = Secret::new(unsafe { c_w_str_to_string(psz_old_password) }); - // SAFETY: - // - `psz_new_password` is guaranteed to be non-null due to the prior check. - // - The memory region `psz_new_password` contains a valid null-terminator at the end of string. - // - The memory region `psz_new_password` points to is valid for reads of bytes up to and including null-terminator. - let mut new_password = Secret::new(unsafe { c_w_str_to_string(psz_new_password) }); + let mut domain = try_execute!( + // SAFETY: + // - `psz_domain_name` is guaranteed to be non-null due to the prior check. + // - The memory region `psz_domain_name` contains a valid null-terminator at the end of string. + // - The memory region `psz_domain_name` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(psz_domain_name) }.map_err(Error::from) + ); + let mut username = try_execute!( + // SAFETY: + // - `psz_account_name` is guaranteed to be non-null due to the prior check. + // - The memory region `psz_account_name` contains a valid null-terminator at the end of string. + // - The memory region `psz_account_name` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(psz_account_name) }.map_err(Error::from) + ); + let mut password = Secret::new(try_execute!( + // SAFETY: + // - `psz_old_password` is guaranteed to be non-null due to the prior check. + // - The memory region `psz_old_password` contains a valid null-terminator at the end of string. + // - The memory region `psz_old_password` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(psz_old_password) }.map_err(Error::from) + )); + let mut new_password = Secret::new(try_execute!( + // SAFETY: + // - `psz_new_password` is guaranteed to be non-null due to the prior check. + // - The memory region `psz_new_password` contains a valid null-terminator at the end of string. + // - The memory region `psz_new_password` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(psz_new_password) }.map_err(Error::from) + )); // SAFETY: // * `security_package_name' is a `String`. diff --git a/ffi/src/sspi/sec_pkg_info.rs b/ffi/src/sspi/sec_pkg_info.rs index f23cd0e5..20f3d9b5 100644 --- a/ffi/src/sspi/sec_pkg_info.rs +++ b/ffi/src/sspi/sec_pkg_info.rs @@ -2,7 +2,7 @@ use std::ffi::CStr; use std::mem::size_of; use std::ptr::copy_nonoverlapping; -use sspi::{enumerate_security_packages, str_to_w_buff, PackageInfo, KERBEROS_VERSION}; +use sspi::{enumerate_security_packages, str_to_w_buff, Error, PackageInfo, KERBEROS_VERSION}; #[cfg(windows)] use symbol_rename_macro::rename_symbol; @@ -448,11 +448,13 @@ pub unsafe extern "system" fn QuerySecurityPackageInfoW( check_null!(p_package_name); check_null!(pp_package_info); - // SAFETY: - // - `p_package_name` is guaranteed to be non-null due to the prior check. - // - The memory region `p_package_name` contains a valid null-terminator at the end of string. - // - The memory region `p_package_name` points to is valid for reads of bytes up to and including null-terminator. - let pkg_name = unsafe { c_w_str_to_string(p_package_name) }; + let pkg_name = try_execute!( + // SAFETY: + // - `p_package_name` is guaranteed to be non-null due to the prior check. + // - The memory region `p_package_name` contains a valid null-terminator at the end of string. + // - The memory region `p_package_name` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(p_package_name) }.map_err(Error::from) + ); let pkg_info: RawSecPkgInfoW = try_execute!(enumerate_security_packages()) .into_iter() diff --git a/ffi/src/sspi/sec_winnt_auth_identity.rs b/ffi/src/sspi/sec_winnt_auth_identity.rs index 5351adff..682a041b 100644 --- a/ffi/src/sspi/sec_winnt_auth_identity.rs +++ b/ffi/src/sspi/sec_winnt_auth_identity.rs @@ -478,13 +478,18 @@ pub unsafe fn auth_data_to_identity_buffers_w( let auth_data = unsafe { auth_data.as_ref() }.expect("auth_data pointer should not be null"); if !auth_data.package_list.is_null() && auth_data.package_list_length > 0 { - // SAFETY: `package_list` is not null due to a prior check. - *package_list = Some(String::from_utf16_lossy(unsafe { - from_raw_parts( - auth_data.package_list, - usize::try_from(auth_data.package_list_length).unwrap(), + *package_list = Some( + String::from_utf16( + // SAFETY: `package_list` is not null due to a prior check. + unsafe { + from_raw_parts( + auth_data.package_list, + usize::try_from(auth_data.package_list_length).unwrap(), + ) + }, ) - })); + .map_err(Error::from)?, + ); } ( @@ -1263,18 +1268,21 @@ mod tests { assert_eq!( "user", - String::from_utf16_lossy(from_raw_parts((*identity).user, (*identity).user_length as usize)) + String::from_utf16(from_raw_parts((*identity).user, (*identity).user_length as usize)) + .expect("user is a correct utf-16 string") ); assert_eq!( "pass", - String::from_utf16_lossy(from_raw_parts( + String::from_utf16(from_raw_parts( (*identity).password, (*identity).password_length as usize )) + .expect("password is a correct utf-16 string") ); assert_eq!( "domain", - String::from_utf16_lossy(from_raw_parts((*identity).domain, (*identity).domain_length as usize)) + String::from_utf16(from_raw_parts((*identity).domain, (*identity).domain_length as usize)) + .expect("domain is a correct utf-16 string") ); let status = SspiFreeAuthIdentity(identity as *mut _); diff --git a/ffi/src/utils.rs b/ffi/src/utils.rs index d829f1c1..b7ee74e8 100644 --- a/ffi/src/utils.rs +++ b/ffi/src/utils.rs @@ -1,4 +1,5 @@ use std::slice::from_raw_parts; +use std::string::FromUtf16Error; use libc::c_char; @@ -13,7 +14,7 @@ pub(crate) fn into_raw_ptr(value: T) -> *mut T { /// Behavior is undefined is any of the following conditions are violated: /// /// * `s` must be a [valid], null-terminated C string. -pub(crate) unsafe fn c_w_str_to_string(s: *const u16) -> String { +pub(crate) unsafe fn c_w_str_to_string(s: *const u16) -> Result { let mut len = 0; // SAFETY: `s` is a valid, null-terminated C string. @@ -22,7 +23,7 @@ pub(crate) unsafe fn c_w_str_to_string(s: *const u16) -> String { } // SAFETY: `s` is a valid, null-terminated C string. - String::from_utf16_lossy(unsafe { from_raw_parts(s, len) }) + String::from_utf16(unsafe { from_raw_parts(s, len) }) } /// The returned length includes the null terminator char. diff --git a/ffi/src/winscard/scard.rs b/ffi/src/winscard/scard.rs index 11dc340a..8d2df75e 100644 --- a/ffi/src/winscard/scard.rs +++ b/ffi/src/winscard/scard.rs @@ -155,11 +155,13 @@ pub unsafe extern "system" fn SCardConnectW( check_null!(ph_card); check_null!(pdw_active_protocol); - // SAFETY: - // - `sz_reader` is guaranteed to be non-null due to the prior check. - // - The memory region `sz_reader` contains a valid null-terminator at the end of string. - // - The memory region `sz_reader` points to is valid for reads of bytes up to and including null-terminator. - let reader_name = unsafe { c_w_str_to_string(sz_reader) }; + let reader_name = try_execute!( + // SAFETY: + // - `sz_reader` is guaranteed to be non-null due to the prior check. + // - The memory region `sz_reader` contains a valid null-terminator at the end of string. + // - The memory region `sz_reader` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(sz_reader) }.map_err(Error::from) + ); try_execute!( // SAFETY: diff --git a/ffi/src/winscard/scard_context.rs b/ffi/src/winscard/scard_context.rs index 33d1d33f..2eb51f57 100644 --- a/ffi/src/winscard/scard_context.rs +++ b/ffi/src/winscard/scard_context.rs @@ -14,7 +14,7 @@ use libc::c_void; use symbol_rename_macro::rename_symbol; use uuid::Uuid; use winscard::winscard::{CurrentState, ReaderState, WinScardContext}; -use winscard::{ErrorKind, ScardContext as PivCardContext, SmartCardInfo, WinScardResult}; +use winscard::{Error, ErrorKind, ScardContext as PivCardContext, SmartCardInfo, WinScardResult}; use super::buf_alloc::{build_buf_request_type, build_buf_request_type_wide, save_out_buf, save_out_buf_wide}; use crate::utils::{c_w_str_to_string, into_raw_ptr, str_encode_utf16}; @@ -579,11 +579,13 @@ pub unsafe extern "system" fn SCardGetCardTypeProviderNameW( check_null!(szProvider); check_null!(pcch_provider); - // SAFETY: - // - `sz_card_name` is guaranteed to be non-null due to the prior check. - // - The memory region `sz_card_name` contains a valid null-terminator at the end of string. - // - The memory region `sz_card_name` points to is valid for reads of bytes up to and including null-terminator. - let card_name = unsafe { c_w_str_to_string(sz_card_name) }; + let card_name = try_execute!( + // SAFETY: + // - `sz_card_name` is guaranteed to be non-null due to the prior check. + // - The memory region `sz_card_name` contains a valid null-terminator at the end of string. + // - The memory region `sz_card_name` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(sz_card_name) }.map_err(Error::from) + ); let context_handle = try_execute!( // SAFETY: @@ -1038,7 +1040,7 @@ pub unsafe extern "system" fn SCardGetStatusChangeA( atr: c_reader.rgb_atr, }) }) - .collect::, winscard::Error>>()); + .collect::, Error>>()); try_execute!(context.get_status_change(dw_timeout, &mut reader_states)); for (reader_state, c_reader_state) in reader_states.iter().zip(c_reader_states.iter_mut()) { @@ -1096,11 +1098,13 @@ pub unsafe extern "system" fn SCardGetStatusChangeW( check_null!(c_reader.sz_reader, "reader name in reader state"); Ok(ReaderState { - // SAFETY: - // - `c_reader.sz_reader` is guaranteed to be non-null due to the prior check. - // - The memory region `c_reader.sz_reader` contains a valid null-terminator at the end of string. - // - The memory region `c_reader.sz_reader` points to is valid for reads of bytes up to and including null-terminator. - reader_name: Cow::Owned(unsafe { c_w_str_to_string(c_reader.sz_reader) }), + reader_name: Cow::Owned( + // SAFETY: + // - `c_reader.sz_reader` is guaranteed to be non-null due to the prior check. + // - The memory region `c_reader.sz_reader` contains a valid null-terminator at the end of string. + // - The memory region `c_reader.sz_reader` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(c_reader.sz_reader) }.map_err(Error::from)?, + ), user_data: c_reader.pv_user_data as usize, current_state: CurrentState::from_bits(c_reader.dw_current_state).unwrap_or_default(), event_state: CurrentState::from_bits(c_reader.dw_event_state).unwrap_or_default(), @@ -1108,7 +1112,7 @@ pub unsafe extern "system" fn SCardGetStatusChangeW( atr: c_reader.rgb_atr, }) }) - .collect::, winscard::Error>>()); + .collect::, Error>>()); try_execute!(context.get_status_change(dw_timeout, &mut reader_states)); for (reader_state, c_reader_state) in reader_states.iter().zip(c_reader_states.iter_mut()) { @@ -1259,11 +1263,14 @@ pub unsafe extern "system" fn SCardReadCacheW( ) -> ScardStatus { check_null!(lookup_name); - // SAFETY: - // - `lookup_name` is guaranteed to be non-null due to the prior check. - // - The memory region `lookup_name` contains a valid null-terminator at the end of string. - // - The memory region `lookup_name` points to is valid for reads of bytes up to and including null-terminator. - let lookup_name = unsafe { c_w_str_to_string(lookup_name) }; + let lookup_name = try_execute!( + // SAFETY: + // - `lookup_name` is guaranteed to be non-null due to the prior check. + // - The memory region `lookup_name` contains a valid null-terminator at the end of string. + // - The memory region `lookup_name` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(lookup_name) }.map_err(Error::from) + ); + try_execute!( // SAFETY: // - `context` is a valid raw scard context handle. @@ -1394,11 +1401,13 @@ pub unsafe extern "system" fn SCardWriteCacheW( ) -> ScardStatus { check_null!(lookup_name); - // SAFETY: - // - `lookup_name` is guaranteed to be non-null due to the prior check. - // - The memory region `lookup_name` contains a valid null-terminator at the end of string. - // - The memory region `lookup_name` points to is valid for reads of bytes up to and including null-terminator. - let lookup_name = unsafe { c_w_str_to_string(lookup_name) }; + let lookup_name = try_execute!( + // SAFETY: + // - `lookup_name` is guaranteed to be non-null due to the prior check. + // - The memory region `lookup_name` contains a valid null-terminator at the end of string. + // - The memory region `lookup_name` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(lookup_name) }.map_err(Error::from) + ); try_execute!( // SAFETY: // - `context` is a valid raw scard context handle. @@ -1515,11 +1524,13 @@ pub unsafe extern "system" fn SCardGetReaderIconW( ) -> ScardStatus { check_null!(sz_reader_name); - // SAFETY: - // - `sz_reader_name` is guaranteed to be non-null due to the prior check. - // - The memory region `sz_reader_name` contains a valid null-terminator at the end of string. - // - The memory region `sz_reader_name` points to is valid for reads of bytes up to and including null-terminator. - let reader_name = unsafe { c_w_str_to_string(sz_reader_name) }; + let reader_name = try_execute!( + // SAFETY: + // - `sz_reader_name` is guaranteed to be non-null due to the prior check. + // - The memory region `sz_reader_name` contains a valid null-terminator at the end of string. + // - The memory region `sz_reader_name` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(sz_reader_name) }.map_err(Error::from) + ); try_execute!( // SAFETY: @@ -1614,11 +1625,13 @@ pub unsafe extern "system" fn SCardGetDeviceTypeIdW( ) -> ScardStatus { check_null!(sz_reader_name); - // SAFETY: - // - `sz_reader_name` is guaranteed to be non-null due to the prior check. - // - The memory region `sz_reader_name` contains a valid null-terminator at the end of string. - // - The memory region `sz_reader_name` points to is valid for reads of bytes up to and including null-terminator. - let reader_name = unsafe { c_w_str_to_string(sz_reader_name) }; + let reader_name = try_execute!( + // SAFETY: + // - `sz_reader_name` is guaranteed to be non-null due to the prior check. + // - The memory region `sz_reader_name` contains a valid null-terminator at the end of string. + // - The memory region `sz_reader_name` points to is valid for reads of bytes up to and including null-terminator. + unsafe { c_w_str_to_string(sz_reader_name) }.map_err(Error::from) + ); try_execute!( // SAFETY: diff --git a/src/auth_identity.rs b/src/auth_identity.rs index 36e828b6..abd5be36 100644 --- a/src/auth_identity.rs +++ b/src/auth_identity.rs @@ -5,6 +5,7 @@ use crate::{utils, Error, Secret}; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum UsernameError { MixedFormat, + InvalidUtf16, } impl std::error::Error for UsernameError {} @@ -13,6 +14,7 @@ impl fmt::Display for UsernameError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { UsernameError::MixedFormat => write!(f, "mixed username format"), + UsernameError::InvalidUtf16 => write!(f, "invalid UTF-16 string"), } } } @@ -226,15 +228,18 @@ impl TryFrom<&AuthIdentityBuffers> for AuthIdentity { type Error = UsernameError; fn try_from(credentials_buffers: &AuthIdentityBuffers) -> Result { - let account_name = utils::bytes_to_utf16_string(&credentials_buffers.user); + let account_name = + utils::bytes_to_utf16_string(&credentials_buffers.user).map_err(|_| UsernameError::InvalidUtf16)?; let domain_name = if !credentials_buffers.domain.is_empty() { - Some(utils::bytes_to_utf16_string(&credentials_buffers.domain)) + Some(utils::bytes_to_utf16_string(&credentials_buffers.domain).map_err(|_| UsernameError::InvalidUtf16)?) } else { None }; let username = Username::new(&account_name, domain_name.as_deref())?; - let password = utils::bytes_to_utf16_string(credentials_buffers.password.as_ref()).into(); + let password = utils::bytes_to_utf16_string(credentials_buffers.password.as_ref()) + .map_err(|_| UsernameError::InvalidUtf16)? + .into(); Ok(Self { username, password }) } @@ -367,7 +372,7 @@ mod scard_credentials { fn try_from(value: &SmartCardIdentityBuffers) -> Result { let private_key = if let Some(key) = &value.private_key_pem { Some(SecretPrivateKey::new( - PrivateKey::from_pem_str(&utils::bytes_to_utf16_string(key)).map_err(|e| { + PrivateKey::from_pem_str(&utils::bytes_to_utf16_string(key)?).map_err(|e| { Error::new( ErrorKind::InternalError, format!("Unable to create a PrivateKey from a PEM string: {}", e), @@ -380,12 +385,20 @@ mod scard_credentials { Ok(Self { certificate: picky_asn1_der::from_bytes(&value.certificate)?, - reader_name: utils::bytes_to_utf16_string(&value.reader_name), - pin: utils::bytes_to_utf16_string(value.pin.as_ref()).into_bytes().into(), - username: utils::bytes_to_utf16_string(&value.username), - card_name: value.card_name.as_deref().map(utils::bytes_to_utf16_string), - container_name: value.container_name.as_deref().map(utils::bytes_to_utf16_string), - csp_name: utils::bytes_to_utf16_string(&value.csp_name), + reader_name: utils::bytes_to_utf16_string(&value.reader_name)?, + pin: utils::bytes_to_utf16_string(value.pin.as_ref())?.into_bytes().into(), + username: utils::bytes_to_utf16_string(&value.username)?, + card_name: value + .card_name + .as_deref() + .map(utils::bytes_to_utf16_string) + .transpose()?, + container_name: value + .container_name + .as_deref() + .map(utils::bytes_to_utf16_string) + .transpose()?, + csp_name: utils::bytes_to_utf16_string(&value.csp_name)?, private_key, scard_type: value.scard_type.clone(), }) diff --git a/src/kerberos/client/mod.rs b/src/kerberos/client/mod.rs index 775c8e2d..3f4e7018 100644 --- a/src/kerberos/client/mod.rs +++ b/src/kerberos/client/mod.rs @@ -113,9 +113,9 @@ pub async fn initialize_security_context<'a>( let (username, password, realm, cname_type) = match credentials { CredentialsBuffers::AuthIdentity(auth_identity) => { - let username = utf16_bytes_to_utf8_string(&auth_identity.user); - let domain = utf16_bytes_to_utf8_string(&auth_identity.domain); - let password = utf16_bytes_to_utf8_string(auth_identity.password.as_ref()); + let username = utf16_bytes_to_utf8_string(&auth_identity.user)?; + let domain = utf16_bytes_to_utf8_string(&auth_identity.domain)?; + let password = utf16_bytes_to_utf8_string(auth_identity.password.as_ref())?; let realm = get_client_principal_realm(&username, &domain); let cname_type = get_client_principal_name_type(&username, &domain); @@ -124,8 +124,8 @@ pub async fn initialize_security_context<'a>( } #[cfg(feature = "scard")] CredentialsBuffers::SmartCard(smart_card) => { - let username = utf16_bytes_to_utf8_string(&smart_card.username); - let password = utf16_bytes_to_utf8_string(smart_card.pin.as_ref()); + let username = utf16_bytes_to_utf8_string(&smart_card.username)?; + let password = utf16_bytes_to_utf8_string(smart_card.pin.as_ref())?; let realm = get_client_principal_realm(&username, ""); let cname_type = get_client_principal_name_type(&username, ""); @@ -150,7 +150,7 @@ pub async fn initialize_security_context<'a>( let pa_data_options = match credentials { CredentialsBuffers::AuthIdentity(auth_identity) => { - let domain = utf16_bytes_to_utf8_string(&auth_identity.domain); + let domain = utf16_bytes_to_utf8_string(&auth_identity.domain)?; let salt = format!("{}{}", domain, username); AsReqPaDataOptions::AuthIdentity(GenerateAsPaDataOptions { diff --git a/src/kerberos/mod.rs b/src/kerberos/mod.rs index 6ac3fc26..fd86fca1 100644 --- a/src/kerberos/mod.rs +++ b/src/kerberos/mod.rs @@ -532,7 +532,7 @@ impl Sspi for Kerberos { if let Some(CredentialsBuffers::SmartCard(ref identity_buffers)) = self.auth_identity { use crate::utils::utf16_bytes_to_utf8_string; - let username = utf16_bytes_to_utf8_string(&identity_buffers.username); + let username = utf16_bytes_to_utf8_string(&identity_buffers.username)?; let username = crate::Username::parse(&username).map_err(|e| Error::new(ErrorKind::InvalidParameter, e))?; return Ok(ContextNames { username }); } diff --git a/src/kerberos/server/as_exchange.rs b/src/kerberos/server/as_exchange.rs index b66a4f85..a71d62ba 100644 --- a/src/kerberos/server/as_exchange.rs +++ b/src/kerberos/server/as_exchange.rs @@ -48,9 +48,9 @@ pub(super) async fn request_tgt( let (username, password, realm, cname_type) = match credentials { CredentialsBuffers::AuthIdentity(auth_identity) => { - let username = utf16_bytes_to_utf8_string(&auth_identity.user); - let domain = utf16_bytes_to_utf8_string(&auth_identity.domain); - let password = utf16_bytes_to_utf8_string(auth_identity.password.as_ref()); + let username = utf16_bytes_to_utf8_string(&auth_identity.user)?; + let domain = utf16_bytes_to_utf8_string(&auth_identity.domain)?; + let password = utf16_bytes_to_utf8_string(auth_identity.password.as_ref())?; let realm = get_client_principal_realm(&username, &domain); let cname_type = get_client_principal_name_type(&username, &domain); @@ -83,7 +83,7 @@ pub(super) async fn request_tgt( let pa_data_options = match credentials { CredentialsBuffers::AuthIdentity(auth_identity) => { - let domain = utf16_bytes_to_utf8_string(&auth_identity.domain); + let domain = utf16_bytes_to_utf8_string(&auth_identity.domain)?; let salt = format!("{}{}", domain, username); AsReqPaDataOptions::AuthIdentity(GenerateAsPaDataOptions { diff --git a/src/negotiate.rs b/src/negotiate.rs index fd22b9a2..3bf41db1 100644 --- a/src/negotiate.rs +++ b/src/negotiate.rs @@ -613,7 +613,7 @@ impl<'a> Negotiate { #[cfg(feature = "scard")] if let Some(Some(CredentialsBuffers::SmartCard(identity))) = builder.credentials_handle { if let NegotiatedProtocol::Ntlm(_) = &self.protocol { - let username = crate::utils::bytes_to_utf16_string(&identity.username); + let username = crate::utils::bytes_to_utf16_string(&identity.username)?; let host = detect_kdc_url(&get_client_principal_realm(&username, "")) .ok_or_else(|| Error::new(ErrorKind::NoAuthenticatingAuthority, "can not detect KDC url"))?; debug!("Negotiate: try Kerberos"); diff --git a/src/ntlm/messages/computations.rs b/src/ntlm/messages/computations.rs index cca798ce..58c71507 100644 --- a/src/ntlm/messages/computations.rs +++ b/src/ntlm/messages/computations.rs @@ -182,7 +182,7 @@ pub(super) fn compute_ntlm_v2_hash(identity: &AuthIdentityBuffers) -> crate::Res compute_md4(identity.password.as_ref()) }; - let user_utf16 = utils::bytes_to_utf16_string(identity.user.as_ref()); + let user_utf16 = utils::bytes_to_utf16_string(identity.user.as_ref())?; let mut user_uppercase_with_domain = utils::string_to_utf16(user_utf16.to_uppercase().as_str()); user_uppercase_with_domain.extend(&identity.domain); diff --git a/src/utils.rs b/src/utils.rs index b1e38af1..70441b8a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -20,13 +20,13 @@ pub fn str_to_w_buff(data: &str) -> Vec { data.encode_utf16().chain(std::iter::once(0)).collect() } -pub(crate) fn bytes_to_utf16_string(mut value: &[u8]) -> String { +pub(crate) fn bytes_to_utf16_string(mut value: &[u8]) -> Result { let mut value_u16 = vec![0x00; value.len() / 2]; value .read_u16_into::(value_u16.as_mut()) .expect("read_u16_into cannot fail at this point"); - String::from_utf16_lossy(value_u16.as_ref()) + String::from_utf16(value_u16.as_ref()).map_err(Error::from) } #[cfg_attr(not(target_os = "windows"), allow(unused))] @@ -34,14 +34,15 @@ pub(crate) fn is_azure_ad_domain(domain: &str) -> bool { domain == crate::pku2u::AZURE_AD_DOMAIN } -pub fn utf16_bytes_to_utf8_string(data: &[u8]) -> String { +pub fn utf16_bytes_to_utf8_string(data: &[u8]) -> Result { debug_assert_eq!(data.len() % 2, 0); - String::from_utf16_lossy( + String::from_utf16( &data .chunks(2) .map(|c| u16::from_le_bytes(c.try_into().unwrap())) .collect::>(), ) + .map_err(Error::from) } pub(crate) fn generate_random_symmetric_key(cipher: &CipherSuite, rnd: &mut StdRng) -> Vec {