Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update vs sync scan and sync endpoint to account for new cidr_org relation #814

Merged
merged 3 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 45 additions & 43 deletions backend/src/xfd_django/xfd_api/api_methods/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# Standard Python Libraries
from datetime import datetime
import json
from uuid import uuid4

# Third-Party Libraries
Expand All @@ -17,47 +18,52 @@
from xfd_mini_dl.models import Organization, Sector

from ..helpers.s3_client import S3Client
from ..utils.csv_utils import convert_csv_to_json, create_checksum
from ..utils.csv_utils import create_checksum


async def sync_post(sync_body, request: Request):
"""Ingest and persist organization data to the data lake."""
headers = request.headers
request_checksum = headers.get("x-checksum")

if not request_checksum or not sync_body.data:
raise HTTPException(status_code=500, detail="Missing checksum")

if request_checksum != create_checksum(sync_body.data):
raise HTTPException(status_code=500, detail="Missing checksum")

# Use MinIO client to save CSV data to S3
s3_client = S3Client()
start_bound, end_bound = parse_cursor(headers.get("x-cursor"))
file_name = generate_s3_filename(start_bound, end_bound)

s3_url = s3_client.save_csv(sync_body.data, file_name)
if not s3_url:
return {"status": 500}

parsed_data = convert_csv_to_json(sync_body.data)
for item in parsed_data:
try:
org = save_organization_to_mdl(
org_dict=item,
network_list=item["cidrs"],
location=item["location"],
db_name="mini_data_lake_integration",
)

if org:
link_parent_organization(org, item.get("parent"))
link_sectors_to_organization(org, item.get("sectors", []))
try:
headers = request.headers
request_checksum = headers.get("x-checksum")
if not request_checksum or not sync_body.data:
raise HTTPException(status_code=500)

if request_checksum != create_checksum(sync_body.data):
raise HTTPException(status_code=500)

# Use MinIO client to save CSV data to S3
s3_client = S3Client()
start_bound, end_bound = parse_cursor(headers.get("x-cursor"))
file_name = generate_s3_filename(start_bound, end_bound)

s3_url = s3_client.save_csv(sync_body.data, file_name)
if not s3_url:
raise HTTPException(status_code=500)

parsed_data = json.loads(sync_body.data)

for item in parsed_data:
try:
org = save_organization_to_mdl(
org_dict=item,
network_list=item["cidrs"],
location=item["location"],
db_name="mini_data_lake",
)

except Exception as e:
print("Error processing item:", e)
if org:
link_parent_organization(
org, item.get("parent"), db_name="mini_data_lake"
)
link_sectors_to_organization(
org, item.get("sectors", []), db_name="mini_data_lake"
)

return {"status": 200}
except Exception:
raise HTTPException(status_code=500)
except Exception:
raise HTTPException(status_code=500)


def parse_cursor(cursor_header):
Expand All @@ -75,7 +81,7 @@ def generate_s3_filename(start_bound, end_bound):
return f"lz_org_sync/{now.month}-{now.day}-{now.year}/{start_bound}-{end_bound}.csv"


def link_parent_organization(org, parent_data):
def link_parent_organization(org, parent_data, db_name="mini_data_lake_lz"):
"""Link an organization to its parent if applicable."""
if not isinstance(parent_data, dict):
return
Expand All @@ -85,9 +91,7 @@ def link_parent_organization(org, parent_data):
return

try:
parent_org = Organization.objects.using("mini_data_lake_integration").get(
acronym=parent_acronym
)
parent_org = Organization.objects.using(db_name).get(acronym=parent_acronym)
org.parent = parent_org
org.save()
except Organization.DoesNotExist:
Expand All @@ -96,7 +100,7 @@ def link_parent_organization(org, parent_data):
print("Error while linking parent org to child org:", e)


def link_sectors_to_organization(org, sectors):
def link_sectors_to_organization(org, sectors, db_name="mini_data_lake_lz"):
"""Associate sectors with the organization."""
if not isinstance(sectors, list):
return
Expand All @@ -108,9 +112,7 @@ def link_sectors_to_organization(org, sectors):

try:
with transaction.atomic():
sector_obj, created = Sector.objects.using(
"mini_data_lake_integration"
).get_or_create(
sector_obj, created = Sector.objects.using(db_name).get_or_create(
acronym=sector_acronym,
defaults={"id": str(uuid4()), "name": sector.get("name")},
)
Expand Down
2 changes: 0 additions & 2 deletions backend/src/xfd_django/xfd_api/schema_models/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
class SyncResponse(BaseModel):
"""Response model for sync operations."""

status: int


class SyncBody(BaseModel):
"""Request body model for sync operations."""
Expand Down
111 changes: 48 additions & 63 deletions backend/src/xfd_django/xfd_api/tasks/vulnScanningSync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import psycopg2
import requests
from xfd_api.utils.chunk import chunk_list_by_bytes
from xfd_api.utils.csv_utils import convert_to_csv, create_checksum
from xfd_api.utils.csv_utils import create_checksum
from xfd_api.utils.hash import hash_ip
from xfd_api.utils.scan_utils.vuln_scanning_sync_utils import (
fetch_orgs_and_relations,
load_test_data,
save_cve_to_datalake,
save_host,
save_ip_to_datalake,
Expand Down Expand Up @@ -50,42 +51,6 @@ async def handler(event):
return {"statusCode": 500, "body": str(e)}


# Used for loading test data from file for vuln_scans, port_scans, hosts, tickets
def load_test_data(data_set: str) -> list:
"""Load test data from local files for scanning simulations.

Args:
data_set (str): The type of data set to load (e.g., "requests", "vuln_scan").

Returns:
list: The parsed JSON data from the file.

Raises:
ValueError: If an unknown data_set is provided.
FileNotFoundError: If the specified file does not exist.
"""
file_paths = {
"requests": "~/Downloads/requests_full_redshift.json",
"vuln_scan": "~/Downloads/vuln_scan_sample.json",
"port_scans": "~/Downloads/port_scans_sample.json",
"hosts": "~/Downloads/hosts_sample.json",
"tickets": "~/Downloads/tickets_sample.json",
}

file_path = file_paths.get(data_set)

if file_path is None:
raise ValueError(f"Unknown data set: {data_set}")

expanded_path = os.path.expanduser(file_path)

if not os.path.exists(expanded_path):
raise FileNotFoundError(f"Test data file not found: {expanded_path}")

with open(expanded_path, encoding="utf-8") as file:
return json.load(file)


def query_redshift(query, params=None):
"""Execute a query on Redshift and return results as a list of dictionaries."""
conn = psycopg2.connect(
Expand Down Expand Up @@ -113,7 +78,8 @@ def main():
print("Starting VS Sync scan")

# Load request data
request_list = load_test_data("requests")
request_list = fetch_from_redshift("SELECT * FROM vmtableau.requests;")

org_id_dict = process_orgs(request_list)

# Process Organizations & Relations
Expand Down Expand Up @@ -150,35 +116,61 @@ def main():
print("VS Sync scan completed successfully!")


def detect_data_set(query):
"""Detect the data set from the query."""
if "vulns_scans" in query:
return "vuln_scan"
if "hosts" in query:
return "hosts"
if "tickets" in query:
return "tickets"
if "port_scans" in query:
return "port_scans"
return None


def fetch_from_redshift(query):
"""Fetch data from Redshift and log execution time."""
IS_LOCAL = True
if IS_LOCAL:
data_set = detect_data_set(query)
return load_test_data(data_set)
try:
start_time = datetime.datetime.now()
result = query_redshift(query)
end_time = datetime.datetime.now()
duration_seconds = (end_time - start_time).total_seconds()
print(f"[Redshift] [{duration_seconds}s] [{len(result)} records] {query}")
return result
return result.rows
except Exception as e:
print(f"Error fetching data from Redshift: {e}")
return []


def save_json_to_file(data, filename="test.json"):
"""Save JSON data to a file."""
try:
with open(filename, "w", encoding="utf-8") as file:
json.dump(data, file, indent=4)
print(f"Data successfully saved to {filename}")
except Exception as e:
print(f"Error saving JSON to file: {e}")


def process_organizations_and_relations():
"""Fetch organizations and sync with the external API."""
try:
shaped_orgs = fetch_orgs_and_relations()
if not shaped_orgs:
return

print("Shaped orgs exist, chunking and processing")
chunks = chunk_list_by_bytes(shaped_orgs, 4194304)

chunks = chunk_list_by_bytes(shaped_orgs, 2097152)
for idx, chunk_info in enumerate(chunks):
chunk = chunk_info["chunk"]
bounds = chunk_info["bounds"]
csv_data = convert_to_csv(chunk)
send_csv_to_sync(csv_data, bounds)
# save_json_to_file(json.dumps(chunk), f"org_chunk_{idx}.json")
send_csv_to_sync(json.dumps(chunk), bounds)

except Exception as e:
print(f"Error processing organization data: {e}")

Expand All @@ -196,11 +188,11 @@ def send_csv_to_sync(csv_data, bounds):
"x-checksum": checksum,
"x-cursor": f"{bounds['start']}-{bounds['end']}",
"Content-Type": "application/json",
"Authorization": os.environ.get("DMZ_API_KEY"),
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImU0N2Q0MDk3LWM5ZTEtNGMxZS1iMmY0LWY2Mjg4OTU3NTlkMCIsImVtYWlsIjoiSkFOU09OLkJVTkNFQGFzc29jaWF0ZXMuY2lzYS5kaHMuZ292IiwiZXhwIjoxNzQxMzc1NTQwfQ.Yo4A0fBP8cF00ODKgxD5U2NaffXv1k3nmaN4jv5kyVs",
}

response = requests.post(
os.environ.get("DMZ_SYNC_ENDPOINT"), json=body, headers=headers, timeout=60
"http://localhost:3000/sync", json=body, headers=headers, timeout=60
)
if response.status_code == 200:
print("CSV successfully sent to /sync")
Expand Down Expand Up @@ -364,18 +356,18 @@ def process_orgs(request_list):
return org_id_dict


def link_parent_child_organizations(parent_child_dict, org_id_dict):
def link_parent_child_organizations(
parent_child_dict, org_id_dict, db_name="mini_data_lake_lz"
):
"""Link child organizations to their respective parent organizations."""
for parent_acronym, child_acronyms in parent_child_dict.items():
parent_id = org_id_dict.get(parent_acronym)
if not parent_id:
print(f"Parent acronym {parent_acronym} not found in org_id_dict")
continue

try:
parent_org = Organization.objects.get(id=parent_id)
parent_org = Organization.objects.using(db_name).get(id=parent_id)
except Organization.DoesNotExist:
print(f"Parent organization {parent_id} does not exist")
continue

# Collect child organization IDs
Expand All @@ -387,21 +379,19 @@ def link_parent_child_organizations(parent_child_dict, org_id_dict):

# Update parent field for child organizations
if children_ids:
Organization.objects.filter(id__in=children_ids).update(
Organization.objects.using(db_name).filter(id__in=children_ids).update(
parent=parent_org.id
)
print(
f"Successfully linked {len(children_ids)} children to parent {parent_acronym}"
)


def assign_organizations_to_sectors(sector_child_dict, org_id_dict):
def assign_organizations_to_sectors(
sector_child_dict, org_id_dict, db_name="mini_data_lake_lz"
):
"""Assign organizations to sectors based on sector-child relationships."""
for sector_id, child_acronyms in sector_child_dict.items():
try:
sector = Sector.objects.get(id=sector_id)
sector = Sector.objects.using(db_name).get(id=sector_id)
except Sector.DoesNotExist:
print(f"Sector {sector_id} does not exist")
continue

organization_ids = [
Expand All @@ -412,10 +402,7 @@ def assign_organizations_to_sectors(sector_child_dict, org_id_dict):

if organization_ids:
sector.organizations.add(
*Organization.objects.filter(id__in=organization_ids)
)
print(
f"Successfully added {len(organization_ids)} organizations to sector {sector_id}"
*Organization.objects.using(db_name).filter(id__in=organization_ids)
)


Expand All @@ -437,7 +424,6 @@ def process_request(request_list, sector_child_dict, parent_child_dict, org_id_d
# Skip non-sector records
if "type" not in request["agency"]:
if request["_id"] in non_sector_list:
print("Record missing ID, skipping to next")
continue

process_sector(request, sector_child_dict)
Expand Down Expand Up @@ -482,7 +468,6 @@ def process_sector(request, sector_child_dict):
"retired": sector_data["retired"],
},
)
print(f"{'Created' if created else 'Updated'} sector {sector_obj.id}")
sector_child_dict[sector_obj.id] = request["children"]
except Exception as e:
print("Error occurred creating sector", e)
Expand Down
Loading
Loading