Skip to content

Commit 114b3ea

Browse files
committed
Add ability to use single https connection
1 parent 007984e commit 114b3ea

File tree

1 file changed

+47
-39
lines changed

1 file changed

+47
-39
lines changed

cpapi/mgmt_api.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class APIClientArgs:
4040
# context possible values - web_api (default) or gaia_api
4141
def __init__(self, port=None, fingerprint=None, sid=None, server="127.0.0.1", http_debug_level=0,
4242
api_calls=None, debug_file="", proxy_host=None, proxy_port=8080,
43-
api_version=None, unsafe=False, unsafe_auto_accept=False, context="web_api"):
43+
api_version=None, unsafe=False, unsafe_auto_accept=False, context="web_api", single_conn=True):
4444
self.port = port
4545
# management server fingerprint
4646
self.fingerprint = fingerprint
@@ -66,6 +66,8 @@ def __init__(self, port=None, fingerprint=None, sid=None, server="127.0.0.1", ht
6666
self.unsafe_auto_accept = unsafe_auto_accept
6767
# The context of using the client - defaults to web_api
6868
self.context = context
69+
# Indicates that the client should use single HTTPS connection
70+
self.single_conn = single_conn
6971

7072

7173
class APIClient:
@@ -108,6 +110,10 @@ def __init__(self, api_client_args=None):
108110
self.unsafe_auto_accept = api_client_args.unsafe_auto_accept
109111
# The context of using the client - defaults to web_api
110112
self.context = api_client_args.context
113+
# HTTPS connection
114+
self.conn = None
115+
# Indicates that the client should use single HTTPS connection
116+
self.single_conn = api_client_args.single_conn
111117

112118
def __enter__(self):
113119
return self
@@ -265,8 +271,8 @@ def api_call(self, command, payload=None, sid=None, wait_for_task=True, timeout=
265271
:side-effects: updates the class's uid and server variables
266272
"""
267273
timeout_start = time.time()
268-
269-
self.check_fingerprint()
274+
if self.check_fingerprint() is False:
275+
return APIResponse("", False, err_message="Invalid fingerprint")
270276
if payload is None:
271277
payload = {}
272278
# Convert the json payload to a string if needed
@@ -292,23 +298,8 @@ def api_call(self, command, payload=None, sid=None, wait_for_task=True, timeout=
292298
if sid is not None:
293299
_headers["X-chkp-sid"] = sid
294300

295-
# Create ssl context with no ssl verification, we do it by ourselves
296-
context = ssl.create_default_context()
297-
context.check_hostname = False
298-
context.verify_mode = ssl.CERT_NONE
299-
300-
# create https connection
301-
if self.proxy_host and self.proxy_port:
302-
conn = HTTPSConnection(self.proxy_host, self.proxy_port, context=context)
303-
conn.set_tunnel(self.server, self.get_port())
304-
else:
305-
conn = HTTPSConnection(self.server, self.get_port(), context=context)
306-
307-
# Set fingerprint
308-
conn.fingerprint = self.fingerprint
309-
310-
# Set debug level
311-
conn.set_debuglevel(self.http_debug_level)
301+
# init https connection. if single connection is True, use last connection
302+
conn = self.get_https_connection()
312303
url = "/" + self.context + "/" + (("v" + str(self.api_version) + "/") if self.api_version else "") + command
313304
response = None
314305
try:
@@ -328,7 +319,8 @@ def api_call(self, command, payload=None, sid=None, wait_for_task=True, timeout=
328319
except Exception as err:
329320
res = APIResponse("", False, err_message=err)
330321
finally:
331-
conn.close()
322+
if not self.single_conn:
323+
conn.close()
332324

333325
if response:
334326
res.status_code = response.status
@@ -464,21 +456,13 @@ def gen_api_query(self, command, details_level="standard", container_keys=None,
464456

465457
def get_server_fingerprint(self):
466458
"""
467-
Initiates an HTTPS connection to the server and extracts the SHA1 fingerprint from the server's certificate.
459+
Initiates an HTTPS connection to the server if need and extracts the SHA1 fingerprint from the server's certificate.
468460
:return: string with SHA1 fingerprint (all uppercase letters)
469461
"""
470-
context = ssl.create_default_context()
471-
context.check_hostname = False
472-
context.verify_mode = ssl.CERT_NONE
473-
474-
if self.proxy_host and self.proxy_port:
475-
conn = HTTPSConnection(self.proxy_host, self.proxy_port, context=context)
476-
conn.set_tunnel(self.server, self.get_port())
477-
else:
478-
conn = HTTPSConnection(self.server, self.get_port(), context=context)
479-
462+
conn = self.get_https_connection()
480463
fingerprint_hash = conn.get_fingerprint_hash()
481-
conn.close()
464+
if not self.single_conn:
465+
conn.close()
482466
return fingerprint_hash
483467

484468
def __wait_for_task(self, task_id, timeout=-1):
@@ -723,22 +707,46 @@ def read_fingerprint_from_file(server, filename="fingerprints.txt"):
723707
return json_dict[server]
724708
return ""
725709

710+
def create_https_connection(self):
711+
context = ssl.create_default_context()
712+
context.check_hostname = False
713+
context.verify_mode = ssl.CERT_NONE
714+
# create https connection
715+
if self.proxy_host and self.proxy_port:
716+
conn = HTTPSConnection(self.proxy_host, self.proxy_port, context=context)
717+
conn.set_tunnel(self.server, self.get_port())
718+
else:
719+
conn = HTTPSConnection(self.server, self.get_port(), context=context)
720+
721+
# Set fingerprint
722+
conn.fingerprint = self.fingerprint
723+
724+
# Set debug level
725+
conn.set_debuglevel(self.http_debug_level)
726+
conn.connect()
727+
return conn
728+
729+
def get_https_connection(self):
730+
if self.single_conn:
731+
if self.conn is None:
732+
self.conn = self.create_https_connection()
733+
return self.conn
734+
return self.create_https_connection()
735+
736+
def close_connection(self):
737+
if self.conn:
738+
self.conn.close()
739+
726740

727741
class HTTPSConnection(http_client.HTTPSConnection):
728742
"""
729743
A class for making HTTPS connections that overrides the default HTTPS checks (e.g. not accepting
730744
self-signed-certificates) and replaces them with a server fingerprint check.
731745
"""
732-
733746
def connect(self):
734747
http_client.HTTPConnection.connect(self)
735748
self.sock = ssl.wrap_socket(self.sock, self.key_file, self.cert_file, cert_reqs=ssl.CERT_NONE)
736749

737750
def get_fingerprint_hash(self):
738-
try:
739-
http_client.HTTPConnection.connect(self)
740-
self.sock = ssl.wrap_socket(self.sock, self.key_file, self.cert_file, cert_reqs=ssl.CERT_NONE)
741-
except Exception:
742-
return ""
743751
fingerprint = hashlib.new("SHA1", self.sock.getpeercert(True)).hexdigest()
744752
return fingerprint.upper()

0 commit comments

Comments
 (0)