Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions detection_rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,15 +1440,14 @@ def get_packaged_integrations(
# if both exist, rule tags are only used if defined in definitions for non-dataset packages
# of machine learning analytic packages

rule_integrations = meta.get("integration", [])
if rule_integrations:
for integration in rule_integrations:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simple style fix, replacing if condition with a more robust default value condition via

rule_integrations = meta.get("integration") or []

ineligible_integrations = [
*definitions.NON_DATASET_PACKAGES,
*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES),
]
if integration in ineligible_integrations or isinstance(data, MachineLearningRuleData):
packaged_integrations.append({"package": integration, "integration": None})
rule_integrations = meta.get("integration") or []
for integration in rule_integrations:
ineligible_integrations = [
*definitions.NON_DATASET_PACKAGES,
*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES),
]
if integration in ineligible_integrations or isinstance(data, MachineLearningRuleData):
packaged_integrations.append({"package": integration, "integration": None})

packaged_integrations.extend(parse_datasets(list(datasets), package_manifest))

Expand Down Expand Up @@ -1762,7 +1761,7 @@ def parse_datasets(datasets: list[str], package_manifest: dict[str, Any]) -> lis
else:
package = value

if package in list(package_manifest):
if package in package_manifest:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small style fix

packaged_integrations.append({"package": package, "integration": integration})
return packaged_integrations

Expand Down
253 changes: 251 additions & 2 deletions detection_rules/rule_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Validation logic for rules containing queries."""

import re
import time
import typing
from collections.abc import Callable
from enum import Enum
Expand All @@ -15,16 +16,18 @@
import click
import eql # type: ignore[reportMissingTypeStubs]
import kql # type: ignore[reportMissingTypeStubs]
from elasticsearch import Elasticsearch # type: ignore[reportMissingTypeStubs]
from eql import ast # type: ignore[reportMissingTypeStubs]
from eql.parser import KvTree, LarkToEQL, NodeInfo, TypeHint # type: ignore[reportMissingTypeStubs]
from eql.parser import _parse as base_parse # type: ignore[reportMissingTypeStubs]
from kibana import Kibana # type: ignore[reportMissingTypeStubs]
from marshmallow import ValidationError
from semver import Version

from . import ecs, endgame
from . import ecs, endgame, integrations, utils
from .config import CUSTOM_RULES_DIR, load_current_package_version, parse_rules_config
from .custom_schemas import update_auto_generated_schema
from .integrations import get_integration_schema_data, load_integrations_manifests
from .integrations import get_integration_schema_data, load_integrations_manifests, load_integrations_schemas
from .rule import EQLRuleData, QueryRuleData, QueryValidator, RuleMeta, TOMLRuleContents, set_eql_config
from .schemas import get_stack_schemas

Expand Down Expand Up @@ -639,6 +642,22 @@ def validate_integration(
pass


def convert_to_nested_schema(flat_schemas: dict[str, str]) -> dict[str, Any]:
"""Convert a flat schema to a nested schema with 'properties' for each sub-key."""
nested_schema = {}

for key, value in flat_schemas.items():
parts = key.split(".")
current_level = nested_schema

for part in parts[:-1]:
current_level = current_level.setdefault(part, {}).setdefault("properties", {})

current_level[parts[-1]] = {"type": value}

return nested_schema


def extract_error_field(source: str, exc: eql.EqlParseError | kql.KqlParseError) -> str | None:
"""Extract the field name from an EQL or KQL parse error."""
lines = source.splitlines()
Expand All @@ -647,3 +666,233 @@ def extract_error_field(source: str, exc: eql.EqlParseError | kql.KqlParseError)
start = exc.column # type: ignore[reportUnknownMemberType]
stop = start + len(exc.caret.strip()) # type: ignore[reportUnknownVariableType]
return re.sub(r"^\W+|\W+$", "", line[start:stop]) # type: ignore[reportUnknownArgumentType]


def traverse_schema(keys: list[str], current_schema: dict[str, Any] | None) -> str | None:
"""Recursively traverse the schema to find the type of the column."""
key = keys[0]
if not current_schema:
return None
column = current_schema.get(key) or {}
column_type = column.get("type") if column else None
if not column_type and len(keys) > 1:
return traverse_schema(keys[1:], current_schema=column.get("properties"))
return column_type


def validate_columns_input_mapping(query_columns: list[dict[str, str]], combined_mappings: dict[str, Any]):
"""Validate that the columns in the ESQL query match the provided mappings."""
mismatched_columns: list[str] = []

for column in query_columns:
column_name = column["name"]
if column_name.startswith("Esql.") or column_name.startswith("Esql_priv."):
continue
column_type = column["type"]

# Check if the column exists in combined_mappings or a valid field generated from a function or operator
keys = column_name.split(".")
schema_type = traverse_schema(keys, combined_mappings)

# Validate the type
if not schema_type or column_type != schema_type:
mismatched_columns.append(
f"Dynamic field `{column_name}` is not correctly mapped. "
f"If not dynamic: expected `{schema_type}`, got `{column_type}`."
)

# Raise an error if there are mismatches
if mismatched_columns:
raise ValueError("Column validation errors:\n" + "\n".join(mismatched_columns))

return True


def validate_esql_rule(kibana_client: Kibana, elastic_client: Elasticsearch, contents: TOMLRuleContents) -> None:
rule_id = contents.data.rule_id

# FIXME perhaps move this to utils
def log(val: str) -> None:
print(f"{rule_id}:", val)

kibana_details = kibana_client.get("/api/status")
stack_version = kibana_details["version"]["number"]

log(f"Validating against {stack_version} stack")

indices_str, indices = utils.get_esql_query_indices(contents.data.query)
log(f"Extracted indices from query: {', '.join(indices)}")

# Get mappings for all matching existing index templates

existing_mappings: dict[str, Any] = {}

for index in indices:
index_tmpl_mappings = get_simulated_template_mappings(elastic_client, index)
combine_dicts(existing_mappings, index_tmpl_mappings)

log(f"Collected mappings: {len(existing_mappings)}")

# Collect mappings for the integrations

rule_integrations = []
if contents.metadata.integration:
if isinstance(contents.metadata.integration, list):
rule_integrations = contents.metadata.integration
else:
rule_integrations = [contents.metadata.integration]

if len(rule_integrations) > 0:
log(f"Working with rule integrations: {', '.join(rule_integrations)}")
else:
log("No integrations found in the rule")

package_manifests = load_integrations_manifests()
integration_schemas = load_integrations_schemas()

integration_mappings = {}

for integration in rule_integrations:
# Assume the integration value is a package name
package = integration

package_version, _ = integrations.find_latest_compatible_version(
package,
"",
Version.parse(stack_version),
package_manifests,
)

package_schema = integration_schemas[package][package_version]

# Add schemas for all streams in the package
for stream in package_schema:
flat_schema = package_schema[stream]
stream_mappings = flat_schema_to_mapping(flat_schema)
combine_dicts(integration_mappings, stream_mappings)

log(f"Integration mappings prepared: {len(integration_mappings)}")

combined_mappings = {}
combine_dicts(combined_mappings, existing_mappings)
combine_dicts(combined_mappings, integration_mappings)
# NOTE non-ecs schema needs to have formatting updates prior to merge
# NOTE non-ecs schema uses Kibana reserved word "properties" as a field name
# e.g. "azure.auditlogs.properties.target_resources.0.display_name": "keyword",
non_ecs_mapping = {}
non_ecs = ecs.get_non_ecs_schema()
for index in indices:
non_ecs_mapping.update(non_ecs.get(index, {}))
non_ecs_mapping = ecs.flatten(non_ecs_mapping)
non_ecs_mapping = convert_to_nested_schema(non_ecs_mapping)
if non_ecs_mapping:
combine_dicts(combined_mappings, non_ecs_mapping)

if not combined_mappings:
log("ERROR: no mappings found for the rule")
raise ValueError("No mappings found")

# Creating a test index with the test name
suffix = str(int(time.time() * 1000))
test_index = f"rule-test-index-{suffix}"

# creating an index
response = elastic_client.indices.create(
index=test_index,
mappings={"properties": combined_mappings},
settings={
"index.mapping.total_fields.limit": 10000,
"index.mapping.nested_fields.limit": 500,
"index.mapping.nested_objects.limit": 10000,
},
)
log(f"Index `{test_index}` created: {response}")

# Replace all sources with the test index
query = contents.data.query
query = query.replace(indices_str, test_index)

try:
log(f"Executing a query against `{test_index}`")
response = elastic_client.esql.query(query=query)
log(f"Got query response: {response}")
query_columns = response.get("columns", [])
finally:
response = elastic_client.indices.delete(index=test_index)
log(f"Test index `{test_index}` deleted: {response}")

query_column_names = [c["name"] for c in query_columns]
log(f"Got query columns: {', '.join(query_column_names)}")

# FIXME Perhaps update rule_validator's get_required_fields as well
# to everything needs to either be directly mapped to schema or be annotated as dynamic field
if validate_columns_input_mapping(query_columns, combined_mappings):
log("All dynamic columns have proper formatting.")
else:
log("Dynamic column(s) have improper formatting.")


def get_simulated_template_mappings(elastic_client: Elasticsearch, name: str) -> dict[str, Any]:
"""
Return the mappings from the index configuration that would be applied
to the specified index from an existing index template

https://elasticsearch-py.readthedocs.io/en/stable/api/indices.html#elasticsearch.client.IndicesClient.simulate_index_template
"""
template = elastic_client.indices.simulate_index_template(name=name)
if not template:
return {}
return template["template"]["mappings"]["properties"]


def get_indices(elastic_client: Kibana, index: str) -> list[str]:
"""Fetch indices that match the provided name from Elasticsearch"""
# `index` arg here supports wildcards
return [i["index"] for i in elastic_client.cat.indices(index=index, format="json")]


def combine_dicts(dest: dict[Any, Any], src: dict[Any, Any]) -> None:
"""Combine two dictionaries recursively."""
for k, v in src.items():
if k in dest and isinstance(dest[k], dict) and isinstance(v, dict):
combine_dicts(dest[k], v)
else:
dest[k] = v


def flat_schema_to_mapping(flat_schema: dict[str, str]) -> dict[str, Any]:
"""
Convert dicts with flat JSON paths and values into a nested mapping with
intermediary `properties`, `fields` and `type` fields.
"""

# Sorting here ensures that 'a.b' processed before 'a.b.c', allowing us to correctly
# detect and handle multi-fields.
sorted_items = sorted(flat_schema.items())
result = {}

for field_path, field_type in sorted_items:
parts = field_path.split(".")
current_level = result

for part in parts[:-1]:
node = current_level.setdefault(part, {})

if "type" in node and node["type"] not in ("nested", "object"):
current_level = node.setdefault("fields", {})
else:
current_level = node.setdefault("properties", {})

leaf_key = parts[-1]
current_level[leaf_key] = {"type": field_type}

# add `scaling_factor` field missing in the schema
# https://www.elastic.co/docs/reference/elasticsearch/mapping-reference/number#scaled-float-params
if field_type == "scaled_float":
current_level[leaf_key]["scaling_factor"] = 1000

# add `path` field for `alias` fields, set to a dummy value
if field_type == "alias":
current_level[leaf_key]["path"] = "@timestamp"

return result
13 changes: 13 additions & 0 deletions detection_rules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,16 @@ def get_identifiers(self) -> list[str]:
# another group we're not expecting
raise ValueError("Unrecognized named group in pattern", self.pattern)
return ids


FROM_SOURCES_REGEX = re.compile(r"^\s*FROM\s+(?P<sources>.+?)\s*(?:\||\bmetadata\b|//|$)", re.IGNORECASE | re.MULTILINE)


def get_esql_query_indices(query: str) -> tuple[str, list[str]]:
match = FROM_SOURCES_REGEX.search(query)

if not match:
return "", []

sources_str = match.group("sources")
return sources_str, [source.strip() for source in sources_str.split(",")]
2 changes: 1 addition & 1 deletion hunting/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ def validate_esql_query(self, query: str) -> None:
# Check if either "stats by" or "| keep" exists in the query
if not stats_by_pattern.search(query) and not keep_pattern.search(query):
raise ValueError(
f"Hunt: {self.name} contains an ES|QL query that mustcontain either 'stats by' or 'keep' functions."
f"Hunt: {self.name} contains an ES|QL query that must contain either 'stats by' or 'keep' functions"
)
Loading
Loading