diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index 5a59f372ad..9b631fb397 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -1780,8 +1780,6 @@ def test_networked_bad_cert(self): h.request('GET', '/') self.assertEqual(exc_info.exception.reason, 'CERTIFICATE_VERIFY_FAILED') - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipIf(sys.platform == 'darwin', 'Occasionally success on macOS') def test_local_unknown_cert(self): # The custom cert isn't known to the default trust bundle diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index fe86454473..bf77d6b690 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -39,9 +39,7 @@ mod _ssl { socket::{self, PySocket}, vm::{ Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{ - PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef, PyWeak, - }, + builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyTypeRef, PyWeak}, class_or_notimplemented, convert::{ToPyException, ToPyObject}, exceptions, @@ -68,7 +66,8 @@ mod _ssl { ffi::CStr, fmt, io::{Read, Write}, - path::Path, + path::{Path, PathBuf}, + sync::LazyLock, time::Instant, }; @@ -91,6 +90,7 @@ mod _ssl { // X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT, SSL_ERROR_ZERO_RETURN, SSL_OP_CIPHER_SERVER_PREFERENCE as OP_CIPHER_SERVER_PREFERENCE, + SSL_OP_ENABLE_MIDDLEBOX_COMPAT as OP_ENABLE_MIDDLEBOX_COMPAT, SSL_OP_LEGACY_SERVER_CONNECT as OP_LEGACY_SERVER_CONNECT, SSL_OP_NO_SSLv2 as OP_NO_SSLv2, SSL_OP_NO_SSLv3 as OP_NO_SSLv3, @@ -193,7 +193,8 @@ mod _ssl { #[pyattr(name = "_OPENSSL_API_VERSION")] fn _openssl_api_version(_vm: &VirtualMachine) -> OpensslVersionInfo { - let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16).unwrap(); + let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16) + .expect("OPENSSL_API_VERSION is malformed"); parse_version_info(openssl_api_version) } @@ -251,7 +252,8 @@ mod _ssl { /// SSL/TLS connection terminated abruptly. #[pyattr(name = "SSLEOFError", once)] fn ssl_eof_error(vm: &VirtualMachine) -> PyTypeRef { - PyType::new_simple_heap("ssl.SSLEOFError", &ssl_error(vm), &vm.ctx).unwrap() + vm.ctx + .new_exception_type("ssl", "SSLEOFError", Some(vec![ssl_error(vm)])) } type OpensslVersionInfo = (u8, u8, u8, u8, u8); @@ -352,14 +354,17 @@ mod _ssl { } type PyNid = (libc::c_int, String, String, Option); - fn obj2py(obj: &Asn1ObjectRef) -> PyNid { + fn obj2py(obj: &Asn1ObjectRef, vm: &VirtualMachine) -> PyResult { let nid = obj.nid(); - ( - nid.as_raw(), - nid.short_name().unwrap().to_owned(), - nid.long_name().unwrap().to_owned(), - obj2txt(obj, true), - ) + let short_name = nid + .short_name() + .map_err(|_| vm.new_value_error("NID has no short name".to_owned()))? + .to_owned(); + let long_name = nid + .long_name() + .map_err(|_| vm.new_value_error("NID has no long name".to_owned()))? + .to_owned(); + Ok((nid.as_raw(), short_name, long_name, obj2txt(obj, true))) } #[derive(FromArgs)] @@ -373,55 +378,81 @@ mod _ssl { fn txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult { _txt2obj(&args.txt.to_cstring(vm)?, !args.name) .as_deref() - .map(obj2py) .ok_or_else(|| vm.new_value_error(format!("unknown object '{}'", args.txt))) + .and_then(|obj| obj2py(obj, vm)) } #[pyfunction] fn nid2obj(nid: libc::c_int, vm: &VirtualMachine) -> PyResult { _nid2obj(Nid::from_raw(nid)) .as_deref() - .map(obj2py) .ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}"))) + .and_then(|obj| obj2py(obj, vm)) } - fn get_cert_file_dir() -> (&'static Path, &'static Path) { - let probe = probe(); - // on windows, these should be utf8 strings - fn path_from_bytes(c: &CStr) -> &Path { + // Lazily compute and cache cert file/dir paths + static CERT_PATHS: LazyLock<(PathBuf, PathBuf)> = LazyLock::new(|| { + fn path_from_cstr(c: &CStr) -> PathBuf { #[cfg(unix)] { use std::os::unix::ffi::OsStrExt; - std::ffi::OsStr::from_bytes(c.to_bytes()).as_ref() + std::ffi::OsStr::from_bytes(c.to_bytes()).into() } #[cfg(windows)] { - c.to_str().unwrap().as_ref() + // Use lossy conversion for potential non-UTF8 + PathBuf::from(c.to_string_lossy().as_ref()) } } - let cert_file = probe.cert_file.as_deref().unwrap_or_else(|| { - path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) }) - }); - let cert_dir = probe.cert_dir.as_deref().unwrap_or_else(|| { - path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) }) - }); + + let probe = probe(); + let cert_file = probe + .cert_file + .as_ref() + .map(PathBuf::from) + .unwrap_or_else(|| { + path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) }) + }); + let cert_dir = probe + .cert_dir + .as_ref() + .map(PathBuf::from) + .unwrap_or_else(|| { + path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) }) + }); (cert_file, cert_dir) + }); + + fn get_cert_file_dir() -> (&'static Path, &'static Path) { + let (cert_file, cert_dir) = &*CERT_PATHS; + (cert_file.as_path(), cert_dir.as_path()) } + // Lazily compute and cache cert environment variable names + static CERT_ENV_NAMES: LazyLock<(String, String)> = LazyLock::new(|| { + let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) } + .to_string_lossy() + .into_owned(); + let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) } + .to_string_lossy() + .into_owned(); + (cert_file_env, cert_dir_env) + }); + #[pyfunction] fn get_default_verify_paths( vm: &VirtualMachine, ) -> PyResult<(&'static str, PyObjectRef, &'static str, PyObjectRef)> { - let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) } - .to_str() - .unwrap(); - let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) } - .to_str() - .unwrap(); + let (cert_file_env, cert_dir_env) = &*CERT_ENV_NAMES; let (cert_file, cert_dir) = get_cert_file_dir(); let cert_file = OsPath::new_str(cert_file).filename(vm); let cert_dir = OsPath::new_str(cert_dir).filename(vm); - Ok((cert_file_env, cert_file, cert_dir_env, cert_dir)) + Ok(( + cert_file_env.as_str(), + cert_file, + cert_dir_env.as_str(), + cert_dir, + )) } #[pyfunction(name = "RAND_status")] @@ -522,6 +553,7 @@ mod _ssl { options |= SslOptions::CIPHER_SERVER_PREFERENCE; options |= SslOptions::SINGLE_DH_USE; options |= SslOptions::SINGLE_ECDH_USE; + options |= SslOptions::ENABLE_MIDDLEBOX_COMPAT; builder.set_options(options); let mode = ssl::SslMode::ACCEPT_MOVING_WRITE_BUFFER | ssl::SslMode::AUTO_RETRY; @@ -536,6 +568,13 @@ mod _ssl { .set_session_id_context(b"Python") .map_err(|e| convert_openssl_error(vm, e))?; + // Set default verify flags: VERIFY_X509_TRUSTED_FIRST + unsafe { + let ctx_ptr = builder.as_ptr(); + let param = sys::SSL_CTX_get0_param(ctx_ptr); + sys::X509_VERIFY_PARAM_set_flags(param, sys::X509_V_FLAG_TRUSTED_FIRST); + } + PySslContext { ctx: PyRwLock::new(builder), check_hostname: AtomicCell::new(check_hostname), @@ -846,8 +885,16 @@ mod _ssl { let certs = ctx.cert_store().all_certificates(); #[cfg(not(ossl300))] let certs = ctx.cert_store().objects().iter().filter_map(|x| x.x509()); + + // Filter to only include CA certificates (Basic Constraints: CA=TRUE) let certs = certs .into_iter() + .filter(|cert| { + unsafe { + // X509_check_ca() returns 1 for CA certificates + X509_check_ca(cert.as_ptr()) == 1 + } + }) .map(|ref cert| cert_to_py(vm, cert, binary_form)) .collect::, _>>()?; Ok(certs) @@ -884,6 +931,20 @@ mod _ssl { args: WrapSocketArgs, vm: &VirtualMachine, ) -> PyResult { + // validate socket type and context protocol + if !args.server_side && zelf.protocol == SslVersion::TlsServer { + return Err(vm.new_exception_msg( + ssl_error(vm), + "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), + )); + } + if args.server_side && zelf.protocol == SslVersion::TlsClient { + return Err(vm.new_exception_msg( + ssl_error(vm), + "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), + )); + } + let mut ssl = ssl::Ssl::new(&zelf.ctx()).map_err(|e| convert_openssl_error(vm, e))?; let socket_type = if args.server_side { @@ -1681,6 +1742,12 @@ mod _ssl { unsafe impl Sync for PySslMemoryBio {} // OpenSSL functions not in openssl-sys + + unsafe extern "C" { + // X509_check_ca returns 1 for CA certificates, 0 otherwise + fn X509_check_ca(x: *const sys::X509) -> libc::c_int; + } + unsafe extern "C" { fn SSL_get_ciphers(ssl: *const sys::SSL) -> *const sys::stack_st_SSL_CIPHER; } @@ -1857,12 +1924,12 @@ mod _ssl { } #[pygetset] - fn id(&self, vm: &VirtualMachine) -> PyObjectRef { + fn id(&self, vm: &VirtualMachine) -> PyBytesRef { unsafe { let mut len: libc::c_uint = 0; let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len); let id_slice = std::slice::from_raw_parts(id_ptr, len as usize); - vm.ctx.new_bytes(id_slice.to_vec()).into() + vm.ctx.new_bytes(id_slice.to_vec()) } } @@ -1900,23 +1967,39 @@ mod _ssl { "certificate verify failed" => "CERTIFICATE_VERIFY_FAILED", _ => default_errstr, }; - let msg = if let Some(lib) = e.library() { - // add `library` attribute - let attr_name = vm.ctx.as_ref().intern_str("library"); - cls.set_attr(attr_name, vm.ctx.new_str(lib).into()); + + // Build message + let lib_obj = e.library(); + let msg = if let Some(lib) = lib_obj { format!("[{lib}] {errstr} ({file}:{line})") } else { format!("{errstr} ({file}:{line})") }; - // add `reason` attribute - let attr_name = vm.ctx.as_ref().intern_str("reason"); - cls.set_attr(attr_name, vm.ctx.new_str(errstr).into()); + // Create exception instance let reason = sys::ERR_GET_REASON(e.code()); - vm.new_exception( + let exc = vm.new_exception( cls, vec![vm.ctx.new_int(reason).into(), vm.ctx.new_str(msg).into()], - ) + ); + + // Set attributes on instance, not class + let exc_obj: PyObjectRef = exc.into(); + + // Set reason attribute (always set, even if just the error string) + let reason_value = vm.ctx.new_str(errstr); + let _ = exc_obj.set_attr("reason", reason_value, vm); + + // Set library attribute (None if not available) + let library_value: PyObjectRef = if let Some(lib) = lib_obj { + vm.ctx.new_str(lib).into() + } else { + vm.ctx.none() + }; + let _ = exc_obj.set_attr("library", library_value, vm); + + // Convert back to PyBaseExceptionRef + exc_obj.downcast().unwrap() } None => vm.new_exception_empty(cls), } @@ -2013,7 +2096,8 @@ mod _ssl { dict.set_item("subject", name_to_py(cert.subject_name())?, vm)?; dict.set_item("issuer", name_to_py(cert.issuer_name())?, vm)?; - dict.set_item("version", vm.new_pyobj(cert.version()), vm)?; + // X.509 version: OpenSSL uses 0-based (0=v1, 1=v2, 2=v3) but Python uses 1-based (1=v1, 2=v2, 3=v3) + dict.set_item("version", vm.new_pyobj(cert.version() + 1), vm)?; let serial_num = cert .serial_number() @@ -2226,21 +2310,18 @@ mod windows { Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj(ascii!("pkcs_7_asn")), other => vm.new_pyobj(other), }; - let usage: PyObjectRef = match c.valid_uses()? { + let usage: PyObjectRef = match c.valid_uses().map_err(|e| e.to_pyexception(vm))? { ValidUses::All => vm.ctx.new_bool(true).into(), ValidUses::Oids(oids) => PyFrozenSet::from_iter( vm, oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()), - ) - .unwrap() + )? .into_ref(&vm.ctx) .into(), }; Ok(vm.new_tuple((cert, enc_type, usage)).into()) }); - let certs = certs - .collect::, _>>() - .map_err(|e: std::io::Error| e.to_pyexception(vm))?; + let certs: Vec = certs.collect::>>()?; Ok(certs) } }