Skip to content

Commit 6cbb049

Browse files
committed
add smoke test with pytest into CI
1 parent 8216777 commit 6cbb049

13 files changed

+942
-75
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ on:
55
types: [opened, synchronize]
66
push:
77
branches:
8-
- '*'
8+
- 'main'
99

1010
jobs:
1111
tests:
@@ -24,18 +24,13 @@ jobs:
2424
with:
2525
python-version: 3.9
2626

27-
- name: Set Spark env
28-
run: |
29-
export SPARK_LOCAL_IP=127.0.0.1
30-
export SPARK_SUBMIT_OPTS="--illegal-access=permit -Dio.netty.tryReflectionSetAccessible=true"
31-
3227
- name: Generate coverage report
3328
working-directory: ./
3429
run: |
3530
pip install -r requirements.txt
36-
pip install pyspark==3.1.2 pandas==1.4.2 numpy==1.22.4 coverage
37-
coverage run -m unittest discover -s tests -p 'tests_*.py'
31+
pip install coverage
32+
coverage run -m pytest
3833
coverage xml
3934
4035
- name: Publish test coverage
41-
uses: codecov/codecov-action@v1
36+
uses: codecov/codecov-action@v1

solacc.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

solacc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
src/solacc.py

src/solacc.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import graph_pipelines
2+
import config
3+
import json
4+
import os
5+
import sys
6+
import click
7+
8+
9+
def get_config(cfg_file):
10+
try:
11+
with open(cfg_file, "r") as fp:
12+
cfg = json.load(fp)
13+
except IOError:
14+
print (f"Error opening config file '{cfg_file}'")
15+
return None
16+
config.validate_config(cfg)
17+
return cfg
18+
19+
20+
@click.group()
21+
@click.option('--config', default="config/cga.json", help="use config from given json file")
22+
@click.pass_context
23+
def cli(ctx, config):
24+
ctx.ensure_object(dict)
25+
cfg = get_config(config)
26+
if cfg is None:
27+
sys.exit()
28+
ctx.obj['cfg']= cfg
29+
30+
31+
@cli.command()
32+
@click.option('--prefix', default=False, help="prefix generated notebooks with a number")
33+
@click.pass_context
34+
def generate(ctx, prefix):
35+
cfg = ctx.obj["cfg"]
36+
common = {}
37+
for k, v in cfg.items():
38+
if type(v) != list and type(v) != dict:
39+
common[k] = v
40+
41+
i = 0
42+
print(f"Generating notebooks for ")
43+
for nb_spec in cfg["notebooks"]:
44+
nb_spec.update(common)
45+
nb_spec["prefix"] = ""
46+
if prefix:
47+
nb_spec["prefix"] = f"{i:02d}_"
48+
print(f"... {nb_spec['prefix']}{nb_spec['id']}")
49+
if nb_spec["id"] == "dlt_edges":
50+
assert "dlt" in nb_spec and nb_spec["dlt"] == True
51+
graph_pipelines.gen_dlt_edges_notebook(nb_spec)
52+
elif nb_spec["id"] == "create_views":
53+
graph_pipelines.gen_create_views_notebook(nb_spec)
54+
else:
55+
graph_pipelines.gen_simple_notebook(nb_spec)
56+
i += 1
57+
58+
59+
@cli.command()
60+
@click.pass_context
61+
def deploy(ctx):
62+
#print(f"deploy.cfg= {json.dumps(ctx.obj['cfg'], indent=2)}")
63+
click.echo("deploy not implemented yet.")
64+
65+
if __name__ == "__main__":
66+
cli(obj={})

src/test_everything.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
import solacc
3+
import subprocess
4+
import os
5+
from click.testing import CliRunner
6+
7+
def results_matches_expected(basedir):
8+
results_dir = os.path.join(basedir, "results")
9+
expected_dir = os.path.join(basedir, "expected")
10+
process = subprocess.Popen(["diff", "-qr", results_dir, expected_dir])
11+
exit_code = process.wait()
12+
return not exit_code
13+
14+
def test_generate_smoke():
15+
runner = CliRunner()
16+
basedir = "src/tests/smoke01"
17+
result = runner.invoke(solacc.cli, ['--config', os.path.join(basedir, 'smoke01.json'), 'generate', '--prefix', 'True'])
18+
expected_result_str = """Generating notebooks for
19+
... 00_okta_collector
20+
... 01_bronze_sample
21+
... 02_dlt_edges
22+
... 03_create_views
23+
... 04_extract_same_as_edges
24+
... 05_analytics_01_impact
25+
... 06_analytics_02_investigation
26+
"""
27+
assert result.output == expected_result_str
28+
assert results_matches_expected(basedir)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Databricks notebook source
2+
# This notebook is designed to be run as a task within a multi-task job workflow.
3+
# These time window input widgets enable the user to do back fill and re-processing within the multi-task job workflow
4+
#dbutils.widgets.removeAll()
5+
dbutils.widgets.text("okta_start_time", "", "start time (YYYY-mm-ddTHH:MM:SSZ): ")
6+
start_time = dbutils.widgets.get("okta_start_time")
7+
dbutils.widgets.text("okta_end_time", "", "end time (YYYY-mm-ddTHH:MM:SSZ): ")
8+
end_time = dbutils.widgets.get("okta_end_time")
9+
10+
print(start_time + " to " + end_time)
11+
12+
13+
# COMMAND ----------
14+
15+
import json
16+
17+
#Here we use Okta API token, in production we recommend storing this token in Databricks secret store
18+
#https://docs.databricks.com/security/secrets/index.html
19+
cfg = {
20+
"base_url": "https://dev-74006068.okta.com/api/v1/logs",
21+
"token": "CHANGEME",
22+
"start_time": start_time,
23+
"end_time": end_time,
24+
"batch_size": 1000,
25+
"target_db": "",
26+
"target_table": "okta_bronze",
27+
"storage_path": "/tmp/solacc_cga"
28+
}
29+
30+
# we need to figure out where/when we execute this DDL. Ideally we don't want it in the collector task. This is needed to enable querying of the bronze table to figure out what is the latest event timestamp
31+
sql_str = f"""
32+
CREATE TABLE IF NOT EXISTS {cfg['target_db']}.{cfg['target_table']} (
33+
ingest_ts TIMESTAMP,
34+
event_ts TIMESTAMP,
35+
event_date TIMESTAMP,
36+
raw STRING) USING DELTA PARTITIONED BY (event_date) LOCATION '{cfg['storage_path']}'
37+
"""
38+
print(sql_str)
39+
spark.sql(sql_str)
40+
41+
# if task parameters (ie widgets) are empty, then we default to using the latest timestamp from the bronze table
42+
if len(cfg["start_time"])==0 and len(cfg["end_time"])==0:
43+
sql_str = f"""
44+
select max(event_ts) as latest_event_ts
45+
from {cfg['target_db']}.{cfg['target_table']}"""
46+
47+
df = spark.sql(sql_str)
48+
latest_ts = df.first()["latest_event_ts"]
49+
if latest_ts is None:
50+
print("latest_ts is none - default to 7 days from now")
51+
default_ts = datetime.today() - timedelta(days=7)
52+
cfg["start_time"]=default_ts.strftime("%Y-%m-%dT%H:%M:%SZ")
53+
else:
54+
print("latest_ts from bronze table is " + latest_ts.isoformat())
55+
cfg["start_time"]=latest_ts.strftime("%Y-%m-%dT%H:%M:%SZ")
56+
57+
print(json.dumps(cfg, indent=2))
58+
59+
# COMMAND ----------
60+
61+
import requests
62+
import json
63+
import re
64+
import datetime
65+
66+
from pyspark.sql import Row
67+
import pyspark.sql.functions as f
68+
69+
def poll_okta_logs(cfg, debug=False):
70+
MINIMUM_COUNT=5 # Must be >= 2, see note below
71+
72+
headers = {'Authorization': 'SSWS ' + cfg["token"]}
73+
query_params = {
74+
"limit": str(cfg["batch_size"]),
75+
"sortOrder": "ASCENDING",
76+
"since": cfg["start_time"]
77+
}
78+
if cfg["end_time"]:
79+
query_params["until"] = cfg["end_time"]
80+
81+
url = cfg["base_url"]
82+
total_cnt = 0
83+
while True:
84+
# Request the next link in our sequence:
85+
r = requests.get(url, headers=headers, params=query_params)
86+
87+
if not r.status_code == requests.codes.ok:
88+
break
89+
90+
ingest_ts = datetime.datetime.now(datetime.timezone.utc)
91+
92+
# Break apart the records into individual rows
93+
jsons = []
94+
jsons.extend([json.dumps(x) for x in r.json()])
95+
96+
# Make sure we have something to add to the table
97+
if len(jsons) == 0: break
98+
# Load into a dataframe
99+
df = (
100+
sc.parallelize([Row(raw=x) for x in jsons]).toDF()
101+
.selectExpr(f"'{ingest_ts.isoformat()}'::timestamp AS ingest_ts",
102+
"date_trunc('DAY', raw:published::timestamp) AS event_date",
103+
"raw:published::timestamp AS event_ts",
104+
"raw AS raw")
105+
)
106+
#print("%d %s" % (df.count(),url))
107+
total_cnt += len(jsons)
108+
if debug:
109+
display(df)
110+
else:
111+
# Append to delta table
112+
df.write\
113+
.option("mergeSchema", "true")\
114+
.format('delta') \
115+
.mode('append') \
116+
.partitionBy("event_date") \
117+
.save(cfg["storage_path"])
118+
119+
#When we make an API call, we cause an event. So there is the potential to get
120+
#into a self-perpetuating loop. Thus we look to ensure there is a certain minimum number
121+
#of entries before we are willing loop again.
122+
if len(jsons) < MINIMUM_COUNT: break
123+
124+
#print(r.headers["Link"])
125+
126+
# Look for the 'next' link; note there is also a 'self' link, so we need to get the right one
127+
rgx = re.search(r"\<([^>]+)\>\; rel=\"next\"", str(r.headers['Link']), re.I)
128+
if rgx:
129+
# We got a next link match; set that as new URL and repeat
130+
url = rgx.group(1)
131+
continue
132+
else:
133+
# No next link, we are done
134+
break
135+
return total_cnt
136+
137+
cnt = poll_okta_logs(cfg)
138+
print(f"Total records polled = {cnt}")
139+
140+
# COMMAND ----------
141+

0 commit comments

Comments
 (0)