diff --git a/Lib/_opcode_metadata.py b/Lib/_opcode_metadata.py
index 3e98489419f..abb748519c3 100644
--- a/Lib/_opcode_metadata.py
+++ b/Lib/_opcode_metadata.py
@@ -136,10 +136,102 @@
'JUMP_IF_FALSE_OR_POP': 129,
'JUMP_IF_TRUE_OR_POP': 130,
'JUMP_IF_NOT_EXC_MATCH': 131,
- 'SET_EXC_INFO': 134,
- 'SUBSCRIPT': 135,
+ 'SET_EXC_INFO': 132,
+ 'SUBSCRIPT': 133,
'RESUME': 149,
- 'LOAD_CLOSURE': 253,
+ 'BINARY_OP_ADD_FLOAT': 150,
+ 'BINARY_OP_ADD_INT': 151,
+ 'BINARY_OP_ADD_UNICODE': 152,
+ 'BINARY_OP_MULTIPLY_FLOAT': 153,
+ 'BINARY_OP_MULTIPLY_INT': 154,
+ 'BINARY_OP_SUBTRACT_FLOAT': 155,
+ 'BINARY_OP_SUBTRACT_INT': 156,
+ 'BINARY_SUBSCR_DICT': 157,
+ 'BINARY_SUBSCR_GETITEM': 158,
+ 'BINARY_SUBSCR_LIST_INT': 159,
+ 'BINARY_SUBSCR_STR_INT': 160,
+ 'BINARY_SUBSCR_TUPLE_INT': 161,
+ 'CALL_ALLOC_AND_ENTER_INIT': 162,
+ 'CALL_BOUND_METHOD_EXACT_ARGS': 163,
+ 'CALL_BOUND_METHOD_GENERAL': 164,
+ 'CALL_BUILTIN_CLASS': 165,
+ 'CALL_BUILTIN_FAST': 166,
+ 'CALL_BUILTIN_FAST_WITH_KEYWORDS': 167,
+ 'CALL_BUILTIN_O': 168,
+ 'CALL_ISINSTANCE': 169,
+ 'CALL_LEN': 170,
+ 'CALL_LIST_APPEND': 171,
+ 'CALL_METHOD_DESCRIPTOR_FAST': 172,
+ 'CALL_METHOD_DESCRIPTOR_FAST_WITH_KEYWORDS': 173,
+ 'CALL_METHOD_DESCRIPTOR_NOARGS': 174,
+ 'CALL_METHOD_DESCRIPTOR_O': 175,
+ 'CALL_NON_PY_GENERAL': 176,
+ 'CALL_PY_EXACT_ARGS': 177,
+ 'CALL_PY_GENERAL': 178,
+ 'CALL_STR_1': 179,
+ 'CALL_TUPLE_1': 180,
+ 'CALL_TYPE_1': 181,
+ 'COMPARE_OP_FLOAT': 182,
+ 'COMPARE_OP_INT': 183,
+ 'COMPARE_OP_STR': 184,
+ 'CONTAINS_OP_DICT': 185,
+ 'CONTAINS_OP_SET': 186,
+ 'FOR_ITER_GEN': 187,
+ 'FOR_ITER_LIST': 188,
+ 'FOR_ITER_RANGE': 189,
+ 'FOR_ITER_TUPLE': 190,
+ 'LOAD_ATTR_CLASS': 191,
+ 'LOAD_ATTR_GETATTRIBUTE_OVERRIDDEN': 192,
+ 'LOAD_ATTR_INSTANCE_VALUE': 193,
+ 'LOAD_ATTR_METHOD_LAZY_DICT': 194,
+ 'LOAD_ATTR_METHOD_NO_DICT': 195,
+ 'LOAD_ATTR_METHOD_WITH_VALUES': 196,
+ 'LOAD_ATTR_MODULE': 197,
+ 'LOAD_ATTR_NONDESCRIPTOR_NO_DICT': 198,
+ 'LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES': 199,
+ 'LOAD_ATTR_PROPERTY': 200,
+ 'LOAD_ATTR_SLOT': 201,
+ 'LOAD_ATTR_WITH_HINT': 202,
+ 'LOAD_GLOBAL_BUILTIN': 203,
+ 'LOAD_GLOBAL_MODULE': 204,
+ 'LOAD_SUPER_ATTR_ATTR': 205,
+ 'LOAD_SUPER_ATTR_METHOD': 206,
+ 'RESUME_CHECK': 207,
+ 'SEND_GEN': 208,
+ 'STORE_ATTR_INSTANCE_VALUE': 209,
+ 'STORE_ATTR_SLOT': 210,
+ 'STORE_ATTR_WITH_HINT': 211,
+ 'STORE_SUBSCR_DICT': 212,
+ 'STORE_SUBSCR_LIST_INT': 213,
+ 'TO_BOOL_ALWAYS_TRUE': 214,
+ 'TO_BOOL_BOOL': 215,
+ 'TO_BOOL_INT': 216,
+ 'TO_BOOL_LIST': 217,
+ 'TO_BOOL_NONE': 218,
+ 'TO_BOOL_STR': 219,
+ 'UNPACK_SEQUENCE_LIST': 220,
+ 'UNPACK_SEQUENCE_TUPLE': 221,
+ 'UNPACK_SEQUENCE_TWO_TUPLE': 222,
+ 'INSTRUMENTED_RESUME': 236,
+ 'INSTRUMENTED_END_FOR': 237,
+ 'INSTRUMENTED_END_SEND': 238,
+ 'INSTRUMENTED_RETURN_VALUE': 239,
+ 'INSTRUMENTED_RETURN_CONST': 240,
+ 'INSTRUMENTED_YIELD_VALUE': 241,
+ 'INSTRUMENTED_LOAD_SUPER_ATTR': 242,
+ 'INSTRUMENTED_FOR_ITER': 243,
+ 'INSTRUMENTED_CALL': 244,
+ 'INSTRUMENTED_CALL_KW': 245,
+ 'INSTRUMENTED_CALL_FUNCTION_EX': 246,
+ 'INSTRUMENTED_INSTRUCTION': 247,
+ 'INSTRUMENTED_JUMP_FORWARD': 248,
+ 'INSTRUMENTED_JUMP_BACKWARD': 249,
+ 'INSTRUMENTED_POP_JUMP_IF_TRUE': 250,
+ 'INSTRUMENTED_POP_JUMP_IF_FALSE': 251,
+ 'INSTRUMENTED_POP_JUMP_IF_NONE': 252,
+ 'INSTRUMENTED_POP_JUMP_IF_NOT_NONE': 253,
+ 'INSTRUMENTED_LINE': 254,
+ 'LOAD_CLOSURE': 255,
'JUMP': 256,
'JUMP_NO_INTERRUPT': 257,
'RESERVED_258': 258,
diff --git a/Lib/http/__init__.py b/Lib/http/__init__.py
index bf8d7d68868..17a47b180e5 100644
--- a/Lib/http/__init__.py
+++ b/Lib/http/__init__.py
@@ -1,14 +1,15 @@
-from enum import IntEnum
+from enum import StrEnum, IntEnum, _simple_enum
-__all__ = ['HTTPStatus']
+__all__ = ['HTTPStatus', 'HTTPMethod']
-class HTTPStatus(IntEnum):
+@_simple_enum(IntEnum)
+class HTTPStatus:
"""HTTP status codes and reason phrases
Status codes from the following RFCs are all observed:
- * RFC 7231: Hypertext Transfer Protocol (HTTP/1.1), obsoletes 2616
+ * RFC 9110: HTTP Semantics, obsoletes 7231, which obsoleted 2616
* RFC 6585: Additional HTTP Status Codes
* RFC 3229: Delta encoding in HTTP
* RFC 4918: HTTP Extensions for WebDAV, obsoletes 2518
@@ -25,11 +26,30 @@ class HTTPStatus(IntEnum):
def __new__(cls, value, phrase, description=''):
obj = int.__new__(cls, value)
obj._value_ = value
-
obj.phrase = phrase
obj.description = description
return obj
+ @property
+ def is_informational(self):
+ return 100 <= self <= 199
+
+ @property
+ def is_success(self):
+ return 200 <= self <= 299
+
+ @property
+ def is_redirection(self):
+ return 300 <= self <= 399
+
+ @property
+ def is_client_error(self):
+ return 400 <= self <= 499
+
+ @property
+ def is_server_error(self):
+ return 500 <= self <= 599
+
# informational
CONTINUE = 100, 'Continue', 'Request received, please continue'
SWITCHING_PROTOCOLS = (101, 'Switching Protocols',
@@ -94,22 +114,25 @@ def __new__(cls, value, phrase, description=''):
'Client must specify Content-Length')
PRECONDITION_FAILED = (412, 'Precondition Failed',
'Precondition in headers is false')
- REQUEST_ENTITY_TOO_LARGE = (413, 'Request Entity Too Large',
- 'Entity is too large')
- REQUEST_URI_TOO_LONG = (414, 'Request-URI Too Long',
+ CONTENT_TOO_LARGE = (413, 'Content Too Large',
+ 'Content is too large')
+ REQUEST_ENTITY_TOO_LARGE = CONTENT_TOO_LARGE
+ URI_TOO_LONG = (414, 'URI Too Long',
'URI is too long')
+ REQUEST_URI_TOO_LONG = URI_TOO_LONG
UNSUPPORTED_MEDIA_TYPE = (415, 'Unsupported Media Type',
'Entity body in unsupported format')
- REQUESTED_RANGE_NOT_SATISFIABLE = (416,
- 'Requested Range Not Satisfiable',
+ RANGE_NOT_SATISFIABLE = (416, 'Range Not Satisfiable',
'Cannot satisfy request range')
+ REQUESTED_RANGE_NOT_SATISFIABLE = RANGE_NOT_SATISFIABLE
EXPECTATION_FAILED = (417, 'Expectation Failed',
'Expect condition could not be satisfied')
IM_A_TEAPOT = (418, 'I\'m a Teapot',
'Server refuses to brew coffee because it is a teapot.')
MISDIRECTED_REQUEST = (421, 'Misdirected Request',
'Server is not able to produce a response')
- UNPROCESSABLE_ENTITY = 422, 'Unprocessable Entity'
+ UNPROCESSABLE_CONTENT = 422, 'Unprocessable Content'
+ UNPROCESSABLE_ENTITY = UNPROCESSABLE_CONTENT
LOCKED = 423, 'Locked'
FAILED_DEPENDENCY = 424, 'Failed Dependency'
TOO_EARLY = 425, 'Too Early'
@@ -148,3 +171,32 @@ def __new__(cls, value, phrase, description=''):
NETWORK_AUTHENTICATION_REQUIRED = (511,
'Network Authentication Required',
'The client needs to authenticate to gain network access')
+
+
+@_simple_enum(StrEnum)
+class HTTPMethod:
+ """HTTP methods and descriptions
+
+ Methods from the following RFCs are all observed:
+
+ * RFC 9110: HTTP Semantics, obsoletes 7231, which obsoleted 2616
+ * RFC 5789: PATCH Method for HTTP
+ """
+ def __new__(cls, value, description):
+ obj = str.__new__(cls, value)
+ obj._value_ = value
+ obj.description = description
+ return obj
+
+ def __repr__(self):
+ return "<%s.%s>" % (self.__class__.__name__, self._name_)
+
+ CONNECT = 'CONNECT', 'Establish a connection to the server.'
+ DELETE = 'DELETE', 'Remove the target.'
+ GET = 'GET', 'Retrieve the target.'
+ HEAD = 'HEAD', 'Same as GET, but only retrieve the status line and header section.'
+ OPTIONS = 'OPTIONS', 'Describe the communication options for the target.'
+ PATCH = 'PATCH', 'Apply partial modifications to a target.'
+ POST = 'POST', 'Perform target-specific processing with the request payload.'
+ PUT = 'PUT', 'Replace the target with the request payload.'
+ TRACE = 'TRACE', 'Perform a message loop-back test along the path to the target.'
diff --git a/Lib/http/client.py b/Lib/http/client.py
index a6ab135b2c3..dd5f4136e9e 100644
--- a/Lib/http/client.py
+++ b/Lib/http/client.py
@@ -111,6 +111,11 @@
_MAXLINE = 65536
_MAXHEADERS = 100
+# Data larger than this will be read in chunks, to prevent extreme
+# overallocation.
+_MIN_READ_BUF_SIZE = 1 << 20
+
+
# Header name/value ABNF (http://tools.ietf.org/html/rfc7230#section-3.2)
#
# VCHAR = %x21-7E
@@ -172,6 +177,13 @@ def _encode(data, name='data'):
"if you want to send it encoded in UTF-8." %
(name.title(), data[err.start:err.end], name)) from None
+def _strip_ipv6_iface(enc_name: bytes) -> bytes:
+ """Remove interface scope from IPv6 address."""
+ enc_name, percent, _ = enc_name.partition(b"%")
+ if percent:
+ assert enc_name.startswith(b'['), enc_name
+ enc_name += b']'
+ return enc_name
class HTTPMessage(email.message.Message):
# XXX The only usage of this method is in
@@ -221,8 +233,9 @@ def _read_headers(fp):
break
return headers
-def parse_headers(fp, _class=HTTPMessage):
- """Parses only RFC2822 headers from a file pointer.
+def _parse_header_lines(header_lines, _class=HTTPMessage):
+ """
+ Parses only RFC 5322 headers from header lines.
email Parser wants to see strings rather than bytes.
But a TextIOWrapper around self.rfile would buffer too many bytes
@@ -231,10 +244,15 @@ def parse_headers(fp, _class=HTTPMessage):
to parse.
"""
- headers = _read_headers(fp)
- hstring = b''.join(headers).decode('iso-8859-1')
+ hstring = b''.join(header_lines).decode('iso-8859-1')
return email.parser.Parser(_class=_class).parsestr(hstring)
+def parse_headers(fp, _class=HTTPMessage):
+ """Parses only RFC 5322 headers from a file pointer."""
+
+ headers = _read_headers(fp)
+ return _parse_header_lines(headers, _class)
+
class HTTPResponse(io.BufferedIOBase):
@@ -448,6 +466,7 @@ def isclosed(self):
return self.fp is None
def read(self, amt=None):
+ """Read and return the response body, or up to the next amt bytes."""
if self.fp is None:
return b""
@@ -458,7 +477,7 @@ def read(self, amt=None):
if self.chunked:
return self._read_chunked(amt)
- if amt is not None:
+ if amt is not None and amt >= 0:
if self.length is not None and amt > self.length:
# clip the read to the "end of response"
amt = self.length
@@ -576,13 +595,11 @@ def _get_chunk_left(self):
def _read_chunked(self, amt=None):
assert self.chunked != _UNKNOWN
+ if amt is not None and amt < 0:
+ amt = None
value = []
try:
- while True:
- chunk_left = self._get_chunk_left()
- if chunk_left is None:
- break
-
+ while (chunk_left := self._get_chunk_left()) is not None:
if amt is not None and amt <= chunk_left:
value.append(self._safe_read(amt))
self.chunk_left = chunk_left - amt
@@ -593,8 +610,8 @@ def _read_chunked(self, amt=None):
amt -= chunk_left
self.chunk_left = 0
return b''.join(value)
- except IncompleteRead:
- raise IncompleteRead(b''.join(value))
+ except IncompleteRead as exc:
+ raise IncompleteRead(b''.join(value)) from exc
def _readinto_chunked(self, b):
assert self.chunked != _UNKNOWN
@@ -627,10 +644,25 @@ def _safe_read(self, amt):
reading. If the bytes are truly not available (due to EOF), then the
IncompleteRead exception can be used to detect the problem.
"""
- data = self.fp.read(amt)
- if len(data) < amt:
- raise IncompleteRead(data, amt-len(data))
- return data
+ cursize = min(amt, _MIN_READ_BUF_SIZE)
+ data = self.fp.read(cursize)
+ if len(data) >= amt:
+ return data
+ if len(data) < cursize:
+ raise IncompleteRead(data, amt - len(data))
+
+ data = io.BytesIO(data)
+ data.seek(0, 2)
+ while True:
+ # This is a geometric increase in read size (never more than
+ # doubling out the current length of data per loop iteration).
+ delta = min(cursize, amt - cursize)
+ data.write(self.fp.read(delta))
+ if data.tell() >= amt:
+ return data.getvalue()
+ cursize += delta
+ if data.tell() < cursize:
+ raise IncompleteRead(data.getvalue(), amt - data.tell())
def _safe_readinto(self, b):
"""Same as _safe_read, but for reading into a buffer."""
@@ -655,6 +687,8 @@ def read1(self, n=-1):
self._close_conn()
elif self.length is not None:
self.length -= len(result)
+ if not self.length:
+ self._close_conn()
return result
def peek(self, n=-1):
@@ -679,6 +713,8 @@ def readline(self, limit=-1):
self._close_conn()
elif self.length is not None:
self.length -= len(result)
+ if not self.length:
+ self._close_conn()
return result
def _read1_chunked(self, n):
@@ -786,6 +822,20 @@ def getcode(self):
'''
return self.status
+
+def _create_https_context(http_version):
+ # Function also used by urllib.request to be able to set the check_hostname
+ # attribute on a context object.
+ context = ssl._create_default_https_context()
+ # send ALPN extension to indicate HTTP/1.1 protocol
+ if http_version == 11:
+ context.set_alpn_protocols(['http/1.1'])
+ # enable PHA for TLS 1.3 connections if available
+ if context.post_handshake_auth is not None:
+ context.post_handshake_auth = True
+ return context
+
+
class HTTPConnection:
_http_vsn = 11
@@ -847,6 +897,7 @@ def __init__(self, host, port=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
self._tunnel_host = None
self._tunnel_port = None
self._tunnel_headers = {}
+ self._raw_proxy_headers = None
(self.host, self.port) = self._get_hostport(host, port)
@@ -859,9 +910,9 @@ def __init__(self, host, port=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
def set_tunnel(self, host, port=None, headers=None):
"""Set up host and port for HTTP CONNECT tunnelling.
- In a connection that uses HTTP CONNECT tunneling, the host passed to the
- constructor is used as a proxy server that relays all communication to
- the endpoint passed to `set_tunnel`. This done by sending an HTTP
+ In a connection that uses HTTP CONNECT tunnelling, the host passed to
+ the constructor is used as a proxy server that relays all communication
+ to the endpoint passed to `set_tunnel`. This done by sending an HTTP
CONNECT request to the proxy server when the connection is established.
This method must be called before the HTTP connection has been
@@ -869,6 +920,13 @@ def set_tunnel(self, host, port=None, headers=None):
The headers argument should be a mapping of extra HTTP headers to send
with the CONNECT request.
+
+ As HTTP/1.1 is used for HTTP CONNECT tunnelling request, as per the RFC
+ (https://tools.ietf.org/html/rfc7231#section-4.3.6), a HTTP Host:
+ header must be provided, matching the authority-form of the request
+ target provided as the destination for the CONNECT request. If a
+ HTTP Host: header is not provided via the headers argument, one
+ is generated and transmitted automatically.
"""
if self.sock:
@@ -876,10 +934,15 @@ def set_tunnel(self, host, port=None, headers=None):
self._tunnel_host, self._tunnel_port = self._get_hostport(host, port)
if headers:
- self._tunnel_headers = headers
+ self._tunnel_headers = headers.copy()
else:
self._tunnel_headers.clear()
+ if not any(header.lower() == "host" for header in self._tunnel_headers):
+ encoded_host = self._tunnel_host.encode("idna").decode("ascii")
+ self._tunnel_headers["Host"] = "%s:%d" % (
+ encoded_host, self._tunnel_port)
+
def _get_hostport(self, host, port):
if port is None:
i = host.rfind(':')
@@ -895,17 +958,24 @@ def _get_hostport(self, host, port):
host = host[:i]
else:
port = self.default_port
- if host and host[0] == '[' and host[-1] == ']':
- host = host[1:-1]
+ if host and host[0] == '[' and host[-1] == ']':
+ host = host[1:-1]
return (host, port)
def set_debuglevel(self, level):
self.debuglevel = level
+ def _wrap_ipv6(self, ip):
+ if b':' in ip and ip[0] != b'['[0]:
+ return b"[" + ip + b"]"
+ return ip
+
def _tunnel(self):
- connect = b"CONNECT %s:%d HTTP/1.0\r\n" % (
- self._tunnel_host.encode("ascii"), self._tunnel_port)
+ connect = b"CONNECT %s:%d %s\r\n" % (
+ self._wrap_ipv6(self._tunnel_host.encode("idna")),
+ self._tunnel_port,
+ self._http_vsn_str.encode("ascii"))
headers = [connect]
for header, value in self._tunnel_headers.items():
headers.append(f"{header}: {value}\r\n".encode("latin-1"))
@@ -917,23 +987,35 @@ def _tunnel(self):
del headers
response = self.response_class(self.sock, method=self._method)
- (version, code, message) = response._read_status()
+ try:
+ (version, code, message) = response._read_status()
- if code != http.HTTPStatus.OK:
- self.close()
- raise OSError(f"Tunnel connection failed: {code} {message.strip()}")
- while True:
- line = response.fp.readline(_MAXLINE + 1)
- if len(line) > _MAXLINE:
- raise LineTooLong("header line")
- if not line:
- # for sites which EOF without sending a trailer
- break
- if line in (b'\r\n', b'\n', b''):
- break
+ self._raw_proxy_headers = _read_headers(response.fp)
if self.debuglevel > 0:
- print('header:', line.decode())
+ for header in self._raw_proxy_headers:
+ print('header:', header.decode())
+
+ if code != http.HTTPStatus.OK:
+ self.close()
+ raise OSError(f"Tunnel connection failed: {code} {message.strip()}")
+
+ finally:
+ response.close()
+
+ def get_proxy_response_headers(self):
+ """
+ Returns a dictionary with the headers of the response
+ received from the proxy server to the CONNECT request
+ sent to set the tunnel.
+
+ If the CONNECT request was not sent, the method returns None.
+ """
+ return (
+ _parse_header_lines(self._raw_proxy_headers)
+ if self._raw_proxy_headers is not None
+ else None
+ )
def connect(self):
"""Connect to the host and port specified in __init__."""
@@ -942,7 +1024,7 @@ def connect(self):
(self.host,self.port), self.timeout, self.source_address)
# Might fail in OSs that don't implement TCP_NODELAY
try:
- self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except OSError as e:
if e.errno != errno.ENOPROTOOPT:
raise
@@ -980,14 +1062,11 @@ def send(self, data):
print("send:", repr(data))
if hasattr(data, "read") :
if self.debuglevel > 0:
- print("sendIng a read()able")
+ print("sending a readable")
encode = self._is_textIO(data)
if encode and self.debuglevel > 0:
print("encoding file using iso-8859-1")
- while 1:
- datablock = data.read(self.blocksize)
- if not datablock:
- break
+ while datablock := data.read(self.blocksize):
if encode:
datablock = datablock.encode("iso-8859-1")
sys.audit("http.client.send", self, datablock)
@@ -1013,14 +1092,11 @@ def _output(self, s):
def _read_readable(self, readable):
if self.debuglevel > 0:
- print("sendIng a read()able")
+ print("reading a readable")
encode = self._is_textIO(readable)
if encode and self.debuglevel > 0:
print("encoding file using iso-8859-1")
- while True:
- datablock = readable.read(self.blocksize)
- if not datablock:
- break
+ while datablock := readable.read(self.blocksize):
if encode:
datablock = datablock.encode("iso-8859-1")
yield datablock
@@ -1157,7 +1233,7 @@ def putrequest(self, method, url, skip_host=False,
netloc_enc = netloc.encode("ascii")
except UnicodeEncodeError:
netloc_enc = netloc.encode("idna")
- self.putheader('Host', netloc_enc)
+ self.putheader('Host', _strip_ipv6_iface(netloc_enc))
else:
if self._tunnel_host:
host = self._tunnel_host
@@ -1173,9 +1249,9 @@ def putrequest(self, method, url, skip_host=False,
# As per RFC 273, IPv6 address should be wrapped with []
# when used as Host header
-
- if host.find(':') >= 0:
- host_enc = b'[' + host_enc + b']'
+ host_enc = self._wrap_ipv6(host_enc)
+ if ":" in host:
+ host_enc = _strip_ipv6_iface(host_enc)
if port == self.default_port:
self.putheader('Host', host_enc)
@@ -1400,46 +1476,15 @@ class HTTPSConnection(HTTPConnection):
default_port = HTTPS_PORT
- # XXX Should key_file and cert_file be deprecated in favour of context?
-
- def __init__(self, host, port=None, key_file=None, cert_file=None,
- timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
- source_address=None, *, context=None,
- check_hostname=None, blocksize=8192):
+ def __init__(self, host, port=None,
+ *, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
+ source_address=None, context=None, blocksize=8192):
super(HTTPSConnection, self).__init__(host, port, timeout,
source_address,
blocksize=blocksize)
- if (key_file is not None or cert_file is not None or
- check_hostname is not None):
- import warnings
- warnings.warn("key_file, cert_file and check_hostname are "
- "deprecated, use a custom context instead.",
- DeprecationWarning, 2)
- self.key_file = key_file
- self.cert_file = cert_file
if context is None:
- context = ssl._create_default_https_context()
- # send ALPN extension to indicate HTTP/1.1 protocol
- if self._http_vsn == 11:
- context.set_alpn_protocols(['http/1.1'])
- # enable PHA for TLS 1.3 connections if available
- if context.post_handshake_auth is not None:
- context.post_handshake_auth = True
- will_verify = context.verify_mode != ssl.CERT_NONE
- if check_hostname is None:
- check_hostname = context.check_hostname
- if check_hostname and not will_verify:
- raise ValueError("check_hostname needs a SSL context with "
- "either CERT_OPTIONAL or CERT_REQUIRED")
- if key_file or cert_file:
- context.load_cert_chain(cert_file, key_file)
- # cert and key file means the user wants to authenticate.
- # enable TLS 1.3 PHA implicitly even for custom contexts.
- if context.post_handshake_auth is not None:
- context.post_handshake_auth = True
+ context = _create_https_context(self._http_vsn)
self._context = context
- if check_hostname is not None:
- self._context.check_hostname = check_hostname
def connect(self):
"Connect to a host on a given (SSL) port."
diff --git a/Lib/http/cookiejar.py b/Lib/http/cookiejar.py
index 685f6a0b976..9a2f0fb851c 100644
--- a/Lib/http/cookiejar.py
+++ b/Lib/http/cookiejar.py
@@ -34,10 +34,7 @@
import re
import time
import urllib.parse, urllib.request
-try:
- import threading as _threading
-except ImportError:
- import dummy_threading as _threading
+import threading as _threading
import http.client # only for the default HTTP port
from calendar import timegm
@@ -92,8 +89,7 @@ def _timegm(tt):
DAYS = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
MONTHS = ["Jan", "Feb", "Mar", "Apr", "May", "Jun",
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
-MONTHS_LOWER = []
-for month in MONTHS: MONTHS_LOWER.append(month.lower())
+MONTHS_LOWER = [month.lower() for month in MONTHS]
def time2isoz(t=None):
"""Return a string representing time in seconds since epoch, t.
@@ -108,9 +104,9 @@ def time2isoz(t=None):
"""
if t is None:
- dt = datetime.datetime.utcnow()
+ dt = datetime.datetime.now(tz=datetime.UTC)
else:
- dt = datetime.datetime.utcfromtimestamp(t)
+ dt = datetime.datetime.fromtimestamp(t, tz=datetime.UTC)
return "%04d-%02d-%02d %02d:%02d:%02dZ" % (
dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
@@ -126,9 +122,9 @@ def time2netscape(t=None):
"""
if t is None:
- dt = datetime.datetime.utcnow()
+ dt = datetime.datetime.now(tz=datetime.UTC)
else:
- dt = datetime.datetime.utcfromtimestamp(t)
+ dt = datetime.datetime.fromtimestamp(t, tz=datetime.UTC)
return "%s, %02d-%s-%04d %02d:%02d:%02d GMT" % (
DAYS[dt.weekday()], dt.day, MONTHS[dt.month-1],
dt.year, dt.hour, dt.minute, dt.second)
@@ -434,6 +430,7 @@ def split_header_words(header_values):
if pairs: result.append(pairs)
return result
+HEADER_JOIN_TOKEN_RE = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+")
HEADER_JOIN_ESCAPE_RE = re.compile(r"([\"\\])")
def join_header_words(lists):
"""Do the inverse (almost) of the conversion done by split_header_words.
@@ -441,10 +438,10 @@ def join_header_words(lists):
Takes a list of lists of (key, value) pairs and produces a single header
value. Attribute values are quoted if needed.
- >>> join_header_words([[("text/plain", None), ("charset", "iso-8859-1")]])
- 'text/plain; charset="iso-8859-1"'
- >>> join_header_words([[("text/plain", None)], [("charset", "iso-8859-1")]])
- 'text/plain, charset="iso-8859-1"'
+ >>> join_header_words([[("text/plain", None), ("charset", "iso-8859/1")]])
+ 'text/plain; charset="iso-8859/1"'
+ >>> join_header_words([[("text/plain", None)], [("charset", "iso-8859/1")]])
+ 'text/plain, charset="iso-8859/1"'
"""
headers = []
@@ -452,7 +449,7 @@ def join_header_words(lists):
attr = []
for k, v in pairs:
if v is not None:
- if not re.search(r"^\w+$", v):
+ if not HEADER_JOIN_TOKEN_RE.fullmatch(v):
v = HEADER_JOIN_ESCAPE_RE.sub(r"\\\1", v) # escape " and \
v = '"%s"' % v
k = "%s=%s" % (k, v)
@@ -644,7 +641,7 @@ def eff_request_host(request):
"""
erhn = req_host = request_host(request)
- if req_host.find(".") == -1 and not IPV4_RE.search(req_host):
+ if "." not in req_host:
erhn = req_host + ".local"
return req_host, erhn
@@ -1047,12 +1044,13 @@ def set_ok_domain(self, cookie, request):
else:
undotted_domain = domain
embedded_dots = (undotted_domain.find(".") >= 0)
- if not embedded_dots and domain != ".local":
+ if not embedded_dots and not erhn.endswith(".local"):
_debug(" non-local domain %s contains no embedded dot",
domain)
return False
if cookie.version == 0:
- if (not erhn.endswith(domain) and
+ if (not (erhn.endswith(domain) or
+ erhn.endswith(f"{undotted_domain}.local")) and
(not erhn.startswith(".") and
not ("."+erhn).endswith(domain))):
_debug(" effective request-host %s (even with added "
@@ -1227,14 +1225,9 @@ def path_return_ok(self, path, request):
_debug(" %s does not path-match %s", req_path, path)
return False
-def vals_sorted_by_key(adict):
- keys = sorted(adict.keys())
- return map(adict.get, keys)
-
def deepvalues(mapping):
- """Iterates over nested mapping, depth-first, in sorted order by key."""
- values = vals_sorted_by_key(mapping)
- for obj in values:
+ """Iterates over nested mapping, depth-first"""
+ for obj in list(mapping.values()):
mapping = False
try:
obj.items
@@ -1898,7 +1891,10 @@ def save(self, filename=None, ignore_discard=False, ignore_expires=False):
if self.filename is not None: filename = self.filename
else: raise ValueError(MISSING_FILENAME_TEXT)
- with open(filename, "w") as f:
+ with os.fdopen(
+ os.open(filename, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o600),
+ 'w',
+ ) as f:
# There really isn't an LWP Cookies 2.0 format, but this indicates
# that there is extra information in here (domain_dot and
# port_spec) while still being compatible with libwww-perl, I hope.
@@ -1923,9 +1919,7 @@ def _really_load(self, f, filename, ignore_discard, ignore_expires):
"comment", "commenturl")
try:
- while 1:
- line = f.readline()
- if line == "": break
+ while (line := f.readline()) != "":
if not line.startswith(header):
continue
line = line[len(header):].strip()
@@ -1993,7 +1987,7 @@ class MozillaCookieJar(FileCookieJar):
This class differs from CookieJar only in the format it uses to save and
load cookies to and from a file. This class uses the Mozilla/Netscape
- `cookies.txt' format. lynx uses this file format, too.
+ `cookies.txt' format. curl and lynx use this file format, too.
Don't expect cookies saved while the browser is running to be noticed by
the browser (in fact, Mozilla on unix will overwrite your saved cookies if
@@ -2025,12 +2019,9 @@ def _really_load(self, f, filename, ignore_discard, ignore_expires):
filename)
try:
- while 1:
- line = f.readline()
+ while (line := f.readline()) != "":
rest = {}
- if line == "": break
-
# httponly is a cookie flag as defined in rfc6265
# when encoded in a netscape cookie file,
# the line is prepended with "#HttpOnly_"
@@ -2094,7 +2085,10 @@ def save(self, filename=None, ignore_discard=False, ignore_expires=False):
if self.filename is not None: filename = self.filename
else: raise ValueError(MISSING_FILENAME_TEXT)
- with open(filename, "w") as f:
+ with os.fdopen(
+ os.open(filename, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o600),
+ 'w',
+ ) as f:
f.write(NETSCAPE_HEADER_TEXT)
now = time.time()
for cookie in self:
diff --git a/Lib/http/server.py b/Lib/http/server.py
index 58abadf7377..0ec479003a4 100644
--- a/Lib/http/server.py
+++ b/Lib/http/server.py
@@ -2,18 +2,18 @@
Note: BaseHTTPRequestHandler doesn't implement any HTTP request; see
SimpleHTTPRequestHandler for simple implementations of GET, HEAD and POST,
-and CGIHTTPRequestHandler for CGI scripts.
+and (deprecated) CGIHTTPRequestHandler for CGI scripts.
-It does, however, optionally implement HTTP/1.1 persistent connections,
-as of version 0.3.
+It does, however, optionally implement HTTP/1.1 persistent connections.
Notes on CGIHTTPRequestHandler
------------------------------
-This class implements GET and POST requests to cgi-bin scripts.
+This class is deprecated. It implements GET and POST requests to cgi-bin scripts.
-If the os.fork() function is not present (e.g. on Windows),
-subprocess.Popen() is used as a fallback, with slightly altered semantics.
+If the os.fork() function is not present (Windows), subprocess.Popen() is used,
+with slightly altered but never documented semantics. Use from a threaded
+process is likely to trigger a warning at os.fork() time.
In all cases, the implementation is intentionally naive -- all
requests are executed synchronously.
@@ -93,6 +93,7 @@
import html
import http.client
import io
+import itertools
import mimetypes
import os
import posixpath
@@ -109,11 +110,10 @@
# Default error message template
DEFAULT_ERROR_MESSAGE = """\
-
-
+
+
-
+
Error response
@@ -127,6 +127,10 @@
DEFAULT_ERROR_CONTENT_TYPE = "text/html;charset=utf-8"
+# Data larger than this will be read in chunks, to prevent extreme
+# overallocation.
+_MIN_READ_BUF_SIZE = 1 << 20
+
class HTTPServer(socketserver.TCPServer):
allow_reuse_address = 1 # Seems to make sense in testing environment
@@ -275,6 +279,7 @@ def parse_request(self):
error response has already been sent back.
"""
+ is_http_0_9 = False
self.command = None # set in case of error on the first line
self.request_version = version = self.default_request_version
self.close_connection = True
@@ -300,6 +305,10 @@ def parse_request(self):
# - Leading zeros MUST be ignored by recipients.
if len(version_number) != 2:
raise ValueError
+ if any(not component.isdigit() for component in version_number):
+ raise ValueError("non digit in http version")
+ if any(len(component) > 10 for component in version_number):
+ raise ValueError("unreasonable length http version")
version_number = int(version_number[0]), int(version_number[1])
except (ValueError, IndexError):
self.send_error(
@@ -328,8 +337,21 @@ def parse_request(self):
HTTPStatus.BAD_REQUEST,
"Bad HTTP/0.9 request type (%r)" % command)
return False
+ is_http_0_9 = True
self.command, self.path = command, path
+ # gh-87389: The purpose of replacing '//' with '/' is to protect
+ # against open redirect attacks possibly triggered if the path starts
+ # with '//' because http clients treat //path as an absolute URI
+ # without scheme (similar to http://path) rather than a path.
+ if self.path.startswith('//'):
+ self.path = '/' + self.path.lstrip('/') # Reduce to a single /
+
+ # For HTTP/0.9, headers are not expected at all.
+ if is_http_0_9:
+ self.headers = {}
+ return True
+
# Examine the headers and look for a Connection directive.
try:
self.headers = http.client.parse_headers(self.rfile,
@@ -556,6 +578,11 @@ def log_error(self, format, *args):
self.log_message(format, *args)
+ # https://en.wikipedia.org/wiki/List_of_Unicode_characters#Control_codes
+ _control_char_table = str.maketrans(
+ {c: fr'\x{c:02x}' for c in itertools.chain(range(0x20), range(0x7f,0xa0))})
+ _control_char_table[ord('\\')] = r'\\'
+
def log_message(self, format, *args):
"""Log an arbitrary message.
@@ -571,12 +598,16 @@ def log_message(self, format, *args):
The client ip and current date/time are prefixed to
every message.
+ Unicode control characters are replaced with escaped hex
+ before writing the output to stderr.
+
"""
+ message = format % args
sys.stderr.write("%s - - [%s] %s\n" %
(self.address_string(),
self.log_date_time_string(),
- format%args))
+ message.translate(self._control_char_table)))
def version_string(self):
"""Return the server software version string."""
@@ -637,6 +668,7 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
"""
server_version = "SimpleHTTP/" + __version__
+ index_pages = ("index.html", "index.htm")
extensions_map = _encodings_map_default = {
'.gz': 'application/gzip',
'.Z': 'application/octet-stream',
@@ -680,7 +712,7 @@ def send_head(self):
f = None
if os.path.isdir(path):
parts = urllib.parse.urlsplit(self.path)
- if not parts.path.endswith('/'):
+ if not parts.path.endswith(('/', '%2f', '%2F')):
# redirect browser - doing basically what apache does
self.send_response(HTTPStatus.MOVED_PERMANENTLY)
new_parts = (parts[0], parts[1], parts[2] + '/',
@@ -690,9 +722,9 @@ def send_head(self):
self.send_header("Content-Length", "0")
self.end_headers()
return None
- for index in "index.html", "index.htm":
+ for index in self.index_pages:
index = os.path.join(path, index)
- if os.path.exists(index):
+ if os.path.isfile(index):
path = index
break
else:
@@ -702,7 +734,7 @@ def send_head(self):
# The test for this was added in test_httpserver.py
# However, some OS platforms accept a trailingSlash as a filename
# See discussion on python-dev and Issue34711 regarding
- # parseing and rejection of filenames with a trailing slash
+ # parsing and rejection of filenames with a trailing slash
if path.endswith("/"):
self.send_error(HTTPStatus.NOT_FOUND, "File not found")
return None
@@ -770,21 +802,23 @@ def list_directory(self, path):
return None
list.sort(key=lambda a: a.lower())
r = []
+ displaypath = self.path
+ displaypath = displaypath.split('#', 1)[0]
+ displaypath = displaypath.split('?', 1)[0]
try:
- displaypath = urllib.parse.unquote(self.path,
+ displaypath = urllib.parse.unquote(displaypath,
errors='surrogatepass')
except UnicodeDecodeError:
- displaypath = urllib.parse.unquote(path)
+ displaypath = urllib.parse.unquote(displaypath)
displaypath = html.escape(displaypath, quote=False)
enc = sys.getfilesystemencoding()
- title = 'Directory listing for %s' % displaypath
- r.append('')
- r.append('\n')
- r.append(' ' % enc)
- r.append('%s \n' % title)
- r.append('\n%s ' % title)
+ title = f'Directory listing for {displaypath}'
+ r.append('')
+ r.append('')
+ r.append('')
+ r.append(f' ')
+ r.append(f'{title} \n')
+ r.append(f'\n{title} ')
r.append(' \n')
for name in list:
fullname = os.path.join(path, name)
@@ -820,14 +854,14 @@ def translate_path(self, path):
"""
# abandon query parameters
- path = path.split('?',1)[0]
- path = path.split('#',1)[0]
+ path = path.split('#', 1)[0]
+ path = path.split('?', 1)[0]
# Don't forget explicit trailing slash when normalizing. Issue17324
- trailing_slash = path.rstrip().endswith('/')
try:
path = urllib.parse.unquote(path, errors='surrogatepass')
except UnicodeDecodeError:
path = urllib.parse.unquote(path)
+ trailing_slash = path.endswith('/')
path = posixpath.normpath(path)
words = path.split('/')
words = filter(None, words)
@@ -877,7 +911,7 @@ def guess_type(self, path):
ext = ext.lower()
if ext in self.extensions_map:
return self.extensions_map[ext]
- guess, _ = mimetypes.guess_type(path)
+ guess, _ = mimetypes.guess_file_type(path)
if guess:
return guess
return 'application/octet-stream'
@@ -966,6 +1000,12 @@ class CGIHTTPRequestHandler(SimpleHTTPRequestHandler):
"""
+ def __init__(self, *args, **kwargs):
+ import warnings
+ warnings._deprecated("http.server.CGIHTTPRequestHandler",
+ remove=(3, 15))
+ super().__init__(*args, **kwargs)
+
# Determine platform specifics
have_fork = hasattr(os, 'fork')
@@ -1078,7 +1118,7 @@ def run_cgi(self):
"CGI script is not executable (%r)" % scriptname)
return
- # Reference: http://hoohoo.ncsa.uiuc.edu/cgi/env.html
+ # Reference: https://www6.uniovi.es/~antonio/ncsa_httpd/cgi/env.html
# XXX Much of the following could be prepared ahead of time!
env = copy.deepcopy(os.environ)
env['SERVER_SOFTWARE'] = self.version_string()
@@ -1198,7 +1238,18 @@ def run_cgi(self):
env = env
)
if self.command.lower() == "post" and nbytes > 0:
- data = self.rfile.read(nbytes)
+ cursize = 0
+ data = self.rfile.read(min(nbytes, _MIN_READ_BUF_SIZE))
+ while len(data) < nbytes and len(data) != cursize:
+ cursize = len(data)
+ # This is a geometric increase in read size (never more
+ # than doubling out the current length of data per loop
+ # iteration).
+ delta = min(cursize, nbytes - cursize)
+ try:
+ data += self.rfile.read(delta)
+ except TimeoutError:
+ break
else:
data = None
# throw away additional data [see bug #427345]
@@ -1258,15 +1309,19 @@ def test(HandlerClass=BaseHTTPRequestHandler,
parser = argparse.ArgumentParser()
parser.add_argument('--cgi', action='store_true',
help='run as CGI server')
- parser.add_argument('--bind', '-b', metavar='ADDRESS',
- help='specify alternate bind address '
+ parser.add_argument('-b', '--bind', metavar='ADDRESS',
+ help='bind to this address '
'(default: all interfaces)')
- parser.add_argument('--directory', '-d', default=os.getcwd(),
- help='specify alternate directory '
+ parser.add_argument('-d', '--directory', default=os.getcwd(),
+ help='serve this directory '
'(default: current directory)')
- parser.add_argument('port', action='store', default=8000, type=int,
- nargs='?',
- help='specify alternate port (default: 8000)')
+ parser.add_argument('-p', '--protocol', metavar='VERSION',
+ default='HTTP/1.0',
+ help='conform to this HTTP version '
+ '(default: %(default)s)')
+ parser.add_argument('port', default=8000, type=int, nargs='?',
+ help='bind to this port '
+ '(default: %(default)s)')
args = parser.parse_args()
if args.cgi:
handler_class = CGIHTTPRequestHandler
@@ -1292,4 +1347,5 @@ def finish_request(self, request, client_address):
ServerClass=DualStackServer,
port=args.port,
bind=args.bind,
+ protocol=args.protocol,
)
diff --git a/Lib/json/__init__.py b/Lib/json/__init__.py
index ed2c74771ea..c7a6dcdf77e 100644
--- a/Lib/json/__init__.py
+++ b/Lib/json/__init__.py
@@ -128,8 +128,9 @@ def dump(obj, fp, *, skipkeys=False, ensure_ascii=True, check_circular=True,
instead of raising a ``TypeError``.
If ``ensure_ascii`` is false, then the strings written to ``fp`` can
- contain non-ASCII characters if they appear in strings contained in
- ``obj``. Otherwise, all such characters are escaped in JSON strings.
+ contain non-ASCII and non-printable characters if they appear in strings
+ contained in ``obj``. Otherwise, all such characters are escaped in JSON
+ strings.
If ``check_circular`` is false, then the circular reference check
for container types will be skipped and a circular reference will
@@ -145,10 +146,11 @@ def dump(obj, fp, *, skipkeys=False, ensure_ascii=True, check_circular=True,
level of 0 will only insert newlines. ``None`` is the most compact
representation.
- If specified, ``separators`` should be an ``(item_separator, key_separator)``
- tuple. The default is ``(', ', ': ')`` if *indent* is ``None`` and
- ``(',', ': ')`` otherwise. To get the most compact JSON representation,
- you should specify ``(',', ':')`` to eliminate whitespace.
+ If specified, ``separators`` should be an ``(item_separator,
+ key_separator)`` tuple. The default is ``(', ', ': ')`` if *indent* is
+ ``None`` and ``(',', ': ')`` otherwise. To get the most compact JSON
+ representation, you should specify ``(',', ':')`` to eliminate
+ whitespace.
``default(obj)`` is a function that should return a serializable version
of obj or raise TypeError. The default simply raises TypeError.
@@ -189,9 +191,10 @@ def dumps(obj, *, skipkeys=False, ensure_ascii=True, check_circular=True,
(``str``, ``int``, ``float``, ``bool``, ``None``) will be skipped
instead of raising a ``TypeError``.
- If ``ensure_ascii`` is false, then the return value can contain non-ASCII
- characters if they appear in strings contained in ``obj``. Otherwise, all
- such characters are escaped in JSON strings.
+ If ``ensure_ascii`` is false, then the return value can contain
+ non-ASCII and non-printable characters if they appear in strings
+ contained in ``obj``. Otherwise, all such characters are escaped in
+ JSON strings.
If ``check_circular`` is false, then the circular reference check
for container types will be skipped and a circular reference will
@@ -207,10 +210,11 @@ def dumps(obj, *, skipkeys=False, ensure_ascii=True, check_circular=True,
level of 0 will only insert newlines. ``None`` is the most compact
representation.
- If specified, ``separators`` should be an ``(item_separator, key_separator)``
- tuple. The default is ``(', ', ': ')`` if *indent* is ``None`` and
- ``(',', ': ')`` otherwise. To get the most compact JSON representation,
- you should specify ``(',', ':')`` to eliminate whitespace.
+ If specified, ``separators`` should be an ``(item_separator,
+ key_separator)`` tuple. The default is ``(', ', ': ')`` if *indent* is
+ ``None`` and ``(',', ': ')`` otherwise. To get the most compact JSON
+ representation, you should specify ``(',', ':')`` to eliminate
+ whitespace.
``default(obj)`` is a function that should return a serializable version
of obj or raise TypeError. The default simply raises TypeError.
@@ -281,11 +285,12 @@ def load(fp, *, cls=None, object_hook=None, parse_float=None,
``object_hook`` will be used instead of the ``dict``. This feature
can be used to implement custom decoders (e.g. JSON-RPC class hinting).
- ``object_pairs_hook`` is an optional function that will be called with the
- result of any object literal decoded with an ordered list of pairs. The
- return value of ``object_pairs_hook`` will be used instead of the ``dict``.
- This feature can be used to implement custom decoders. If ``object_hook``
- is also defined, the ``object_pairs_hook`` takes priority.
+ ``object_pairs_hook`` is an optional function that will be called with
+ the result of any object literal decoded with an ordered list of pairs.
+ The return value of ``object_pairs_hook`` will be used instead of the
+ ``dict``. This feature can be used to implement custom decoders. If
+ ``object_hook`` is also defined, the ``object_pairs_hook`` takes
+ priority.
To use a custom ``JSONDecoder`` subclass, specify it with the ``cls``
kwarg; otherwise ``JSONDecoder`` is used.
@@ -306,11 +311,12 @@ def loads(s, *, cls=None, object_hook=None, parse_float=None,
``object_hook`` will be used instead of the ``dict``. This feature
can be used to implement custom decoders (e.g. JSON-RPC class hinting).
- ``object_pairs_hook`` is an optional function that will be called with the
- result of any object literal decoded with an ordered list of pairs. The
- return value of ``object_pairs_hook`` will be used instead of the ``dict``.
- This feature can be used to implement custom decoders. If ``object_hook``
- is also defined, the ``object_pairs_hook`` takes priority.
+ ``object_pairs_hook`` is an optional function that will be called with
+ the result of any object literal decoded with an ordered list of pairs.
+ The return value of ``object_pairs_hook`` will be used instead of the
+ ``dict``. This feature can be used to implement custom decoders. If
+ ``object_hook`` is also defined, the ``object_pairs_hook`` takes
+ priority.
``parse_float``, if specified, will be called with the string
of every JSON float to be decoded. By default this is equivalent to
diff --git a/Lib/json/decoder.py b/Lib/json/decoder.py
index 9e6ca981d76..db87724a897 100644
--- a/Lib/json/decoder.py
+++ b/Lib/json/decoder.py
@@ -311,10 +311,10 @@ def __init__(self, *, object_hook=None, parse_float=None,
place of the given ``dict``. This can be used to provide custom
deserializations (e.g. to support JSON-RPC class hinting).
- ``object_pairs_hook``, if specified will be called with the result of
- every JSON object decoded with an ordered list of pairs. The return
- value of ``object_pairs_hook`` will be used instead of the ``dict``.
- This feature can be used to implement custom decoders.
+ ``object_pairs_hook``, if specified will be called with the result
+ of every JSON object decoded with an ordered list of pairs. The
+ return value of ``object_pairs_hook`` will be used instead of the
+ ``dict``. This feature can be used to implement custom decoders.
If ``object_hook`` is also defined, the ``object_pairs_hook`` takes
priority.
diff --git a/Lib/json/encoder.py b/Lib/json/encoder.py
index 08ef39d1592..0671500d106 100644
--- a/Lib/json/encoder.py
+++ b/Lib/json/encoder.py
@@ -111,9 +111,10 @@ def __init__(self, *, skipkeys=False, ensure_ascii=True,
encoding of keys that are not str, int, float, bool or None.
If skipkeys is True, such items are simply skipped.
- If ensure_ascii is true, the output is guaranteed to be str
- objects with all incoming non-ASCII characters escaped. If
- ensure_ascii is false, the output can contain non-ASCII characters.
+ If ensure_ascii is true, the output is guaranteed to be str objects
+ with all incoming non-ASCII and non-printable characters escaped.
+ If ensure_ascii is false, the output can contain non-ASCII and
+ non-printable characters.
If check_circular is true, then lists, dicts, and custom encoded
objects will be checked for circular references during encoding to
@@ -134,14 +135,15 @@ def __init__(self, *, skipkeys=False, ensure_ascii=True,
indent level. An indent level of 0 will only insert newlines.
None is the most compact representation.
- If specified, separators should be an (item_separator, key_separator)
- tuple. The default is (', ', ': ') if *indent* is ``None`` and
- (',', ': ') otherwise. To get the most compact JSON representation,
- you should specify (',', ':') to eliminate whitespace.
+ If specified, separators should be an (item_separator,
+ key_separator) tuple. The default is (', ', ': ') if *indent* is
+ ``None`` and (',', ': ') otherwise. To get the most compact JSON
+ representation, you should specify (',', ':') to eliminate
+ whitespace.
If specified, default is a function that gets called for objects
- that can't otherwise be serialized. It should return a JSON encodable
- version of the object or raise a ``TypeError``.
+ that can't otherwise be serialized. It should return a JSON
+ encodable version of the object or raise a ``TypeError``.
"""
diff --git a/Lib/test/test__opcode.py b/Lib/test/test__opcode.py
index 60dcdc6cd70..045e010db4c 100644
--- a/Lib/test/test__opcode.py
+++ b/Lib/test/test__opcode.py
@@ -16,6 +16,7 @@ def check_bool_function_result(self, func, ops, expected):
self.assertIsInstance(func(op), bool)
self.assertEqual(func(op), expected)
+ @unittest.expectedFailure # TODO: RUSTPYTHON; Move LoadClosure to psudoes
def test_invalid_opcodes(self):
invalid = [-100, -1, 255, 512, 513, 1000]
self.check_bool_function_result(_opcode.is_valid, invalid, False)
@@ -27,7 +28,6 @@ def test_invalid_opcodes(self):
self.check_bool_function_result(_opcode.has_local, invalid, False)
self.check_bool_function_result(_opcode.has_exc, invalid, False)
- @unittest.expectedFailure # TODO: RUSTPYTHON - no instrumented opcodes
def test_is_valid(self):
names = [
'CACHE',
diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py
index 456767bbe0c..e543cc236c1 100644
--- a/Lib/test/test_copy.py
+++ b/Lib/test/test_copy.py
@@ -207,8 +207,6 @@ def __eq__(self, other):
self.assertIsNot(y, x)
self.assertEqual(y.foo, x.foo)
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_copy_inst_getnewargs_ex(self):
class C(int):
def __new__(cls, *, foo):
@@ -507,8 +505,6 @@ def __eq__(self, other):
self.assertEqual(y.foo, x.foo)
self.assertIsNot(y.foo, x.foo)
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_deepcopy_inst_getnewargs_ex(self):
class C(int):
def __new__(cls, *, foo):
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
index b7f93d1bac9..bf9b1875573 100644
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -698,7 +698,6 @@ def test_copy(self):
dialect = csv.get_dialect(name)
self.assertRaises(TypeError, copy.copy, dialect)
- @unittest.expectedFailure # TODO: RUSTPYTHON
def test_pickle(self):
for name in csv.list_dialects():
dialect = csv.get_dialect(name)
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py
index 7420a49b8f7..2ad302690c0 100644
--- a/Lib/test/test_descr.py
+++ b/Lib/test/test_descr.py
@@ -5258,7 +5258,6 @@ def _check_reduce(self, proto, obj, args=(), kwargs={}, state=None,
self.assertEqual(obj.__reduce_ex__(proto), reduce_value)
self.assertEqual(obj.__reduce__(), reduce_value)
- @unittest.expectedFailure # TODO: RUSTPYTHON
def test_reduce(self):
protocols = range(pickle.HIGHEST_PROTOCOL + 1)
args = (-101, "spam")
@@ -5382,7 +5381,6 @@ class C16(list):
for proto in protocols:
self._check_reduce(proto, obj, listitems=list(obj))
- @unittest.expectedFailure # TODO: RUSTPYTHON
def test_special_method_lookup(self):
protocols = range(pickle.HIGHEST_PROTOCOL + 1)
class Picky:
@@ -5515,7 +5513,6 @@ class E(C):
y = pickle_copier.copy(x)
self._assert_is_copy(x, y)
- @unittest.expectedFailure # TODO: RUSTPYTHON
def test_reduce_copying(self):
# Tests pickling and copying new-style classes and objects.
global C1
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index 5a961711cce..21a3b8edd4e 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -2130,7 +2130,6 @@ class NEI(NamedInt, Enum):
test_pickle_dump_load(self.assertIs, NEI.y)
test_pickle_dump_load(self.assertIs, NEI)
- @unittest.expectedFailure # TODO: RUSTPYTHON; fails on pickle
def test_subclasses_with_getnewargs_ex(self):
class NamedInt(int):
__qualname__ = 'NamedInt' # needed for pickle protocol 4
diff --git a/Lib/test/test_http_cookiejar.py b/Lib/test/test_http_cookiejar.py
index 68a693c78b3..51fa4a3d413 100644
--- a/Lib/test/test_http_cookiejar.py
+++ b/Lib/test/test_http_cookiejar.py
@@ -1,14 +1,16 @@
"""Tests for http/cookiejar.py."""
import os
+import stat
+import sys
import re
-import test.support
+from test import support
from test.support import os_helper
from test.support import warnings_helper
+from test.support.testcase import ExtraAssertions
import time
import unittest
import urllib.request
-import pathlib
from http.cookiejar import (time2isoz, http2time, iso2time, time2netscape,
parse_ns_headers, join_header_words, split_header_words, Cookie,
@@ -17,6 +19,7 @@
reach, is_HDN, domain_match, user_domain_match, request_path,
request_port, request_host)
+mswindows = (sys.platform == "win32")
class DateTimeTests(unittest.TestCase):
@@ -104,8 +107,7 @@ def test_http2time_formats(self):
self.assertEqual(http2time(s.lower()), test_t, s.lower())
self.assertEqual(http2time(s.upper()), test_t, s.upper())
- def test_http2time_garbage(self):
- for test in [
+ @support.subTests('test', [
'',
'Garbage',
'Mandag 16. September 1996',
@@ -120,10 +122,9 @@ def test_http2time_garbage(self):
'08-01-3697739',
'09 Feb 19942632 22:23:32 GMT',
'Wed, 09 Feb 1994834 22:23:32 GMT',
- ]:
- self.assertIsNone(http2time(test),
- "http2time(%s) is not None\n"
- "http2time(test) %s" % (test, http2time(test)))
+ ])
+ def test_http2time_garbage(self, test):
+ self.assertIsNone(http2time(test))
def test_http2time_redos_regression_actually_completes(self):
# LOOSE_HTTP_DATE_RE was vulnerable to malicious input which caused catastrophic backtracking (REDoS).
@@ -148,9 +149,7 @@ def parse_date(text):
self.assertEqual(parse_date("1994-02-03 19:45:29 +0530"),
(1994, 2, 3, 14, 15, 29))
- def test_iso2time_formats(self):
- # test iso2time for supported dates.
- tests = [
+ @support.subTests('s', [
'1994-02-03 00:00:00 -0000', # ISO 8601 format
'1994-02-03 00:00:00 +0000', # ISO 8601 format
'1994-02-03 00:00:00', # zone is optional
@@ -163,16 +162,15 @@ def test_iso2time_formats(self):
# A few tests with extra space at various places
' 1994-02-03 ',
' 1994-02-03T00:00:00 ',
- ]
-
+ ])
+ def test_iso2time_formats(self, s):
+ # test iso2time for supported dates.
test_t = 760233600 # assume broken POSIX counting of seconds
- for s in tests:
- self.assertEqual(iso2time(s), test_t, s)
- self.assertEqual(iso2time(s.lower()), test_t, s.lower())
- self.assertEqual(iso2time(s.upper()), test_t, s.upper())
+ self.assertEqual(iso2time(s), test_t, s)
+ self.assertEqual(iso2time(s.lower()), test_t, s.lower())
+ self.assertEqual(iso2time(s.upper()), test_t, s.upper())
- def test_iso2time_garbage(self):
- for test in [
+ @support.subTests('test', [
'',
'Garbage',
'Thursday, 03-Feb-94 00:00:00 GMT',
@@ -185,11 +183,10 @@ def test_iso2time_garbage(self):
'01-01-1980 00:00:62',
'01-01-1980T00:00:62',
'19800101T250000Z',
- ]:
- self.assertIsNone(iso2time(test),
- "iso2time(%r)" % test)
+ ])
+ def test_iso2time_garbage(self, test):
+ self.assertIsNone(iso2time(test))
- @unittest.skip("TODO, RUSTPYTHON, regressed to quadratic complexity")
def test_iso2time_performance_regression(self):
# If ISO_DATE_RE regresses to quadratic complexity, this test will take a very long time to succeed.
# If fixed, it should complete within a fraction of a second.
@@ -199,24 +196,23 @@ def test_iso2time_performance_regression(self):
class HeaderTests(unittest.TestCase):
- def test_parse_ns_headers(self):
- # quotes should be stripped
- expected = [[('foo', 'bar'), ('expires', 2209069412), ('version', '0')]]
- for hdr in [
+ @support.subTests('hdr', [
'foo=bar; expires=01 Jan 2040 22:23:32 GMT',
'foo=bar; expires="01 Jan 2040 22:23:32 GMT"',
- ]:
- self.assertEqual(parse_ns_headers([hdr]), expected)
-
- def test_parse_ns_headers_version(self):
-
+ ])
+ def test_parse_ns_headers(self, hdr):
# quotes should be stripped
- expected = [[('foo', 'bar'), ('version', '1')]]
- for hdr in [
+ expected = [[('foo', 'bar'), ('expires', 2209069412), ('version', '0')]]
+ self.assertEqual(parse_ns_headers([hdr]), expected)
+
+ @support.subTests('hdr', [
'foo=bar; version="1"',
'foo=bar; Version="1"',
- ]:
- self.assertEqual(parse_ns_headers([hdr]), expected)
+ ])
+ def test_parse_ns_headers_version(self, hdr):
+ # quotes should be stripped
+ expected = [[('foo', 'bar'), ('version', '1')]]
+ self.assertEqual(parse_ns_headers([hdr]), expected)
def test_parse_ns_headers_special_names(self):
# names such as 'expires' are not special in first name=value pair
@@ -232,8 +228,7 @@ def test_join_header_words(self):
self.assertEqual(join_header_words([[]]), "")
- def test_split_header_words(self):
- tests = [
+ @support.subTests('arg,expect', [
("foo", [[("foo", None)]]),
("foo=bar", [[("foo", "bar")]]),
(" foo ", [[("foo", None)]]),
@@ -250,24 +245,22 @@ def test_split_header_words(self):
(r'foo; bar=baz, spam=, foo="\,\;\"", bar= ',
[[("foo", None), ("bar", "baz")],
[("spam", "")], [("foo", ',;"')], [("bar", "")]]),
- ]
-
- for arg, expect in tests:
- try:
- result = split_header_words([arg])
- except:
- import traceback, io
- f = io.StringIO()
- traceback.print_exc(None, f)
- result = "(error -- traceback follows)\n\n%s" % f.getvalue()
- self.assertEqual(result, expect, """
+ ])
+ def test_split_header_words(self, arg, expect):
+ try:
+ result = split_header_words([arg])
+ except:
+ import traceback, io
+ f = io.StringIO()
+ traceback.print_exc(None, f)
+ result = "(error -- traceback follows)\n\n%s" % f.getvalue()
+ self.assertEqual(result, expect, """
When parsing: '%s'
Expected: '%s'
Got: '%s'
""" % (arg, expect, result))
- def test_roundtrip(self):
- tests = [
+ @support.subTests('arg,expect', [
("foo", "foo"),
("foo=bar", "foo=bar"),
(" foo ", "foo"),
@@ -276,23 +269,35 @@ def test_roundtrip(self):
("foo=bar;bar=baz", "foo=bar; bar=baz"),
('foo bar baz', "foo; bar; baz"),
(r'foo="\"" bar="\\"', r'foo="\""; bar="\\"'),
+ ("föo=bär", 'föo="bär"'),
('foo,,,bar', 'foo, bar'),
('foo=bar,bar=baz', 'foo=bar, bar=baz'),
+ ("foo=\n", 'foo=""'),
+ ('foo="\n"', 'foo="\n"'),
+ ('foo=bar\n', 'foo=bar'),
+ ('foo="bar\n"', 'foo="bar\n"'),
+ ('foo=bar\nbaz', 'foo=bar; baz'),
+ ('foo="bar\nbaz"', 'foo="bar\nbaz"'),
('text/html; charset=iso-8859-1',
- 'text/html; charset="iso-8859-1"'),
+ 'text/html; charset=iso-8859-1'),
+
+ ('text/html; charset="iso-8859/1"',
+ 'text/html; charset="iso-8859/1"'),
('foo="bar"; port="80,81"; discard, bar=baz',
'foo=bar; port="80,81"; discard, bar=baz'),
(r'Basic realm="\"foo\\\\bar\""',
- r'Basic; realm="\"foo\\\\bar\""')
- ]
-
- for arg, expect in tests:
- input = split_header_words([arg])
- res = join_header_words(input)
- self.assertEqual(res, expect, """
+ r'Basic; realm="\"foo\\\\bar\""'),
+
+ ('n; foo="foo;_", bar="foo,_"',
+ 'n; foo="foo;_", bar="foo,_"'),
+ ])
+ def test_roundtrip(self, arg, expect):
+ input = split_header_words([arg])
+ res = join_header_words(input)
+ self.assertEqual(res, expect, """
When parsing: '%s'
Expected: '%s'
Got: '%s'
@@ -336,9 +341,9 @@ def test_constructor_with_str(self):
self.assertEqual(c.filename, filename)
def test_constructor_with_path_like(self):
- filename = pathlib.Path(os_helper.TESTFN)
- c = LWPCookieJar(filename)
- self.assertEqual(c.filename, os.fspath(filename))
+ filename = os_helper.TESTFN
+ c = LWPCookieJar(os_helper.FakePath(filename))
+ self.assertEqual(c.filename, filename)
def test_constructor_with_none(self):
c = LWPCookieJar(None)
@@ -365,10 +370,63 @@ def test_lwp_valueless_cookie(self):
c = LWPCookieJar()
c.load(filename, ignore_discard=True)
finally:
- try: os.unlink(filename)
- except OSError: pass
+ os_helper.unlink(filename)
self.assertEqual(c._cookies["www.acme.com"]["/"]["boo"].value, None)
+ @unittest.skipIf(mswindows, "windows file permissions are incompatible with file modes")
+ @os_helper.skip_unless_working_chmod
+ def test_lwp_filepermissions(self):
+ # Cookie file should only be readable by the creator
+ filename = os_helper.TESTFN
+ c = LWPCookieJar()
+ interact_netscape(c, "http://www.acme.com/", 'boo')
+ try:
+ c.save(filename, ignore_discard=True)
+ st = os.stat(filename)
+ self.assertEqual(stat.S_IMODE(st.st_mode), 0o600)
+ finally:
+ os_helper.unlink(filename)
+
+ @unittest.skipIf(mswindows, "windows file permissions are incompatible with file modes")
+ @os_helper.skip_unless_working_chmod
+ def test_mozilla_filepermissions(self):
+ # Cookie file should only be readable by the creator
+ filename = os_helper.TESTFN
+ c = MozillaCookieJar()
+ interact_netscape(c, "http://www.acme.com/", 'boo')
+ try:
+ c.save(filename, ignore_discard=True)
+ st = os.stat(filename)
+ self.assertEqual(stat.S_IMODE(st.st_mode), 0o600)
+ finally:
+ os_helper.unlink(filename)
+
+ @unittest.skipIf(mswindows, "windows file permissions are incompatible with file modes")
+ @os_helper.skip_unless_working_chmod
+ def test_cookie_files_are_truncated(self):
+ filename = os_helper.TESTFN
+ for cookiejar_class in (LWPCookieJar, MozillaCookieJar):
+ c = cookiejar_class(filename)
+
+ req = urllib.request.Request("http://www.acme.com/")
+ headers = ["Set-Cookie: pll_lang=en; Max-Age=31536000; path=/"]
+ res = FakeResponse(headers, "http://www.acme.com/")
+ c.extract_cookies(res, req)
+ self.assertEqual(len(c), 1)
+
+ try:
+ # Save the first version with contents:
+ c.save()
+ # Now, clear cookies and re-save:
+ c.clear()
+ c.save()
+ # Check that file was truncated:
+ c.load()
+ finally:
+ os_helper.unlink(filename)
+
+ self.assertEqual(len(c), 0)
+
def test_bad_magic(self):
# OSErrors (eg. file doesn't exist) are allowed to propagate
filename = os_helper.TESTFN
@@ -392,8 +450,7 @@ def test_bad_magic(self):
c = cookiejar_class()
self.assertRaises(LoadError, c.load, filename)
finally:
- try: os.unlink(filename)
- except OSError: pass
+ os_helper.unlink(filename)
class CookieTests(unittest.TestCase):
# XXX
@@ -442,14 +499,7 @@ class CookieTests(unittest.TestCase):
## just the 7 special TLD's listed in their spec. And folks rely on
## that...
- def test_domain_return_ok(self):
- # test optimization: .domain_return_ok() should filter out most
- # domains in the CookieJar before we try to access them (because that
- # may require disk access -- in particular, with MSIECookieJar)
- # This is only a rough check for performance reasons, so it's not too
- # critical as long as it's sufficiently liberal.
- pol = DefaultCookiePolicy()
- for url, domain, ok in [
+ @support.subTests('url,domain,ok', [
("http://foo.bar.com/", "blah.com", False),
("http://foo.bar.com/", "rhubarb.blah.com", False),
("http://foo.bar.com/", "rhubarb.foo.bar.com", False),
@@ -469,11 +519,18 @@ def test_domain_return_ok(self):
("http://foo/", ".local", True),
("http://barfoo.com", ".foo.com", False),
("http://barfoo.com", "foo.com", False),
- ]:
- request = urllib.request.Request(url)
- r = pol.domain_return_ok(domain, request)
- if ok: self.assertTrue(r)
- else: self.assertFalse(r)
+ ])
+ def test_domain_return_ok(self, url, domain, ok):
+ # test optimization: .domain_return_ok() should filter out most
+ # domains in the CookieJar before we try to access them (because that
+ # may require disk access -- in particular, with MSIECookieJar)
+ # This is only a rough check for performance reasons, so it's not too
+ # critical as long as it's sufficiently liberal.
+ pol = DefaultCookiePolicy()
+ request = urllib.request.Request(url)
+ r = pol.domain_return_ok(domain, request)
+ if ok: self.assertTrue(r)
+ else: self.assertFalse(r)
def test_missing_value(self):
# missing = sign in Cookie: header is regarded by Mozilla as a missing
@@ -489,7 +546,7 @@ def test_missing_value(self):
self.assertIsNone(cookie.value)
self.assertEqual(cookie.name, '"spam"')
self.assertEqual(lwp_cookie_str(cookie), (
- r'"spam"; path="/foo/"; domain="www.acme.com"; '
+ r'"spam"; path="/foo/"; domain=www.acme.com; '
'path_spec; discard; version=0'))
old_str = repr(c)
c.save(ignore_expires=True, ignore_discard=True)
@@ -497,7 +554,7 @@ def test_missing_value(self):
c = MozillaCookieJar(filename)
c.revert(ignore_expires=True, ignore_discard=True)
finally:
- os.unlink(c.filename)
+ os_helper.unlink(c.filename)
# cookies unchanged apart from lost info re. whether path was specified
self.assertEqual(
repr(c),
@@ -507,10 +564,7 @@ def test_missing_value(self):
self.assertEqual(interact_netscape(c, "http://www.acme.com/foo/"),
'"spam"; eggs')
- def test_rfc2109_handling(self):
- # RFC 2109 cookies are handled as RFC 2965 or Netscape cookies,
- # dependent on policy settings
- for rfc2109_as_netscape, rfc2965, version in [
+ @support.subTests('rfc2109_as_netscape,rfc2965,version', [
# default according to rfc2965 if not explicitly specified
(None, False, 0),
(None, True, 1),
@@ -519,24 +573,27 @@ def test_rfc2109_handling(self):
(False, True, 1),
(True, False, 0),
(True, True, 0),
- ]:
- policy = DefaultCookiePolicy(
- rfc2109_as_netscape=rfc2109_as_netscape,
- rfc2965=rfc2965)
- c = CookieJar(policy)
- interact_netscape(c, "http://www.example.com/", "ni=ni; Version=1")
- try:
- cookie = c._cookies["www.example.com"]["/"]["ni"]
- except KeyError:
- self.assertIsNone(version) # didn't expect a stored cookie
- else:
- self.assertEqual(cookie.version, version)
- # 2965 cookies are unaffected
- interact_2965(c, "http://www.example.com/",
- "foo=bar; Version=1")
- if rfc2965:
- cookie2965 = c._cookies["www.example.com"]["/"]["foo"]
- self.assertEqual(cookie2965.version, 1)
+ ])
+ def test_rfc2109_handling(self, rfc2109_as_netscape, rfc2965, version):
+ # RFC 2109 cookies are handled as RFC 2965 or Netscape cookies,
+ # dependent on policy settings
+ policy = DefaultCookiePolicy(
+ rfc2109_as_netscape=rfc2109_as_netscape,
+ rfc2965=rfc2965)
+ c = CookieJar(policy)
+ interact_netscape(c, "http://www.example.com/", "ni=ni; Version=1")
+ try:
+ cookie = c._cookies["www.example.com"]["/"]["ni"]
+ except KeyError:
+ self.assertIsNone(version) # didn't expect a stored cookie
+ else:
+ self.assertEqual(cookie.version, version)
+ # 2965 cookies are unaffected
+ interact_2965(c, "http://www.example.com/",
+ "foo=bar; Version=1")
+ if rfc2965:
+ cookie2965 = c._cookies["www.example.com"]["/"]["foo"]
+ self.assertEqual(cookie2965.version, 1)
def test_ns_parser(self):
c = CookieJar()
@@ -597,8 +654,6 @@ def test_ns_parser_special_names(self):
self.assertIn('expires', cookies)
self.assertIn('version', cookies)
- # TODO: RUSTPYTHON; need to update http library to remove warnings
- @unittest.expectedFailure
def test_expires(self):
# if expires is in future, keep cookie...
c = CookieJar()
@@ -706,8 +761,7 @@ def test_default_path_with_query(self):
# Cookie is sent back to the same URI.
self.assertEqual(interact_netscape(cj, uri), value)
- def test_escape_path(self):
- cases = [
+ @support.subTests('arg,result', [
# quoted safe
("/foo%2f/bar", "/foo%2F/bar"),
("/foo%2F/bar", "/foo%2F/bar"),
@@ -727,9 +781,9 @@ def test_escape_path(self):
("/foo/bar\u00fc", "/foo/bar%C3%BC"), # UTF-8 encoded
# unicode
("/foo/bar\uabcd", "/foo/bar%EA%AF%8D"), # UTF-8 encoded
- ]
- for arg, result in cases:
- self.assertEqual(escape_path(arg), result)
+ ])
+ def test_escape_path(self, arg, result):
+ self.assertEqual(escape_path(arg), result)
def test_request_path(self):
# with parameters
@@ -923,6 +977,48 @@ def test_two_component_domain_ns(self):
## self.assertEqual(len(c), 2)
self.assertEqual(len(c), 4)
+ def test_localhost_domain(self):
+ c = CookieJar()
+
+ interact_netscape(c, "http://localhost", "foo=bar; domain=localhost;")
+
+ self.assertEqual(len(c), 1)
+
+ def test_localhost_domain_contents(self):
+ c = CookieJar()
+
+ interact_netscape(c, "http://localhost", "foo=bar; domain=localhost;")
+
+ self.assertEqual(c._cookies[".localhost"]["/"]["foo"].value, "bar")
+
+ def test_localhost_domain_contents_2(self):
+ c = CookieJar()
+
+ interact_netscape(c, "http://localhost", "foo=bar;")
+
+ self.assertEqual(c._cookies["localhost.local"]["/"]["foo"].value, "bar")
+
+ def test_evil_nonlocal_domain(self):
+ c = CookieJar()
+
+ interact_netscape(c, "http://evil.com", "foo=bar; domain=.localhost")
+
+ self.assertEqual(len(c), 0)
+
+ def test_evil_local_domain(self):
+ c = CookieJar()
+
+ interact_netscape(c, "http://localhost", "foo=bar; domain=.evil.com")
+
+ self.assertEqual(len(c), 0)
+
+ def test_evil_local_domain_2(self):
+ c = CookieJar()
+
+ interact_netscape(c, "http://localhost", "foo=bar; domain=.someother.local")
+
+ self.assertEqual(len(c), 0)
+
def test_two_component_domain_rfc2965(self):
pol = DefaultCookiePolicy(rfc2965=True)
c = CookieJar(pol)
@@ -1254,11 +1350,11 @@ def test_Cookie_iterator(self):
r'port="90,100, 80,8080"; '
r'max-age=100; Comment = "Just kidding! (\"|\\\\) "')
- versions = [1, 1, 1, 0, 1]
- names = ["bang", "foo", "foo", "spam", "foo"]
- domains = [".sol.no", "blah.spam.org", "www.acme.com",
- "www.acme.com", "www.acme.com"]
- paths = ["/", "/", "/", "/blah", "/blah/"]
+ versions = [1, 0, 1, 1, 1]
+ names = ["foo", "spam", "foo", "foo", "bang"]
+ domains = ["blah.spam.org", "www.acme.com", "www.acme.com",
+ "www.acme.com", ".sol.no"]
+ paths = ["/", "/blah", "/blah/", "/", "/"]
for i in range(4):
i = 0
@@ -1331,7 +1427,7 @@ def cookiejar_from_cookie_headers(headers):
self.assertIsNone(cookie.expires)
-class LWPCookieTests(unittest.TestCase):
+class LWPCookieTests(unittest.TestCase, ExtraAssertions):
# Tests taken from libwww-perl, with a few modifications and additions.
def test_netscape_example_1(self):
@@ -1423,7 +1519,7 @@ def test_netscape_example_1(self):
h = req.get_header("Cookie")
self.assertIn("PART_NUMBER=ROCKET_LAUNCHER_0001", h)
self.assertIn("CUSTOMER=WILE_E_COYOTE", h)
- self.assertTrue(h.startswith("SHIPPING=FEDEX;"))
+ self.assertStartsWith(h, "SHIPPING=FEDEX;")
def test_netscape_example_2(self):
# Second Example transaction sequence:
@@ -1727,8 +1823,7 @@ def test_rejection(self):
c = LWPCookieJar(policy=pol)
c.load(filename, ignore_discard=True)
finally:
- try: os.unlink(filename)
- except OSError: pass
+ os_helper.unlink(filename)
self.assertEqual(old, repr(c))
@@ -1787,8 +1882,7 @@ def save_and_restore(cj, ignore_discard):
DefaultCookiePolicy(rfc2965=True))
new_c.load(ignore_discard=ignore_discard)
finally:
- try: os.unlink(filename)
- except OSError: pass
+ os_helper.unlink(filename)
return new_c
new_c = save_and_restore(c, True)
diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py
index d4a6eefe322..275578d53cb 100644
--- a/Lib/test/test_httplib.py
+++ b/Lib/test/test_httplib.py
@@ -1,4 +1,4 @@
-import sys
+import enum
import errno
from http import client, HTTPStatus
import io
@@ -8,7 +8,6 @@
import re
import socket
import threading
-import warnings
import unittest
from unittest import mock
@@ -17,16 +16,19 @@
from test import support
from test.support import os_helper
from test.support import socket_helper
-from test.support import warnings_helper
+from test.support.testcase import ExtraAssertions
+support.requires_working_socket(module=True)
here = os.path.dirname(__file__)
# Self-signed cert file for 'localhost'
-CERT_localhost = os.path.join(here, 'certdata/keycert.pem')
+CERT_localhost = os.path.join(here, 'certdata', 'keycert.pem')
# Self-signed cert file for 'fakehostname'
-CERT_fakehostname = os.path.join(here, 'certdata/keycert2.pem')
+CERT_fakehostname = os.path.join(here, 'certdata', 'keycert2.pem')
# Self-signed cert file for self-signed.pythontest.net
-CERT_selfsigned_pythontestdotnet = os.path.join(here, 'certdata/selfsigned_pythontestdotnet.pem')
+CERT_selfsigned_pythontestdotnet = os.path.join(
+ here, 'certdata', 'selfsigned_pythontestdotnet.pem',
+)
# constants for testing chunked encoding
chunked_start = (
@@ -133,7 +135,7 @@ def connect(self):
def create_connection(self, *pos, **kw):
return FakeSocket(*self.fake_socket_args)
-class HeaderTests(TestCase):
+class HeaderTests(TestCase, ExtraAssertions):
def test_auto_headers(self):
# Some headers are added automatically, but should not be added by
# .request() if they are explicitly set.
@@ -272,7 +274,7 @@ def test_ipv6host_header(self):
sock = FakeSocket('')
conn.sock = sock
conn.request('GET', '/foo')
- self.assertTrue(sock.data.startswith(expected))
+ self.assertStartsWith(sock.data, expected)
expected = b'GET /foo HTTP/1.1\r\nHost: [2001:102A::]\r\n' \
b'Accept-Encoding: identity\r\n\r\n'
@@ -280,7 +282,23 @@ def test_ipv6host_header(self):
sock = FakeSocket('')
conn.sock = sock
conn.request('GET', '/foo')
- self.assertTrue(sock.data.startswith(expected))
+ self.assertStartsWith(sock.data, expected)
+
+ expected = b'GET /foo HTTP/1.1\r\nHost: [fe80::]\r\n' \
+ b'Accept-Encoding: identity\r\n\r\n'
+ conn = client.HTTPConnection('[fe80::%2]')
+ sock = FakeSocket('')
+ conn.sock = sock
+ conn.request('GET', '/foo')
+ self.assertStartsWith(sock.data, expected)
+
+ expected = b'GET /foo HTTP/1.1\r\nHost: [fe80::]:81\r\n' \
+ b'Accept-Encoding: identity\r\n\r\n'
+ conn = client.HTTPConnection('[fe80::%2]:81')
+ sock = FakeSocket('')
+ conn.sock = sock
+ conn.request('GET', '/foo')
+ self.assertStartsWith(sock.data, expected)
def test_malformed_headers_coped_with(self):
# Issue 19996
@@ -318,9 +336,9 @@ def test_parse_all_octets(self):
self.assertIsNotNone(resp.getheader('obs-text'))
self.assertIn('obs-text', resp.msg)
for folded in (resp.getheader('obs-fold'), resp.msg['obs-fold']):
- self.assertTrue(folded.startswith('text'))
+ self.assertStartsWith(folded, 'text')
self.assertIn(' folded with space', folded)
- self.assertTrue(folded.endswith('folded with tab'))
+ self.assertEndsWith(folded, 'folded with tab')
def test_invalid_headers(self):
conn = client.HTTPConnection('example.com')
@@ -520,11 +538,203 @@ def _parse_chunked(self, data):
return b''.join(body)
-class BasicTest(TestCase):
+class BasicTest(TestCase, ExtraAssertions):
def test_dir_with_added_behavior_on_status(self):
# see issue40084
self.assertTrue({'description', 'name', 'phrase', 'value'} <= set(dir(HTTPStatus(404))))
+ def test_simple_httpstatus(self):
+ class CheckedHTTPStatus(enum.IntEnum):
+ """HTTP status codes and reason phrases
+
+ Status codes from the following RFCs are all observed:
+
+ * RFC 7231: Hypertext Transfer Protocol (HTTP/1.1), obsoletes 2616
+ * RFC 6585: Additional HTTP Status Codes
+ * RFC 3229: Delta encoding in HTTP
+ * RFC 4918: HTTP Extensions for WebDAV, obsoletes 2518
+ * RFC 5842: Binding Extensions to WebDAV
+ * RFC 7238: Permanent Redirect
+ * RFC 2295: Transparent Content Negotiation in HTTP
+ * RFC 2774: An HTTP Extension Framework
+ * RFC 7725: An HTTP Status Code to Report Legal Obstacles
+ * RFC 7540: Hypertext Transfer Protocol Version 2 (HTTP/2)
+ * RFC 2324: Hyper Text Coffee Pot Control Protocol (HTCPCP/1.0)
+ * RFC 8297: An HTTP Status Code for Indicating Hints
+ * RFC 8470: Using Early Data in HTTP
+ """
+ def __new__(cls, value, phrase, description=''):
+ obj = int.__new__(cls, value)
+ obj._value_ = value
+
+ obj.phrase = phrase
+ obj.description = description
+ return obj
+
+ @property
+ def is_informational(self):
+ return 100 <= self <= 199
+
+ @property
+ def is_success(self):
+ return 200 <= self <= 299
+
+ @property
+ def is_redirection(self):
+ return 300 <= self <= 399
+
+ @property
+ def is_client_error(self):
+ return 400 <= self <= 499
+
+ @property
+ def is_server_error(self):
+ return 500 <= self <= 599
+
+ # informational
+ CONTINUE = 100, 'Continue', 'Request received, please continue'
+ SWITCHING_PROTOCOLS = (101, 'Switching Protocols',
+ 'Switching to new protocol; obey Upgrade header')
+ PROCESSING = 102, 'Processing'
+ EARLY_HINTS = 103, 'Early Hints'
+ # success
+ OK = 200, 'OK', 'Request fulfilled, document follows'
+ CREATED = 201, 'Created', 'Document created, URL follows'
+ ACCEPTED = (202, 'Accepted',
+ 'Request accepted, processing continues off-line')
+ NON_AUTHORITATIVE_INFORMATION = (203,
+ 'Non-Authoritative Information', 'Request fulfilled from cache')
+ NO_CONTENT = 204, 'No Content', 'Request fulfilled, nothing follows'
+ RESET_CONTENT = 205, 'Reset Content', 'Clear input form for further input'
+ PARTIAL_CONTENT = 206, 'Partial Content', 'Partial content follows'
+ MULTI_STATUS = 207, 'Multi-Status'
+ ALREADY_REPORTED = 208, 'Already Reported'
+ IM_USED = 226, 'IM Used'
+ # redirection
+ MULTIPLE_CHOICES = (300, 'Multiple Choices',
+ 'Object has several resources -- see URI list')
+ MOVED_PERMANENTLY = (301, 'Moved Permanently',
+ 'Object moved permanently -- see URI list')
+ FOUND = 302, 'Found', 'Object moved temporarily -- see URI list'
+ SEE_OTHER = 303, 'See Other', 'Object moved -- see Method and URL list'
+ NOT_MODIFIED = (304, 'Not Modified',
+ 'Document has not changed since given time')
+ USE_PROXY = (305, 'Use Proxy',
+ 'You must use proxy specified in Location to access this resource')
+ TEMPORARY_REDIRECT = (307, 'Temporary Redirect',
+ 'Object moved temporarily -- see URI list')
+ PERMANENT_REDIRECT = (308, 'Permanent Redirect',
+ 'Object moved permanently -- see URI list')
+ # client error
+ BAD_REQUEST = (400, 'Bad Request',
+ 'Bad request syntax or unsupported method')
+ UNAUTHORIZED = (401, 'Unauthorized',
+ 'No permission -- see authorization schemes')
+ PAYMENT_REQUIRED = (402, 'Payment Required',
+ 'No payment -- see charging schemes')
+ FORBIDDEN = (403, 'Forbidden',
+ 'Request forbidden -- authorization will not help')
+ NOT_FOUND = (404, 'Not Found',
+ 'Nothing matches the given URI')
+ METHOD_NOT_ALLOWED = (405, 'Method Not Allowed',
+ 'Specified method is invalid for this resource')
+ NOT_ACCEPTABLE = (406, 'Not Acceptable',
+ 'URI not available in preferred format')
+ PROXY_AUTHENTICATION_REQUIRED = (407,
+ 'Proxy Authentication Required',
+ 'You must authenticate with this proxy before proceeding')
+ REQUEST_TIMEOUT = (408, 'Request Timeout',
+ 'Request timed out; try again later')
+ CONFLICT = 409, 'Conflict', 'Request conflict'
+ GONE = (410, 'Gone',
+ 'URI no longer exists and has been permanently removed')
+ LENGTH_REQUIRED = (411, 'Length Required',
+ 'Client must specify Content-Length')
+ PRECONDITION_FAILED = (412, 'Precondition Failed',
+ 'Precondition in headers is false')
+ CONTENT_TOO_LARGE = (413, 'Content Too Large',
+ 'Content is too large')
+ REQUEST_ENTITY_TOO_LARGE = CONTENT_TOO_LARGE
+ URI_TOO_LONG = (414, 'URI Too Long', 'URI is too long')
+ REQUEST_URI_TOO_LONG = URI_TOO_LONG
+ UNSUPPORTED_MEDIA_TYPE = (415, 'Unsupported Media Type',
+ 'Entity body in unsupported format')
+ RANGE_NOT_SATISFIABLE = (416,
+ 'Range Not Satisfiable',
+ 'Cannot satisfy request range')
+ REQUESTED_RANGE_NOT_SATISFIABLE = RANGE_NOT_SATISFIABLE
+ EXPECTATION_FAILED = (417, 'Expectation Failed',
+ 'Expect condition could not be satisfied')
+ IM_A_TEAPOT = (418, 'I\'m a Teapot',
+ 'Server refuses to brew coffee because it is a teapot.')
+ MISDIRECTED_REQUEST = (421, 'Misdirected Request',
+ 'Server is not able to produce a response')
+ UNPROCESSABLE_CONTENT = 422, 'Unprocessable Content'
+ UNPROCESSABLE_ENTITY = UNPROCESSABLE_CONTENT
+ LOCKED = 423, 'Locked'
+ FAILED_DEPENDENCY = 424, 'Failed Dependency'
+ TOO_EARLY = 425, 'Too Early'
+ UPGRADE_REQUIRED = 426, 'Upgrade Required'
+ PRECONDITION_REQUIRED = (428, 'Precondition Required',
+ 'The origin server requires the request to be conditional')
+ TOO_MANY_REQUESTS = (429, 'Too Many Requests',
+ 'The user has sent too many requests in '
+ 'a given amount of time ("rate limiting")')
+ REQUEST_HEADER_FIELDS_TOO_LARGE = (431,
+ 'Request Header Fields Too Large',
+ 'The server is unwilling to process the request because its header '
+ 'fields are too large')
+ UNAVAILABLE_FOR_LEGAL_REASONS = (451,
+ 'Unavailable For Legal Reasons',
+ 'The server is denying access to the '
+ 'resource as a consequence of a legal demand')
+ # server errors
+ INTERNAL_SERVER_ERROR = (500, 'Internal Server Error',
+ 'Server got itself in trouble')
+ NOT_IMPLEMENTED = (501, 'Not Implemented',
+ 'Server does not support this operation')
+ BAD_GATEWAY = (502, 'Bad Gateway',
+ 'Invalid responses from another server/proxy')
+ SERVICE_UNAVAILABLE = (503, 'Service Unavailable',
+ 'The server cannot process the request due to a high load')
+ GATEWAY_TIMEOUT = (504, 'Gateway Timeout',
+ 'The gateway server did not receive a timely response')
+ HTTP_VERSION_NOT_SUPPORTED = (505, 'HTTP Version Not Supported',
+ 'Cannot fulfill request')
+ VARIANT_ALSO_NEGOTIATES = 506, 'Variant Also Negotiates'
+ INSUFFICIENT_STORAGE = 507, 'Insufficient Storage'
+ LOOP_DETECTED = 508, 'Loop Detected'
+ NOT_EXTENDED = 510, 'Not Extended'
+ NETWORK_AUTHENTICATION_REQUIRED = (511,
+ 'Network Authentication Required',
+ 'The client needs to authenticate to gain network access')
+ enum._test_simple_enum(CheckedHTTPStatus, HTTPStatus)
+
+ def test_httpstatus_range(self):
+ """Checks that the statuses are in the 100-599 range"""
+
+ for member in HTTPStatus.__members__.values():
+ self.assertGreaterEqual(member, 100)
+ self.assertLessEqual(member, 599)
+
+ def test_httpstatus_category(self):
+ """Checks that the statuses belong to the standard categories"""
+
+ categories = (
+ ((100, 199), "is_informational"),
+ ((200, 299), "is_success"),
+ ((300, 399), "is_redirection"),
+ ((400, 499), "is_client_error"),
+ ((500, 599), "is_server_error"),
+ )
+ for member in HTTPStatus.__members__.values():
+ for (lower, upper), category in categories:
+ category_indicator = getattr(member, category)
+ if lower <= member <= upper:
+ self.assertTrue(category_indicator)
+ else:
+ self.assertFalse(category_indicator)
+
def test_status_lines(self):
# Test HTTP status lines
@@ -780,8 +990,7 @@ def test_send_file(self):
sock = FakeSocket(body)
conn.sock = sock
conn.request('GET', '/foo', body)
- self.assertTrue(sock.data.startswith(expected), '%r != %r' %
- (sock.data[:len(expected)], expected))
+ self.assertStartsWith(sock.data, expected)
def test_send(self):
expected = b'this is a test this is only a test'
@@ -872,6 +1081,25 @@ def test_chunked(self):
self.assertEqual(resp.read(), expected)
resp.close()
+ # Explicit full read
+ for n in (-123, -1, None):
+ with self.subTest('full read', n=n):
+ sock = FakeSocket(chunked_start + last_chunk + chunked_end)
+ resp = client.HTTPResponse(sock, method="GET")
+ resp.begin()
+ self.assertTrue(resp.chunked)
+ self.assertEqual(resp.read(n), expected)
+ resp.close()
+
+ # Read first chunk
+ with self.subTest('read1(-1)'):
+ sock = FakeSocket(chunked_start + last_chunk + chunked_end)
+ resp = client.HTTPResponse(sock, method="GET")
+ resp.begin()
+ self.assertTrue(resp.chunked)
+ self.assertEqual(resp.read1(-1), b"hello worl")
+ resp.close()
+
# Various read sizes
for n in range(1, 12):
sock = FakeSocket(chunked_start + last_chunk + chunked_end)
@@ -1227,6 +1455,72 @@ def run_server():
thread.join()
self.assertEqual(result, b"proxied data\n")
+ def test_large_content_length(self):
+ serv = socket.create_server((HOST, 0))
+ self.addCleanup(serv.close)
+
+ def run_server():
+ [conn, address] = serv.accept()
+ with conn:
+ while conn.recv(1024):
+ conn.sendall(
+ b"HTTP/1.1 200 Ok\r\n"
+ b"Content-Length: %d\r\n"
+ b"\r\n" % size)
+ conn.sendall(b'A' * (size//3))
+ conn.sendall(b'B' * (size - size//3))
+
+ thread = threading.Thread(target=run_server)
+ thread.start()
+ self.addCleanup(thread.join, 1.0)
+
+ conn = client.HTTPConnection(*serv.getsockname())
+ try:
+ for w in range(15, 27):
+ size = 1 << w
+ conn.request("GET", "/")
+ with conn.getresponse() as response:
+ self.assertEqual(len(response.read()), size)
+ finally:
+ conn.close()
+ thread.join(1.0)
+
+ def test_large_content_length_truncated(self):
+ serv = socket.create_server((HOST, 0))
+ self.addCleanup(serv.close)
+
+ def run_server():
+ while True:
+ [conn, address] = serv.accept()
+ with conn:
+ conn.recv(1024)
+ if not size:
+ break
+ conn.sendall(
+ b"HTTP/1.1 200 Ok\r\n"
+ b"Content-Length: %d\r\n"
+ b"\r\n"
+ b"Text" % size)
+
+ thread = threading.Thread(target=run_server)
+ thread.start()
+ self.addCleanup(thread.join, 1.0)
+
+ conn = client.HTTPConnection(*serv.getsockname())
+ try:
+ for w in range(18, 65):
+ size = 1 << w
+ conn.request("GET", "/")
+ with conn.getresponse() as response:
+ self.assertRaises(client.IncompleteRead, response.read)
+ conn.close()
+ finally:
+ conn.close()
+ size = 0
+ conn.request("GET", "/")
+ conn.close()
+ thread.join(1.0)
+
def test_putrequest_override_domain_validation(self):
"""
It should be possible to override the default validation
@@ -1266,7 +1560,7 @@ def _encode_request(self, str_url):
conn.putrequest('GET', '/☃')
-class ExtendedReadTest(TestCase):
+class ExtendedReadTest(TestCase, ExtraAssertions):
"""
Test peek(), read1(), readline()
"""
@@ -1325,7 +1619,7 @@ def mypeek(n=-1):
# then unbounded peek
p2 = resp.peek()
self.assertGreaterEqual(len(p2), len(p))
- self.assertTrue(p2.startswith(p))
+ self.assertStartsWith(p2, p)
next = resp.read(len(p2))
self.assertEqual(next, p2)
else:
@@ -1340,18 +1634,22 @@ def test_readline(self):
resp = self.resp
self._verify_readline(self.resp.readline, self.lines_expected)
- def _verify_readline(self, readline, expected):
+ def test_readline_without_limit(self):
+ self._verify_readline(self.resp.readline, self.lines_expected, limit=-1)
+
+ def _verify_readline(self, readline, expected, limit=5):
all = []
while True:
# short readlines
- line = readline(5)
+ line = readline(limit)
if line and line != b"foo":
if len(line) < 5:
- self.assertTrue(line.endswith(b"\n"))
+ self.assertEndsWith(line, b"\n")
all.append(line)
if not line:
break
self.assertEqual(b"".join(all), expected)
+ self.assertTrue(self.resp.isclosed())
def test_read1(self):
resp = self.resp
@@ -1371,6 +1669,7 @@ def test_read1_unbounded(self):
break
all.append(data)
self.assertEqual(b"".join(all), self.lines_expected)
+ self.assertTrue(resp.isclosed())
def test_read1_bounded(self):
resp = self.resp
@@ -1382,15 +1681,22 @@ def test_read1_bounded(self):
self.assertLessEqual(len(data), 10)
all.append(data)
self.assertEqual(b"".join(all), self.lines_expected)
+ self.assertTrue(resp.isclosed())
def test_read1_0(self):
self.assertEqual(self.resp.read1(0), b"")
+ self.assertFalse(self.resp.isclosed())
def test_peek_0(self):
p = self.resp.peek(0)
self.assertLessEqual(0, len(p))
+class ExtendedReadTestContentLengthKnown(ExtendedReadTest):
+ _header, _body = ExtendedReadTest.lines.split('\r\n\r\n', 1)
+ lines = _header + f'\r\nContent-Length: {len(_body)}\r\n\r\n' + _body
+
+
class ExtendedReadTestChunked(ExtendedReadTest):
"""
Test peek(), read1(), readline() in chunked mode
@@ -1447,7 +1753,7 @@ def readline(self, limit):
raise
-class OfflineTest(TestCase):
+class OfflineTest(TestCase, ExtraAssertions):
def test_all(self):
# Documented objects defined in the module should be in __all__
expected = {"responses"} # Allowlist documented dict() object
@@ -1500,13 +1806,17 @@ def test_client_constants(self):
'GONE',
'LENGTH_REQUIRED',
'PRECONDITION_FAILED',
+ 'CONTENT_TOO_LARGE',
'REQUEST_ENTITY_TOO_LARGE',
+ 'URI_TOO_LONG',
'REQUEST_URI_TOO_LONG',
'UNSUPPORTED_MEDIA_TYPE',
+ 'RANGE_NOT_SATISFIABLE',
'REQUESTED_RANGE_NOT_SATISFIABLE',
'EXPECTATION_FAILED',
'IM_A_TEAPOT',
'MISDIRECTED_REQUEST',
+ 'UNPROCESSABLE_CONTENT',
'UNPROCESSABLE_ENTITY',
'LOCKED',
'FAILED_DEPENDENCY',
@@ -1529,7 +1839,7 @@ def test_client_constants(self):
]
for const in expected:
with self.subTest(constant=const):
- self.assertTrue(hasattr(client, const))
+ self.assertHasAttr(client, const)
class SourceAddressTest(TestCase):
@@ -1766,6 +2076,7 @@ def test_networked_good_cert(self):
h.close()
self.assertIn('nginx', server_string)
+ @support.requires_resource('walltime')
def test_networked_bad_cert(self):
# We feed a "CA" cert that is unrelated to the server's cert
import ssl
@@ -1778,7 +2089,6 @@ def test_networked_bad_cert(self):
h.request('GET', '/')
self.assertEqual(exc_info.exception.reason, 'CERTIFICATE_VERIFY_FAILED')
- @unittest.skipIf(sys.platform == 'darwin', 'Occasionally success on macOS')
def test_local_unknown_cert(self):
# The custom cert isn't known to the default trust bundle
import ssl
@@ -1788,8 +2098,9 @@ def test_local_unknown_cert(self):
h.request('GET', '/')
self.assertEqual(exc_info.exception.reason, 'CERTIFICATE_VERIFY_FAILED')
+ @unittest.expectedFailure # TODO: RUSTPYTHON http.client.RemoteDisconnected: Remote end closed connection without response
def test_local_good_hostname(self):
- # The (valid) cert validates the HTTP hostname
+ # The (valid) cert validates the HTTPS hostname
import ssl
server = self.make_server(CERT_localhost)
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
@@ -1801,8 +2112,9 @@ def test_local_good_hostname(self):
self.addCleanup(resp.close)
self.assertEqual(resp.status, 404)
+ @unittest.expectedFailure # TODO: RUSTPYTHON http.client.RemoteDisconnected: Remote end closed connection without response
def test_local_bad_hostname(self):
- # The (valid) cert doesn't validate the HTTP hostname
+ # The (valid) cert doesn't validate the HTTPS hostname
import ssl
server = self.make_server(CERT_fakehostname)
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
@@ -1810,38 +2122,21 @@ def test_local_bad_hostname(self):
h = client.HTTPSConnection('localhost', server.port, context=context)
with self.assertRaises(ssl.CertificateError):
h.request('GET', '/')
- # Same with explicit check_hostname=True
- with warnings_helper.check_warnings(('', DeprecationWarning)):
- h = client.HTTPSConnection('localhost', server.port,
- context=context, check_hostname=True)
+
+ # Same with explicit context.check_hostname=True
+ context.check_hostname = True
+ h = client.HTTPSConnection('localhost', server.port, context=context)
with self.assertRaises(ssl.CertificateError):
h.request('GET', '/')
- # With check_hostname=False, the mismatching is ignored
- context.check_hostname = False
- with warnings_helper.check_warnings(('', DeprecationWarning)):
- h = client.HTTPSConnection('localhost', server.port,
- context=context, check_hostname=False)
- h.request('GET', '/nonexistent')
- resp = h.getresponse()
- resp.close()
- h.close()
- self.assertEqual(resp.status, 404)
- # The context's check_hostname setting is used if one isn't passed to
- # HTTPSConnection.
+
+ # With context.check_hostname=False, the mismatching is ignored
context.check_hostname = False
h = client.HTTPSConnection('localhost', server.port, context=context)
h.request('GET', '/nonexistent')
resp = h.getresponse()
- self.assertEqual(resp.status, 404)
resp.close()
h.close()
- # Passing check_hostname to HTTPSConnection should override the
- # context's setting.
- with warnings_helper.check_warnings(('', DeprecationWarning)):
- h = client.HTTPSConnection('localhost', server.port,
- context=context, check_hostname=True)
- with self.assertRaises(ssl.CertificateError):
- h.request('GET', '/')
+ self.assertEqual(resp.status, 404)
@unittest.skipIf(not hasattr(client, 'HTTPSConnection'),
'http.client.HTTPSConnection not available')
@@ -1877,11 +2172,9 @@ def test_tls13_pha(self):
self.assertIs(h._context, context)
self.assertFalse(h._context.post_handshake_auth)
- with warnings.catch_warnings():
- warnings.filterwarnings('ignore', 'key_file, cert_file and check_hostname are deprecated',
- DeprecationWarning)
- h = client.HTTPSConnection('localhost', 443, context=context,
- cert_file=CERT_localhost)
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT, cert_file=CERT_localhost)
+ context.post_handshake_auth = True
+ h = client.HTTPSConnection('localhost', 443, context=context)
self.assertTrue(h._context.post_handshake_auth)
@@ -2016,14 +2309,15 @@ def test_getting_header_defaultint(self):
header = self.resp.getheader('No-Such-Header',default=42)
self.assertEqual(header, 42)
-class TunnelTests(TestCase):
+class TunnelTests(TestCase, ExtraAssertions):
def setUp(self):
response_text = (
- 'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT
+ 'HTTP/1.1 200 OK\r\n\r\n' # Reply to CONNECT
'HTTP/1.1 200 OK\r\n' # Reply to HEAD
'Content-Length: 42\r\n\r\n'
)
self.host = 'proxy.com'
+ self.port = client.HTTP_PORT
self.conn = client.HTTPConnection(self.host)
self.conn._create_connection = self._create_connection(response_text)
@@ -2035,15 +2329,45 @@ def create_connection(address, timeout=None, source_address=None):
return FakeSocket(response_text, host=address[0], port=address[1])
return create_connection
- def test_set_tunnel_host_port_headers(self):
+ def test_set_tunnel_host_port_headers_add_host_missing(self):
tunnel_host = 'destination.com'
tunnel_port = 8888
tunnel_headers = {'User-Agent': 'Mozilla/5.0 (compatible, MSIE 11)'}
+ tunnel_headers_after = tunnel_headers.copy()
+ tunnel_headers_after['Host'] = '%s:%d' % (tunnel_host, tunnel_port)
self.conn.set_tunnel(tunnel_host, port=tunnel_port,
headers=tunnel_headers)
self.conn.request('HEAD', '/', '')
self.assertEqual(self.conn.sock.host, self.host)
- self.assertEqual(self.conn.sock.port, client.HTTP_PORT)
+ self.assertEqual(self.conn.sock.port, self.port)
+ self.assertEqual(self.conn._tunnel_host, tunnel_host)
+ self.assertEqual(self.conn._tunnel_port, tunnel_port)
+ self.assertEqual(self.conn._tunnel_headers, tunnel_headers_after)
+
+ def test_set_tunnel_host_port_headers_set_host_identical(self):
+ tunnel_host = 'destination.com'
+ tunnel_port = 8888
+ tunnel_headers = {'User-Agent': 'Mozilla/5.0 (compatible, MSIE 11)',
+ 'Host': '%s:%d' % (tunnel_host, tunnel_port)}
+ self.conn.set_tunnel(tunnel_host, port=tunnel_port,
+ headers=tunnel_headers)
+ self.conn.request('HEAD', '/', '')
+ self.assertEqual(self.conn.sock.host, self.host)
+ self.assertEqual(self.conn.sock.port, self.port)
+ self.assertEqual(self.conn._tunnel_host, tunnel_host)
+ self.assertEqual(self.conn._tunnel_port, tunnel_port)
+ self.assertEqual(self.conn._tunnel_headers, tunnel_headers)
+
+ def test_set_tunnel_host_port_headers_set_host_different(self):
+ tunnel_host = 'destination.com'
+ tunnel_port = 8888
+ tunnel_headers = {'User-Agent': 'Mozilla/5.0 (compatible, MSIE 11)',
+ 'Host': '%s:%d' % ('example.com', 4200)}
+ self.conn.set_tunnel(tunnel_host, port=tunnel_port,
+ headers=tunnel_headers)
+ self.conn.request('HEAD', '/', '')
+ self.assertEqual(self.conn.sock.host, self.host)
+ self.assertEqual(self.conn.sock.port, self.port)
self.assertEqual(self.conn._tunnel_host, tunnel_host)
self.assertEqual(self.conn._tunnel_port, tunnel_port)
self.assertEqual(self.conn._tunnel_headers, tunnel_headers)
@@ -2055,17 +2379,96 @@ def test_disallow_set_tunnel_after_connect(self):
'destination.com')
def test_connect_with_tunnel(self):
- self.conn.set_tunnel('destination.com')
+ d = {
+ b'host': b'destination.com',
+ b'port': client.HTTP_PORT,
+ }
+ self.conn.set_tunnel(d[b'host'].decode('ascii'))
+ self.conn.request('HEAD', '/', '')
+ self.assertEqual(self.conn.sock.host, self.host)
+ self.assertEqual(self.conn.sock.port, self.port)
+ self.assertIn(b'CONNECT %(host)s:%(port)d HTTP/1.1\r\n'
+ b'Host: %(host)s:%(port)d\r\n\r\n' % d,
+ self.conn.sock.data)
+ self.assertIn(b'HEAD / HTTP/1.1\r\nHost: %(host)s\r\n' % d,
+ self.conn.sock.data)
+
+ def test_connect_with_tunnel_with_default_port(self):
+ d = {
+ b'host': b'destination.com',
+ b'port': client.HTTP_PORT,
+ }
+ self.conn.set_tunnel(d[b'host'].decode('ascii'), port=d[b'port'])
+ self.conn.request('HEAD', '/', '')
+ self.assertEqual(self.conn.sock.host, self.host)
+ self.assertEqual(self.conn.sock.port, self.port)
+ self.assertIn(b'CONNECT %(host)s:%(port)d HTTP/1.1\r\n'
+ b'Host: %(host)s:%(port)d\r\n\r\n' % d,
+ self.conn.sock.data)
+ self.assertIn(b'HEAD / HTTP/1.1\r\nHost: %(host)s\r\n' % d,
+ self.conn.sock.data)
+
+ def test_connect_with_tunnel_with_nonstandard_port(self):
+ d = {
+ b'host': b'destination.com',
+ b'port': 8888,
+ }
+ self.conn.set_tunnel(d[b'host'].decode('ascii'), port=d[b'port'])
+ self.conn.request('HEAD', '/', '')
+ self.assertEqual(self.conn.sock.host, self.host)
+ self.assertEqual(self.conn.sock.port, self.port)
+ self.assertIn(b'CONNECT %(host)s:%(port)d HTTP/1.1\r\n'
+ b'Host: %(host)s:%(port)d\r\n\r\n' % d,
+ self.conn.sock.data)
+ self.assertIn(b'HEAD / HTTP/1.1\r\nHost: %(host)s:%(port)d\r\n' % d,
+ self.conn.sock.data)
+
+ # This request is not RFC-valid, but it's been possible with the library
+ # for years, so don't break it unexpectedly... This also tests
+ # case-insensitivity when injecting Host: headers if they're missing.
+ def test_connect_with_tunnel_with_different_host_header(self):
+ d = {
+ b'host': b'destination.com',
+ b'tunnel_host_header': b'example.com:9876',
+ b'port': client.HTTP_PORT,
+ }
+ self.conn.set_tunnel(
+ d[b'host'].decode('ascii'),
+ headers={'HOST': d[b'tunnel_host_header'].decode('ascii')})
+ self.conn.request('HEAD', '/', '')
+ self.assertEqual(self.conn.sock.host, self.host)
+ self.assertEqual(self.conn.sock.port, self.port)
+ self.assertIn(b'CONNECT %(host)s:%(port)d HTTP/1.1\r\n'
+ b'HOST: %(tunnel_host_header)s\r\n\r\n' % d,
+ self.conn.sock.data)
+ self.assertIn(b'HEAD / HTTP/1.1\r\nHost: %(host)s\r\n' % d,
+ self.conn.sock.data)
+
+ def test_connect_with_tunnel_different_host(self):
+ d = {
+ b'host': b'destination.com',
+ b'port': client.HTTP_PORT,
+ }
+ self.conn.set_tunnel(d[b'host'].decode('ascii'))
+ self.conn.request('HEAD', '/', '')
+ self.assertEqual(self.conn.sock.host, self.host)
+ self.assertEqual(self.conn.sock.port, self.port)
+ self.assertIn(b'CONNECT %(host)s:%(port)d HTTP/1.1\r\n'
+ b'Host: %(host)s:%(port)d\r\n\r\n' % d,
+ self.conn.sock.data)
+ self.assertIn(b'HEAD / HTTP/1.1\r\nHost: %(host)s\r\n' % d,
+ self.conn.sock.data)
+
+ def test_connect_with_tunnel_idna(self):
+ dest = '\u03b4\u03c0\u03b8.gr'
+ dest_port = b'%s:%d' % (dest.encode('idna'), client.HTTP_PORT)
+ expected = b'CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n' % (
+ dest_port, dest_port)
+ self.conn.set_tunnel(dest)
self.conn.request('HEAD', '/', '')
self.assertEqual(self.conn.sock.host, self.host)
self.assertEqual(self.conn.sock.port, client.HTTP_PORT)
- self.assertIn(b'CONNECT destination.com', self.conn.sock.data)
- # issue22095
- self.assertNotIn(b'Host: destination.com:None', self.conn.sock.data)
- self.assertIn(b'Host: destination.com', self.conn.sock.data)
-
- # This test should be removed when CONNECT gets the HTTP/1.1 blessing
- self.assertNotIn(b'Host: proxy.com', self.conn.sock.data)
+ self.assertIn(expected, self.conn.sock.data)
def test_tunnel_connect_single_send_connection_setup(self):
"""Regresstion test for https://bugs.python.org/issue43332."""
@@ -2080,17 +2483,39 @@ def test_tunnel_connect_single_send_connection_setup(self):
msg=f'unexpected number of send calls: {mock_send.mock_calls}')
proxy_setup_data_sent = mock_send.mock_calls[0][1][0]
self.assertIn(b'CONNECT destination.com', proxy_setup_data_sent)
- self.assertTrue(
- proxy_setup_data_sent.endswith(b'\r\n\r\n'),
+ self.assertEndsWith(proxy_setup_data_sent, b'\r\n\r\n',
msg=f'unexpected proxy data sent {proxy_setup_data_sent!r}')
def test_connect_put_request(self):
- self.conn.set_tunnel('destination.com')
+ d = {
+ b'host': b'destination.com',
+ b'port': client.HTTP_PORT,
+ }
+ self.conn.set_tunnel(d[b'host'].decode('ascii'))
+ self.conn.request('PUT', '/', '')
+ self.assertEqual(self.conn.sock.host, self.host)
+ self.assertEqual(self.conn.sock.port, self.port)
+ self.assertIn(b'CONNECT %(host)s:%(port)d HTTP/1.1\r\n'
+ b'Host: %(host)s:%(port)d\r\n\r\n' % d,
+ self.conn.sock.data)
+ self.assertIn(b'PUT / HTTP/1.1\r\nHost: %(host)s\r\n' % d,
+ self.conn.sock.data)
+
+ def test_connect_put_request_ipv6(self):
+ self.conn.set_tunnel('[1:2:3::4]', 1234)
+ self.conn.request('PUT', '/', '')
+ self.assertEqual(self.conn.sock.host, self.host)
+ self.assertEqual(self.conn.sock.port, client.HTTP_PORT)
+ self.assertIn(b'CONNECT [1:2:3::4]:1234', self.conn.sock.data)
+ self.assertIn(b'Host: [1:2:3::4]:1234', self.conn.sock.data)
+
+ def test_connect_put_request_ipv6_port(self):
+ self.conn.set_tunnel('[1:2:3::4]:1234')
self.conn.request('PUT', '/', '')
self.assertEqual(self.conn.sock.host, self.host)
self.assertEqual(self.conn.sock.port, client.HTTP_PORT)
- self.assertIn(b'CONNECT destination.com', self.conn.sock.data)
- self.assertIn(b'Host: destination.com', self.conn.sock.data)
+ self.assertIn(b'CONNECT [1:2:3::4]:1234', self.conn.sock.data)
+ self.assertIn(b'Host: [1:2:3::4]:1234', self.conn.sock.data)
def test_tunnel_debuglog(self):
expected_header = 'X-Dummy: 1'
@@ -2105,6 +2530,56 @@ def test_tunnel_debuglog(self):
lines = output.getvalue().splitlines()
self.assertIn('header: {}'.format(expected_header), lines)
+ def test_proxy_response_headers(self):
+ expected_header = ('X-Dummy', '1')
+ response_text = (
+ 'HTTP/1.0 200 OK\r\n'
+ '{0}\r\n\r\n'.format(':'.join(expected_header))
+ )
+
+ self.conn._create_connection = self._create_connection(response_text)
+ self.conn.set_tunnel('destination.com')
+
+ self.conn.request('PUT', '/', '')
+ headers = self.conn.get_proxy_response_headers()
+ self.assertIn(expected_header, headers.items())
+
+ def test_no_proxy_response_headers(self):
+ expected_header = ('X-Dummy', '1')
+ response_text = (
+ 'HTTP/1.0 200 OK\r\n'
+ '{0}\r\n\r\n'.format(':'.join(expected_header))
+ )
+
+ self.conn._create_connection = self._create_connection(response_text)
+
+ self.conn.request('PUT', '/', '')
+ headers = self.conn.get_proxy_response_headers()
+ self.assertIsNone(headers)
+
+ def test_tunnel_leak(self):
+ sock = None
+
+ def _create_connection(address, timeout=None, source_address=None):
+ nonlocal sock
+ sock = FakeSocket(
+ 'HTTP/1.1 404 NOT FOUND\r\n\r\n',
+ host=address[0],
+ port=address[1],
+ )
+ return sock
+
+ self.conn._create_connection = _create_connection
+ self.conn.set_tunnel('destination.com')
+ exc = None
+ try:
+ self.conn.request('HEAD', '/', '')
+ except OSError as e:
+ # keeping a reference to exc keeps response alive in the traceback
+ exc = e
+ self.assertIsNotNone(exc)
+ self.assertTrue(sock.file_closed)
+
if __name__ == '__main__':
unittest.main(verbosity=2)
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py
index cd689492ca3..63b778d8b97 100644
--- a/Lib/test/test_httpservers.py
+++ b/Lib/test/test_httpservers.py
@@ -8,6 +8,7 @@
SimpleHTTPRequestHandler, CGIHTTPRequestHandler
from http import server, HTTPStatus
+import contextlib
import os
import socket
import sys
@@ -26,13 +27,16 @@
import datetime
import threading
from unittest import mock
-from io import BytesIO
+from io import BytesIO, StringIO
import unittest
from test import support
-from test.support import os_helper
-from test.support import threading_helper
+from test.support import (
+ is_apple, os_helper, requires_subprocess, threading_helper
+)
+from test.support.testcase import ExtraAssertions
+support.requires_working_socket(module=True)
class NoLogRequestHandler:
def log_message(self, *args):
@@ -64,7 +68,7 @@ def stop(self):
self.join()
-class BaseTestCase(unittest.TestCase):
+class BaseTestCase(unittest.TestCase, ExtraAssertions):
def setUp(self):
self._threads = threading_helper.threading_setup()
os.environ = os_helper.EnvironmentVarGuard()
@@ -163,6 +167,27 @@ def test_version_digits(self):
res = self.con.getresponse()
self.assertEqual(res.status, HTTPStatus.BAD_REQUEST)
+ def test_version_signs_and_underscores(self):
+ self.con._http_vsn_str = 'HTTP/-9_9_9.+9_9_9'
+ self.con.putrequest('GET', '/')
+ self.con.endheaders()
+ res = self.con.getresponse()
+ self.assertEqual(res.status, HTTPStatus.BAD_REQUEST)
+
+ def test_major_version_number_too_long(self):
+ self.con._http_vsn_str = 'HTTP/909876543210.0'
+ self.con.putrequest('GET', '/')
+ self.con.endheaders()
+ res = self.con.getresponse()
+ self.assertEqual(res.status, HTTPStatus.BAD_REQUEST)
+
+ def test_minor_version_number_too_long(self):
+ self.con._http_vsn_str = 'HTTP/1.909876543210'
+ self.con.putrequest('GET', '/')
+ self.con.endheaders()
+ res = self.con.getresponse()
+ self.assertEqual(res.status, HTTPStatus.BAD_REQUEST)
+
def test_version_none_get(self):
self.con._http_vsn_str = ''
self.con.putrequest('GET', '/')
@@ -292,6 +317,44 @@ def test_head_via_send_error(self):
self.assertEqual(b'', data)
+class HTTP09ServerTestCase(BaseTestCase):
+
+ class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
+ """Request handler for HTTP/0.9 server."""
+
+ def do_GET(self):
+ self.wfile.write(f'OK: here is {self.path}\r\n'.encode())
+
+ def setUp(self):
+ super().setUp()
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.sock = self.enterContext(self.sock)
+ self.sock.connect((self.HOST, self.PORT))
+
+ def test_simple_get(self):
+ self.sock.send(b'GET /index.html\r\n')
+ res = self.sock.recv(1024)
+ self.assertEqual(res, b"OK: here is /index.html\r\n")
+
+ def test_invalid_request(self):
+ self.sock.send(b'POST /index.html\r\n')
+ res = self.sock.recv(1024)
+ self.assertIn(b"Bad HTTP/0.9 request type ('POST')", res)
+
+ def test_single_request(self):
+ self.sock.send(b'GET /foo.html\r\n')
+ res = self.sock.recv(1024)
+ self.assertEqual(res, b"OK: here is /foo.html\r\n")
+
+ # Ignore errors if the connection is already closed,
+ # as this is the expected behavior of HTTP/0.9.
+ with contextlib.suppress(OSError):
+ self.sock.send(b'GET /bar.html\r\n')
+ res = self.sock.recv(1024)
+ # The server should not process our request.
+ self.assertEqual(res, b'')
+
+
class RequestHandlerLoggingTestCase(BaseTestCase):
class request_handler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
@@ -312,8 +375,7 @@ def test_get(self):
self.con.request('GET', '/')
self.con.getresponse()
- self.assertTrue(
- err.getvalue().endswith('"GET / HTTP/1.1" 200 -\n'))
+ self.assertEndsWith(err.getvalue(), '"GET / HTTP/1.1" 200 -\n')
def test_err(self):
self.con = http.client.HTTPConnection(self.HOST, self.PORT)
@@ -324,8 +386,8 @@ def test_err(self):
self.con.getresponse()
lines = err.getvalue().split('\n')
- self.assertTrue(lines[0].endswith('code 404, message File not found'))
- self.assertTrue(lines[1].endswith('"ERROR / HTTP/1.1" 404 -'))
+ self.assertEndsWith(lines[0], 'code 404, message File not found')
+ self.assertEndsWith(lines[1], '"ERROR / HTTP/1.1" 404 -')
class SimpleHTTPServerTestCase(BaseTestCase):
@@ -333,7 +395,7 @@ class request_handler(NoLogRequestHandler, SimpleHTTPRequestHandler):
pass
def setUp(self):
- BaseTestCase.setUp(self)
+ super().setUp()
self.cwd = os.getcwd()
basetempdir = tempfile.gettempdir()
os.chdir(basetempdir)
@@ -361,7 +423,7 @@ def tearDown(self):
except:
pass
finally:
- BaseTestCase.tearDown(self)
+ super().tearDown()
def check_status_and_reason(self, response, status, data=None):
def close_conn():
@@ -388,35 +450,175 @@ def close_conn():
reader.close()
return body
- @unittest.skipIf(sys.platform == 'darwin',
- 'undecodable name cannot always be decoded on macOS')
- @unittest.skipIf(sys.platform == 'win32',
- 'undecodable name cannot be decoded on win32')
- @unittest.skipUnless(os_helper.TESTFN_UNDECODABLE,
- 'need os_helper.TESTFN_UNDECODABLE')
- def test_undecodable_filename(self):
+ def check_list_dir_dirname(self, dirname, quotedname=None):
+ fullpath = os.path.join(self.tempdir, dirname)
+ try:
+ os.mkdir(os.path.join(self.tempdir, dirname))
+ except (OSError, UnicodeEncodeError):
+ self.skipTest(f'Can not create directory {dirname!a} '
+ f'on current file system')
+
+ if quotedname is None:
+ quotedname = urllib.parse.quote(dirname, errors='surrogatepass')
+ response = self.request(self.base_url + '/' + quotedname + '/')
+ body = self.check_status_and_reason(response, HTTPStatus.OK)
+ displaypath = html.escape(f'{self.base_url}/{dirname}/', quote=False)
enc = sys.getfilesystemencoding()
- filename = os.fsdecode(os_helper.TESTFN_UNDECODABLE) + '.txt'
- with open(os.path.join(self.tempdir, filename), 'wb') as f:
- f.write(os_helper.TESTFN_UNDECODABLE)
+ prefix = f'listing for {displaypath}'.encode(enc, 'surrogateescape')
+ self.assertIn(prefix + b'title>', body)
+ self.assertIn(prefix + b'h1>', body)
+
+ def check_list_dir_filename(self, filename):
+ fullpath = os.path.join(self.tempdir, filename)
+ content = ascii(fullpath).encode() + (os_helper.TESTFN_UNDECODABLE or b'\xff')
+ try:
+ with open(fullpath, 'wb') as f:
+ f.write(content)
+ except OSError:
+ self.skipTest(f'Can not create file {filename!a} '
+ f'on current file system')
+
response = self.request(self.base_url + '/')
- if sys.platform == 'darwin':
- # On Mac OS the HFS+ filesystem replaces bytes that aren't valid
- # UTF-8 into a percent-encoded value.
- for name in os.listdir(self.tempdir):
- if name != 'test': # Ignore a filename created in setUp().
- filename = name
- break
body = self.check_status_and_reason(response, HTTPStatus.OK)
quotedname = urllib.parse.quote(filename, errors='surrogatepass')
- self.assertIn(('href="%s"' % quotedname)
- .encode(enc, 'surrogateescape'), body)
- self.assertIn(('>%s<' % html.escape(filename, quote=False))
- .encode(enc, 'surrogateescape'), body)
+ enc = response.headers.get_content_charset()
+ self.assertIsNotNone(enc)
+ self.assertIn((f'href="{quotedname}"').encode('ascii'), body)
+ displayname = html.escape(filename, quote=False)
+ self.assertIn(f'>{displayname}<'.encode(enc, 'surrogateescape'), body)
+
response = self.request(self.base_url + '/' + quotedname)
- self.check_status_and_reason(response, HTTPStatus.OK,
- data=os_helper.TESTFN_UNDECODABLE)
+ self.check_status_and_reason(response, HTTPStatus.OK, data=content)
+
+ @unittest.skipUnless(os_helper.TESTFN_NONASCII,
+ 'need os_helper.TESTFN_NONASCII')
+ def test_list_dir_nonascii_dirname(self):
+ dirname = os_helper.TESTFN_NONASCII + '.dir'
+ self.check_list_dir_dirname(dirname)
+
+ @unittest.skipUnless(os_helper.TESTFN_NONASCII,
+ 'need os_helper.TESTFN_NONASCII')
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
+ def test_list_dir_nonascii_filename(self):
+ filename = os_helper.TESTFN_NONASCII + '.txt'
+ self.check_list_dir_filename(filename)
+
+ @unittest.skipIf(is_apple,
+ 'undecodable name cannot always be decoded on Apple platforms')
+ @unittest.skipIf(sys.platform == 'win32',
+ 'undecodable name cannot be decoded on win32')
+ @unittest.skipUnless(os_helper.TESTFN_UNDECODABLE,
+ 'need os_helper.TESTFN_UNDECODABLE')
+ def test_list_dir_undecodable_dirname(self):
+ dirname = os.fsdecode(os_helper.TESTFN_UNDECODABLE) + '.dir'
+ self.check_list_dir_dirname(dirname)
+
+ @unittest.skipIf(is_apple,
+ 'undecodable name cannot always be decoded on Apple platforms')
+ @unittest.skipIf(sys.platform == 'win32',
+ 'undecodable name cannot be decoded on win32')
+ @unittest.skipUnless(os_helper.TESTFN_UNDECODABLE,
+ 'need os_helper.TESTFN_UNDECODABLE')
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
+ def test_list_dir_undecodable_filename(self):
+ filename = os.fsdecode(os_helper.TESTFN_UNDECODABLE) + '.txt'
+ self.check_list_dir_filename(filename)
+
+ def test_list_dir_undecodable_dirname2(self):
+ dirname = '\ufffd.dir'
+ self.check_list_dir_dirname(dirname, quotedname='%ff.dir')
+
+ @unittest.skipUnless(os_helper.TESTFN_UNENCODABLE,
+ 'need os_helper.TESTFN_UNENCODABLE')
+ def test_list_dir_unencodable_dirname(self):
+ dirname = os_helper.TESTFN_UNENCODABLE + '.dir'
+ self.check_list_dir_dirname(dirname)
+
+ @unittest.skipUnless(os_helper.TESTFN_UNENCODABLE,
+ 'need os_helper.TESTFN_UNENCODABLE')
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
+ def test_list_dir_unencodable_filename(self):
+ filename = os_helper.TESTFN_UNENCODABLE + '.txt'
+ self.check_list_dir_filename(filename)
+
+ def test_list_dir_escape_dirname(self):
+ # Characters that need special treating in URL or HTML.
+ for name in ('q?', 'f#', '&', '&', '', '"dq"', "'sq'",
+ '%A4', '%E2%82%AC'):
+ with self.subTest(name=name):
+ dirname = name + '.dir'
+ self.check_list_dir_dirname(dirname,
+ quotedname=urllib.parse.quote(dirname, safe='&<>\'"'))
+
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
+ def test_list_dir_escape_filename(self):
+ # Characters that need special treating in URL or HTML.
+ for name in ('q?', 'f#', '&', '&', '', '"dq"', "'sq'",
+ '%A4', '%E2%82%AC'):
+ with self.subTest(name=name):
+ filename = name + '.txt'
+ self.check_list_dir_filename(filename)
+ os_helper.unlink(os.path.join(self.tempdir, filename))
+
+ def test_list_dir_with_query_and_fragment(self):
+ prefix = f'listing for {self.base_url}/'.encode('latin1')
+ response = self.request(self.base_url + '/#123').read()
+ self.assertIn(prefix + b'title>', response)
+ self.assertIn(prefix + b'h1>', response)
+ response = self.request(self.base_url + '/?x=123').read()
+ self.assertIn(prefix + b'title>', response)
+ self.assertIn(prefix + b'h1>', response)
+
+ def test_get_dir_redirect_location_domain_injection_bug(self):
+ """Ensure //evil.co/..%2f../../X does not put //evil.co/ in Location.
+
+ //netloc/ in a Location header is a redirect to a new host.
+ https://github.com/python/cpython/issues/87389
+
+ This checks that a path resolving to a directory on our server cannot
+ resolve into a redirect to another server.
+ """
+ os.mkdir(os.path.join(self.tempdir, 'existing_directory'))
+ url = f'/python.org/..%2f..%2f..%2f..%2f..%2f../%0a%0d/../{self.tempdir_name}/existing_directory'
+ expected_location = f'{url}/' # /python.org.../ single slash single prefix, trailing slash
+ # Canonicalizes to /tmp/tempdir_name/existing_directory which does
+ # exist and is a dir, triggering the 301 redirect logic.
+ response = self.request(url)
+ self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ location = response.getheader('Location')
+ self.assertEqual(location, expected_location, msg='non-attack failed!')
+ # //python.org... multi-slash prefix, no trailing slash
+ attack_url = f'/{url}'
+ response = self.request(attack_url)
+ self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ location = response.getheader('Location')
+ self.assertNotStartsWith(location, '//')
+ self.assertEqual(location, expected_location,
+ msg='Expected Location header to start with a single / and '
+ 'end with a / as this is a directory redirect.')
+
+ # ///python.org... triple-slash prefix, no trailing slash
+ attack3_url = f'//{url}'
+ response = self.request(attack3_url)
+ self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ self.assertEqual(response.getheader('Location'), expected_location)
+
+ # If the second word in the http request (Request-URI for the http
+ # method) is a full URI, we don't worry about it, as that'll be parsed
+ # and reassembled as a full URI within BaseHTTPRequestHandler.send_head
+ # so no errant scheme-less //netloc//evil.co/ domain mixup can happen.
+ attack_scheme_netloc_2slash_url = f'https://pypi.org/{url}'
+ expected_scheme_netloc_location = f'{attack_scheme_netloc_2slash_url}/'
+ response = self.request(attack_scheme_netloc_2slash_url)
+ self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ location = response.getheader('Location')
+ # We're just ensuring that the scheme and domain make it through, if
+ # there are or aren't multiple slashes at the start of the path that
+ # follows that isn't important in this Location: header.
+ self.assertStartsWith(location, 'https://pypi.org/')
+
+ @unittest.expectedFailure # TODO: RUSTPYTHON
def test_get(self):
#constructs the path relative to the root directory of the HTTPServer
response = self.request(self.base_url + '/test')
@@ -424,10 +626,19 @@ def test_get(self):
# check for trailing "/" which should return 404. See Issue17324
response = self.request(self.base_url + '/test/')
self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
+ response = self.request(self.base_url + '/test%2f')
+ self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
+ response = self.request(self.base_url + '/test%2F')
+ self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
response = self.request(self.base_url + '/')
self.check_status_and_reason(response, HTTPStatus.OK)
+ response = self.request(self.base_url + '%2f')
+ self.check_status_and_reason(response, HTTPStatus.OK)
+ response = self.request(self.base_url + '%2F')
+ self.check_status_and_reason(response, HTTPStatus.OK)
response = self.request(self.base_url)
self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ self.assertEqual(response.getheader("Location"), self.base_url + "/")
self.assertEqual(response.getheader("Content-Length"), "0")
response = self.request(self.base_url + '/?hi=2')
self.check_status_and_reason(response, HTTPStatus.OK)
@@ -439,6 +650,9 @@ def test_get(self):
self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
response = self.request('/' + 'ThisDoesNotExist' + '/')
self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
+ os.makedirs(os.path.join(self.tempdir, 'spam', 'index.html'))
+ response = self.request(self.base_url + '/spam/')
+ self.check_status_and_reason(response, HTTPStatus.OK)
data = b"Dummy index file\r\n"
with open(os.path.join(self.tempdir_name, 'index.html'), 'wb') as f:
@@ -456,6 +670,7 @@ def test_get(self):
finally:
os.chmod(self.tempdir, 0o755)
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
def test_head(self):
response = self.request(
self.base_url + '/test', method='HEAD')
@@ -465,6 +680,7 @@ def test_head(self):
self.assertEqual(response.getheader('content-type'),
'application/octet-stream')
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
def test_browser_cache(self):
"""Check that when a request to /test is sent with the request header
If-Modified-Since set to date of last modification, the server returns
@@ -483,6 +699,7 @@ def test_browser_cache(self):
response = self.request(self.base_url + '/test', headers=headers)
self.check_status_and_reason(response, HTTPStatus.NOT_MODIFIED)
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
def test_browser_cache_file_changed(self):
# with If-Modified-Since earlier than Last-Modified, must return 200
dt = self.last_modif_datetime
@@ -494,6 +711,7 @@ def test_browser_cache_file_changed(self):
response = self.request(self.base_url + '/test', headers=headers)
self.check_status_and_reason(response, HTTPStatus.OK)
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
def test_browser_cache_with_If_None_Match_header(self):
# if If-None-Match header is present, ignore If-Modified-Since
@@ -512,6 +730,7 @@ def test_invalid_requests(self):
response = self.request('/', method='GETs')
self.check_status_and_reason(response, HTTPStatus.NOT_IMPLEMENTED)
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
def test_last_modified(self):
"""Checks that the datetime returned in Last-Modified response header
is the actual datetime of last modification, rounded to the second
@@ -521,6 +740,7 @@ def test_last_modified(self):
last_modif_header = response.headers['Last-modified']
self.assertEqual(last_modif_header, self.last_modif_header)
+ @unittest.expectedFailure # TODO: RUSTPYTHON; http.client.RemoteDisconnected: Remote end closed connection without response
def test_path_without_leading_slash(self):
response = self.request(self.tempdir_name + '/test')
self.check_status_and_reason(response, HTTPStatus.OK, data=self.data)
@@ -530,6 +750,8 @@ def test_path_without_leading_slash(self):
self.check_status_and_reason(response, HTTPStatus.OK)
response = self.request(self.tempdir_name)
self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ self.assertEqual(response.getheader("Location"),
+ self.tempdir_name + "/")
response = self.request(self.tempdir_name + '/?hi=2')
self.check_status_and_reason(response, HTTPStatus.OK)
response = self.request(self.tempdir_name + '?hi=1')
@@ -537,27 +759,6 @@ def test_path_without_leading_slash(self):
self.assertEqual(response.getheader("Location"),
self.tempdir_name + "/?hi=1")
- def test_html_escape_filename(self):
- filename = '.txt'
- fullpath = os.path.join(self.tempdir, filename)
-
- try:
- open(fullpath, 'wb').close()
- except OSError:
- raise unittest.SkipTest('Can not create file %s on current file '
- 'system' % filename)
-
- try:
- response = self.request(self.base_url + '/')
- body = self.check_status_and_reason(response, HTTPStatus.OK)
- enc = response.headers.get_content_charset()
- finally:
- os.unlink(fullpath) # avoid affecting test_undecodable_filename
-
- self.assertIsNotNone(enc)
- html_text = '>%s<' % html.escape(filename, quote=False)
- self.assertIn(html_text.encode(enc), body)
-
cgi_file1 = """\
#!%s
@@ -569,14 +770,19 @@ def test_html_escape_filename(self):
cgi_file2 = """\
#!%s
-import cgi
+import os
+import sys
+import urllib.parse
print("Content-type: text/html")
print()
-form = cgi.FieldStorage()
-print("%%s, %%s, %%s" %% (form.getfirst("spam"), form.getfirst("eggs"),
- form.getfirst("bacon")))
+content_length = int(os.environ["CONTENT_LENGTH"])
+query_string = sys.stdin.buffer.read(content_length)
+params = {key.decode("utf-8"): val.decode("utf-8")
+ for key, val in urllib.parse.parse_qsl(query_string)}
+
+print("%%s, %%s, %%s" %% (params["spam"], params["eggs"], params["bacon"]))
"""
cgi_file4 = """\
@@ -607,17 +813,40 @@ def test_html_escape_filename(self):
print("")
"""
-@unittest.skipIf(not hasattr(os, '_exit'),
- "TODO: RUSTPYTHON, run_cgi in http/server.py gets stuck as os._exit(127) doesn't currently kill forked processes")
+cgi_file7 = """\
+#!%s
+import os
+import sys
+
+print("Content-type: text/plain")
+print()
+
+content_length = int(os.environ["CONTENT_LENGTH"])
+body = sys.stdin.buffer.read(content_length)
+
+print(f"{content_length} {len(body)}")
+"""
+
+
@unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0,
"This test can't be run reliably as root (issue #13308).")
+@requires_subprocess()
class CGIHTTPServerTestCase(BaseTestCase):
class request_handler(NoLogRequestHandler, CGIHTTPRequestHandler):
- pass
+ _test_case_self = None # populated by each setUp() method call.
+
+ def __init__(self, *args, **kwargs):
+ with self._test_case_self.assertWarnsRegex(
+ DeprecationWarning,
+ r'http\.server\.CGIHTTPRequestHandler'):
+ # This context also happens to catch and silence the
+ # threading DeprecationWarning from os.fork().
+ super().__init__(*args, **kwargs)
linesep = os.linesep.encode('ascii')
def setUp(self):
+ self.request_handler._test_case_self = self # practical, but yuck.
BaseTestCase.setUp(self)
self.cwd = os.getcwd()
self.parent_dir = tempfile.mkdtemp()
@@ -637,12 +866,13 @@ def setUp(self):
self.file3_path = None
self.file4_path = None
self.file5_path = None
+ self.file6_path = None
+ self.file7_path = None
# The shebang line should be pure ASCII: use symlink if possible.
# See issue #7668.
self._pythonexe_symlink = None
- # TODO: RUSTPYTHON; dl_nt not supported yet
- if os_helper.can_symlink() and sys.platform != 'win32':
+ if os_helper.can_symlink():
self.pythonexe = os.path.join(self.parent_dir, 'python')
self._pythonexe_symlink = support.PythonSymlink(self.pythonexe).__enter__()
else:
@@ -692,9 +922,15 @@ def setUp(self):
file6.write(cgi_file6 % self.pythonexe)
os.chmod(self.file6_path, 0o777)
+ self.file7_path = os.path.join(self.cgi_dir, 'file7.py')
+ with open(self.file7_path, 'w', encoding='utf-8') as file7:
+ file7.write(cgi_file7 % self.pythonexe)
+ os.chmod(self.file7_path, 0o777)
+
os.chdir(self.parent_dir)
def tearDown(self):
+ self.request_handler._test_case_self = None
try:
os.chdir(self.cwd)
if self._pythonexe_symlink:
@@ -713,11 +949,16 @@ def tearDown(self):
os.remove(self.file5_path)
if self.file6_path:
os.remove(self.file6_path)
+ if self.file7_path:
+ os.remove(self.file7_path)
os.rmdir(self.cgi_child_dir)
os.rmdir(self.cgi_dir)
os.rmdir(self.cgi_dir_in_sub_dir)
os.rmdir(self.sub_dir_2)
os.rmdir(self.sub_dir_1)
+ # The 'gmon.out' file can be written in the current working
+ # directory if C-level code profiling with gprof is enabled.
+ os_helper.unlink(os.path.join(self.parent_dir, 'gmon.out'))
os.rmdir(self.parent_dir)
finally:
BaseTestCase.tearDown(self)
@@ -764,8 +1005,7 @@ def test_url_collapse_path(self):
msg='path = %r\nGot: %r\nWanted: %r' %
(path, actual, expected))
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; AssertionError: Tuples differ: (b"", None, 200) != (b"Hello World\n", "text/html", )')
def test_headers_and_content(self):
res = self.request('/cgi-bin/file1.py')
self.assertEqual(
@@ -776,9 +1016,7 @@ def test_issue19435(self):
res = self.request('///////////nocgi.py/../cgi-bin/nothere.sh')
self.assertEqual(res.status, HTTPStatus.NOT_FOUND)
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
- @unittest.expectedFailure
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; b"" != b"1, python, 123456\n"')
def test_post(self):
params = urllib.parse.urlencode(
{'spam' : 1, 'eggs' : 'python', 'bacon' : 123456})
@@ -787,13 +1025,30 @@ def test_post(self):
self.assertEqual(res.read(), b'1, python, 123456' + self.linesep)
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; AssertionError: b"" != b"32768 32768\n"')
+ def test_large_content_length(self):
+ for w in range(15, 25):
+ size = 1 << w
+ body = b'X' * size
+ headers = {'Content-Length' : str(size)}
+ res = self.request('/cgi-bin/file7.py', 'POST', body, headers)
+ self.assertEqual(res.read(), b'%d %d' % (size, size) + self.linesep)
+
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; AssertionError: b"" != b"Hello World\n"')
+ def test_large_content_length_truncated(self):
+ with support.swap_attr(self.request_handler, 'timeout', 0.001):
+ for w in range(18, 65):
+ size = 1 << w
+ headers = {'Content-Length' : str(size)}
+ res = self.request('/cgi-bin/file1.py', 'POST', b'x', headers)
+ self.assertEqual(res.read(), b'Hello World' + self.linesep)
+
def test_invaliduri(self):
res = self.request('/cgi-bin/invalid')
res.read()
self.assertEqual(res.status, HTTPStatus.NOT_FOUND)
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; AssertionError: Tuples differ: (b"Hello World\n", "text/html", ) != (b"", None, 200)')
def test_authorization(self):
headers = {b'Authorization' : b'Basic ' +
base64.b64encode(b'username:pass')}
@@ -802,8 +1057,7 @@ def test_authorization(self):
(b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
(res.read(), res.getheader('Content-type'), res.status))
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; AssertionError: Tuples differ: (b"Hello World\n", "text/html", ) != (b"", None, 200)')
def test_no_leading_slash(self):
# http://bugs.python.org/issue2254
res = self.request('cgi-bin/file1.py')
@@ -811,8 +1065,7 @@ def test_no_leading_slash(self):
(b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
(res.read(), res.getheader('Content-type'), res.status))
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; ValueError: signal only works in main thread')
def test_os_environ_is_not_altered(self):
signature = "Test CGI Server"
os.environ['SERVER_SOFTWARE'] = signature
@@ -822,32 +1075,28 @@ def test_os_environ_is_not_altered(self):
(res.read(), res.getheader('Content-type'), res.status))
self.assertEqual(os.environ['SERVER_SOFTWARE'], signature)
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; ValueError: signal only works in main thread')
def test_urlquote_decoding_in_cgi_check(self):
res = self.request('/cgi-bin%2ffile1.py')
self.assertEqual(
(b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
(res.read(), res.getheader('Content-type'), res.status))
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; AssertionError: Tuples differ: (b"Hello World\n", "text/html", ) != (b"", None, 200)')
def test_nested_cgi_path_issue21323(self):
res = self.request('/cgi-bin/child-dir/file3.py')
self.assertEqual(
(b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
(res.read(), res.getheader('Content-type'), res.status))
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; ValueError: signal only works in main thread')
def test_query_with_multiple_question_mark(self):
res = self.request('/cgi-bin/file4.py?a=b?c=d')
self.assertEqual(
(b'a=b?c=d' + self.linesep, 'text/html', HTTPStatus.OK),
(res.read(), res.getheader('Content-type'), res.status))
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; AssertionError: Tuples differ: (b"k=aa%2F%2Fbb&//q//p//=//a//b//\n", "text/html", ) != (b"", None, 200)')
def test_query_with_continuous_slashes(self):
res = self.request('/cgi-bin/file4.py?k=aa%2F%2Fbb&//q//p//=//a//b//')
self.assertEqual(
@@ -855,8 +1104,7 @@ def test_query_with_continuous_slashes(self):
'text/html', HTTPStatus.OK),
(res.read(), res.getheader('Content-type'), res.status))
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; Tuples differ: (b"", None, 200) != (b"Hello World\n", "text/html", )')
def test_cgi_path_in_sub_directories(self):
try:
CGIHTTPRequestHandler.cgi_directories.append('/sub/dir/cgi-bin')
@@ -867,8 +1115,7 @@ def test_cgi_path_in_sub_directories(self):
finally:
CGIHTTPRequestHandler.cgi_directories.remove('/sub/dir/cgi-bin')
- # TODO: RUSTPYTHON
- @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON; works only on windows")
+ @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON; AssertionError: b"HTTP_ACCEPT=text/html,text/plain" not found in b""')
def test_accept(self):
browser_accept = \
'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8'
@@ -929,7 +1176,7 @@ def numWrites(self):
return len(self.datas)
-class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
+class BaseHTTPRequestHandlerTestCase(unittest.TestCase, ExtraAssertions):
"""Test the functionality of the BaseHTTPServer.
Test the support for the Expect 100-continue header.
@@ -960,6 +1207,27 @@ def verify_http_server_response(self, response):
match = self.HTTPResponseMatch.search(response)
self.assertIsNotNone(match)
+ def test_unprintable_not_logged(self):
+ # We call the method from the class directly as our Socketless
+ # Handler subclass overrode it... nice for everything BUT this test.
+ self.handler.client_address = ('127.0.0.1', 1337)
+ log_message = BaseHTTPRequestHandler.log_message
+ with mock.patch.object(sys, 'stderr', StringIO()) as fake_stderr:
+ log_message(self.handler, '/foo')
+ log_message(self.handler, '/\033bar\000\033')
+ log_message(self.handler, '/spam %s.', 'a')
+ log_message(self.handler, '/spam %s.', '\033\x7f\x9f\xa0beans')
+ log_message(self.handler, '"GET /foo\\b"ar\007 HTTP/1.0"')
+ stderr = fake_stderr.getvalue()
+ self.assertNotIn('\033', stderr) # non-printable chars are caught.
+ self.assertNotIn('\000', stderr) # non-printable chars are caught.
+ lines = stderr.splitlines()
+ self.assertIn('/foo', lines[0])
+ self.assertIn(r'/\x1bbar\x00\x1b', lines[1])
+ self.assertIn('/spam a.', lines[2])
+ self.assertIn('/spam \\x1b\\x7f\\x9f\xa0beans.', lines[3])
+ self.assertIn(r'"GET /foo\\b"ar\x07 HTTP/1.0"', lines[4])
+
def test_http_1_1(self):
result = self.send_typical_request(b'GET / HTTP/1.1\r\n\r\n')
self.verify_http_server_response(result[0])
@@ -996,7 +1264,7 @@ def test_extra_space(self):
b'Host: dummy\r\n'
b'\r\n'
)
- self.assertTrue(result[0].startswith(b'HTTP/1.1 400 '))
+ self.assertStartsWith(result[0], b'HTTP/1.1 400 ')
self.verify_expected_headers(result[1:result.index(b'\r\n')])
self.assertFalse(self.handler.get_called)
@@ -1110,7 +1378,7 @@ def test_request_length(self):
# Issue #10714: huge request lines are discarded, to avoid Denial
# of Service attacks.
result = self.send_typical_request(b'GET ' + b'x' * 65537)
- self.assertEqual(result[0], b'HTTP/1.1 414 Request-URI Too Long\r\n')
+ self.assertEqual(result[0], b'HTTP/1.1 414 URI Too Long\r\n')
self.assertFalse(self.handler.get_called)
self.assertIsInstance(self.handler.requestline, str)
diff --git a/Lib/test/test_json/test_encode_basestring_ascii.py b/Lib/test/test_json/test_encode_basestring_ascii.py
index 6a39b72a09d..c90d3e968e5 100644
--- a/Lib/test/test_json/test_encode_basestring_ascii.py
+++ b/Lib/test/test_json/test_encode_basestring_ascii.py
@@ -8,13 +8,12 @@
('\u0123\u4567\u89ab\ucdef\uabcd\uef4a', '"\\u0123\\u4567\\u89ab\\ucdef\\uabcd\\uef4a"'),
('controls', '"controls"'),
('\x08\x0c\n\r\t', '"\\b\\f\\n\\r\\t"'),
+ ('\x00\x1f\x7f', '"\\u0000\\u001f\\u007f"'),
('{"object with 1 member":["array with 1 element"]}', '"{\\"object with 1 member\\":[\\"array with 1 element\\"]}"'),
(' s p a c e d ', '" s p a c e d "'),
('\U0001d120', '"\\ud834\\udd20"'),
('\u03b1\u03a9', '"\\u03b1\\u03a9"'),
("`1~!@#$%^&*()_+-={':[,]}|;.>?", '"`1~!@#$%^&*()_+-={\':[,]}|;.>?"'),
- ('\x08\x0c\n\r\t', '"\\b\\f\\n\\r\\t"'),
- ('\u0123\u4567\u89ab\ucdef\uabcd\uef4a', '"\\u0123\\u4567\\u89ab\\ucdef\\uabcd\\uef4a"'),
]
class TestEncodeBasestringAscii:
diff --git a/Lib/test/test_json/test_scanstring.py b/Lib/test/test_json/test_scanstring.py
index a5c46bb64b4..d6922c3b1b9 100644
--- a/Lib/test/test_json/test_scanstring.py
+++ b/Lib/test/test_json/test_scanstring.py
@@ -3,6 +3,7 @@
import unittest # XXX: RUSTPYTHON; importing to be able to skip tests
+
class TestScanstring:
def test_scanstring(self):
scanstring = self.json.decoder.scanstring
@@ -147,7 +148,7 @@ def test_bad_escapes(self):
@unittest.expectedFailure
def test_overflow(self):
with self.assertRaises(OverflowError):
- self.json.decoder.scanstring(b"xxx", sys.maxsize+1)
+ self.json.decoder.scanstring("xxx", sys.maxsize+1)
class TestPyScanstring(TestScanstring, PyTest): pass
diff --git a/Lib/test/test_json/test_unicode.py b/Lib/test/test_json/test_unicode.py
index 4bdb607e7da..be0ac8823d5 100644
--- a/Lib/test/test_json/test_unicode.py
+++ b/Lib/test/test_json/test_unicode.py
@@ -34,6 +34,29 @@ def test_encoding7(self):
j = self.dumps(u + "\n", ensure_ascii=False)
self.assertEqual(j, f'"{u}\\n"')
+ def test_ascii_non_printable_encode(self):
+ u = '\b\t\n\f\r\x00\x1f\x7f'
+ self.assertEqual(self.dumps(u),
+ '"\\b\\t\\n\\f\\r\\u0000\\u001f\\u007f"')
+ self.assertEqual(self.dumps(u, ensure_ascii=False),
+ '"\\b\\t\\n\\f\\r\\u0000\\u001f\x7f"')
+
+ def test_ascii_non_printable_decode(self):
+ self.assertEqual(self.loads('"\\b\\t\\n\\f\\r"'),
+ '\b\t\n\f\r')
+ s = ''.join(map(chr, range(32)))
+ for c in s:
+ self.assertRaises(self.JSONDecodeError, self.loads, f'"{c}"')
+ self.assertEqual(self.loads(f'"{s}"', strict=False), s)
+ self.assertEqual(self.loads('"\x7f"'), '\x7f')
+
+ def test_escaped_decode(self):
+ self.assertEqual(self.loads('"\\b\\t\\n\\f\\r"'), '\b\t\n\f\r')
+ self.assertEqual(self.loads('"\\"\\\\\\/"'), '"\\/')
+ for c in set(map(chr, range(0x100))) - set('"\\/bfnrt'):
+ self.assertRaises(self.JSONDecodeError, self.loads, f'"\\{c}"')
+ self.assertRaises(self.JSONDecodeError, self.loads, f'"\\{c}"', strict=False)
+
def test_big_unicode_encode(self):
u = '\U0001d120'
self.assertEqual(self.dumps(u), '"\\ud834\\udd20"')
@@ -50,6 +73,18 @@ def test_unicode_decode(self):
s = f'"\\u{i:04x}"'
self.assertEqual(self.loads(s), u)
+ def test_single_surrogate_encode(self):
+ self.assertEqual(self.dumps('\uD83D'), '"\\ud83d"')
+ self.assertEqual(self.dumps('\uD83D', ensure_ascii=False), '"\ud83d"')
+ self.assertEqual(self.dumps('\uDC0D'), '"\\udc0d"')
+ self.assertEqual(self.dumps('\uDC0D', ensure_ascii=False), '"\udc0d"')
+
+ def test_single_surrogate_decode(self):
+ self.assertEqual(self.loads('"\uD83D"'), '\ud83d')
+ self.assertEqual(self.loads('"\\uD83D"'), '\ud83d')
+ self.assertEqual(self.loads('"\udc0d"'), '\udc0d')
+ self.assertEqual(self.loads('"\\udc0d"'), '\udc0d')
+
def test_unicode_preservation(self):
self.assertEqual(type(self.loads('""')), str)
self.assertEqual(type(self.loads('"a"')), str)
@@ -104,4 +139,19 @@ def test_object_pairs_hook_with_unicode(self):
class TestPyUnicode(TestUnicode, PyTest): pass
-class TestCUnicode(TestUnicode, CTest): pass
+
+class TestCUnicode(TestUnicode, CTest):
+ # TODO: RUSTPYTHON
+ @unittest.expectedFailure
+ def test_ascii_non_printable_encode(self):
+ return super().test_ascii_non_printable_encode()
+
+ # TODO: RUSTPYTHON
+ @unittest.skip("TODO: RUSTPYTHON; panics with 'str has surrogates'")
+ def test_single_surrogate_decode(self):
+ return super().test_single_surrogate_decode()
+
+ # TODO: RUSTPYTHON
+ @unittest.skip("TODO: RUSTPYTHON; panics with 'str has surrogates'")
+ def test_single_surrogate_encode(self):
+ return super().test_single_surrogate_encode()
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
index 8ea77d186e4..12b61e76423 100644
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -736,6 +736,7 @@ def remove_loop(fname, tries):
@threading_helper.requires_working_threading()
@skip_if_asan_fork
@skip_if_tsan_fork
+ @unittest.skip("TODO: RUSTPYTHON; Flaky")
def test_post_fork_child_no_deadlock(self):
"""Ensure child logging locks are not held; bpo-6721 & bpo-36533."""
class _OurHandler(logging.Handler):
diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py
index 1bac61f59e1..4010ef9c340 100644
--- a/Lib/test/test_lzma.py
+++ b/Lib/test/test_lzma.py
@@ -409,8 +409,6 @@ def test_decompressor_bigmem(self, size):
# Pickling raises an exception; there's no way to serialize an lzma_stream.
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_pickle(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises(TypeError):
@@ -2194,4 +2192,4 @@ def test_filter_properties_roundtrip(self):
if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
+ unittest.main()
diff --git a/Lib/test/test_memoryio.py b/Lib/test/test_memoryio.py
index 07d9d38d6e4..343e5dd7a6c 100644
--- a/Lib/test/test_memoryio.py
+++ b/Lib/test/test_memoryio.py
@@ -745,8 +745,6 @@ def test_init(self):
def test_issue5449(self):
super().test_issue5449()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_pickling(self):
super().test_pickling()
@@ -777,8 +775,6 @@ def test_truncate(self):
def test_write(self):
super().test_write()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_getstate(self):
memio = self.ioclass()
state = memio.__getstate__()
@@ -911,8 +907,6 @@ def test_newline_none(self):
def test_newlines_property(self):
super().test_newlines_property()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_pickling(self):
super().test_pickling()
@@ -954,8 +948,6 @@ def test_widechar(self):
self.assertEqual(memio.tell(), len(buf) * 2)
self.assertEqual(memio.getvalue(), buf + buf)
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_getstate(self):
memio = self.ioclass()
state = memio.__getstate__()
@@ -1006,8 +998,6 @@ def test_newline_cr(self):
def test_newline_crlf(self):
super().test_newline_crlf()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_newline_default(self):
super().test_newline_default()
@@ -1016,8 +1006,6 @@ def test_newline_default(self):
def test_newline_empty(self):
super().test_newline_empty()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
def test_newline_lf(self):
super().test_newline_lf()
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index ea51b9d0916..7271696a191 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -97,10 +97,6 @@ def dumps(self, arg, proto=None, **kwargs):
def test_picklebuffer_error(self): # TODO(RUSTPYTHON): Remove this test when it passes
return super().test_picklebuffer_error()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_bad_getattr(self): # TODO(RUSTPYTHON): Remove this test when it passes
- return super().test_bad_getattr()
# TODO: RUSTPYTHON
@unittest.expectedFailure
@@ -135,15 +131,7 @@ def loads(self, buf, **kwds):
def test_c_methods(self): # TODO(RUSTPYTHON): Remove this test when it passes
return super().test_c_methods()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_complex_newobj_ex(self): # TODO(RUSTPYTHON): Remove this test when it passes
- return super().test_complex_newobj_ex()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_py_methods(self): # TODO(RUSTPYTHON): Remove this test when it passes
- return super().test_py_methods()
# TODO: RUSTPYTHON
@unittest.expectedFailure
@@ -239,10 +227,6 @@ def loads(self, buf, **kwds):
def test_c_methods(self): # TODO(RUSTPYTHON): Remove this test when it passes
return super().test_c_methods()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_complex_newobj_ex(self): # TODO(RUSTPYTHON): Remove this test when it passes
- return super().test_complex_newobj_ex()
# TODO: RUSTPYTHON
@unittest.expectedFailure
@@ -259,10 +243,6 @@ def test_correctly_quoted_string(self): # TODO(RUSTPYTHON): Remove this test whe
def test_load_python2_str_as_bytes(self): # TODO(RUSTPYTHON): Remove this test when it passes
return super().test_load_python2_str_as_bytes()
- # TODO: RUSTPYTHON
- @unittest.expectedFailure
- def test_py_methods(self): # TODO(RUSTPYTHON): Remove this test when it passes
- return super().test_py_methods()
# TODO: RUSTPYTHON
@unittest.expectedFailure
diff --git a/Lib/test/test_robotparser.py b/Lib/test/test_robotparser.py
index b0bed431d4b..89cabfe0083 100644
--- a/Lib/test/test_robotparser.py
+++ b/Lib/test/test_robotparser.py
@@ -259,6 +259,10 @@ class EmptyQueryStringTest(BaseRobotTest, unittest.TestCase):
good = ['/some/path?']
bad = ['/another/path?']
+ @unittest.expectedFailure # TODO: RUSTPYTHON; self.assertFalse(self.parser.can_fetch(agent, url))\nAssertionError: True is not false
+ def test_bad_urls(self):
+ super().test_bad_urls()
+
class DefaultEntryTest(BaseRequestRateTest, unittest.TestCase):
robots_txt = """\
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 5384e4caf69..9798a4f59c3 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -3525,6 +3525,7 @@ def test_starttls(self):
else:
s.close()
+ @unittest.expectedFailure # TODO: RUSTPYTHON
def test_socketserver(self):
"""Using socketserver to create and manage SSL connections."""
server = make_https_server(self, certfile=SIGNED_CERTFILE)
@@ -4596,7 +4597,7 @@ def server_callback(identity):
with client_context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
- @unittest.skip("TODO: rustpython")
+ @unittest.skip("TODO: RUSTPYTHON; Hangs")
def test_thread_recv_while_main_thread_sends(self):
# GH-137583: Locking was added to calls to send() and recv() on SSL
# socket objects. This seemed fine at the surface level because those
diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py
index aee9fb78017..7e3607842fd 100644
--- a/Lib/test/test_urllib.py
+++ b/Lib/test/test_urllib.py
@@ -1556,7 +1556,6 @@ def test_pathname2url_win(self):
@unittest.skipIf(sys.platform == 'win32',
'test specific to POSIX pathnames')
- @unittest.expectedFailure # AssertionError: '//a/b.c' != '////a/b.c'
def test_pathname2url_posix(self):
fn = urllib.request.pathname2url
self.assertEqual(fn('/'), '/')
@@ -1617,7 +1616,6 @@ def test_url2pathname_win(self):
@unittest.skipIf(sys.platform == 'win32',
'test specific to POSIX pathnames')
- @unittest.expectedFailure # AssertionError: '///foo/bar' != '/foo/bar'
def test_url2pathname_posix(self):
fn = urllib.request.url2pathname
self.assertEqual(fn('/foo/bar'), '/foo/bar')
diff --git a/Lib/test/test_urllib2.py b/Lib/test/test_urllib2.py
index 399c94213a6..263472499d6 100644
--- a/Lib/test/test_urllib2.py
+++ b/Lib/test/test_urllib2.py
@@ -1,9 +1,11 @@
import unittest
from test import support
from test.support import os_helper
-from test.support import socket_helper
+from test.support import requires_subprocess
from test.support import warnings_helper
+from test.support.testcase import ExtraAssertions
from test import test_urllib
+from unittest import mock
import os
import io
@@ -14,16 +16,19 @@
import subprocess
import urllib.request
-# The proxy bypass method imported below has logic specific to the OSX
-# proxy config data structure but is testable on all platforms.
+# The proxy bypass method imported below has logic specific to the
+# corresponding system but is testable on all platforms.
from urllib.request import (Request, OpenerDirector, HTTPBasicAuthHandler,
HTTPPasswordMgrWithPriorAuth, _parse_proxy,
+ _proxy_bypass_winreg_override,
_proxy_bypass_macosx_sysconf,
AbstractDigestAuthHandler)
from urllib.parse import urlparse
import urllib.error
import http.client
+support.requires_working_socket(module=True)
+
# XXX
# Request
# CacheFTPHandler (hard to write)
@@ -483,7 +488,18 @@ def build_test_opener(*handler_instances):
return opener
-class MockHTTPHandler(urllib.request.BaseHandler):
+class MockHTTPHandler(urllib.request.HTTPHandler):
+ # Very simple mock HTTP handler with no special behavior other than using a mock HTTP connection
+
+ def __init__(self, debuglevel=None):
+ super(MockHTTPHandler, self).__init__(debuglevel=debuglevel)
+ self.httpconn = MockHTTPClass()
+
+ def http_open(self, req):
+ return self.do_open(self.httpconn, req)
+
+
+class MockHTTPHandlerRedirect(urllib.request.BaseHandler):
# useful for testing redirections and auth
# sends supplied headers and code as first response
# sends 200 OK as second response
@@ -511,16 +527,17 @@ def http_open(self, req):
return MockResponse(200, "OK", msg, "", req.get_full_url())
-class MockHTTPSHandler(urllib.request.AbstractHTTPHandler):
- # Useful for testing the Proxy-Authorization request by verifying the
- # properties of httpcon
+if hasattr(http.client, 'HTTPSConnection'):
+ class MockHTTPSHandler(urllib.request.HTTPSHandler):
+ # Useful for testing the Proxy-Authorization request by verifying the
+ # properties of httpcon
- def __init__(self, debuglevel=0):
- urllib.request.AbstractHTTPHandler.__init__(self, debuglevel=debuglevel)
- self.httpconn = MockHTTPClass()
+ def __init__(self, debuglevel=None, context=None, check_hostname=None):
+ super(MockHTTPSHandler, self).__init__(debuglevel, context, check_hostname)
+ self.httpconn = MockHTTPClass()
- def https_open(self, req):
- return self.do_open(self.httpconn, req)
+ def https_open(self, req):
+ return self.do_open(self.httpconn, req)
class MockHTTPHandlerCheckAuth(urllib.request.BaseHandler):
@@ -701,10 +718,6 @@ def test_processors(self):
def sanepathname2url(path):
- try:
- path.encode("utf-8")
- except UnicodeEncodeError:
- raise unittest.SkipTest("path is not encodable to utf8")
urlpath = urllib.request.pathname2url(path)
if os.name == "nt" and urlpath.startswith("///"):
urlpath = urlpath[2:]
@@ -712,8 +725,9 @@ def sanepathname2url(path):
return urlpath
-class HandlerTests(unittest.TestCase):
+class HandlerTests(unittest.TestCase, ExtraAssertions):
+ @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: None != 'image/gif'
def test_ftp(self):
class MockFTPWrapper:
def __init__(self, data):
@@ -761,7 +775,7 @@ def connect_ftp(self, user, passwd, host, port, dirs,
["foo", "bar"], "", None),
("ftp://localhost/baz.gif;type=a",
"localhost", ftplib.FTP_PORT, "", "", "A",
- [], "baz.gif", None), # XXX really this should guess image/gif
+ [], "baz.gif", "image/gif"),
]:
req = Request(url)
req.timeout = None
@@ -777,6 +791,7 @@ def connect_ftp(self, user, passwd, host, port, dirs,
headers = r.info()
self.assertEqual(headers.get("Content-type"), mimetype)
self.assertEqual(int(headers["Content-length"]), len(data))
+ r.close()
def test_file(self):
import email.utils
@@ -984,6 +999,7 @@ def test_http_body_fileobj(self):
file_obj.close()
+ @requires_subprocess()
def test_http_body_pipe(self):
# A file reading from a pipe.
# A pipe cannot be seek'ed. There is no way to determine the
@@ -1047,12 +1063,37 @@ def test_http_body_array(self):
newreq = h.do_request_(req)
self.assertEqual(int(newreq.get_header('Content-length')),16)
- def test_http_handler_debuglevel(self):
+ def test_http_handler_global_debuglevel(self):
+ with mock.patch.object(http.client.HTTPConnection, 'debuglevel', 6):
+ o = OpenerDirector()
+ h = MockHTTPHandler()
+ o.add_handler(h)
+ o.open("http://www.example.com")
+ self.assertEqual(h._debuglevel, 6)
+
+ def test_http_handler_local_debuglevel(self):
o = OpenerDirector()
- h = MockHTTPSHandler(debuglevel=1)
+ h = MockHTTPHandler(debuglevel=5)
+ o.add_handler(h)
+ o.open("http://www.example.com")
+ self.assertEqual(h._debuglevel, 5)
+
+ @unittest.skipUnless(hasattr(http.client, 'HTTPSConnection'), 'HTTPSConnection required for HTTPS tests.')
+ def test_https_handler_global_debuglevel(self):
+ with mock.patch.object(http.client.HTTPSConnection, 'debuglevel', 7):
+ o = OpenerDirector()
+ h = MockHTTPSHandler()
+ o.add_handler(h)
+ o.open("https://www.example.com")
+ self.assertEqual(h._debuglevel, 7)
+
+ @unittest.skipUnless(hasattr(http.client, 'HTTPSConnection'), 'HTTPSConnection required for HTTPS tests.')
+ def test_https_handler_local_debuglevel(self):
+ o = OpenerDirector()
+ h = MockHTTPSHandler(debuglevel=4)
o.add_handler(h)
o.open("https://www.example.com")
- self.assertEqual(h._debuglevel, 1)
+ self.assertEqual(h._debuglevel, 4)
def test_http_doubleslash(self):
# Checks the presence of any unnecessary double slash in url does not
@@ -1140,15 +1181,15 @@ def test_errors(self):
r = MockResponse(200, "OK", {}, "", url)
newr = h.http_response(req, r)
self.assertIs(r, newr)
- self.assertFalse(hasattr(o, "proto")) # o.error not called
+ self.assertNotHasAttr(o, "proto") # o.error not called
r = MockResponse(202, "Accepted", {}, "", url)
newr = h.http_response(req, r)
self.assertIs(r, newr)
- self.assertFalse(hasattr(o, "proto")) # o.error not called
+ self.assertNotHasAttr(o, "proto") # o.error not called
r = MockResponse(206, "Partial content", {}, "", url)
newr = h.http_response(req, r)
self.assertIs(r, newr)
- self.assertFalse(hasattr(o, "proto")) # o.error not called
+ self.assertNotHasAttr(o, "proto") # o.error not called
# anything else calls o.error (and MockOpener returns None, here)
r = MockResponse(502, "Bad gateway", {}, "", url)
self.assertIsNone(h.http_response(req, r))
@@ -1179,7 +1220,7 @@ def test_redirect(self):
o = h.parent = MockOpener()
# ordinary redirect behaviour
- for code in 301, 302, 303, 307:
+ for code in 301, 302, 303, 307, 308:
for data in None, "blah\nblah\n":
method = getattr(h, "http_error_%s" % code)
req = Request(from_url, data)
@@ -1191,10 +1232,11 @@ def test_redirect(self):
try:
method(req, MockFile(), code, "Blah",
MockHeaders({"location": to_url}))
- except urllib.error.HTTPError:
- # 307 in response to POST requires user OK
- self.assertEqual(code, 307)
+ except urllib.error.HTTPError as err:
+ # 307 and 308 in response to POST require user OK
+ self.assertIn(code, (307, 308))
self.assertIsNotNone(data)
+ err.close()
self.assertEqual(o.req.get_full_url(), to_url)
try:
self.assertEqual(o.req.get_method(), "GET")
@@ -1230,9 +1272,10 @@ def redirect(h, req, url=to_url):
while 1:
redirect(h, req, "http://example.com/")
count = count + 1
- except urllib.error.HTTPError:
+ except urllib.error.HTTPError as err:
# don't stop until max_repeats, because cookies may introduce state
self.assertEqual(count, urllib.request.HTTPRedirectHandler.max_repeats)
+ err.close()
# detect endless non-repeating chain of redirects
req = Request(from_url, origin_req_host="example.com")
@@ -1242,9 +1285,10 @@ def redirect(h, req, url=to_url):
while 1:
redirect(h, req, "http://example.com/%d" % count)
count = count + 1
- except urllib.error.HTTPError:
+ except urllib.error.HTTPError as err:
self.assertEqual(count,
urllib.request.HTTPRedirectHandler.max_redirections)
+ err.close()
def test_invalid_redirect(self):
from_url = "http://example.com/a.html"
@@ -1258,9 +1302,11 @@ def test_invalid_redirect(self):
for scheme in invalid_schemes:
invalid_url = scheme + '://' + schemeless_url
- self.assertRaises(urllib.error.HTTPError, h.http_error_302,
+ with self.assertRaises(urllib.error.HTTPError) as cm:
+ h.http_error_302(
req, MockFile(), 302, "Security Loophole",
MockHeaders({"location": invalid_url}))
+ cm.exception.close()
for scheme in valid_schemes:
valid_url = scheme + '://' + schemeless_url
@@ -1288,7 +1334,7 @@ def test_cookie_redirect(self):
cj = CookieJar()
interact_netscape(cj, "http://www.example.com/", "spam=eggs")
- hh = MockHTTPHandler(302, "Location: http://www.cracker.com/\r\n\r\n")
+ hh = MockHTTPHandlerRedirect(302, "Location: http://www.cracker.com/\r\n\r\n")
hdeh = urllib.request.HTTPDefaultErrorHandler()
hrh = urllib.request.HTTPRedirectHandler()
cp = urllib.request.HTTPCookieProcessor(cj)
@@ -1298,7 +1344,7 @@ def test_cookie_redirect(self):
def test_redirect_fragment(self):
redirected_url = 'http://www.example.com/index.html#OK\r\n\r\n'
- hh = MockHTTPHandler(302, 'Location: ' + redirected_url)
+ hh = MockHTTPHandlerRedirect(302, 'Location: ' + redirected_url)
hdeh = urllib.request.HTTPDefaultErrorHandler()
hrh = urllib.request.HTTPRedirectHandler()
o = build_test_opener(hh, hdeh, hrh)
@@ -1358,7 +1404,16 @@ def http_open(self, req):
response = opener.open('http://example.com/')
expected = b'GET ' + result + b' '
request = handler.last_buf
- self.assertTrue(request.startswith(expected), repr(request))
+ self.assertStartsWith(request, expected)
+
+ def test_redirect_head_request(self):
+ from_url = "http://example.com/a.html"
+ to_url = "http://example.com/b.html"
+ h = urllib.request.HTTPRedirectHandler()
+ req = Request(from_url, method="HEAD")
+ fp = MockFile()
+ new_req = h.redirect_request(req, fp, 302, "Found", {}, to_url)
+ self.assertEqual(new_req.get_method(), "HEAD")
def test_proxy(self):
u = "proxy.example.com:3128"
@@ -1379,7 +1434,8 @@ def test_proxy(self):
[tup[0:2] for tup in o.calls])
def test_proxy_no_proxy(self):
- os.environ['no_proxy'] = 'python.org'
+ env = self.enterContext(os_helper.EnvironmentVarGuard())
+ env['no_proxy'] = 'python.org'
o = OpenerDirector()
ph = urllib.request.ProxyHandler(dict(http="proxy.example.com"))
o.add_handler(ph)
@@ -1391,10 +1447,10 @@ def test_proxy_no_proxy(self):
self.assertEqual(req.host, "www.python.org")
o.open(req)
self.assertEqual(req.host, "www.python.org")
- del os.environ['no_proxy']
def test_proxy_no_proxy_all(self):
- os.environ['no_proxy'] = '*'
+ env = self.enterContext(os_helper.EnvironmentVarGuard())
+ env['no_proxy'] = '*'
o = OpenerDirector()
ph = urllib.request.ProxyHandler(dict(http="proxy.example.com"))
o.add_handler(ph)
@@ -1402,7 +1458,6 @@ def test_proxy_no_proxy_all(self):
self.assertEqual(req.host, "www.python.org")
o.open(req)
self.assertEqual(req.host, "www.python.org")
- del os.environ['no_proxy']
def test_proxy_https(self):
o = OpenerDirector()
@@ -1420,6 +1475,7 @@ def test_proxy_https(self):
self.assertEqual([(handlers[0], "https_open")],
[tup[0:2] for tup in o.calls])
+ @unittest.skipUnless(hasattr(http.client, 'HTTPSConnection'), 'HTTPSConnection required for HTTPS tests.')
def test_proxy_https_proxy_authorization(self):
o = OpenerDirector()
ph = urllib.request.ProxyHandler(dict(https='proxy.example.com:3128'))
@@ -1443,6 +1499,30 @@ def test_proxy_https_proxy_authorization(self):
self.assertEqual(req.host, "proxy.example.com:3128")
self.assertEqual(req.get_header("Proxy-authorization"), "FooBar")
+ @unittest.skipUnless(os.name == "nt", "only relevant for Windows")
+ def test_winreg_proxy_bypass(self):
+ proxy_override = "www.example.com;*.example.net; 192.168.0.1"
+ proxy_bypass = _proxy_bypass_winreg_override
+ for host in ("www.example.com", "www.example.net", "192.168.0.1"):
+ self.assertTrue(proxy_bypass(host, proxy_override),
+ "expected bypass of %s to be true" % host)
+
+ for host in ("example.com", "www.example.org", "example.net",
+ "192.168.0.2"):
+ self.assertFalse(proxy_bypass(host, proxy_override),
+ "expected bypass of %s to be False" % host)
+
+ # check intranet address bypass
+ proxy_override = "example.com; "
+ self.assertTrue(proxy_bypass("example.com", proxy_override),
+ "expected bypass of %s to be true" % host)
+ self.assertFalse(proxy_bypass("example.net", proxy_override),
+ "expected bypass of %s to be False" % host)
+ for host in ("test", "localhost"):
+ self.assertTrue(proxy_bypass(host, proxy_override),
+ "expect to bypass intranet address '%s'"
+ % host)
+
@unittest.skipUnless(sys.platform == 'darwin', "only relevant for OSX")
def test_osx_proxy_bypass(self):
bypass = {
@@ -1483,7 +1563,7 @@ def check_basic_auth(self, headers, realm):
password_manager = MockPasswordManager()
auth_handler = urllib.request.HTTPBasicAuthHandler(password_manager)
body = '\r\n'.join(headers) + '\r\n\r\n'
- http_handler = MockHTTPHandler(401, body)
+ http_handler = MockHTTPHandlerRedirect(401, body)
opener.add_handler(auth_handler)
opener.add_handler(http_handler)
self._test_basic_auth(opener, auth_handler, "Authorization",
@@ -1543,7 +1623,7 @@ def test_proxy_basic_auth(self):
password_manager = MockPasswordManager()
auth_handler = urllib.request.ProxyBasicAuthHandler(password_manager)
realm = "ACME Networks"
- http_handler = MockHTTPHandler(
+ http_handler = MockHTTPHandlerRedirect(
407, 'Proxy-Authenticate: Basic realm="%s"\r\n\r\n' % realm)
opener.add_handler(auth_handler)
opener.add_handler(http_handler)
@@ -1555,11 +1635,11 @@ def test_proxy_basic_auth(self):
def test_basic_and_digest_auth_handlers(self):
# HTTPDigestAuthHandler raised an exception if it couldn't handle a 40*
- # response (http://python.org/sf/1479302), where it should instead
+ # response (https://bugs.python.org/issue1479302), where it should instead
# return None to allow another handler (especially
# HTTPBasicAuthHandler) to handle the response.
- # Also (http://python.org/sf/14797027, RFC 2617 section 1.2), we must
+ # Also (https://bugs.python.org/issue14797027, RFC 2617 section 1.2), we must
# try digest first (since it's the strongest auth scheme), so we record
# order of calls here to check digest comes first:
class RecordingOpenerDirector(OpenerDirector):
@@ -1587,7 +1667,7 @@ def http_error_401(self, *args, **kwds):
digest_handler = TestDigestAuthHandler(password_manager)
basic_handler = TestBasicAuthHandler(password_manager)
realm = "ACME Networks"
- http_handler = MockHTTPHandler(
+ http_handler = MockHTTPHandlerRedirect(
401, 'WWW-Authenticate: Basic realm="%s"\r\n\r\n' % realm)
opener.add_handler(basic_handler)
opener.add_handler(digest_handler)
@@ -1607,7 +1687,7 @@ def test_unsupported_auth_digest_handler(self):
opener = OpenerDirector()
# While using DigestAuthHandler
digest_auth_handler = urllib.request.HTTPDigestAuthHandler(None)
- http_handler = MockHTTPHandler(
+ http_handler = MockHTTPHandlerRedirect(
401, 'WWW-Authenticate: Kerberos\r\n\r\n')
opener.add_handler(digest_auth_handler)
opener.add_handler(http_handler)
@@ -1617,7 +1697,7 @@ def test_unsupported_auth_basic_handler(self):
# While using BasicAuthHandler
opener = OpenerDirector()
basic_auth_handler = urllib.request.HTTPBasicAuthHandler(None)
- http_handler = MockHTTPHandler(
+ http_handler = MockHTTPHandlerRedirect(
401, 'WWW-Authenticate: NTLM\r\n\r\n')
opener.add_handler(basic_auth_handler)
opener.add_handler(http_handler)
@@ -1704,7 +1784,7 @@ def test_basic_prior_auth_send_after_first_success(self):
opener = OpenerDirector()
opener.add_handler(auth_prior_handler)
- http_handler = MockHTTPHandler(
+ http_handler = MockHTTPHandlerRedirect(
401, 'WWW-Authenticate: Basic realm="%s"\r\n\r\n' % None)
opener.add_handler(http_handler)
@@ -1755,7 +1835,7 @@ def test_invalid_closed(self):
self.assertTrue(conn.fakesock.closed, "Connection not closed")
-class MiscTests(unittest.TestCase):
+class MiscTests(unittest.TestCase, ExtraAssertions):
def opener_has_handler(self, opener, handler_class):
self.assertTrue(any(h.__class__ == handler_class
@@ -1814,14 +1894,21 @@ def test_HTTPError_interface(self):
url = code = fp = None
hdrs = 'Content-Length: 42'
err = urllib.error.HTTPError(url, code, msg, hdrs, fp)
- self.assertTrue(hasattr(err, 'reason'))
+ self.assertHasAttr(err, 'reason')
self.assertEqual(err.reason, 'something bad happened')
- self.assertTrue(hasattr(err, 'headers'))
+ self.assertHasAttr(err, 'headers')
self.assertEqual(err.headers, 'Content-Length: 42')
expected_errmsg = 'HTTP Error %s: %s' % (err.code, err.msg)
self.assertEqual(str(err), expected_errmsg)
expected_errmsg = '' % (err.code, err.msg)
self.assertEqual(repr(err), expected_errmsg)
+ err.close()
+
+ def test_gh_98778(self):
+ x = urllib.error.HTTPError("url", 405, "METHOD NOT ALLOWED", None, None)
+ self.assertEqual(getattr(x, "__notes__", ()), ())
+ self.assertIsInstance(x.fp.read(), bytes)
+ x.close()
def test_parse_proxy(self):
parse_proxy_test_cases = [
diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py
index 2c54ef85b4b..9a899785116 100644
--- a/Lib/test/test_urllib2_localnet.py
+++ b/Lib/test/test_urllib2_localnet.py
@@ -8,15 +8,18 @@
import unittest
import hashlib
+from test import support
from test.support import hashlib_helper
from test.support import threading_helper
-from test.support import warnings_helper
+from test.support.testcase import ExtraAssertions
try:
import ssl
except ImportError:
ssl = None
+support.requires_working_socket(module=True)
+
here = os.path.dirname(__file__)
# Self-signed cert file for 'localhost'
CERT_localhost = os.path.join(here, 'certdata', 'keycert.pem')
@@ -314,7 +317,9 @@ def test_basic_auth_httperror(self):
ah = urllib.request.HTTPBasicAuthHandler()
ah.add_password(self.REALM, self.server_url, self.USER, self.INCORRECT_PASSWD)
urllib.request.install_opener(urllib.request.build_opener(ah))
- self.assertRaises(urllib.error.HTTPError, urllib.request.urlopen, self.server_url)
+ with self.assertRaises(urllib.error.HTTPError) as cm:
+ urllib.request.urlopen(self.server_url)
+ cm.exception.close()
@hashlib_helper.requires_hashdigest("md5", openssl=True)
@@ -356,23 +361,23 @@ def stop_server(self):
self.server.stop()
self.server = None
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_proxy_with_bad_password_raises_httperror(self):
self.proxy_digest_handler.add_password(self.REALM, self.URL,
self.USER, self.PASSWD+"bad")
self.digest_auth_handler.set_qop("auth")
- self.assertRaises(urllib.error.HTTPError,
- self.opener.open,
- self.URL)
+ with self.assertRaises(urllib.error.HTTPError) as cm:
+ self.opener.open(self.URL)
+ cm.exception.close()
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_proxy_with_no_password_raises_httperror(self):
self.digest_auth_handler.set_qop("auth")
- self.assertRaises(urllib.error.HTTPError,
- self.opener.open,
- self.URL)
+ with self.assertRaises(urllib.error.HTTPError) as cm:
+ self.opener.open(self.URL)
+ cm.exception.close()
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_proxy_qop_auth_works(self):
self.proxy_digest_handler.add_password(self.REALM, self.URL,
self.USER, self.PASSWD)
@@ -381,7 +386,7 @@ def test_proxy_qop_auth_works(self):
while result.read():
pass
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_proxy_qop_auth_int_works_or_throws_urlerror(self):
self.proxy_digest_handler.add_password(self.REALM, self.URL,
self.USER, self.PASSWD)
@@ -442,7 +447,7 @@ def log_message(self, *args):
return FakeHTTPRequestHandler
-class TestUrlopen(unittest.TestCase):
+class TestUrlopen(unittest.TestCase, ExtraAssertions):
"""Tests urllib.request.urlopen using the network.
These tests are not exhaustive. Assuming that testing using files does a
@@ -506,7 +511,7 @@ def start_https_server(self, responses=None, **kwargs):
handler.port = server.port
return handler
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_redirection(self):
expected_response = b"We got here..."
responses = [
@@ -520,7 +525,7 @@ def test_redirection(self):
self.assertEqual(data, expected_response)
self.assertEqual(handler.requests, ["/", "/somewhere_else"])
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_chunked(self):
expected_response = b"hello world"
chunked_start = (
@@ -535,7 +540,7 @@ def test_chunked(self):
data = self.urlopen("http://localhost:%s/" % handler.port)
self.assertEqual(data, expected_response)
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_404(self):
expected_response = b"Bad bad bad..."
handler = self.start_server([(404, [], expected_response)])
@@ -551,7 +556,7 @@ def test_404(self):
self.assertEqual(data, expected_response)
self.assertEqual(handler.requests, ["/weeble"])
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_200(self):
expected_response = b"pycon 2008..."
handler = self.start_server([(200, [], expected_response)])
@@ -559,7 +564,7 @@ def test_200(self):
self.assertEqual(data, expected_response)
self.assertEqual(handler.requests, ["/bizarre"])
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_200_with_parameters(self):
expected_response = b"pycon 2008..."
handler = self.start_server([(200, [], expected_response)])
@@ -568,41 +573,14 @@ def test_200_with_parameters(self):
self.assertEqual(data, expected_response)
self.assertEqual(handler.requests, ["/bizarre", b"get=with_feeling"])
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_https(self):
handler = self.start_https_server()
context = ssl.create_default_context(cafile=CERT_localhost)
data = self.urlopen("https://localhost:%s/bizarre" % handler.port, context=context)
self.assertEqual(data, b"we care a bit")
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
- def test_https_with_cafile(self):
- handler = self.start_https_server(certfile=CERT_localhost)
- with warnings_helper.check_warnings(('', DeprecationWarning)):
- # Good cert
- data = self.urlopen("https://localhost:%s/bizarre" % handler.port,
- cafile=CERT_localhost)
- self.assertEqual(data, b"we care a bit")
- # Bad cert
- with self.assertRaises(urllib.error.URLError) as cm:
- self.urlopen("https://localhost:%s/bizarre" % handler.port,
- cafile=CERT_fakehostname)
- # Good cert, but mismatching hostname
- handler = self.start_https_server(certfile=CERT_fakehostname)
- with self.assertRaises(urllib.error.URLError) as cm:
- self.urlopen("https://localhost:%s/bizarre" % handler.port,
- cafile=CERT_fakehostname)
-
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
- def test_https_with_cadefault(self):
- handler = self.start_https_server(certfile=CERT_localhost)
- # Self-signed cert should fail verification with system certificate store
- with warnings_helper.check_warnings(('', DeprecationWarning)):
- with self.assertRaises(urllib.error.URLError) as cm:
- self.urlopen("https://localhost:%s/bizarre" % handler.port,
- cadefault=True)
-
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_https_sni(self):
if ssl is None:
self.skipTest("ssl module required")
@@ -619,7 +597,7 @@ def cb_sni(ssl_sock, server_name, initial_context):
self.urlopen("https://localhost:%s" % handler.port, context=context)
self.assertEqual(sni_name, "localhost")
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_sending_headers(self):
handler = self.start_server()
req = urllib.request.Request("http://localhost:%s/" % handler.port,
@@ -628,7 +606,7 @@ def test_sending_headers(self):
pass
self.assertEqual(handler.headers_received["Range"], "bytes=20-39")
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_sending_headers_camel(self):
handler = self.start_server()
req = urllib.request.Request("http://localhost:%s/" % handler.port,
@@ -638,16 +616,15 @@ def test_sending_headers_camel(self):
self.assertIn("X-Some-Header", handler.headers_received.keys())
self.assertNotIn("X-SoMe-hEader", handler.headers_received.keys())
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_basic(self):
handler = self.start_server()
with urllib.request.urlopen("http://localhost:%s" % handler.port) as open_url:
for attr in ("read", "close", "info", "geturl"):
- self.assertTrue(hasattr(open_url, attr), "object returned from "
- "urlopen lacks the %s attribute" % attr)
+ self.assertHasAttr(open_url, attr)
self.assertTrue(open_url.read(), "calling 'read' failed")
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_info(self):
handler = self.start_server()
open_url = urllib.request.urlopen(
@@ -659,7 +636,7 @@ def test_info(self):
"instance of email.message.Message")
self.assertEqual(info_obj.get_content_subtype(), "plain")
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_geturl(self):
# Make sure same URL as opened is returned by geturl.
handler = self.start_server()
@@ -668,7 +645,7 @@ def test_geturl(self):
url = open_url.geturl()
self.assertEqual(url, "http://localhost:%s" % handler.port)
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_iteration(self):
expected_response = b"pycon 2008..."
handler = self.start_server([(200, [], expected_response)])
@@ -676,7 +653,7 @@ def test_iteration(self):
for line in data:
self.assertEqual(line, expected_response)
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_line_iteration(self):
lines = [b"We\n", b"got\n", b"here\n", b"verylong " * 8192 + b"\n"]
expected_response = b"".join(lines)
@@ -689,7 +666,7 @@ def test_line_iteration(self):
(index, len(lines[index]), len(line)))
self.assertEqual(index + 1, len(lines))
- @unittest.skipIf(os.name == "nt", "TODO: RUSTPYTHON, ValueError: illegal environment variable name")
+ @unittest.skipIf(os.name == 'nt', 'TODO: RUSTPYTHON; ValueError: illegal environment variable name')
def test_issue16464(self):
# See https://bugs.python.org/issue16464
# and https://bugs.python.org/issue46648
@@ -709,6 +686,7 @@ def test_issue16464(self):
self.assertEqual(b"1234567890", request.data)
self.assertEqual("10", request.get_header("Content-length"))
+
def setUpModule():
thread_info = threading_helper.threading_setup()
unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info)
diff --git a/Lib/test/test_urllib2net.py b/Lib/test/test_urllib2net.py
index c70b522d31d..41f170a6ad5 100644
--- a/Lib/test/test_urllib2net.py
+++ b/Lib/test/test_urllib2net.py
@@ -137,7 +137,6 @@ def setUp(self):
# XXX The rest of these tests aren't very good -- they don't check much.
# They do sometimes catch some major disasters, though.
- @unittest.expectedFailure # TODO: RUSTPYTHON urllib.error.URLError:
@support.requires_resource('walltime')
def test_ftp(self):
# Testing the same URL twice exercises the caching in CacheFTPHandler
diff --git a/Lib/test/test_urllib_response.py b/Lib/test/test_urllib_response.py
index 73d2ef0424f..d949fa38bfc 100644
--- a/Lib/test/test_urllib_response.py
+++ b/Lib/test/test_urllib_response.py
@@ -4,6 +4,11 @@
import tempfile
import urllib.response
import unittest
+from test import support
+
+if support.is_wasi:
+ raise unittest.SkipTest("Cannot create socket on WASI")
+
class TestResponse(unittest.TestCase):
@@ -43,6 +48,7 @@ def test_addinfo(self):
info = urllib.response.addinfo(self.fp, self.test_headers)
self.assertEqual(info.info(), self.test_headers)
self.assertEqual(info.headers, self.test_headers)
+ info.close()
def test_addinfourl(self):
url = "http://www.python.org"
@@ -55,6 +61,7 @@ def test_addinfourl(self):
self.assertEqual(infourl.headers, self.test_headers)
self.assertEqual(infourl.url, url)
self.assertEqual(infourl.status, code)
+ infourl.close()
def tearDown(self):
self.sock.close()
diff --git a/Lib/test/test_wsgiref.py b/Lib/test/test_wsgiref.py
index 1a3b4d4b721..d546e3ef219 100644
--- a/Lib/test/test_wsgiref.py
+++ b/Lib/test/test_wsgiref.py
@@ -134,7 +134,6 @@ def test_environ(self):
b"Python test,Python test 2;query=test;/path/"
)
- @unittest.expectedFailure # TODO: RUSTPYTHON; http library needs to be updated
def test_request_length(self):
out, err = run_amock(data=b"GET " + (b"x" * 65537) + b" HTTP/1.0\n\n")
self.assertEqual(out.splitlines()[0],
diff --git a/Lib/urllib/error.py b/Lib/urllib/error.py
index 8cd901f13f8..a9cd1ecadd6 100644
--- a/Lib/urllib/error.py
+++ b/Lib/urllib/error.py
@@ -10,7 +10,7 @@
an application may want to handle an exception like a regular
response.
"""
-
+import io
import urllib.response
__all__ = ['URLError', 'HTTPError', 'ContentTooShortError']
@@ -42,12 +42,9 @@ def __init__(self, url, code, msg, hdrs, fp):
self.hdrs = hdrs
self.fp = fp
self.filename = url
- # The addinfourl classes depend on fp being a valid file
- # object. In some cases, the HTTPError may not have a valid
- # file object. If this happens, the simplest workaround is to
- # not initialize the base classes.
- if fp is not None:
- self.__super_init(fp, hdrs, url, code)
+ if fp is None:
+ fp = io.BytesIO()
+ self.__super_init(fp, hdrs, url, code)
def __str__(self):
return 'HTTP Error %s: %s' % (self.code, self.msg)
diff --git a/Lib/urllib/parse.py b/Lib/urllib/parse.py
index b35997bc00c..c72138a33ca 100644
--- a/Lib/urllib/parse.py
+++ b/Lib/urllib/parse.py
@@ -25,13 +25,19 @@
scenarios for parsing, and for backward compatibility purposes, some
parsing quirks from older RFCs are retained. The testcases in
test_urlparse.py provides a good indicator of parsing behavior.
+
+The WHATWG URL Parser spec should also be considered. We are not compliant with
+it either due to existing user code API behavior expectations (Hyrum's Law).
+It serves as a useful guide when making changes.
"""
+from collections import namedtuple
+import functools
+import math
import re
-import sys
import types
-import collections
import warnings
+import ipaddress
__all__ = ["urlparse", "urlunparse", "urljoin", "urldefrag",
"urlsplit", "urlunsplit", "urlencode", "parse_qs",
@@ -46,18 +52,18 @@
uses_relative = ['', 'ftp', 'http', 'gopher', 'nntp', 'imap',
'wais', 'file', 'https', 'shttp', 'mms',
- 'prospero', 'rtsp', 'rtspu', 'sftp',
+ 'prospero', 'rtsp', 'rtsps', 'rtspu', 'sftp',
'svn', 'svn+ssh', 'ws', 'wss']
uses_netloc = ['', 'ftp', 'http', 'gopher', 'nntp', 'telnet',
'imap', 'wais', 'file', 'mms', 'https', 'shttp',
- 'snews', 'prospero', 'rtsp', 'rtspu', 'rsync',
+ 'snews', 'prospero', 'rtsp', 'rtsps', 'rtspu', 'rsync',
'svn', 'svn+ssh', 'sftp', 'nfs', 'git', 'git+ssh',
- 'ws', 'wss']
+ 'ws', 'wss', 'itms-services']
uses_params = ['', 'ftp', 'hdl', 'prospero', 'http', 'imap',
- 'https', 'shttp', 'rtsp', 'rtspu', 'sip', 'sips',
- 'mms', 'sftp', 'tel']
+ 'https', 'shttp', 'rtsp', 'rtsps', 'rtspu', 'sip',
+ 'sips', 'mms', 'sftp', 'tel']
# These are not actually used anymore, but should stay for backwards
# compatibility. (They are undocumented, but have a public-looking name.)
@@ -66,7 +72,7 @@
'telnet', 'wais', 'imap', 'snews', 'sip', 'sips']
uses_query = ['', 'http', 'wais', 'imap', 'https', 'shttp', 'mms',
- 'gopher', 'rtsp', 'rtspu', 'sip', 'sips']
+ 'gopher', 'rtsp', 'rtsps', 'rtspu', 'sip', 'sips']
uses_fragment = ['', 'ftp', 'hdl', 'http', 'gopher', 'news',
'nntp', 'wais', 'https', 'shttp', 'snews',
@@ -78,18 +84,17 @@
'0123456789'
'+-.')
+# Leading and trailing C0 control and space to be stripped per WHATWG spec.
+# == "".join([chr(i) for i in range(0, 0x20 + 1)])
+_WHATWG_C0_CONTROL_OR_SPACE = '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f '
+
# Unsafe bytes to be removed per WHATWG spec
_UNSAFE_URL_BYTES_TO_REMOVE = ['\t', '\r', '\n']
-# XXX: Consider replacing with functools.lru_cache
-MAX_CACHE_SIZE = 20
-_parse_cache = {}
-
def clear_cache():
- """Clear the parse cache and the quoters cache."""
- _parse_cache.clear()
- _safe_quoters.clear()
-
+ """Clear internal performance caches. Undocumented; some tests want it."""
+ urlsplit.cache_clear()
+ _byte_quoter_factory.cache_clear()
# Helpers for bytes handling
# For 3.2, we deliberately require applications that
@@ -171,12 +176,11 @@ def hostname(self):
def port(self):
port = self._hostinfo[1]
if port is not None:
- try:
- port = int(port, 10)
- except ValueError:
- message = f'Port could not be cast to integer value as {port!r}'
- raise ValueError(message) from None
- if not ( 0 <= port <= 65535):
+ if port.isdigit() and port.isascii():
+ port = int(port)
+ else:
+ raise ValueError(f"Port could not be cast to integer value as {port!r}")
+ if not (0 <= port <= 65535):
raise ValueError("Port out of range 0-65535")
return port
@@ -243,8 +247,6 @@ def _hostinfo(self):
return hostname, port
-from collections import namedtuple
-
_DefragResultBase = namedtuple('DefragResult', 'url fragment')
_SplitResultBase = namedtuple(
'SplitResult', 'scheme netloc path query fragment')
@@ -434,6 +436,37 @@ def _checknetloc(netloc):
raise ValueError("netloc '" + netloc + "' contains invalid " +
"characters under NFKC normalization")
+def _check_bracketed_netloc(netloc):
+ # Note that this function must mirror the splitting
+ # done in NetlocResultMixins._hostinfo().
+ hostname_and_port = netloc.rpartition('@')[2]
+ before_bracket, have_open_br, bracketed = hostname_and_port.partition('[')
+ if have_open_br:
+ # No data is allowed before a bracket.
+ if before_bracket:
+ raise ValueError("Invalid IPv6 URL")
+ hostname, _, port = bracketed.partition(']')
+ # No data is allowed after the bracket but before the port delimiter.
+ if port and not port.startswith(":"):
+ raise ValueError("Invalid IPv6 URL")
+ else:
+ hostname, _, port = hostname_and_port.partition(':')
+ _check_bracketed_host(hostname)
+
+# Valid bracketed hosts are defined in
+# https://www.rfc-editor.org/rfc/rfc3986#page-49 and https://url.spec.whatwg.org/
+def _check_bracketed_host(hostname):
+ if hostname.startswith('v'):
+ if not re.match(r"\Av[a-fA-F0-9]+\..+\Z", hostname):
+ raise ValueError(f"IPvFuture address is invalid")
+ else:
+ ip = ipaddress.ip_address(hostname) # Throws Value Error if not IPv6 or IPv4
+ if isinstance(ip, ipaddress.IPv4Address):
+ raise ValueError(f"An IPv4 address cannot be in brackets")
+
+# typed=True avoids BytesWarnings being emitted during cache key
+# comparison since this API supports both bytes and str input.
+@functools.lru_cache(typed=True)
def urlsplit(url, scheme='', allow_fragments=True):
"""Parse a URL into 5 components:
:///?#
@@ -456,39 +489,37 @@ def urlsplit(url, scheme='', allow_fragments=True):
"""
url, scheme, _coerce_result = _coerce_args(url, scheme)
+ # Only lstrip url as some applications rely on preserving trailing space.
+ # (https://url.spec.whatwg.org/#concept-basic-url-parser would strip both)
+ url = url.lstrip(_WHATWG_C0_CONTROL_OR_SPACE)
+ scheme = scheme.strip(_WHATWG_C0_CONTROL_OR_SPACE)
for b in _UNSAFE_URL_BYTES_TO_REMOVE:
url = url.replace(b, "")
scheme = scheme.replace(b, "")
allow_fragments = bool(allow_fragments)
- key = url, scheme, allow_fragments, type(url), type(scheme)
- cached = _parse_cache.get(key, None)
- if cached:
- return _coerce_result(cached)
- if len(_parse_cache) >= MAX_CACHE_SIZE: # avoid runaway growth
- clear_cache()
netloc = query = fragment = ''
i = url.find(':')
- if i > 0:
+ if i > 0 and url[0].isascii() and url[0].isalpha():
for c in url[:i]:
if c not in scheme_chars:
break
else:
scheme, url = url[:i].lower(), url[i+1:]
-
if url[:2] == '//':
netloc, url = _splitnetloc(url, 2)
if (('[' in netloc and ']' not in netloc) or
(']' in netloc and '[' not in netloc)):
raise ValueError("Invalid IPv6 URL")
+ if '[' in netloc and ']' in netloc:
+ _check_bracketed_netloc(netloc)
if allow_fragments and '#' in url:
url, fragment = url.split('#', 1)
if '?' in url:
url, query = url.split('?', 1)
_checknetloc(netloc)
v = SplitResult(scheme, netloc, url, query, fragment)
- _parse_cache[key] = v
return _coerce_result(v)
def urlunparse(components):
@@ -510,9 +541,13 @@ def urlunsplit(components):
empty query; the RFC states that these are equivalent)."""
scheme, netloc, url, query, fragment, _coerce_result = (
_coerce_args(*components))
- if netloc or (scheme and scheme in uses_netloc and url[:2] != '//'):
+ if netloc:
if url and url[:1] != '/': url = '/' + url
- url = '//' + (netloc or '') + url
+ url = '//' + netloc + url
+ elif url[:2] == '//':
+ url = '//' + url
+ elif scheme and scheme in uses_netloc and (not url or url[:1] == '/'):
+ url = '//' + url
if scheme:
url = scheme + ':' + url
if query:
@@ -611,6 +646,9 @@ def urldefrag(url):
def unquote_to_bytes(string):
"""unquote_to_bytes('abc%20def') -> b'abc def'."""
+ return bytes(_unquote_impl(string))
+
+def _unquote_impl(string: bytes | bytearray | str) -> bytes | bytearray:
# Note: strings are encoded as UTF-8. This is only an issue if it contains
# unescaped non-ASCII characters, which URIs should not.
if not string:
@@ -622,8 +660,8 @@ def unquote_to_bytes(string):
bits = string.split(b'%')
if len(bits) == 1:
return string
- res = [bits[0]]
- append = res.append
+ res = bytearray(bits[0])
+ append = res.extend
# Delay the initialization of the table to not waste memory
# if the function is never called
global _hextobyte
@@ -637,10 +675,20 @@ def unquote_to_bytes(string):
except KeyError:
append(b'%')
append(item)
- return b''.join(res)
+ return res
_asciire = re.compile('([\x00-\x7f]+)')
+def _generate_unquoted_parts(string, encoding, errors):
+ previous_match_end = 0
+ for ascii_match in _asciire.finditer(string):
+ start, end = ascii_match.span()
+ yield string[previous_match_end:start] # Non-ASCII
+ # The ascii_match[1] group == string[start:end].
+ yield _unquote_impl(ascii_match[1]).decode(encoding, errors)
+ previous_match_end = end
+ yield string[previous_match_end:] # Non-ASCII tail
+
def unquote(string, encoding='utf-8', errors='replace'):
"""Replace %xx escapes by their single-character equivalent. The optional
encoding and errors parameters specify how to decode percent-encoded
@@ -652,21 +700,16 @@ def unquote(string, encoding='utf-8', errors='replace'):
unquote('abc%20def') -> 'abc def'.
"""
if isinstance(string, bytes):
- return unquote_to_bytes(string).decode(encoding, errors)
+ return _unquote_impl(string).decode(encoding, errors)
if '%' not in string:
+ # Is it a string-like object?
string.split
return string
if encoding is None:
encoding = 'utf-8'
if errors is None:
errors = 'replace'
- bits = _asciire.split(string)
- res = [bits[0]]
- append = res.append
- for i in range(1, len(bits), 2):
- append(unquote_to_bytes(bits[i]).decode(encoding, errors))
- append(bits[i + 1])
- return ''.join(res)
+ return ''.join(_generate_unquoted_parts(string, encoding, errors))
def parse_qs(qs, keep_blank_values=False, strict_parsing=False,
@@ -740,11 +783,29 @@ def parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
Returns a list, as G-d intended.
"""
- qs, _coerce_result = _coerce_args(qs)
- separator, _ = _coerce_args(separator)
- if not separator or (not isinstance(separator, (str, bytes))):
+ if not separator or not isinstance(separator, (str, bytes)):
raise ValueError("Separator must be of type string or bytes.")
+ if isinstance(qs, str):
+ if not isinstance(separator, str):
+ separator = str(separator, 'ascii')
+ eq = '='
+ def _unquote(s):
+ return unquote_plus(s, encoding=encoding, errors=errors)
+ else:
+ if not qs:
+ return []
+ # Use memoryview() to reject integers and iterables,
+ # acceptable by the bytes constructor.
+ qs = bytes(memoryview(qs))
+ if isinstance(separator, str):
+ separator = bytes(separator, 'ascii')
+ eq = b'='
+ def _unquote(s):
+ return unquote_to_bytes(s.replace(b'+', b' '))
+
+ if not qs:
+ return []
# If max_num_fields is defined then check that the number of fields
# is less than max_num_fields. This prevents a memory exhaustion DOS
@@ -756,25 +817,14 @@ def parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
r = []
for name_value in qs.split(separator):
- if not name_value and not strict_parsing:
- continue
- nv = name_value.split('=', 1)
- if len(nv) != 2:
- if strict_parsing:
+ if name_value or strict_parsing:
+ name, has_eq, value = name_value.partition(eq)
+ if not has_eq and strict_parsing:
raise ValueError("bad query field: %r" % (name_value,))
- # Handle case of a control-name with no equal sign
- if keep_blank_values:
- nv.append('')
- else:
- continue
- if len(nv[1]) or keep_blank_values:
- name = nv[0].replace('+', ' ')
- name = unquote(name, encoding=encoding, errors=errors)
- name = _coerce_result(name)
- value = nv[1].replace('+', ' ')
- value = unquote(value, encoding=encoding, errors=errors)
- value = _coerce_result(value)
- r.append((name, value))
+ if value or keep_blank_values:
+ name = _unquote(name)
+ value = _unquote(value)
+ r.append((name, value))
return r
def unquote_plus(string, encoding='utf-8', errors='replace'):
@@ -791,23 +841,30 @@ def unquote_plus(string, encoding='utf-8', errors='replace'):
b'0123456789'
b'_.-~')
_ALWAYS_SAFE_BYTES = bytes(_ALWAYS_SAFE)
-_safe_quoters = {}
-class Quoter(collections.defaultdict):
- """A mapping from bytes (in range(0,256)) to strings.
+def __getattr__(name):
+ if name == 'Quoter':
+ warnings.warn('Deprecated in 3.11. '
+ 'urllib.parse.Quoter will be removed in Python 3.14. '
+ 'It was not intended to be a public API.',
+ DeprecationWarning, stacklevel=2)
+ return _Quoter
+ raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
+
+class _Quoter(dict):
+ """A mapping from bytes numbers (in range(0,256)) to strings.
String values are percent-encoded byte values, unless the key < 128, and
- in the "safe" set (either the specified safe set, or default set).
+ in either of the specified safe set, or the always safe set.
"""
- # Keeps a cache internally, using defaultdict, for efficiency (lookups
+ # Keeps a cache internally, via __missing__, for efficiency (lookups
# of cached keys don't call Python code at all).
def __init__(self, safe):
"""safe: bytes object."""
self.safe = _ALWAYS_SAFE.union(safe)
def __repr__(self):
- # Without this, will just display as a defaultdict
- return "<%s %r>" % (self.__class__.__name__, dict(self))
+ return f""
def __missing__(self, b):
# Handle a cache miss. Store quoted string in cache and return.
@@ -886,6 +943,11 @@ def quote_plus(string, safe='', encoding=None, errors=None):
string = quote(string, safe + space, encoding, errors)
return string.replace(' ', '+')
+# Expectation: A typical program is unlikely to create more than 5 of these.
+@functools.lru_cache
+def _byte_quoter_factory(safe):
+ return _Quoter(safe).__getitem__
+
def quote_from_bytes(bs, safe='/'):
"""Like quote(), but accepts a bytes object rather than a str, and does
not perform string-to-bytes encoding. It always returns an ASCII string.
@@ -899,14 +961,19 @@ def quote_from_bytes(bs, safe='/'):
# Normalize 'safe' by converting to bytes and removing non-ASCII chars
safe = safe.encode('ascii', 'ignore')
else:
+ # List comprehensions are faster than generator expressions.
safe = bytes([c for c in safe if c < 128])
if not bs.rstrip(_ALWAYS_SAFE_BYTES + safe):
return bs.decode()
- try:
- quoter = _safe_quoters[safe]
- except KeyError:
- _safe_quoters[safe] = quoter = Quoter(safe).__getitem__
- return ''.join([quoter(char) for char in bs])
+ quoter = _byte_quoter_factory(safe)
+ if (bs_len := len(bs)) < 200_000:
+ return ''.join(map(quoter, bs))
+ else:
+ # This saves memory - https://github.com/python/cpython/issues/95865
+ chunk_size = math.isqrt(bs_len)
+ chunks = [''.join(map(quoter, bs[i:i+chunk_size]))
+ for i in range(0, bs_len, chunk_size)]
+ return ''.join(chunks)
def urlencode(query, doseq=False, safe='', encoding=None, errors=None,
quote_via=quote_plus):
@@ -939,10 +1006,9 @@ def urlencode(query, doseq=False, safe='', encoding=None, errors=None,
# but that's a minor nit. Since the original implementation
# allowed empty dicts that type of behavior probably should be
# preserved for consistency
- except TypeError:
- ty, va, tb = sys.exc_info()
+ except TypeError as err:
raise TypeError("not a valid non-string sequence "
- "or mapping object").with_traceback(tb)
+ "or mapping object") from err
l = []
if not doseq:
@@ -1125,15 +1191,15 @@ def splitnport(host, defport=-1):
def _splitnport(host, defport=-1):
"""Split host and port, returning numeric port.
Return given default port if no ':' found; defaults to -1.
- Return numerical port if a valid number are found after ':'.
+ Return numerical port if a valid number is found after ':'.
Return None if ':' but not a valid number."""
host, delim, port = host.rpartition(':')
if not delim:
host = port
elif port:
- try:
+ if port.isdigit() and port.isascii():
nport = int(port)
- except ValueError:
+ else:
nport = None
return host, nport
return host, defport
diff --git a/Lib/urllib/request.py b/Lib/urllib/request.py
index a0ef60b30de..21d76913feb 100644
--- a/Lib/urllib/request.py
+++ b/Lib/urllib/request.py
@@ -11,8 +11,8 @@
Handlers needed to open the requested URL. For example, the
HTTPHandler performs HTTP GET and POST requests and deals with
non-error returns. The HTTPRedirectHandler automatically deals with
-HTTP 301, 302, 303 and 307 redirect errors, and the HTTPDigestAuthHandler
-deals with digest authentication.
+HTTP 301, 302, 303, 307, and 308 redirect errors, and the
+HTTPDigestAuthHandler deals with digest authentication.
urlopen(url, data=None) -- Basic usage is the same as original
urllib. pass the url and optionally data to post to an HTTP URL, and
@@ -88,7 +88,6 @@
import http.client
import io
import os
-import posixpath
import re
import socket
import string
@@ -137,7 +136,7 @@
_opener = None
def urlopen(url, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
- *, cafile=None, capath=None, cadefault=False, context=None):
+ *, context=None):
'''Open the URL url, which can be either a string or a Request object.
*data* must be an object specifying additional data to be sent to
@@ -155,14 +154,6 @@ def urlopen(url, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
If *context* is specified, it must be a ssl.SSLContext instance describing
the various SSL options. See HTTPSConnection for more details.
- The optional *cafile* and *capath* parameters specify a set of trusted CA
- certificates for HTTPS requests. cafile should point to a single file
- containing a bundle of CA certificates, whereas capath should point to a
- directory of hashed certificate files. More information can be found in
- ssl.SSLContext.load_verify_locations().
-
- The *cadefault* parameter is ignored.
-
This function always returns an object which can work as a
context manager and has the properties url, headers, and status.
@@ -188,25 +179,7 @@ def urlopen(url, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
'''
global _opener
- if cafile or capath or cadefault:
- import warnings
- warnings.warn("cafile, capath and cadefault are deprecated, use a "
- "custom context instead.", DeprecationWarning, 2)
- if context is not None:
- raise ValueError(
- "You can't pass both context and any of cafile, capath, and "
- "cadefault"
- )
- if not _have_ssl:
- raise ValueError('SSL support not available')
- context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH,
- cafile=cafile,
- capath=capath)
- # send ALPN extension to indicate HTTP/1.1 protocol
- context.set_alpn_protocols(['http/1.1'])
- https_handler = HTTPSHandler(context=context)
- opener = build_opener(https_handler)
- elif context:
+ if context:
https_handler = HTTPSHandler(context=context)
opener = build_opener(https_handler)
elif _opener is None:
@@ -266,10 +239,7 @@ def urlretrieve(url, filename=None, reporthook=None, data=None):
if reporthook:
reporthook(blocknum, bs, size)
- while True:
- block = fp.read(bs)
- if not block:
- break
+ while block := fp.read(bs):
read += len(block)
tfp.write(block)
blocknum += 1
@@ -661,7 +631,7 @@ def redirect_request(self, req, fp, code, msg, headers, newurl):
but another Handler might.
"""
m = req.get_method()
- if (not (code in (301, 302, 303, 307) and m in ("GET", "HEAD")
+ if (not (code in (301, 302, 303, 307, 308) and m in ("GET", "HEAD")
or code in (301, 302, 303) and m == "POST")):
raise HTTPError(req.full_url, code, msg, headers, fp)
@@ -680,6 +650,7 @@ def redirect_request(self, req, fp, code, msg, headers, newurl):
newheaders = {k: v for k, v in req.headers.items()
if k.lower() not in CONTENT_HEADERS}
return Request(newurl,
+ method="HEAD" if m == "HEAD" else "GET",
headers=newheaders,
origin_req_host=req.origin_req_host,
unverifiable=True)
@@ -748,7 +719,7 @@ def http_error_302(self, req, fp, code, msg, headers):
return self.parent.open(new, timeout=req.timeout)
- http_error_301 = http_error_303 = http_error_307 = http_error_302
+ http_error_301 = http_error_303 = http_error_307 = http_error_308 = http_error_302
inf_msg = "The HTTP server returned a redirect error that would " \
"lead to an infinite loop.\n" \
@@ -907,9 +878,9 @@ def find_user_password(self, realm, authuri):
class HTTPPasswordMgrWithPriorAuth(HTTPPasswordMgrWithDefaultRealm):
- def __init__(self, *args, **kwargs):
+ def __init__(self):
self.authenticated = {}
- super().__init__(*args, **kwargs)
+ super().__init__()
def add_password(self, realm, uri, user, passwd, is_authenticated=False):
self.update_authenticated(uri, is_authenticated)
@@ -1255,8 +1226,8 @@ def http_error_407(self, req, fp, code, msg, headers):
class AbstractHTTPHandler(BaseHandler):
- def __init__(self, debuglevel=0):
- self._debuglevel = debuglevel
+ def __init__(self, debuglevel=None):
+ self._debuglevel = debuglevel if debuglevel is not None else http.client.HTTPConnection.debuglevel
def set_http_debuglevel(self, level):
self._debuglevel = level
@@ -1382,14 +1353,19 @@ def http_open(self, req):
class HTTPSHandler(AbstractHTTPHandler):
- def __init__(self, debuglevel=0, context=None, check_hostname=None):
+ def __init__(self, debuglevel=None, context=None, check_hostname=None):
+ debuglevel = debuglevel if debuglevel is not None else http.client.HTTPSConnection.debuglevel
AbstractHTTPHandler.__init__(self, debuglevel)
+ if context is None:
+ http_version = http.client.HTTPSConnection._http_vsn
+ context = http.client._create_https_context(http_version)
+ if check_hostname is not None:
+ context.check_hostname = check_hostname
self._context = context
- self._check_hostname = check_hostname
def https_open(self, req):
return self.do_open(http.client.HTTPSConnection, req,
- context=self._context, check_hostname=self._check_hostname)
+ context=self._context)
https_request = AbstractHTTPHandler.do_request_
@@ -1561,6 +1537,7 @@ def ftp_open(self, req):
dirs, file = dirs[:-1], dirs[-1]
if dirs and not dirs[0]:
dirs = dirs[1:]
+ fw = None
try:
fw = self.connect_ftp(user, passwd, host, port, dirs, req.timeout)
type = file and 'I' or 'D'
@@ -1578,9 +1555,12 @@ def ftp_open(self, req):
headers += "Content-length: %d\n" % retrlen
headers = email.message_from_string(headers)
return addinfourl(fp, headers, req.full_url)
- except ftplib.all_errors as exp:
- exc = URLError('ftp error: %r' % exp)
- raise exc.with_traceback(sys.exc_info()[2])
+ except Exception as exp:
+ if fw is not None and not fw.keepalive:
+ fw.close()
+ if isinstance(exp, ftplib.all_errors):
+ raise URLError(exp) from exp
+ raise
def connect_ftp(self, user, passwd, host, port, dirs, timeout):
return ftpwrapper(user, passwd, host, port, dirs, timeout,
@@ -1604,14 +1584,15 @@ def setMaxConns(self, m):
def connect_ftp(self, user, passwd, host, port, dirs, timeout):
key = user, host, port, '/'.join(dirs), timeout
- if key in self.cache:
- self.timeout[key] = time.time() + self.delay
- else:
- self.cache[key] = ftpwrapper(user, passwd, host, port,
- dirs, timeout)
- self.timeout[key] = time.time() + self.delay
+ conn = self.cache.get(key)
+ if conn is None or not conn.keepalive:
+ if conn is not None:
+ conn.close()
+ conn = self.cache[key] = ftpwrapper(user, passwd, host, port,
+ dirs, timeout)
+ self.timeout[key] = time.time() + self.delay
self.check_cache()
- return self.cache[key]
+ return conn
def check_cache(self):
# first check for old ones
@@ -1681,12 +1662,27 @@ def data_open(self, req):
def url2pathname(pathname):
"""OS-specific conversion from a relative URL of the 'file' scheme
to a file system path; not recommended for general use."""
- return unquote(pathname)
+ if pathname[:3] == '///':
+ # URL has an empty authority section, so the path begins on the
+ # third character.
+ pathname = pathname[2:]
+ elif pathname[:12] == '//localhost/':
+ # Skip past 'localhost' authority.
+ pathname = pathname[11:]
+ encoding = sys.getfilesystemencoding()
+ errors = sys.getfilesystemencodeerrors()
+ return unquote(pathname, encoding=encoding, errors=errors)
def pathname2url(pathname):
"""OS-specific conversion from a file system path to a relative URL
of the 'file' scheme; not recommended for general use."""
- return quote(pathname)
+ if pathname[:2] == '//':
+ # Add explicitly empty authority to avoid interpreting the path
+ # as authority.
+ pathname = '//' + pathname
+ encoding = sys.getfilesystemencoding()
+ errors = sys.getfilesystemencodeerrors()
+ return quote(pathname, encoding=encoding, errors=errors)
ftpcache = {}
@@ -1791,7 +1787,7 @@ def open(self, fullurl, data=None):
except (HTTPError, URLError):
raise
except OSError as msg:
- raise OSError('socket error', msg).with_traceback(sys.exc_info()[2])
+ raise OSError('socket error', msg) from msg
def open_unknown(self, fullurl, data=None):
"""Overridable interface to open unknown URL type."""
@@ -1845,10 +1841,7 @@ def retrieve(self, url, filename=None, reporthook=None, data=None):
size = int(headers["Content-Length"])
if reporthook:
reporthook(blocknum, bs, size)
- while 1:
- block = fp.read(bs)
- if not block:
- break
+ while block := fp.read(bs):
read += len(block)
tfp.write(block)
blocknum += 1
@@ -1988,9 +1981,17 @@ def http_error_default(self, url, fp, errcode, errmsg, headers):
if _have_ssl:
def _https_connection(self, host):
- return http.client.HTTPSConnection(host,
- key_file=self.key_file,
- cert_file=self.cert_file)
+ if self.key_file or self.cert_file:
+ http_version = http.client.HTTPSConnection._http_vsn
+ context = http.client._create_https_context(http_version)
+ context.load_cert_chain(self.cert_file, self.key_file)
+ # cert and key file means the user wants to authenticate.
+ # enable TLS 1.3 PHA implicitly even for custom contexts.
+ if context.post_handshake_auth is not None:
+ context.post_handshake_auth = True
+ else:
+ context = None
+ return http.client.HTTPSConnection(host, context=context)
def open_https(self, url, data=None):
"""Use HTTPS protocol."""
@@ -2093,7 +2094,7 @@ def open_ftp(self, url):
headers = email.message_from_string(headers)
return addinfourl(fp, headers, "ftp:" + url)
except ftperrors() as exp:
- raise URLError('ftp error %r' % exp).with_traceback(sys.exc_info()[2])
+ raise URLError(f'ftp error: {exp}') from exp
def open_data(self, url, data=None):
"""Use "data" URL."""
@@ -2211,6 +2212,13 @@ def http_error_307(self, url, fp, errcode, errmsg, headers, data=None):
else:
return self.http_error_default(url, fp, errcode, errmsg, headers)
+ def http_error_308(self, url, fp, errcode, errmsg, headers, data=None):
+ """Error 308 -- relocated, but turn POST into error."""
+ if data is None:
+ return self.http_error_301(url, fp, errcode, errmsg, headers, data)
+ else:
+ return self.http_error_default(url, fp, errcode, errmsg, headers)
+
def http_error_401(self, url, fp, errcode, errmsg, headers, data=None,
retry=False):
"""Error 401 -- authentication required.
@@ -2436,8 +2444,7 @@ def retrfile(self, file, type):
conn, retrlen = self.ftp.ntransfercmd(cmd)
except ftplib.error_perm as reason:
if str(reason)[:3] != '550':
- raise URLError('ftp error: %r' % reason).with_traceback(
- sys.exc_info()[2])
+ raise URLError(f'ftp error: {reason}') from reason
if not conn:
# Set transfer mode to ASCII!
self.ftp.voidcmd('TYPE A')
@@ -2464,7 +2471,13 @@ def retrfile(self, file, type):
return (ftpobj, retrlen)
def endtransfer(self):
+ if not self.busy:
+ return
self.busy = 0
+ try:
+ self.ftp.voidresp()
+ except ftperrors():
+ pass
def close(self):
self.keepalive = False
@@ -2492,28 +2505,34 @@ def getproxies_environment():
this seems to be the standard convention. If you need a
different way, you can pass a proxies dictionary to the
[Fancy]URLopener constructor.
-
"""
- proxies = {}
# in order to prefer lowercase variables, process environment in
# two passes: first matches any, second pass matches lowercase only
- for name, value in os.environ.items():
- name = name.lower()
- if value and name[-6:] == '_proxy':
- proxies[name[:-6]] = value
+
+ # select only environment variables which end in (after making lowercase) _proxy
+ proxies = {}
+ environment = []
+ for name in os.environ:
+ # fast screen underscore position before more expensive case-folding
+ if len(name) > 5 and name[-6] == "_" and name[-5:].lower() == "proxy":
+ value = os.environ[name]
+ proxy_name = name[:-6].lower()
+ environment.append((name, value, proxy_name))
+ if value:
+ proxies[proxy_name] = value
# CVE-2016-1000110 - If we are running as CGI script, forget HTTP_PROXY
# (non-all-lowercase) as it may be set from the web server by a "Proxy:"
# header from the client
# If "proxy" is lowercase, it will still be used thanks to the next block
if 'REQUEST_METHOD' in os.environ:
proxies.pop('http', None)
- for name, value in os.environ.items():
+ for name, value, proxy_name in environment:
+ # not case-folded, checking here for lower-case env vars only
if name[-6:] == '_proxy':
- name = name.lower()
if value:
- proxies[name[:-6]] = value
+ proxies[proxy_name] = value
else:
- proxies.pop(name[:-6], None)
+ proxies.pop(proxy_name, None)
return proxies
def proxy_bypass_environment(host, proxies=None):
@@ -2566,6 +2585,7 @@ def _proxy_bypass_macosx_sysconf(host, proxy_settings):
}
"""
from fnmatch import fnmatch
+ from ipaddress import AddressValueError, IPv4Address
hostonly, port = _splitport(host)
@@ -2582,20 +2602,17 @@ def ip2num(ipAddr):
return True
hostIP = None
+ try:
+ hostIP = int(IPv4Address(hostonly))
+ except AddressValueError:
+ pass
for value in proxy_settings.get('exceptions', ()):
# Items in the list are strings like these: *.local, 169.254/16
if not value: continue
m = re.match(r"(\d+(?:\.\d+)*)(/\d+)?", value)
- if m is not None:
- if hostIP is None:
- try:
- hostIP = socket.gethostbyname(hostonly)
- hostIP = ip2num(hostIP)
- except OSError:
- continue
-
+ if m is not None and hostIP is not None:
base = ip2num(m.group(1))
mask = m.group(2)
if mask is None:
@@ -2618,6 +2635,31 @@ def ip2num(ipAddr):
return False
+# Same as _proxy_bypass_macosx_sysconf, testable on all platforms
+def _proxy_bypass_winreg_override(host, override):
+ """Return True if the host should bypass the proxy server.
+
+ The proxy override list is obtained from the Windows
+ Internet settings proxy override registry value.
+
+ An example of a proxy override value is:
+ "www.example.com;*.example.net; 192.168.0.1"
+ """
+ from fnmatch import fnmatch
+
+ host, _ = _splitport(host)
+ proxy_override = override.split(';')
+ for test in proxy_override:
+ test = test.strip()
+ # "" should bypass the proxy server for all intranet addresses
+ if test == '':
+ if '.' not in host:
+ return True
+ elif fnmatch(host, test):
+ return True
+ return False
+
+
if sys.platform == 'darwin':
from _scproxy import _get_proxy_settings, _get_proxies
@@ -2716,7 +2758,7 @@ def proxy_bypass_registry(host):
import winreg
except ImportError:
# Std modules, so should be around - but you never know!
- return 0
+ return False
try:
internetSettings = winreg.OpenKey(winreg.HKEY_CURRENT_USER,
r'Software\Microsoft\Windows\CurrentVersion\Internet Settings')
@@ -2726,40 +2768,10 @@ def proxy_bypass_registry(host):
'ProxyOverride')[0])
# ^^^^ Returned as Unicode but problems if not converted to ASCII
except OSError:
- return 0
+ return False
if not proxyEnable or not proxyOverride:
- return 0
- # try to make a host list from name and IP address.
- rawHost, port = _splitport(host)
- host = [rawHost]
- try:
- addr = socket.gethostbyname(rawHost)
- if addr != rawHost:
- host.append(addr)
- except OSError:
- pass
- try:
- fqdn = socket.getfqdn(rawHost)
- if fqdn != rawHost:
- host.append(fqdn)
- except OSError:
- pass
- # make a check value list from the registry entry: replace the
- # '' string by the localhost entry and the corresponding
- # canonical entry.
- proxyOverride = proxyOverride.split(';')
- # now check if we match one of the registry values.
- for test in proxyOverride:
- if test == '':
- if '.' not in rawHost:
- return 1
- test = test.replace(".", r"\.") # mask dots
- test = test.replace("*", r".*") # change glob sequence
- test = test.replace("?", r".") # change glob char
- for val in host:
- if re.match(test, val, re.I):
- return 1
- return 0
+ return False
+ return _proxy_bypass_winreg_override(host, proxyOverride)
def proxy_bypass(host):
"""Return True, if host should be bypassed.
diff --git a/Lib/urllib/robotparser.py b/Lib/urllib/robotparser.py
index c58565e3945..63689816f30 100644
--- a/Lib/urllib/robotparser.py
+++ b/Lib/urllib/robotparser.py
@@ -11,6 +11,8 @@
"""
import collections
+import re
+import urllib.error
import urllib.parse
import urllib.request
@@ -19,6 +21,19 @@
RequestRate = collections.namedtuple("RequestRate", "requests seconds")
+def normalize(path):
+ unquoted = urllib.parse.unquote(path, errors='surrogateescape')
+ return urllib.parse.quote(unquoted, errors='surrogateescape')
+
+def normalize_path(path):
+ path, sep, query = path.partition('?')
+ path = normalize(path)
+ if sep:
+ query = re.sub(r'[^=&]+', lambda m: normalize(m[0]), query)
+ path += '?' + query
+ return path
+
+
class RobotFileParser:
""" This class provides a set of methods to read, parse and answer
questions about a single robots.txt file.
@@ -54,7 +69,7 @@ def modified(self):
def set_url(self, url):
"""Sets the URL referring to a robots.txt file."""
self.url = url
- self.host, self.path = urllib.parse.urlparse(url)[1:3]
+ self.host, self.path = urllib.parse.urlsplit(url)[1:3]
def read(self):
"""Reads the robots.txt URL and feeds it to the parser."""
@@ -65,9 +80,10 @@ def read(self):
self.disallow_all = True
elif err.code >= 400 and err.code < 500:
self.allow_all = True
+ err.close()
else:
raw = f.read()
- self.parse(raw.decode("utf-8").splitlines())
+ self.parse(raw.decode("utf-8", "surrogateescape").splitlines())
def _add_entry(self, entry):
if "*" in entry.useragents:
@@ -111,7 +127,7 @@ def parse(self, lines):
line = line.split(':', 1)
if len(line) == 2:
line[0] = line[0].strip().lower()
- line[1] = urllib.parse.unquote(line[1].strip())
+ line[1] = line[1].strip()
if line[0] == "user-agent":
if state == 2:
self._add_entry(entry)
@@ -165,10 +181,9 @@ def can_fetch(self, useragent, url):
return False
# search for given user agent matches
# the first match counts
- parsed_url = urllib.parse.urlparse(urllib.parse.unquote(url))
- url = urllib.parse.urlunparse(('','',parsed_url.path,
- parsed_url.params,parsed_url.query, parsed_url.fragment))
- url = urllib.parse.quote(url)
+ parsed_url = urllib.parse.urlsplit(url)
+ url = urllib.parse.urlunsplit(('', '', *parsed_url[2:]))
+ url = normalize_path(url)
if not url:
url = "/"
for entry in self.entries:
@@ -211,7 +226,6 @@ def __str__(self):
entries = entries + [self.default_entry]
return '\n\n'.join(map(str, entries))
-
class RuleLine:
"""A rule line is a single "Allow:" (allowance==True) or "Disallow:"
(allowance==False) followed by a path."""
@@ -219,8 +233,7 @@ def __init__(self, path, allowance):
if path == '' and not allowance:
# an empty value means allow all
allowance = True
- path = urllib.parse.urlunparse(urllib.parse.urlparse(path))
- self.path = urllib.parse.quote(path)
+ self.path = normalize_path(path)
self.allowance = allowance
def applies_to(self, filename):
@@ -266,7 +279,7 @@ def applies_to(self, useragent):
def allowance(self, filename):
"""Preconditions:
- our agent applies to this entry
- - filename is URL decoded"""
+ - filename is URL encoded"""
for line in self.rulelines:
if line.applies_to(filename):
return line.allowance
diff --git a/crates/compiler-core/src/bytecode/instruction.rs b/crates/compiler-core/src/bytecode/instruction.rs
index 3ebb3666ae2..44a57c44320 100644
--- a/crates/compiler-core/src/bytecode/instruction.rs
+++ b/crates/compiler-core/src/bytecode/instruction.rs
@@ -245,10 +245,7 @@ pub enum Instruction {
YieldValue {
arg: Arg,
} = 118,
- Resume {
- arg: Arg,
- } = 149,
- // ==================== RustPython-only instructions (119-135) ====================
+ // ==================== RustPython-only instructions (119-133) ====================
// Ideally, we want to be fully aligned with CPython opcodes, but we still have some leftovers.
// So we assign random IDs to these opcodes.
Break {
@@ -277,10 +274,106 @@ pub enum Instruction {
target: Arg,
} = 130,
JumpIfNotExcMatch(Arg) = 131,
- SetExcInfo = 134,
- Subscript = 135,
+ SetExcInfo = 132,
+ Subscript = 133,
+ // End of custom instructions
+ Resume {
+ arg: Arg,
+ } = 149,
+ BinaryOpAddFloat = 150, // Placeholder
+ BinaryOpAddInt = 151, // Placeholder
+ BinaryOpAddUnicode = 152, // Placeholder
+ BinaryOpMultiplyFloat = 153, // Placeholder
+ BinaryOpMultiplyInt = 154, // Placeholder
+ BinaryOpSubtractFloat = 155, // Placeholder
+ BinaryOpSubtractInt = 156, // Placeholder
+ BinarySubscrDict = 157, // Placeholder
+ BinarySubscrGetitem = 158, // Placeholder
+ BinarySubscrListInt = 159, // Placeholder
+ BinarySubscrStrInt = 160, // Placeholder
+ BinarySubscrTupleInt = 161, // Placeholder
+ CallAllocAndEnterInit = 162, // Placeholder
+ CallBoundMethodExactArgs = 163, // Placeholder
+ CallBoundMethodGeneral = 164, // Placeholder
+ CallBuiltinClass = 165, // Placeholder
+ CallBuiltinFast = 166, // Placeholder
+ CallBuiltinFastWithKeywords = 167, // Placeholder
+ CallBuiltinO = 168, // Placeholder
+ CallIsinstance = 169, // Placeholder
+ CallLen = 170, // Placeholder
+ CallListAppend = 171, // Placeholder
+ CallMethodDescriptorFast = 172, // Placeholder
+ CallMethodDescriptorFastWithKeywords = 173, // Placeholder
+ CallMethodDescriptorNoargs = 174, // Placeholder
+ CallMethodDescriptorO = 175, // Placeholder
+ CallNonPyGeneral = 176, // Placeholder
+ CallPyExactArgs = 177, // Placeholder
+ CallPyGeneral = 178, // Placeholder
+ CallStr1 = 179, // Placeholder
+ CallTuple1 = 180, // Placeholder
+ CallType1 = 181, // Placeholder
+ CompareOpFloat = 182, // Placeholder
+ CompareOpInt = 183, // Placeholder
+ CompareOpStr = 184, // Placeholder
+ ContainsOpDict = 185, // Placeholder
+ ContainsOpSet = 186, // Placeholder
+ ForIterGen = 187, // Placeholder
+ ForIterList = 188, // Placeholder
+ ForIterRange = 189, // Placeholder
+ ForIterTuple = 190, // Placeholder
+ LoadAttrClass = 191, // Placeholder
+ LoadAttrGetattributeOverridden = 192, // Placeholder
+ LoadAttrInstanceValue = 193, // Placeholder
+ LoadAttrMethodLazyDict = 194, // Placeholder
+ LoadAttrMethodNoDict = 195, // Placeholder
+ LoadAttrMethodWithValues = 196, // Placeholder
+ LoadAttrModule = 197, // Placeholder
+ LoadAttrNondescriptorNoDict = 198, // Placeholder
+ LoadAttrNondescriptorWithValues = 199, // Placeholder
+ LoadAttrProperty = 200, // Placeholder
+ LoadAttrSlot = 201, // Placeholder
+ LoadAttrWithHint = 202, // Placeholder
+ LoadGlobalBuiltin = 203, // Placeholder
+ LoadGlobalModule = 204, // Placeholder
+ LoadSuperAttrAttr = 205, // Placeholder
+ LoadSuperAttrMethod = 206, // Placeholder
+ ResumeCheck = 207, // Placeholder
+ SendGen = 208, // Placeholder
+ StoreAttrInstanceValue = 209, // Placeholder
+ StoreAttrSlot = 210, // Placeholder
+ StoreAttrWithHint = 211, // Placeholder
+ StoreSubscrDict = 212, // Placeholder
+ StoreSubscrListInt = 213, // Placeholder
+ ToBoolAlwaysTrue = 214, // Placeholder
+ ToBoolBool = 215, // Placeholder
+ ToBoolInt = 216, // Placeholder
+ ToBoolList = 217, // Placeholder
+ ToBoolNone = 218, // Placeholder
+ ToBoolStr = 219, // Placeholder
+ UnpackSequenceList = 220, // Placeholder
+ UnpackSequenceTuple = 221, // Placeholder
+ UnpackSequenceTwoTuple = 222, // Placeholder
+ InstrumentedResume = 236, // Placeholder
+ InstrumentedEndFor = 237, // Placeholder
+ InstrumentedEndSend = 238, // Placeholder
+ InstrumentedReturnValue = 239, // Placeholder
+ InstrumentedReturnConst = 240, // Placeholder
+ InstrumentedYieldValue = 241, // Placeholder
+ InstrumentedLoadSuperAttr = 242, // Placeholder
+ InstrumentedForIter = 243, // Placeholder
+ InstrumentedCall = 244, // Placeholder
+ InstrumentedCallKw = 245, // Placeholder
+ InstrumentedCallFunctionEx = 246, // Placeholder
+ InstrumentedInstruction = 247, // Placeholder
+ InstrumentedJumpForward = 248, // Placeholder
+ InstrumentedJumpBackward = 249, // Placeholder
+ InstrumentedPopJumpIfTrue = 250, // Placeholder
+ InstrumentedPopJumpIfFalse = 251, // Placeholder
+ InstrumentedPopJumpIfNone = 252, // Placeholder
+ InstrumentedPopJumpIfNotNone = 253, // Placeholder
+ InstrumentedLine = 254, // Placeholder
// Pseudos (needs to be moved to `PseudoInstruction` enum.
- LoadClosure(Arg) = 253, // TODO: Move to pseudos
+ LoadClosure(Arg) = 255, // TODO: Move to pseudos
}
const _: () = assert!(mem::size_of::() == 1);
@@ -305,6 +398,12 @@ impl TryFrom for Instruction {
// Resume has a non-contiguous opcode (149)
let resume_id = u8::from(Self::Resume { arg: Arg::marker() });
+ let specialized_start = u8::from(Self::BinaryOpAddFloat);
+ let specialized_end = u8::from(Self::UnpackSequenceTwoTuple);
+
+ let instrumented_start = u8::from(Self::InstrumentedResume);
+ let instrumented_end = u8::from(Self::InstrumentedLine);
+
// TODO: Remove this; This instruction needs to be pseudo
let load_closure = u8::from(Self::LoadClosure(Arg::marker()));
@@ -345,6 +444,8 @@ impl TryFrom for Instruction {
|| value == resume_id
|| value == load_closure
|| custom_ops.contains(&value)
+ || (specialized_start..=specialized_end).contains(&value)
+ || (instrumented_start..=instrumented_end).contains(&value)
{
Ok(unsafe { mem::transmute::(value) })
} else {
@@ -589,6 +690,98 @@ impl InstructionMetadata for Instruction {
Self::PopJumpIfNone { .. } => 0,
Self::PopJumpIfNotNone { .. } => 0,
Self::LoadClosure(_) => 1,
+ Self::BinaryOpAddFloat => 0,
+ Self::BinaryOpAddInt => 0,
+ Self::BinaryOpAddUnicode => 0,
+ Self::BinaryOpMultiplyFloat => 0,
+ Self::BinaryOpMultiplyInt => 0,
+ Self::BinaryOpSubtractFloat => 0,
+ Self::BinaryOpSubtractInt => 0,
+ Self::BinarySubscrDict => 0,
+ Self::BinarySubscrGetitem => 0,
+ Self::BinarySubscrListInt => 0,
+ Self::BinarySubscrStrInt => 0,
+ Self::BinarySubscrTupleInt => 0,
+ Self::CallAllocAndEnterInit => 0,
+ Self::CallBoundMethodExactArgs => 0,
+ Self::CallBoundMethodGeneral => 0,
+ Self::CallBuiltinClass => 0,
+ Self::CallBuiltinFast => 0,
+ Self::CallBuiltinFastWithKeywords => 0,
+ Self::CallBuiltinO => 0,
+ Self::CallIsinstance => 0,
+ Self::CallLen => 0,
+ Self::CallListAppend => 0,
+ Self::CallMethodDescriptorFast => 0,
+ Self::CallMethodDescriptorFastWithKeywords => 0,
+ Self::CallMethodDescriptorNoargs => 0,
+ Self::CallMethodDescriptorO => 0,
+ Self::CallNonPyGeneral => 0,
+ Self::CallPyExactArgs => 0,
+ Self::CallPyGeneral => 0,
+ Self::CallStr1 => 0,
+ Self::CallTuple1 => 0,
+ Self::CallType1 => 0,
+ Self::CompareOpFloat => 0,
+ Self::CompareOpInt => 0,
+ Self::CompareOpStr => 0,
+ Self::ContainsOpDict => 0,
+ Self::ContainsOpSet => 0,
+ Self::ForIterGen => 0,
+ Self::ForIterList => 0,
+ Self::ForIterRange => 0,
+ Self::ForIterTuple => 0,
+ Self::LoadAttrClass => 0,
+ Self::LoadAttrGetattributeOverridden => 0,
+ Self::LoadAttrInstanceValue => 0,
+ Self::LoadAttrMethodLazyDict => 0,
+ Self::LoadAttrMethodNoDict => 0,
+ Self::LoadAttrMethodWithValues => 0,
+ Self::LoadAttrModule => 0,
+ Self::LoadAttrNondescriptorNoDict => 0,
+ Self::LoadAttrNondescriptorWithValues => 0,
+ Self::LoadAttrProperty => 0,
+ Self::LoadAttrSlot => 0,
+ Self::LoadAttrWithHint => 0,
+ Self::LoadGlobalBuiltin => 0,
+ Self::LoadGlobalModule => 0,
+ Self::LoadSuperAttrAttr => 0,
+ Self::LoadSuperAttrMethod => 0,
+ Self::ResumeCheck => 0,
+ Self::SendGen => 0,
+ Self::StoreAttrInstanceValue => 0,
+ Self::StoreAttrSlot => 0,
+ Self::StoreAttrWithHint => 0,
+ Self::StoreSubscrDict => 0,
+ Self::StoreSubscrListInt => 0,
+ Self::ToBoolAlwaysTrue => 0,
+ Self::ToBoolBool => 0,
+ Self::ToBoolInt => 0,
+ Self::ToBoolList => 0,
+ Self::ToBoolNone => 0,
+ Self::ToBoolStr => 0,
+ Self::UnpackSequenceList => 0,
+ Self::UnpackSequenceTuple => 0,
+ Self::UnpackSequenceTwoTuple => 0,
+ Self::InstrumentedResume => 0,
+ Self::InstrumentedEndFor => 0,
+ Self::InstrumentedEndSend => 0,
+ Self::InstrumentedReturnValue => 0,
+ Self::InstrumentedReturnConst => 0,
+ Self::InstrumentedYieldValue => 0,
+ Self::InstrumentedLoadSuperAttr => 0,
+ Self::InstrumentedForIter => 0,
+ Self::InstrumentedCall => 0,
+ Self::InstrumentedCallKw => 0,
+ Self::InstrumentedCallFunctionEx => 0,
+ Self::InstrumentedInstruction => 0,
+ Self::InstrumentedJumpForward => 0,
+ Self::InstrumentedJumpBackward => 0,
+ Self::InstrumentedPopJumpIfTrue => 0,
+ Self::InstrumentedPopJumpIfFalse => 0,
+ Self::InstrumentedPopJumpIfNone => 0,
+ Self::InstrumentedPopJumpIfNotNone => 0,
+ Self::InstrumentedLine => 0,
}
}
diff --git a/crates/vm/Lib/python_builtins/__reducelib.py b/crates/vm/Lib/python_builtins/__reducelib.py
deleted file mode 100644
index 0067cd0a818..00000000000
--- a/crates/vm/Lib/python_builtins/__reducelib.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# Modified from code from the PyPy project:
-# https://bitbucket.org/pypy/pypy/src/default/pypy/objspace/std/objectobject.py
-
-# The MIT License
-
-# Permission is hereby granted, free of charge, to any person
-# obtaining a copy of this software and associated documentation
-# files (the "Software"), to deal in the Software without
-# restriction, including without limitation the rights to use,
-# copy, modify, merge, publish, distribute, sublicense, and/or
-# sell copies of the Software, and to permit persons to whom the
-# Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included
-# in all copies or substantial portions of the Software.
-
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
-# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-import copyreg
-
-
-def _abstract_method_error(typ):
- methods = ", ".join(sorted(typ.__abstractmethods__))
- err = "Can't instantiate abstract class %s with abstract methods %s"
- raise TypeError(err % (typ.__name__, methods))
-
-
-def reduce_2(obj):
- cls = obj.__class__
-
- try:
- getnewargs = obj.__getnewargs__
- except AttributeError:
- args = ()
- else:
- args = getnewargs()
- if not isinstance(args, tuple):
- raise TypeError("__getnewargs__ should return a tuple")
-
- try:
- getstate = obj.__getstate__
- except AttributeError:
- state = getattr(obj, "__dict__", None)
- names = slotnames(cls) # not checking for list
- if names is not None:
- slots = {}
- for name in names:
- try:
- value = getattr(obj, name)
- except AttributeError:
- pass
- else:
- slots[name] = value
- if slots:
- state = state, slots
- else:
- state = getstate()
-
- listitems = iter(obj) if isinstance(obj, list) else None
- dictitems = iter(obj.items()) if isinstance(obj, dict) else None
-
- newobj = copyreg.__newobj__
-
- args2 = (cls,) + args
- return newobj, args2, state, listitems, dictitems
-
-
-def slotnames(cls):
- if not isinstance(cls, type):
- return None
-
- try:
- return cls.__dict__["__slotnames__"]
- except KeyError:
- pass
-
- slotnames = copyreg._slotnames(cls)
- if not isinstance(slotnames, list) and slotnames is not None:
- raise TypeError("copyreg._slotnames didn't return a list or None")
- return slotnames
diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs
index 58c683d3fab..3d5159e58c1 100644
--- a/crates/vm/src/builtins/function.rs
+++ b/crates/vm/src/builtins/function.rs
@@ -2,8 +2,8 @@
mod jit;
use super::{
- PyAsyncGen, PyCode, PyCoroutine, PyDictRef, PyGenerator, PyStr, PyStrRef, PyTuple, PyTupleRef,
- PyType,
+ PyAsyncGen, PyCode, PyCoroutine, PyDictRef, PyGenerator, PyModule, PyStr, PyStrRef, PyTuple,
+ PyTupleRef, PyType,
};
#[cfg(feature = "jit")]
use crate::common::lock::OnceCell;
@@ -67,9 +67,15 @@ impl PyFunction {
if let Some(frame) = vm.current_frame() {
frame.builtins.clone().into()
} else {
- vm.builtins.clone().into()
+ vm.builtins.dict().into()
}
});
+ // If builtins is a module, use its __dict__ instead
+ let builtins = if let Some(module) = builtins.downcast_ref::() {
+ module.dict().into()
+ } else {
+ builtins
+ };
let qualname = vm.ctx.new_str(code.qualname.as_str());
let func = Self {
@@ -679,11 +685,11 @@ pub struct PyFunctionNewArgs {
#[pyarg(any, optional)]
name: OptionalArg,
#[pyarg(any, optional)]
- defaults: OptionalArg,
+ argdefs: Option,
#[pyarg(any, optional)]
- closure: OptionalArg,
+ closure: Option,
#[pyarg(any, optional)]
- kwdefaults: OptionalArg,
+ kwdefaults: Option,
}
impl Constructor for PyFunction {
@@ -691,7 +697,7 @@ impl Constructor for PyFunction {
fn py_new(_cls: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult {
// Handle closure - must be a tuple of cells
- let closure = if let Some(closure_tuple) = args.closure.into_option() {
+ let closure = if let Some(closure_tuple) = args.closure {
// Check that closure length matches code's free variables
if closure_tuple.len() != args.code.freevars.len() {
return Err(vm.new_value_error(format!(
@@ -722,10 +728,10 @@ impl Constructor for PyFunction {
if let Some(closure_tuple) = closure {
func.closure = Some(closure_tuple);
}
- if let Some(defaults) = args.defaults.into_option() {
- func.defaults_and_kwdefaults.lock().0 = Some(defaults);
+ if let Some(argdefs) = args.argdefs {
+ func.defaults_and_kwdefaults.lock().0 = Some(argdefs);
}
- if let Some(kwdefaults) = args.kwdefaults.into_option() {
+ if let Some(kwdefaults) = args.kwdefaults {
func.defaults_and_kwdefaults.lock().1 = Some(kwdefaults);
}
diff --git a/crates/vm/src/builtins/object.rs b/crates/vm/src/builtins/object.rs
index 6f072542547..982e11afd00 100644
--- a/crates/vm/src/builtins/object.rs
+++ b/crates/vm/src/builtins/object.rs
@@ -184,15 +184,12 @@ fn type_slot_names(typ: &Py, vm: &VirtualMachine) -> PyResult PyResult {
- // TODO: itemsize
- // if required && obj.class().slots.itemsize > 0 {
- // return vm.new_type_error(format!(
- // "cannot pickle {:.200} objects",
- // obj.class().name()
- // ));
- // }
+ // Check itemsize
+ if required && obj.class().slots.itemsize > 0 {
+ return Err(vm.new_type_error(format!("cannot pickle {:.200} objects", obj.class().name())));
+ }
let state = if obj.dict().is_none_or(|d| d.is_empty()) {
vm.ctx.none()
@@ -208,22 +205,36 @@ fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine)
type_slot_names(obj.class(), vm).map_err(|_| vm.new_type_error("cannot pickle object"))?;
if required {
- let mut basicsize = obj.class().slots.basicsize;
- // if obj.class().slots.dict_offset > 0
- // && !obj.class().slots.flags.has_feature(PyTypeFlags::MANAGED_DICT)
- // {
- // basicsize += std::mem::size_of::();
- // }
- // if obj.class().slots.weaklist_offset > 0 {
- // basicsize += std::mem::size_of::();
- // }
+ // Start with PyBaseObject_Type's basicsize
+ let mut basicsize = vm.ctx.types.object_type.slots.basicsize;
+
+ // Add __dict__ size if type has dict
+ if obj.class().slots.flags.has_feature(PyTypeFlags::HAS_DICT) {
+ basicsize += core::mem::size_of::();
+ }
+
+ // Add __weakref__ size if type has weakref support
+ let has_weakref = if let Some(ref ext) = obj.class().heaptype_ext {
+ match &ext.slots {
+ None => true, // Heap type without __slots__ has automatic weakref
+ Some(slots) => slots.iter().any(|s| s.as_str() == "__weakref__"),
+ }
+ } else {
+ let weakref_name = vm.ctx.intern_str("__weakref__");
+ obj.class().attributes.read().contains_key(weakref_name)
+ };
+ if has_weakref {
+ basicsize += core::mem::size_of::();
+ }
+
+ // Add slots size
if let Some(ref slot_names) = slot_names {
basicsize += core::mem::size_of::() * slot_names.__len__();
}
+
+ // Fail if actual type's basicsize > expected basicsize
if obj.class().slots.basicsize > basicsize {
- return Err(
- vm.new_type_error(format!("cannot pickle {:.200} object", obj.class().name()))
- );
+ return Err(vm.new_type_error(format!("cannot pickle '{}' object", obj.class().name())));
}
}
@@ -249,7 +260,7 @@ fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine)
Ok(state)
}
-// object_getstate in CPython
+// object_getstate
// fn object_getstate(
// obj: &PyObject,
// required: bool,
@@ -550,11 +561,181 @@ pub fn init(ctx: &Context) {
PyBaseObject::extend_class(ctx, ctx.types.object_type);
}
+/// Get arguments for __new__ from __getnewargs_ex__ or __getnewargs__
+/// Returns (args, kwargs) tuple where either can be None
+fn get_new_arguments(
+ obj: &PyObject,
+ vm: &VirtualMachine,
+) -> PyResult<(Option, Option)> {
+ // First try __getnewargs_ex__
+ if let Some(getnewargs_ex) = vm.get_special_method(obj, identifier!(vm, __getnewargs_ex__))? {
+ let newargs = getnewargs_ex.invoke((), vm)?;
+
+ let newargs_tuple: PyRef = newargs.downcast().map_err(|obj| {
+ vm.new_type_error(format!(
+ "__getnewargs_ex__ should return a tuple, not '{}'",
+ obj.class().name()
+ ))
+ })?;
+
+ if newargs_tuple.len() != 2 {
+ return Err(vm.new_value_error(format!(
+ "__getnewargs_ex__ should return a tuple of length 2, not {}",
+ newargs_tuple.len()
+ )));
+ }
+
+ let args = newargs_tuple.as_slice()[0].clone();
+ let kwargs = newargs_tuple.as_slice()[1].clone();
+
+ let args_tuple: PyRef = args.downcast().map_err(|obj| {
+ vm.new_type_error(format!(
+ "first item of the tuple returned by __getnewargs_ex__ must be a tuple, not '{}'",
+ obj.class().name()
+ ))
+ })?;
+
+ let kwargs_dict: PyRef = kwargs.downcast().map_err(|obj| {
+ vm.new_type_error(format!(
+ "second item of the tuple returned by __getnewargs_ex__ must be a dict, not '{}'",
+ obj.class().name()
+ ))
+ })?;
+
+ return Ok((Some(args_tuple), Some(kwargs_dict)));
+ }
+
+ // Fall back to __getnewargs__
+ if let Some(getnewargs) = vm.get_special_method(obj, identifier!(vm, __getnewargs__))? {
+ let args = getnewargs.invoke((), vm)?;
+
+ let args_tuple: PyRef = args.downcast().map_err(|obj| {
+ vm.new_type_error(format!(
+ "__getnewargs__ should return a tuple, not '{}'",
+ obj.class().name()
+ ))
+ })?;
+
+ return Ok((Some(args_tuple), None));
+ }
+
+ // No __getnewargs_ex__ or __getnewargs__
+ Ok((None, None))
+}
+
+/// Check if __getstate__ is overridden by comparing with object.__getstate__
+fn is_getstate_overridden(obj: &PyObject, vm: &VirtualMachine) -> bool {
+ let obj_cls = obj.class();
+ let object_type = vm.ctx.types.object_type;
+
+ // If the class is object itself, not overridden
+ if obj_cls.is(object_type) {
+ return false;
+ }
+
+ // Check if __getstate__ in the MRO comes from object or elsewhere
+ // If the type has its own __getstate__, it's overridden
+ if let Some(getstate) = obj_cls.get_attr(identifier!(vm, __getstate__))
+ && let Some(obj_getstate) = object_type.get_attr(identifier!(vm, __getstate__))
+ {
+ return !getstate.is(&obj_getstate);
+ }
+ false
+}
+
+/// object_getstate - calls __getstate__ method or default implementation
+fn object_getstate(obj: &PyObject, required: bool, vm: &VirtualMachine) -> PyResult {
+ // If __getstate__ is not overridden, use the default implementation with required flag
+ if !is_getstate_overridden(obj, vm) {
+ return object_getstate_default(obj, required, vm);
+ }
+
+ // __getstate__ is overridden, call it without required
+ let getstate = obj.get_attr(identifier!(vm, __getstate__), vm)?;
+ getstate.call((), vm)
+}
+
+/// Get list items iterator if obj is a list (or subclass), None iterator otherwise
+fn get_items_iter(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, PyObjectRef)> {
+ let listitems: PyObjectRef = if obj.fast_isinstance(vm.ctx.types.list_type) {
+ obj.get_iter(vm)?.into()
+ } else {
+ vm.ctx.none()
+ };
+
+ let dictitems: PyObjectRef = if obj.fast_isinstance(vm.ctx.types.dict_type) {
+ let items = vm.call_method(obj, "items", ())?;
+ items.get_iter(vm)?.into()
+ } else {
+ vm.ctx.none()
+ };
+
+ Ok((listitems, dictitems))
+}
+
+/// reduce_newobj - creates reduce tuple for protocol >= 2
+fn reduce_newobj(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
+ // Check if type has tp_new
+ let cls = obj.class();
+ if cls.slots.new.load().is_none() {
+ return Err(vm.new_type_error(format!("cannot pickle '{}' object", cls.name())));
+ }
+
+ let (args, kwargs) = get_new_arguments(&obj, vm)?;
+
+ let copyreg = vm.import("copyreg", 0)?;
+
+ let has_args = args.is_some();
+
+ let (newobj, newargs): (PyObjectRef, PyObjectRef) = if kwargs.is_none()
+ || kwargs.as_ref().is_some_and(|k| k.is_empty())
+ {
+ // Use copyreg.__newobj__
+ let newobj = copyreg.get_attr("__newobj__", vm)?;
+
+ let args_vec: Vec = args.map(|a| a.as_slice().to_vec()).unwrap_or_default();
+
+ // Create (cls, *args) tuple
+ let mut newargs_vec: Vec = vec![cls.to_owned().into()];
+ newargs_vec.extend(args_vec);
+ let newargs = vm.ctx.new_tuple(newargs_vec);
+
+ (newobj, newargs.into())
+ } else {
+ // Use copyreg.__newobj_ex__
+ let newobj = copyreg.get_attr("__newobj_ex__", vm)?;
+ let args_tuple: PyObjectRef = args
+ .map(|a| a.into())
+ .unwrap_or_else(|| vm.ctx.empty_tuple.clone().into());
+ let kwargs_dict: PyObjectRef = kwargs
+ .map(|k| k.into())
+ .unwrap_or_else(|| vm.ctx.new_dict().into());
+
+ let newargs = vm
+ .ctx
+ .new_tuple(vec![cls.to_owned().into(), args_tuple, kwargs_dict]);
+ (newobj, newargs.into())
+ };
+
+ // Determine if state is required
+ // required = !(has_args || is_list || is_dict)
+ let is_list = obj.fast_isinstance(vm.ctx.types.list_type);
+ let is_dict = obj.fast_isinstance(vm.ctx.types.dict_type);
+ let required = !(has_args || is_list || is_dict);
+
+ let state = object_getstate(&obj, required, vm)?;
+
+ let (listitems, dictitems) = get_items_iter(&obj, vm)?;
+
+ let result = vm
+ .ctx
+ .new_tuple(vec![newobj, newargs, state, listitems, dictitems]);
+ Ok(result.into())
+}
+
fn common_reduce(obj: PyObjectRef, proto: usize, vm: &VirtualMachine) -> PyResult {
if proto >= 2 {
- let reducelib = vm.import("__reducelib", 0)?;
- let reduce_2 = reducelib.get_attr("reduce_2", vm)?;
- reduce_2.call((obj,), vm)
+ reduce_newobj(obj, vm)
} else {
let copyreg = vm.import("copyreg", 0)?;
let reduce_ex = copyreg.get_attr("_reduce_ex", vm)?;
diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/io.rs
index 54a38ef20e6..0c6eebe15d3 100644
--- a/crates/vm/src/stdlib/io.rs
+++ b/crates/vm/src/stdlib/io.rs
@@ -158,8 +158,8 @@ mod _io {
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
TryFromBorrowedObject, TryFromObject,
builtins::{
- PyBaseExceptionRef, PyBool, PyByteArray, PyBytes, PyBytesRef, PyMemoryView, PyStr,
- PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef,
+ PyBaseExceptionRef, PyBool, PyByteArray, PyBytes, PyBytesRef, PyDict, PyMemoryView,
+ PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef,
},
class::StaticType,
common::lock::{
@@ -4077,6 +4077,67 @@ mod _io {
const fn line_buffering(&self) -> bool {
false
}
+
+ #[pymethod]
+ fn __getstate__(zelf: PyRef, vm: &VirtualMachine) -> PyResult {
+ let buffer = zelf.buffer(vm)?;
+ let content = Wtf8Buf::from_bytes(buffer.getvalue())
+ .map_err(|_| vm.new_value_error("Error Retrieving Value"))?;
+ let pos = buffer.tell();
+ drop(buffer);
+
+ // Get __dict__ if it exists and is non-empty
+ let dict_obj: PyObjectRef = match zelf.as_object().dict() {
+ Some(d) if !d.is_empty() => d.into(),
+ _ => vm.ctx.none(),
+ };
+
+ // Return (content, newline, position, dict)
+ // TODO: store actual newline setting when it's implemented
+ Ok(vm.ctx.new_tuple(vec![
+ vm.ctx.new_str(content).into(),
+ vm.ctx.new_str("\n").into(),
+ vm.ctx.new_int(pos).into(),
+ dict_obj,
+ ]))
+ }
+
+ #[pymethod]
+ fn __setstate__(zelf: PyRef, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
+ if state.len() != 4 {
+ return Err(vm.new_type_error(format!(
+ "__setstate__ argument should be 4-tuple, got {}",
+ state.len()
+ )));
+ }
+
+ let content: PyStrRef = state[0].clone().try_into_value(vm)?;
+ // state[1] is newline - TODO: use when newline handling is implemented
+ let pos: u64 = state[2].clone().try_into_value(vm)?;
+ let dict = &state[3];
+
+ // Set content
+ let raw_bytes = content.as_bytes().to_vec();
+ *zelf.buffer.write() = BufferedIO::new(Cursor::new(raw_bytes));
+
+ // Set position
+ zelf.buffer(vm)?
+ .seek(SeekFrom::Start(pos))
+ .map_err(|err| os_err(vm, err))?;
+
+ // Set __dict__ if provided
+ if !vm.is_none(dict) {
+ let dict_ref: PyRef = dict.clone().try_into_value(vm)?;
+ if let Some(obj_dict) = zelf.as_object().dict() {
+ obj_dict.clear();
+ for (key, value) in dict_ref.into_iter() {
+ obj_dict.set_item(&*key, value, vm)?;
+ }
+ }
+ }
+
+ Ok(())
+ }
}
#[pyattr]
@@ -4225,6 +4286,65 @@ mod _io {
self.closed.store(true);
Ok(())
}
+
+ #[pymethod]
+ fn __getstate__(zelf: PyRef, vm: &VirtualMachine) -> PyResult {
+ let buffer = zelf.buffer(vm)?;
+ let content = buffer.getvalue();
+ let pos = buffer.tell();
+ drop(buffer);
+
+ // Get __dict__ if it exists and is non-empty
+ let dict_obj: PyObjectRef = match zelf.as_object().dict() {
+ Some(d) if !d.is_empty() => d.into(),
+ _ => vm.ctx.none(),
+ };
+
+ // Return (content, position, dict)
+ Ok(vm.ctx.new_tuple(vec![
+ vm.ctx.new_bytes(content).into(),
+ vm.ctx.new_int(pos).into(),
+ dict_obj,
+ ]))
+ }
+
+ #[pymethod]
+ fn __setstate__(zelf: PyRef, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
+ if zelf.closed.load() {
+ return Err(vm.new_value_error("__setstate__ on closed file"));
+ }
+ if state.len() != 3 {
+ return Err(vm.new_type_error(format!(
+ "__setstate__ argument should be 3-tuple, got {}",
+ state.len()
+ )));
+ }
+
+ let content: PyBytesRef = state[0].clone().try_into_value(vm)?;
+ let pos: u64 = state[1].clone().try_into_value(vm)?;
+ let dict = &state[2];
+
+ // Check exports and set content (like CHECK_EXPORTS)
+ let mut buffer = zelf.try_resizable(vm)?;
+ *buffer = BufferedIO::new(Cursor::new(content.as_bytes().to_vec()));
+ buffer
+ .seek(SeekFrom::Start(pos))
+ .map_err(|err| os_err(vm, err))?;
+ drop(buffer);
+
+ // Set __dict__ if provided
+ if !vm.is_none(dict) {
+ let dict_ref: PyRef = dict.clone().try_into_value(vm)?;
+ if let Some(obj_dict) = zelf.as_object().dict() {
+ obj_dict.clear();
+ for (key, value) in dict_ref.into_iter() {
+ obj_dict.set_item(&*key, value, vm)?;
+ }
+ }
+ }
+
+ Ok(())
+ }
}
#[pyclass]
diff --git a/crates/vm/src/stdlib/thread.rs b/crates/vm/src/stdlib/thread.rs
index d51d78015d6..db588e5eab7 100644
--- a/crates/vm/src/stdlib/thread.rs
+++ b/crates/vm/src/stdlib/thread.rs
@@ -516,7 +516,7 @@ pub(crate) mod _thread {
let mut handles = vm.state.shutdown_handles.lock();
// Clean up finished entries
handles.retain(|(inner_weak, _): &ShutdownEntry| {
- inner_weak.upgrade().map_or(false, |inner| {
+ inner_weak.upgrade().is_some_and(|inner| {
let guard = inner.lock();
guard.state != ThreadHandleState::Done && guard.ident != current_ident
})
diff --git a/crates/vm/src/vm/context.rs b/crates/vm/src/vm/context.rs
index b12352f6eee..65c742e4915 100644
--- a/crates/vm/src/vm/context.rs
+++ b/crates/vm/src/vm/context.rs
@@ -135,6 +135,7 @@ declare_const_name! {
__getformat__,
__getitem__,
__getnewargs__,
+ __getnewargs_ex__,
__getstate__,
__gt__,
__hash__,
diff --git a/extra_tests/snippets/builtins_module.py b/extra_tests/snippets/builtins_module.py
index 6dea94d8d77..bf762425c89 100644
--- a/extra_tests/snippets/builtins_module.py
+++ b/extra_tests/snippets/builtins_module.py
@@ -22,6 +22,17 @@
exec("", namespace)
assert namespace["__builtins__"] == __builtins__.__dict__
+
+# function.__builtins__ should be a dict, not a module
+# See: https://docs.python.org/3/reference/datamodel.html
+def test_func():
+ pass
+
+
+assert isinstance(test_func.__builtins__, dict), (
+ f"function.__builtins__ should be dict, got {type(test_func.__builtins__)}"
+)
+
# with assert_raises(NameError):
# exec('print(__builtins__)', {'__builtins__': {}})