@@ -31,32 +31,39 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None:
31
31
32
32
33
33
@pytest .fixture (scope = "package" )
34
- def gds () -> Generator [Any , None , None ]:
35
- from gds_helper import aura_api , connect_to_plugin_gds , create_aurads_instance
36
- from graphdatascience import GraphDataScience
34
+ def aura_ds_instance () -> Generator [Any , None , None ]:
35
+ if os .environ .get ("AURA_API_CLIENT_ID" , None ) is None :
36
+ yield None
37
+ return
38
+
39
+ from gds_helper import aura_api , create_aurads_instance
40
+
41
+ api = aura_api ()
42
+ id , dbms_connection_info = create_aurads_instance (api )
43
+
44
+ # setting as environment variables to run notebooks with this connection
45
+ os .environ ["NEO4J_URI" ] = dbms_connection_info .uri
46
+ os .environ ["NEO4J_USER" ] = dbms_connection_info .username
47
+ os .environ ["NEO4J_PASSWORD" ] = dbms_connection_info .password
48
+ yield dbms_connection_info
37
49
38
- use_cloud_setup = os .environ .get ("AURA_API_CLIENT_ID" , None )
50
+ # Clear Neo4j_URI after test (rerun should create a new instance)
51
+ os .environ ["NEO4J_URI" ] = ""
52
+ api .delete_instance (id )
39
53
40
- if use_cloud_setup :
41
- api = aura_api ()
42
- id , dbms_connection_info = create_aurads_instance (api )
43
54
44
- # setting as environment variables to run notebooks with this connection
45
- os . environ [ "NEO4J_URI" ] = dbms_connection_info . uri
46
- os . environ [ "NEO4J_USER" ] = dbms_connection_info . username
47
- os . environ [ "NEO4J_PASSWORD" ] = dbms_connection_info . password
55
+ @ pytest . fixture ( scope = "package" )
56
+ def gds ( aura_ds_instance : Any ) -> Generator [ Any , None , None ]:
57
+ from gds_helper import connect_to_plugin_gds
58
+ from graphdatascience import GraphDataScience
48
59
60
+ if aura_ds_instance :
49
61
yield GraphDataScience (
50
- endpoint = dbms_connection_info .uri ,
51
- auth = (dbms_connection_info .username , dbms_connection_info .password ),
62
+ endpoint = aura_ds_instance .uri ,
63
+ auth = (aura_ds_instance .username , aura_ds_instance .password ),
52
64
aura_ds = True ,
53
65
database = "neo4j" ,
54
66
)
55
-
56
- # Clear Neo4j_URI after test (rerun should create a new instance)
57
- os .environ ["NEO4J_URI" ] = ""
58
-
59
- api .delete_instance (id )
60
67
else :
61
68
NEO4J_URI = os .environ .get ("NEO4J_URI" , "neo4j://localhost:7687" )
62
69
gds = connect_to_plugin_gds (NEO4J_URI )
@@ -65,12 +72,24 @@ def gds() -> Generator[Any, None, None]:
65
72
66
73
67
74
@pytest .fixture (scope = "package" )
68
- def neo4j_session ( ) -> Generator [Any , None , None ]:
75
+ def neo4j_driver ( aura_ds_instance : Any ) -> Generator [Any , None , None ]:
69
76
import neo4j
70
77
71
- NEO4J_URI = os .environ .get ("NEO4J_URI" , "neo4j://localhost:7687" )
78
+ if aura_ds_instance :
79
+ driver = neo4j .GraphDatabase .driver (
80
+ aura_ds_instance .uri , auth = (aura_ds_instance .username , aura_ds_instance .password )
81
+ )
82
+ else :
83
+ NEO4J_URI = os .environ .get ("NEO4J_URI" , "neo4j://localhost:7687" )
84
+ driver = neo4j .GraphDatabase .driver (NEO4J_URI )
85
+
86
+ driver .verify_connectivity ()
87
+ yield driver
72
88
73
- with neo4j .GraphDatabase .driver (NEO4J_URI ) as driver :
74
- driver .verify_connectivity ()
75
- with driver .session () as session :
76
- yield session
89
+ driver .close ()
90
+
91
+
92
+ @pytest .fixture (scope = "package" )
93
+ def neo4j_session (neo4j_driver : Any ) -> Generator [Any , None , None ]:
94
+ with neo4j_driver .session () as session :
95
+ yield session
0 commit comments