1
1
import json
2
2
import os
3
3
import socket
4
+ import subprocess
5
+ import threading
4
6
import typing as t
7
+ import uuid
5
8
from codecs import open
6
9
from contextlib import contextmanager
7
- from os .path import abspath , dirname , isfile , join
10
+ from os .path import abspath , basename , dirname , isfile , join
8
11
from pathlib import Path
9
12
from random import choice
10
13
from string import ascii_lowercase , ascii_uppercase , digits
31
34
from sqlalchemy_utils import database_exists , drop_database
32
35
33
36
from . import database , factories , models
37
+ from .utils import generate_ssl_certs , stream_logs
34
38
35
39
36
40
def pytest_addoption (parser : "Parser" ):
@@ -70,6 +74,27 @@ def pytest_addoption(parser: "Parser"):
70
74
help = "The TCP port of the MySQL server." ,
71
75
)
72
76
77
+ parser .addoption (
78
+ "--mysql-ssl-ca" ,
79
+ dest = "mysql_ssl_ca" ,
80
+ default = None ,
81
+ help = "Path to SSL CA certificate file." ,
82
+ )
83
+
84
+ parser .addoption (
85
+ "--mysql-ssl-cert" ,
86
+ dest = "mysql_ssl_cert" ,
87
+ default = None ,
88
+ help = "Path to SSL certificate file." ,
89
+ )
90
+
91
+ parser .addoption (
92
+ "--mysql-ssl-key" ,
93
+ dest = "mysql_ssl_key" ,
94
+ default = None ,
95
+ help = "Path to SSL key file." ,
96
+ )
97
+
73
98
parser .addoption (
74
99
"--no-docker" ,
75
100
dest = "use_docker" ,
@@ -159,10 +184,13 @@ class MySQLCredentials(t.NamedTuple):
159
184
host : str
160
185
port : int
161
186
database : str
187
+ ssl_ca : t .Optional [str ] = None
188
+ ssl_cert : t .Optional [str ] = None
189
+ ssl_key : t .Optional [str ] = None
162
190
163
191
164
192
@pytest .fixture (scope = "session" )
165
- def mysql_credentials (pytestconfig : Config ) -> MySQLCredentials :
193
+ def mysql_credentials (request , pytestconfig : Config , tmp_path_factory : pytest . TempPathFactory ) -> MySQLCredentials :
166
194
db_credentials_file : str = abspath (join (dirname (__file__ ), "db_credentials.json" ))
167
195
if isfile (db_credentials_file ):
168
196
with open (db_credentials_file , "r" , "utf-8" ) as fh :
@@ -173,6 +201,9 @@ def mysql_credentials(pytestconfig: Config) -> MySQLCredentials:
173
201
database = db_credentials ["mysql_database" ],
174
202
host = db_credentials ["mysql_host" ],
175
203
port = db_credentials ["mysql_port" ],
204
+ ssl_ca = db_credentials .get ("mysql_ssl_ca" ),
205
+ ssl_cert = db_credentials .get ("mysql_ssl_cert" ),
206
+ ssl_key = db_credentials .get ("mysql_ssl_key" ),
176
207
)
177
208
178
209
port : int = pytestconfig .getoption ("mysql_port" ) or 3306
@@ -182,12 +213,35 @@ def mysql_credentials(pytestconfig: Config) -> MySQLCredentials:
182
213
pytest .fail (f"No ports appear to be available on the host { pytestconfig .getoption ('mysql_host' )} " )
183
214
port += 1
184
215
216
+ ssl_credentials = {
217
+ "ssl_ca" : pytestconfig .getoption ("mysql_ssl_ca" ) or None ,
218
+ "ssl_cert" : pytestconfig .getoption ("mysql_ssl_cert" ) or None ,
219
+ "ssl_key" : pytestconfig .getoption ("mysql_ssl_key" ) or None ,
220
+ }
221
+
222
+ if hasattr (request , "param" ) and request .param == "ssl" :
223
+ certs_dir = tmp_path_factory .getbasetemp () / "certs"
224
+ if not certs_dir .exists ():
225
+ certs_dir .mkdir (parents = True )
226
+ generate_ssl_certs (certs_dir )
227
+
228
+ # FIXED: docker perms
229
+ subprocess .call (["chmod" , "0644" , str (certs_dir / "ca-key.pem" )])
230
+ subprocess .call (["chmod" , "0644" , str (certs_dir / "server-key.pem" )])
231
+
232
+ ssl_credentials = {
233
+ "ssl_ca" : str (certs_dir / "ca.pem" ),
234
+ "ssl_cert" : str (certs_dir / "server-cert.pem" ),
235
+ "ssl_key" : str (certs_dir / "server-key.pem" ),
236
+ }
237
+
185
238
return MySQLCredentials (
186
239
user = pytestconfig .getoption ("mysql_user" ) or "tester" ,
187
240
password = pytestconfig .getoption ("mysql_password" ) or "testpass" ,
188
241
database = pytestconfig .getoption ("mysql_database" ) or "test_db" ,
189
242
host = pytestconfig .getoption ("mysql_host" ) or "0.0.0.0" ,
190
243
port = port ,
244
+ ** ssl_credentials ,
191
245
)
192
246
193
247
@@ -222,33 +276,72 @@ def mysql_instance(mysql_credentials: MySQLCredentials, pytestconfig: Config) ->
222
276
except (HTTPError , NotFound ) as err :
223
277
pytest .fail (str (err ))
224
278
279
+ ssl_cmds = []
280
+ ssl_args = {}
281
+ ssl_volumes = {}
282
+ host_certs_dir = None
283
+ container_certs_dir = "/etc/mysql/certs"
284
+
285
+ if mysql_credentials .ssl_ca :
286
+ host_certs_dir = dirname (mysql_credentials .ssl_ca )
287
+ ssl_cmds .append (f"--ssl-ca={ container_certs_dir } /{ basename (mysql_credentials .ssl_ca )} " )
288
+ ssl_args ["ssl_ca" ] = mysql_credentials .ssl_ca
289
+
290
+ if mysql_credentials .ssl_cert :
291
+ host_certs_dir = dirname (mysql_credentials .ssl_cert )
292
+ ssl_cmds .append (f"--ssl-cert={ container_certs_dir } /{ basename (mysql_credentials .ssl_cert )} " )
293
+ ssl_args ["ssl_cert" ] = f"{ host_certs_dir } /client-cert.pem"
294
+
295
+ if mysql_credentials .ssl_key :
296
+ host_certs_dir = dirname (mysql_credentials .ssl_key )
297
+ ssl_cmds .append (f"--ssl-key={ container_certs_dir } /{ basename (mysql_credentials .ssl_key )} " )
298
+ ssl_args ["ssl_key" ] = f"{ host_certs_dir } /client-key.pem"
299
+
300
+ if host_certs_dir :
301
+ ssl_volumes [host_certs_dir ] = {"bind" : container_certs_dir , "mode" : "ro" }
302
+
303
+ if ssl_args :
304
+ ssl_args ["ssl_verify_cert" ] = True
305
+
306
+ container_name = f"pytest_mysql_to_sqlite3_{ uuid .uuid4 ().hex [:10 ]} "
307
+
225
308
container = client .containers .run (
226
309
image = docker_mysql_image ,
227
- name = "pytest_mysql_to_sqlite3" ,
310
+ name = container_name ,
228
311
ports = {"3306/tcp" : (mysql_credentials .host , f"{ mysql_credentials .port } /tcp" )},
229
312
environment = {
230
313
"MYSQL_RANDOM_ROOT_PASSWORD" : "yes" ,
231
314
"MYSQL_USER" : mysql_credentials .user ,
232
315
"MYSQL_PASSWORD" : mysql_credentials .password ,
233
316
"MYSQL_DATABASE" : mysql_credentials .database ,
234
317
},
318
+ volumes = ssl_volumes ,
235
319
command = [
236
320
"--character-set-server=utf8mb4" ,
237
321
"--collation-server=utf8mb4_unicode_ci" ,
238
- ],
322
+ ]
323
+ + ssl_cmds ,
239
324
detach = True ,
240
325
auto_remove = True ,
241
326
)
242
327
328
+ log_thread = threading .Thread (target = stream_logs , args = (container ,))
329
+ # The thread will terminate when the main program terminates
330
+ log_thread .daemon = True
331
+ log_thread .start ()
332
+
243
333
while not mysql_available and mysql_connection_retries > 0 :
244
334
try :
335
+ print (f"Attempt #{ mysql_connection_retries } to connect to MySQL..." )
336
+
245
337
mysql_connection = mysql .connector .connect (
246
338
user = mysql_credentials .user ,
247
339
password = mysql_credentials .password ,
248
340
host = mysql_credentials .host ,
249
341
port = mysql_credentials .port ,
250
342
charset = "utf8mb4" ,
251
343
collation = "utf8mb4_unicode_ci" ,
344
+ ** ssl_args ,
252
345
)
253
346
except mysql .connector .Error as err :
254
347
if err .errno == errorcode .CR_SERVER_LOST :
@@ -270,6 +363,10 @@ def mysql_instance(mysql_credentials: MySQLCredentials, pytestconfig: Config) ->
270
363
if use_docker and container is not None :
271
364
container .kill ()
272
365
366
+ # Wait for the log thread to finish (optional)
367
+ if "log_thread" in locals () and log_thread .is_alive ():
368
+ log_thread .join (timeout = 5 )
369
+
273
370
274
371
@pytest .fixture (scope = "session" )
275
372
def mysql_database (
0 commit comments