diff --git a/CHANGELOG.md b/CHANGELOG.md index 1531004d3..c4b532b35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## [Unreleased] +- Create `elasticstack_kibana_security_detection_rule` resource. ([#1290](https://github.com/elastic/terraform-provider-elasticstack/pull/1290)) - Add `elasticstack_kibana_export_saved_objects` data source ([#1293](https://github.com/elastic/terraform-provider-elasticstack/pull/1293)) - Create `elasticstack_kibana_maintenance_window` resource. ([#1224](https://github.com/elastic/terraform-provider-elasticstack/pull/1224)) - Add support for `solution` field in `elasticstack_kibana_space` resource and data source ([#1102](https://github.com/elastic/terraform-provider-elasticstack/issues/1102)) diff --git a/docs/resources/kibana_security_detection_rule.md b/docs/resources/kibana_security_detection_rule.md new file mode 100644 index 000000000..ff5acf49f --- /dev/null +++ b/docs/resources/kibana_security_detection_rule.md @@ -0,0 +1,431 @@ + +--- +# generated by https://github.com/hashicorp/terraform-plugin-docs +page_title: "elasticstack_kibana_security_detection_rule Resource - terraform-provider-elasticstack" +subcategory: "Kibana" +description: |- + Creates or updates a Kibana security detection rule. See the rules API documentation https://www.elastic.co/guide/en/security/current/rules-api-create.html for more details. +--- + +# elasticstack_kibana_security_detection_rule (Resource) + +Creates or updates a Kibana security detection rule. See the [rules API documentation](https://www.elastic.co/guide/en/security/current/rules-api-create.html) for more details. + +## Example Usage + +```terraform +provider "elasticstack" { + kibana {} +} + +# Basic security detection rule +resource "elasticstack_kibana_security_detection_rule" "example" { + name = "Suspicious Activity Detection" + type = "query" + query = "event.action:logon AND user.name:admin" + language = "kuery" + enabled = true + description = "Detects suspicious admin logon activities" + severity = "high" + risk_score = 75 + from = "now-6m" + to = "now" + interval = "5m" + + author = ["Security Team"] + tags = ["security", "authentication", "admin"] + license = "Elastic License v2" + false_positives = ["Legitimate admin access during maintenance windows"] + references = [ + "https://example.com/security-docs", + "https://example.com/admin-access-policy" + ] + + note = "Investigate the source IP and verify if the admin access is legitimate." + setup = "Ensure that authentication logs are being collected and indexed." +} + +# Advanced security detection rule with custom settings +resource "elasticstack_kibana_security_detection_rule" "advanced" { + name = "Advanced Threat Detection" + type = "query" + query = "process.name:powershell.exe AND process.args:*encoded*" + language = "kuery" + enabled = true + description = "Detects encoded PowerShell commands which may indicate malicious activity" + severity = "critical" + risk_score = 90 + from = "now-10m" + to = "now" + interval = "2m" + max_signals = 200 + version = 1 + + index = [ + "winlogbeat-*", + "logs-windows-*" + ] + + author = [ + "Threat Intelligence Team", + "SOC Analysts" + ] + + tags = [ + "windows", + "powershell", + "encoded", + "malware", + "critical" + ] + + false_positives = [ + "Legitimate encoded PowerShell scripts used by automation", + "Software installation scripts" + ] + + references = [ + "https://attack.mitre.org/techniques/T1059/001/", + "https://example.com/powershell-security-guide" + ] + + license = "Elastic License v2" + note = <<-EOT + ## Investigation Steps + 1. Examine the full PowerShell command line + 2. Decode any base64 encoded content + 3. Check the parent process that spawned PowerShell + 4. Review network connections made during execution + 5. Check for file system modifications + EOT + + setup = <<-EOT + ## Prerequisites + - Windows endpoint monitoring must be enabled + - PowerShell logging should be configured + - Sysmon or equivalent process monitoring required + EOT +} +``` + + +## Schema + +### Required + +- `description` (String) The rule's description. +- `name` (String) A human-readable name for the rule. +- `type` (String) Rule type. Supported types: query, eql, esql, machine_learning, new_terms, saved_query, threat_match, threshold. + +### Optional + +- `actions` (Attributes List) Array of automated actions taken when alerts are generated by the rule. (see [below for nested schema](#nestedatt--actions)) +- `alert_suppression` (Attributes) Defines alert suppression configuration to reduce duplicate alerts. (see [below for nested schema](#nestedatt--alert_suppression)) +- `anomaly_threshold` (Number) Anomaly score threshold above which the rule creates an alert. Valid values are from 0 to 100. Required for machine_learning rules. +- `author` (List of String) The rule's author. +- `building_block_type` (String) Determines if the rule acts as a building block. If set, value must be `default`. Building-block alerts are not displayed in the UI by default and are used as a foundation for other rules. +- `concurrent_searches` (Number) Number of concurrent searches for threat intelligence. Optional for threat_match rules. +- `data_view_id` (String) Data view ID for the rule. Not supported for esql and machine_learning rule types. +- `enabled` (Boolean) Determines whether the rule is enabled. +- `exceptions_list` (Attributes List) Array of exception containers to prevent the rule from generating alerts. (see [below for nested schema](#nestedatt--exceptions_list)) +- `false_positives` (List of String) String array used to describe common reasons why the rule may issue false-positive alerts. +- `filters` (String) Query and filter context array to define alert conditions as JSON. Supports complex filter structures including bool queries, term filters, range filters, etc. Available for all rule types. +- `from` (String) Time from which data is analyzed each time the rule runs, using a date math range. +- `history_window_start` (String) Start date to use when checking if a term has been seen before. Supports relative dates like 'now-30d'. Required for new_terms rules. +- `index` (List of String) Indices on which the rule functions. +- `interval` (String) Frequency of rule execution, using a date math range. +- `investigation_fields` (List of String) Array of field names to include in alert investigation. Available for all rule types. +- `items_per_search` (Number) Number of items to search for in each concurrent search. Optional for threat_match rules. +- `language` (String) The query language (KQL or Lucene). +- `license` (String) The rule's license. +- `machine_learning_job_id` (List of String) Machine learning job ID(s) the rule monitors for anomaly scores. Required for machine_learning rules. +- `max_signals` (Number) Maximum number of alerts the rule can create during a single run. +- `namespace` (String) Alerts index namespace. Available for all rule types. +- `new_terms_fields` (List of String) Field names containing the new terms. Required for new_terms rules. +- `note` (String) Notes to help investigate alerts produced by the rule. +- `query` (String) The query language definition. +- `references` (List of String) String array containing references and URLs to sources of additional information. +- `related_integrations` (Attributes List) Array of related integrations that provide additional context for the rule. (see [below for nested schema](#nestedatt--related_integrations)) +- `required_fields` (Attributes List) Array of Elasticsearch fields and types that must be present in source indices for the rule to function properly. (see [below for nested schema](#nestedatt--required_fields)) +- `response_actions` (Attributes List) Array of response actions to take when alerts are generated by the rule. (see [below for nested schema](#nestedatt--response_actions)) +- `risk_score` (Number) A numerical representation of the alert's severity from 0 to 100. +- `risk_score_mapping` (Attributes List) Array of risk score mappings to override the default risk score based on source event field values. (see [below for nested schema](#nestedatt--risk_score_mapping)) +- `rule_id` (String) A stable unique identifier for the rule object. If omitted, a UUID is generated. +- `rule_name_override` (String) Override the rule name in Kibana. Available for all rule types. +- `saved_id` (String) Identifier of the saved query used for the rule. Required for saved_query rules. +- `setup` (String) Setup guide with instructions on rule prerequisites. +- `severity` (String) Severity level of alerts produced by the rule. +- `severity_mapping` (Attributes List) Array of severity mappings to override the default severity based on source event field values. (see [below for nested schema](#nestedatt--severity_mapping)) +- `space_id` (String) An identifier for the space. If space_id is not provided, the default space is used. +- `tags` (List of String) String array containing words and phrases to help categorize, filter, and search rules. +- `threat` (Attributes List) MITRE ATT&CK framework threat information. (see [below for nested schema](#nestedatt--threat)) +- `threat_filters` (List of String) Additional filters for threat intelligence data. Optional for threat_match rules. +- `threat_index` (List of String) Array of index patterns for the threat intelligence indices. Required for threat_match rules. +- `threat_indicator_path` (String) Path to the threat indicator in the indicator documents. Optional for threat_match rules. +- `threat_mapping` (Attributes List) Array of threat mappings that specify how to match events with threat intelligence. Required for threat_match rules. (see [below for nested schema](#nestedatt--threat_mapping)) +- `threat_query` (String) Query used to filter threat intelligence data. Optional for threat_match rules. +- `threshold` (Attributes) Threshold settings for the rule. Required for threshold rules. (see [below for nested schema](#nestedatt--threshold)) +- `tiebreaker_field` (String) Sets the tiebreaker field. Required for EQL rules when event.dataset is not provided. +- `timeline_id` (String) Timeline template ID for the rule. +- `timeline_title` (String) Timeline template title for the rule. +- `timestamp_override` (String) Field name to use for timestamp override. Available for all rule types. +- `timestamp_override_fallback_disabled` (Boolean) Disables timestamp override fallback. Available for all rule types. +- `to` (String) Time to which data is analyzed each time the rule runs, using a date math range. +- `version` (Number) The rule's version number. + +### Read-Only + +- `created_at` (String) The time the rule was created. +- `created_by` (String) The user who created the rule. +- `id` (String) Internal identifier of the resource +- `revision` (Number) The rule's revision number. +- `updated_at` (String) The time the rule was last updated. +- `updated_by` (String) The user who last updated the rule. + + +### Nested Schema for `actions` + +Required: + +- `action_type_id` (String) The action type used for sending notifications (e.g., .slack, .email, .webhook, .pagerduty, etc.). +- `id` (String) The connector ID. +- `params` (Map of String) Object containing the allowed connector fields, which varies according to the connector type. + +Optional: + +- `alerts_filter` (Map of String) Object containing an action's conditional filters. +- `frequency` (Attributes) The action frequency defines when the action runs. (see [below for nested schema](#nestedatt--actions--frequency)) +- `group` (String) Optionally groups actions by use cases. Use 'default' for alert notifications. +- `uuid` (String) A unique identifier for the action. + + +### Nested Schema for `actions.frequency` + +Required: + +- `notify_when` (String) Defines how often rules run actions. Valid values: onActionGroupChange, onActiveAlert, onThrottleInterval. +- `summary` (Boolean) Action summary indicates whether we will send a summary notification about all the generated alerts or notification per individual alert. +- `throttle` (String) Time interval for throttling actions (e.g., '1h', '30m', 'no_actions', 'rule'). + + + + +### Nested Schema for `alert_suppression` + +Optional: + +- `duration` (String) Duration for which alerts are suppressed. +- `group_by` (List of String) Array of field names to group alerts by for suppression. +- `missing_fields_strategy` (String) Strategy for handling missing fields in suppression grouping: 'suppress' - only one alert will be created per suppress by bucket, 'doNotSuppress' - per each document a separate alert will be created. + + + +### Nested Schema for `exceptions_list` + +Required: + +- `id` (String) The exception container ID. +- `list_id` (String) The exception container's list ID. +- `namespace_type` (String) The namespace type for the exception container. +- `type` (String) The type of exception container. + + + +### Nested Schema for `related_integrations` + +Required: + +- `package` (String) Name of the integration package. +- `version` (String) Version of the integration package. + +Optional: + +- `integration` (String) Name of the specific integration. + + + +### Nested Schema for `required_fields` + +Required: + +- `name` (String) Name of the Elasticsearch field. +- `type` (String) Type of the Elasticsearch field. + +Read-Only: + +- `ecs` (Boolean) Indicates whether the field is ECS-compliant. This is computed by the backend based on the field name and type. + + + +### Nested Schema for `response_actions` + +Required: + +- `action_type_id` (String) The action type used for response actions (.osquery, .endpoint). +- `params` (Attributes) Parameters for the response action. Structure varies based on action_type_id. (see [below for nested schema](#nestedatt--response_actions--params)) + + +### Nested Schema for `response_actions.params` + +Optional: + +- `command` (String) Command to run (endpoint only). Valid values: isolate, kill-process, suspend-process. +- `comment` (String) Comment describing the action (endpoint only). +- `config` (Attributes) Configuration for process commands (endpoint only). (see [below for nested schema](#nestedatt--response_actions--params--config)) +- `ecs_mapping` (Map of String) Map Osquery results columns to ECS fields (osquery only). +- `pack_id` (String) Query pack identifier (osquery only). +- `queries` (Attributes List) Array of queries to run (osquery only). (see [below for nested schema](#nestedatt--response_actions--params--queries)) +- `query` (String) SQL query to run (osquery only). Example: 'SELECT * FROM processes;' +- `saved_query_id` (String) Saved query identifier (osquery only). +- `timeout` (Number) Timeout period in seconds (osquery only). Min: 60, Max: 900. + + +### Nested Schema for `response_actions.params.config` + +Required: + +- `field` (String) Field to use instead of process.pid. + +Optional: + +- `overwrite` (Boolean) Whether to overwrite field with process.pid. + + + +### Nested Schema for `response_actions.params.queries` + +Required: + +- `id` (String) Query ID. +- `query` (String) Query to run. + +Optional: + +- `ecs_mapping` (Map of String) ECS field mappings for this query. +- `platform` (String) Platform to run the query on. +- `removed` (Boolean) Whether the query is removed. +- `snapshot` (Boolean) Whether this is a snapshot query. +- `version` (String) Query version. + + + + + +### Nested Schema for `risk_score_mapping` + +Required: + +- `field` (String) Source event field used to override the default risk_score. +- `operator` (String) Operator to use for field value matching. Currently only 'equals' is supported. +- `value` (String) Value to match against the field. + +Optional: + +- `risk_score` (Number) Risk score to use when the field matches the value (0-100). If omitted, uses the rule's default risk_score. + + + +### Nested Schema for `severity_mapping` + +Required: + +- `field` (String) Source event field used to override the default severity. +- `operator` (String) Operator to use for field value matching. Currently only 'equals' is supported. +- `severity` (String) Severity level to use when the field matches the value. +- `value` (String) Value to match against the field. + + + +### Nested Schema for `threat` + +Required: + +- `framework` (String) Threat framework (typically 'MITRE ATT&CK'). +- `tactic` (Attributes) MITRE ATT&CK tactic information. (see [below for nested schema](#nestedatt--threat--tactic)) + +Optional: + +- `technique` (Attributes List) MITRE ATT&CK technique information. (see [below for nested schema](#nestedatt--threat--technique)) + + +### Nested Schema for `threat.tactic` + +Required: + +- `id` (String) MITRE ATT&CK tactic ID. +- `name` (String) MITRE ATT&CK tactic name. +- `reference` (String) MITRE ATT&CK tactic reference URL. + + + +### Nested Schema for `threat.technique` + +Required: + +- `id` (String) MITRE ATT&CK technique ID. +- `name` (String) MITRE ATT&CK technique name. +- `reference` (String) MITRE ATT&CK technique reference URL. + +Optional: + +- `subtechnique` (Attributes List) MITRE ATT&CK sub-technique information. (see [below for nested schema](#nestedatt--threat--technique--subtechnique)) + + +### Nested Schema for `threat.technique.subtechnique` + +Required: + +- `id` (String) MITRE ATT&CK sub-technique ID. +- `name` (String) MITRE ATT&CK sub-technique name. +- `reference` (String) MITRE ATT&CK sub-technique reference URL. + + + + + +### Nested Schema for `threat_mapping` + +Required: + +- `entries` (Attributes List) Array of mapping entries. (see [below for nested schema](#nestedatt--threat_mapping--entries)) + + +### Nested Schema for `threat_mapping.entries` + +Required: + +- `field` (String) Event field to match. +- `type` (String) Type of match (mapping). +- `value` (String) Threat intelligence field to match against. + + + + +### Nested Schema for `threshold` + +Required: + +- `value` (Number) The threshold value from which an alert is generated. + +Optional: + +- `cardinality` (Attributes List) Cardinality settings for threshold rule. (see [below for nested schema](#nestedatt--threshold--cardinality)) +- `field` (List of String) Field(s) to use for threshold aggregation. + + +### Nested Schema for `threshold.cardinality` + +Required: + +- `field` (String) The field on which to calculate and compare the cardinality. +- `value` (Number) The threshold cardinality value. + +## Import + +Import is supported using the following syntax: + +The [`terraform import` command](https://developer.hashicorp.com/terraform/cli/commands/import) can be used, for example: + +```shell +terraform import elasticstack_kibana_security_detection_rule.example default/12345678-1234-1234-1234-123456789abc +``` diff --git a/examples/resources/elasticstack_kibana_security_detection_rule/import.sh b/examples/resources/elasticstack_kibana_security_detection_rule/import.sh new file mode 100644 index 000000000..2836ff959 --- /dev/null +++ b/examples/resources/elasticstack_kibana_security_detection_rule/import.sh @@ -0,0 +1 @@ +terraform import elasticstack_kibana_security_detection_rule.example default/12345678-1234-1234-1234-123456789abc \ No newline at end of file diff --git a/examples/resources/elasticstack_kibana_security_detection_rule/resource.tf b/examples/resources/elasticstack_kibana_security_detection_rule/resource.tf new file mode 100644 index 000000000..d4283b5a1 --- /dev/null +++ b/examples/resources/elasticstack_kibana_security_detection_rule/resource.tf @@ -0,0 +1,92 @@ +provider "elasticstack" { + kibana {} +} + +# Basic security detection rule +resource "elasticstack_kibana_security_detection_rule" "example" { + name = "Suspicious Activity Detection" + type = "query" + query = "event.action:logon AND user.name:admin" + language = "kuery" + enabled = true + description = "Detects suspicious admin logon activities" + severity = "high" + risk_score = 75 + from = "now-6m" + to = "now" + interval = "5m" + + author = ["Security Team"] + tags = ["security", "authentication", "admin"] + license = "Elastic License v2" + false_positives = ["Legitimate admin access during maintenance windows"] + references = [ + "https://example.com/security-docs", + "https://example.com/admin-access-policy" + ] + + note = "Investigate the source IP and verify if the admin access is legitimate." + setup = "Ensure that authentication logs are being collected and indexed." +} + +# Advanced security detection rule with custom settings +resource "elasticstack_kibana_security_detection_rule" "advanced" { + name = "Advanced Threat Detection" + type = "query" + query = "process.name:powershell.exe AND process.args:*encoded*" + language = "kuery" + enabled = true + description = "Detects encoded PowerShell commands which may indicate malicious activity" + severity = "critical" + risk_score = 90 + from = "now-10m" + to = "now" + interval = "2m" + max_signals = 200 + version = 1 + + index = [ + "winlogbeat-*", + "logs-windows-*" + ] + + author = [ + "Threat Intelligence Team", + "SOC Analysts" + ] + + tags = [ + "windows", + "powershell", + "encoded", + "malware", + "critical" + ] + + false_positives = [ + "Legitimate encoded PowerShell scripts used by automation", + "Software installation scripts" + ] + + references = [ + "https://attack.mitre.org/techniques/T1059/001/", + "https://example.com/powershell-security-guide" + ] + + license = "Elastic License v2" + note = <<-EOT + ## Investigation Steps + 1. Examine the full PowerShell command line + 2. Decode any base64 encoded content + 3. Check the parent process that spawned PowerShell + 4. Review network connections made during execution + 5. Check for file system modifications + EOT + + setup = <<-EOT + ## Prerequisites + - Windows endpoint monitoring must be enabled + - PowerShell logging should be configured + - Sysmon or equivalent process monitoring required + EOT +} \ No newline at end of file diff --git a/generated/kbapi/kibana.gen.go b/generated/kbapi/kibana.gen.go index 3c947f747..500ac2ba1 100644 --- a/generated/kbapi/kibana.gen.go +++ b/generated/kbapi/kibana.gen.go @@ -51577,7 +51577,7 @@ func (t SLOsIndicatorPropertiesTimesliceMetric_Params_Metric_Metrics_Item) AsSLO // FromSLOsTimesliceMetricBasicMetricWithField overwrites any union data inside the SLOsIndicatorPropertiesTimesliceMetric_Params_Metric_Metrics_Item as the provided SLOsTimesliceMetricBasicMetricWithField func (t *SLOsIndicatorPropertiesTimesliceMetric_Params_Metric_Metrics_Item) FromSLOsTimesliceMetricBasicMetricWithField(v SLOsTimesliceMetricBasicMetricWithField) error { - v.Aggregation = "avg" + v.Aggregation = "sum" b, err := json.Marshal(v) t.union = b return err @@ -51585,7 +51585,7 @@ func (t *SLOsIndicatorPropertiesTimesliceMetric_Params_Metric_Metrics_Item) From // MergeSLOsTimesliceMetricBasicMetricWithField performs a merge with any union data inside the SLOsIndicatorPropertiesTimesliceMetric_Params_Metric_Metrics_Item, using the provided SLOsTimesliceMetricBasicMetricWithField func (t *SLOsIndicatorPropertiesTimesliceMetric_Params_Metric_Metrics_Item) MergeSLOsTimesliceMetricBasicMetricWithField(v SLOsTimesliceMetricBasicMetricWithField) error { - v.Aggregation = "avg" + v.Aggregation = "sum" b, err := json.Marshal(v) if err != nil { return err @@ -51666,12 +51666,12 @@ func (t SLOsIndicatorPropertiesTimesliceMetric_Params_Metric_Metrics_Item) Value return nil, err } switch discriminator { - case "avg": - return t.AsSLOsTimesliceMetricBasicMetricWithField() case "doc_count": return t.AsSLOsTimesliceMetricDocCountMetric() case "percentile": return t.AsSLOsTimesliceMetricPercentileMetric() + case "sum": + return t.AsSLOsTimesliceMetricBasicMetricWithField() default: return nil, errors.New("unknown discriminator value: " + discriminator) } @@ -53532,6 +53532,7 @@ func (t SecurityDetectionsAPIResponseAction) AsSecurityDetectionsAPIOsqueryRespo // FromSecurityDetectionsAPIOsqueryResponseAction overwrites any union data inside the SecurityDetectionsAPIResponseAction as the provided SecurityDetectionsAPIOsqueryResponseAction func (t *SecurityDetectionsAPIResponseAction) FromSecurityDetectionsAPIOsqueryResponseAction(v SecurityDetectionsAPIOsqueryResponseAction) error { + v.ActionTypeId = ".osquery" b, err := json.Marshal(v) t.union = b return err @@ -53539,6 +53540,7 @@ func (t *SecurityDetectionsAPIResponseAction) FromSecurityDetectionsAPIOsqueryRe // MergeSecurityDetectionsAPIOsqueryResponseAction performs a merge with any union data inside the SecurityDetectionsAPIResponseAction, using the provided SecurityDetectionsAPIOsqueryResponseAction func (t *SecurityDetectionsAPIResponseAction) MergeSecurityDetectionsAPIOsqueryResponseAction(v SecurityDetectionsAPIOsqueryResponseAction) error { + v.ActionTypeId = ".osquery" b, err := json.Marshal(v) if err != nil { return err @@ -53558,6 +53560,7 @@ func (t SecurityDetectionsAPIResponseAction) AsSecurityDetectionsAPIEndpointResp // FromSecurityDetectionsAPIEndpointResponseAction overwrites any union data inside the SecurityDetectionsAPIResponseAction as the provided SecurityDetectionsAPIEndpointResponseAction func (t *SecurityDetectionsAPIResponseAction) FromSecurityDetectionsAPIEndpointResponseAction(v SecurityDetectionsAPIEndpointResponseAction) error { + v.ActionTypeId = ".endpoint" b, err := json.Marshal(v) t.union = b return err @@ -53565,6 +53568,7 @@ func (t *SecurityDetectionsAPIResponseAction) FromSecurityDetectionsAPIEndpointR // MergeSecurityDetectionsAPIEndpointResponseAction performs a merge with any union data inside the SecurityDetectionsAPIResponseAction, using the provided SecurityDetectionsAPIEndpointResponseAction func (t *SecurityDetectionsAPIResponseAction) MergeSecurityDetectionsAPIEndpointResponseAction(v SecurityDetectionsAPIEndpointResponseAction) error { + v.ActionTypeId = ".endpoint" b, err := json.Marshal(v) if err != nil { return err @@ -53575,6 +53579,29 @@ func (t *SecurityDetectionsAPIResponseAction) MergeSecurityDetectionsAPIEndpoint return err } +func (t SecurityDetectionsAPIResponseAction) Discriminator() (string, error) { + var discriminator struct { + Discriminator string `json:"action_type_id"` + } + err := json.Unmarshal(t.union, &discriminator) + return discriminator.Discriminator, err +} + +func (t SecurityDetectionsAPIResponseAction) ValueByDiscriminator() (interface{}, error) { + discriminator, err := t.Discriminator() + if err != nil { + return nil, err + } + switch discriminator { + case ".endpoint": + return t.AsSecurityDetectionsAPIEndpointResponseAction() + case ".osquery": + return t.AsSecurityDetectionsAPIOsqueryResponseAction() + default: + return nil, errors.New("unknown discriminator value: " + discriminator) + } +} + func (t SecurityDetectionsAPIResponseAction) MarshalJSON() ([]byte, error) { b, err := t.union.MarshalJSON() return b, err @@ -53656,6 +53683,7 @@ func (t SecurityDetectionsAPIRuleCreateProps) AsSecurityDetectionsAPIEqlRuleCrea // FromSecurityDetectionsAPIEqlRuleCreateProps overwrites any union data inside the SecurityDetectionsAPIRuleCreateProps as the provided SecurityDetectionsAPIEqlRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIEqlRuleCreateProps(v SecurityDetectionsAPIEqlRuleCreateProps) error { + v.Type = "eql" b, err := json.Marshal(v) t.union = b return err @@ -53663,6 +53691,7 @@ func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIEqlRuleC // MergeSecurityDetectionsAPIEqlRuleCreateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleCreateProps, using the provided SecurityDetectionsAPIEqlRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) MergeSecurityDetectionsAPIEqlRuleCreateProps(v SecurityDetectionsAPIEqlRuleCreateProps) error { + v.Type = "eql" b, err := json.Marshal(v) if err != nil { return err @@ -53682,6 +53711,7 @@ func (t SecurityDetectionsAPIRuleCreateProps) AsSecurityDetectionsAPIQueryRuleCr // FromSecurityDetectionsAPIQueryRuleCreateProps overwrites any union data inside the SecurityDetectionsAPIRuleCreateProps as the provided SecurityDetectionsAPIQueryRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIQueryRuleCreateProps(v SecurityDetectionsAPIQueryRuleCreateProps) error { + v.Type = "query" b, err := json.Marshal(v) t.union = b return err @@ -53689,6 +53719,7 @@ func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIQueryRul // MergeSecurityDetectionsAPIQueryRuleCreateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleCreateProps, using the provided SecurityDetectionsAPIQueryRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) MergeSecurityDetectionsAPIQueryRuleCreateProps(v SecurityDetectionsAPIQueryRuleCreateProps) error { + v.Type = "query" b, err := json.Marshal(v) if err != nil { return err @@ -53708,6 +53739,7 @@ func (t SecurityDetectionsAPIRuleCreateProps) AsSecurityDetectionsAPISavedQueryR // FromSecurityDetectionsAPISavedQueryRuleCreateProps overwrites any union data inside the SecurityDetectionsAPIRuleCreateProps as the provided SecurityDetectionsAPISavedQueryRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPISavedQueryRuleCreateProps(v SecurityDetectionsAPISavedQueryRuleCreateProps) error { + v.Type = "saved_query" b, err := json.Marshal(v) t.union = b return err @@ -53715,6 +53747,7 @@ func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPISavedQue // MergeSecurityDetectionsAPISavedQueryRuleCreateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleCreateProps, using the provided SecurityDetectionsAPISavedQueryRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) MergeSecurityDetectionsAPISavedQueryRuleCreateProps(v SecurityDetectionsAPISavedQueryRuleCreateProps) error { + v.Type = "saved_query" b, err := json.Marshal(v) if err != nil { return err @@ -53734,6 +53767,7 @@ func (t SecurityDetectionsAPIRuleCreateProps) AsSecurityDetectionsAPIThresholdRu // FromSecurityDetectionsAPIThresholdRuleCreateProps overwrites any union data inside the SecurityDetectionsAPIRuleCreateProps as the provided SecurityDetectionsAPIThresholdRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIThresholdRuleCreateProps(v SecurityDetectionsAPIThresholdRuleCreateProps) error { + v.Type = "threshold" b, err := json.Marshal(v) t.union = b return err @@ -53741,6 +53775,7 @@ func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIThreshol // MergeSecurityDetectionsAPIThresholdRuleCreateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleCreateProps, using the provided SecurityDetectionsAPIThresholdRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) MergeSecurityDetectionsAPIThresholdRuleCreateProps(v SecurityDetectionsAPIThresholdRuleCreateProps) error { + v.Type = "threshold" b, err := json.Marshal(v) if err != nil { return err @@ -53760,6 +53795,7 @@ func (t SecurityDetectionsAPIRuleCreateProps) AsSecurityDetectionsAPIThreatMatch // FromSecurityDetectionsAPIThreatMatchRuleCreateProps overwrites any union data inside the SecurityDetectionsAPIRuleCreateProps as the provided SecurityDetectionsAPIThreatMatchRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIThreatMatchRuleCreateProps(v SecurityDetectionsAPIThreatMatchRuleCreateProps) error { + v.Type = "threat_match" b, err := json.Marshal(v) t.union = b return err @@ -53767,6 +53803,7 @@ func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIThreatMa // MergeSecurityDetectionsAPIThreatMatchRuleCreateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleCreateProps, using the provided SecurityDetectionsAPIThreatMatchRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) MergeSecurityDetectionsAPIThreatMatchRuleCreateProps(v SecurityDetectionsAPIThreatMatchRuleCreateProps) error { + v.Type = "threat_match" b, err := json.Marshal(v) if err != nil { return err @@ -53786,6 +53823,7 @@ func (t SecurityDetectionsAPIRuleCreateProps) AsSecurityDetectionsAPIMachineLear // FromSecurityDetectionsAPIMachineLearningRuleCreateProps overwrites any union data inside the SecurityDetectionsAPIRuleCreateProps as the provided SecurityDetectionsAPIMachineLearningRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIMachineLearningRuleCreateProps(v SecurityDetectionsAPIMachineLearningRuleCreateProps) error { + v.Type = "machine_learning" b, err := json.Marshal(v) t.union = b return err @@ -53793,6 +53831,7 @@ func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIMachineL // MergeSecurityDetectionsAPIMachineLearningRuleCreateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleCreateProps, using the provided SecurityDetectionsAPIMachineLearningRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) MergeSecurityDetectionsAPIMachineLearningRuleCreateProps(v SecurityDetectionsAPIMachineLearningRuleCreateProps) error { + v.Type = "machine_learning" b, err := json.Marshal(v) if err != nil { return err @@ -53812,6 +53851,7 @@ func (t SecurityDetectionsAPIRuleCreateProps) AsSecurityDetectionsAPINewTermsRul // FromSecurityDetectionsAPINewTermsRuleCreateProps overwrites any union data inside the SecurityDetectionsAPIRuleCreateProps as the provided SecurityDetectionsAPINewTermsRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPINewTermsRuleCreateProps(v SecurityDetectionsAPINewTermsRuleCreateProps) error { + v.Type = "new_terms" b, err := json.Marshal(v) t.union = b return err @@ -53819,6 +53859,7 @@ func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPINewTerms // MergeSecurityDetectionsAPINewTermsRuleCreateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleCreateProps, using the provided SecurityDetectionsAPINewTermsRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) MergeSecurityDetectionsAPINewTermsRuleCreateProps(v SecurityDetectionsAPINewTermsRuleCreateProps) error { + v.Type = "new_terms" b, err := json.Marshal(v) if err != nil { return err @@ -53838,6 +53879,7 @@ func (t SecurityDetectionsAPIRuleCreateProps) AsSecurityDetectionsAPIEsqlRuleCre // FromSecurityDetectionsAPIEsqlRuleCreateProps overwrites any union data inside the SecurityDetectionsAPIRuleCreateProps as the provided SecurityDetectionsAPIEsqlRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIEsqlRuleCreateProps(v SecurityDetectionsAPIEsqlRuleCreateProps) error { + v.Type = "esql" b, err := json.Marshal(v) t.union = b return err @@ -53845,6 +53887,7 @@ func (t *SecurityDetectionsAPIRuleCreateProps) FromSecurityDetectionsAPIEsqlRule // MergeSecurityDetectionsAPIEsqlRuleCreateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleCreateProps, using the provided SecurityDetectionsAPIEsqlRuleCreateProps func (t *SecurityDetectionsAPIRuleCreateProps) MergeSecurityDetectionsAPIEsqlRuleCreateProps(v SecurityDetectionsAPIEsqlRuleCreateProps) error { + v.Type = "esql" b, err := json.Marshal(v) if err != nil { return err @@ -53855,6 +53898,41 @@ func (t *SecurityDetectionsAPIRuleCreateProps) MergeSecurityDetectionsAPIEsqlRul return err } +func (t SecurityDetectionsAPIRuleCreateProps) Discriminator() (string, error) { + var discriminator struct { + Discriminator string `json:"type"` + } + err := json.Unmarshal(t.union, &discriminator) + return discriminator.Discriminator, err +} + +func (t SecurityDetectionsAPIRuleCreateProps) ValueByDiscriminator() (interface{}, error) { + discriminator, err := t.Discriminator() + if err != nil { + return nil, err + } + switch discriminator { + case "eql": + return t.AsSecurityDetectionsAPIEqlRuleCreateProps() + case "esql": + return t.AsSecurityDetectionsAPIEsqlRuleCreateProps() + case "machine_learning": + return t.AsSecurityDetectionsAPIMachineLearningRuleCreateProps() + case "new_terms": + return t.AsSecurityDetectionsAPINewTermsRuleCreateProps() + case "query": + return t.AsSecurityDetectionsAPIQueryRuleCreateProps() + case "saved_query": + return t.AsSecurityDetectionsAPISavedQueryRuleCreateProps() + case "threat_match": + return t.AsSecurityDetectionsAPIThreatMatchRuleCreateProps() + case "threshold": + return t.AsSecurityDetectionsAPIThresholdRuleCreateProps() + default: + return nil, errors.New("unknown discriminator value: " + discriminator) + } +} + func (t SecurityDetectionsAPIRuleCreateProps) MarshalJSON() ([]byte, error) { b, err := t.union.MarshalJSON() return b, err @@ -54092,6 +54170,7 @@ func (t SecurityDetectionsAPIRuleResponse) AsSecurityDetectionsAPIEqlRule() (Sec // FromSecurityDetectionsAPIEqlRule overwrites any union data inside the SecurityDetectionsAPIRuleResponse as the provided SecurityDetectionsAPIEqlRule func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIEqlRule(v SecurityDetectionsAPIEqlRule) error { + v.Type = "eql" b, err := json.Marshal(v) t.union = b return err @@ -54099,6 +54178,7 @@ func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIEqlRule(v S // MergeSecurityDetectionsAPIEqlRule performs a merge with any union data inside the SecurityDetectionsAPIRuleResponse, using the provided SecurityDetectionsAPIEqlRule func (t *SecurityDetectionsAPIRuleResponse) MergeSecurityDetectionsAPIEqlRule(v SecurityDetectionsAPIEqlRule) error { + v.Type = "eql" b, err := json.Marshal(v) if err != nil { return err @@ -54118,6 +54198,7 @@ func (t SecurityDetectionsAPIRuleResponse) AsSecurityDetectionsAPIQueryRule() (S // FromSecurityDetectionsAPIQueryRule overwrites any union data inside the SecurityDetectionsAPIRuleResponse as the provided SecurityDetectionsAPIQueryRule func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIQueryRule(v SecurityDetectionsAPIQueryRule) error { + v.Type = "query" b, err := json.Marshal(v) t.union = b return err @@ -54125,6 +54206,7 @@ func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIQueryRule(v // MergeSecurityDetectionsAPIQueryRule performs a merge with any union data inside the SecurityDetectionsAPIRuleResponse, using the provided SecurityDetectionsAPIQueryRule func (t *SecurityDetectionsAPIRuleResponse) MergeSecurityDetectionsAPIQueryRule(v SecurityDetectionsAPIQueryRule) error { + v.Type = "query" b, err := json.Marshal(v) if err != nil { return err @@ -54144,6 +54226,7 @@ func (t SecurityDetectionsAPIRuleResponse) AsSecurityDetectionsAPISavedQueryRule // FromSecurityDetectionsAPISavedQueryRule overwrites any union data inside the SecurityDetectionsAPIRuleResponse as the provided SecurityDetectionsAPISavedQueryRule func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPISavedQueryRule(v SecurityDetectionsAPISavedQueryRule) error { + v.Type = "saved_query" b, err := json.Marshal(v) t.union = b return err @@ -54151,6 +54234,7 @@ func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPISavedQueryR // MergeSecurityDetectionsAPISavedQueryRule performs a merge with any union data inside the SecurityDetectionsAPIRuleResponse, using the provided SecurityDetectionsAPISavedQueryRule func (t *SecurityDetectionsAPIRuleResponse) MergeSecurityDetectionsAPISavedQueryRule(v SecurityDetectionsAPISavedQueryRule) error { + v.Type = "saved_query" b, err := json.Marshal(v) if err != nil { return err @@ -54170,6 +54254,7 @@ func (t SecurityDetectionsAPIRuleResponse) AsSecurityDetectionsAPIThresholdRule( // FromSecurityDetectionsAPIThresholdRule overwrites any union data inside the SecurityDetectionsAPIRuleResponse as the provided SecurityDetectionsAPIThresholdRule func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIThresholdRule(v SecurityDetectionsAPIThresholdRule) error { + v.Type = "threshold" b, err := json.Marshal(v) t.union = b return err @@ -54177,6 +54262,7 @@ func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIThresholdRu // MergeSecurityDetectionsAPIThresholdRule performs a merge with any union data inside the SecurityDetectionsAPIRuleResponse, using the provided SecurityDetectionsAPIThresholdRule func (t *SecurityDetectionsAPIRuleResponse) MergeSecurityDetectionsAPIThresholdRule(v SecurityDetectionsAPIThresholdRule) error { + v.Type = "threshold" b, err := json.Marshal(v) if err != nil { return err @@ -54196,6 +54282,7 @@ func (t SecurityDetectionsAPIRuleResponse) AsSecurityDetectionsAPIThreatMatchRul // FromSecurityDetectionsAPIThreatMatchRule overwrites any union data inside the SecurityDetectionsAPIRuleResponse as the provided SecurityDetectionsAPIThreatMatchRule func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIThreatMatchRule(v SecurityDetectionsAPIThreatMatchRule) error { + v.Type = "threat_match" b, err := json.Marshal(v) t.union = b return err @@ -54203,6 +54290,7 @@ func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIThreatMatch // MergeSecurityDetectionsAPIThreatMatchRule performs a merge with any union data inside the SecurityDetectionsAPIRuleResponse, using the provided SecurityDetectionsAPIThreatMatchRule func (t *SecurityDetectionsAPIRuleResponse) MergeSecurityDetectionsAPIThreatMatchRule(v SecurityDetectionsAPIThreatMatchRule) error { + v.Type = "threat_match" b, err := json.Marshal(v) if err != nil { return err @@ -54222,6 +54310,7 @@ func (t SecurityDetectionsAPIRuleResponse) AsSecurityDetectionsAPIMachineLearnin // FromSecurityDetectionsAPIMachineLearningRule overwrites any union data inside the SecurityDetectionsAPIRuleResponse as the provided SecurityDetectionsAPIMachineLearningRule func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIMachineLearningRule(v SecurityDetectionsAPIMachineLearningRule) error { + v.Type = "machine_learning" b, err := json.Marshal(v) t.union = b return err @@ -54229,6 +54318,7 @@ func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIMachineLear // MergeSecurityDetectionsAPIMachineLearningRule performs a merge with any union data inside the SecurityDetectionsAPIRuleResponse, using the provided SecurityDetectionsAPIMachineLearningRule func (t *SecurityDetectionsAPIRuleResponse) MergeSecurityDetectionsAPIMachineLearningRule(v SecurityDetectionsAPIMachineLearningRule) error { + v.Type = "machine_learning" b, err := json.Marshal(v) if err != nil { return err @@ -54248,6 +54338,7 @@ func (t SecurityDetectionsAPIRuleResponse) AsSecurityDetectionsAPINewTermsRule() // FromSecurityDetectionsAPINewTermsRule overwrites any union data inside the SecurityDetectionsAPIRuleResponse as the provided SecurityDetectionsAPINewTermsRule func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPINewTermsRule(v SecurityDetectionsAPINewTermsRule) error { + v.Type = "new_terms" b, err := json.Marshal(v) t.union = b return err @@ -54255,6 +54346,7 @@ func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPINewTermsRul // MergeSecurityDetectionsAPINewTermsRule performs a merge with any union data inside the SecurityDetectionsAPIRuleResponse, using the provided SecurityDetectionsAPINewTermsRule func (t *SecurityDetectionsAPIRuleResponse) MergeSecurityDetectionsAPINewTermsRule(v SecurityDetectionsAPINewTermsRule) error { + v.Type = "new_terms" b, err := json.Marshal(v) if err != nil { return err @@ -54274,6 +54366,7 @@ func (t SecurityDetectionsAPIRuleResponse) AsSecurityDetectionsAPIEsqlRule() (Se // FromSecurityDetectionsAPIEsqlRule overwrites any union data inside the SecurityDetectionsAPIRuleResponse as the provided SecurityDetectionsAPIEsqlRule func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIEsqlRule(v SecurityDetectionsAPIEsqlRule) error { + v.Type = "esql" b, err := json.Marshal(v) t.union = b return err @@ -54281,6 +54374,7 @@ func (t *SecurityDetectionsAPIRuleResponse) FromSecurityDetectionsAPIEsqlRule(v // MergeSecurityDetectionsAPIEsqlRule performs a merge with any union data inside the SecurityDetectionsAPIRuleResponse, using the provided SecurityDetectionsAPIEsqlRule func (t *SecurityDetectionsAPIRuleResponse) MergeSecurityDetectionsAPIEsqlRule(v SecurityDetectionsAPIEsqlRule) error { + v.Type = "esql" b, err := json.Marshal(v) if err != nil { return err @@ -54291,6 +54385,41 @@ func (t *SecurityDetectionsAPIRuleResponse) MergeSecurityDetectionsAPIEsqlRule(v return err } +func (t SecurityDetectionsAPIRuleResponse) Discriminator() (string, error) { + var discriminator struct { + Discriminator string `json:"type"` + } + err := json.Unmarshal(t.union, &discriminator) + return discriminator.Discriminator, err +} + +func (t SecurityDetectionsAPIRuleResponse) ValueByDiscriminator() (interface{}, error) { + discriminator, err := t.Discriminator() + if err != nil { + return nil, err + } + switch discriminator { + case "eql": + return t.AsSecurityDetectionsAPIEqlRule() + case "esql": + return t.AsSecurityDetectionsAPIEsqlRule() + case "machine_learning": + return t.AsSecurityDetectionsAPIMachineLearningRule() + case "new_terms": + return t.AsSecurityDetectionsAPINewTermsRule() + case "query": + return t.AsSecurityDetectionsAPIQueryRule() + case "saved_query": + return t.AsSecurityDetectionsAPISavedQueryRule() + case "threat_match": + return t.AsSecurityDetectionsAPIThreatMatchRule() + case "threshold": + return t.AsSecurityDetectionsAPIThresholdRule() + default: + return nil, errors.New("unknown discriminator value: " + discriminator) + } +} + func (t SecurityDetectionsAPIRuleResponse) MarshalJSON() ([]byte, error) { b, err := t.union.MarshalJSON() return b, err @@ -54372,6 +54501,7 @@ func (t SecurityDetectionsAPIRuleUpdateProps) AsSecurityDetectionsAPIEqlRuleUpda // FromSecurityDetectionsAPIEqlRuleUpdateProps overwrites any union data inside the SecurityDetectionsAPIRuleUpdateProps as the provided SecurityDetectionsAPIEqlRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIEqlRuleUpdateProps(v SecurityDetectionsAPIEqlRuleUpdateProps) error { + v.Type = "eql" b, err := json.Marshal(v) t.union = b return err @@ -54379,6 +54509,7 @@ func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIEqlRuleU // MergeSecurityDetectionsAPIEqlRuleUpdateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleUpdateProps, using the provided SecurityDetectionsAPIEqlRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) MergeSecurityDetectionsAPIEqlRuleUpdateProps(v SecurityDetectionsAPIEqlRuleUpdateProps) error { + v.Type = "eql" b, err := json.Marshal(v) if err != nil { return err @@ -54398,6 +54529,7 @@ func (t SecurityDetectionsAPIRuleUpdateProps) AsSecurityDetectionsAPIQueryRuleUp // FromSecurityDetectionsAPIQueryRuleUpdateProps overwrites any union data inside the SecurityDetectionsAPIRuleUpdateProps as the provided SecurityDetectionsAPIQueryRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIQueryRuleUpdateProps(v SecurityDetectionsAPIQueryRuleUpdateProps) error { + v.Type = "query" b, err := json.Marshal(v) t.union = b return err @@ -54405,6 +54537,7 @@ func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIQueryRul // MergeSecurityDetectionsAPIQueryRuleUpdateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleUpdateProps, using the provided SecurityDetectionsAPIQueryRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) MergeSecurityDetectionsAPIQueryRuleUpdateProps(v SecurityDetectionsAPIQueryRuleUpdateProps) error { + v.Type = "query" b, err := json.Marshal(v) if err != nil { return err @@ -54424,6 +54557,7 @@ func (t SecurityDetectionsAPIRuleUpdateProps) AsSecurityDetectionsAPISavedQueryR // FromSecurityDetectionsAPISavedQueryRuleUpdateProps overwrites any union data inside the SecurityDetectionsAPIRuleUpdateProps as the provided SecurityDetectionsAPISavedQueryRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPISavedQueryRuleUpdateProps(v SecurityDetectionsAPISavedQueryRuleUpdateProps) error { + v.Type = "saved_query" b, err := json.Marshal(v) t.union = b return err @@ -54431,6 +54565,7 @@ func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPISavedQue // MergeSecurityDetectionsAPISavedQueryRuleUpdateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleUpdateProps, using the provided SecurityDetectionsAPISavedQueryRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) MergeSecurityDetectionsAPISavedQueryRuleUpdateProps(v SecurityDetectionsAPISavedQueryRuleUpdateProps) error { + v.Type = "saved_query" b, err := json.Marshal(v) if err != nil { return err @@ -54450,6 +54585,7 @@ func (t SecurityDetectionsAPIRuleUpdateProps) AsSecurityDetectionsAPIThresholdRu // FromSecurityDetectionsAPIThresholdRuleUpdateProps overwrites any union data inside the SecurityDetectionsAPIRuleUpdateProps as the provided SecurityDetectionsAPIThresholdRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIThresholdRuleUpdateProps(v SecurityDetectionsAPIThresholdRuleUpdateProps) error { + v.Type = "threshold" b, err := json.Marshal(v) t.union = b return err @@ -54457,6 +54593,7 @@ func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIThreshol // MergeSecurityDetectionsAPIThresholdRuleUpdateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleUpdateProps, using the provided SecurityDetectionsAPIThresholdRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) MergeSecurityDetectionsAPIThresholdRuleUpdateProps(v SecurityDetectionsAPIThresholdRuleUpdateProps) error { + v.Type = "threshold" b, err := json.Marshal(v) if err != nil { return err @@ -54476,6 +54613,7 @@ func (t SecurityDetectionsAPIRuleUpdateProps) AsSecurityDetectionsAPIThreatMatch // FromSecurityDetectionsAPIThreatMatchRuleUpdateProps overwrites any union data inside the SecurityDetectionsAPIRuleUpdateProps as the provided SecurityDetectionsAPIThreatMatchRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIThreatMatchRuleUpdateProps(v SecurityDetectionsAPIThreatMatchRuleUpdateProps) error { + v.Type = "threat_match" b, err := json.Marshal(v) t.union = b return err @@ -54483,6 +54621,7 @@ func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIThreatMa // MergeSecurityDetectionsAPIThreatMatchRuleUpdateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleUpdateProps, using the provided SecurityDetectionsAPIThreatMatchRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) MergeSecurityDetectionsAPIThreatMatchRuleUpdateProps(v SecurityDetectionsAPIThreatMatchRuleUpdateProps) error { + v.Type = "threat_match" b, err := json.Marshal(v) if err != nil { return err @@ -54502,6 +54641,7 @@ func (t SecurityDetectionsAPIRuleUpdateProps) AsSecurityDetectionsAPIMachineLear // FromSecurityDetectionsAPIMachineLearningRuleUpdateProps overwrites any union data inside the SecurityDetectionsAPIRuleUpdateProps as the provided SecurityDetectionsAPIMachineLearningRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIMachineLearningRuleUpdateProps(v SecurityDetectionsAPIMachineLearningRuleUpdateProps) error { + v.Type = "machine_learning" b, err := json.Marshal(v) t.union = b return err @@ -54509,6 +54649,7 @@ func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIMachineL // MergeSecurityDetectionsAPIMachineLearningRuleUpdateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleUpdateProps, using the provided SecurityDetectionsAPIMachineLearningRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) MergeSecurityDetectionsAPIMachineLearningRuleUpdateProps(v SecurityDetectionsAPIMachineLearningRuleUpdateProps) error { + v.Type = "machine_learning" b, err := json.Marshal(v) if err != nil { return err @@ -54528,6 +54669,7 @@ func (t SecurityDetectionsAPIRuleUpdateProps) AsSecurityDetectionsAPINewTermsRul // FromSecurityDetectionsAPINewTermsRuleUpdateProps overwrites any union data inside the SecurityDetectionsAPIRuleUpdateProps as the provided SecurityDetectionsAPINewTermsRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPINewTermsRuleUpdateProps(v SecurityDetectionsAPINewTermsRuleUpdateProps) error { + v.Type = "new_terms" b, err := json.Marshal(v) t.union = b return err @@ -54535,6 +54677,7 @@ func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPINewTerms // MergeSecurityDetectionsAPINewTermsRuleUpdateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleUpdateProps, using the provided SecurityDetectionsAPINewTermsRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) MergeSecurityDetectionsAPINewTermsRuleUpdateProps(v SecurityDetectionsAPINewTermsRuleUpdateProps) error { + v.Type = "new_terms" b, err := json.Marshal(v) if err != nil { return err @@ -54554,6 +54697,7 @@ func (t SecurityDetectionsAPIRuleUpdateProps) AsSecurityDetectionsAPIEsqlRuleUpd // FromSecurityDetectionsAPIEsqlRuleUpdateProps overwrites any union data inside the SecurityDetectionsAPIRuleUpdateProps as the provided SecurityDetectionsAPIEsqlRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIEsqlRuleUpdateProps(v SecurityDetectionsAPIEsqlRuleUpdateProps) error { + v.Type = "esql" b, err := json.Marshal(v) t.union = b return err @@ -54561,6 +54705,7 @@ func (t *SecurityDetectionsAPIRuleUpdateProps) FromSecurityDetectionsAPIEsqlRule // MergeSecurityDetectionsAPIEsqlRuleUpdateProps performs a merge with any union data inside the SecurityDetectionsAPIRuleUpdateProps, using the provided SecurityDetectionsAPIEsqlRuleUpdateProps func (t *SecurityDetectionsAPIRuleUpdateProps) MergeSecurityDetectionsAPIEsqlRuleUpdateProps(v SecurityDetectionsAPIEsqlRuleUpdateProps) error { + v.Type = "esql" b, err := json.Marshal(v) if err != nil { return err @@ -54571,6 +54716,41 @@ func (t *SecurityDetectionsAPIRuleUpdateProps) MergeSecurityDetectionsAPIEsqlRul return err } +func (t SecurityDetectionsAPIRuleUpdateProps) Discriminator() (string, error) { + var discriminator struct { + Discriminator string `json:"type"` + } + err := json.Unmarshal(t.union, &discriminator) + return discriminator.Discriminator, err +} + +func (t SecurityDetectionsAPIRuleUpdateProps) ValueByDiscriminator() (interface{}, error) { + discriminator, err := t.Discriminator() + if err != nil { + return nil, err + } + switch discriminator { + case "eql": + return t.AsSecurityDetectionsAPIEqlRuleUpdateProps() + case "esql": + return t.AsSecurityDetectionsAPIEsqlRuleUpdateProps() + case "machine_learning": + return t.AsSecurityDetectionsAPIMachineLearningRuleUpdateProps() + case "new_terms": + return t.AsSecurityDetectionsAPINewTermsRuleUpdateProps() + case "query": + return t.AsSecurityDetectionsAPIQueryRuleUpdateProps() + case "saved_query": + return t.AsSecurityDetectionsAPISavedQueryRuleUpdateProps() + case "threat_match": + return t.AsSecurityDetectionsAPIThreatMatchRuleUpdateProps() + case "threshold": + return t.AsSecurityDetectionsAPIThresholdRuleUpdateProps() + default: + return nil, errors.New("unknown discriminator value: " + discriminator) + } +} + func (t SecurityDetectionsAPIRuleUpdateProps) MarshalJSON() ([]byte, error) { b, err := t.union.MarshalJSON() return b, err diff --git a/generated/kbapi/transform_schema.go b/generated/kbapi/transform_schema.go index 009fcb3e8..82a841583 100644 --- a/generated/kbapi/transform_schema.go +++ b/generated/kbapi/transform_schema.go @@ -812,6 +812,56 @@ func transformKibanaPaths(schema *Schema) { schema.Components.CreateRef(schema, "Data_views_create_data_view_request_object_inner", "schemas.Data_views_create_data_view_request_object.properties.data_view") schema.Components.CreateRef(schema, "Data_views_update_data_view_request_object_inner", "schemas.Data_views_update_data_view_request_object.properties.data_view") + schema.Components.Set("schemas.Security_Detections_API_RuleResponse.discriminator", Map{ + "mapping": Map{ + "eql": "#/components/schemas/Security_Detections_API_EqlRule", + "esql": "#/components/schemas/Security_Detections_API_EsqlRule", + "machine_learning": "#/components/schemas/Security_Detections_API_MachineLearningRule", + "new_terms": "#/components/schemas/Security_Detections_API_NewTermsRule", + "query": "#/components/schemas/Security_Detections_API_QueryRule", + "saved_query": "#/components/schemas/Security_Detections_API_SavedQueryRule", + "threat_match": "#/components/schemas/Security_Detections_API_ThreatMatchRule", + "threshold": "#/components/schemas/Security_Detections_API_ThresholdRule", + }, + "propertyName": "type", + }) + + schema.Components.Set("schemas.Security_Detections_API_RuleCreateProps.discriminator", Map{ + "mapping": Map{ + "eql": "#/components/schemas/Security_Detections_API_EqlRuleCreateProps", + "esql": "#/components/schemas/Security_Detections_API_EsqlRuleCreateProps", + "machine_learning": "#/components/schemas/Security_Detections_API_MachineLearningRuleCreateProps", + "new_terms": "#/components/schemas/Security_Detections_API_NewTermsRuleCreateProps", + "query": "#/components/schemas/Security_Detections_API_QueryRuleCreateProps", + "saved_query": "#/components/schemas/Security_Detections_API_SavedQueryRuleCreateProps", + "threat_match": "#/components/schemas/Security_Detections_API_ThreatMatchRuleCreateProps", + "threshold": "#/components/schemas/Security_Detections_API_ThresholdRuleCreateProps", + }, + "propertyName": "type", + }) + + schema.Components.Set("schemas.Security_Detections_API_RuleUpdateProps.discriminator", Map{ + "mapping": Map{ + "eql": "#/components/schemas/Security_Detections_API_EqlRuleUpdateProps", + "esql": "#/components/schemas/Security_Detections_API_EsqlRuleUpdateProps", + "machine_learning": "#/components/schemas/Security_Detections_API_MachineLearningRuleUpdateProps", + "new_terms": "#/components/schemas/Security_Detections_API_NewTermsRuleUpdateProps", + "query": "#/components/schemas/Security_Detections_API_QueryRuleUpdateProps", + "saved_query": "#/components/schemas/Security_Detections_API_SavedQueryRuleUpdateProps", + "threat_match": "#/components/schemas/Security_Detections_API_ThreatMatchRuleUpdateProps", + "threshold": "#/components/schemas/Security_Detections_API_ThresholdRuleUpdateProps", + }, + "propertyName": "type", + }) + + schema.Components.Set("schemas.Security_Detections_API_ResponseAction.discriminator", Map{ + "mapping": Map{ + ".osquery": "#/components/schemas/Security_Detections_API_OsqueryResponseAction", + ".endpoint": "#/components/schemas/Security_Detections_API_EndpointResponseAction", + }, + "propertyName": "action_type_id", + }) + } func removeBrokenDiscriminator(schema *Schema) { @@ -826,10 +876,7 @@ func removeBrokenDiscriminator(schema *Schema) { "Security_AI_Assistant_API_KnowledgeBaseEntryResponse", "Security_AI_Assistant_API_KnowledgeBaseEntryUpdateProps", "Security_AI_Assistant_API_KnowledgeBaseEntryUpdateRouteProps", - "Security_Detections_API_RuleCreateProps", - "Security_Detections_API_RuleResponse", "Security_Detections_API_RuleSource", - "Security_Detections_API_RuleUpdateProps", "Security_Endpoint_Exceptions_API_ExceptionListItemEntry", "Security_Exceptions_API_ExceptionListItemEntry", } diff --git a/internal/clients/api_client.go b/internal/clients/api_client.go index 2861994f5..1361aff1b 100644 --- a/internal/clients/api_client.go +++ b/internal/clients/api_client.go @@ -356,6 +356,10 @@ func (a *ApiClient) EnforceMinVersion(ctx context.Context, minVersion *version.V return serverVersion.GreaterThanOrEqual(minVersion), nil } +type MinVersionEnforceable interface { + EnforceMinVersion(ctx context.Context, minVersion *version.Version) (bool, diag.Diagnostics) +} + func (a *ApiClient) ServerVersion(ctx context.Context) (*version.Version, diag.Diagnostics) { if a.elasticsearch != nil { return a.versionFromElasticsearch(ctx) diff --git a/internal/kibana/security_detection_rule/acc_test.go b/internal/kibana/security_detection_rule/acc_test.go new file mode 100644 index 000000000..9ba1f4a25 --- /dev/null +++ b/internal/kibana/security_detection_rule/acc_test.go @@ -0,0 +1,4524 @@ +package security_detection_rule_test + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/acctest" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/clients/kibana_oapi" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/elastic/terraform-provider-elasticstack/internal/versionutils" + "github.com/google/uuid" + "github.com/hashicorp/go-version" + "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/terraform" +) + +// checkResourceJSONAttr compares the JSON string value of a resource attribute +func checkResourceJSONAttr(name, key, expectedJSON string) resource.TestCheckFunc { + return func(s *terraform.State) error { + ms := s.RootModule() + rs, ok := ms.Resources[name] + if !ok { + return fmt.Errorf("Not found: %s in %s", name, ms.Path) + } + is := rs.Primary + if is == nil { + return fmt.Errorf("No primary instance: %s in %s", name, ms.Path) + } + + actualJSON, ok := is.Attributes[key] + if !ok { + return fmt.Errorf("%s: Attribute '%s' not found", name, key) + } + + if eq, err := utils.JSONBytesEqual([]byte(expectedJSON), []byte(actualJSON)); !eq { + return fmt.Errorf( + "%s: Attribute '%s' expected %#v, got %#v (: %v)", + name, + key, + expectedJSON, + actualJSON, + err) + } + return nil + } +} + +var minVersionSupport = version.Must(version.NewVersion("8.11.0")) +var minResponseActionVersionSupport = version.Must(version.NewVersion("8.16.0")) + +func TestAccResourceSecurityDetectionRule_Query(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_query("test-query-rule"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-query-rule"), + resource.TestCheckResourceAttr(resourceName, "type", "query"), + resource.TestCheckResourceAttr(resourceName, "query", "*:*"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test query security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "50"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "test-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "test-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom Query Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "@timestamp"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "true"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "event.severity"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "85"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "event.action"), + + // Check filters field + checkResourceJSONAttr(resourceName, "filters", `[{"bool": {"must": [{"term": {"event.category": "authentication"}}], "must_not": [{"term": {"event.outcome": "success"}}]}}]`), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "windows"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "1.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "system"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "event.type"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "host.os.type"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "event.severity_level"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "critical"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "critical"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.query", "SELECT * FROM processes WHERE name = 'malicious.exe';"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "300"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.process.name", "name"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.process.pid", "pid"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.action_type_id", ".endpoint"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.command", "isolate"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.comment", "Isolate host due to suspicious activity"), + + // Check alert suppression + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.#", "2"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.1", "host.name"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.duration", "5m"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.missing_fields_strategy", "suppress"), + + // Verify building_block_type is not set by default + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_queryUpdate("test-query-rule-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-query-rule-updated"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test query security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "75"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "updated-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom Query Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "event.ingested"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "false"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "event.risk_level"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "critical"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "95"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "3"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "event.action"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.2", "source.ip"), + + // Check filters field (updated values) + checkResourceJSONAttr(resourceName, "filters", `[{"range": {"@timestamp": {"gte": "now-1h", "lte": "now"}}}, {"terms": {"event.action": ["login", "logout", "access"]}}]`), + + // Check related integrations (updated values) + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "2"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "linux"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "2.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "auditd"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.1.package", "network"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.1.version", "1.5.0"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations.1.integration"), + + // Check required fields (updated values) + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "3"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "event.category"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "process.name"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.2.name", "custom.field"), + resource.TestCheckResourceAttr(resourceName, "required_fields.2.type", "text"), + resource.TestCheckResourceAttr(resourceName, "required_fields.2.ecs", "false"), + + // Check severity mapping (updated values) + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "2"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "alert.severity"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "high"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "high"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.1.field", "alert.severity"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.1.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.1.value", "medium"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.1.severity", "medium"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.pack_id", "incident_response_pack"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "600"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.host.name", "hostname"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.name", "username"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.process.name", "process_name"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.#", "2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.id", "query1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.query", "SELECT * FROM logged_in_users;"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.platform", "linux"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.version", "4.6.0"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.1.id", "query2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.1.query", "SELECT * FROM processes WHERE state = 'R';"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.1.platform", "linux"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.1.version", "4.6.0"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.1.ecs_mapping.process.pid", "pid"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.1.ecs_mapping.process.command_line", "cmdline"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.action_type_id", ".endpoint"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.command", "kill-process"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.comment", "Kill suspicious process identified during investigation"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.config.field", "process.entity_id"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.config.overwrite", "true"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_queryRemoveFilters("test-query-rule-no-filters"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-query-rule-no-filters"), + resource.TestCheckResourceAttr(resourceName, "description", "Test query rule with filters removed"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + + // Verify filters field is not present when not specified + resource.TestCheckNoResourceAttr(resourceName, "filters"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_EQL(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_eql("test-eql-rule"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-eql-rule"), + resource.TestCheckResourceAttr(resourceName, "type", "eql"), + resource.TestCheckResourceAttr(resourceName, "query", "process where process.name == \"cmd.exe\""), + resource.TestCheckResourceAttr(resourceName, "language", "eql"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test EQL security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "70"), + resource.TestCheckResourceAttr(resourceName, "index.0", "winlogbeat-*"), + resource.TestCheckResourceAttr(resourceName, "tiebreaker_field", "@timestamp"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "eql-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "eql-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom EQL Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "process.start"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "false"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "process.executable"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "C:\\Windows\\System32\\cmd.exe"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "75"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "process.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "process.executable"), + + // Check filters field + checkResourceJSONAttr(resourceName, "filters", `[{"bool": {"filter": [{"term": {"process.parent.name": "explorer.exe"}}]}}]`), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "windows"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "1.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "system"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "process.name"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "event.type"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "event.severity_level"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "high"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "high"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.saved_query_id", "suspicious_processes"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "300"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_eqlUpdate("test-eql-rule-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-eql-rule-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "process where process.name == \"powershell.exe\""), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test EQL security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "critical"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "90"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom EQL Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "process.end"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "true"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "process.parent.name"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "cmd.exe"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "95"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "3"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "process.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "process.executable"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.2", "process.parent.name"), + + // Check filters field (updated values) + checkResourceJSONAttr(resourceName, "filters", `[{"exists": {"field": "process.code_signature.trusted"}}, {"term": {"host.os.family": "windows"}}]`), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "windows"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "2.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "system"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "process.parent.name"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "event.category"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "event.severity_level"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "critical"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "critical"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.pack_id", "eql_response_pack"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "450"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.process.executable", "executable_path"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.process.parent.name", "parent_name"), + + // Check alert suppression (updated values) + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.#", "2"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.0", "process.parent.name"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.1", "host.name"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.duration", "45m"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.missing_fields_strategy", "doNotSuppress"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_ESQL(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_esql("test-esql-rule"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-esql-rule"), + resource.TestCheckResourceAttr(resourceName, "type", "esql"), + resource.TestCheckResourceAttr(resourceName, "query", "FROM logs-* | WHERE event.action == \"login\" | STATS count(*) BY user.name"), + resource.TestCheckResourceAttr(resourceName, "language", "esql"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test ESQL security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "60"), + resource.TestCheckResourceAttr(resourceName, "namespace", "esql-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom ESQL Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "event.created"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "true"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "user.domain"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "admin"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "80"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "user.domain"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "system"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "1.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "auth"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "user.name"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "event.action"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "user.domain"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "admin"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "high"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.query", "SELECT * FROM users WHERE username LIKE '%admin%';"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "400"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.name", "username"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.domain", "domain"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.action_type_id", ".endpoint"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.command", "isolate"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.comment", "Isolate host due to suspicious admin activity"), + + // Check alert suppression + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.#", "2"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.1", "user.domain"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.duration", "15m"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.missing_fields_strategy", "doNotSuppress"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_esqlUpdate("test-esql-rule-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-esql-rule-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "FROM logs-* | WHERE event.action == \"logout\" | STATS count(*) BY user.name, source.ip"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test ESQL security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "80"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom ESQL Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "event.start"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "false"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "event.outcome"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "failure"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "95"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "3"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "user.domain"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.2", "event.outcome"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "system"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "2.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "auth"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "user.name"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "event.outcome"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "event.outcome"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "failure"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "critical"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.saved_query_id", "failed_login_investigation"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "500"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.event.outcome", "outcome"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.name", "username"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.source.ip", "source_ip"), + + resource.TestCheckResourceAttr(resourceName, "exceptions_list.#", "1"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.id", "esql-exception-1"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.list_id", "esql-rule-exceptions"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.namespace_type", "single"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.type", "detection"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_MachineLearning(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_machineLearning("test-ml-rule"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-ml-rule"), + resource.TestCheckResourceAttr(resourceName, "type", "machine_learning"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test ML security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "critical"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "90"), + resource.TestCheckResourceAttr(resourceName, "anomaly_threshold", "75"), + resource.TestCheckResourceAttr(resourceName, "machine_learning_job_id.0", "test-ml-job"), + + resource.TestCheckResourceAttr(resourceName, "namespace", "ml-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom ML Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "ml.job_id"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "false"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "ml.anomaly_score"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "critical"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "100"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "ml.anomaly_score"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "ml.job_id"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "ml"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "1.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "anomaly_detection"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "ml.anomaly_score"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "double"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "false"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "ml.job_id"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "false"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "ml.anomaly_score"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "critical"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "critical"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.query", "SELECT * FROM processes WHERE pid IN (SELECT DISTINCT pid FROM connections WHERE remote_address NOT LIKE '10.%' AND remote_address NOT LIKE '192.168.%' AND remote_address NOT LIKE '127.%');"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "600"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.process.pid", "pid"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.process.name", "name"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.ml.anomaly_score", "anomaly_score"), + + // Check alert suppression + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.#", "1"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.0", "ml.job_id"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.duration", "30m"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.missing_fields_strategy", "suppress"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_machineLearningUpdate("test-ml-rule-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-ml-rule-updated"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test ML security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "85"), + resource.TestCheckResourceAttr(resourceName, "anomaly_threshold", "80"), + resource.TestCheckResourceAttr(resourceName, "machine_learning_job_id.0", "test-ml-job"), + resource.TestCheckResourceAttr(resourceName, "machine_learning_job_id.1", "test-ml-job-2"), + + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom ML Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "ml.anomaly_score"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "true"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "ml.is_anomaly"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "true"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "95"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "3"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "ml.anomaly_score"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "ml.job_id"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.2", "ml.is_anomaly"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "ml"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "2.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "anomaly_detection"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "ml.is_anomaly"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "boolean"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "false"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "ml.job_id"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "false"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "ml.is_anomaly"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "true"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "high"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.pack_id", "ml_anomaly_investigation"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "700"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.ml.job_id", "job_id"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.ml.is_anomaly", "is_anomaly"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.host.name", "hostname"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.id", "ml_query1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.query", "SELECT * FROM system_info;"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.platform", "linux"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.version", "4.7.0"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.action_type_id", ".endpoint"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.command", "isolate"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.comment", "Collect process tree for ML anomaly investigation"), + + resource.TestCheckResourceAttr(resourceName, "exceptions_list.#", "1"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.id", "ml-exception-1"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.list_id", "ml-rule-exceptions"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.namespace_type", "agnostic"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.type", "detection"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_NewTerms(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_newTerms("test-new-terms-rule"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-new-terms-rule"), + resource.TestCheckResourceAttr(resourceName, "type", "new_terms"), + resource.TestCheckResourceAttr(resourceName, "query", "user.name:*"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test new terms security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "50"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "new_terms_fields.0", "user.name"), + + // Check filters field + checkResourceJSONAttr(resourceName, "filters", `[{"bool": {"should": [{"wildcard": {"user.domain": "*.internal"}}, {"term": {"user.type": "service_account"}}]}}]`), + + resource.TestCheckResourceAttr(resourceName, "history_window_start", "now-14d"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "new-terms-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "new-terms-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom New Terms Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "user.created"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "true"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "user.type"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "service_account"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "65"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "user.type"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "security"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "1.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "users"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "user.name"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "user.type"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "false"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "user.type"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "service_account"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "medium"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.query", "SELECT * FROM last WHERE username = '{{user.name}}';"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "350"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.name", "username"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.type", "user_type"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.host.name", "hostname"), + + // Check alert suppression + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.#", "2"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.1", "user.type"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.duration", "20m"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.missing_fields_strategy", "doNotSuppress"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_newTermsUpdate("test-new-terms-rule-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-new-terms-rule-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "user.name:* AND source.ip:*"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test new terms security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "75"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "index.1", "audit-*"), + resource.TestCheckResourceAttr(resourceName, "new_terms_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "new_terms_fields.1", "source.ip"), + + // Check filters field (updated values) + checkResourceJSONAttr(resourceName, "filters", `[{"geo_distance": {"distance": "1000km", "source.geo.location": {"lat": 40.12, "lon": -71.34}}}]`), + + resource.TestCheckResourceAttr(resourceName, "history_window_start", "now-30d"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom New Terms Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "user.last_login"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "false"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "2"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "user.roles"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "admin"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "95"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.1.field", "source.geo.country_name"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.1.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.1.value", "CN"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.1.risk_score", "85"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "4"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "user.type"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.2", "source.ip"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.3", "user.roles"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.saved_query_id", "admin_user_investigation"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "800"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.roles", "roles"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.source.ip", "source_ip"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.name", "username"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.action_type_id", ".endpoint"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.command", "isolate"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.comment", "Isolate host due to new admin user activity from suspicious IP"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_SavedQuery(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_savedQuery("test-saved-query-rule"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-saved-query-rule"), + resource.TestCheckResourceAttr(resourceName, "type", "saved_query"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test saved query security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "30"), + resource.TestCheckResourceAttr(resourceName, "saved_id", "test-saved-query-id"), + + // Check filters field + checkResourceJSONAttr(resourceName, "filters", `[{"prefix": {"event.action": "user_"}}]`), + + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "saved-query-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "saved-query-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom Saved Query Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "event.start"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "false"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "event.category"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "authentication"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "45"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "event.category"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "event.action"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "system"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "1.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "logs"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "event.category"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "event.action"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "event.category"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "authentication"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "low"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.query", "SELECT * FROM logged_in_users WHERE user = '{{user.name}}';"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "250"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.event.category", "category"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.event.action", "action"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.name", "username"), + + // Check alert suppression + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.#", "2"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.0", "event.category"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.1", "event.action"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.duration", "8h"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.missing_fields_strategy", "suppress"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_savedQueryUpdate("test-saved-query-rule-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-saved-query-rule-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "event.action:*"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test saved query security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "60"), + resource.TestCheckResourceAttr(resourceName, "saved_id", "test-saved-query-id-updated"), + + // Check filters field (updated values) + checkResourceJSONAttr(resourceName, "filters", `[{"script": {"script": {"source": "doc['event.severity'].value > 2"}}}]`), + + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "index.1", "audit-*"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-saved-query-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "updated-saved-query-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom Saved Query Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "event.end"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "true"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "event.type"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "access"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "70"), + + // Check investigation fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "3"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "host.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "user.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.2", "process.name"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "system"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "2.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "logs"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "event.type"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "host.name"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "event.type"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "access"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "medium"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.pack_id", "access_investigation_pack"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "400"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.event.type", "type"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.host.name", "hostname"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.name", "username"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.id", "access_query1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.query", "SELECT * FROM users WHERE username = '{{user.name}}';"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.platform", "linux"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.version", "4.8.0"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.ecs_mapping.user.id", "uid"), + + resource.TestCheckResourceAttr(resourceName, "exceptions_list.#", "1"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.id", "saved-query-exception-1"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.list_id", "saved-query-exceptions"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.namespace_type", "agnostic"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.type", "detection"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_ThreatMatch(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_threatMatch("test-threat-match-rule"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-threat-match-rule"), + resource.TestCheckResourceAttr(resourceName, "type", "threat_match"), + resource.TestCheckResourceAttr(resourceName, "query", "destination.ip:*"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test threat match security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "80"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "threat-match-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "threat-match-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom Threat Match Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "threat.indicator.first_seen"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "true"), + resource.TestCheckResourceAttr(resourceName, "threat_index.0", "threat-intel-*"), + resource.TestCheckResourceAttr(resourceName, "threat_query", "threat.indicator.type:ip"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.0.entries.0.field", "destination.ip"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.0.entries.0.type", "mapping"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.0.entries.0.value", "threat.indicator.ip"), + + // Check filters field + checkResourceJSONAttr(resourceName, "filters", `[{"bool": {"must_not": [{"term": {"destination.ip": "127.0.0.1"}}]}}]`), + + // Check investigation_fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "destination.ip"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "source.ip"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "threat_intel"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "1.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "indicators"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "destination.ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "threat.indicator.ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "threat.indicator.confidence"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "high"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "high"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "threat.indicator.confidence"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "85"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.query", "SELECT * FROM listening_ports WHERE address = '{{destination.ip}}';"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "300"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.destination.ip", "dest_ip"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.threat.indicator.ip", "threat_ip"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.threat.indicator.confidence", "confidence"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.action_type_id", ".endpoint"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.command", "isolate"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.comment", "Isolate host due to threat match on destination IP"), + + // Check alert suppression + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.#", "2"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.0", "destination.ip"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.group_by.1", "source.ip"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.duration", "1h"), + resource.TestCheckResourceAttr(resourceName, "alert_suppression.missing_fields_strategy", "doNotSuppress"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_threatMatchUpdate("test-threat-match-rule-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-threat-match-rule-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "destination.ip:* OR source.ip:*"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test threat match security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "critical"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "95"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "index.1", "network-*"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-threat-match-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "updated-threat-match-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom Threat Match Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "threat.indicator.last_seen"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "false"), + resource.TestCheckResourceAttr(resourceName, "threat_index.0", "threat-intel-*"), + resource.TestCheckResourceAttr(resourceName, "threat_index.1", "ioc-*"), + resource.TestCheckResourceAttr(resourceName, "threat_query", "threat.indicator.type:(ip OR domain)"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.0.entries.0.field", "destination.ip"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.1.entries.0.field", "source.ip"), + + // Check filters field (updated values) + checkResourceJSONAttr(resourceName, "filters", `[{"regexp": {"destination.domain": ".*\\.suspicious\\.com"}}]`), + + // Check investigation_fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "3"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "destination.ip"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "source.ip"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.2", "threat.indicator.type"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "threat_intel"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "2.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "indicators"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "3"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "destination.ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "source.ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.2.name", "threat.indicator.ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.2.type", "ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.2.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "threat.indicator.confidence"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "high"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "critical"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "threat.indicator.confidence"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "100"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.saved_query_id", "threat_intel_investigation"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "450"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.source.ip", "src_ip"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.destination.ip", "dest_ip"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.threat.indicator.type", "threat_type"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.action_type_id", ".endpoint"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.command", "kill-process"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.comment", "Kill processes communicating with known threat indicators"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.config.field", "process.entity_id"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.config.overwrite", "true"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_Threshold(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_threshold("test-threshold-rule"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-threshold-rule"), + resource.TestCheckResourceAttr(resourceName, "type", "threshold"), + resource.TestCheckResourceAttr(resourceName, "query", "event.action:login"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test threshold security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "threshold-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "threshold-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom Threshold Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "event.created"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "false"), + resource.TestCheckResourceAttr(resourceName, "threshold.value", "10"), + resource.TestCheckResourceAttr(resourceName, "threshold.field.0", "user.name"), + + // Check filters field + checkResourceJSONAttr(resourceName, "filters", `[{"bool": {"filter": [{"range": {"event.ingested": {"gte": "now-24h"}}}]}}]`), + + // Check investigation_fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "event.action"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "system"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "1.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "auth"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "event.action"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "user.name"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "event.outcome"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "success"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "medium"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "event.outcome"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "success"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "45"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.query", "SELECT * FROM logged_in_users WHERE user = '{{user.name}}' ORDER BY time DESC LIMIT 10;"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "200"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.name", "username"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.event.action", "action"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.event.outcome", "outcome"), + + // Check alert suppression (threshold rules only support duration) + resource.TestCheckResourceAttr(resourceName, "alert_suppression.duration", "30m"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_thresholdUpdate("test-threshold-rule-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-threshold-rule-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "event.action:(login OR logout)"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test threshold security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "75"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "index.1", "audit-*"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-threshold-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "updated-threshold-namespace"), + resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom Threshold Rule Name"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override", "event.start"), + resource.TestCheckResourceAttr(resourceName, "timestamp_override_fallback_disabled", "true"), + resource.TestCheckResourceAttr(resourceName, "threshold.value", "20"), + resource.TestCheckResourceAttr(resourceName, "threshold.field.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "threshold.field.1", "source.ip"), + + // Check filters field (updated values) + checkResourceJSONAttr(resourceName, "filters", `[{"bool": {"should": [{"match": {"user.roles": "admin"}}, {"term": {"event.severity": "high"}}], "minimum_should_match": 1}}]`), + + // Check investigation_fields + resource.TestCheckResourceAttr(resourceName, "investigation_fields.#", "3"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.1", "source.ip"), + resource.TestCheckResourceAttr(resourceName, "investigation_fields.2", "event.outcome"), + + // Check related integrations + resource.TestCheckResourceAttr(resourceName, "related_integrations.#", "1"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.package", "system"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.version", "2.0.0"), + resource.TestCheckResourceAttr(resourceName, "related_integrations.0.integration", "auth"), + + // Check required fields + resource.TestCheckResourceAttr(resourceName, "required_fields.#", "2"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.name", "event.action"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.type", "keyword"), + resource.TestCheckResourceAttr(resourceName, "required_fields.0.ecs", "true"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.name", "source.ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.type", "ip"), + resource.TestCheckResourceAttr(resourceName, "required_fields.1.ecs", "true"), + + // Check severity mapping + resource.TestCheckResourceAttr(resourceName, "severity_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.field", "event.outcome"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.value", "failure"), + resource.TestCheckResourceAttr(resourceName, "severity_mapping.0.severity", "high"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "event.outcome"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "failure"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "90"), + + // Check response actions + resource.TestCheckResourceAttr(resourceName, "response_actions.#", "2"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.action_type_id", ".osquery"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.pack_id", "login_failure_investigation"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.timeout", "350"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.event.outcome", "outcome"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.source.ip", "source_ip"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.ecs_mapping.user.name", "username"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.#", "1"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.id", "failed_login_query"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.query", "SELECT * FROM last WHERE type = 7 AND username = '{{user.name}}';"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.platform", "linux"), + resource.TestCheckResourceAttr(resourceName, "response_actions.0.params.queries.0.version", "4.9.0"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.action_type_id", ".endpoint"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.command", "isolate"), + resource.TestCheckResourceAttr(resourceName, "response_actions.1.params.comment", "Isolate host due to multiple failed login attempts"), + + // Check updated alert suppression (threshold rules only support duration) + resource.TestCheckResourceAttr(resourceName, "alert_suppression.duration", "45h"), + ), + }, + }, + }) +} + +func testAccCheckSecurityDetectionRuleDestroy(s *terraform.State) error { + client, err := clients.NewAcceptanceTestingClient() + if err != nil { + return err + } + + kbClient, err := client.GetKibanaOapiClient() + if err != nil { + return err + } + + for _, rs := range s.RootModule().Resources { + switch rs.Type { + case "elasticstack_kibana_security_detection_rule": + // Parse ID to get space_id and rule_id + parts := strings.Split(rs.Primary.ID, "/") + if len(parts) != 2 { + return fmt.Errorf("invalid resource ID format: %s", rs.Primary.ID) + } + ruleId := parts[1] + + // Check if the rule still exists + ruleObjectId := kbapi.SecurityDetectionsAPIRuleObjectId(uuid.MustParse(ruleId)) + params := &kbapi.ReadRuleParams{ + Id: &ruleObjectId, + } + + response, err := kbClient.API.ReadRuleWithResponse(context.Background(), params) + if err != nil { + return fmt.Errorf("failed to read security detection rule: %v", err) + } + + // If the rule still exists (status 200), it means destroy failed + if response.StatusCode() == 200 { + return fmt.Errorf("security detection rule (%s) still exists", ruleId) + } + + // If we get a 404, that's expected - the rule was properly destroyed + // Any other status code indicates an error + if response.StatusCode() != 404 { + return fmt.Errorf("unexpected status code when checking security detection rule: %d", response.StatusCode()) + } + + case "elasticstack_kibana_action_connector": + // Parse ID to get space_id and connector_id + compId, _ := clients.CompositeIdFromStr(rs.Primary.ID) + + // Get connector client from the Kibana OAPI client + oapiClient, err := client.GetKibanaOapiClient() + if err != nil { + return err + } + + connector, diags := kibana_oapi.GetConnector(context.Background(), oapiClient, compId.ResourceId, compId.ClusterId) + if diags.HasError() { + return fmt.Errorf("failed to get connector: %v", diags) + } + + if connector != nil { + return fmt.Errorf("action connector (%s) still exists", compId.ResourceId) + } + } + } + + return nil +} + +func testAccSecurityDetectionRuleConfig_query(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "*:*" + language = "kuery" + enabled = true + description = "Test query security detection rule" + severity = "medium" + risk_score = 50 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + data_view_id = "test-data-view-id" + namespace = "test-namespace" + rule_name_override = "Custom Query Rule Name" + timestamp_override = "@timestamp" + timestamp_override_fallback_disabled = true + + filters = jsonencode([ + { + "bool" = { + "must" = [ + { + "term" = { + "event.category" = "authentication" + } + } + ] + "must_not" = [ + { + "term" = { + "event.outcome" = "success" + } + } + ] + } + } + ]) + + investigation_fields = ["user.name", "event.action"] + + risk_score_mapping = [ + { + field = "event.severity" + operator = "equals" + value = "high" + risk_score = 85 + } + ] + + related_integrations = [ + { + package = "windows" + version = "1.0.0" + integration = "system" + } + ] + + required_fields = [ + { + name = "event.type" + type = "keyword" + }, + { + name = "host.os.type" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "event.severity_level" + operator = "equals" + value = "critical" + severity = "critical" + } + ] + + alert_suppression = { + group_by = ["user.name", "host.name"] + duration = "5m" + missing_fields_strategy = "suppress" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + query = "SELECT * FROM processes WHERE name = 'malicious.exe';" + timeout = 300 + ecs_mapping = { + "process.name" = "name" + "process.pid" = "pid" + } + } + }, + { + action_type_id = ".endpoint" + params = { + command = "isolate" + comment = "Isolate host due to suspicious activity" + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_queryUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "*:*" + language = "kuery" + enabled = true + description = "Updated test query security detection rule" + severity = "high" + risk_score = 75 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + author = ["Test Author"] + tags = ["test", "automation"] + license = "Elastic License v2" + data_view_id = "updated-data-view-id" + namespace = "updated-namespace" + rule_name_override = "Updated Custom Query Rule Name" + timestamp_override = "event.ingested" + timestamp_override_fallback_disabled = false + + filters = jsonencode([ + { + "range" = { + "@timestamp" = { + "gte" = "now-1h" + "lte" = "now" + } + } + }, + { + "terms" = { + "event.action" = ["login", "logout", "access"] + } + } + ]) + + investigation_fields = ["user.name", "event.action", "source.ip"] + + risk_score_mapping = [ + { + field = "event.risk_level" + operator = "equals" + value = "critical" + risk_score = 95 + } + ] + + related_integrations = [ + { + package = "linux" + version = "2.0.0" + integration = "auditd" + }, + { + package = "network" + version = "1.5.0" + } + ] + + required_fields = [ + { + name = "event.category" + type = "keyword" + }, + { + name = "process.name" + type = "keyword" + }, + { + name = "custom.field" + type = "text" + } + ] + + severity_mapping = [ + { + field = "alert.severity" + operator = "equals" + value = "high" + severity = "high" + }, + { + field = "alert.severity" + operator = "equals" + value = "medium" + severity = "medium" + } + ] + + response_actions = [ + { + action_type_id = ".osquery" + params = { + pack_id = "incident_response_pack" + timeout = 600 + ecs_mapping = { + "host.name" = "hostname" + "user.name" = "username" + "process.name" = "process_name" + } + queries = [ + { + id = "query1" + query = "SELECT * FROM logged_in_users;" + platform = "linux" + version = "4.6.0" + }, + { + id = "query2" + query = "SELECT * FROM processes WHERE state = 'R';" + platform = "linux" + version = "4.6.0" + ecs_mapping = { + "process.pid" = "pid" + "process.command_line" = "cmdline" + } + } + ] + } + }, + { + action_type_id = ".endpoint" + params = { + command = "kill-process" + comment = "Kill suspicious process identified during investigation" + config = { + field = "process.entity_id" + overwrite = true + } + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_eql(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "eql" + query = "process where process.name == \"cmd.exe\"" + language = "eql" + enabled = true + description = "Test EQL security detection rule" + severity = "high" + risk_score = 70 + from = "now-6m" + to = "now" + interval = "5m" + index = ["winlogbeat-*"] + tiebreaker_field = "@timestamp" + data_view_id = "eql-data-view-id" + namespace = "eql-namespace" + rule_name_override = "Custom EQL Rule Name" + timestamp_override = "process.start" + timestamp_override_fallback_disabled = false + + filters = jsonencode([ + { + "bool" = { + "filter" = [ + { + "term" = { + "process.parent.name" = "explorer.exe" + } + } + ] + } + } + ]) + + investigation_fields = ["process.name", "process.executable"] + + risk_score_mapping = [ + { + field = "process.executable" + operator = "equals" + value = "C:\\Windows\\System32\\cmd.exe" + risk_score = 75 + } + ] + + related_integrations = [ + { + package = "windows" + version = "1.0.0" + integration = "system" + } + ] + + required_fields = [ + { + name = "process.name" + type = "keyword" + }, + { + name = "event.type" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "event.severity_level" + operator = "equals" + value = "high" + severity = "high" + } + ] + + alert_suppression = { + group_by = ["process.name", "user.name"] + duration = "10m" + missing_fields_strategy = "suppress" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + saved_query_id = "suspicious_processes" + timeout = 300 + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_eqlUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "eql" + query = "process where process.name == \"powershell.exe\"" + language = "eql" + enabled = true + description = "Updated test EQL security detection rule" + severity = "critical" + risk_score = 90 + from = "now-6m" + to = "now" + interval = "5m" + index = ["winlogbeat-*"] + tiebreaker_field = "@timestamp" + author = ["Test Author"] + tags = ["test", "eql", "automation"] + license = "Elastic License v2" + rule_name_override = "Updated Custom EQL Rule Name" + timestamp_override = "process.end" + timestamp_override_fallback_disabled = true + + filters = jsonencode([ + { + "exists" = { + "field" = "process.code_signature.trusted" + } + }, + { + "term" = { + "host.os.family" = "windows" + } + } + ]) + + investigation_fields = ["process.name", "process.executable", "process.parent.name"] + + risk_score_mapping = [ + { + field = "process.parent.name" + operator = "equals" + value = "cmd.exe" + risk_score = 95 + } + ] + + related_integrations = [ + { + package = "windows" + version = "2.0.0" + integration = "system" + } + ] + + required_fields = [ + { + name = "process.parent.name" + type = "keyword" + }, + { + name = "event.category" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "event.severity_level" + operator = "equals" + value = "critical" + severity = "critical" + } + ] + + alert_suppression = { + group_by = ["process.parent.name", "host.name"] + duration = "45m" + missing_fields_strategy = "doNotSuppress" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + pack_id = "eql_response_pack" + timeout = 450 + ecs_mapping = { + "process.executable" = "executable_path" + "process.parent.name" = "parent_name" + } + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_esql(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "esql" + query = "FROM logs-* | WHERE event.action == \"login\" | STATS count(*) BY user.name" + language = "esql" + enabled = true + description = "Test ESQL security detection rule" + severity = "medium" + risk_score = 60 + from = "now-6m" + to = "now" + interval = "5m" + namespace = "esql-namespace" + rule_name_override = "Custom ESQL Rule Name" + timestamp_override = "event.created" + timestamp_override_fallback_disabled = true + + investigation_fields = ["user.name", "user.domain"] + + risk_score_mapping = [ + { + field = "user.domain" + operator = "equals" + value = "admin" + risk_score = 80 + } + ] + + related_integrations = [ + { + package = "system" + version = "1.0.0" + integration = "auth" + } + ] + + required_fields = [ + { + name = "user.name" + type = "keyword" + }, + { + name = "event.action" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "user.domain" + operator = "equals" + value = "admin" + severity = "high" + } + ] + + alert_suppression = { + group_by = ["user.name", "user.domain"] + duration = "15m" + missing_fields_strategy = "doNotSuppress" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + query = "SELECT * FROM users WHERE username LIKE '%%admin%%';" + timeout = 400 + ecs_mapping = { + "user.name" = "username" + "user.domain" = "domain" + } + } + }, + { + action_type_id = ".endpoint" + params = { + command = "isolate" + comment = "Isolate host due to suspicious admin activity" + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_esqlUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "esql" + query = "FROM logs-* | WHERE event.action == \"logout\" | STATS count(*) BY user.name, source.ip" + language = "esql" + enabled = true + description = "Updated test ESQL security detection rule" + severity = "high" + risk_score = 80 + from = "now-6m" + to = "now" + interval = "5m" + author = ["Test Author"] + tags = ["test", "esql", "automation"] + license = "Elastic License v2" + rule_name_override = "Updated Custom ESQL Rule Name" + timestamp_override = "event.start" + timestamp_override_fallback_disabled = false + + investigation_fields = ["user.name", "user.domain", "event.outcome"] + + risk_score_mapping = [ + { + field = "event.outcome" + operator = "equals" + value = "failure" + risk_score = 95 + } + ] + + related_integrations = [ + { + package = "system" + version = "2.0.0" + integration = "auth" + } + ] + + required_fields = [ + { + name = "user.name" + type = "keyword" + }, + { + name = "event.outcome" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "event.outcome" + operator = "equals" + value = "failure" + severity = "critical" + } + ] + + response_actions = [ + { + action_type_id = ".osquery" + params = { + saved_query_id = "failed_login_investigation" + timeout = 500 + ecs_mapping = { + "event.outcome" = "outcome" + "user.name" = "username" + "source.ip" = "source_ip" + } + } + } + ] + + exceptions_list = [ + { + id = "esql-exception-1" + list_id = "esql-rule-exceptions" + namespace_type = "single" + type = "detection" + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_machineLearning(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "machine_learning" + enabled = true + description = "Test ML security detection rule" + severity = "critical" + risk_score = 90 + from = "now-6m" + to = "now" + interval = "5m" + anomaly_threshold = 75 + machine_learning_job_id = ["test-ml-job"] + namespace = "ml-namespace" + rule_name_override = "Custom ML Rule Name" + timestamp_override = "ml.job_id" + timestamp_override_fallback_disabled = false + + investigation_fields = ["ml.anomaly_score", "ml.job_id"] + + risk_score_mapping = [ + { + field = "ml.anomaly_score" + operator = "equals" + value = "critical" + risk_score = 100 + } + ] + + related_integrations = [ + { + package = "ml" + version = "1.0.0" + integration = "anomaly_detection" + } + ] + + required_fields = [ + { + name = "ml.anomaly_score" + type = "double" + }, + { + name = "ml.job_id" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "ml.anomaly_score" + operator = "equals" + value = "critical" + severity = "critical" + } + ] + + alert_suppression = { + group_by = ["ml.job_id"] + duration = "30m" + missing_fields_strategy = "suppress" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + query = "SELECT * FROM processes WHERE pid IN (SELECT DISTINCT pid FROM connections WHERE remote_address NOT LIKE '10.%%' AND remote_address NOT LIKE '192.168.%%' AND remote_address NOT LIKE '127.%%');" + timeout = 600 + ecs_mapping = { + "process.pid" = "pid" + "process.name" = "name" + "ml.anomaly_score" = "anomaly_score" + } + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_machineLearningUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "machine_learning" + enabled = true + description = "Updated test ML security detection rule" + severity = "high" + risk_score = 85 + from = "now-6m" + to = "now" + interval = "5m" + anomaly_threshold = 80 + machine_learning_job_id = ["test-ml-job", "test-ml-job-2"] + author = ["Test Author"] + tags = ["test", "ml", "automation"] + license = "Elastic License v2" + rule_name_override = "Updated Custom ML Rule Name" + timestamp_override = "ml.anomaly_score" + timestamp_override_fallback_disabled = true + + investigation_fields = ["ml.anomaly_score", "ml.job_id", "ml.is_anomaly"] + + risk_score_mapping = [ + { + field = "ml.is_anomaly" + operator = "equals" + value = "true" + risk_score = 95 + } + ] + + related_integrations = [ + { + package = "ml" + version = "2.0.0" + integration = "anomaly_detection" + } + ] + + required_fields = [ + { + name = "ml.is_anomaly" + type = "boolean" + }, + { + name = "ml.job_id" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "ml.is_anomaly" + operator = "equals" + value = "true" + severity = "high" + } + ] + + response_actions = [ + { + action_type_id = ".osquery" + params = { + pack_id = "ml_anomaly_investigation" + timeout = 700 + ecs_mapping = { + "ml.job_id" = "job_id" + "ml.is_anomaly" = "is_anomaly" + "host.name" = "hostname" + } + queries = [ + { + id = "ml_query1" + query = "SELECT * FROM system_info;" + platform = "linux" + version = "4.7.0" + } + ] + } + }, + { + action_type_id = ".endpoint" + params = { + command = "isolate" + comment = "Collect process tree for ML anomaly investigation" + } + } + ] + + exceptions_list = [ + { + id = "ml-exception-1" + list_id = "ml-rule-exceptions" + namespace_type = "agnostic" + type = "detection" + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_newTerms(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "new_terms" + query = "user.name:*" + language = "kuery" + enabled = true + description = "Test new terms security detection rule" + severity = "medium" + risk_score = 50 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + new_terms_fields = ["user.name"] + history_window_start = "now-14d" + data_view_id = "new-terms-data-view-id" + namespace = "new-terms-namespace" + rule_name_override = "Custom New Terms Rule Name" + timestamp_override = "user.created" + timestamp_override_fallback_disabled = true + + filters = jsonencode([ + { + "bool" = { + "should" = [ + { + "wildcard" = { + "user.domain" = "*.internal" + } + }, + { + "term" = { + "user.type" = "service_account" + } + } + ] + } + } + ]) + + investigation_fields = ["user.name", "user.type"] + + risk_score_mapping = [ + { + field = "user.type" + operator = "equals" + value = "service_account" + risk_score = 65 + } + ] + + related_integrations = [ + { + package = "security" + version = "1.0.0" + integration = "users" + } + ] + + required_fields = [ + { + name = "user.name" + type = "keyword" + }, + { + name = "user.type" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "user.type" + operator = "equals" + value = "service_account" + severity = "medium" + } + ] + + alert_suppression = { + group_by = ["user.name", "user.type"] + duration = "20m" + missing_fields_strategy = "doNotSuppress" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + query = "SELECT * FROM last WHERE username = '{{user.name}}';" + timeout = 350 + ecs_mapping = { + "user.name" = "username" + "user.type" = "user_type" + "host.name" = "hostname" + } + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_newTermsUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "new_terms" + query = "user.name:* AND source.ip:*" + language = "kuery" + enabled = true + description = "Updated test new terms security detection rule" + severity = "high" + risk_score = 75 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*", "audit-*"] + new_terms_fields = ["user.name", "source.ip"] + history_window_start = "now-30d" + author = ["Test Author"] + tags = ["test", "new-terms", "automation"] + license = "Elastic License v2" + rule_name_override = "Updated Custom New Terms Rule Name" + timestamp_override = "user.last_login" + timestamp_override_fallback_disabled = false + + filters = jsonencode([ + { + "geo_distance" = { + "distance" = "1000km" + "source.geo.location" = { + "lat" = 40.12 + "lon" = -71.34 + } + } + } + ]) + + investigation_fields = ["user.name", "user.type", "source.ip", "user.roles"] + + risk_score_mapping = [ + { + field = "user.roles" + operator = "equals" + value = "admin" + risk_score = 95 + }, + { + field = "source.geo.country_name" + operator = "equals" + value = "CN" + risk_score = 85 + } + ] + + related_integrations = [ + { + package = "security" + version = "2.0.0" + integration = "users" + } + ] + + required_fields = [ + { + name = "user.name" + type = "keyword" + }, + { + name = "source.ip" + type = "ip" + }, + { + name = "user.roles" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "user.roles" + operator = "equals" + value = "admin" + severity = "high" + } + ] + + response_actions = [ + { + action_type_id = ".osquery" + params = { + saved_query_id = "admin_user_investigation" + timeout = 800 + ecs_mapping = { + "user.roles" = "roles" + "source.ip" = "source_ip" + "user.name" = "username" + } + } + }, + { + action_type_id = ".endpoint" + params = { + command = "isolate" + comment = "Isolate host due to new admin user activity from suspicious IP" + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_savedQuery(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "saved_query" + query = "*:*" + enabled = true + description = "Test saved query security detection rule" + severity = "low" + risk_score = 30 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + saved_id = "test-saved-query-id" + data_view_id = "saved-query-data-view-id" + namespace = "saved-query-namespace" + rule_name_override = "Custom Saved Query Rule Name" + timestamp_override = "event.start" + timestamp_override_fallback_disabled = false + + filters = jsonencode([ + { + "prefix" = { + "event.action" = "user_" + } + } + ]) + + investigation_fields = ["event.category", "event.action"] + + risk_score_mapping = [ + { + field = "event.category" + operator = "equals" + value = "authentication" + risk_score = 45 + } + ] + + related_integrations = [ + { + package = "system" + version = "1.0.0" + integration = "logs" + } + ] + + required_fields = [ + { + name = "event.category" + type = "keyword" + }, + { + name = "event.action" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "event.category" + operator = "equals" + value = "authentication" + severity = "low" + } + ] + + alert_suppression = { + group_by = ["event.category", "event.action"] + duration = "8h" + missing_fields_strategy = "suppress" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + query = "SELECT * FROM logged_in_users WHERE user = '{{user.name}}';" + timeout = 250 + ecs_mapping = { + "event.category" = "category" + "event.action" = "action" + "user.name" = "username" + } + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_savedQueryUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "saved_query" + query = "event.action:*" + enabled = true + description = "Updated test saved query security detection rule" + severity = "medium" + risk_score = 60 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*", "audit-*"] + saved_id = "test-saved-query-id-updated" + data_view_id = "updated-saved-query-data-view-id" + namespace = "updated-saved-query-namespace" + author = ["Test Author"] + tags = ["test", "saved-query", "automation"] + license = "Elastic License v2" + rule_name_override = "Updated Custom Saved Query Rule Name" + timestamp_override = "event.end" + timestamp_override_fallback_disabled = true + + filters = jsonencode([ + { + "script" = { + "script" = { + "source" = "doc['event.severity'].value > 2" + } + } + } + ]) + + investigation_fields = ["host.name", "user.name", "process.name"] + + risk_score_mapping = [ + { + field = "event.type" + operator = "equals" + value = "access" + risk_score = 70 + } + ] + + related_integrations = [ + { + package = "system" + version = "2.0.0" + integration = "logs" + } + ] + + required_fields = [ + { + name = "event.type" + type = "keyword" + }, + { + name = "host.name" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "event.type" + operator = "equals" + value = "access" + severity = "medium" + } + ] + + response_actions = [ + { + action_type_id = ".osquery" + params = { + pack_id = "access_investigation_pack" + timeout = 400 + ecs_mapping = { + "event.type" = "type" + "host.name" = "hostname" + "user.name" = "username" + } + queries = [ + { + id = "access_query1" + query = "SELECT * FROM users WHERE username = '{{user.name}}';" + platform = "linux" + version = "4.8.0" + ecs_mapping = { + "user.id" = "uid" + } + } + ] + } + } + ] + + exceptions_list = [ + { + id = "saved-query-exception-1" + list_id = "saved-query-exceptions" + namespace_type = "agnostic" + type = "detection" + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_threatMatch(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "threat_match" + query = "destination.ip:*" + language = "kuery" + enabled = true + description = "Test threat match security detection rule" + severity = "high" + risk_score = 80 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + data_view_id = "threat-match-data-view-id" + namespace = "threat-match-namespace" + rule_name_override = "Custom Threat Match Rule Name" + timestamp_override = "threat.indicator.first_seen" + timestamp_override_fallback_disabled = true + threat_index = ["threat-intel-*"] + threat_query = "threat.indicator.type:ip" + + filters = jsonencode([ + { + "bool" = { + "must_not" = [ + { + "term" = { + "destination.ip" = "127.0.0.1" + } + } + ] + } + } + ]) + + investigation_fields = ["destination.ip", "source.ip"] + + threat_mapping = [ + { + entries = [ + { + field = "destination.ip" + type = "mapping" + value = "threat.indicator.ip" + } + ] + } + ] + + risk_score_mapping = [ + { + field = "threat.indicator.confidence" + operator = "equals" + value = "medium" + risk_score = 85 + } + ] + + related_integrations = [ + { + package = "threat_intel" + version = "1.0.0" + integration = "indicators" + } + ] + + required_fields = [ + { + name = "destination.ip" + type = "ip" + }, + { + name = "threat.indicator.ip" + type = "ip" + } + ] + + severity_mapping = [ + { + field = "threat.indicator.confidence" + operator = "equals" + value = "high" + severity = "high" + } + ] + + alert_suppression = { + group_by = ["destination.ip", "source.ip"] + duration = "1h" + missing_fields_strategy = "doNotSuppress" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + query = "SELECT * FROM listening_ports WHERE address = '{{destination.ip}}';" + timeout = 300 + ecs_mapping = { + "destination.ip" = "dest_ip" + "threat.indicator.ip" = "threat_ip" + "threat.indicator.confidence" = "confidence" + } + } + }, + { + action_type_id = ".endpoint" + params = { + command = "isolate" + comment = "Isolate host due to threat match on destination IP" + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_threatMatchUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "threat_match" + query = "destination.ip:* OR source.ip:*" + language = "kuery" + enabled = true + description = "Updated test threat match security detection rule" + severity = "critical" + risk_score = 95 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*", "network-*"] + data_view_id = "updated-threat-match-data-view-id" + namespace = "updated-threat-match-namespace" + threat_index = ["threat-intel-*", "ioc-*"] + threat_query = "threat.indicator.type:(ip OR domain)" + author = ["Test Author"] + tags = ["test", "threat-match", "automation"] + license = "Elastic License v2" + rule_name_override = "Updated Custom Threat Match Rule Name" + timestamp_override = "threat.indicator.last_seen" + timestamp_override_fallback_disabled = false + + filters = jsonencode([ + { + "regexp" = { + "destination.domain" = ".*\\.suspicious\\.com" + } + } + ]) + + investigation_fields = ["destination.ip", "source.ip", "threat.indicator.type"] + + threat_mapping = [ + { + entries = [ + { + field = "destination.ip" + type = "mapping" + value = "threat.indicator.ip" + } + ] + }, + { + entries = [ + { + field = "source.ip" + type = "mapping" + value = "threat.indicator.ip" + } + ] + } + ] + + risk_score_mapping = [ + { + field = "threat.indicator.confidence" + operator = "equals" + value = "high" + risk_score = 100 + } + ] + + related_integrations = [ + { + package = "threat_intel" + version = "2.0.0" + integration = "indicators" + } + ] + + required_fields = [ + { + name = "destination.ip" + type = "ip" + }, + { + name = "source.ip" + type = "ip" + }, + { + name = "threat.indicator.ip" + type = "ip" + } + ] + + severity_mapping = [ + { + field = "threat.indicator.confidence" + operator = "equals" + value = "high" + severity = "critical" + } + ] + + response_actions = [ + { + action_type_id = ".osquery" + params = { + saved_query_id = "threat_intel_investigation" + timeout = 450 + ecs_mapping = { + "source.ip" = "src_ip" + "destination.ip" = "dest_ip" + "threat.indicator.type" = "threat_type" + } + } + }, + { + action_type_id = ".endpoint" + params = { + command = "kill-process" + comment = "Kill processes communicating with known threat indicators" + config = { + field = "process.entity_id" + overwrite = true + } + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_threshold(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "threshold" + query = "event.action:login" + language = "kuery" + enabled = true + description = "Test threshold security detection rule" + severity = "medium" + risk_score = 55 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + data_view_id = "threshold-data-view-id" + namespace = "threshold-namespace" + rule_name_override = "Custom Threshold Rule Name" + timestamp_override = "event.created" + timestamp_override_fallback_disabled = false + + filters = jsonencode([ + { + "bool" = { + "filter" = [ + { + "range" = { + "event.ingested" = { + "gte" = "now-24h" + } + } + } + ] + } + } + ]) + + investigation_fields = ["user.name", "event.action"] + + threshold = { + value = 10 + field = ["user.name"] + } + + risk_score_mapping = [ + { + field = "event.outcome" + operator = "equals" + value = "success" + risk_score = 45 + } + ] + + related_integrations = [ + { + package = "system" + version = "1.0.0" + integration = "auth" + } + ] + + required_fields = [ + { + name = "event.action" + type = "keyword" + }, + { + name = "user.name" + type = "keyword" + } + ] + + severity_mapping = [ + { + field = "event.outcome" + operator = "equals" + value = "success" + severity = "medium" + } + ] + + alert_suppression = { + duration = "30m" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + query = "SELECT * FROM logged_in_users WHERE user = '{{user.name}}' ORDER BY time DESC LIMIT 10;" + timeout = 200 + ecs_mapping = { + "user.name" = "username" + "event.action" = "action" + "event.outcome" = "outcome" + } + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_thresholdUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "threshold" + query = "event.action:(login OR logout)" + language = "kuery" + enabled = true + description = "Updated test threshold security detection rule" + severity = "high" + risk_score = 75 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*", "audit-*"] + data_view_id = "updated-threshold-data-view-id" + namespace = "updated-threshold-namespace" + author = ["Test Author"] + tags = ["test", "threshold", "automation"] + license = "Elastic License v2" + rule_name_override = "Updated Custom Threshold Rule Name" + timestamp_override = "event.start" + timestamp_override_fallback_disabled = true + + filters = jsonencode([ + { + "bool" = { + "should" = [ + { + "match" = { + "user.roles" = "admin" + } + }, + { + "term" = { + "event.severity" = "high" + } + } + ] + "minimum_should_match" = 1 + } + } + ]) + + investigation_fields = ["user.name", "source.ip", "event.outcome"] + + threshold = { + value = 20 + field = ["user.name", "source.ip"] + } + + risk_score_mapping = [ + { + field = "event.outcome" + operator = "equals" + value = "failure" + risk_score = 90 + } + ] + + related_integrations = [ + { + package = "system" + version = "2.0.0" + integration = "auth" + } + ] + + required_fields = [ + { + name = "event.action" + type = "keyword" + }, + { + name = "source.ip" + type = "ip" + } + ] + + severity_mapping = [ + { + field = "event.outcome" + operator = "equals" + value = "failure" + severity = "high" + } + ] + + alert_suppression = { + duration = "45h" + } + + response_actions = [ + { + action_type_id = ".osquery" + params = { + pack_id = "login_failure_investigation" + timeout = 350 + ecs_mapping = { + "event.outcome" = "outcome" + "source.ip" = "source_ip" + "user.name" = "username" + } + queries = [ + { + id = "failed_login_query" + query = "SELECT * FROM last WHERE type = 7 AND username = '{{user.name}}';" + platform = "linux" + version = "4.9.0" + } + ] + } + }, + { + action_type_id = ".endpoint" + params = { + command = "isolate" + comment = "Isolate host due to multiple failed login attempts" + } + } + ] +} +`, name) +} + +func TestAccResourceSecurityDetectionRule_WithConnectorAction(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + connectorResourceName := "elasticstack_kibana_action_connector.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_withConnectorAction("test-rule-with-action"), + Check: resource.ComposeTestCheckFunc( + // Check connector attributes + resource.TestCheckResourceAttr(connectorResourceName, "name", "test connector 1"), + resource.TestCheckResourceAttr(connectorResourceName, "connector_id", "1d30b67b-f90b-4e28-87c2-137cba361509"), + resource.TestCheckResourceAttr(connectorResourceName, "connector_type_id", ".cases-webhook"), + resource.TestCheckResourceAttrSet(connectorResourceName, "config"), + resource.TestCheckResourceAttrSet(connectorResourceName, "secrets"), + + // Check security detection rule attributes + resource.TestCheckResourceAttr(resourceName, "name", "test-rule-with-action"), + resource.TestCheckResourceAttr(resourceName, "type", "query"), + resource.TestCheckResourceAttr(resourceName, "query", "user.name:*"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test security detection rule with connector action"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "50"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "connector-action-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "connector-action-namespace"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "user.privileged"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "true"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "75"), + + // Check action attributes + resource.TestCheckResourceAttr(resourceName, "actions.#", "1"), + resource.TestCheckResourceAttr(resourceName, "actions.0.action_type_id", ".cases-webhook"), + resource.TestCheckResourceAttr(resourceName, "actions.0.id", "1d30b67b-f90b-4e28-87c2-137cba361509"), + resource.TestCheckResourceAttr(resourceName, "actions.0.group", "default"), + resource.TestCheckResourceAttr(resourceName, "actions.0.params.message", "CRITICAL EQL Alert: PowerShell process detected"), + resource.TestCheckResourceAttr(resourceName, "actions.0.frequency.notify_when", "onActiveAlert"), + resource.TestCheckResourceAttr(resourceName, "actions.0.frequency.summary", "true"), + resource.TestCheckResourceAttr(resourceName, "actions.0.frequency.throttle", "10m"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_withConnectorActionUpdate("test-rule-with-action-updated"), + Check: resource.ComposeTestCheckFunc( + // Check updated rule attributes + resource.TestCheckResourceAttr(resourceName, "name", "test-rule-with-action-updated"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test security detection rule with connector action"), + resource.TestCheckResourceAttr(resourceName, "severity", "high"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "75"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-connector-action-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "updated-connector-action-namespace"), + resource.TestCheckResourceAttr(resourceName, "tags.#", "2"), + resource.TestCheckResourceAttr(resourceName, "tags.0", "test"), + resource.TestCheckResourceAttr(resourceName, "tags.1", "terraform"), + + // Check risk score mapping + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.#", "1"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.field", "user.privileged"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.operator", "equals"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.value", "true"), + resource.TestCheckResourceAttr(resourceName, "risk_score_mapping.0.risk_score", "95"), + + // Check updated action attributes + resource.TestCheckResourceAttr(resourceName, "actions.0.params.message", "UPDATED CRITICAL Alert: Security event detected"), + resource.TestCheckResourceAttr(resourceName, "actions.0.frequency.throttle", "5m"), + + // Check exceptions list attributes + resource.TestCheckResourceAttr(resourceName, "exceptions_list.#", "1"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.id", "test-action-exception"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.list_id", "action-rule-exceptions"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.namespace_type", "single"), + resource.TestCheckResourceAttr(resourceName, "exceptions_list.0.type", "detection"), + ), + }, + }, + }) +} + +func testAccSecurityDetectionRuleConfig_withConnectorAction(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_action_connector" "test" { + name = "test connector 1" + connector_id = "1d30b67b-f90b-4e28-87c2-137cba361509" + config = jsonencode({ + createIncidentJson = "{}" + createIncidentResponseKey = "key" + createIncidentUrl = "https://www.elastic.co/" + getIncidentResponseExternalTitleKey = "title" + getIncidentUrl = "https://www.elastic.co/" + updateIncidentJson = "{}" + updateIncidentUrl = "https://elasticsearch.com/" + viewIncidentUrl = "https://www.elastic.co/" + createIncidentMethod = "put" + }) + secrets = jsonencode({ + user = "user2" + password = "password2" + }) + connector_type_id = ".cases-webhook" +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + description = "Test security detection rule with connector action" + type = "query" + severity = "medium" + risk_score = 50 + enabled = true + query = "user.name:*" + language = "kuery" + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + data_view_id = "connector-action-data-view-id" + namespace = "connector-action-namespace" + + risk_score_mapping = [ + { + field = "user.privileged" + operator = "equals" + value = "true" + risk_score = 75 + } + ] + + actions = [ + { + action_type_id = ".cases-webhook" + id = "${elasticstack_kibana_action_connector.test.connector_id}" + params = { + message = "CRITICAL EQL Alert: PowerShell process detected" + } + group = "default" + frequency = { + notify_when = "onActiveAlert" + summary = true + throttle = "10m" + } + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_withConnectorActionUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_action_connector" "test" { + name = "test connector 1" + connector_id = "1d30b67b-f90b-4e28-87c2-137cba361509" + config = jsonencode({ + createIncidentJson = "{}" + createIncidentResponseKey = "key" + createIncidentUrl = "https://www.elastic.co/" + getIncidentResponseExternalTitleKey = "title" + getIncidentUrl = "https://www.elastic.co/" + updateIncidentJson = "{}" + updateIncidentUrl = "https://elasticsearch.com/" + viewIncidentUrl = "https://www.elastic.co/" + createIncidentMethod = "put" + }) + secrets = jsonencode({ + user = "user2" + password = "password2" + }) + connector_type_id = ".cases-webhook" +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + description = "Updated test security detection rule with connector action" + type = "query" + severity = "high" + risk_score = 75 + enabled = true + query = "user.name:*" + language = "kuery" + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + data_view_id = "updated-connector-action-data-view-id" + namespace = "updated-connector-action-namespace" + + tags = ["test", "terraform"] + + risk_score_mapping = [ + { + field = "user.privileged" + operator = "equals" + value = "true" + risk_score = 95 + } + ] + + actions = [ + { + action_type_id = ".cases-webhook" + id = "${elasticstack_kibana_action_connector.test.connector_id}" + params = { + message = "UPDATED CRITICAL Alert: Security event detected" + } + group = "default" + frequency = { + notify_when = "onActiveAlert" + summary = true + throttle = "5m" + } + } + ] + + exceptions_list = [ + { + id = "test-action-exception" + list_id = "action-rule-exceptions" + namespace_type = "single" + type = "detection" + } + ] +} +`, name) +} + +func TestAccResourceSecurityDetectionRule_BuildingBlockType(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_buildingBlockType("test-building-block-rule"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-building-block-rule"), + resource.TestCheckResourceAttr(resourceName, "type", "query"), + resource.TestCheckResourceAttr(resourceName, "query", "process.name:*"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Test building block security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "building-block-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "building-block-namespace"), + resource.TestCheckResourceAttr(resourceName, "building_block_type", "default"), + + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_buildingBlockTypeUpdate("test-building-block-rule-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-building-block-rule-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "process.name:* AND user.name:*"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated test building block security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "40"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-building-block-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "updated-building-block-namespace"), + resource.TestCheckResourceAttr(resourceName, "building_block_type", "default"), + resource.TestCheckResourceAttr(resourceName, "tags.#", "2"), + resource.TestCheckResourceAttr(resourceName, "tags.0", "building-block"), + resource.TestCheckResourceAttr(resourceName, "tags.1", "test"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minResponseActionVersionSupport), + Config: testAccSecurityDetectionRuleConfig_buildingBlockTypeRemoved("test-building-block-rule-no-type"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-building-block-rule-no-type"), + resource.TestCheckResourceAttr(resourceName, "description", "Test rule without building block type"), + resource.TestCheckResourceAttr(resourceName, "data_view_id", "no-building-block-data-view-id"), + resource.TestCheckResourceAttr(resourceName, "namespace", "no-building-block-namespace"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + ), + }, + }, + }) +} + +func testAccSecurityDetectionRuleConfig_buildingBlockType(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "process.name:*" + language = "kuery" + enabled = true + description = "Test building block security detection rule" + severity = "low" + risk_score = 21 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + data_view_id = "building-block-data-view-id" + namespace = "building-block-namespace" + building_block_type = "default" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_buildingBlockTypeUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "process.name:* AND user.name:*" + language = "kuery" + enabled = true + description = "Updated test building block security detection rule" + severity = "medium" + risk_score = 40 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + data_view_id = "updated-building-block-data-view-id" + namespace = "updated-building-block-namespace" + building_block_type = "default" + author = ["Test Author"] + tags = ["building-block", "test"] + license = "Elastic License v2" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_buildingBlockTypeRemoved(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "process.name:*" + language = "kuery" + enabled = true + description = "Test rule without building block type" + severity = "medium" + risk_score = 50 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + data_view_id = "no-building-block-data-view-id" + namespace = "no-building-block-namespace" +} +`, name) +} + +func TestAccResourceSecurityDetectionRule_QueryMinimal(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_queryMinimal("test-query-rule-minimal"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-query-rule-minimal"), + resource.TestCheckResourceAttr(resourceName, "type", "query"), + resource.TestCheckResourceAttr(resourceName, "query", "*:*"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Minimal test query security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + + // Verify only required fields are set + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + + // Verify optional fields are not set + resource.TestCheckNoResourceAttr(resourceName, "data_view_id"), + resource.TestCheckNoResourceAttr(resourceName, "namespace"), + resource.TestCheckNoResourceAttr(resourceName, "rule_name_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override_fallback_disabled"), + resource.TestCheckNoResourceAttr(resourceName, "filters"), + resource.TestCheckNoResourceAttr(resourceName, "investigation_fields"), + resource.TestCheckNoResourceAttr(resourceName, "risk_score_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations"), + resource.TestCheckNoResourceAttr(resourceName, "required_fields"), + resource.TestCheckNoResourceAttr(resourceName, "severity_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "response_actions"), + resource.TestCheckNoResourceAttr(resourceName, "alert_suppression"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_queryMinimalUpdate("test-query-rule-minimal-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-query-rule-minimal-updated"), + resource.TestCheckResourceAttr(resourceName, "type", "query"), + resource.TestCheckResourceAttr(resourceName, "query", "event.category:authentication"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "false"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated minimal test query security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "index.1", "winlogbeat-*"), + + // Verify required fields are still set + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + + // Verify optional fields are still not set + resource.TestCheckNoResourceAttr(resourceName, "data_view_id"), + resource.TestCheckNoResourceAttr(resourceName, "namespace"), + resource.TestCheckNoResourceAttr(resourceName, "rule_name_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override_fallback_disabled"), + resource.TestCheckNoResourceAttr(resourceName, "filters"), + resource.TestCheckNoResourceAttr(resourceName, "investigation_fields"), + resource.TestCheckNoResourceAttr(resourceName, "risk_score_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations"), + resource.TestCheckNoResourceAttr(resourceName, "required_fields"), + resource.TestCheckNoResourceAttr(resourceName, "severity_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "response_actions"), + resource.TestCheckNoResourceAttr(resourceName, "alert_suppression"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + ), + }, + }, + }) +} + +func testAccSecurityDetectionRuleConfig_queryMinimal(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "*:*" + language = "kuery" + enabled = true + description = "Minimal test query security detection rule" + severity = "low" + risk_score = 21 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_queryMinimalUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "event.category:authentication" + language = "kuery" + enabled = false + description = "Updated minimal test query security detection rule" + severity = "medium" + risk_score = 55 + from = "now-12m" + to = "now" + interval = "10m" + index = ["logs-*", "winlogbeat-*"] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_queryRemoveFilters(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "*:*" + language = "kuery" + enabled = true + description = "Test query rule with filters removed" + severity = "medium" + risk_score = 55 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + data_view_id = "no-filters-data-view-id" + namespace = "no-filters-namespace" + + # Note: No filters field specified - this tests removing filters from a rule +} +`, name) +} + +func TestAccResourceSecurityDetectionRule_EQLMinimal(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_eqlMinimal("test-eql-rule-minimal"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-eql-rule-minimal"), + resource.TestCheckResourceAttr(resourceName, "type", "eql"), + resource.TestCheckResourceAttr(resourceName, "query", "process where process.name == \"cmd.exe\""), + resource.TestCheckResourceAttr(resourceName, "language", "eql"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Minimal test EQL security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), + resource.TestCheckResourceAttr(resourceName, "index.0", "winlogbeat-*"), + + // Verify only required fields are set + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + + // Verify optional fields are not set + resource.TestCheckNoResourceAttr(resourceName, "data_view_id"), + resource.TestCheckNoResourceAttr(resourceName, "namespace"), + resource.TestCheckNoResourceAttr(resourceName, "rule_name_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override_fallback_disabled"), + resource.TestCheckNoResourceAttr(resourceName, "filters"), + resource.TestCheckNoResourceAttr(resourceName, "investigation_fields"), + resource.TestCheckNoResourceAttr(resourceName, "risk_score_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations"), + resource.TestCheckNoResourceAttr(resourceName, "required_fields"), + resource.TestCheckNoResourceAttr(resourceName, "severity_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "response_actions"), + resource.TestCheckNoResourceAttr(resourceName, "alert_suppression"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + resource.TestCheckNoResourceAttr(resourceName, "tiebreaker_field"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_eqlMinimalUpdate("test-eql-rule-minimal-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-eql-rule-minimal-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "process where process.name == \"powershell.exe\""), + resource.TestCheckResourceAttr(resourceName, "description", "Updated minimal test EQL security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_ESQLMinimal(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_esqlMinimal("test-esql-rule-minimal"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-esql-rule-minimal"), + resource.TestCheckResourceAttr(resourceName, "type", "esql"), + resource.TestCheckResourceAttr(resourceName, "query", "FROM logs-* | WHERE event.action == \"login\" | STATS count(*) BY user.name"), + resource.TestCheckResourceAttr(resourceName, "language", "esql"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Minimal test ESQL security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), + + // Verify only required fields are set + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + + // Verify optional fields are not set + resource.TestCheckNoResourceAttr(resourceName, "data_view_id"), + resource.TestCheckNoResourceAttr(resourceName, "namespace"), + resource.TestCheckNoResourceAttr(resourceName, "rule_name_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override_fallback_disabled"), + resource.TestCheckNoResourceAttr(resourceName, "filters"), + resource.TestCheckNoResourceAttr(resourceName, "investigation_fields"), + resource.TestCheckNoResourceAttr(resourceName, "risk_score_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations"), + resource.TestCheckNoResourceAttr(resourceName, "required_fields"), + resource.TestCheckNoResourceAttr(resourceName, "severity_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "response_actions"), + resource.TestCheckNoResourceAttr(resourceName, "alert_suppression"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + // Note: index is not checked for ESQL as it doesn't use index patterns + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_esqlMinimalUpdate("test-esql-rule-minimal-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-esql-rule-minimal-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "FROM logs-* | WHERE event.action == \"logout\" | STATS count(*) BY user.name, source.ip"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated minimal test ESQL security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_MachineLearningMinimal(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_machineLearningMinimal("test-ml-rule-minimal"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-ml-rule-minimal"), + resource.TestCheckResourceAttr(resourceName, "type", "machine_learning"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Minimal test ML security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), + resource.TestCheckResourceAttr(resourceName, "anomaly_threshold", "75"), + resource.TestCheckResourceAttr(resourceName, "machine_learning_job_id.0", "test-ml-job"), + + // Verify only required fields are set + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + + // Verify optional fields are not set + resource.TestCheckNoResourceAttr(resourceName, "data_view_id"), + resource.TestCheckNoResourceAttr(resourceName, "namespace"), + resource.TestCheckNoResourceAttr(resourceName, "rule_name_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override_fallback_disabled"), + resource.TestCheckNoResourceAttr(resourceName, "filters"), + resource.TestCheckNoResourceAttr(resourceName, "investigation_fields"), + resource.TestCheckNoResourceAttr(resourceName, "risk_score_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations"), + resource.TestCheckNoResourceAttr(resourceName, "required_fields"), + resource.TestCheckNoResourceAttr(resourceName, "severity_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "response_actions"), + resource.TestCheckNoResourceAttr(resourceName, "alert_suppression"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_machineLearningMinimalUpdate("test-ml-rule-minimal-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-ml-rule-minimal-updated"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated minimal test ML security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + resource.TestCheckResourceAttr(resourceName, "anomaly_threshold", "80"), + resource.TestCheckResourceAttr(resourceName, "machine_learning_job_id.0", "test-ml-job"), + resource.TestCheckResourceAttr(resourceName, "machine_learning_job_id.1", "test-ml-job-2"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_NewTermsMinimal(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_newTermsMinimal("test-new-terms-rule-minimal"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-new-terms-rule-minimal"), + resource.TestCheckResourceAttr(resourceName, "type", "new_terms"), + resource.TestCheckResourceAttr(resourceName, "query", "user.name:*"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Minimal test new terms security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "new_terms_fields.0", "user.name"), + resource.TestCheckResourceAttr(resourceName, "history_window_start", "now-14d"), + + // Verify only required fields are set + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + + // Verify optional fields are not set + resource.TestCheckNoResourceAttr(resourceName, "data_view_id"), + resource.TestCheckNoResourceAttr(resourceName, "namespace"), + resource.TestCheckNoResourceAttr(resourceName, "rule_name_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override_fallback_disabled"), + resource.TestCheckNoResourceAttr(resourceName, "filters"), + resource.TestCheckNoResourceAttr(resourceName, "investigation_fields"), + resource.TestCheckNoResourceAttr(resourceName, "risk_score_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations"), + resource.TestCheckNoResourceAttr(resourceName, "required_fields"), + resource.TestCheckNoResourceAttr(resourceName, "severity_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "response_actions"), + resource.TestCheckNoResourceAttr(resourceName, "alert_suppression"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_newTermsMinimalUpdate("test-new-terms-rule-minimal-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-new-terms-rule-minimal-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "host.name:*"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated minimal test new terms security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + resource.TestCheckResourceAttr(resourceName, "new_terms_fields.0", "host.name"), + resource.TestCheckResourceAttr(resourceName, "history_window_start", "now-7d"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_SavedQueryMinimal(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_savedQueryMinimal("test-saved-query-rule-minimal"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-saved-query-rule-minimal"), + resource.TestCheckResourceAttr(resourceName, "type", "saved_query"), + resource.TestCheckResourceAttr(resourceName, "query", "*:*"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Minimal test saved query security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "saved_id", "test-saved-query-id"), + + // Verify only required fields are set + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + + // Verify optional fields are not set + resource.TestCheckNoResourceAttr(resourceName, "data_view_id"), + resource.TestCheckNoResourceAttr(resourceName, "namespace"), + resource.TestCheckNoResourceAttr(resourceName, "rule_name_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override_fallback_disabled"), + resource.TestCheckNoResourceAttr(resourceName, "filters"), + resource.TestCheckNoResourceAttr(resourceName, "investigation_fields"), + resource.TestCheckNoResourceAttr(resourceName, "risk_score_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations"), + resource.TestCheckNoResourceAttr(resourceName, "required_fields"), + resource.TestCheckNoResourceAttr(resourceName, "severity_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "response_actions"), + resource.TestCheckNoResourceAttr(resourceName, "alert_suppression"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_savedQueryMinimalUpdate("test-saved-query-rule-minimal-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-saved-query-rule-minimal-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "event.category:authentication"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated minimal test saved query security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + resource.TestCheckResourceAttr(resourceName, "saved_id", "test-saved-query-id-updated"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_ThreatMatchMinimal(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_threatMatchMinimal("test-threat-match-rule-minimal"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-threat-match-rule-minimal"), + resource.TestCheckResourceAttr(resourceName, "type", "threat_match"), + resource.TestCheckResourceAttr(resourceName, "query", "destination.ip:*"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Minimal test threat match security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "threat_index.0", "threat-intel-*"), + resource.TestCheckResourceAttr(resourceName, "threat_query", "threat.indicator.type:ip"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.0.entries.0.field", "destination.ip"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.0.entries.0.type", "mapping"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.0.entries.0.value", "threat.indicator.ip"), + + // Verify only required fields are set + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + + // Verify optional fields are not set + resource.TestCheckNoResourceAttr(resourceName, "data_view_id"), + resource.TestCheckNoResourceAttr(resourceName, "namespace"), + resource.TestCheckNoResourceAttr(resourceName, "rule_name_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override_fallback_disabled"), + resource.TestCheckNoResourceAttr(resourceName, "filters"), + resource.TestCheckNoResourceAttr(resourceName, "investigation_fields"), + resource.TestCheckNoResourceAttr(resourceName, "risk_score_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations"), + resource.TestCheckNoResourceAttr(resourceName, "required_fields"), + resource.TestCheckNoResourceAttr(resourceName, "severity_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "response_actions"), + resource.TestCheckNoResourceAttr(resourceName, "alert_suppression"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_threatMatchMinimalUpdate("test-threat-match-rule-minimal-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-threat-match-rule-minimal-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "source.ip:*"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated minimal test threat match security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + resource.TestCheckResourceAttr(resourceName, "threat_query", "threat.indicator.type:domain"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.0.entries.0.field", "source.ip"), + resource.TestCheckResourceAttr(resourceName, "threat_mapping.0.entries.0.value", "threat.indicator.domain"), + ), + }, + }, + }) +} + +func TestAccResourceSecurityDetectionRule_ThresholdMinimal(t *testing.T) { + resourceName := "elasticstack_kibana_security_detection_rule.test" + + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + CheckDestroy: testAccCheckSecurityDetectionRuleDestroy, + Steps: []resource.TestStep{ + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_thresholdMinimal("test-threshold-rule-minimal"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-threshold-rule-minimal"), + resource.TestCheckResourceAttr(resourceName, "type", "threshold"), + resource.TestCheckResourceAttr(resourceName, "query", "event.action:login"), + resource.TestCheckResourceAttr(resourceName, "language", "kuery"), + resource.TestCheckResourceAttr(resourceName, "enabled", "true"), + resource.TestCheckResourceAttr(resourceName, "description", "Minimal test threshold security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "low"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), + resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), + resource.TestCheckResourceAttr(resourceName, "threshold.value", "10"), + resource.TestCheckResourceAttr(resourceName, "threshold.field.0", "user.name"), + + // Verify only required fields are set + resource.TestCheckResourceAttrSet(resourceName, "id"), + resource.TestCheckResourceAttrSet(resourceName, "rule_id"), + resource.TestCheckResourceAttrSet(resourceName, "created_at"), + resource.TestCheckResourceAttrSet(resourceName, "created_by"), + + // Verify optional fields are not set + resource.TestCheckNoResourceAttr(resourceName, "data_view_id"), + resource.TestCheckNoResourceAttr(resourceName, "namespace"), + resource.TestCheckNoResourceAttr(resourceName, "rule_name_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override"), + resource.TestCheckNoResourceAttr(resourceName, "timestamp_override_fallback_disabled"), + resource.TestCheckNoResourceAttr(resourceName, "filters"), + resource.TestCheckNoResourceAttr(resourceName, "investigation_fields"), + resource.TestCheckNoResourceAttr(resourceName, "risk_score_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "related_integrations"), + resource.TestCheckNoResourceAttr(resourceName, "required_fields"), + resource.TestCheckNoResourceAttr(resourceName, "severity_mapping"), + resource.TestCheckNoResourceAttr(resourceName, "response_actions"), + resource.TestCheckNoResourceAttr(resourceName, "alert_suppression"), + resource.TestCheckNoResourceAttr(resourceName, "building_block_type"), + ), + }, + { + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Config: testAccSecurityDetectionRuleConfig_thresholdMinimalUpdate("test-threshold-rule-minimal-updated"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceName, "name", "test-threshold-rule-minimal-updated"), + resource.TestCheckResourceAttr(resourceName, "query", "event.action:logout"), + resource.TestCheckResourceAttr(resourceName, "description", "Updated minimal test threshold security detection rule"), + resource.TestCheckResourceAttr(resourceName, "severity", "medium"), + resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), + resource.TestCheckResourceAttr(resourceName, "threshold.value", "20"), + resource.TestCheckResourceAttr(resourceName, "threshold.field.0", "host.name"), + ), + }, + }, + }) +} + +func testAccSecurityDetectionRuleConfig_eqlMinimal(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "eql" + query = "process where process.name == \"cmd.exe\"" + language = "eql" + enabled = true + description = "Minimal test EQL security detection rule" + severity = "low" + risk_score = 21 + from = "now-6m" + to = "now" + interval = "5m" + index = ["winlogbeat-*"] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_eqlMinimalUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "eql" + query = "process where process.name == \"powershell.exe\"" + language = "eql" + enabled = true + description = "Updated minimal test EQL security detection rule" + severity = "medium" + risk_score = 55 + from = "now-12m" + to = "now" + interval = "10m" + index = ["winlogbeat-*", "sysmon-*"] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_esqlMinimal(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "esql" + query = "FROM logs-* | WHERE event.action == \"login\" | STATS count(*) BY user.name" + language = "esql" + enabled = true + description = "Minimal test ESQL security detection rule" + severity = "low" + risk_score = 21 + from = "now-6m" + to = "now" + interval = "5m" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_esqlMinimalUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "esql" + query = "FROM logs-* | WHERE event.action == \"logout\" | STATS count(*) BY user.name, source.ip" + language = "esql" + enabled = false + description = "Updated minimal test ESQL security detection rule" + severity = "medium" + risk_score = 55 + from = "now-12m" + to = "now" + interval = "10m" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_machineLearningMinimal(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "machine_learning" + enabled = true + description = "Minimal test ML security detection rule" + severity = "low" + risk_score = 21 + from = "now-6m" + to = "now" + interval = "5m" + anomaly_threshold = 75 + machine_learning_job_id = ["test-ml-job"] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_machineLearningMinimalUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "machine_learning" + enabled = false + description = "Updated minimal test ML security detection rule" + severity = "medium" + risk_score = 55 + from = "now-12m" + to = "now" + interval = "10m" + anomaly_threshold = 80 + machine_learning_job_id = ["test-ml-job", "test-ml-job-2"] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_newTermsMinimal(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "new_terms" + query = "user.name:*" + language = "kuery" + enabled = true + description = "Minimal test new terms security detection rule" + severity = "low" + risk_score = 21 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + new_terms_fields = ["user.name"] + history_window_start = "now-14d" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_newTermsMinimalUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "new_terms" + query = "host.name:*" + language = "kuery" + enabled = false + description = "Updated minimal test new terms security detection rule" + severity = "medium" + risk_score = 55 + from = "now-12m" + to = "now" + interval = "10m" + index = ["logs-*", "winlogbeat-*"] + new_terms_fields = ["host.name"] + history_window_start = "now-7d" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_savedQueryMinimal(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "saved_query" + query = "*:*" + enabled = true + description = "Minimal test saved query security detection rule" + severity = "low" + risk_score = 21 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + saved_id = "test-saved-query-id" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_savedQueryMinimalUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "saved_query" + query = "event.category:authentication" + enabled = false + description = "Updated minimal test saved query security detection rule" + severity = "medium" + risk_score = 55 + from = "now-12m" + to = "now" + interval = "10m" + index = ["logs-*", "winlogbeat-*"] + saved_id = "test-saved-query-id-updated" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_threatMatchMinimal(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "threat_match" + query = "destination.ip:*" + language = "kuery" + enabled = true + description = "Minimal test threat match security detection rule" + severity = "low" + risk_score = 21 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + threat_index = ["threat-intel-*"] + threat_query = "threat.indicator.type:ip" + + threat_mapping = [ + { + entries = [ + { + field = "destination.ip" + type = "mapping" + value = "threat.indicator.ip" + } + ] + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_threatMatchMinimalUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "threat_match" + query = "source.ip:*" + language = "kuery" + enabled = false + description = "Updated minimal test threat match security detection rule" + severity = "medium" + risk_score = 55 + from = "now-12m" + to = "now" + interval = "10m" + index = ["logs-*", "winlogbeat-*"] + threat_index = ["threat-intel-*", "misp-*"] + threat_query = "threat.indicator.type:domain" + + threat_mapping = [ + { + entries = [ + { + field = "source.ip" + type = "mapping" + value = "threat.indicator.domain" + } + ] + } + ] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_thresholdMinimal(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "threshold" + query = "event.action:login" + language = "kuery" + enabled = true + description = "Minimal test threshold security detection rule" + severity = "low" + risk_score = 21 + from = "now-6m" + to = "now" + interval = "5m" + index = ["logs-*"] + + threshold = { + value = 10 + field = ["user.name"] + } +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_thresholdMinimalUpdate(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "threshold" + query = "event.action:logout" + language = "kuery" + enabled = false + description = "Updated minimal test threshold security detection rule" + severity = "medium" + risk_score = 55 + from = "now-12m" + to = "now" + interval = "10m" + index = ["logs-*", "winlogbeat-*"] + + threshold = { + value = 20 + field = ["host.name"] + } +} +`, name) +} diff --git a/internal/kibana/security_detection_rule/create.go b/internal/kibana/security_detection_rule/create.go new file mode 100644 index 000000000..7a0e0e6ee --- /dev/null +++ b/internal/kibana/security_detection_rule/create.go @@ -0,0 +1,88 @@ +package security_detection_rule + +import ( + "context" + "fmt" + + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/hashicorp/terraform-plugin-framework/resource" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +func (r *securityDetectionRuleResource) Create(ctx context.Context, req resource.CreateRequest, resp *resource.CreateResponse) { + var data SecurityDetectionRuleData + + resp.Diagnostics.Append(req.Plan.Get(ctx, &data)...) + if resp.Diagnostics.HasError() { + return + } + + // Create the rule using kbapi client + kbClient, err := r.client.GetKibanaOapiClient() + if err != nil { + resp.Diagnostics.AddError( + "Error getting Kibana client", + "Could not get Kibana OAPI client: "+err.Error(), + ) + return + } + + // Build the create request + createProps, diags := data.toCreateProps(ctx, r.client) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + // Create the rule + response, err := kbClient.API.CreateRuleWithResponse(ctx, createProps) + if err != nil { + resp.Diagnostics.AddError( + "Error creating security detection rule", + "Could not create security detection rule: "+err.Error(), + ) + return + } + + if response.StatusCode() != 200 { + resp.Diagnostics.AddError( + "Error creating security detection rule", + fmt.Sprintf("API returned status %d: %s", response.StatusCode(), string(response.Body)), + ) + return + } + + // Parse the response to get the ID, then use Read logic for consistency + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + // Set the ID based on the created rule + id, diags := extractId(response.JSON200) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + compId := clients.CompositeId{ + ClusterId: data.SpaceId.ValueString(), + ResourceId: id, + } + data.Id = types.StringValue(compId.String()) + readData, diags := r.read(ctx, id, data.SpaceId.ValueString()) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + if readData == nil { + resp.Diagnostics.AddError( + "Error reading created security detection rule", + "Could not read security detection rule after creation", + ) + return + } + + resp.Diagnostics.Append(resp.State.Set(ctx, readData)...) +} diff --git a/internal/kibana/security_detection_rule/delete.go b/internal/kibana/security_detection_rule/delete.go new file mode 100644 index 000000000..d2e1a48bb --- /dev/null +++ b/internal/kibana/security_detection_rule/delete.go @@ -0,0 +1,70 @@ +package security_detection_rule + +import ( + "context" + "fmt" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/resource" +) + +func (r *securityDetectionRuleResource) Delete(ctx context.Context, req resource.DeleteRequest, resp *resource.DeleteResponse) { + var data SecurityDetectionRuleData + + resp.Diagnostics.Append(req.State.Get(ctx, &data)...) + if resp.Diagnostics.HasError() { + return + } + + // Parse ID to get space_id and rule_id + compId, diags := clients.CompositeIdFromStrFw(data.Id.ValueString()) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + // Get the rule using kbapi client + kbClient, err := r.client.GetKibanaOapiClient() + if err != nil { + resp.Diagnostics.AddError( + "Error getting Kibana client", + "Could not get Kibana OAPI client: "+err.Error(), + ) + return + } + + // Delete the rule + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + resp.Diagnostics.AddError("ID was not a valid UUID", err.Error()) + return + } + ruleObjectId := kbapi.SecurityDetectionsAPIRuleObjectId(uid) + params := &kbapi.DeleteRuleParams{ + Id: &ruleObjectId, + } + + response, err := kbClient.API.DeleteRuleWithResponse(ctx, params) + if err != nil { + resp.Diagnostics.AddError( + "Error deleting security detection rule", + "Could not delete security detection rule: "+err.Error(), + ) + return + } + + if response.StatusCode() == 404 { + // Rule was already deleted, which is fine + return + } + + if response.StatusCode() != 200 { + resp.Diagnostics.AddError( + "Error deleting security detection rule", + fmt.Sprintf("API returned status %d: %s", response.StatusCode(), string(response.Body)), + ) + return + } +} diff --git a/internal/kibana/security_detection_rule/models.go b/internal/kibana/security_detection_rule/models.go new file mode 100644 index 000000000..f8a7791af --- /dev/null +++ b/internal/kibana/security_detection_rule/models.go @@ -0,0 +1,882 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/elastic/terraform-provider-elasticstack/internal/utils/customtypes" + "github.com/hashicorp/go-version" + "github.com/hashicorp/terraform-plugin-framework-jsontypes/jsontypes" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +// MinVersionResponseActions defines the minimum server version required for response actions +var MinVersionResponseActions = version.Must(version.NewVersion("8.16.0")) + +type SecurityDetectionRuleData struct { + Id types.String `tfsdk:"id"` + SpaceId types.String `tfsdk:"space_id"` + RuleId types.String `tfsdk:"rule_id"` + Name types.String `tfsdk:"name"` + Type types.String `tfsdk:"type"` + Query types.String `tfsdk:"query"` + Language types.String `tfsdk:"language"` + Index types.List `tfsdk:"index"` + Enabled types.Bool `tfsdk:"enabled"` + From types.String `tfsdk:"from"` + To types.String `tfsdk:"to"` + Interval types.String `tfsdk:"interval"` + + // Rule content + Description types.String `tfsdk:"description"` + RiskScore types.Int64 `tfsdk:"risk_score"` + RiskScoreMapping types.List `tfsdk:"risk_score_mapping"` + Severity types.String `tfsdk:"severity"` + SeverityMapping types.List `tfsdk:"severity_mapping"` + Author types.List `tfsdk:"author"` + Tags types.List `tfsdk:"tags"` + License types.String `tfsdk:"license"` + RelatedIntegrations types.List `tfsdk:"related_integrations"` + RequiredFields types.List `tfsdk:"required_fields"` + + // Optional fields + FalsePositives types.List `tfsdk:"false_positives"` + References types.List `tfsdk:"references"` + Note types.String `tfsdk:"note"` + Setup types.String `tfsdk:"setup"` + MaxSignals types.Int64 `tfsdk:"max_signals"` + Version types.Int64 `tfsdk:"version"` + + // Read-only fields + CreatedAt types.String `tfsdk:"created_at"` + CreatedBy types.String `tfsdk:"created_by"` + UpdatedAt types.String `tfsdk:"updated_at"` + UpdatedBy types.String `tfsdk:"updated_by"` + Revision types.Int64 `tfsdk:"revision"` + + // EQL-specific fields + TiebreakerField types.String `tfsdk:"tiebreaker_field"` + + // Machine Learning-specific fields + AnomalyThreshold types.Int64 `tfsdk:"anomaly_threshold"` + MachineLearningJobId types.List `tfsdk:"machine_learning_job_id"` + + // New Terms-specific fields + NewTermsFields types.List `tfsdk:"new_terms_fields"` + HistoryWindowStart types.String `tfsdk:"history_window_start"` + + // Saved Query-specific fields + SavedId types.String `tfsdk:"saved_id"` + + // Threat Match-specific fields + ThreatIndex types.List `tfsdk:"threat_index"` + ThreatQuery types.String `tfsdk:"threat_query"` + ThreatMapping types.List `tfsdk:"threat_mapping"` + ThreatFilters types.List `tfsdk:"threat_filters"` + ThreatIndicatorPath types.String `tfsdk:"threat_indicator_path"` + ConcurrentSearches types.Int64 `tfsdk:"concurrent_searches"` + ItemsPerSearch types.Int64 `tfsdk:"items_per_search"` + + // Threshold-specific fields + Threshold types.Object `tfsdk:"threshold"` + + // Optional timeline fields (common across multiple rule types) + TimelineId types.String `tfsdk:"timeline_id"` + TimelineTitle types.String `tfsdk:"timeline_title"` + + // Threat field (common across multiple rule types) + Threat types.List `tfsdk:"threat"` + + // Actions field (common across all rule types) + Actions types.List `tfsdk:"actions"` + + // Response actions field (common across all rule types) + ResponseActions types.List `tfsdk:"response_actions"` + + // Exceptions list field (common across all rule types) + ExceptionsList types.List `tfsdk:"exceptions_list"` + + // Alert suppression field (common across all rule types) + AlertSuppression types.Object `tfsdk:"alert_suppression"` + + // Building block type field (common across all rule types) + BuildingBlockType types.String `tfsdk:"building_block_type"` + + // Data view ID field (common across all rule types) + DataViewId types.String `tfsdk:"data_view_id"` + + // Namespace field (common across all rule types) + Namespace types.String `tfsdk:"namespace"` + + // Rule name override field (common across all rule types) + RuleNameOverride types.String `tfsdk:"rule_name_override"` + + // Timestamp override fields (common across all rule types) + TimestampOverride types.String `tfsdk:"timestamp_override"` + TimestampOverrideFallbackDisabled types.Bool `tfsdk:"timestamp_override_fallback_disabled"` + + // Investigation fields (common across all rule types) + InvestigationFields types.List `tfsdk:"investigation_fields"` + + // Filters field (common across all rule types) - Query and filter context array to define alert conditions + Filters jsontypes.Normalized `tfsdk:"filters"` +} +type SecurityDetectionRuleTfData struct { + ThreatMapping types.List `tfsdk:"threat_mapping"` +} + +type SecurityDetectionRuleTfDataItem struct { + Entries types.List `tfsdk:"entries"` +} + +type SecurityDetectionRuleTfDataItemEntry struct { + Field types.String `tfsdk:"field"` + Type types.String `tfsdk:"type"` + Value types.String `tfsdk:"value"` +} + +type ThresholdModel struct { + Field types.List `tfsdk:"field"` + Value types.Int64 `tfsdk:"value"` + Cardinality types.List `tfsdk:"cardinality"` +} + +type AlertSuppressionModel struct { + GroupBy types.List `tfsdk:"group_by"` + Duration customtypes.Duration `tfsdk:"duration"` + MissingFieldsStrategy types.String `tfsdk:"missing_fields_strategy"` +} + +type CardinalityModel struct { + Field types.String `tfsdk:"field"` + Value types.Int64 `tfsdk:"value"` +} + +type ActionModel struct { + ActionTypeId types.String `tfsdk:"action_type_id"` + Id types.String `tfsdk:"id"` + Params types.Map `tfsdk:"params"` + Group types.String `tfsdk:"group"` + Uuid types.String `tfsdk:"uuid"` + AlertsFilter types.Map `tfsdk:"alerts_filter"` + Frequency types.Object `tfsdk:"frequency"` +} + +type ActionFrequencyModel struct { + NotifyWhen types.String `tfsdk:"notify_when"` + Summary types.Bool `tfsdk:"summary"` + Throttle types.String `tfsdk:"throttle"` +} + +type ResponseActionModel struct { + ActionTypeId types.String `tfsdk:"action_type_id"` + Params types.Object `tfsdk:"params"` +} + +type ResponseActionParamsModel struct { + // Osquery params + Query types.String `tfsdk:"query"` + PackId types.String `tfsdk:"pack_id"` + SavedQueryId types.String `tfsdk:"saved_query_id"` + Timeout types.Int64 `tfsdk:"timeout"` + EcsMapping types.Map `tfsdk:"ecs_mapping"` + Queries types.List `tfsdk:"queries"` + + // Endpoint params + Command types.String `tfsdk:"command"` + Comment types.String `tfsdk:"comment"` + Config types.Object `tfsdk:"config"` +} + +type OsqueryQueryModel struct { + Id types.String `tfsdk:"id"` + Query types.String `tfsdk:"query"` + Platform types.String `tfsdk:"platform"` + Version types.String `tfsdk:"version"` + Removed types.Bool `tfsdk:"removed"` + Snapshot types.Bool `tfsdk:"snapshot"` + EcsMapping types.Map `tfsdk:"ecs_mapping"` +} + +type EndpointProcessConfigModel struct { + Field types.String `tfsdk:"field"` + Overwrite types.Bool `tfsdk:"overwrite"` +} + +type ExceptionsListModel struct { + Id types.String `tfsdk:"id"` + ListId types.String `tfsdk:"list_id"` + NamespaceType types.String `tfsdk:"namespace_type"` + Type types.String `tfsdk:"type"` +} + +type RiskScoreMappingModel struct { + Field types.String `tfsdk:"field"` + Operator types.String `tfsdk:"operator"` + Value types.String `tfsdk:"value"` + RiskScore types.Int64 `tfsdk:"risk_score"` +} + +type RelatedIntegrationModel struct { + Package types.String `tfsdk:"package"` + Version types.String `tfsdk:"version"` + Integration types.String `tfsdk:"integration"` +} + +type RequiredFieldModel struct { + Name types.String `tfsdk:"name"` + Type types.String `tfsdk:"type"` + Ecs types.Bool `tfsdk:"ecs"` +} + +type SeverityMappingModel struct { + Field types.String `tfsdk:"field"` + Operator types.String `tfsdk:"operator"` + Value types.String `tfsdk:"value"` + Severity types.String `tfsdk:"severity"` +} + +// CommonCreateProps holds all the field pointers for setting common create properties +type CommonCreateProps struct { + Actions **[]kbapi.SecurityDetectionsAPIRuleAction + ResponseActions **[]kbapi.SecurityDetectionsAPIResponseAction + RuleId **kbapi.SecurityDetectionsAPIRuleSignatureId + Enabled **kbapi.SecurityDetectionsAPIIsRuleEnabled + From **kbapi.SecurityDetectionsAPIRuleIntervalFrom + To **kbapi.SecurityDetectionsAPIRuleIntervalTo + Interval **kbapi.SecurityDetectionsAPIRuleInterval + Index **[]string + Author **[]string + Tags **[]string + FalsePositives **[]string + References **[]string + License **kbapi.SecurityDetectionsAPIRuleLicense + Note **kbapi.SecurityDetectionsAPIInvestigationGuide + Setup **kbapi.SecurityDetectionsAPISetupGuide + MaxSignals **kbapi.SecurityDetectionsAPIMaxSignals + Version **kbapi.SecurityDetectionsAPIRuleVersion + ExceptionsList **[]kbapi.SecurityDetectionsAPIRuleExceptionList + AlertSuppression **kbapi.SecurityDetectionsAPIAlertSuppression + RiskScoreMapping **kbapi.SecurityDetectionsAPIRiskScoreMapping + SeverityMapping **kbapi.SecurityDetectionsAPISeverityMapping + RelatedIntegrations **kbapi.SecurityDetectionsAPIRelatedIntegrationArray + RequiredFields **[]kbapi.SecurityDetectionsAPIRequiredFieldInput + BuildingBlockType **kbapi.SecurityDetectionsAPIBuildingBlockType + DataViewId **kbapi.SecurityDetectionsAPIDataViewId + Namespace **kbapi.SecurityDetectionsAPIAlertsIndexNamespace + RuleNameOverride **kbapi.SecurityDetectionsAPIRuleNameOverride + TimestampOverride **kbapi.SecurityDetectionsAPITimestampOverride + TimestampOverrideFallbackDisabled **kbapi.SecurityDetectionsAPITimestampOverrideFallbackDisabled + InvestigationFields **kbapi.SecurityDetectionsAPIInvestigationFields + Filters **kbapi.SecurityDetectionsAPIRuleFilterArray +} + +// CommonUpdateProps holds all the field pointers for setting common update properties +type CommonUpdateProps struct { + Actions **[]kbapi.SecurityDetectionsAPIRuleAction + ResponseActions **[]kbapi.SecurityDetectionsAPIResponseAction + RuleId **kbapi.SecurityDetectionsAPIRuleSignatureId + Enabled **kbapi.SecurityDetectionsAPIIsRuleEnabled + From **kbapi.SecurityDetectionsAPIRuleIntervalFrom + To **kbapi.SecurityDetectionsAPIRuleIntervalTo + Interval **kbapi.SecurityDetectionsAPIRuleInterval + Index **[]string + Author **[]string + Tags **[]string + FalsePositives **[]string + References **[]string + License **kbapi.SecurityDetectionsAPIRuleLicense + Note **kbapi.SecurityDetectionsAPIInvestigationGuide + Setup **kbapi.SecurityDetectionsAPISetupGuide + MaxSignals **kbapi.SecurityDetectionsAPIMaxSignals + Version **kbapi.SecurityDetectionsAPIRuleVersion + ExceptionsList **[]kbapi.SecurityDetectionsAPIRuleExceptionList + AlertSuppression **kbapi.SecurityDetectionsAPIAlertSuppression + RiskScoreMapping **kbapi.SecurityDetectionsAPIRiskScoreMapping + SeverityMapping **kbapi.SecurityDetectionsAPISeverityMapping + RelatedIntegrations **kbapi.SecurityDetectionsAPIRelatedIntegrationArray + RequiredFields **[]kbapi.SecurityDetectionsAPIRequiredFieldInput + BuildingBlockType **kbapi.SecurityDetectionsAPIBuildingBlockType + DataViewId **kbapi.SecurityDetectionsAPIDataViewId + Namespace **kbapi.SecurityDetectionsAPIAlertsIndexNamespace + RuleNameOverride **kbapi.SecurityDetectionsAPIRuleNameOverride + TimestampOverride **kbapi.SecurityDetectionsAPITimestampOverride + TimestampOverrideFallbackDisabled **kbapi.SecurityDetectionsAPITimestampOverrideFallbackDisabled + InvestigationFields **kbapi.SecurityDetectionsAPIInvestigationFields + Filters **kbapi.SecurityDetectionsAPIRuleFilterArray +} + +// Helper function to set common properties across all rule types +func (d SecurityDetectionRuleData) setCommonCreateProps( + ctx context.Context, + props *CommonCreateProps, + diags *diag.Diagnostics, + client clients.MinVersionEnforceable, +) { + // Set optional rule_id if provided + if props.RuleId != nil && utils.IsKnown(d.RuleId) { + id := kbapi.SecurityDetectionsAPIRuleSignatureId(d.RuleId.ValueString()) + *props.RuleId = &id + } + + // Set enabled status + if props.Enabled != nil && utils.IsKnown(d.Enabled) { + isEnabled := kbapi.SecurityDetectionsAPIIsRuleEnabled(d.Enabled.ValueBool()) + *props.Enabled = &isEnabled + } + + // Set time range + if props.From != nil && utils.IsKnown(d.From) { + fromTime := kbapi.SecurityDetectionsAPIRuleIntervalFrom(d.From.ValueString()) + *props.From = &fromTime + } + + if props.To != nil && utils.IsKnown(d.To) { + toTime := kbapi.SecurityDetectionsAPIRuleIntervalTo(d.To.ValueString()) + *props.To = &toTime + } + + // Set interval + if props.Interval != nil && utils.IsKnown(d.Interval) { + intervalTime := kbapi.SecurityDetectionsAPIRuleInterval(d.Interval.ValueString()) + *props.Interval = &intervalTime + } + + // Set index patterns (if index pointer is provided) + if props.Index != nil && utils.IsKnown(d.Index) { + indexList := utils.ListTypeAs[string](ctx, d.Index, path.Root("index"), diags) + if !diags.HasError() && len(indexList) > 0 { + *props.Index = &indexList + } + } + + // Set author + if props.Author != nil && utils.IsKnown(d.Author) { + authorList := utils.ListTypeAs[string](ctx, d.Author, path.Root("author"), diags) + if !diags.HasError() && len(authorList) > 0 { + *props.Author = &authorList + } + } + + // Set tags + if props.Tags != nil && utils.IsKnown(d.Tags) { + tagsList := utils.ListTypeAs[string](ctx, d.Tags, path.Root("tags"), diags) + if !diags.HasError() && len(tagsList) > 0 { + *props.Tags = &tagsList + } + } + + // Set false positives + if props.FalsePositives != nil && utils.IsKnown(d.FalsePositives) { + fpList := utils.ListTypeAs[string](ctx, d.FalsePositives, path.Root("false_positives"), diags) + if !diags.HasError() && len(fpList) > 0 { + *props.FalsePositives = &fpList + } + } + + // Set references + if props.References != nil && utils.IsKnown(d.References) { + refList := utils.ListTypeAs[string](ctx, d.References, path.Root("references"), diags) + if !diags.HasError() && len(refList) > 0 { + *props.References = &refList + } + } + + // Set optional string fields + if props.License != nil && utils.IsKnown(d.License) { + ruleLicense := kbapi.SecurityDetectionsAPIRuleLicense(d.License.ValueString()) + *props.License = &ruleLicense + } + + if props.Note != nil && utils.IsKnown(d.Note) { + ruleNote := kbapi.SecurityDetectionsAPIInvestigationGuide(d.Note.ValueString()) + *props.Note = &ruleNote + } + + if props.Setup != nil && utils.IsKnown(d.Setup) { + ruleSetup := kbapi.SecurityDetectionsAPISetupGuide(d.Setup.ValueString()) + *props.Setup = &ruleSetup + } + + // Set max signals + if props.MaxSignals != nil && utils.IsKnown(d.MaxSignals) { + maxSig := kbapi.SecurityDetectionsAPIMaxSignals(d.MaxSignals.ValueInt64()) + *props.MaxSignals = &maxSig + } + + // Set version + if props.Version != nil && utils.IsKnown(d.Version) { + ruleVersion := kbapi.SecurityDetectionsAPIRuleVersion(d.Version.ValueInt64()) + *props.Version = &ruleVersion + } + + // Set actions + if props.Actions != nil && utils.IsKnown(d.Actions) { + actions, actionDiags := d.actionsToApi(ctx) + diags.Append(actionDiags...) + if !actionDiags.HasError() && len(actions) > 0 { + *props.Actions = &actions + } + } + + // Set exceptions list + if props.ExceptionsList != nil && utils.IsKnown(d.ExceptionsList) { + exceptionsList, exceptionsListDiags := d.exceptionsListToApi(ctx) + diags.Append(exceptionsListDiags...) + if !exceptionsListDiags.HasError() && len(exceptionsList) > 0 { + *props.ExceptionsList = &exceptionsList + } + } + + // Set risk score mapping + if props.RiskScoreMapping != nil && utils.IsKnown(d.RiskScoreMapping) { + riskScoreMapping, riskScoreMappingDiags := d.riskScoreMappingToApi(ctx) + diags.Append(riskScoreMappingDiags...) + if !riskScoreMappingDiags.HasError() && len(riskScoreMapping) > 0 { + *props.RiskScoreMapping = &riskScoreMapping + } + } + + // Set building block type + if props.BuildingBlockType != nil && utils.IsKnown(d.BuildingBlockType) { + buildingBlockType := kbapi.SecurityDetectionsAPIBuildingBlockType(d.BuildingBlockType.ValueString()) + *props.BuildingBlockType = &buildingBlockType + } + + // Set data view ID + if props.DataViewId != nil && utils.IsKnown(d.DataViewId) { + dataViewId := kbapi.SecurityDetectionsAPIDataViewId(d.DataViewId.ValueString()) + *props.DataViewId = &dataViewId + } + + // Set namespace + if props.Namespace != nil && utils.IsKnown(d.Namespace) { + namespace := kbapi.SecurityDetectionsAPIAlertsIndexNamespace(d.Namespace.ValueString()) + *props.Namespace = &namespace + } + + // Set rule name override + if props.RuleNameOverride != nil && utils.IsKnown(d.RuleNameOverride) { + ruleNameOverride := kbapi.SecurityDetectionsAPIRuleNameOverride(d.RuleNameOverride.ValueString()) + *props.RuleNameOverride = &ruleNameOverride + } + + // Set timestamp override + if props.TimestampOverride != nil && utils.IsKnown(d.TimestampOverride) { + timestampOverride := kbapi.SecurityDetectionsAPITimestampOverride(d.TimestampOverride.ValueString()) + *props.TimestampOverride = ×tampOverride + } + + // Set timestamp override fallback disabled + if props.TimestampOverrideFallbackDisabled != nil && utils.IsKnown(d.TimestampOverrideFallbackDisabled) { + timestampOverrideFallbackDisabled := kbapi.SecurityDetectionsAPITimestampOverrideFallbackDisabled(d.TimestampOverrideFallbackDisabled.ValueBool()) + *props.TimestampOverrideFallbackDisabled = ×tampOverrideFallbackDisabled + } + + // Set severity mapping + if props.SeverityMapping != nil && utils.IsKnown(d.SeverityMapping) { + severityMapping, severityMappingDiags := d.severityMappingToApi(ctx) + diags.Append(severityMappingDiags...) + if !severityMappingDiags.HasError() && severityMapping != nil && len(*severityMapping) > 0 { + *props.SeverityMapping = severityMapping + } + } + + // Set related integrations + if props.RelatedIntegrations != nil && utils.IsKnown(d.RelatedIntegrations) { + relatedIntegrations, relatedIntegrationsDiags := d.relatedIntegrationsToApi(ctx) + diags.Append(relatedIntegrationsDiags...) + if !relatedIntegrationsDiags.HasError() && relatedIntegrations != nil && len(*relatedIntegrations) > 0 { + *props.RelatedIntegrations = relatedIntegrations + } + } + + // Set required fields + if props.RequiredFields != nil && utils.IsKnown(d.RequiredFields) { + requiredFields, requiredFieldsDiags := d.requiredFieldsToApi(ctx) + diags.Append(requiredFieldsDiags...) + if !requiredFieldsDiags.HasError() && requiredFields != nil && len(*requiredFields) > 0 { + *props.RequiredFields = requiredFields + } + } + + // Set investigation fields + if props.InvestigationFields != nil { + investigationFields, investigationFieldsDiags := d.investigationFieldsToApi(ctx) + if !investigationFieldsDiags.HasError() && investigationFields != nil { + *props.InvestigationFields = investigationFields + } + diags.Append(investigationFieldsDiags...) + } + + // Set response actions + if props.ResponseActions != nil && utils.IsKnown(d.ResponseActions) { + responseActions, responseActionsDiags := d.responseActionsToApi(ctx, client) + diags.Append(responseActionsDiags...) + if !responseActionsDiags.HasError() && len(responseActions) > 0 { + *props.ResponseActions = &responseActions + } + } + + // Set filters + if props.Filters != nil && utils.IsKnown(d.Filters) { + filters, filtersDiags := d.filtersToApi(ctx) + diags.Append(filtersDiags...) + if !filtersDiags.HasError() && filters != nil { + *props.Filters = filters + } + } + + // Set alert suppression + if props.AlertSuppression != nil { + alertSuppression := d.alertSuppressionToApi(ctx, diags) + if alertSuppression != nil { + *props.AlertSuppression = alertSuppression + } + } +} + +// Helper function to set common update properties across all rule types +func (d SecurityDetectionRuleData) setCommonUpdateProps( + ctx context.Context, + props *CommonUpdateProps, + diags *diag.Diagnostics, + client clients.MinVersionEnforceable, +) { + // Set enabled status + if props.Enabled != nil && utils.IsKnown(d.Enabled) { + isEnabled := kbapi.SecurityDetectionsAPIIsRuleEnabled(d.Enabled.ValueBool()) + *props.Enabled = &isEnabled + } + + // Set time range + if props.From != nil && utils.IsKnown(d.From) { + fromTime := kbapi.SecurityDetectionsAPIRuleIntervalFrom(d.From.ValueString()) + *props.From = &fromTime + } + + if props.To != nil && utils.IsKnown(d.To) { + toTime := kbapi.SecurityDetectionsAPIRuleIntervalTo(d.To.ValueString()) + *props.To = &toTime + } + + // Set interval + if props.Interval != nil && utils.IsKnown(d.Interval) { + intervalTime := kbapi.SecurityDetectionsAPIRuleInterval(d.Interval.ValueString()) + *props.Interval = &intervalTime + } + + // Set index patterns (if index pointer is provided) + if props.Index != nil && utils.IsKnown(d.Index) { + indexList := utils.ListTypeAs[string](ctx, d.Index, path.Root("index"), diags) + if !diags.HasError() { + *props.Index = &indexList + } + } + + // Set author + if props.Author != nil && utils.IsKnown(d.Author) { + authorList := utils.ListTypeAs[string](ctx, d.Author, path.Root("author"), diags) + if !diags.HasError() { + *props.Author = &authorList + } + } + + // Set tags + if props.Tags != nil && utils.IsKnown(d.Tags) { + tagsList := utils.ListTypeAs[string](ctx, d.Tags, path.Root("tags"), diags) + if !diags.HasError() { + *props.Tags = &tagsList + } + } + + // Set false positives + if props.FalsePositives != nil && utils.IsKnown(d.FalsePositives) { + fpList := utils.ListTypeAs[string](ctx, d.FalsePositives, path.Root("false_positives"), diags) + if !diags.HasError() { + *props.FalsePositives = &fpList + } + } + + // Set references + if props.References != nil && utils.IsKnown(d.References) { + refList := utils.ListTypeAs[string](ctx, d.References, path.Root("references"), diags) + if !diags.HasError() { + *props.References = &refList + } + } + + // Set optional string fields + if props.License != nil && utils.IsKnown(d.License) { + ruleLicense := kbapi.SecurityDetectionsAPIRuleLicense(d.License.ValueString()) + *props.License = &ruleLicense + } + + if props.Note != nil && utils.IsKnown(d.Note) { + ruleNote := kbapi.SecurityDetectionsAPIInvestigationGuide(d.Note.ValueString()) + *props.Note = &ruleNote + } + + if props.Setup != nil && utils.IsKnown(d.Setup) { + ruleSetup := kbapi.SecurityDetectionsAPISetupGuide(d.Setup.ValueString()) + *props.Setup = &ruleSetup + } + + // Set max signals + if props.MaxSignals != nil && utils.IsKnown(d.MaxSignals) { + maxSig := kbapi.SecurityDetectionsAPIMaxSignals(d.MaxSignals.ValueInt64()) + *props.MaxSignals = &maxSig + } + + // Set version + if props.Version != nil && utils.IsKnown(d.Version) { + ruleVersion := kbapi.SecurityDetectionsAPIRuleVersion(d.Version.ValueInt64()) + *props.Version = &ruleVersion + } + + // Set actions + if props.Actions != nil && utils.IsKnown(d.Actions) { + actions, actionDiags := d.actionsToApi(ctx) + diags.Append(actionDiags...) + if !actionDiags.HasError() && len(actions) > 0 { + *props.Actions = &actions + } + } + + // Set exceptions list + if props.ExceptionsList != nil && utils.IsKnown(d.ExceptionsList) { + exceptionsList, exceptionsListDiags := d.exceptionsListToApi(ctx) + diags.Append(exceptionsListDiags...) + if !exceptionsListDiags.HasError() && len(exceptionsList) > 0 { + *props.ExceptionsList = &exceptionsList + } + } + + // Set risk score mapping + if props.RiskScoreMapping != nil && utils.IsKnown(d.RiskScoreMapping) { + riskScoreMapping, riskScoreMappingDiags := d.riskScoreMappingToApi(ctx) + diags.Append(riskScoreMappingDiags...) + if !riskScoreMappingDiags.HasError() && len(riskScoreMapping) > 0 { + *props.RiskScoreMapping = &riskScoreMapping + } + } + + // Set building block type + if props.BuildingBlockType != nil && utils.IsKnown(d.BuildingBlockType) { + buildingBlockType := kbapi.SecurityDetectionsAPIBuildingBlockType(d.BuildingBlockType.ValueString()) + *props.BuildingBlockType = &buildingBlockType + } + + // Set data view ID + if props.DataViewId != nil && utils.IsKnown(d.DataViewId) { + dataViewId := kbapi.SecurityDetectionsAPIDataViewId(d.DataViewId.ValueString()) + *props.DataViewId = &dataViewId + } + + // Set namespace + if props.Namespace != nil && utils.IsKnown(d.Namespace) { + namespace := kbapi.SecurityDetectionsAPIAlertsIndexNamespace(d.Namespace.ValueString()) + *props.Namespace = &namespace + } + + // Set rule name override + if props.RuleNameOverride != nil && utils.IsKnown(d.RuleNameOverride) { + ruleNameOverride := kbapi.SecurityDetectionsAPIRuleNameOverride(d.RuleNameOverride.ValueString()) + *props.RuleNameOverride = &ruleNameOverride + } + + // Set timestamp override + if props.TimestampOverride != nil && utils.IsKnown(d.TimestampOverride) { + timestampOverride := kbapi.SecurityDetectionsAPITimestampOverride(d.TimestampOverride.ValueString()) + *props.TimestampOverride = ×tampOverride + } + + // Set timestamp override fallback disabled + if props.TimestampOverrideFallbackDisabled != nil && utils.IsKnown(d.TimestampOverrideFallbackDisabled) { + timestampOverrideFallbackDisabled := kbapi.SecurityDetectionsAPITimestampOverrideFallbackDisabled(d.TimestampOverrideFallbackDisabled.ValueBool()) + *props.TimestampOverrideFallbackDisabled = ×tampOverrideFallbackDisabled + } + + // Set severity mapping + if props.SeverityMapping != nil && utils.IsKnown(d.SeverityMapping) { + severityMapping, severityMappingDiags := d.severityMappingToApi(ctx) + diags.Append(severityMappingDiags...) + if !severityMappingDiags.HasError() && severityMapping != nil && len(*severityMapping) > 0 { + *props.SeverityMapping = severityMapping + } + } + + // Set related integrations + if props.RelatedIntegrations != nil && utils.IsKnown(d.RelatedIntegrations) { + relatedIntegrations, relatedIntegrationsDiags := d.relatedIntegrationsToApi(ctx) + diags.Append(relatedIntegrationsDiags...) + if !relatedIntegrationsDiags.HasError() && relatedIntegrations != nil && len(*relatedIntegrations) > 0 { + *props.RelatedIntegrations = relatedIntegrations + } + } + + // Set required fields + if props.RequiredFields != nil && utils.IsKnown(d.RequiredFields) { + requiredFields, requiredFieldsDiags := d.requiredFieldsToApi(ctx) + diags.Append(requiredFieldsDiags...) + if !requiredFieldsDiags.HasError() && requiredFields != nil && len(*requiredFields) > 0 { + *props.RequiredFields = requiredFields + } + } + + // Set investigation fields + if props.InvestigationFields != nil { + investigationFields, investigationFieldsDiags := d.investigationFieldsToApi(ctx) + if !investigationFieldsDiags.HasError() && investigationFields != nil { + *props.InvestigationFields = investigationFields + } + diags.Append(investigationFieldsDiags...) + } + + // Set response actions + if props.ResponseActions != nil && utils.IsKnown(d.ResponseActions) { + responseActions, responseActionsDiags := d.responseActionsToApi(ctx, client) + diags.Append(responseActionsDiags...) + if !responseActionsDiags.HasError() && len(responseActions) > 0 { + *props.ResponseActions = &responseActions + } + } + + // Set filters + if props.Filters != nil && utils.IsKnown(d.Filters) { + filters, filtersDiags := d.filtersToApi(ctx) + diags.Append(filtersDiags...) + if !filtersDiags.HasError() && filters != nil { + *props.Filters = filters + } + } + + // Set alert suppression + if props.AlertSuppression != nil { + alertSuppression := d.alertSuppressionToApi(ctx, diags) + if alertSuppression != nil { + *props.AlertSuppression = alertSuppression + } + } +} + +// Helper function to initialize fields that should be set to default values for all rule types +func (d *SecurityDetectionRuleData) initializeAllFieldsToDefaults(ctx context.Context, diags *diag.Diagnostics) { + + // Initialize fields that should be empty lists for all rule types initially + if !utils.IsKnown(d.Author) { + d.Author = types.ListNull(types.StringType) + } + if !utils.IsKnown(d.Tags) { + d.Tags = types.ListNull(types.StringType) + } + if !utils.IsKnown(d.FalsePositives) { + d.FalsePositives = types.ListNull(types.StringType) + } + if !utils.IsKnown(d.References) { + d.References = types.ListNull(types.StringType) + } + + // Initialize new common fields with proper empty lists + if !utils.IsKnown(d.RelatedIntegrations) { + d.RelatedIntegrations = types.ListNull(getRelatedIntegrationElementType()) + } + if !utils.IsKnown(d.RequiredFields) { + d.RequiredFields = types.ListNull(getRequiredFieldElementType()) + } + if !utils.IsKnown(d.SeverityMapping) { + d.SeverityMapping = types.ListNull(getSeverityMappingElementType()) + } + + // Initialize building block type to null by default + if !utils.IsKnown(d.BuildingBlockType) { + d.BuildingBlockType = types.StringNull() + } + + // Actions field (common across all rule types) + if !utils.IsKnown(d.Actions) { + d.Actions = types.ListNull(getActionElementType()) + } + + // Exceptions list field (common across all rule types) + if !utils.IsKnown(d.ExceptionsList) { + d.ExceptionsList = types.ListNull(getExceptionsListElementType()) + } + + // Initialize all type-specific fields to null/empty by default + d.initializeTypeSpecificFieldsToDefaults(ctx, diags) +} + +// Helper function to initialize type-specific fields to default/null values +func (d *SecurityDetectionRuleData) initializeTypeSpecificFieldsToDefaults(ctx context.Context, diags *diag.Diagnostics) { + // EQL-specific fields + if !utils.IsKnown(d.TiebreakerField) { + d.TiebreakerField = types.StringNull() + } + + // Machine Learning-specific fields + if !utils.IsKnown(d.AnomalyThreshold) { + d.AnomalyThreshold = types.Int64Null() + } + if !utils.IsKnown(d.MachineLearningJobId) { + d.MachineLearningJobId = types.ListNull(types.StringType) + } + + // New Terms-specific fields + if !utils.IsKnown(d.NewTermsFields) { + d.NewTermsFields = types.ListNull(types.StringType) + } + if !utils.IsKnown(d.HistoryWindowStart) { + d.HistoryWindowStart = types.StringNull() + } + + // Saved Query-specific fields + if !utils.IsKnown(d.SavedId) { + d.SavedId = types.StringNull() + } + + // Threat Match-specific fields + if !utils.IsKnown(d.ThreatIndex) { + d.ThreatIndex = types.ListNull(types.StringType) + } + if !utils.IsKnown(d.ThreatQuery) { + d.ThreatQuery = types.StringNull() + } + if !utils.IsKnown(d.ThreatMapping) { + d.ThreatMapping = types.ListNull(getThreatMappingElementType()) + } + if !utils.IsKnown(d.ThreatFilters) { + d.ThreatFilters = types.ListNull(types.StringType) + } + if !utils.IsKnown(d.ThreatIndicatorPath) { + d.ThreatIndicatorPath = types.StringNull() + } + if !utils.IsKnown(d.ConcurrentSearches) { + d.ConcurrentSearches = types.Int64Null() + } + if !utils.IsKnown(d.ItemsPerSearch) { + d.ItemsPerSearch = types.Int64Null() + } + + // Threshold-specific fields + if !utils.IsKnown(d.Threshold) { + d.Threshold = types.ObjectNull(getThresholdType()) + } + + // Timeline fields (common across multiple rule types) + if !utils.IsKnown(d.TimelineId) { + d.TimelineId = types.StringNull() + } + if !utils.IsKnown(d.TimelineTitle) { + d.TimelineTitle = types.StringNull() + } + + // Threat field (common across multiple rule types) - MITRE ATT&CK framework + if !utils.IsKnown(d.Threat) { + d.Threat = types.ListNull(getThreatElementType()) + } +} diff --git a/internal/kibana/security_detection_rule/models_eql.go b/internal/kibana/security_detection_rule/models_eql.go new file mode 100644 index 000000000..96f9aa9ca --- /dev/null +++ b/internal/kibana/security_detection_rule/models_eql.go @@ -0,0 +1,319 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +type EqlRuleProcessor struct{} + +func (e EqlRuleProcessor) HandlesRuleType(t string) bool { + return t == "eql" +} + +func (e EqlRuleProcessor) ToCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + return toEqlRuleCreateProps(ctx, client, d) +} + +func (e EqlRuleProcessor) ToUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + return toEqlRuleUpdateProps(ctx, client, d) +} + +func (e EqlRuleProcessor) HandlesAPIRuleResponse(rule any) bool { + _, ok := rule.(kbapi.SecurityDetectionsAPIEqlRule) + return ok +} + +func (e EqlRuleProcessor) UpdateFromResponse(ctx context.Context, rule any, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + value, ok := rule.(kbapi.SecurityDetectionsAPIEqlRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return diags + } + + return updateFromEqlRule(ctx, &value, d) +} + +func (e EqlRuleProcessor) ExtractId(response any) (string, diag.Diagnostics) { + var diags diag.Diagnostics + value, ok := response.(kbapi.SecurityDetectionsAPIEqlRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return "", diags + } + return value.Id.String(), diags +} + +func toEqlRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var createProps kbapi.SecurityDetectionsAPIRuleCreateProps + + eqlRule := kbapi.SecurityDetectionsAPIEqlRuleCreateProps{ + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIEqlRuleCreatePropsType("eql"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + Language: kbapi.SecurityDetectionsAPIEqlQueryLanguage("eql"), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + d.setCommonCreateProps(ctx, &CommonCreateProps{ + Actions: &eqlRule.Actions, + ResponseActions: &eqlRule.ResponseActions, + RuleId: &eqlRule.RuleId, + Enabled: &eqlRule.Enabled, + From: &eqlRule.From, + To: &eqlRule.To, + Interval: &eqlRule.Interval, + Index: &eqlRule.Index, + Author: &eqlRule.Author, + Tags: &eqlRule.Tags, + FalsePositives: &eqlRule.FalsePositives, + References: &eqlRule.References, + License: &eqlRule.License, + Note: &eqlRule.Note, + Setup: &eqlRule.Setup, + MaxSignals: &eqlRule.MaxSignals, + Version: &eqlRule.Version, + ExceptionsList: &eqlRule.ExceptionsList, + AlertSuppression: &eqlRule.AlertSuppression, + RiskScoreMapping: &eqlRule.RiskScoreMapping, + SeverityMapping: &eqlRule.SeverityMapping, + RelatedIntegrations: &eqlRule.RelatedIntegrations, + RequiredFields: &eqlRule.RequiredFields, + BuildingBlockType: &eqlRule.BuildingBlockType, + DataViewId: &eqlRule.DataViewId, + Namespace: &eqlRule.Namespace, + RuleNameOverride: &eqlRule.RuleNameOverride, + TimestampOverride: &eqlRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &eqlRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &eqlRule.InvestigationFields, + Filters: &eqlRule.Filters, + }, &diags, client) + + // Set EQL-specific fields + if utils.IsKnown(d.TiebreakerField) { + tiebreakerField := kbapi.SecurityDetectionsAPITiebreakerField(d.TiebreakerField.ValueString()) + eqlRule.TiebreakerField = &tiebreakerField + } + + // Convert to union type + err := createProps.FromSecurityDetectionsAPIEqlRuleCreateProps(eqlRule) + if err != nil { + diags.AddError( + "Error building create properties", + "Could not convert EQL rule properties: "+err.Error(), + ) + } + + return createProps, diags +} +func toEqlRuleUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps + + // Parse ID to get space_id and rule_id + compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) + diags.Append(resourceIdDiags...) + + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + diags.AddError("ID was not a valid UUID", err.Error()) + return updateProps, diags + } + var id = kbapi.SecurityDetectionsAPIRuleObjectId(uid) + + eqlRule := kbapi.SecurityDetectionsAPIEqlRuleUpdateProps{ + Id: &id, + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIEqlRuleUpdatePropsType("eql"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + Language: kbapi.SecurityDetectionsAPIEqlQueryLanguage("eql"), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // For updates, we need to include the rule_id if it's set + if utils.IsKnown(d.RuleId) { + ruleId := kbapi.SecurityDetectionsAPIRuleSignatureId(d.RuleId.ValueString()) + eqlRule.RuleId = &ruleId + eqlRule.Id = nil // if rule_id is set, we cant send id + } + + d.setCommonUpdateProps(ctx, &CommonUpdateProps{ + Actions: &eqlRule.Actions, + ResponseActions: &eqlRule.ResponseActions, + RuleId: &eqlRule.RuleId, + Enabled: &eqlRule.Enabled, + From: &eqlRule.From, + To: &eqlRule.To, + Interval: &eqlRule.Interval, + Index: &eqlRule.Index, + Author: &eqlRule.Author, + Tags: &eqlRule.Tags, + FalsePositives: &eqlRule.FalsePositives, + References: &eqlRule.References, + License: &eqlRule.License, + Note: &eqlRule.Note, + Setup: &eqlRule.Setup, + MaxSignals: &eqlRule.MaxSignals, + Version: &eqlRule.Version, + ExceptionsList: &eqlRule.ExceptionsList, + AlertSuppression: &eqlRule.AlertSuppression, + RiskScoreMapping: &eqlRule.RiskScoreMapping, + SeverityMapping: &eqlRule.SeverityMapping, + RelatedIntegrations: &eqlRule.RelatedIntegrations, + RequiredFields: &eqlRule.RequiredFields, + BuildingBlockType: &eqlRule.BuildingBlockType, + DataViewId: &eqlRule.DataViewId, + Namespace: &eqlRule.Namespace, + RuleNameOverride: &eqlRule.RuleNameOverride, + TimestampOverride: &eqlRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &eqlRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &eqlRule.InvestigationFields, + Filters: &eqlRule.Filters, + }, &diags, client) + + // Set EQL-specific fields + if utils.IsKnown(d.TiebreakerField) { + tiebreakerField := kbapi.SecurityDetectionsAPITiebreakerField(d.TiebreakerField.ValueString()) + eqlRule.TiebreakerField = &tiebreakerField + } + + // Convert to union type + err = updateProps.FromSecurityDetectionsAPIEqlRuleUpdateProps(eqlRule) + if err != nil { + diags.AddError( + "Error building update properties", + "Could not convert EQL rule properties: "+err.Error(), + ) + } + + return updateProps, diags +} +func updateFromEqlRule(ctx context.Context, rule *kbapi.SecurityDetectionsAPIEqlRule, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + + compId := clients.CompositeId{ + ClusterId: d.SpaceId.ValueString(), + ResourceId: rule.Id.String(), + } + d.Id = types.StringValue(compId.String()) + + d.RuleId = types.StringValue(string(rule.RuleId)) + d.Name = types.StringValue(string(rule.Name)) + d.Type = types.StringValue(string(rule.Type)) + + // Update common fields + diags.Append(d.updateDataViewIdFromApi(ctx, rule.DataViewId)...) + diags.Append(d.updateNamespaceFromApi(ctx, rule.Namespace)...) + diags.Append(d.updateRuleNameOverrideFromApi(ctx, rule.RuleNameOverride)...) + diags.Append(d.updateTimestampOverrideFromApi(ctx, rule.TimestampOverride)...) + diags.Append(d.updateTimestampOverrideFallbackDisabledFromApi(ctx, rule.TimestampOverrideFallbackDisabled)...) + + d.Query = types.StringValue(rule.Query) + d.Language = types.StringValue(string(rule.Language)) + d.Enabled = types.BoolValue(bool(rule.Enabled)) + d.From = types.StringValue(string(rule.From)) + d.To = types.StringValue(string(rule.To)) + d.Interval = types.StringValue(string(rule.Interval)) + d.Description = types.StringValue(string(rule.Description)) + d.RiskScore = types.Int64Value(int64(rule.RiskScore)) + d.Severity = types.StringValue(string(rule.Severity)) + d.MaxSignals = types.Int64Value(int64(rule.MaxSignals)) + d.Version = types.Int64Value(int64(rule.Version)) + + // Update building block type + diags.Append(d.updateBuildingBlockTypeFromApi(ctx, rule.BuildingBlockType)...) + + // Update read-only fields + d.CreatedAt = utils.TimeToStringValue(rule.CreatedAt) + d.CreatedBy = types.StringValue(rule.CreatedBy) + d.UpdatedAt = utils.TimeToStringValue(rule.UpdatedAt) + d.UpdatedBy = types.StringValue(rule.UpdatedBy) + d.Revision = types.Int64Value(int64(rule.Revision)) + + // Update index patterns + diags.Append(d.updateIndexFromApi(ctx, rule.Index)...) + + // Update author + diags.Append(d.updateAuthorFromApi(ctx, rule.Author)...) + + // Update tags + diags.Append(d.updateTagsFromApi(ctx, rule.Tags)...) + + // Update false positives + diags.Append(d.updateFalsePositivesFromApi(ctx, rule.FalsePositives)...) + + // Update references + diags.Append(d.updateReferencesFromApi(ctx, rule.References)...) + + // Update optional string fields + diags.Append(d.updateLicenseFromApi(ctx, rule.License)...) + diags.Append(d.updateNoteFromApi(ctx, rule.Note)...) + diags.Append(d.updateSetupFromApi(ctx, rule.Setup)...) + + // EQL-specific fields + if rule.TiebreakerField != nil { + d.TiebreakerField = types.StringValue(string(*rule.TiebreakerField)) + } else { + d.TiebreakerField = types.StringNull() + } + + // Update actions + actionDiags := d.updateActionsFromApi(ctx, rule.Actions) + diags.Append(actionDiags...) + + // Update exceptions list + exceptionsListDiags := d.updateExceptionsListFromApi(ctx, rule.ExceptionsList) + diags.Append(exceptionsListDiags...) + + // Update risk score mapping + riskScoreMappingDiags := d.updateRiskScoreMappingFromApi(ctx, rule.RiskScoreMapping) + diags.Append(riskScoreMappingDiags...) + + // Update investigation fields + investigationFieldsDiags := d.updateInvestigationFieldsFromApi(ctx, rule.InvestigationFields) + diags.Append(investigationFieldsDiags...) + + // Update filters field + filtersDiags := d.updateFiltersFromApi(ctx, rule.Filters) + diags.Append(filtersDiags...) + + // Update severity mapping + severityMappingDiags := d.updateSeverityMappingFromApi(ctx, &rule.SeverityMapping) + diags.Append(severityMappingDiags...) + + // Update related integrations + relatedIntegrationsDiags := d.updateRelatedIntegrationsFromApi(ctx, &rule.RelatedIntegrations) + diags.Append(relatedIntegrationsDiags...) + + // Update required fields + requiredFieldsDiags := d.updateRequiredFieldsFromApi(ctx, &rule.RequiredFields) + diags.Append(requiredFieldsDiags...) + + // Update alert suppression + alertSuppressionDiags := d.updateAlertSuppressionFromApi(ctx, rule.AlertSuppression) + diags.Append(alertSuppressionDiags...) + + // Update response actions + responseActionsDiags := d.updateResponseActionsFromApi(ctx, rule.ResponseActions) + diags.Append(responseActionsDiags...) + + return diags +} diff --git a/internal/kibana/security_detection_rule/models_esql.go b/internal/kibana/security_detection_rule/models_esql.go new file mode 100644 index 000000000..c7229e390 --- /dev/null +++ b/internal/kibana/security_detection_rule/models_esql.go @@ -0,0 +1,302 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +type EsqlRuleProcessor struct{} + +func (e EsqlRuleProcessor) HandlesRuleType(t string) bool { + return t == "esql" +} + +func (e EsqlRuleProcessor) ToCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + return d.toEsqlRuleCreateProps(ctx, client) +} + +func (e EsqlRuleProcessor) ToUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + return d.toEsqlRuleUpdateProps(ctx, client) +} + +func (e EsqlRuleProcessor) HandlesAPIRuleResponse(rule any) bool { + _, ok := rule.(kbapi.SecurityDetectionsAPIEsqlRule) + return ok +} + +func (e EsqlRuleProcessor) UpdateFromResponse(ctx context.Context, rule any, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + value, ok := rule.(kbapi.SecurityDetectionsAPIEsqlRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return diags + } + + return d.updateFromEsqlRule(ctx, &value) +} + +func (e EsqlRuleProcessor) ExtractId(response any) (string, diag.Diagnostics) { + var diags diag.Diagnostics + value, ok := response.(kbapi.SecurityDetectionsAPIEsqlRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return "", diags + } + return value.Id.String(), diags +} + +func (d SecurityDetectionRuleData) toEsqlRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var createProps kbapi.SecurityDetectionsAPIRuleCreateProps + + esqlRule := kbapi.SecurityDetectionsAPIEsqlRuleCreateProps{ + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIEsqlRuleCreatePropsType("esql"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + Language: kbapi.SecurityDetectionsAPIEsqlQueryLanguage("esql"), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + d.setCommonCreateProps(ctx, &CommonCreateProps{ + Actions: &esqlRule.Actions, + ResponseActions: &esqlRule.ResponseActions, + RuleId: &esqlRule.RuleId, + Enabled: &esqlRule.Enabled, + From: &esqlRule.From, + To: &esqlRule.To, + Interval: &esqlRule.Interval, + Index: nil, // ESQL rules don't use index patterns + Author: &esqlRule.Author, + Tags: &esqlRule.Tags, + FalsePositives: &esqlRule.FalsePositives, + References: &esqlRule.References, + License: &esqlRule.License, + Note: &esqlRule.Note, + Setup: &esqlRule.Setup, + MaxSignals: &esqlRule.MaxSignals, + Version: &esqlRule.Version, + ExceptionsList: &esqlRule.ExceptionsList, + AlertSuppression: &esqlRule.AlertSuppression, + RiskScoreMapping: &esqlRule.RiskScoreMapping, + SeverityMapping: &esqlRule.SeverityMapping, + RelatedIntegrations: &esqlRule.RelatedIntegrations, + RequiredFields: &esqlRule.RequiredFields, + BuildingBlockType: &esqlRule.BuildingBlockType, + DataViewId: nil, // ESQL rules don't have DataViewId + Namespace: &esqlRule.Namespace, + RuleNameOverride: &esqlRule.RuleNameOverride, + TimestampOverride: &esqlRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &esqlRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &esqlRule.InvestigationFields, + Filters: nil, // ESQL rules don't support this field + }, &diags, client) + + // ESQL rules don't use index patterns as they use FROM clause in the query + + // Convert to union type + err := createProps.FromSecurityDetectionsAPIEsqlRuleCreateProps(esqlRule) + if err != nil { + diags.AddError( + "Error building create properties", + "Could not convert ESQL rule properties: "+err.Error(), + ) + } + + return createProps, diags +} + +func (d SecurityDetectionRuleData) toEsqlRuleUpdateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps + + // Parse ID to get space_id and rule_id + compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) + diags.Append(resourceIdDiags...) + + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + diags.AddError("ID was not a valid UUID", err.Error()) + return updateProps, diags + } + var id = kbapi.SecurityDetectionsAPIRuleObjectId(uid) + + esqlRule := kbapi.SecurityDetectionsAPIEsqlRuleUpdateProps{ + Id: &id, + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIEsqlRuleUpdatePropsType("esql"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + Language: kbapi.SecurityDetectionsAPIEsqlQueryLanguage("esql"), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // For updates, we need to include the rule_id if it's set + if utils.IsKnown(d.RuleId) { + ruleId := kbapi.SecurityDetectionsAPIRuleSignatureId(d.RuleId.ValueString()) + esqlRule.RuleId = &ruleId + esqlRule.Id = nil // if rule_id is set, we cant send id + } + + d.setCommonUpdateProps(ctx, &CommonUpdateProps{ + Actions: &esqlRule.Actions, + ResponseActions: &esqlRule.ResponseActions, + RuleId: &esqlRule.RuleId, + Enabled: &esqlRule.Enabled, + From: &esqlRule.From, + To: &esqlRule.To, + Interval: &esqlRule.Interval, + Index: nil, // ESQL rules don't use index patterns + Author: &esqlRule.Author, + Tags: &esqlRule.Tags, + FalsePositives: &esqlRule.FalsePositives, + References: &esqlRule.References, + License: &esqlRule.License, + Note: &esqlRule.Note, + Setup: &esqlRule.Setup, + MaxSignals: &esqlRule.MaxSignals, + Version: &esqlRule.Version, + ExceptionsList: &esqlRule.ExceptionsList, + AlertSuppression: &esqlRule.AlertSuppression, + RiskScoreMapping: &esqlRule.RiskScoreMapping, + SeverityMapping: &esqlRule.SeverityMapping, + RelatedIntegrations: &esqlRule.RelatedIntegrations, + RequiredFields: &esqlRule.RequiredFields, + BuildingBlockType: &esqlRule.BuildingBlockType, + DataViewId: nil, // ESQL rules don't have DataViewId + Namespace: &esqlRule.Namespace, + RuleNameOverride: &esqlRule.RuleNameOverride, + TimestampOverride: &esqlRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &esqlRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &esqlRule.InvestigationFields, + Filters: nil, // ESQL rules don't have Filters + }, &diags, client) + + // ESQL rules don't use index patterns as they use FROM clause in the query + + // Convert to union type + err = updateProps.FromSecurityDetectionsAPIEsqlRuleUpdateProps(esqlRule) + if err != nil { + diags.AddError( + "Error building update properties", + "Could not convert ESQL rule properties: "+err.Error(), + ) + } + + return updateProps, diags +} +func (d *SecurityDetectionRuleData) updateFromEsqlRule(ctx context.Context, rule *kbapi.SecurityDetectionsAPIEsqlRule) diag.Diagnostics { + var diags diag.Diagnostics + + compId := clients.CompositeId{ + ClusterId: d.SpaceId.ValueString(), + ResourceId: rule.Id.String(), + } + d.Id = types.StringValue(compId.String()) + + d.RuleId = types.StringValue(string(rule.RuleId)) + d.Name = types.StringValue(string(rule.Name)) + d.Type = types.StringValue(string(rule.Type)) + + // Update common fields (ESQL doesn't support DataViewId) + d.DataViewId = types.StringNull() + diags.Append(d.updateNamespaceFromApi(ctx, rule.Namespace)...) + diags.Append(d.updateRuleNameOverrideFromApi(ctx, rule.RuleNameOverride)...) + diags.Append(d.updateTimestampOverrideFromApi(ctx, rule.TimestampOverride)...) + diags.Append(d.updateTimestampOverrideFallbackDisabledFromApi(ctx, rule.TimestampOverrideFallbackDisabled)...) + + d.Query = types.StringValue(rule.Query) + d.Language = types.StringValue(string(rule.Language)) + d.Enabled = types.BoolValue(bool(rule.Enabled)) + d.From = types.StringValue(string(rule.From)) + d.To = types.StringValue(string(rule.To)) + d.Interval = types.StringValue(string(rule.Interval)) + d.Description = types.StringValue(string(rule.Description)) + d.RiskScore = types.Int64Value(int64(rule.RiskScore)) + d.Severity = types.StringValue(string(rule.Severity)) + d.MaxSignals = types.Int64Value(int64(rule.MaxSignals)) + d.Version = types.Int64Value(int64(rule.Version)) + + // Update building block type + diags.Append(d.updateBuildingBlockTypeFromApi(ctx, rule.BuildingBlockType)...) + + // Update read-only fields + d.CreatedAt = types.StringValue(rule.CreatedAt.Format("2006-01-02T15:04:05.000Z")) + d.CreatedBy = types.StringValue(rule.CreatedBy) + d.UpdatedAt = types.StringValue(rule.UpdatedAt.Format("2006-01-02T15:04:05.000Z")) + d.UpdatedBy = types.StringValue(rule.UpdatedBy) + d.Revision = types.Int64Value(int64(rule.Revision)) + + // ESQL rules don't use index patterns + d.Index = types.ListValueMust(types.StringType, []attr.Value{}) + + // Update author + diags.Append(d.updateAuthorFromApi(ctx, rule.Author)...) + + // Update tags + diags.Append(d.updateTagsFromApi(ctx, rule.Tags)...) + + // Update false positives + diags.Append(d.updateFalsePositivesFromApi(ctx, rule.FalsePositives)...) + + // Update references + diags.Append(d.updateReferencesFromApi(ctx, rule.References)...) + + // Update optional string fields + diags.Append(d.updateLicenseFromApi(ctx, rule.License)...) + diags.Append(d.updateNoteFromApi(ctx, rule.Note)...) + diags.Append(d.updateSetupFromApi(ctx, rule.Setup)...) + + // Update actions + actionDiags := d.updateActionsFromApi(ctx, rule.Actions) + diags.Append(actionDiags...) + + // Update exceptions list + exceptionsListDiags := d.updateExceptionsListFromApi(ctx, rule.ExceptionsList) + diags.Append(exceptionsListDiags...) + + // Update risk score mapping + riskScoreMappingDiags := d.updateRiskScoreMappingFromApi(ctx, rule.RiskScoreMapping) + diags.Append(riskScoreMappingDiags...) + + // Update investigation fields + investigationFieldsDiags := d.updateInvestigationFieldsFromApi(ctx, rule.InvestigationFields) + diags.Append(investigationFieldsDiags...) + + // Update severity mapping + severityMappingDiags := d.updateSeverityMappingFromApi(ctx, &rule.SeverityMapping) + diags.Append(severityMappingDiags...) + + // Update related integrations + relatedIntegrationsDiags := d.updateRelatedIntegrationsFromApi(ctx, &rule.RelatedIntegrations) + diags.Append(relatedIntegrationsDiags...) + + // Update required fields + requiredFieldsDiags := d.updateRequiredFieldsFromApi(ctx, &rule.RequiredFields) + diags.Append(requiredFieldsDiags...) + + // Update alert suppression + alertSuppressionDiags := d.updateAlertSuppressionFromApi(ctx, rule.AlertSuppression) + diags.Append(alertSuppressionDiags...) + + // Update response actions + responseActionsDiags := d.updateResponseActionsFromApi(ctx, rule.ResponseActions) + diags.Append(responseActionsDiags...) + + return diags +} diff --git a/internal/kibana/security_detection_rule/models_from_api_type_utils.go b/internal/kibana/security_detection_rule/models_from_api_type_utils.go new file mode 100644 index 000000000..20053e2db --- /dev/null +++ b/internal/kibana/security_detection_rule/models_from_api_type_utils.go @@ -0,0 +1,996 @@ +package security_detection_rule + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/elastic/terraform-provider-elasticstack/internal/utils/customtypes" + "github.com/hashicorp/terraform-plugin-framework-jsontypes/jsontypes" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +// Utilities to convert various API types to Terraform model types + +// convertActionsToModel converts kbapi.SecurityDetectionsAPIRuleAction slice to Terraform model +func convertActionsToModel(ctx context.Context, apiActions []kbapi.SecurityDetectionsAPIRuleAction) (types.List, diag.Diagnostics) { + var diags diag.Diagnostics + + if len(apiActions) == 0 { + return types.ListNull(getActionElementType()), diags + } + + actions := make([]ActionModel, 0) + + for _, apiAction := range apiActions { + action := ActionModel{ + ActionTypeId: types.StringValue(apiAction.ActionTypeId), + Id: types.StringValue(string(apiAction.Id)), + } + + // Convert params + if apiAction.Params != nil { + paramsMap := make(map[string]attr.Value) + for k, v := range apiAction.Params { + if v != nil { + paramsMap[k] = types.StringValue(fmt.Sprintf("%v", v)) + } + } + paramsValue, paramsDiags := types.MapValue(types.StringType, paramsMap) + diags.Append(paramsDiags...) + action.Params = paramsValue + } else { + action.Params = types.MapNull(types.StringType) + } + + // Set optional fields + action.Group = types.StringPointerValue(apiAction.Group) + + if apiAction.Uuid != nil { + action.Uuid = types.StringValue(string(*apiAction.Uuid)) + } else { + action.Uuid = types.StringNull() + } + + if apiAction.AlertsFilter != nil { + alertsFilterMap := make(map[string]attr.Value) + for k, v := range *apiAction.AlertsFilter { + if v != nil { + alertsFilterMap[k] = types.StringValue(fmt.Sprintf("%v", v)) + } + } + alertsFilterValue, alertsFilterDiags := types.MapValue(types.StringType, alertsFilterMap) + diags.Append(alertsFilterDiags...) + action.AlertsFilter = alertsFilterValue + } else { + action.AlertsFilter = types.MapNull(types.StringType) + } + + // Convert frequency + if apiAction.Frequency != nil { + var throttleStr string + if throttle0, err := apiAction.Frequency.Throttle.AsSecurityDetectionsAPIRuleActionThrottle0(); err == nil { + throttleStr = string(throttle0) + } else if throttle1, err := apiAction.Frequency.Throttle.AsSecurityDetectionsAPIRuleActionThrottle1(); err == nil { + throttleStr = string(throttle1) + } + + frequencyModel := ActionFrequencyModel{ + NotifyWhen: types.StringValue(string(apiAction.Frequency.NotifyWhen)), + Summary: types.BoolValue(apiAction.Frequency.Summary), + Throttle: types.StringValue(throttleStr), + } + + frequencyObj, frequencyDiags := types.ObjectValueFrom(ctx, getActionFrequencyType(), frequencyModel) + diags.Append(frequencyDiags...) + action.Frequency = frequencyObj + } else { + action.Frequency = types.ObjectNull(getActionFrequencyType()) + } + + actions = append(actions, action) + } + + listValue, listDiags := types.ListValueFrom(ctx, getActionElementType(), actions) + diags.Append(listDiags...) + return listValue, diags +} + +// convertExceptionsListToModel converts kbapi.SecurityDetectionsAPIRuleExceptionList slice to Terraform model +func convertExceptionsListToModel(ctx context.Context, apiExceptionsList []kbapi.SecurityDetectionsAPIRuleExceptionList) (types.List, diag.Diagnostics) { + var diags diag.Diagnostics + + if len(apiExceptionsList) == 0 { + return types.ListNull(getExceptionsListElementType()), diags + } + + exceptions := make([]ExceptionsListModel, 0) + + for _, apiException := range apiExceptionsList { + exception := ExceptionsListModel{ + Id: types.StringValue(apiException.Id), + ListId: types.StringValue(apiException.ListId), + NamespaceType: types.StringValue(string(apiException.NamespaceType)), + Type: types.StringValue(string(apiException.Type)), + } + + exceptions = append(exceptions, exception) + } + + listValue, listDiags := types.ListValueFrom(ctx, getExceptionsListElementType(), exceptions) + diags.Append(listDiags...) + return listValue, diags +} + +// convertRiskScoreMappingToModel converts kbapi.SecurityDetectionsAPIRiskScoreMapping to Terraform model +func convertRiskScoreMappingToModel(ctx context.Context, apiRiskScoreMapping kbapi.SecurityDetectionsAPIRiskScoreMapping) (types.List, diag.Diagnostics) { + var diags diag.Diagnostics + + if len(apiRiskScoreMapping) == 0 { + return types.ListNull(getRiskScoreMappingElementType()), diags + } + + mappings := make([]RiskScoreMappingModel, 0) + + for _, apiMapping := range apiRiskScoreMapping { + mapping := RiskScoreMappingModel{ + Field: types.StringValue(apiMapping.Field), + Operator: types.StringValue(string(apiMapping.Operator)), + Value: types.StringValue(apiMapping.Value), + } + + // Set optional risk score if provided + if apiMapping.RiskScore != nil { + mapping.RiskScore = types.Int64Value(int64(*apiMapping.RiskScore)) + } else { + mapping.RiskScore = types.Int64Null() + } + + mappings = append(mappings, mapping) + } + + listValue, listDiags := types.ListValueFrom(ctx, getRiskScoreMappingElementType(), mappings) + diags.Append(listDiags...) + return listValue, diags +} + +// convertInvestigationFieldsToModel converts kbapi.SecurityDetectionsAPIInvestigationFields to Terraform model +func convertInvestigationFieldsToModel(ctx context.Context, apiInvestigationFields *kbapi.SecurityDetectionsAPIInvestigationFields) (types.List, diag.Diagnostics) { + var diags diag.Diagnostics + + if apiInvestigationFields == nil || len(apiInvestigationFields.FieldNames) == 0 { + return types.ListNull(types.StringType), diags + } + + fieldNames := make([]string, len(apiInvestigationFields.FieldNames)) + for i, field := range apiInvestigationFields.FieldNames { + fieldNames[i] = string(field) + } + + return utils.SliceToListType_String(ctx, fieldNames, path.Root("investigation_fields"), &diags), diags +} + +// convertRelatedIntegrationsToModel converts kbapi.SecurityDetectionsAPIRelatedIntegrationArray to Terraform model +func convertRelatedIntegrationsToModel(ctx context.Context, apiRelatedIntegrations *kbapi.SecurityDetectionsAPIRelatedIntegrationArray) (types.List, diag.Diagnostics) { + var diags diag.Diagnostics + + if apiRelatedIntegrations == nil || len(*apiRelatedIntegrations) == 0 { + return types.ListNull(getRelatedIntegrationElementType()), diags + } + + integrations := make([]RelatedIntegrationModel, 0) + + for _, apiIntegration := range *apiRelatedIntegrations { + integration := RelatedIntegrationModel{ + Package: types.StringValue(string(apiIntegration.Package)), + Version: types.StringValue(string(apiIntegration.Version)), + } + + // Set optional integration field if provided + if apiIntegration.Integration != nil { + integration.Integration = types.StringValue(string(*apiIntegration.Integration)) + } else { + integration.Integration = types.StringNull() + } + + integrations = append(integrations, integration) + } + + listValue, listDiags := types.ListValueFrom(ctx, getRelatedIntegrationElementType(), integrations) + diags.Append(listDiags...) + return listValue, diags +} + +// convertRequiredFieldsToModel converts kbapi.SecurityDetectionsAPIRequiredFieldArray to Terraform model +func convertRequiredFieldsToModel(ctx context.Context, apiRequiredFields *kbapi.SecurityDetectionsAPIRequiredFieldArray) (types.List, diag.Diagnostics) { + var diags diag.Diagnostics + + if apiRequiredFields == nil || len(*apiRequiredFields) == 0 { + return types.ListNull(getRequiredFieldElementType()), diags + } + + fields := make([]RequiredFieldModel, 0) + + for _, apiField := range *apiRequiredFields { + field := RequiredFieldModel{ + Name: types.StringValue(apiField.Name), + Type: types.StringValue(apiField.Type), + Ecs: types.BoolValue(apiField.Ecs), + } + + fields = append(fields, field) + } + + listValue, listDiags := types.ListValueFrom(ctx, getRequiredFieldElementType(), fields) + diags.Append(listDiags...) + return listValue, diags +} + +// convertSeverityMappingToModel converts kbapi.SecurityDetectionsAPISeverityMapping to Terraform model +func convertSeverityMappingToModel(ctx context.Context, apiSeverityMapping *kbapi.SecurityDetectionsAPISeverityMapping) (types.List, diag.Diagnostics) { + var diags diag.Diagnostics + + if apiSeverityMapping == nil || len(*apiSeverityMapping) == 0 { + return types.ListNull(getSeverityMappingElementType()), diags + } + + mappings := make([]SeverityMappingModel, 0) + + for _, apiMapping := range *apiSeverityMapping { + mapping := SeverityMappingModel{ + Field: types.StringValue(apiMapping.Field), + Operator: types.StringValue(string(apiMapping.Operator)), + Value: types.StringValue(apiMapping.Value), + Severity: types.StringValue(string(apiMapping.Severity)), + } + + mappings = append(mappings, mapping) + } + + listValue, listDiags := types.ListValueFrom(ctx, getSeverityMappingElementType(), mappings) + diags.Append(listDiags...) + return listValue, diags +} + +// convertThreatMappingToModel converts kbapi.SecurityDetectionsAPIThreatMapping to the terraform model +func convertThreatMappingToModel(ctx context.Context, apiThreatMappings kbapi.SecurityDetectionsAPIThreatMapping) (types.List, diag.Diagnostics) { + var threatMappings []SecurityDetectionRuleTfDataItem + + for _, apiMapping := range apiThreatMappings { + var entries []SecurityDetectionRuleTfDataItemEntry + + for _, apiEntry := range apiMapping.Entries { + entries = append(entries, SecurityDetectionRuleTfDataItemEntry{ + Field: types.StringValue(string(apiEntry.Field)), + Type: types.StringValue(string(apiEntry.Type)), + Value: types.StringValue(string(apiEntry.Value)), + }) + } + + entriesListValue, diags := types.ListValueFrom(ctx, getThreatMappingEntryElementType(), entries) + if diags.HasError() { + return types.ListNull(getThreatMappingElementType()), diags + } + + threatMappings = append(threatMappings, SecurityDetectionRuleTfDataItem{ + Entries: entriesListValue, + }) + } + + listValue, diags := types.ListValueFrom(ctx, getThreatMappingElementType(), threatMappings) + return listValue, diags +} + +// convertResponseActionsToModel converts kbapi response actions array to the terraform model +func convertResponseActionsToModel(ctx context.Context, apiResponseActions *[]kbapi.SecurityDetectionsAPIResponseAction) (types.List, diag.Diagnostics) { + var diags diag.Diagnostics + + if apiResponseActions == nil || len(*apiResponseActions) == 0 { + return types.ListNull(getResponseActionElementType()), diags + } + + var responseActions []ResponseActionModel + + for _, apiResponseAction := range *apiResponseActions { + var responseAction ResponseActionModel + + // Use ValueByDiscriminator to get the concrete type + actionValue, err := apiResponseAction.ValueByDiscriminator() + if err != nil { + diags.AddError("Failed to get response action discriminator", fmt.Sprintf("Error: %s", err.Error())) + continue + } + + switch concreteAction := actionValue.(type) { + case kbapi.SecurityDetectionsAPIOsqueryResponseAction: + convertedAction, convertDiags := convertOsqueryResponseActionToModel(ctx, concreteAction) + diags.Append(convertDiags...) + if !convertDiags.HasError() { + responseAction = convertedAction + } + + case kbapi.SecurityDetectionsAPIEndpointResponseAction: + convertedAction, convertDiags := convertEndpointResponseActionToModel(ctx, concreteAction) + diags.Append(convertDiags...) + if !convertDiags.HasError() { + responseAction = convertedAction + } + + default: + diags.AddError("Unknown response action type", fmt.Sprintf("Unsupported response action type: %T", concreteAction)) + continue + } + + responseActions = append(responseActions, responseAction) + } + + listValue, listDiags := types.ListValueFrom(ctx, getResponseActionElementType(), responseActions) + if listDiags.HasError() { + diags.Append(listDiags...) + } + + return listValue, diags +} + +// convertOsqueryResponseActionToModel converts an Osquery response action to the terraform model +func convertOsqueryResponseActionToModel(ctx context.Context, osqueryAction kbapi.SecurityDetectionsAPIOsqueryResponseAction) (ResponseActionModel, diag.Diagnostics) { + var diags diag.Diagnostics + var responseAction ResponseActionModel + + responseAction.ActionTypeId = types.StringValue(string(osqueryAction.ActionTypeId)) + + // Convert osquery params + paramsModel := ResponseActionParamsModel{} + paramsModel.Query = types.StringPointerValue(osqueryAction.Params.Query) + if osqueryAction.Params.PackId != nil { + paramsModel.PackId = types.StringPointerValue(osqueryAction.Params.PackId) + } else { + paramsModel.PackId = types.StringNull() + } + if osqueryAction.Params.SavedQueryId != nil { + paramsModel.SavedQueryId = types.StringPointerValue(osqueryAction.Params.SavedQueryId) + } else { + paramsModel.SavedQueryId = types.StringNull() + } + if osqueryAction.Params.Timeout != nil { + paramsModel.Timeout = types.Int64Value(int64(*osqueryAction.Params.Timeout)) + } else { + paramsModel.Timeout = types.Int64Null() + } + + // Convert ECS mapping + if osqueryAction.Params.EcsMapping != nil { + ecsMappingAttrs := make(map[string]attr.Value) + for key, value := range *osqueryAction.Params.EcsMapping { + if value.Field != nil { + ecsMappingAttrs[key] = types.StringPointerValue(value.Field) + } else { + ecsMappingAttrs[key] = types.StringNull() + } + } + ecsMappingValue, ecsDiags := types.MapValue(types.StringType, ecsMappingAttrs) + if ecsDiags.HasError() { + diags.Append(ecsDiags...) + } else { + paramsModel.EcsMapping = ecsMappingValue + } + } else { + paramsModel.EcsMapping = types.MapNull(types.StringType) + } + + // Convert queries array + if osqueryAction.Params.Queries != nil { + var queries []OsqueryQueryModel + for _, apiQuery := range *osqueryAction.Params.Queries { + query := OsqueryQueryModel{ + Id: types.StringValue(apiQuery.Id), + Query: types.StringValue(apiQuery.Query), + } + if apiQuery.Platform != nil { + query.Platform = types.StringPointerValue(apiQuery.Platform) + } else { + query.Platform = types.StringNull() + } + if apiQuery.Version != nil { + query.Version = types.StringPointerValue(apiQuery.Version) + } else { + query.Version = types.StringNull() + } + if apiQuery.Removed != nil { + query.Removed = types.BoolPointerValue(apiQuery.Removed) + } else { + query.Removed = types.BoolNull() + } + if apiQuery.Snapshot != nil { + query.Snapshot = types.BoolPointerValue(apiQuery.Snapshot) + } else { + query.Snapshot = types.BoolNull() + } + + // Convert query ECS mapping + if apiQuery.EcsMapping != nil { + queryEcsMappingAttrs := make(map[string]attr.Value) + for key, value := range *apiQuery.EcsMapping { + if value.Field != nil { + queryEcsMappingAttrs[key] = types.StringPointerValue(value.Field) + } else { + queryEcsMappingAttrs[key] = types.StringNull() + } + } + queryEcsMappingValue, queryEcsDiags := types.MapValue(types.StringType, queryEcsMappingAttrs) + if queryEcsDiags.HasError() { + diags.Append(queryEcsDiags...) + } else { + query.EcsMapping = queryEcsMappingValue + } + } else { + query.EcsMapping = types.MapNull(types.StringType) + } + + queries = append(queries, query) + } + + queriesListValue, queriesDiags := types.ListValueFrom(ctx, getOsqueryQueryElementType(), queries) + if queriesDiags.HasError() { + diags.Append(queriesDiags...) + } else { + paramsModel.Queries = queriesListValue + } + } else { + paramsModel.Queries = types.ListNull(getOsqueryQueryElementType()) + } + + // Set remaining fields to null since this is osquery + paramsModel.Command = types.StringNull() + paramsModel.Comment = types.StringNull() + paramsModel.Config = types.ObjectNull(getEndpointProcessConfigType()) + + paramsObjectValue, paramsDiags := types.ObjectValueFrom(ctx, getResponseActionParamsType(), paramsModel) + if paramsDiags.HasError() { + diags.Append(paramsDiags...) + } else { + responseAction.Params = paramsObjectValue + } + + return responseAction, diags +} + +// convertEndpointResponseActionToModel converts an Endpoint response action to the terraform model +func convertEndpointResponseActionToModel(ctx context.Context, endpointAction kbapi.SecurityDetectionsAPIEndpointResponseAction) (ResponseActionModel, diag.Diagnostics) { + var diags diag.Diagnostics + var responseAction ResponseActionModel + + responseAction.ActionTypeId = types.StringValue(string(endpointAction.ActionTypeId)) + + // Convert endpoint params + paramsModel := ResponseActionParamsModel{} + + commandParams, err := endpointAction.Params.AsSecurityDetectionsAPIDefaultParams() + if err == nil { + switch commandParams.Command { + case "isolate": + defaultParams, err := endpointAction.Params.AsSecurityDetectionsAPIDefaultParams() + if err != nil { + diags.AddError("Failed to parse endpoint default params", fmt.Sprintf("Error: %s", err.Error())) + } else { + paramsModel.Command = types.StringValue(string(defaultParams.Command)) + if defaultParams.Comment != nil { + paramsModel.Comment = types.StringPointerValue(defaultParams.Comment) + } else { + paramsModel.Comment = types.StringNull() + } + paramsModel.Config = types.ObjectNull(getEndpointProcessConfigType()) + } + case "kill-process", "suspend-process": + processesParams, err := endpointAction.Params.AsSecurityDetectionsAPIProcessesParams() + if err != nil { + diags.AddError("Failed to parse endpoint processes params", fmt.Sprintf("Error: %s", err.Error())) + } else { + paramsModel.Command = types.StringValue(string(processesParams.Command)) + if processesParams.Comment != nil { + paramsModel.Comment = types.StringPointerValue(processesParams.Comment) + } else { + paramsModel.Comment = types.StringNull() + } + + // Convert config + configModel := EndpointProcessConfigModel{ + Field: types.StringValue(processesParams.Config.Field), + } + if processesParams.Config.Overwrite != nil { + configModel.Overwrite = types.BoolPointerValue(processesParams.Config.Overwrite) + } else { + configModel.Overwrite = types.BoolNull() + } + + configObjectValue, configDiags := types.ObjectValueFrom(ctx, getEndpointProcessConfigType(), configModel) + if configDiags.HasError() { + diags.Append(configDiags...) + } else { + paramsModel.Config = configObjectValue + } + } + } + } else { + diags.AddError("Unknown endpoint command", fmt.Sprintf("Unsupported endpoint command: %s. Error: %s", commandParams.Command, err.Error())) + } + + // Set osquery fields to null since this is endpoint + paramsModel.Query = types.StringNull() + paramsModel.PackId = types.StringNull() + paramsModel.SavedQueryId = types.StringNull() + paramsModel.Timeout = types.Int64Null() + paramsModel.EcsMapping = types.MapNull(types.StringType) + paramsModel.Queries = types.ListNull(getOsqueryQueryElementType()) + + paramsObjectValue, paramsDiags := types.ObjectValueFrom(ctx, getResponseActionParamsType(), paramsModel) + if paramsDiags.HasError() { + diags.Append(paramsDiags...) + } else { + responseAction.Params = paramsObjectValue + } + + return responseAction, diags +} + +// convertThresholdToModel converts kbapi.SecurityDetectionsAPIThreshold to the terraform model +func convertThresholdToModel(ctx context.Context, apiThreshold kbapi.SecurityDetectionsAPIThreshold) (types.Object, diag.Diagnostics) { + var diags diag.Diagnostics + + // Handle threshold field - can be single string or array + var fieldList types.List + if singleField, err := apiThreshold.Field.AsSecurityDetectionsAPIThresholdField0(); err == nil { + // Single field + fieldList = utils.SliceToListType_String(ctx, []string{string(singleField)}, path.Root("threshold").AtName("field"), &diags) + } else if multipleFields, err := apiThreshold.Field.AsSecurityDetectionsAPIThresholdField1(); err == nil { + // Multiple fields + fieldStrings := make([]string, len(multipleFields)) + for i, field := range multipleFields { + fieldStrings[i] = string(field) + } + fieldList = utils.SliceToListType_String(ctx, fieldStrings, path.Root("threshold").AtName("field"), &diags) + } else { + fieldList = types.ListValueMust(types.StringType, []attr.Value{}) + } + + // Handle cardinality (optional) + var cardinalityList types.List + if apiThreshold.Cardinality != nil && len(*apiThreshold.Cardinality) > 0 { + cardinalityList = utils.SliceToListType(ctx, *apiThreshold.Cardinality, getCardinalityType(), path.Root("threshold").AtName("cardinality"), &diags, + func(item struct { + Field string `json:"field"` + Value int `json:"value"` + }, meta utils.ListMeta) CardinalityModel { + return CardinalityModel{ + Field: types.StringValue(item.Field), + Value: types.Int64Value(int64(item.Value)), + } + }) + } else { + cardinalityList = types.ListNull(getCardinalityType()) + } + + thresholdModel := ThresholdModel{ + Field: fieldList, + Value: types.Int64Value(int64(apiThreshold.Value)), + Cardinality: cardinalityList, + } + + thresholdObject, objDiags := types.ObjectValueFrom(ctx, getThresholdType(), thresholdModel) + diags.Append(objDiags...) + return thresholdObject, diags +} + +// convertFiltersFromApi converts the API filters field back to the Terraform type +func (d *SecurityDetectionRuleData) updateFiltersFromApi(ctx context.Context, apiFilters *kbapi.SecurityDetectionsAPIRuleFilterArray) diag.Diagnostics { + var diags diag.Diagnostics + + if apiFilters == nil || len(*apiFilters) == 0 { + d.Filters = jsontypes.NewNormalizedNull() + return diags + } + + // Marshal the []interface{} to JSON string + jsonBytes, err := json.Marshal(*apiFilters) + if err != nil { + diags.AddError("Failed to marshal filters", err.Error()) + return diags + } + + // Create a NormalizedValue from the JSON string + d.Filters = jsontypes.NewNormalizedValue(string(jsonBytes)) + return diags +} + +// Helper function to update severity mapping from API response +func (d *SecurityDetectionRuleData) updateSeverityMappingFromApi(ctx context.Context, severityMapping *kbapi.SecurityDetectionsAPISeverityMapping) diag.Diagnostics { + var diags diag.Diagnostics + + if severityMapping != nil && len(*severityMapping) > 0 { + severityMappingValue, severityMappingDiags := convertSeverityMappingToModel(ctx, severityMapping) + diags.Append(severityMappingDiags...) + if !severityMappingDiags.HasError() { + d.SeverityMapping = severityMappingValue + } + } else { + d.SeverityMapping = types.ListNull(getSeverityMappingElementType()) + } + + return diags +} + +// Helper function to update index patterns from API response +func (d *SecurityDetectionRuleData) updateIndexFromApi(ctx context.Context, index *[]string) diag.Diagnostics { + var diags diag.Diagnostics + + if index != nil && len(*index) > 0 { + d.Index = utils.ListValueFrom(ctx, *index, types.StringType, path.Root("index"), &diags) + } else { + d.Index = types.ListValueMust(types.StringType, []attr.Value{}) + } + + return diags +} + +// Helper function to update author from API response +func (d *SecurityDetectionRuleData) updateAuthorFromApi(ctx context.Context, author []string) diag.Diagnostics { + var diags diag.Diagnostics + + if len(author) > 0 { + d.Author = utils.ListValueFrom(ctx, author, types.StringType, path.Root("author"), &diags) + } else { + d.Author = types.ListValueMust(types.StringType, []attr.Value{}) + } + + return diags +} + +// Helper function to update tags from API response +func (d *SecurityDetectionRuleData) updateTagsFromApi(ctx context.Context, tags []string) diag.Diagnostics { + var diags diag.Diagnostics + + if len(tags) > 0 { + d.Tags = utils.ListValueFrom(ctx, tags, types.StringType, path.Root("tags"), &diags) + } else { + d.Tags = types.ListValueMust(types.StringType, []attr.Value{}) + } + + return diags +} + +// Helper function to update false positives from API response +func (d *SecurityDetectionRuleData) updateFalsePositivesFromApi(ctx context.Context, falsePositives []string) diag.Diagnostics { + var diags diag.Diagnostics + + d.FalsePositives = utils.ListValueFrom(ctx, falsePositives, types.StringType, path.Root("false_positives"), &diags) + + return diags +} + +// Helper function to update references from API response +func (d *SecurityDetectionRuleData) updateReferencesFromApi(ctx context.Context, references []string) diag.Diagnostics { + var diags diag.Diagnostics + + if len(references) > 0 { + d.References = utils.ListValueFrom(ctx, references, types.StringType, path.Root("references"), &diags) + } else { + d.References = types.ListValueMust(types.StringType, []attr.Value{}) + } + + return diags +} + +// Helper function to update data view ID from API response +func (d *SecurityDetectionRuleData) updateDataViewIdFromApi(ctx context.Context, dataViewId *kbapi.SecurityDetectionsAPIDataViewId) diag.Diagnostics { + var diags diag.Diagnostics + + if dataViewId != nil { + d.DataViewId = types.StringValue(string(*dataViewId)) + } else { + d.DataViewId = types.StringNull() + } + + return diags +} + +// Helper function to update namespace from API response +func (d *SecurityDetectionRuleData) updateNamespaceFromApi(ctx context.Context, namespace *kbapi.SecurityDetectionsAPIAlertsIndexNamespace) diag.Diagnostics { + var diags diag.Diagnostics + + if namespace != nil { + d.Namespace = types.StringValue(string(*namespace)) + } else { + d.Namespace = types.StringNull() + } + + return diags +} + +// Helper function to update rule name override from API response +func (d *SecurityDetectionRuleData) updateRuleNameOverrideFromApi(ctx context.Context, ruleNameOverride *kbapi.SecurityDetectionsAPIRuleNameOverride) diag.Diagnostics { + var diags diag.Diagnostics + + if ruleNameOverride != nil { + d.RuleNameOverride = types.StringValue(string(*ruleNameOverride)) + } else { + d.RuleNameOverride = types.StringNull() + } + + return diags +} + +// Helper function to update timestamp override from API response +func (d *SecurityDetectionRuleData) updateTimestampOverrideFromApi(ctx context.Context, timestampOverride *kbapi.SecurityDetectionsAPITimestampOverride) diag.Diagnostics { + var diags diag.Diagnostics + + if timestampOverride != nil { + d.TimestampOverride = types.StringValue(string(*timestampOverride)) + } else { + d.TimestampOverride = types.StringNull() + } + + return diags +} + +// Helper function to update timestamp override fallback disabled from API response +func (d *SecurityDetectionRuleData) updateTimestampOverrideFallbackDisabledFromApi(ctx context.Context, timestampOverrideFallbackDisabled *kbapi.SecurityDetectionsAPITimestampOverrideFallbackDisabled) diag.Diagnostics { + var diags diag.Diagnostics + + if timestampOverrideFallbackDisabled != nil { + d.TimestampOverrideFallbackDisabled = types.BoolValue(bool(*timestampOverrideFallbackDisabled)) + } else { + d.TimestampOverrideFallbackDisabled = types.BoolNull() + } + + return diags +} + +// Helper function to update building block type from API response +func (d *SecurityDetectionRuleData) updateBuildingBlockTypeFromApi(ctx context.Context, buildingBlockType *kbapi.SecurityDetectionsAPIBuildingBlockType) diag.Diagnostics { + var diags diag.Diagnostics + + if buildingBlockType != nil { + d.BuildingBlockType = types.StringValue(string(*buildingBlockType)) + } else { + d.BuildingBlockType = types.StringNull() + } + + return diags +} + +// Helper function to update license from API response +func (d *SecurityDetectionRuleData) updateLicenseFromApi(ctx context.Context, license *kbapi.SecurityDetectionsAPIRuleLicense) diag.Diagnostics { + var diags diag.Diagnostics + + if license != nil { + d.License = types.StringValue(string(*license)) + } else { + d.License = types.StringNull() + } + + return diags +} + +// Helper function to update note from API response +func (d *SecurityDetectionRuleData) updateNoteFromApi(ctx context.Context, note *kbapi.SecurityDetectionsAPIInvestigationGuide) diag.Diagnostics { + var diags diag.Diagnostics + + if note != nil { + d.Note = types.StringValue(string(*note)) + } else { + d.Note = types.StringNull() + } + + return diags +} + +// Helper function to update setup from API response +func (d *SecurityDetectionRuleData) updateSetupFromApi(ctx context.Context, setup kbapi.SecurityDetectionsAPISetupGuide) diag.Diagnostics { + var diags diag.Diagnostics + + // Handle setup field - if empty, set to null to maintain consistency with optional schema + if string(setup) != "" { + d.Setup = types.StringValue(string(setup)) + } else { + d.Setup = types.StringNull() + } + + return diags +} + +// Helper function to update exceptions list from API response +func (d *SecurityDetectionRuleData) updateExceptionsListFromApi(ctx context.Context, exceptionsList []kbapi.SecurityDetectionsAPIRuleExceptionList) diag.Diagnostics { + var diags diag.Diagnostics + + if len(exceptionsList) > 0 { + exceptionsListValue, exceptionsListDiags := convertExceptionsListToModel(ctx, exceptionsList) + diags.Append(exceptionsListDiags...) + if !exceptionsListDiags.HasError() { + d.ExceptionsList = exceptionsListValue + } + } else { + d.ExceptionsList = types.ListNull(getExceptionsListElementType()) + } + + return diags +} + +// Helper function to update risk score mapping from API response +func (d *SecurityDetectionRuleData) updateRiskScoreMappingFromApi(ctx context.Context, riskScoreMapping kbapi.SecurityDetectionsAPIRiskScoreMapping) diag.Diagnostics { + var diags diag.Diagnostics + + if len(riskScoreMapping) > 0 { + riskScoreMappingValue, riskScoreMappingDiags := convertRiskScoreMappingToModel(ctx, riskScoreMapping) + diags.Append(riskScoreMappingDiags...) + if !riskScoreMappingDiags.HasError() { + d.RiskScoreMapping = riskScoreMappingValue + } + } else { + d.RiskScoreMapping = types.ListNull(getRiskScoreMappingElementType()) + } + + return diags +} + +// Helper function to update actions from API response +func (d *SecurityDetectionRuleData) updateActionsFromApi(ctx context.Context, actions []kbapi.SecurityDetectionsAPIRuleAction) diag.Diagnostics { + var diags diag.Diagnostics + + if len(actions) > 0 { + actionsListValue, actionDiags := convertActionsToModel(ctx, actions) + diags.Append(actionDiags...) + if !actionDiags.HasError() { + d.Actions = actionsListValue + } + } else { + d.Actions = types.ListNull(getActionElementType()) + } + + return diags +} + +func (d *SecurityDetectionRuleData) updateAlertSuppressionFromApi(ctx context.Context, apiSuppression *kbapi.SecurityDetectionsAPIAlertSuppression) diag.Diagnostics { + var diags diag.Diagnostics + + if apiSuppression == nil { + d.AlertSuppression = types.ObjectNull(getAlertSuppressionType()) + return diags + } + + model := AlertSuppressionModel{} + + // Convert group_by (required field according to API) + if len(apiSuppression.GroupBy) > 0 { + groupByList := make([]attr.Value, len(apiSuppression.GroupBy)) + for i, field := range apiSuppression.GroupBy { + groupByList[i] = types.StringValue(field) + } + model.GroupBy = types.ListValueMust(types.StringType, groupByList) + } else { + model.GroupBy = types.ListNull(types.StringType) + } + + // Convert duration (optional) + if apiSuppression.Duration != nil { + model.Duration = parseDurationFromApi(*apiSuppression.Duration) + } else { + model.Duration = customtypes.NewDurationNull() + } + + // Convert missing_fields_strategy (optional) + if apiSuppression.MissingFieldsStrategy != nil { + model.MissingFieldsStrategy = types.StringValue(string(*apiSuppression.MissingFieldsStrategy)) + } else { + model.MissingFieldsStrategy = types.StringNull() + } + + alertSuppressionObj, objDiags := types.ObjectValueFrom(ctx, getAlertSuppressionType(), model) + diags.Append(objDiags...) + + d.AlertSuppression = alertSuppressionObj + + return diags +} + +func (d *SecurityDetectionRuleData) updateThresholdAlertSuppressionFromApi(ctx context.Context, apiSuppression *kbapi.SecurityDetectionsAPIThresholdAlertSuppression) diag.Diagnostics { + var diags diag.Diagnostics + + if apiSuppression == nil { + d.AlertSuppression = types.ObjectNull(getAlertSuppressionType()) + return diags + } + + model := AlertSuppressionModel{} + + // Threshold alert suppression only has duration field, so we set group_by and missing_fields_strategy to null + model.GroupBy = types.ListNull(types.StringType) + model.MissingFieldsStrategy = types.StringNull() + + // Convert duration (always present in threshold alert suppression) + model.Duration = parseDurationFromApi(apiSuppression.Duration) + + alertSuppressionObj, objDiags := types.ObjectValueFrom(ctx, getAlertSuppressionType(), model) + diags.Append(objDiags...) + + d.AlertSuppression = alertSuppressionObj + + return diags +} + +// updateResponseActionsFromApi updates the ResponseActions field from API response +func (d *SecurityDetectionRuleData) updateResponseActionsFromApi(ctx context.Context, responseActions *[]kbapi.SecurityDetectionsAPIResponseAction) diag.Diagnostics { + var diags diag.Diagnostics + + if responseActions != nil && len(*responseActions) > 0 { + responseActionsValue, responseActionsDiags := convertResponseActionsToModel(ctx, responseActions) + diags.Append(responseActionsDiags...) + if !responseActionsDiags.HasError() { + d.ResponseActions = responseActionsValue + } + } else { + d.ResponseActions = types.ListNull(getResponseActionElementType()) + } + + return diags +} + +// Helper function to update investigation fields from API response +func (d *SecurityDetectionRuleData) updateInvestigationFieldsFromApi(ctx context.Context, investigationFields *kbapi.SecurityDetectionsAPIInvestigationFields) diag.Diagnostics { + var diags diag.Diagnostics + + investigationFieldsValue, investigationFieldsDiags := convertInvestigationFieldsToModel(ctx, investigationFields) + diags.Append(investigationFieldsDiags...) + if diags.HasError() { + return diags + } + d.InvestigationFields = investigationFieldsValue + + return diags +} + +// Helper function to update related integrations from API response +func (d *SecurityDetectionRuleData) updateRelatedIntegrationsFromApi(ctx context.Context, relatedIntegrations *kbapi.SecurityDetectionsAPIRelatedIntegrationArray) diag.Diagnostics { + var diags diag.Diagnostics + + if relatedIntegrations != nil && len(*relatedIntegrations) > 0 { + relatedIntegrationsValue, relatedIntegrationsDiags := convertRelatedIntegrationsToModel(ctx, relatedIntegrations) + diags.Append(relatedIntegrationsDiags...) + if !relatedIntegrationsDiags.HasError() { + d.RelatedIntegrations = relatedIntegrationsValue + } + } else { + d.RelatedIntegrations = types.ListNull(getRelatedIntegrationElementType()) + } + + return diags +} + +// Helper function to update required fields from API response +func (d *SecurityDetectionRuleData) updateRequiredFieldsFromApi(ctx context.Context, requiredFields *kbapi.SecurityDetectionsAPIRequiredFieldArray) diag.Diagnostics { + var diags diag.Diagnostics + + if requiredFields != nil && len(*requiredFields) > 0 { + requiredFieldsValue, requiredFieldsDiags := convertRequiredFieldsToModel(ctx, requiredFields) + diags.Append(requiredFieldsDiags...) + if !requiredFieldsDiags.HasError() { + d.RequiredFields = requiredFieldsValue + } + } else { + d.RequiredFields = types.ListNull(getRequiredFieldElementType()) + } + + return diags +} + +// parseDurationFromApi converts an API duration to customtypes.Duration +func parseDurationFromApi(apiDuration kbapi.SecurityDetectionsAPIAlertSuppressionDuration) customtypes.Duration { + // Convert the API's Value + Unit format back to a duration string + durationStr := strconv.Itoa(apiDuration.Value) + string(apiDuration.Unit) + return customtypes.NewDurationValue(durationStr) +} diff --git a/internal/kibana/security_detection_rule/models_machine_learning.go b/internal/kibana/security_detection_rule/models_machine_learning.go new file mode 100644 index 000000000..f41b61282 --- /dev/null +++ b/internal/kibana/security_detection_rule/models_machine_learning.go @@ -0,0 +1,347 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +type MachineLearningRuleProcessor struct{} + +func (m MachineLearningRuleProcessor) HandlesRuleType(t string) bool { + return t == "machine_learning" +} + +func (m MachineLearningRuleProcessor) ToCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + return d.toMachineLearningRuleCreateProps(ctx, client) +} + +func (m MachineLearningRuleProcessor) ToUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + return d.toMachineLearningRuleUpdateProps(ctx, client) +} + +func (m MachineLearningRuleProcessor) HandlesAPIRuleResponse(rule any) bool { + _, ok := rule.(kbapi.SecurityDetectionsAPIMachineLearningRule) + return ok +} + +func (m MachineLearningRuleProcessor) UpdateFromResponse(ctx context.Context, rule any, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + value, ok := rule.(kbapi.SecurityDetectionsAPIMachineLearningRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return diags + } + + return d.updateFromMachineLearningRule(ctx, &value) +} + +func (m MachineLearningRuleProcessor) ExtractId(response any) (string, diag.Diagnostics) { + var diags diag.Diagnostics + value, ok := response.(kbapi.SecurityDetectionsAPIMachineLearningRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return "", diags + } + return value.Id.String(), diags +} + +func (d SecurityDetectionRuleData) toMachineLearningRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var createProps kbapi.SecurityDetectionsAPIRuleCreateProps + + mlRule := kbapi.SecurityDetectionsAPIMachineLearningRuleCreateProps{ + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIMachineLearningRuleCreatePropsType("machine_learning"), + AnomalyThreshold: kbapi.SecurityDetectionsAPIAnomalyThreshold(d.AnomalyThreshold.ValueInt64()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // Set ML job ID(s) - can be single string or array + if utils.IsKnown(d.MachineLearningJobId) { + jobIds := utils.ListTypeAs[string](ctx, d.MachineLearningJobId, path.Root("machine_learning_job_id"), &diags) + if !diags.HasError() { + var mlJobId kbapi.SecurityDetectionsAPIMachineLearningJobId + err := mlJobId.FromSecurityDetectionsAPIMachineLearningJobId1(jobIds) + if err != nil { + diags.AddError("Error setting ML job IDs", err.Error()) + } else { + mlRule.MachineLearningJobId = mlJobId + } + } + } + + d.setCommonCreateProps(ctx, &CommonCreateProps{ + Actions: &mlRule.Actions, + ResponseActions: &mlRule.ResponseActions, + RuleId: &mlRule.RuleId, + Enabled: &mlRule.Enabled, + From: &mlRule.From, + To: &mlRule.To, + Interval: &mlRule.Interval, + Index: nil, // ML rules don't use index patterns + Author: &mlRule.Author, + Tags: &mlRule.Tags, + FalsePositives: &mlRule.FalsePositives, + References: &mlRule.References, + License: &mlRule.License, + Note: &mlRule.Note, + Setup: &mlRule.Setup, + MaxSignals: &mlRule.MaxSignals, + Version: &mlRule.Version, + ExceptionsList: &mlRule.ExceptionsList, + AlertSuppression: &mlRule.AlertSuppression, + RiskScoreMapping: &mlRule.RiskScoreMapping, + SeverityMapping: &mlRule.SeverityMapping, + RelatedIntegrations: &mlRule.RelatedIntegrations, + RequiredFields: &mlRule.RequiredFields, + BuildingBlockType: &mlRule.BuildingBlockType, + DataViewId: nil, // ML rules don't have DataViewId + Namespace: &mlRule.Namespace, + RuleNameOverride: &mlRule.RuleNameOverride, + TimestampOverride: &mlRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &mlRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &mlRule.InvestigationFields, + }, &diags, client) + + // ML rules don't use index patterns or query + + // Convert to union type + err := createProps.FromSecurityDetectionsAPIMachineLearningRuleCreateProps(mlRule) + if err != nil { + diags.AddError( + "Error building create properties", + "Could not convert ML rule properties: "+err.Error(), + ) + } + + return createProps, diags +} +func (d SecurityDetectionRuleData) toMachineLearningRuleUpdateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps + + // Parse ID to get space_id and rule_id + compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) + diags.Append(resourceIdDiags...) + + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + diags.AddError("ID was not a valid UUID", err.Error()) + return updateProps, diags + } + var id = kbapi.SecurityDetectionsAPIRuleObjectId(uid) + + mlRule := kbapi.SecurityDetectionsAPIMachineLearningRuleUpdateProps{ + Id: &id, + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIMachineLearningRuleUpdatePropsType("machine_learning"), + AnomalyThreshold: kbapi.SecurityDetectionsAPIAnomalyThreshold(d.AnomalyThreshold.ValueInt64()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // For updates, we need to include the rule_id if it's set + if utils.IsKnown(d.RuleId) { + ruleId := kbapi.SecurityDetectionsAPIRuleSignatureId(d.RuleId.ValueString()) + mlRule.RuleId = &ruleId + mlRule.Id = nil // if rule_id is set, we cant send id + } + + // Set ML job ID(s) - can be single string or array + if utils.IsKnown(d.MachineLearningJobId) { + jobIds := utils.ListTypeAs[string](ctx, d.MachineLearningJobId, path.Root("machine_learning_job_id"), &diags) + if !diags.HasError() { + var mlJobId kbapi.SecurityDetectionsAPIMachineLearningJobId + err := mlJobId.FromSecurityDetectionsAPIMachineLearningJobId1(jobIds) + if err != nil { + diags.AddError("Error setting ML job IDs", err.Error()) + } else { + mlRule.MachineLearningJobId = mlJobId + } + } + } + + d.setCommonUpdateProps(ctx, &CommonUpdateProps{ + Actions: &mlRule.Actions, + ResponseActions: &mlRule.ResponseActions, + RuleId: &mlRule.RuleId, + Enabled: &mlRule.Enabled, + From: &mlRule.From, + To: &mlRule.To, + Interval: &mlRule.Interval, + Index: nil, // ML rules don't use index patterns + Author: &mlRule.Author, + Tags: &mlRule.Tags, + FalsePositives: &mlRule.FalsePositives, + References: &mlRule.References, + License: &mlRule.License, + Note: &mlRule.Note, + Setup: &mlRule.Setup, + MaxSignals: &mlRule.MaxSignals, + Version: &mlRule.Version, + ExceptionsList: &mlRule.ExceptionsList, + AlertSuppression: &mlRule.AlertSuppression, + RiskScoreMapping: &mlRule.RiskScoreMapping, + SeverityMapping: &mlRule.SeverityMapping, + RelatedIntegrations: &mlRule.RelatedIntegrations, + RequiredFields: &mlRule.RequiredFields, + BuildingBlockType: &mlRule.BuildingBlockType, + DataViewId: nil, // ML rules don't have DataViewId + Namespace: &mlRule.Namespace, + RuleNameOverride: &mlRule.RuleNameOverride, + TimestampOverride: &mlRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &mlRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &mlRule.InvestigationFields, + Filters: nil, // ML rules don't have Filters + }, &diags, client) + + // ML rules don't use index patterns or query + + // Convert to union type + err = updateProps.FromSecurityDetectionsAPIMachineLearningRuleUpdateProps(mlRule) + if err != nil { + diags.AddError( + "Error building update properties", + "Could not convert ML rule properties: "+err.Error(), + ) + } + + return updateProps, diags +} + +func (d *SecurityDetectionRuleData) updateFromMachineLearningRule(ctx context.Context, rule *kbapi.SecurityDetectionsAPIMachineLearningRule) diag.Diagnostics { + var diags diag.Diagnostics + + compId := clients.CompositeId{ + ClusterId: d.SpaceId.ValueString(), + ResourceId: rule.Id.String(), + } + d.Id = types.StringValue(compId.String()) + + d.RuleId = types.StringValue(string(rule.RuleId)) + d.Name = types.StringValue(string(rule.Name)) + d.Type = types.StringValue(string(rule.Type)) + + // Update common fields (ML doesn't support DataViewId) + d.DataViewId = types.StringNull() + diags.Append(d.updateNamespaceFromApi(ctx, rule.Namespace)...) + diags.Append(d.updateRuleNameOverrideFromApi(ctx, rule.RuleNameOverride)...) + diags.Append(d.updateTimestampOverrideFromApi(ctx, rule.TimestampOverride)...) + diags.Append(d.updateTimestampOverrideFallbackDisabledFromApi(ctx, rule.TimestampOverrideFallbackDisabled)...) + + d.Enabled = types.BoolValue(bool(rule.Enabled)) + d.From = types.StringValue(string(rule.From)) + d.To = types.StringValue(string(rule.To)) + d.Interval = types.StringValue(string(rule.Interval)) + d.Description = types.StringValue(string(rule.Description)) + d.RiskScore = types.Int64Value(int64(rule.RiskScore)) + d.Severity = types.StringValue(string(rule.Severity)) + + // Update building block type + diags.Append(d.updateBuildingBlockTypeFromApi(ctx, rule.BuildingBlockType)...) + d.MaxSignals = types.Int64Value(int64(rule.MaxSignals)) + d.Version = types.Int64Value(int64(rule.Version)) + + // Update read-only fields + d.CreatedAt = types.StringValue(rule.CreatedAt.Format("2006-01-02T15:04:05.000Z")) + d.CreatedBy = types.StringValue(rule.CreatedBy) + d.UpdatedAt = types.StringValue(rule.UpdatedAt.Format("2006-01-02T15:04:05.000Z")) + d.UpdatedBy = types.StringValue(rule.UpdatedBy) + d.Revision = types.Int64Value(int64(rule.Revision)) + + // ML rules don't use index patterns or query + d.Index = types.ListValueMust(types.StringType, []attr.Value{}) + d.Query = types.StringNull() + d.Language = types.StringNull() + + // ML-specific fields + d.AnomalyThreshold = types.Int64Value(int64(rule.AnomalyThreshold)) + + // Handle ML job ID(s) - can be single string or array + // Try to extract as single job ID first, then as array + if singleJobId, err := rule.MachineLearningJobId.AsSecurityDetectionsAPIMachineLearningJobId0(); err == nil { + // Single job ID + d.MachineLearningJobId = utils.ListValueFrom(ctx, []string{string(singleJobId)}, types.StringType, path.Root("machine_learning_job_id"), &diags) + } else if multipleJobIds, err := rule.MachineLearningJobId.AsSecurityDetectionsAPIMachineLearningJobId1(); err == nil { + // Multiple job IDs + jobIdStrings := make([]string, len(multipleJobIds)) + for i, jobId := range multipleJobIds { + jobIdStrings[i] = string(jobId) + } + d.MachineLearningJobId = utils.ListValueFrom(ctx, jobIdStrings, types.StringType, path.Root("machine_learning_job_id"), &diags) + } else { + d.MachineLearningJobId = types.ListValueMust(types.StringType, []attr.Value{}) + } + + // Update author + diags.Append(d.updateAuthorFromApi(ctx, rule.Author)...) + + // Update tags + diags.Append(d.updateTagsFromApi(ctx, rule.Tags)...) + + // Update false positives + diags.Append(d.updateFalsePositivesFromApi(ctx, rule.FalsePositives)...) + + // Update references + diags.Append(d.updateReferencesFromApi(ctx, rule.References)...) + + // Update optional string fields + diags.Append(d.updateLicenseFromApi(ctx, rule.License)...) + diags.Append(d.updateNoteFromApi(ctx, rule.Note)...) + diags.Append(d.updateSetupFromApi(ctx, rule.Setup)...) + + // Update actions + actionDiags := d.updateActionsFromApi(ctx, rule.Actions) + diags.Append(actionDiags...) + + // Update exceptions list + exceptionsListDiags := d.updateExceptionsListFromApi(ctx, rule.ExceptionsList) + diags.Append(exceptionsListDiags...) + + // Update risk score mapping + riskScoreMappingDiags := d.updateRiskScoreMappingFromApi(ctx, rule.RiskScoreMapping) + diags.Append(riskScoreMappingDiags...) + + // Update investigation fields + investigationFieldsDiags := d.updateInvestigationFieldsFromApi(ctx, rule.InvestigationFields) + diags.Append(investigationFieldsDiags...) + + // Update severity mapping + severityMappingDiags := d.updateSeverityMappingFromApi(ctx, &rule.SeverityMapping) + diags.Append(severityMappingDiags...) + + // Update related integrations + relatedIntegrationsDiags := d.updateRelatedIntegrationsFromApi(ctx, &rule.RelatedIntegrations) + diags.Append(relatedIntegrationsDiags...) + + // Update required fields + requiredFieldsDiags := d.updateRequiredFieldsFromApi(ctx, &rule.RequiredFields) + diags.Append(requiredFieldsDiags...) + + // Update alert suppression + alertSuppressionDiags := d.updateAlertSuppressionFromApi(ctx, rule.AlertSuppression) + diags.Append(alertSuppressionDiags...) + + // Update response actions + responseActionsDiags := d.updateResponseActionsFromApi(ctx, rule.ResponseActions) + diags.Append(responseActionsDiags...) + + return diags +} diff --git a/internal/kibana/security_detection_rule/models_new_terms.go b/internal/kibana/security_detection_rule/models_new_terms.go new file mode 100644 index 000000000..0223f9d7d --- /dev/null +++ b/internal/kibana/security_detection_rule/models_new_terms.go @@ -0,0 +1,332 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +type NewTermsRuleProcessor struct{} + +func (n NewTermsRuleProcessor) HandlesRuleType(t string) bool { + return t == "new_terms" +} + +func (n NewTermsRuleProcessor) ToCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + return d.toNewTermsRuleCreateProps(ctx, client) +} + +func (n NewTermsRuleProcessor) ToUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + return d.toNewTermsRuleUpdateProps(ctx, client) +} + +func (n NewTermsRuleProcessor) HandlesAPIRuleResponse(rule any) bool { + _, ok := rule.(kbapi.SecurityDetectionsAPINewTermsRule) + return ok +} + +func (n NewTermsRuleProcessor) UpdateFromResponse(ctx context.Context, rule any, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + value, ok := rule.(kbapi.SecurityDetectionsAPINewTermsRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return diags + } + + return d.updateFromNewTermsRule(ctx, &value) +} + +func (n NewTermsRuleProcessor) ExtractId(response any) (string, diag.Diagnostics) { + var diags diag.Diagnostics + value, ok := response.(kbapi.SecurityDetectionsAPINewTermsRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return "", diags + } + return value.Id.String(), diags +} + +func (d SecurityDetectionRuleData) toNewTermsRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var createProps kbapi.SecurityDetectionsAPIRuleCreateProps + + newTermsRule := kbapi.SecurityDetectionsAPINewTermsRuleCreateProps{ + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPINewTermsRuleCreatePropsType("new_terms"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + HistoryWindowStart: kbapi.SecurityDetectionsAPIHistoryWindowStart(d.HistoryWindowStart.ValueString()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // Set new terms fields + if utils.IsKnown(d.NewTermsFields) { + newTermsFields := utils.ListTypeAs[string](ctx, d.NewTermsFields, path.Root("new_terms_fields"), &diags) + if !diags.HasError() { + newTermsRule.NewTermsFields = newTermsFields + } + } + + d.setCommonCreateProps(ctx, &CommonCreateProps{ + Actions: &newTermsRule.Actions, + ResponseActions: &newTermsRule.ResponseActions, + RuleId: &newTermsRule.RuleId, + Enabled: &newTermsRule.Enabled, + From: &newTermsRule.From, + To: &newTermsRule.To, + Interval: &newTermsRule.Interval, + Index: &newTermsRule.Index, + Author: &newTermsRule.Author, + Tags: &newTermsRule.Tags, + FalsePositives: &newTermsRule.FalsePositives, + References: &newTermsRule.References, + License: &newTermsRule.License, + Note: &newTermsRule.Note, + Setup: &newTermsRule.Setup, + MaxSignals: &newTermsRule.MaxSignals, + Version: &newTermsRule.Version, + ExceptionsList: &newTermsRule.ExceptionsList, + AlertSuppression: &newTermsRule.AlertSuppression, + RiskScoreMapping: &newTermsRule.RiskScoreMapping, + SeverityMapping: &newTermsRule.SeverityMapping, + RelatedIntegrations: &newTermsRule.RelatedIntegrations, + RequiredFields: &newTermsRule.RequiredFields, + BuildingBlockType: &newTermsRule.BuildingBlockType, + DataViewId: &newTermsRule.DataViewId, + Namespace: &newTermsRule.Namespace, + RuleNameOverride: &newTermsRule.RuleNameOverride, + TimestampOverride: &newTermsRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &newTermsRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &newTermsRule.InvestigationFields, + Filters: &newTermsRule.Filters, + }, &diags, client) + + // Set query language + newTermsRule.Language = d.getKQLQueryLanguage() + + // Convert to union type + err := createProps.FromSecurityDetectionsAPINewTermsRuleCreateProps(newTermsRule) + if err != nil { + diags.AddError( + "Error building create properties", + "Could not convert new terms rule properties: "+err.Error(), + ) + } + + return createProps, diags +} +func (d SecurityDetectionRuleData) toNewTermsRuleUpdateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps + + // Parse ID to get space_id and rule_id + compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) + diags.Append(resourceIdDiags...) + + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + diags.AddError("ID was not a valid UUID", err.Error()) + return updateProps, diags + } + var id = kbapi.SecurityDetectionsAPIRuleObjectId(uid) + + newTermsRule := kbapi.SecurityDetectionsAPINewTermsRuleUpdateProps{ + Id: &id, + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPINewTermsRuleUpdatePropsType("new_terms"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + HistoryWindowStart: kbapi.SecurityDetectionsAPIHistoryWindowStart(d.HistoryWindowStart.ValueString()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // For updates, we need to include the rule_id if it's set + if utils.IsKnown(d.RuleId) { + ruleId := kbapi.SecurityDetectionsAPIRuleSignatureId(d.RuleId.ValueString()) + newTermsRule.RuleId = &ruleId + newTermsRule.Id = nil // if rule_id is set, we cant send id + } + + // Set new terms fields + if utils.IsKnown(d.NewTermsFields) { + newTermsFields := utils.ListTypeAs[string](ctx, d.NewTermsFields, path.Root("new_terms_fields"), &diags) + if !diags.HasError() { + newTermsRule.NewTermsFields = newTermsFields + } + } + + d.setCommonUpdateProps(ctx, &CommonUpdateProps{ + Actions: &newTermsRule.Actions, + ResponseActions: &newTermsRule.ResponseActions, + RuleId: &newTermsRule.RuleId, + Enabled: &newTermsRule.Enabled, + From: &newTermsRule.From, + To: &newTermsRule.To, + Interval: &newTermsRule.Interval, + Index: &newTermsRule.Index, + Author: &newTermsRule.Author, + Tags: &newTermsRule.Tags, + FalsePositives: &newTermsRule.FalsePositives, + References: &newTermsRule.References, + License: &newTermsRule.License, + Note: &newTermsRule.Note, + InvestigationFields: &newTermsRule.InvestigationFields, + Setup: &newTermsRule.Setup, + MaxSignals: &newTermsRule.MaxSignals, + Version: &newTermsRule.Version, + ExceptionsList: &newTermsRule.ExceptionsList, + AlertSuppression: &newTermsRule.AlertSuppression, + RiskScoreMapping: &newTermsRule.RiskScoreMapping, + SeverityMapping: &newTermsRule.SeverityMapping, + RelatedIntegrations: &newTermsRule.RelatedIntegrations, + RequiredFields: &newTermsRule.RequiredFields, + BuildingBlockType: &newTermsRule.BuildingBlockType, + DataViewId: &newTermsRule.DataViewId, + Namespace: &newTermsRule.Namespace, + RuleNameOverride: &newTermsRule.RuleNameOverride, + TimestampOverride: &newTermsRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &newTermsRule.TimestampOverrideFallbackDisabled, + Filters: &newTermsRule.Filters, + }, &diags, client) + + // Set query language + newTermsRule.Language = d.getKQLQueryLanguage() + + // Convert to union type + err = updateProps.FromSecurityDetectionsAPINewTermsRuleUpdateProps(newTermsRule) + if err != nil { + diags.AddError( + "Error building update properties", + "Could not convert new terms rule properties: "+err.Error(), + ) + } + + return updateProps, diags +} +func (d *SecurityDetectionRuleData) updateFromNewTermsRule(ctx context.Context, rule *kbapi.SecurityDetectionsAPINewTermsRule) diag.Diagnostics { + var diags diag.Diagnostics + + compId := clients.CompositeId{ + ClusterId: d.SpaceId.ValueString(), + ResourceId: rule.Id.String(), + } + d.Id = types.StringValue(compId.String()) + + d.RuleId = types.StringValue(string(rule.RuleId)) + d.Name = types.StringValue(string(rule.Name)) + d.Type = types.StringValue(string(rule.Type)) + + // Update common fields + diags.Append(d.updateDataViewIdFromApi(ctx, rule.DataViewId)...) + diags.Append(d.updateNamespaceFromApi(ctx, rule.Namespace)...) + diags.Append(d.updateRuleNameOverrideFromApi(ctx, rule.RuleNameOverride)...) + diags.Append(d.updateTimestampOverrideFromApi(ctx, rule.TimestampOverride)...) + diags.Append(d.updateTimestampOverrideFallbackDisabledFromApi(ctx, rule.TimestampOverrideFallbackDisabled)...) + + d.Query = types.StringValue(rule.Query) + d.Language = types.StringValue(string(rule.Language)) + d.Enabled = types.BoolValue(bool(rule.Enabled)) + + // Update building block type + diags.Append(d.updateBuildingBlockTypeFromApi(ctx, rule.BuildingBlockType)...) + d.From = types.StringValue(string(rule.From)) + d.To = types.StringValue(string(rule.To)) + d.Interval = types.StringValue(string(rule.Interval)) + d.Description = types.StringValue(string(rule.Description)) + d.RiskScore = types.Int64Value(int64(rule.RiskScore)) + d.Severity = types.StringValue(string(rule.Severity)) + d.MaxSignals = types.Int64Value(int64(rule.MaxSignals)) + d.Version = types.Int64Value(int64(rule.Version)) + + // Update read-only fields + d.CreatedAt = types.StringValue(rule.CreatedAt.Format("2006-01-02T15:04:05.000Z")) + d.CreatedBy = types.StringValue(rule.CreatedBy) + d.UpdatedAt = types.StringValue(rule.UpdatedAt.Format("2006-01-02T15:04:05.000Z")) + d.UpdatedBy = types.StringValue(rule.UpdatedBy) + d.Revision = types.Int64Value(int64(rule.Revision)) + + // Update index patterns + diags.Append(d.updateIndexFromApi(ctx, rule.Index)...) + + // New Terms-specific fields + d.HistoryWindowStart = types.StringValue(string(rule.HistoryWindowStart)) + if len(rule.NewTermsFields) > 0 { + d.NewTermsFields = utils.ListValueFrom(ctx, rule.NewTermsFields, types.StringType, path.Root("new_terms_fields"), &diags) + } else { + d.NewTermsFields = types.ListValueMust(types.StringType, []attr.Value{}) + } + + // Update author + diags.Append(d.updateAuthorFromApi(ctx, rule.Author)...) + + // Update tags + diags.Append(d.updateTagsFromApi(ctx, rule.Tags)...) + + // Update false positives + diags.Append(d.updateFalsePositivesFromApi(ctx, rule.FalsePositives)...) + + // Update references + diags.Append(d.updateReferencesFromApi(ctx, rule.References)...) + + // Update optional string fields + diags.Append(d.updateLicenseFromApi(ctx, rule.License)...) + diags.Append(d.updateNoteFromApi(ctx, rule.Note)...) + diags.Append(d.updateSetupFromApi(ctx, rule.Setup)...) + + // Update actions + actionDiags := d.updateActionsFromApi(ctx, rule.Actions) + diags.Append(actionDiags...) + + // Update exceptions list + exceptionsListDiags := d.updateExceptionsListFromApi(ctx, rule.ExceptionsList) + diags.Append(exceptionsListDiags...) + + // Update risk score mapping + riskScoreMappingDiags := d.updateRiskScoreMappingFromApi(ctx, rule.RiskScoreMapping) + diags.Append(riskScoreMappingDiags...) + + // Update investigation fields + investigationFieldsDiags := d.updateInvestigationFieldsFromApi(ctx, rule.InvestigationFields) + diags.Append(investigationFieldsDiags...) + + // Update filters field + filtersDiags := d.updateFiltersFromApi(ctx, rule.Filters) + diags.Append(filtersDiags...) + + // Update severity mapping + severityMappingDiags := d.updateSeverityMappingFromApi(ctx, &rule.SeverityMapping) + diags.Append(severityMappingDiags...) + + // Update related integrations + relatedIntegrationsDiags := d.updateRelatedIntegrationsFromApi(ctx, &rule.RelatedIntegrations) + diags.Append(relatedIntegrationsDiags...) + + // Update required fields + requiredFieldsDiags := d.updateRequiredFieldsFromApi(ctx, &rule.RequiredFields) + diags.Append(requiredFieldsDiags...) + + // Update alert suppression + alertSuppressionDiags := d.updateAlertSuppressionFromApi(ctx, rule.AlertSuppression) + diags.Append(alertSuppressionDiags...) + + // Update response actions + responseActionsDiags := d.updateResponseActionsFromApi(ctx, rule.ResponseActions) + diags.Append(responseActionsDiags...) + + return diags +} diff --git a/internal/kibana/security_detection_rule/models_query.go b/internal/kibana/security_detection_rule/models_query.go new file mode 100644 index 000000000..1f880a615 --- /dev/null +++ b/internal/kibana/security_detection_rule/models_query.go @@ -0,0 +1,338 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +type QueryRuleProcessor struct{} + +func (q QueryRuleProcessor) HandlesRuleType(t string) bool { + return t == "query" +} + +func (q QueryRuleProcessor) ToCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + return toQueryRuleCreateProps(ctx, client, d) +} + +func (q QueryRuleProcessor) ToUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + return toQueryRuleUpdateProps(ctx, client, d) +} + +func (q QueryRuleProcessor) HandlesAPIRuleResponse(rule any) bool { + _, ok := rule.(kbapi.SecurityDetectionsAPIQueryRule) + return ok +} + +func (q QueryRuleProcessor) UpdateFromResponse(ctx context.Context, rule any, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + value, ok := rule.(kbapi.SecurityDetectionsAPIQueryRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return diags + } + + return updateFromQueryRule(ctx, &value, d) +} + +func (q QueryRuleProcessor) ExtractId(response any) (string, diag.Diagnostics) { + var diags diag.Diagnostics + value, ok := response.(kbapi.SecurityDetectionsAPIQueryRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return "", diags + } + return value.Id.String(), diags +} + +func toQueryRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var createProps kbapi.SecurityDetectionsAPIRuleCreateProps + + queryRuleQuery := kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()) + queryRule := kbapi.SecurityDetectionsAPIQueryRuleCreateProps{ + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIQueryRuleCreatePropsType("query"), + Query: &queryRuleQuery, + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + d.setCommonCreateProps(ctx, &CommonCreateProps{ + Actions: &queryRule.Actions, + ResponseActions: &queryRule.ResponseActions, + RuleId: &queryRule.RuleId, + Enabled: &queryRule.Enabled, + From: &queryRule.From, + To: &queryRule.To, + Interval: &queryRule.Interval, + Index: &queryRule.Index, + Author: &queryRule.Author, + Tags: &queryRule.Tags, + FalsePositives: &queryRule.FalsePositives, + References: &queryRule.References, + License: &queryRule.License, + Note: &queryRule.Note, + Setup: &queryRule.Setup, + MaxSignals: &queryRule.MaxSignals, + Version: &queryRule.Version, + ExceptionsList: &queryRule.ExceptionsList, + AlertSuppression: &queryRule.AlertSuppression, + RiskScoreMapping: &queryRule.RiskScoreMapping, + SeverityMapping: &queryRule.SeverityMapping, + RelatedIntegrations: &queryRule.RelatedIntegrations, + RequiredFields: &queryRule.RequiredFields, + BuildingBlockType: &queryRule.BuildingBlockType, + DataViewId: &queryRule.DataViewId, + Namespace: &queryRule.Namespace, + RuleNameOverride: &queryRule.RuleNameOverride, + TimestampOverride: &queryRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &queryRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &queryRule.InvestigationFields, + Filters: &queryRule.Filters, + }, &diags, client) + + // Set query-specific fields + queryRule.Language = d.getKQLQueryLanguage() + + if utils.IsKnown(d.SavedId) { + savedId := kbapi.SecurityDetectionsAPISavedQueryId(d.SavedId.ValueString()) + queryRule.SavedId = &savedId + } + + // Convert to union type + err := createProps.FromSecurityDetectionsAPIQueryRuleCreateProps(queryRule) + if err != nil { + diags.AddError( + "Error building create properties", + "Could not convert query rule properties: "+err.Error(), + ) + } + + return createProps, diags +} + +func toQueryRuleUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps + + queryRuleQuery := kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()) + + // Parse ID to get space_id and rule_id + compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) + diags.Append(resourceIdDiags...) + + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + diags.AddError("ID was not a valid UUID", err.Error()) + return updateProps, diags + } + var id = kbapi.SecurityDetectionsAPIRuleObjectId(uid) + + queryRule := kbapi.SecurityDetectionsAPIQueryRuleUpdateProps{ + Id: &id, + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIQueryRuleUpdatePropsType("query"), + Query: &queryRuleQuery, + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // For updates, we need to include the rule_id if it's set + if utils.IsKnown(d.RuleId) { + ruleId := kbapi.SecurityDetectionsAPIRuleSignatureId(d.RuleId.ValueString()) + queryRule.RuleId = &ruleId + queryRule.Id = nil // if rule_id is set, we cant send id + } + + d.setCommonUpdateProps(ctx, &CommonUpdateProps{ + Actions: &queryRule.Actions, + ResponseActions: &queryRule.ResponseActions, + RuleId: &queryRule.RuleId, + Enabled: &queryRule.Enabled, + From: &queryRule.From, + To: &queryRule.To, + Interval: &queryRule.Interval, + Index: &queryRule.Index, + Author: &queryRule.Author, + Tags: &queryRule.Tags, + FalsePositives: &queryRule.FalsePositives, + References: &queryRule.References, + License: &queryRule.License, + Note: &queryRule.Note, + Setup: &queryRule.Setup, + MaxSignals: &queryRule.MaxSignals, + Version: &queryRule.Version, + ExceptionsList: &queryRule.ExceptionsList, + AlertSuppression: &queryRule.AlertSuppression, + RiskScoreMapping: &queryRule.RiskScoreMapping, + SeverityMapping: &queryRule.SeverityMapping, + RelatedIntegrations: &queryRule.RelatedIntegrations, + RequiredFields: &queryRule.RequiredFields, + BuildingBlockType: &queryRule.BuildingBlockType, + DataViewId: &queryRule.DataViewId, + Namespace: &queryRule.Namespace, + RuleNameOverride: &queryRule.RuleNameOverride, + TimestampOverride: &queryRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &queryRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &queryRule.InvestigationFields, + Filters: &queryRule.Filters, + }, &diags, client) + + // Set query-specific fields + queryRule.Language = d.getKQLQueryLanguage() + + if utils.IsKnown(d.SavedId) { + savedId := kbapi.SecurityDetectionsAPISavedQueryId(d.SavedId.ValueString()) + queryRule.SavedId = &savedId + } + + // Convert to union type + err = updateProps.FromSecurityDetectionsAPIQueryRuleUpdateProps(queryRule) + if err != nil { + diags.AddError( + "Error building update properties", + "Could not convert query rule properties: "+err.Error(), + ) + } + + return updateProps, diags +} +func updateFromQueryRule(ctx context.Context, rule *kbapi.SecurityDetectionsAPIQueryRule, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + + compId := clients.CompositeId{ + ClusterId: d.SpaceId.ValueString(), + ResourceId: rule.Id.String(), + } + d.Id = types.StringValue(compId.String()) + + d.RuleId = types.StringValue(string(rule.RuleId)) + d.Name = types.StringValue(string(rule.Name)) + d.Type = types.StringValue(string(rule.Type)) + + // Update common fields + dataViewIdDiags := d.updateDataViewIdFromApi(ctx, rule.DataViewId) + diags.Append(dataViewIdDiags...) + + namespaceDiags := d.updateNamespaceFromApi(ctx, rule.Namespace) + diags.Append(namespaceDiags...) + + ruleNameOverrideDiags := d.updateRuleNameOverrideFromApi(ctx, rule.RuleNameOverride) + diags.Append(ruleNameOverrideDiags...) + + timestampOverrideDiags := d.updateTimestampOverrideFromApi(ctx, rule.TimestampOverride) + diags.Append(timestampOverrideDiags...) + + timestampOverrideFallbackDisabledDiags := d.updateTimestampOverrideFallbackDisabledFromApi(ctx, rule.TimestampOverrideFallbackDisabled) + diags.Append(timestampOverrideFallbackDisabledDiags...) + + d.Query = types.StringValue(rule.Query) + d.Language = types.StringValue(string(rule.Language)) + d.Enabled = types.BoolValue(bool(rule.Enabled)) + d.From = types.StringValue(string(rule.From)) + d.To = types.StringValue(string(rule.To)) + d.Interval = types.StringValue(string(rule.Interval)) + d.Description = types.StringValue(string(rule.Description)) + d.RiskScore = types.Int64Value(int64(rule.RiskScore)) + d.Severity = types.StringValue(string(rule.Severity)) + d.MaxSignals = types.Int64Value(int64(rule.MaxSignals)) + d.Version = types.Int64Value(int64(rule.Version)) + + // Update building block type + buildingBlockTypeDiags := d.updateBuildingBlockTypeFromApi(ctx, rule.BuildingBlockType) + diags.Append(buildingBlockTypeDiags...) + + // Update read-only fields + d.CreatedAt = utils.TimeToStringValue(rule.CreatedAt) + d.CreatedBy = types.StringValue(rule.CreatedBy) + d.UpdatedAt = utils.TimeToStringValue(rule.UpdatedAt) + d.UpdatedBy = types.StringValue(rule.UpdatedBy) + d.Revision = types.Int64Value(int64(rule.Revision)) + + // Update index patterns + indexDiags := d.updateIndexFromApi(ctx, rule.Index) + diags.Append(indexDiags...) + + // Update author + authorDiags := d.updateAuthorFromApi(ctx, rule.Author) + diags.Append(authorDiags...) + + // Update tags + tagsDiags := d.updateTagsFromApi(ctx, rule.Tags) + diags.Append(tagsDiags...) + + // Update false positives + falsePositivesDiags := d.updateFalsePositivesFromApi(ctx, rule.FalsePositives) + diags.Append(falsePositivesDiags...) + + // Update references + referencesDiags := d.updateReferencesFromApi(ctx, rule.References) + diags.Append(referencesDiags...) + + // Update optional string fields + licenseDiags := d.updateLicenseFromApi(ctx, rule.License) + diags.Append(licenseDiags...) + + noteDiags := d.updateNoteFromApi(ctx, rule.Note) + diags.Append(noteDiags...) + + setupDiags := d.updateSetupFromApi(ctx, rule.Setup) + diags.Append(setupDiags...) + + // Update actions + actionDiags := d.updateActionsFromApi(ctx, rule.Actions) + diags.Append(actionDiags...) + + // Update exceptions list + exceptionsListDiags := d.updateExceptionsListFromApi(ctx, rule.ExceptionsList) + diags.Append(exceptionsListDiags...) + + // Update risk score mapping + riskScoreMappingDiags := d.updateRiskScoreMappingFromApi(ctx, rule.RiskScoreMapping) + diags.Append(riskScoreMappingDiags...) + + // Update severity mapping + severityMappingDiags := d.updateSeverityMappingFromApi(ctx, &rule.SeverityMapping) + diags.Append(severityMappingDiags...) + + // Update related integrations + relatedIntegrationsDiags := d.updateRelatedIntegrationsFromApi(ctx, &rule.RelatedIntegrations) + diags.Append(relatedIntegrationsDiags...) + + // Update required fields + requiredFieldsDiags := d.updateRequiredFieldsFromApi(ctx, &rule.RequiredFields) + diags.Append(requiredFieldsDiags...) + + // Update investigation fields + investigationFieldsDiags := d.updateInvestigationFieldsFromApi(ctx, rule.InvestigationFields) + diags.Append(investigationFieldsDiags...) + + // Update filters field + filtersDiags := d.updateFiltersFromApi(ctx, rule.Filters) + diags.Append(filtersDiags...) + + // Update alert suppression + alertSuppressionDiags := d.updateAlertSuppressionFromApi(ctx, rule.AlertSuppression) + diags.Append(alertSuppressionDiags...) + + // Update response actions + responseActionsDiags := d.updateResponseActionsFromApi(ctx, rule.ResponseActions) + diags.Append(responseActionsDiags...) + + return diags +} diff --git a/internal/kibana/security_detection_rule/models_saved_query.go b/internal/kibana/security_detection_rule/models_saved_query.go new file mode 100644 index 000000000..55037531c --- /dev/null +++ b/internal/kibana/security_detection_rule/models_saved_query.go @@ -0,0 +1,322 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +type SavedQueryRuleProcessor struct{} + +func (s SavedQueryRuleProcessor) HandlesRuleType(t string) bool { + return t == "saved_query" +} + +func (s SavedQueryRuleProcessor) ToCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + return d.toSavedQueryRuleCreateProps(ctx, client) +} + +func (s SavedQueryRuleProcessor) ToUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + return d.toSavedQueryRuleUpdateProps(ctx, client) +} + +func (s SavedQueryRuleProcessor) HandlesAPIRuleResponse(rule any) bool { + _, ok := rule.(kbapi.SecurityDetectionsAPISavedQueryRule) + return ok +} + +func (s SavedQueryRuleProcessor) UpdateFromResponse(ctx context.Context, rule any, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + value, ok := rule.(kbapi.SecurityDetectionsAPISavedQueryRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return diags + } + + return d.updateFromSavedQueryRule(ctx, &value) +} + +func (s SavedQueryRuleProcessor) ExtractId(response any) (string, diag.Diagnostics) { + var diags diag.Diagnostics + value, ok := response.(kbapi.SecurityDetectionsAPISavedQueryRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return "", diags + } + return value.Id.String(), diags +} + +func (d SecurityDetectionRuleData) toSavedQueryRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var createProps kbapi.SecurityDetectionsAPIRuleCreateProps + + savedQueryRule := kbapi.SecurityDetectionsAPISavedQueryRuleCreateProps{ + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPISavedQueryRuleCreatePropsType("saved_query"), + SavedId: kbapi.SecurityDetectionsAPISavedQueryId(d.SavedId.ValueString()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + d.setCommonCreateProps(ctx, &CommonCreateProps{ + Actions: &savedQueryRule.Actions, + ResponseActions: &savedQueryRule.ResponseActions, + RuleId: &savedQueryRule.RuleId, + Enabled: &savedQueryRule.Enabled, + From: &savedQueryRule.From, + To: &savedQueryRule.To, + Interval: &savedQueryRule.Interval, + Index: &savedQueryRule.Index, + Author: &savedQueryRule.Author, + Tags: &savedQueryRule.Tags, + FalsePositives: &savedQueryRule.FalsePositives, + References: &savedQueryRule.References, + License: &savedQueryRule.License, + Note: &savedQueryRule.Note, + Setup: &savedQueryRule.Setup, + MaxSignals: &savedQueryRule.MaxSignals, + Version: &savedQueryRule.Version, + ExceptionsList: &savedQueryRule.ExceptionsList, + AlertSuppression: &savedQueryRule.AlertSuppression, + RiskScoreMapping: &savedQueryRule.RiskScoreMapping, + SeverityMapping: &savedQueryRule.SeverityMapping, + RelatedIntegrations: &savedQueryRule.RelatedIntegrations, + RequiredFields: &savedQueryRule.RequiredFields, + BuildingBlockType: &savedQueryRule.BuildingBlockType, + DataViewId: &savedQueryRule.DataViewId, + Namespace: &savedQueryRule.Namespace, + RuleNameOverride: &savedQueryRule.RuleNameOverride, + TimestampOverride: &savedQueryRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &savedQueryRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &savedQueryRule.InvestigationFields, + Filters: &savedQueryRule.Filters, + }, &diags, client) + + // Set optional query for saved query rules + if utils.IsKnown(d.Query) { + query := kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()) + savedQueryRule.Query = &query + } + + // Set query language + savedQueryRule.Language = d.getKQLQueryLanguage() + + // Convert to union type + err := createProps.FromSecurityDetectionsAPISavedQueryRuleCreateProps(savedQueryRule) + if err != nil { + diags.AddError( + "Error building create properties", + "Could not convert saved query rule properties: "+err.Error(), + ) + } + + return createProps, diags +} +func (d SecurityDetectionRuleData) toSavedQueryRuleUpdateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps + + // Parse ID to get space_id and rule_id + compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) + diags.Append(resourceIdDiags...) + + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + diags.AddError("ID was not a valid UUID", err.Error()) + return updateProps, diags + } + var id = kbapi.SecurityDetectionsAPIRuleObjectId(uid) + + savedQueryRule := kbapi.SecurityDetectionsAPISavedQueryRuleUpdateProps{ + Id: &id, + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPISavedQueryRuleUpdatePropsType("saved_query"), + SavedId: kbapi.SecurityDetectionsAPISavedQueryId(d.SavedId.ValueString()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // For updates, we need to include the rule_id if it's set + if utils.IsKnown(d.RuleId) { + ruleId := kbapi.SecurityDetectionsAPIRuleSignatureId(d.RuleId.ValueString()) + savedQueryRule.RuleId = &ruleId + savedQueryRule.Id = nil // if rule_id is set, we cant send id + } + + d.setCommonUpdateProps(ctx, &CommonUpdateProps{ + Actions: &savedQueryRule.Actions, + ResponseActions: &savedQueryRule.ResponseActions, + RuleId: &savedQueryRule.RuleId, + Enabled: &savedQueryRule.Enabled, + From: &savedQueryRule.From, + To: &savedQueryRule.To, + Interval: &savedQueryRule.Interval, + Index: &savedQueryRule.Index, + Author: &savedQueryRule.Author, + Tags: &savedQueryRule.Tags, + FalsePositives: &savedQueryRule.FalsePositives, + References: &savedQueryRule.References, + License: &savedQueryRule.License, + Note: &savedQueryRule.Note, + InvestigationFields: &savedQueryRule.InvestigationFields, + Setup: &savedQueryRule.Setup, + MaxSignals: &savedQueryRule.MaxSignals, + Version: &savedQueryRule.Version, + ExceptionsList: &savedQueryRule.ExceptionsList, + AlertSuppression: &savedQueryRule.AlertSuppression, + RiskScoreMapping: &savedQueryRule.RiskScoreMapping, + SeverityMapping: &savedQueryRule.SeverityMapping, + RelatedIntegrations: &savedQueryRule.RelatedIntegrations, + RequiredFields: &savedQueryRule.RequiredFields, + BuildingBlockType: &savedQueryRule.BuildingBlockType, + DataViewId: &savedQueryRule.DataViewId, + Namespace: &savedQueryRule.Namespace, + RuleNameOverride: &savedQueryRule.RuleNameOverride, + TimestampOverride: &savedQueryRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &savedQueryRule.TimestampOverrideFallbackDisabled, + Filters: &savedQueryRule.Filters, + }, &diags, client) + + // Set optional query for saved query rules + if utils.IsKnown(d.Query) { + query := kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()) + savedQueryRule.Query = &query + } + + // Set query language + savedQueryRule.Language = d.getKQLQueryLanguage() + + // Convert to union type + err = updateProps.FromSecurityDetectionsAPISavedQueryRuleUpdateProps(savedQueryRule) + if err != nil { + diags.AddError( + "Error building update properties", + "Could not convert saved query rule properties: "+err.Error(), + ) + } + + return updateProps, diags +} + +func (d *SecurityDetectionRuleData) updateFromSavedQueryRule(ctx context.Context, rule *kbapi.SecurityDetectionsAPISavedQueryRule) diag.Diagnostics { + var diags diag.Diagnostics + + compId := clients.CompositeId{ + ClusterId: d.SpaceId.ValueString(), + ResourceId: rule.Id.String(), + } + d.Id = types.StringValue(compId.String()) + + d.RuleId = types.StringValue(string(rule.RuleId)) + d.Name = types.StringValue(string(rule.Name)) + d.Type = types.StringValue(string(rule.Type)) + + // Update common fields + diags.Append(d.updateDataViewIdFromApi(ctx, rule.DataViewId)...) + diags.Append(d.updateNamespaceFromApi(ctx, rule.Namespace)...) + diags.Append(d.updateRuleNameOverrideFromApi(ctx, rule.RuleNameOverride)...) + diags.Append(d.updateTimestampOverrideFromApi(ctx, rule.TimestampOverride)...) + diags.Append(d.updateTimestampOverrideFallbackDisabledFromApi(ctx, rule.TimestampOverrideFallbackDisabled)...) + + d.SavedId = types.StringValue(string(rule.SavedId)) + d.Enabled = types.BoolValue(bool(rule.Enabled)) + d.From = types.StringValue(string(rule.From)) + + // Update building block type + diags.Append(d.updateBuildingBlockTypeFromApi(ctx, rule.BuildingBlockType)...) + d.To = types.StringValue(string(rule.To)) + d.Interval = types.StringValue(string(rule.Interval)) + d.Description = types.StringValue(string(rule.Description)) + d.RiskScore = types.Int64Value(int64(rule.RiskScore)) + d.Severity = types.StringValue(string(rule.Severity)) + d.MaxSignals = types.Int64Value(int64(rule.MaxSignals)) + d.Version = types.Int64Value(int64(rule.Version)) + + // Update read-only fields + d.CreatedAt = types.StringValue(rule.CreatedAt.Format("2006-01-02T15:04:05.000Z")) + d.CreatedBy = types.StringValue(rule.CreatedBy) + d.UpdatedAt = types.StringValue(rule.UpdatedAt.Format("2006-01-02T15:04:05.000Z")) + d.UpdatedBy = types.StringValue(rule.UpdatedBy) + d.Revision = types.Int64Value(int64(rule.Revision)) + + // Update index patterns + diags.Append(d.updateIndexFromApi(ctx, rule.Index)...) + + // Optional query for saved query rules + d.Query = types.StringPointerValue(rule.Query) + + // Language for saved query rules (not a pointer) + d.Language = types.StringValue(string(rule.Language)) + + // Update author + diags.Append(d.updateAuthorFromApi(ctx, rule.Author)...) + + // Update tags + diags.Append(d.updateTagsFromApi(ctx, rule.Tags)...) + + // Update false positives + diags.Append(d.updateFalsePositivesFromApi(ctx, rule.FalsePositives)...) + + // Update references + diags.Append(d.updateReferencesFromApi(ctx, rule.References)...) + + // Update optional string fields + diags.Append(d.updateLicenseFromApi(ctx, rule.License)...) + diags.Append(d.updateNoteFromApi(ctx, rule.Note)...) + diags.Append(d.updateSetupFromApi(ctx, rule.Setup)...) + + // Update actions + actionDiags := d.updateActionsFromApi(ctx, rule.Actions) + diags.Append(actionDiags...) + + // Update exceptions list + exceptionsListDiags := d.updateExceptionsListFromApi(ctx, rule.ExceptionsList) + diags.Append(exceptionsListDiags...) + + // Update risk score mapping + riskScoreMappingDiags := d.updateRiskScoreMappingFromApi(ctx, rule.RiskScoreMapping) + diags.Append(riskScoreMappingDiags...) + + // Update investigation fields + investigationFieldsDiags := d.updateInvestigationFieldsFromApi(ctx, rule.InvestigationFields) + diags.Append(investigationFieldsDiags...) + + // Update filters field + filtersDiags := d.updateFiltersFromApi(ctx, rule.Filters) + diags.Append(filtersDiags...) + + // Update severity mapping + severityMappingDiags := d.updateSeverityMappingFromApi(ctx, &rule.SeverityMapping) + diags.Append(severityMappingDiags...) + + // Update related integrations + relatedIntegrationsDiags := d.updateRelatedIntegrationsFromApi(ctx, &rule.RelatedIntegrations) + diags.Append(relatedIntegrationsDiags...) + + // Update required fields + requiredFieldsDiags := d.updateRequiredFieldsFromApi(ctx, &rule.RequiredFields) + diags.Append(requiredFieldsDiags...) + + // Update alert suppression + alertSuppressionDiags := d.updateAlertSuppressionFromApi(ctx, rule.AlertSuppression) + diags.Append(alertSuppressionDiags...) + + // Update response actions + responseActionsDiags := d.updateResponseActionsFromApi(ctx, rule.ResponseActions) + diags.Append(responseActionsDiags...) + + return diags +} diff --git a/internal/kibana/security_detection_rule/models_test.go b/internal/kibana/security_detection_rule/models_test.go new file mode 100644 index 000000000..b0ebc16f1 --- /dev/null +++ b/internal/kibana/security_detection_rule/models_test.go @@ -0,0 +1,2261 @@ +package security_detection_rule + +import ( + "context" + "fmt" + "testing" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/elastic/terraform-provider-elasticstack/internal/utils/customtypes" + "github.com/google/uuid" + "github.com/hashicorp/go-version" + "github.com/hashicorp/terraform-plugin-framework-jsontypes/jsontypes" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/types" + "github.com/hashicorp/terraform-plugin-framework/types/basetypes" + v2Diag "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/stretchr/testify/require" +) + +type mockApiClient struct { + serverVersion *version.Version + serverFlavor string + enforceResult bool +} + +func (m mockApiClient) EnforceMinVersion(ctx context.Context, minVersion *version.Version) (bool, v2Diag.Diagnostics) { + supported := m.serverVersion.GreaterThanOrEqual(minVersion) + return supported, nil +} + +// NewMockApiClient creates a new mock API client with default values that support response actions +// This can be used in tests where you need to pass a client to functions like toUpdateProps +func NewMockApiClient() clients.MinVersionEnforceable { + // Use version 8.16.0 by default to support response actions + v, _ := version.NewVersion("8.16.0") + + return mockApiClient{ + serverVersion: v, + serverFlavor: "default", + enforceResult: true, + } +} + +// NewMockApiClientWithVersion creates a mock API client with a specific version +// Use this when you need to test specific version behavior +func NewMockApiClientWithVersion(versionStr string) *mockApiClient { + v, err := version.NewVersion(versionStr) + if err != nil { + panic(fmt.Sprintf("Invalid version in test: %s", versionStr)) + } + return &mockApiClient{ + serverVersion: v, + serverFlavor: "default", + enforceResult: true, + } +} +func TestUpdateFromQueryRule(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + tests := []struct { + name string + rule kbapi.SecurityDetectionsAPIQueryRule + spaceId string + expected SecurityDetectionRuleData + }{ + { + name: "complete query rule", + spaceId: "test-space", + rule: kbapi.SecurityDetectionsAPIQueryRule{ + Id: uuid.MustParse("12345678-1234-1234-1234-123456789012"), + RuleId: "test-rule-id", + Name: "Test Query Rule", + Type: "query", + Query: "user.name:test", + Language: "kuery", + Enabled: true, + From: "now-6m", + To: "now", + Interval: "5m", + Description: "Test description", + RiskScore: 75, + Severity: "medium", + MaxSignals: 100, + Version: 1, + Author: []string{"Test Author"}, + Tags: []string{"test", "detection"}, + Index: utils.Pointer([]string{"logs-*", "metrics-*"}), + CreatedBy: "test-user", + UpdatedBy: "test-user", + Revision: 1, + FalsePositives: []string{"Known false positive"}, + References: []string{"https://example.com/test"}, + License: utils.Pointer(kbapi.SecurityDetectionsAPIRuleLicense("MIT")), + Note: utils.Pointer(kbapi.SecurityDetectionsAPIInvestigationGuide("Investigation note")), + Setup: "Setup instructions", + }, + expected: SecurityDetectionRuleData{ + Id: types.StringValue("test-space/12345678-1234-1234-1234-123456789012"), + SpaceId: types.StringValue("test-space"), + RuleId: types.StringValue("test-rule-id"), + Name: types.StringValue("Test Query Rule"), + Type: types.StringValue("query"), + Query: types.StringValue("user.name:test"), + Language: types.StringValue("kuery"), + Enabled: types.BoolValue(true), + From: types.StringValue("now-6m"), + To: types.StringValue("now"), + Interval: types.StringValue("5m"), + Description: types.StringValue("Test description"), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + MaxSignals: types.Int64Value(100), + Version: types.Int64Value(1), + Author: utils.ListValueFrom(ctx, []string{"Test Author"}, types.StringType, path.Root("author"), &diags), + Tags: utils.ListValueFrom(ctx, []string{"test", "detection"}, types.StringType, path.Root("tags"), &diags), + Index: utils.ListValueFrom(ctx, []string{"logs-*", "metrics-*"}, types.StringType, path.Root("index"), &diags), + CreatedBy: types.StringValue("test-user"), + UpdatedBy: types.StringValue("test-user"), + Revision: types.Int64Value(1), + FalsePositives: utils.ListValueFrom(ctx, []string{"Known false positive"}, types.StringType, path.Root("false_positives"), &diags), + References: utils.ListValueFrom(ctx, []string{"https://example.com/test"}, types.StringType, path.Root("references"), &diags), + License: types.StringValue("MIT"), + Note: types.StringValue("Investigation note"), + Setup: types.StringValue("Setup instructions"), + }, + }, + { + name: "minimal query rule", + spaceId: "default", + rule: kbapi.SecurityDetectionsAPIQueryRule{ + Id: uuid.MustParse("87654321-4321-4321-4321-210987654321"), + RuleId: "minimal-rule", + Name: "Minimal Rule", + Type: "query", + Query: "*", + Language: "kuery", + Enabled: false, + From: "now-1h", + To: "now", + Interval: "1m", + Description: "Minimal test", + RiskScore: 1, + Severity: "low", + MaxSignals: 50, + Version: 1, + CreatedBy: "system", + UpdatedBy: "system", + Revision: 1, + }, + expected: SecurityDetectionRuleData{ + Id: types.StringValue("default/87654321-4321-4321-4321-210987654321"), + SpaceId: types.StringValue("default"), + RuleId: types.StringValue("minimal-rule"), + Name: types.StringValue("Minimal Rule"), + Type: types.StringValue("query"), + Query: types.StringValue("*"), + Language: types.StringValue("kuery"), + Enabled: types.BoolValue(false), + From: types.StringValue("now-1h"), + To: types.StringValue("now"), + Interval: types.StringValue("1m"), + Description: types.StringValue("Minimal test"), + RiskScore: types.Int64Value(1), + Severity: types.StringValue("low"), + MaxSignals: types.Int64Value(50), + Version: types.Int64Value(1), + CreatedBy: types.StringValue("system"), + UpdatedBy: types.StringValue("system"), + Revision: types.Int64Value(1), + Author: types.ListValueMust(types.StringType, []attr.Value{}), + Tags: types.ListValueMust(types.StringType, []attr.Value{}), + Index: types.ListValueMust(types.StringType, []attr.Value{}), + }, + }, + } + + require.Empty(t, diags) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := SecurityDetectionRuleData{ + SpaceId: types.StringValue(tt.spaceId), + } + + diags := updateFromQueryRule(ctx, &tt.rule, &data) + require.Empty(t, diags) + + // Compare key fields + require.Equal(t, tt.expected.Id, data.Id) + require.Equal(t, tt.expected.RuleId, data.RuleId) + require.Equal(t, tt.expected.Name, data.Name) + require.Equal(t, tt.expected.Type, data.Type) + require.Equal(t, tt.expected.Query, data.Query) + require.Equal(t, tt.expected.Language, data.Language) + require.Equal(t, tt.expected.Enabled, data.Enabled) + require.Equal(t, tt.expected.RiskScore, data.RiskScore) + require.Equal(t, tt.expected.Severity, data.Severity) + + // Verify list fields have correct length + require.Equal(t, len(tt.expected.Author.Elements()), len(data.Author.Elements())) + require.Equal(t, len(tt.expected.Tags.Elements()), len(data.Tags.Elements())) + require.Equal(t, len(tt.expected.Index.Elements()), len(data.Index.Elements())) + }) + } +} + +func TestToQueryRuleCreateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + tests := []struct { + name string + data SecurityDetectionRuleData + expectedName string + expectedType string + expectedQuery string + expectedRiskScore int64 + expectedSeverity string + shouldHaveLanguage bool + shouldHaveIndex bool + shouldHaveActions bool + shouldHaveRuleId bool + shouldError bool + }{ + { + name: "complete query rule create", + data: SecurityDetectionRuleData{ + Name: types.StringValue("Test Create Rule"), + Type: types.StringValue("query"), + Query: types.StringValue("process.name:malicious"), + Language: types.StringValue("kuery"), + RiskScore: types.Int64Value(85), + Severity: types.StringValue("high"), + Description: types.StringValue("Test rule description"), + Index: utils.ListValueFrom(ctx, []string{"winlogbeat-*"}, types.StringType, path.Root("index"), &diags), + Author: utils.ListValueFrom(ctx, []string{"Security Team"}, types.StringType, path.Root("author"), &diags), + Enabled: types.BoolValue(true), + RuleId: types.StringValue("custom-rule-id"), + }, + expectedName: "Test Create Rule", + expectedType: "query", + expectedQuery: "process.name:malicious", + expectedRiskScore: 85, + expectedSeverity: "high", + shouldHaveLanguage: true, + shouldHaveIndex: true, + shouldHaveRuleId: true, + }, + { + name: "minimal query rule create", + data: SecurityDetectionRuleData{ + Name: types.StringValue("Minimal Rule"), + Type: types.StringValue("query"), + Query: types.StringValue("*"), + RiskScore: types.Int64Value(1), + Severity: types.StringValue("low"), + Description: types.StringValue("Minimal description"), + }, + expectedName: "Minimal Rule", + expectedType: "query", + expectedQuery: "*", + expectedRiskScore: 1, + expectedSeverity: "low", + }, + } + + require.Empty(t, diags) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + createProps, createDiags := toQueryRuleCreateProps(ctx, NewMockApiClient(), tt.data) + + if tt.shouldError { + require.NotEmpty(t, createDiags) + return + } + + require.Empty(t, createDiags) + + // Extract the concrete type from the union + queryRule, err := createProps.AsSecurityDetectionsAPIQueryRuleCreateProps() + require.NoError(t, err) + + require.Equal(t, tt.expectedName, string(queryRule.Name)) + require.Equal(t, tt.expectedType, string(queryRule.Type)) + require.NotNil(t, queryRule.Query) + require.Equal(t, tt.expectedQuery, string(*queryRule.Query)) + require.Equal(t, tt.expectedRiskScore, int64(queryRule.RiskScore)) + require.Equal(t, tt.expectedSeverity, string(queryRule.Severity)) + + if tt.shouldHaveLanguage { + require.NotNil(t, queryRule.Language) + } + + if tt.shouldHaveIndex { + require.NotNil(t, queryRule.Index) + require.NotEmpty(t, *queryRule.Index) + } + + if tt.shouldHaveRuleId { + require.NotNil(t, queryRule.RuleId) + require.Equal(t, "custom-rule-id", string(*queryRule.RuleId)) + } + }) + } +} + +func TestToEqlRuleCreateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + data := SecurityDetectionRuleData{ + Name: types.StringValue("EQL Test Rule"), + Type: types.StringValue("eql"), + Query: types.StringValue("process where process.name == \"cmd.exe\""), + RiskScore: types.Int64Value(60), + Severity: types.StringValue("medium"), + Description: types.StringValue("EQL rule description"), + TiebreakerField: types.StringValue("@timestamp"), + } + + createProps, createDiags := toEqlRuleCreateProps(ctx, NewMockApiClient(), data) + require.Empty(t, createDiags) + + eqlRule, err := createProps.AsSecurityDetectionsAPIEqlRuleCreateProps() + require.NoError(t, err) + + require.Equal(t, "EQL Test Rule", string(eqlRule.Name)) + require.Equal(t, "eql", string(eqlRule.Type)) + require.Equal(t, "process where process.name == \"cmd.exe\"", string(eqlRule.Query)) + require.Equal(t, "eql", string(eqlRule.Language)) + require.Equal(t, int64(60), int64(eqlRule.RiskScore)) + require.Equal(t, "medium", string(eqlRule.Severity)) + + require.NotNil(t, eqlRule.TiebreakerField) + require.Equal(t, "@timestamp", string(*eqlRule.TiebreakerField)) + + require.Empty(t, diags) +} + +func TestToMachineLearningRuleCreateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + tests := []struct { + name string + data SecurityDetectionRuleData + expectedJobCount int + shouldHaveSingle bool + shouldHaveMultiple bool + }{ + { + name: "single ML job", + data: SecurityDetectionRuleData{ + Name: types.StringValue("ML Test Rule"), + Type: types.StringValue("machine_learning"), + RiskScore: types.Int64Value(70), + Severity: types.StringValue("high"), + Description: types.StringValue("ML rule description"), + AnomalyThreshold: types.Int64Value(50), + MachineLearningJobId: utils.ListValueFrom(ctx, []string{"suspicious_activity"}, types.StringType, path.Root("machine_learning_job_id"), &diags), + }, + expectedJobCount: 1, + shouldHaveMultiple: true, + }, + { + name: "multiple ML jobs", + data: SecurityDetectionRuleData{ + Name: types.StringValue("ML Multi Job Rule"), + Type: types.StringValue("machine_learning"), + RiskScore: types.Int64Value(80), + Severity: types.StringValue("critical"), + Description: types.StringValue("ML multi job rule"), + AnomalyThreshold: types.Int64Value(75), + MachineLearningJobId: utils.ListValueFrom(ctx, []string{"job1", "job2", "job3"}, types.StringType, path.Root("machine_learning_job_id"), &diags), + }, + expectedJobCount: 3, + shouldHaveMultiple: true, + }, + } + + require.Empty(t, diags) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + createProps, createDiags := tt.data.toMachineLearningRuleCreateProps(ctx, NewMockApiClient()) + require.Empty(t, createDiags) + + mlRule, err := createProps.AsSecurityDetectionsAPIMachineLearningRuleCreateProps() + require.NoError(t, err) + + require.Equal(t, tt.data.Name.ValueString(), string(mlRule.Name)) + require.Equal(t, "machine_learning", string(mlRule.Type)) + require.Equal(t, tt.data.AnomalyThreshold.ValueInt64(), int64(mlRule.AnomalyThreshold)) + + if tt.shouldHaveSingle { + ingleJobId, err := mlRule.MachineLearningJobId.AsSecurityDetectionsAPIMachineLearningJobId0() + require.NoError(t, err) + require.Equal(t, "suspicious_activity", string(ingleJobId)) + } + + if tt.shouldHaveMultiple { + multipleJobIds, err := mlRule.MachineLearningJobId.AsSecurityDetectionsAPIMachineLearningJobId1() + require.NoError(t, err) + require.Len(t, multipleJobIds, tt.expectedJobCount) + } + }) + } +} + +func TestToEsqlRuleCreateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + data := SecurityDetectionRuleData{ + Type: types.StringValue("esql"), + Name: types.StringValue("Test ESQL Rule"), + Description: types.StringValue("Test ESQL rule description"), + Query: types.StringValue("FROM logs | WHERE user.name == \"suspicious_user\""), + RiskScore: types.Int64Value(85), + Severity: types.StringValue("high"), + Enabled: types.BoolValue(true), + From: types.StringValue("now-1h"), + To: types.StringValue("now"), + Interval: types.StringValue("10m"), + Author: utils.ListValueFrom(ctx, []string{"Security Team"}, types.StringType, path.Root("author"), &diags), + Tags: utils.ListValueFrom(ctx, []string{"esql", "test"}, types.StringType, path.Root("tags"), &diags), + } + + require.Empty(t, diags) + + createProps, createDiags := data.toEsqlRuleCreateProps(ctx, NewMockApiClient()) + require.Empty(t, createDiags) + + esqlRule, err := createProps.AsSecurityDetectionsAPIEsqlRuleCreateProps() + require.NoError(t, err) + + require.Equal(t, "Test ESQL Rule", string(esqlRule.Name)) + require.Equal(t, "Test ESQL rule description", string(esqlRule.Description)) + require.Equal(t, "esql", string(esqlRule.Type)) + require.Equal(t, "FROM logs | WHERE user.name == \"suspicious_user\"", string(esqlRule.Query)) + require.Equal(t, "esql", string(esqlRule.Language)) + require.Equal(t, int64(85), int64(esqlRule.RiskScore)) + require.Equal(t, "high", string(esqlRule.Severity)) +} + +func TestToNewTermsRuleCreateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + data := SecurityDetectionRuleData{ + Type: types.StringValue("new_terms"), + Name: types.StringValue("Test New Terms Rule"), + Description: types.StringValue("Test new terms rule description"), + Query: types.StringValue("user.name:*"), + Language: types.StringValue("kuery"), + NewTermsFields: utils.ListValueFrom(ctx, []string{"user.name", "host.name"}, types.StringType, path.Root("new_terms_fields"), &diags), + HistoryWindowStart: types.StringValue("now-7d"), + RiskScore: types.Int64Value(60), + Severity: types.StringValue("medium"), + Enabled: types.BoolValue(true), + From: types.StringValue("now-6m"), + To: types.StringValue("now"), + Interval: types.StringValue("5m"), + Index: utils.ListValueFrom(ctx, []string{"logs-*"}, types.StringType, path.Root("index"), &diags), + } + + require.Empty(t, diags) + + createProps, createDiags := data.toNewTermsRuleCreateProps(ctx, NewMockApiClient()) + require.Empty(t, createDiags) + + newTermsRule, err := createProps.AsSecurityDetectionsAPINewTermsRuleCreateProps() + require.NoError(t, err) + + require.Equal(t, "Test New Terms Rule", string(newTermsRule.Name)) + require.Equal(t, "Test new terms rule description", string(newTermsRule.Description)) + require.Equal(t, "new_terms", string(newTermsRule.Type)) + require.Equal(t, "user.name:*", string(newTermsRule.Query)) + require.Equal(t, "now-7d", string(newTermsRule.HistoryWindowStart)) + require.Equal(t, int64(60), int64(newTermsRule.RiskScore)) + require.Equal(t, "medium", string(newTermsRule.Severity)) + require.Len(t, newTermsRule.NewTermsFields, 2) + require.Contains(t, newTermsRule.NewTermsFields, "user.name") + require.Contains(t, newTermsRule.NewTermsFields, "host.name") +} + +func TestToSavedQueryRuleCreateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + data := SecurityDetectionRuleData{ + Type: types.StringValue("saved_query"), + Name: types.StringValue("Test Saved Query Rule"), + Description: types.StringValue("Test saved query rule description"), + SavedId: types.StringValue("my-saved-query-id"), + RiskScore: types.Int64Value(70), + Severity: types.StringValue("high"), + Enabled: types.BoolValue(true), + From: types.StringValue("now-30m"), + To: types.StringValue("now"), + Interval: types.StringValue("15m"), + Index: utils.ListValueFrom(ctx, []string{"auditbeat-*", "filebeat-*"}, types.StringType, path.Root("index"), &diags), + Author: utils.ListValueFrom(ctx, []string{"Security Team"}, types.StringType, path.Root("author"), &diags), + Tags: utils.ListValueFrom(ctx, []string{"saved-query", "detection"}, types.StringType, path.Root("tags"), &diags), + } + + require.Empty(t, diags) + + createProps, createDiags := data.toSavedQueryRuleCreateProps(ctx, NewMockApiClient()) + require.Empty(t, createDiags) + + savedQueryRule, err := createProps.AsSecurityDetectionsAPISavedQueryRuleCreateProps() + require.NoError(t, err) + + require.Equal(t, "Test Saved Query Rule", string(savedQueryRule.Name)) + require.Equal(t, "Test saved query rule description", string(savedQueryRule.Description)) + require.Equal(t, "saved_query", string(savedQueryRule.Type)) + require.Equal(t, "my-saved-query-id", string(savedQueryRule.SavedId)) + require.Equal(t, int64(70), int64(savedQueryRule.RiskScore)) + require.Equal(t, "high", string(savedQueryRule.Severity)) +} + +func TestToThreatMatchRuleCreateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + data := SecurityDetectionRuleData{ + Type: types.StringValue("threat_match"), + Name: types.StringValue("Test Threat Match Rule"), + Description: types.StringValue("Test threat match rule description"), + Query: types.StringValue("source.ip:*"), + Language: types.StringValue("kuery"), + ThreatIndex: utils.ListValueFrom(ctx, []string{"threat-intel-*"}, types.StringType, path.Root("threat_index"), &diags), + ThreatMapping: utils.ListValueFrom(ctx, []SecurityDetectionRuleTfDataItem{ + { + Entries: utils.ListValueFrom(ctx, []SecurityDetectionRuleTfDataItemEntry{ + { + Field: types.StringValue("source.ip"), + Type: types.StringValue("mapping"), + Value: types.StringValue("threat.indicator.ip"), + }, + }, getThreatMappingEntryElementType(), path.Root("threat_mapping").AtListIndex(0).AtName("entries"), &diags), + }, + }, getThreatMappingElementType(), path.Root("threat_mapping"), &diags), + RiskScore: types.Int64Value(90), + Severity: types.StringValue("critical"), + Enabled: types.BoolValue(true), + From: types.StringValue("now-1h"), + To: types.StringValue("now"), + Interval: types.StringValue("5m"), + Index: utils.ListValueFrom(ctx, []string{"logs-*"}, types.StringType, path.Root("index"), &diags), + } + + require.Empty(t, diags) + + createProps, createDiags := data.toThreatMatchRuleCreateProps(ctx, NewMockApiClient()) + require.Empty(t, createDiags) + + threatMatchRule, err := createProps.AsSecurityDetectionsAPIThreatMatchRuleCreateProps() + require.NoError(t, err) + + require.Equal(t, "Test Threat Match Rule", string(threatMatchRule.Name)) + require.Equal(t, "Test threat match rule description", string(threatMatchRule.Description)) + require.Equal(t, "threat_match", string(threatMatchRule.Type)) + require.Equal(t, "source.ip:*", string(threatMatchRule.Query)) + require.Equal(t, int64(90), int64(threatMatchRule.RiskScore)) + require.Equal(t, "critical", string(threatMatchRule.Severity)) + require.Len(t, threatMatchRule.ThreatIndex, 1) + require.Equal(t, "threat-intel-*", threatMatchRule.ThreatIndex[0]) + require.Len(t, threatMatchRule.ThreatMapping, 1) +} + +func TestToThresholdRuleCreateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + data := SecurityDetectionRuleData{ + Type: types.StringValue("threshold"), + Name: types.StringValue("Test Threshold Rule"), + Description: types.StringValue("Test threshold rule description"), + Query: types.StringValue("event.action:login"), + Language: types.StringValue("kuery"), + Threshold: utils.ObjectValueFrom(ctx, &ThresholdModel{ + Field: utils.ListValueFrom(ctx, []string{"user.name"}, types.StringType, path.Root("threshold").AtName("field"), &diags), + Value: types.Int64Value(5), + Cardinality: types.ListNull(getCardinalityType()), + }, getThresholdType(), path.Root("threshold"), &diags), + RiskScore: types.Int64Value(80), + Severity: types.StringValue("high"), + Enabled: types.BoolValue(true), + From: types.StringValue("now-1h"), + To: types.StringValue("now"), + Interval: types.StringValue("5m"), + Index: utils.ListValueFrom(ctx, []string{"auditbeat-*"}, types.StringType, path.Root("index"), &diags), + } + + require.Empty(t, diags) + + createProps, createDiags := data.toThresholdRuleCreateProps(ctx, NewMockApiClient()) + require.Empty(t, createDiags) + + thresholdRule, err := createProps.AsSecurityDetectionsAPIThresholdRuleCreateProps() + require.NoError(t, err) + + require.Equal(t, "Test Threshold Rule", string(thresholdRule.Name)) + require.Equal(t, "Test threshold rule description", string(thresholdRule.Description)) + require.Equal(t, "threshold", string(thresholdRule.Type)) + require.Equal(t, "event.action:login", string(thresholdRule.Query)) + require.Equal(t, int64(80), int64(thresholdRule.RiskScore)) + require.Equal(t, "high", string(thresholdRule.Severity)) + + // Verify threshold configuration + require.NotNil(t, thresholdRule.Threshold) + require.Equal(t, int64(5), int64(thresholdRule.Threshold.Value)) + + // Check single field + singleField, err := thresholdRule.Threshold.Field.AsSecurityDetectionsAPIThresholdField0() + require.NoError(t, err) + require.Equal(t, "user.name", string(singleField)) +} + +func TestThresholdToApi(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + tests := []struct { + name string + data SecurityDetectionRuleData + expectedValue int64 + expectedFieldCount int + hasCardinality bool + }{ + { + name: "threshold with single field", + data: SecurityDetectionRuleData{ + Threshold: utils.ObjectValueFrom(ctx, &ThresholdModel{ + Field: utils.ListValueFrom(ctx, []string{"user.name"}, types.StringType, path.Root("threshold").AtName("field"), &diags), + Value: types.Int64Value(10), + Cardinality: types.ListNull(getCardinalityType()), + }, getThresholdType(), path.Root("threshold"), &diags), + }, + expectedValue: 10, + expectedFieldCount: 1, + }, + { + name: "threshold with multiple fields and cardinality", + data: SecurityDetectionRuleData{ + Threshold: utils.ObjectValueFrom(ctx, &ThresholdModel{ + Field: utils.ListValueFrom(ctx, []string{"user.name", "source.ip"}, types.StringType, path.Root("threshold").AtName("field"), &diags), + Value: types.Int64Value(5), + Cardinality: utils.ListValueFrom(ctx, []CardinalityModel{ + { + Field: types.StringValue("destination.ip"), + Value: types.Int64Value(2), + }, + }, getCardinalityType(), path.Root("threshold").AtName("cardinality"), &diags), + }, getThresholdType(), path.Root("threshold"), &diags), + }, + expectedValue: 5, + expectedFieldCount: 2, + hasCardinality: true, + }, + } + + require.Empty(t, diags) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + threshold := tt.data.thresholdToApi(ctx, &diags) + require.Empty(t, diags) + require.NotNil(t, threshold) + + require.Equal(t, tt.expectedValue, int64(threshold.Value)) + + // Check field count + if singleField, err := threshold.Field.AsSecurityDetectionsAPIThresholdField0(); err == nil { + require.Equal(t, 1, tt.expectedFieldCount) + require.NotEmpty(t, string(singleField)) + } else if multipleFields, err := threshold.Field.AsSecurityDetectionsAPIThresholdField1(); err == nil { + require.Equal(t, tt.expectedFieldCount, len(multipleFields)) + } + + if tt.hasCardinality { + require.NotNil(t, threshold.Cardinality) + require.NotEmpty(t, *threshold.Cardinality) + } + }) + } +} + +func TestAlertSuppressionToApi(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + tests := []struct { + name string + data SecurityDetectionRuleData + expectedGroupByCount int + hasDuration bool + hasMissingFieldsStrategy bool + }{ + { + name: "alert suppression with all fields", + data: SecurityDetectionRuleData{ + AlertSuppression: utils.ObjectValueFrom(ctx, &AlertSuppressionModel{ + GroupBy: utils.ListValueFrom(ctx, []string{"user.name", "source.ip"}, types.StringType, path.Root("alert_suppression").AtName("group_by"), &diags), + Duration: customtypes.NewDurationValue("10m"), + MissingFieldsStrategy: types.StringValue("suppress"), + }, getAlertSuppressionType(), path.Root("alert_suppression"), &diags), + }, + expectedGroupByCount: 2, + hasDuration: true, + hasMissingFieldsStrategy: true, + }, + { + name: "alert suppression minimal", + data: SecurityDetectionRuleData{ + AlertSuppression: utils.ObjectValueFrom(ctx, &AlertSuppressionModel{ + GroupBy: utils.ListValueFrom(ctx, []string{"user.name"}, types.StringType, path.Root("alert_suppression").AtName("group_by"), &diags), + Duration: customtypes.NewDurationNull(), + MissingFieldsStrategy: types.StringNull(), + }, getAlertSuppressionType(), path.Root("alert_suppression"), &diags), + }, + expectedGroupByCount: 1, + }, + } + + require.Empty(t, diags) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + alertSuppression := tt.data.alertSuppressionToApi(ctx, &diags) + require.Empty(t, diags) + require.NotNil(t, alertSuppression) + + require.Equal(t, tt.expectedGroupByCount, len(alertSuppression.GroupBy)) + + if tt.hasDuration { + require.NotNil(t, alertSuppression.Duration) + require.Equal(t, 10, alertSuppression.Duration.Value) + require.Equal(t, "m", string(alertSuppression.Duration.Unit)) + } + + if tt.hasMissingFieldsStrategy { + require.NotNil(t, alertSuppression.MissingFieldsStrategy) + require.Equal(t, "suppress", string(*alertSuppression.MissingFieldsStrategy)) + } + }) + } +} + +func TestThreatMappingToApi(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + data := SecurityDetectionRuleData{ + ThreatMapping: utils.ListValueFrom(ctx, []SecurityDetectionRuleTfDataItem{ + { + Entries: utils.ListValueFrom(ctx, []SecurityDetectionRuleTfDataItemEntry{ + { + Field: types.StringValue("source.ip"), + Type: types.StringValue("mapping"), + Value: types.StringValue("threat.indicator.ip"), + }, + { + Field: types.StringValue("user.name"), + Type: types.StringValue("mapping"), + Value: types.StringValue("threat.indicator.user.name"), + }, + }, getThreatMappingEntryElementType(), path.Root("threat_mapping").AtListIndex(0).AtName("entries"), &diags), + }, + }, getThreatMappingElementType(), path.Root("threat_mapping"), &diags), + } + + require.Empty(t, diags) + + threatMapping, threatMappingDiags := data.threatMappingToApi(ctx) + require.Empty(t, threatMappingDiags) + require.NotNil(t, threatMapping) + require.Len(t, threatMapping, 1) + + mapping := threatMapping[0] + require.Len(t, mapping.Entries, 2) + + require.Equal(t, "source.ip", string(mapping.Entries[0].Field)) + require.Equal(t, "mapping", string(mapping.Entries[0].Type)) + require.Equal(t, "threat.indicator.ip", string(mapping.Entries[0].Value)) + + require.Equal(t, "user.name", string(mapping.Entries[1].Field)) + require.Equal(t, "mapping", string(mapping.Entries[1].Type)) + require.Equal(t, "threat.indicator.user.name", string(mapping.Entries[1].Value)) +} + +func TestActionsToApi(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + data := SecurityDetectionRuleData{ + Actions: utils.ListValueFrom(ctx, []ActionModel{ + { + ActionTypeId: types.StringValue(".slack"), + Id: types.StringValue("slack-action-1"), + Params: utils.MapValueFrom(ctx, map[string]attr.Value{ + "message": types.StringValue("Alert triggered"), + "channel": types.StringValue("#security"), + }, types.StringType, path.Root("actions").AtListIndex(0).AtName("params"), &diags), + Group: types.StringValue("default"), + Uuid: types.StringNull(), + AlertsFilter: utils.MapValueFrom(ctx, map[string]attr.Value{ + "status": types.StringValue("open"), + "severity": types.StringValue("high"), + }, types.StringType, path.Root("actions").AtListIndex(0).AtName("alerts_filter"), &diags), + Frequency: utils.ObjectValueFrom(ctx, &ActionFrequencyModel{ + NotifyWhen: types.StringValue("onActionGroupChange"), + Summary: types.BoolValue(false), + Throttle: types.StringValue("1h"), + }, getActionFrequencyType(), path.Root("actions").AtListIndex(0).AtName("frequency"), &diags), + }, + }, getActionElementType(), path.Root("actions"), &diags), + } + + require.Empty(t, diags) + + actions, actionsDiags := data.actionsToApi(ctx) + require.Empty(t, actionsDiags) + require.Len(t, actions, 1) + + action := actions[0] + require.Equal(t, ".slack", action.ActionTypeId) + require.Equal(t, "slack-action-1", string(action.Id)) + require.NotNil(t, action.Params) + require.Contains(t, action.Params, "message") + require.Equal(t, "Alert triggered", action.Params["message"]) + require.NotNil(t, action.Group) + require.Equal(t, "default", string(*action.Group)) + require.NotNil(t, action.Frequency) +} + +func TestFiltersToApi(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + filtersJSON := `[{"query": {"match": {"field": "value"}}}, {"range": {"timestamp": {"gte": "now-1h"}}}]` + + data := SecurityDetectionRuleData{ + Filters: jsontypes.NewNormalizedValue(filtersJSON), + } + + // Test filters conversion + filters, filtersDiags := data.filtersToApi(ctx) + require.Empty(t, filtersDiags) + require.NotNil(t, filters) + require.Len(t, *filters, 2) + + require.Empty(t, diags) +} + +func TestConvertActionsToModel(t *testing.T) { + ctx := context.Background() + + apiActions := []kbapi.SecurityDetectionsAPIRuleAction{ + { + ActionTypeId: ".email", + Id: "email-action-1", + Params: kbapi.SecurityDetectionsAPIRuleActionParams{ + "to": []string{"admin@example.com"}, + "subject": "Security Alert", + "message": "Alert details here", + }, + Group: utils.Pointer(kbapi.SecurityDetectionsAPIRuleActionGroup("default")), + Uuid: utils.Pointer(kbapi.SecurityDetectionsAPINonEmptyString("action-uuid-123")), + }, + } + + actionsList, diags := convertActionsToModel(ctx, apiActions) + require.Empty(t, diags) + require.False(t, actionsList.IsNull()) + + var actions []ActionModel + elemDiags := actionsList.ElementsAs(ctx, &actions, false) + require.Empty(t, elemDiags) + require.Len(t, actions, 1) + + action := actions[0] + require.Equal(t, ".email", action.ActionTypeId.ValueString()) + require.Equal(t, "email-action-1", action.Id.ValueString()) + require.Equal(t, "default", action.Group.ValueString()) + require.Equal(t, "action-uuid-123", action.Uuid.ValueString()) +} + +func TestUpdateFromRule_UnsupportedType(t *testing.T) { + ctx := context.Background() + data := &SecurityDetectionRuleData{} + + // Create a mock response that will fail to determine discriminator + response := &kbapi.SecurityDetectionsAPIRuleResponse{} + + diags := data.updateFromRule(ctx, response) + require.NotEmpty(t, diags) + require.True(t, diags.HasError()) +} + +func TestUpdateFromRule(t *testing.T) { + ctx := context.Background() + testUUID := uuid.MustParse("12345678-1234-1234-1234-123456789012") + spaceId := "test-space" + + tests := []struct { + name string + setupRule func() *kbapi.SecurityDetectionsAPIRuleResponse + expectError bool + errorMessage string + validateData func(t *testing.T, data *SecurityDetectionRuleData) + }{ + { + name: "query rule type", + setupRule: func() *kbapi.SecurityDetectionsAPIRuleResponse { + rule := kbapi.SecurityDetectionsAPIQueryRule{ + Id: testUUID, + RuleId: "test-query-rule", + Name: "Test Query Rule", + Type: "query", + Query: "user.name:test", + Language: "kuery", + Enabled: true, + RiskScore: 75, + Severity: "medium", + Version: 1, + Description: "Test query rule description", + From: "now-6m", + To: "now", + Interval: "5m", + CreatedBy: "test-user", + UpdatedBy: "test-user", + Revision: 1, + } + response := &kbapi.SecurityDetectionsAPIRuleResponse{} + err := response.FromSecurityDetectionsAPIQueryRule(rule) + require.NoError(t, err) + return response + }, + validateData: func(t *testing.T, data *SecurityDetectionRuleData) { + require.Equal(t, fmt.Sprintf("%s/%s", spaceId, testUUID.String()), data.Id.ValueString()) + require.Equal(t, "test-query-rule", data.RuleId.ValueString()) + require.Equal(t, "Test Query Rule", data.Name.ValueString()) + require.Equal(t, "query", data.Type.ValueString()) + require.Equal(t, "user.name:test", data.Query.ValueString()) + require.Equal(t, "kuery", data.Language.ValueString()) + require.Equal(t, true, data.Enabled.ValueBool()) + require.Equal(t, int64(75), data.RiskScore.ValueInt64()) + require.Equal(t, "medium", data.Severity.ValueString()) + }, + }, + { + name: "eql rule type", + setupRule: func() *kbapi.SecurityDetectionsAPIRuleResponse { + rule := kbapi.SecurityDetectionsAPIEqlRule{ + Id: testUUID, + RuleId: "test-eql-rule", + Name: "Test EQL Rule", + Type: "eql", + Query: "process where process.name == \"cmd.exe\"", + Language: "eql", + Enabled: true, + RiskScore: 80, + Severity: "high", + Version: 1, + Description: "Test EQL rule description", + From: "now-6m", + To: "now", + Interval: "5m", + CreatedBy: "test-user", + UpdatedBy: "test-user", + Revision: 1, + } + response := &kbapi.SecurityDetectionsAPIRuleResponse{} + err := response.FromSecurityDetectionsAPIEqlRule(rule) + require.NoError(t, err) + return response + }, + validateData: func(t *testing.T, data *SecurityDetectionRuleData) { + require.Equal(t, fmt.Sprintf("%s/%s", spaceId, testUUID.String()), data.Id.ValueString()) + require.Equal(t, "test-eql-rule", data.RuleId.ValueString()) + require.Equal(t, "Test EQL Rule", data.Name.ValueString()) + require.Equal(t, "eql", data.Type.ValueString()) + require.Equal(t, "process where process.name == \"cmd.exe\"", data.Query.ValueString()) + require.Equal(t, "eql", data.Language.ValueString()) + require.Equal(t, int64(80), data.RiskScore.ValueInt64()) + require.Equal(t, "high", data.Severity.ValueString()) + }, + }, + { + name: "esql rule type", + setupRule: func() *kbapi.SecurityDetectionsAPIRuleResponse { + rule := kbapi.SecurityDetectionsAPIEsqlRule{ + Id: testUUID, + RuleId: "test-esql-rule", + Name: "Test ESQL Rule", + Type: "esql", + Query: "FROM logs | WHERE user.name == \"suspicious_user\"", + Language: "esql", + Enabled: true, + RiskScore: 85, + Severity: "high", + Version: 1, + Description: "Test ESQL rule description", + From: "now-6m", + To: "now", + Interval: "5m", + CreatedBy: "test-user", + UpdatedBy: "test-user", + Revision: 1, + } + response := &kbapi.SecurityDetectionsAPIRuleResponse{} + err := response.FromSecurityDetectionsAPIEsqlRule(rule) + require.NoError(t, err) + return response + }, + validateData: func(t *testing.T, data *SecurityDetectionRuleData) { + require.Equal(t, fmt.Sprintf("%s/%s", spaceId, testUUID.String()), data.Id.ValueString()) + require.Equal(t, "test-esql-rule", data.RuleId.ValueString()) + require.Equal(t, "Test ESQL Rule", data.Name.ValueString()) + require.Equal(t, "esql", data.Type.ValueString()) + require.Equal(t, "FROM logs | WHERE user.name == \"suspicious_user\"", data.Query.ValueString()) + require.Equal(t, "esql", data.Language.ValueString()) + require.Equal(t, int64(85), data.RiskScore.ValueInt64()) + require.Equal(t, "high", data.Severity.ValueString()) + }, + }, + { + name: "machine_learning rule type", + setupRule: func() *kbapi.SecurityDetectionsAPIRuleResponse { + mlJobId := kbapi.SecurityDetectionsAPIMachineLearningJobId{} + err := mlJobId.FromSecurityDetectionsAPIMachineLearningJobId0("suspicious_activity") + require.NoError(t, err) + + rule := kbapi.SecurityDetectionsAPIMachineLearningRule{ + Id: testUUID, + RuleId: "test-ml-rule", + Name: "Test ML Rule", + Type: "machine_learning", + MachineLearningJobId: mlJobId, + AnomalyThreshold: 50, + Enabled: true, + RiskScore: 70, + Severity: "medium", + Version: 1, + Description: "Test ML rule description", + From: "now-6m", + To: "now", + Interval: "5m", + CreatedBy: "test-user", + UpdatedBy: "test-user", + Revision: 1, + } + response := &kbapi.SecurityDetectionsAPIRuleResponse{} + err = response.FromSecurityDetectionsAPIMachineLearningRule(rule) + require.NoError(t, err) + return response + }, + validateData: func(t *testing.T, data *SecurityDetectionRuleData) { + require.Equal(t, fmt.Sprintf("%s/%s", spaceId, testUUID.String()), data.Id.ValueString()) + require.Equal(t, "test-ml-rule", data.RuleId.ValueString()) + require.Equal(t, "Test ML Rule", data.Name.ValueString()) + require.Equal(t, "machine_learning", data.Type.ValueString()) + require.Equal(t, int64(50), data.AnomalyThreshold.ValueInt64()) + require.Equal(t, int64(70), data.RiskScore.ValueInt64()) + require.Equal(t, "medium", data.Severity.ValueString()) + require.Len(t, data.MachineLearningJobId.Elements(), 1) + }, + }, + { + name: "new_terms rule type", + setupRule: func() *kbapi.SecurityDetectionsAPIRuleResponse { + rule := kbapi.SecurityDetectionsAPINewTermsRule{ + Id: testUUID, + RuleId: "test-new-terms-rule", + Name: "Test New Terms Rule", + Type: "new_terms", + Query: "user.name:*", + Language: "kuery", + NewTermsFields: []string{"user.name", "host.name"}, + HistoryWindowStart: "now-7d", + Enabled: true, + RiskScore: 60, + Severity: "medium", + Version: 1, + Description: "Test new terms rule description", + From: "now-6m", + To: "now", + Interval: "5m", + CreatedBy: "test-user", + UpdatedBy: "test-user", + Revision: 1, + } + response := &kbapi.SecurityDetectionsAPIRuleResponse{} + err := response.FromSecurityDetectionsAPINewTermsRule(rule) + require.NoError(t, err) + return response + }, + validateData: func(t *testing.T, data *SecurityDetectionRuleData) { + require.Equal(t, fmt.Sprintf("%s/%s", spaceId, testUUID.String()), data.Id.ValueString()) + require.Equal(t, "test-new-terms-rule", data.RuleId.ValueString()) + require.Equal(t, "Test New Terms Rule", data.Name.ValueString()) + require.Equal(t, "new_terms", data.Type.ValueString()) + require.Equal(t, "user.name:*", data.Query.ValueString()) + require.Equal(t, "now-7d", data.HistoryWindowStart.ValueString()) + require.Equal(t, int64(60), data.RiskScore.ValueInt64()) + require.Equal(t, "medium", data.Severity.ValueString()) + require.Len(t, data.NewTermsFields.Elements(), 2) + }, + }, + { + name: "saved_query rule type", + setupRule: func() *kbapi.SecurityDetectionsAPIRuleResponse { + rule := kbapi.SecurityDetectionsAPISavedQueryRule{ + Id: testUUID, + RuleId: "test-saved-query-rule", + Name: "Test Saved Query Rule", + Type: "saved_query", + SavedId: "my-saved-query-id", + Enabled: true, + RiskScore: 65, + Severity: "medium", + Version: 1, + Description: "Test saved query rule description", + From: "now-6m", + To: "now", + Interval: "5m", + CreatedBy: "test-user", + UpdatedBy: "test-user", + Revision: 1, + } + response := &kbapi.SecurityDetectionsAPIRuleResponse{} + err := response.FromSecurityDetectionsAPISavedQueryRule(rule) + require.NoError(t, err) + return response + }, + validateData: func(t *testing.T, data *SecurityDetectionRuleData) { + require.Equal(t, fmt.Sprintf("%s/%s", spaceId, testUUID.String()), data.Id.ValueString()) + require.Equal(t, "test-saved-query-rule", data.RuleId.ValueString()) + require.Equal(t, "Test Saved Query Rule", data.Name.ValueString()) + require.Equal(t, "saved_query", data.Type.ValueString()) + require.Equal(t, "my-saved-query-id", data.SavedId.ValueString()) + require.Equal(t, int64(65), data.RiskScore.ValueInt64()) + require.Equal(t, "medium", data.Severity.ValueString()) + }, + }, + { + name: "threat_match rule type", + setupRule: func() *kbapi.SecurityDetectionsAPIRuleResponse { + rule := kbapi.SecurityDetectionsAPIThreatMatchRule{ + Id: testUUID, + RuleId: "test-threat-match-rule", + Name: "Test Threat Match Rule", + Type: "threat_match", + Query: "source.ip:*", + Language: "kuery", + ThreatIndex: []string{"threat-intel-*"}, + ThreatMapping: kbapi.SecurityDetectionsAPIThreatMapping{ + { + Entries: []kbapi.SecurityDetectionsAPIThreatMappingEntry{ + { + Field: "source.ip", + Type: "mapping", + Value: "threat.indicator.ip", + }, + }, + }, + }, + Enabled: true, + RiskScore: 90, + Severity: "critical", + Version: 1, + Description: "Test threat match rule description", + From: "now-6m", + To: "now", + Interval: "5m", + CreatedBy: "test-user", + UpdatedBy: "test-user", + Revision: 1, + } + response := &kbapi.SecurityDetectionsAPIRuleResponse{} + err := response.FromSecurityDetectionsAPIThreatMatchRule(rule) + require.NoError(t, err) + return response + }, + validateData: func(t *testing.T, data *SecurityDetectionRuleData) { + require.Equal(t, fmt.Sprintf("%s/%s", spaceId, testUUID.String()), data.Id.ValueString()) + require.Equal(t, "test-threat-match-rule", data.RuleId.ValueString()) + require.Equal(t, "Test Threat Match Rule", data.Name.ValueString()) + require.Equal(t, "threat_match", data.Type.ValueString()) + require.Equal(t, "source.ip:*", data.Query.ValueString()) + require.Equal(t, int64(90), data.RiskScore.ValueInt64()) + require.Equal(t, "critical", data.Severity.ValueString()) + require.Len(t, data.ThreatIndex.Elements(), 1) + require.Len(t, data.ThreatMapping.Elements(), 1) + }, + }, + { + name: "threshold rule type", + setupRule: func() *kbapi.SecurityDetectionsAPIRuleResponse { + thresholdField := kbapi.SecurityDetectionsAPIThresholdField{} + err := thresholdField.FromSecurityDetectionsAPIThresholdField0("user.name") + require.NoError(t, err) + + rule := kbapi.SecurityDetectionsAPIThresholdRule{ + Id: testUUID, + RuleId: "test-threshold-rule", + Name: "Test Threshold Rule", + Type: "threshold", + Query: "event.action:login", + Language: "kuery", + Threshold: kbapi.SecurityDetectionsAPIThreshold{ + Field: thresholdField, + Value: 5, + }, + Enabled: true, + RiskScore: 75, + Severity: "high", + Version: 1, + Description: "Test threshold rule description", + From: "now-6m", + To: "now", + Interval: "5m", + CreatedBy: "test-user", + UpdatedBy: "test-user", + Revision: 1, + } + response := &kbapi.SecurityDetectionsAPIRuleResponse{} + err = response.FromSecurityDetectionsAPIThresholdRule(rule) + require.NoError(t, err) + return response + }, + validateData: func(t *testing.T, data *SecurityDetectionRuleData) { + require.Equal(t, fmt.Sprintf("%s/%s", spaceId, testUUID.String()), data.Id.ValueString()) + require.Equal(t, "test-threshold-rule", data.RuleId.ValueString()) + require.Equal(t, "Test Threshold Rule", data.Name.ValueString()) + require.Equal(t, "threshold", data.Type.ValueString()) + require.Equal(t, "event.action:login", data.Query.ValueString()) + require.Equal(t, int64(75), data.RiskScore.ValueInt64()) + require.Equal(t, "high", data.Severity.ValueString()) + require.False(t, data.Threshold.IsNull()) + }, + }, + { + name: "discriminator error", + setupRule: func() *kbapi.SecurityDetectionsAPIRuleResponse { + // Create an empty response that will fail discriminator check + return &kbapi.SecurityDetectionsAPIRuleResponse{} + }, + expectError: true, + errorMessage: "Error determining rule processor", + validateData: func(t *testing.T, data *SecurityDetectionRuleData) { + // No validation needed for error case + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := &SecurityDetectionRuleData{ + SpaceId: types.StringValue(spaceId), + } + + response := tt.setupRule() + diags := data.updateFromRule(ctx, response) + + if tt.expectError { + require.True(t, diags.HasError()) + require.Contains(t, diags.Errors()[0].Summary(), tt.errorMessage) + } else { + require.Empty(t, diags) + tt.validateData(t, data) + } + }) + } +} + +func TestCompositeIdOperations(t *testing.T) { + tests := []struct { + name string + inputId string + expectedSpaceId string + expectedResourceId string + shouldError bool + }{ + { + name: "valid composite id", + inputId: "my-space/12345678-1234-1234-1234-123456789012", + expectedSpaceId: "my-space", + expectedResourceId: "12345678-1234-1234-1234-123456789012", + }, + { + name: "invalid composite id format", + inputId: "invalid-format", + shouldError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := SecurityDetectionRuleData{ + Id: types.StringValue(tt.inputId), + } + + compId, diags := clients.CompositeIdFromStrFw(data.Id.ValueString()) + + if tt.shouldError { + require.NotEmpty(t, diags) + return + } + + require.Empty(t, diags) + require.Equal(t, tt.expectedSpaceId, compId.ClusterId) + require.Equal(t, tt.expectedResourceId, compId.ResourceId) + }) + } +} + +func TestResponseActionsToApi(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + tests := []struct { + name string + data SecurityDetectionRuleData + actionType string + shouldError bool + }{ + { + name: "osquery response action", + data: SecurityDetectionRuleData{ + ResponseActions: utils.ListValueFrom(ctx, []ResponseActionModel{ + { + ActionTypeId: types.StringValue(".osquery"), + Params: utils.ObjectValueFrom(ctx, &ResponseActionParamsModel{ + Query: types.StringValue("SELECT * FROM processes"), + Timeout: types.Int64Value(300), + EcsMapping: types.MapNull(types.StringType), + Queries: types.ListNull(getOsqueryQueryElementType()), + PackId: types.StringNull(), + SavedQueryId: types.StringNull(), + Command: types.StringNull(), + Comment: types.StringNull(), + Config: types.ObjectNull(getEndpointProcessConfigType()), + }, getResponseActionParamsType(), path.Root("response_actions").AtListIndex(0).AtName("params"), &diags), + }, + }, getResponseActionElementType(), path.Root("response_actions"), &diags), + }, + actionType: ".osquery", + }, + { + name: "endpoint response action - isolate", + data: SecurityDetectionRuleData{ + ResponseActions: utils.ListValueFrom(ctx, []ResponseActionModel{ + { + ActionTypeId: types.StringValue(".endpoint"), + Params: utils.ObjectValueFrom(ctx, &ResponseActionParamsModel{ + Command: types.StringValue("isolate"), + Comment: types.StringValue("Isolating suspicious host"), + Config: types.ObjectNull(getEndpointProcessConfigType()), + Query: types.StringNull(), + PackId: types.StringNull(), + SavedQueryId: types.StringNull(), + Timeout: types.Int64Null(), + EcsMapping: types.MapNull(types.StringType), + Queries: types.ListNull(getOsqueryQueryElementType()), + }, getResponseActionParamsType(), path.Root("response_actions").AtListIndex(0).AtName("params"), &diags), + }, + }, getResponseActionElementType(), path.Root("response_actions"), &diags), + }, + actionType: ".endpoint", + }, + { + name: "unsupported response action type", + data: SecurityDetectionRuleData{ + ResponseActions: utils.ListValueFrom(ctx, []ResponseActionModel{ + { + ActionTypeId: types.StringValue(".unsupported"), + Params: utils.ObjectValueFrom(ctx, &ResponseActionParamsModel{ + Query: types.StringNull(), + PackId: types.StringNull(), + SavedQueryId: types.StringNull(), + Timeout: types.Int64Null(), + EcsMapping: types.MapNull(types.StringType), + Queries: types.ListNull(getOsqueryQueryElementType()), + Command: types.StringValue("unknown"), + Comment: types.StringNull(), + Config: types.ObjectNull(getEndpointProcessConfigType()), + }, getResponseActionParamsType(), path.Root("response_actions").AtListIndex(0).AtName("params"), &diags), + }, + }, getResponseActionElementType(), path.Root("response_actions"), &diags), + }, + actionType: ".unsupported", + shouldError: true, + }, + } + + require.Empty(t, diags) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + responseActions, responseActionsDiags := tt.data.responseActionsToApi(ctx, NewMockApiClient()) + + if tt.shouldError { + require.NotEmpty(t, responseActionsDiags) + return + } + + require.Empty(t, responseActionsDiags) + require.Len(t, responseActions, 1) + + // Verify the action type by checking discriminator + _, err := responseActions[0].ValueByDiscriminator() + require.NoError(t, err) + }) + } +} + +func TestResponseActionsToApiVersionCheck(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + // Test data with response actions + data := SecurityDetectionRuleData{ + ResponseActions: utils.ListValueFrom(ctx, []ResponseActionModel{ + { + ActionTypeId: types.StringValue(".osquery"), + Params: utils.ObjectValueFrom(ctx, &ResponseActionParamsModel{ + Query: types.StringValue("SELECT * FROM processes"), + Timeout: types.Int64Value(300), + EcsMapping: types.MapNull(types.StringType), + Queries: types.ListNull(getOsqueryQueryElementType()), + PackId: types.StringNull(), + SavedQueryId: types.StringNull(), + Command: types.StringNull(), + Comment: types.StringNull(), + Config: types.ObjectNull(getEndpointProcessConfigType()), + }, getResponseActionParamsType(), path.Root("response_actions").AtListIndex(0).AtName("params"), &diags), + }, + }, getResponseActionElementType(), path.Root("response_actions"), &diags), + } + + require.Empty(t, diags) + + responseActions, responseActionsDiags := data.responseActionsToApi(ctx, NewMockApiClient()) + + // Should work with the test client and return response actions + require.Empty(t, responseActionsDiags) + require.Len(t, responseActions, 1) + + // Verify the action type + actionValue, err := responseActions[0].ValueByDiscriminator() + require.NoError(t, err) + + // Verify it's an osquery action + osqueryAction, ok := actionValue.(kbapi.SecurityDetectionsAPIOsqueryResponseAction) + require.True(t, ok, "Expected osquery action") + require.Equal(t, kbapi.SecurityDetectionsAPIOsqueryResponseActionActionTypeId(".osquery"), osqueryAction.ActionTypeId) +} + +func TestKQLQueryLanguage(t *testing.T) { + tests := []struct { + name string + language types.String + expected *kbapi.SecurityDetectionsAPIKqlQueryLanguage + }{ + { + name: "kuery language", + language: types.StringValue("kuery"), + expected: utils.Pointer(kbapi.SecurityDetectionsAPIKqlQueryLanguage("kuery")), + }, + { + name: "lucene language", + language: types.StringValue("lucene"), + expected: utils.Pointer(kbapi.SecurityDetectionsAPIKqlQueryLanguage("lucene")), + }, + { + name: "unknown language defaults to kuery", + language: types.StringValue("unknown"), + expected: utils.Pointer(kbapi.SecurityDetectionsAPIKqlQueryLanguage("kuery")), + }, + { + name: "null language returns nil", + language: types.StringNull(), + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := SecurityDetectionRuleData{ + Language: tt.language, + } + + result := data.getKQLQueryLanguage() + + if tt.expected == nil { + require.Nil(t, result) + } else { + require.NotNil(t, result) + require.Equal(t, *tt.expected, *result) + } + }) + } +} + +func TestExceptionsListToApi(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + data := SecurityDetectionRuleData{ + ExceptionsList: utils.ListValueFrom(ctx, []ExceptionsListModel{ + { + Id: types.StringValue("exception-1"), + ListId: types.StringValue("trusted-processes"), + NamespaceType: types.StringValue("single"), + Type: types.StringValue("detection"), + }, + { + Id: types.StringValue("exception-2"), + ListId: types.StringValue("allow-list"), + NamespaceType: types.StringValue("agnostic"), + Type: types.StringValue("endpoint"), + }, + }, getExceptionsListElementType(), path.Root("exceptions_list"), &diags), + } + + require.Empty(t, diags) + + exceptionsList, exceptionsListDiags := data.exceptionsListToApi(ctx) + require.Empty(t, exceptionsListDiags) + require.Len(t, exceptionsList, 2) + + require.Equal(t, "exception-1", exceptionsList[0].Id) + require.Equal(t, "trusted-processes", exceptionsList[0].ListId) + require.Equal(t, "single", string(exceptionsList[0].NamespaceType)) + require.Equal(t, "detection", string(exceptionsList[0].Type)) + + require.Equal(t, "exception-2", exceptionsList[1].Id) + require.Equal(t, "allow-list", exceptionsList[1].ListId) + require.Equal(t, "agnostic", string(exceptionsList[1].NamespaceType)) + require.Equal(t, "endpoint", string(exceptionsList[1].Type)) +} + +func TestConvertThresholdToModel(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + apiThreshold kbapi.SecurityDetectionsAPIThreshold + expectedValue int64 + expectedFieldCount int + hasCardinality bool + }{ + { + name: "threshold with single field", + apiThreshold: func() kbapi.SecurityDetectionsAPIThreshold { + threshold := kbapi.SecurityDetectionsAPIThreshold{ + Value: 5, + } + err := threshold.Field.FromSecurityDetectionsAPIThresholdField0("user.name") + require.NoError(t, err) + return threshold + }(), + expectedValue: 5, + expectedFieldCount: 1, + }, + { + name: "threshold with multiple fields and cardinality", + apiThreshold: func() kbapi.SecurityDetectionsAPIThreshold { + threshold := kbapi.SecurityDetectionsAPIThreshold{ + Value: 10, + Cardinality: &kbapi.SecurityDetectionsAPIThresholdCardinality{ + {Field: "source.ip", Value: 3}, + }, + } + err := threshold.Field.FromSecurityDetectionsAPIThresholdField1([]string{"user.name", "process.name"}) + require.NoError(t, err) + return threshold + }(), + expectedValue: 10, + expectedFieldCount: 2, + hasCardinality: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + thresholdObj, diags := convertThresholdToModel(ctx, tt.apiThreshold) + require.Empty(t, diags) + require.False(t, thresholdObj.IsNull()) + + var thresholdModel ThresholdModel + objDiags := thresholdObj.As(ctx, &thresholdModel, basetypes.ObjectAsOptions{}) + require.Empty(t, objDiags) + + require.Equal(t, tt.expectedValue, thresholdModel.Value.ValueInt64()) + require.Equal(t, tt.expectedFieldCount, len(thresholdModel.Field.Elements())) + + if tt.hasCardinality { + require.False(t, thresholdModel.Cardinality.IsNull()) + require.NotEmpty(t, thresholdModel.Cardinality.Elements()) + } + }) + } +} + +func TestToCreateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + tests := []struct { + name string + ruleType string + shouldError bool + errorMsg string + setupData func() SecurityDetectionRuleData + }{ + { + name: "query rule type", + ruleType: "query", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Type: types.StringValue("query"), + Name: types.StringValue("Test Query Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("user.name:test"), + Language: types.StringValue("kuery"), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "eql rule type", + ruleType: "eql", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Type: types.StringValue("eql"), + Name: types.StringValue("Test EQL Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("process where process.name == \"cmd.exe\""), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "esql rule type", + ruleType: "esql", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Type: types.StringValue("esql"), + Name: types.StringValue("Test ESQL Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("FROM logs | WHERE user.name == \"suspicious_user\""), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "machine_learning rule type", + ruleType: "machine_learning", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Type: types.StringValue("machine_learning"), + Name: types.StringValue("Test ML Rule"), + Description: types.StringValue("Test description"), + AnomalyThreshold: types.Int64Value(50), + MachineLearningJobId: utils.ListValueFrom(ctx, []string{"suspicious_activity"}, types.StringType, path.Root("machine_learning_job_id"), &diags), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "new_terms rule type", + ruleType: "new_terms", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Type: types.StringValue("new_terms"), + Name: types.StringValue("Test New Terms Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("user.name:*"), + NewTermsFields: utils.ListValueFrom(ctx, []string{"user.name"}, types.StringType, path.Root("new_terms_fields"), &diags), + HistoryWindowStart: types.StringValue("now-7d"), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "saved_query rule type", + ruleType: "saved_query", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Type: types.StringValue("saved_query"), + Name: types.StringValue("Test Saved Query Rule"), + Description: types.StringValue("Test description"), + SavedId: types.StringValue("my-saved-query"), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "threat_match rule type", + ruleType: "threat_match", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Type: types.StringValue("threat_match"), + Name: types.StringValue("Test Threat Match Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("source.ip:*"), + ThreatIndex: utils.ListValueFrom(ctx, []string{"threat-intel-*"}, types.StringType, path.Root("threat_index"), &diags), + ThreatMapping: utils.ListValueFrom(ctx, []SecurityDetectionRuleTfDataItem{ + { + Entries: utils.ListValueFrom(ctx, []SecurityDetectionRuleTfDataItemEntry{ + { + Field: types.StringValue("source.ip"), + Type: types.StringValue("mapping"), + Value: types.StringValue("threat.indicator.ip"), + }, + }, getThreatMappingEntryElementType(), path.Root("threat_mapping").AtListIndex(0).AtName("entries"), &diags), + }, + }, getThreatMappingElementType(), path.Root("threat_mapping"), &diags), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "threshold rule type", + ruleType: "threshold", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Type: types.StringValue("threshold"), + Name: types.StringValue("Test Threshold Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("event.action:login"), + Threshold: utils.ObjectValueFrom(ctx, &ThresholdModel{ + Field: utils.ListValueFrom(ctx, []string{"user.name"}, types.StringType, path.Root("threshold").AtName("field"), &diags), + Value: types.Int64Value(5), + Cardinality: types.ListNull(getCardinalityType()), + }, getThresholdType(), path.Root("threshold"), &diags), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "unsupported rule type", + ruleType: "unsupported_type", + shouldError: true, + errorMsg: "Rule type 'unsupported_type' is not supported", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Type: types.StringValue("unsupported_type"), + Name: types.StringValue("Test Unsupported Rule"), + Description: types.StringValue("Test description"), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + } + + require.Empty(t, diags) // Check for any setup errors + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := tt.setupData() + + createProps, createDiags := data.toCreateProps(ctx, NewMockApiClient()) + + if tt.shouldError { + require.True(t, createDiags.HasError()) + require.Contains(t, createDiags.Errors()[0].Summary(), "Unsupported rule type") + require.Contains(t, createDiags.Errors()[0].Detail(), tt.errorMsg) + return + } + + require.Empty(t, createDiags) + + // Verify that the create props can be converted to the expected rule type and check values + switch tt.ruleType { + case "query": + queryRule, err := createProps.AsSecurityDetectionsAPIQueryRuleCreateProps() + require.NoError(t, err) + require.Equal(t, "Test Query Rule", string(queryRule.Name)) + require.Equal(t, "Test description", string(queryRule.Description)) + require.Equal(t, "query", string(queryRule.Type)) + require.Equal(t, "user.name:test", string(*queryRule.Query)) + require.Equal(t, "kuery", string(*queryRule.Language)) + require.Equal(t, int64(75), int64(queryRule.RiskScore)) + require.Equal(t, "medium", string(queryRule.Severity)) + case "eql": + eqlRule, err := createProps.AsSecurityDetectionsAPIEqlRuleCreateProps() + require.NoError(t, err) + require.Equal(t, "Test EQL Rule", string(eqlRule.Name)) + require.Equal(t, "Test description", string(eqlRule.Description)) + require.Equal(t, "eql", string(eqlRule.Type)) + require.Equal(t, "process where process.name == \"cmd.exe\"", string(eqlRule.Query)) + require.Equal(t, "eql", string(eqlRule.Language)) + require.Equal(t, int64(75), int64(eqlRule.RiskScore)) + require.Equal(t, "medium", string(eqlRule.Severity)) + case "esql": + esqlRule, err := createProps.AsSecurityDetectionsAPIEsqlRuleCreateProps() + require.NoError(t, err) + require.Equal(t, "Test ESQL Rule", string(esqlRule.Name)) + require.Equal(t, "Test description", string(esqlRule.Description)) + require.Equal(t, "esql", string(esqlRule.Type)) + require.Equal(t, "FROM logs | WHERE user.name == \"suspicious_user\"", string(esqlRule.Query)) + require.Equal(t, "esql", string(esqlRule.Language)) + require.Equal(t, int64(75), int64(esqlRule.RiskScore)) + require.Equal(t, "medium", string(esqlRule.Severity)) + case "machine_learning": + mlRule, err := createProps.AsSecurityDetectionsAPIMachineLearningRuleCreateProps() + require.NoError(t, err) + require.Equal(t, "Test ML Rule", string(mlRule.Name)) + require.Equal(t, "Test description", string(mlRule.Description)) + require.Equal(t, "machine_learning", string(mlRule.Type)) + require.Equal(t, int64(50), int64(mlRule.AnomalyThreshold)) + require.Equal(t, int64(75), int64(mlRule.RiskScore)) + require.Equal(t, "medium", string(mlRule.Severity)) + // Verify ML job ID is set correctly + jobId, err := mlRule.MachineLearningJobId.AsSecurityDetectionsAPIMachineLearningJobId1() + require.NoError(t, err) + require.Equal(t, []string{"suspicious_activity"}, jobId) + case "new_terms": + newTermsRule, err := createProps.AsSecurityDetectionsAPINewTermsRuleCreateProps() + require.NoError(t, err) + require.Equal(t, "Test New Terms Rule", string(newTermsRule.Name)) + require.Equal(t, "Test description", string(newTermsRule.Description)) + require.Equal(t, "new_terms", string(newTermsRule.Type)) + require.Equal(t, "user.name:*", string(newTermsRule.Query)) + require.Equal(t, "now-7d", string(newTermsRule.HistoryWindowStart)) + require.Equal(t, int64(75), int64(newTermsRule.RiskScore)) + require.Equal(t, "medium", string(newTermsRule.Severity)) + require.Len(t, newTermsRule.NewTermsFields, 1) + require.Equal(t, "user.name", newTermsRule.NewTermsFields[0]) + case "saved_query": + savedQueryRule, err := createProps.AsSecurityDetectionsAPISavedQueryRuleCreateProps() + require.NoError(t, err) + require.Equal(t, "Test Saved Query Rule", string(savedQueryRule.Name)) + require.Equal(t, "Test description", string(savedQueryRule.Description)) + require.Equal(t, "saved_query", string(savedQueryRule.Type)) + require.Equal(t, "my-saved-query", string(savedQueryRule.SavedId)) + require.Equal(t, int64(75), int64(savedQueryRule.RiskScore)) + require.Equal(t, "medium", string(savedQueryRule.Severity)) + case "threat_match": + threatMatchRule, err := createProps.AsSecurityDetectionsAPIThreatMatchRuleCreateProps() + require.NoError(t, err) + require.Equal(t, "Test Threat Match Rule", string(threatMatchRule.Name)) + require.Equal(t, "Test description", string(threatMatchRule.Description)) + require.Equal(t, "threat_match", string(threatMatchRule.Type)) + require.Equal(t, "source.ip:*", string(threatMatchRule.Query)) + require.Equal(t, int64(75), int64(threatMatchRule.RiskScore)) + require.Equal(t, "medium", string(threatMatchRule.Severity)) + require.Len(t, threatMatchRule.ThreatIndex, 1) + require.Equal(t, "threat-intel-*", threatMatchRule.ThreatIndex[0]) + require.Len(t, threatMatchRule.ThreatMapping, 1) + case "threshold": + thresholdRule, err := createProps.AsSecurityDetectionsAPIThresholdRuleCreateProps() + require.NoError(t, err) + require.Equal(t, "Test Threshold Rule", string(thresholdRule.Name)) + require.Equal(t, "Test description", string(thresholdRule.Description)) + require.Equal(t, "threshold", string(thresholdRule.Type)) + require.Equal(t, "event.action:login", string(thresholdRule.Query)) + require.Equal(t, int64(75), int64(thresholdRule.RiskScore)) + require.Equal(t, "medium", string(thresholdRule.Severity)) + require.NotNil(t, thresholdRule.Threshold) + require.Equal(t, int64(5), int64(thresholdRule.Threshold.Value)) + // Check single field + singleField, err := thresholdRule.Threshold.Field.AsSecurityDetectionsAPIThresholdField0() + require.NoError(t, err) + require.Equal(t, "user.name", string(singleField)) + } + }) + } +} + +func TestToUpdateProps(t *testing.T) { + ctx := context.Background() + var diags diag.Diagnostics + + // Create a valid composite ID for testing + testUUID := uuid.New() + testSpaceId := "test-space" + validCompositeId := fmt.Sprintf("%s/%s", testSpaceId, testUUID.String()) + + tests := []struct { + name string + ruleType string + shouldError bool + errorMsg string + setupData func() SecurityDetectionRuleData + }{ + { + name: "query rule type", + ruleType: "query", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Id: types.StringValue(validCompositeId), + Type: types.StringValue("query"), + Name: types.StringValue("Test Query Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("user.name:test"), + Language: types.StringValue("kuery"), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "eql rule type", + ruleType: "eql", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Id: types.StringValue(validCompositeId), + Type: types.StringValue("eql"), + Name: types.StringValue("Test EQL Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("process where process.name == \"cmd.exe\""), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "esql rule type", + ruleType: "esql", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Id: types.StringValue(validCompositeId), + Type: types.StringValue("esql"), + Name: types.StringValue("Test ESQL Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("FROM logs | WHERE user.name == \"suspicious_user\""), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "machine_learning rule type", + ruleType: "machine_learning", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Id: types.StringValue(validCompositeId), + Type: types.StringValue("machine_learning"), + Name: types.StringValue("Test ML Rule"), + Description: types.StringValue("Test description"), + AnomalyThreshold: types.Int64Value(50), + MachineLearningJobId: utils.ListValueFrom(ctx, []string{"suspicious_activity"}, types.StringType, path.Root("machine_learning_job_id"), &diags), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "new_terms rule type", + ruleType: "new_terms", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Id: types.StringValue(validCompositeId), + Type: types.StringValue("new_terms"), + Name: types.StringValue("Test New Terms Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("user.name:*"), + NewTermsFields: utils.ListValueFrom(ctx, []string{"user.name"}, types.StringType, path.Root("new_terms_fields"), &diags), + HistoryWindowStart: types.StringValue("now-7d"), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "saved_query rule type", + ruleType: "saved_query", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Id: types.StringValue(validCompositeId), + Type: types.StringValue("saved_query"), + Name: types.StringValue("Test Saved Query Rule"), + Description: types.StringValue("Test description"), + SavedId: types.StringValue("my-saved-query"), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "threat_match rule type", + ruleType: "threat_match", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Id: types.StringValue(validCompositeId), + Type: types.StringValue("threat_match"), + Name: types.StringValue("Test Threat Match Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("source.ip:*"), + ThreatIndex: utils.ListValueFrom(ctx, []string{"threat-intel-*"}, types.StringType, path.Root("threat_index"), &diags), + ThreatMapping: utils.ListValueFrom(ctx, []SecurityDetectionRuleTfDataItem{ + { + Entries: utils.ListValueFrom(ctx, []SecurityDetectionRuleTfDataItemEntry{ + { + Field: types.StringValue("source.ip"), + Type: types.StringValue("mapping"), + Value: types.StringValue("threat.indicator.ip"), + }, + }, getThreatMappingEntryElementType(), path.Root("threat_mapping").AtListIndex(0).AtName("entries"), &diags), + }, + }, getThreatMappingElementType(), path.Root("threat_mapping"), &diags), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "threshold rule type", + ruleType: "threshold", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Id: types.StringValue(validCompositeId), + Type: types.StringValue("threshold"), + Name: types.StringValue("Test Threshold Rule"), + Description: types.StringValue("Test description"), + Query: types.StringValue("event.action:login"), + Threshold: utils.ObjectValueFrom(ctx, &ThresholdModel{ + Field: utils.ListValueFrom(ctx, []string{"user.name"}, types.StringType, path.Root("threshold").AtName("field"), &diags), + Value: types.Int64Value(5), + Cardinality: types.ListNull(getCardinalityType()), + }, getThresholdType(), path.Root("threshold"), &diags), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + { + name: "unsupported rule type", + ruleType: "unsupported_type", + shouldError: true, + errorMsg: "Rule type 'unsupported_type' is not supported", + setupData: func() SecurityDetectionRuleData { + return SecurityDetectionRuleData{ + Id: types.StringValue(validCompositeId), + Type: types.StringValue("unsupported_type"), + Name: types.StringValue("Test Unsupported Rule"), + Description: types.StringValue("Test description"), + RiskScore: types.Int64Value(75), + Severity: types.StringValue("medium"), + } + }, + }, + } + + require.Empty(t, diags) // Check for any setup errors + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := tt.setupData() + + updateProps, updateDiags := data.toUpdateProps(ctx, NewMockApiClient()) + + if tt.shouldError { + require.True(t, updateDiags.HasError()) + require.Contains(t, updateDiags.Errors()[0].Summary(), "Unsupported rule type") + require.Contains(t, updateDiags.Errors()[0].Detail(), tt.errorMsg) + return + } + + require.Empty(t, updateDiags) + + // Verify that the update props can be converted to the expected rule type and check values + switch tt.ruleType { + case "query": + queryRule, err := updateProps.AsSecurityDetectionsAPIQueryRuleUpdateProps() + require.NoError(t, err) + require.Equal(t, "Test Query Rule", string(queryRule.Name)) + require.Equal(t, "Test description", string(queryRule.Description)) + require.Equal(t, "user.name:test", string(*queryRule.Query)) + require.Equal(t, "kuery", string(*queryRule.Language)) + require.Equal(t, int64(75), int64(queryRule.RiskScore)) + require.Equal(t, "medium", string(queryRule.Severity)) + case "eql": + eqlRule, err := updateProps.AsSecurityDetectionsAPIEqlRuleUpdateProps() + require.NoError(t, err) + require.Equal(t, "Test EQL Rule", string(eqlRule.Name)) + require.Equal(t, "Test description", string(eqlRule.Description)) + require.Equal(t, "process where process.name == \"cmd.exe\"", string(eqlRule.Query)) + require.Equal(t, int64(75), int64(eqlRule.RiskScore)) + require.Equal(t, "medium", string(eqlRule.Severity)) + case "esql": + esqlRule, err := updateProps.AsSecurityDetectionsAPIEsqlRuleUpdateProps() + require.NoError(t, err) + require.Equal(t, "Test ESQL Rule", string(esqlRule.Name)) + require.Equal(t, "Test description", string(esqlRule.Description)) + require.Equal(t, "FROM logs | WHERE user.name == \"suspicious_user\"", string(esqlRule.Query)) + require.Equal(t, int64(75), int64(esqlRule.RiskScore)) + require.Equal(t, "medium", string(esqlRule.Severity)) + case "machine_learning": + mlRule, err := updateProps.AsSecurityDetectionsAPIMachineLearningRuleUpdateProps() + require.NoError(t, err) + require.Equal(t, "Test ML Rule", string(mlRule.Name)) + require.Equal(t, "Test description", string(mlRule.Description)) + require.Equal(t, int64(50), int64(mlRule.AnomalyThreshold)) + require.Equal(t, int64(75), int64(mlRule.RiskScore)) + require.Equal(t, "medium", string(mlRule.Severity)) + // Verify ML job ID is set correctly + jobId, err := mlRule.MachineLearningJobId.AsSecurityDetectionsAPIMachineLearningJobId1() + require.NoError(t, err) + require.Equal(t, []string{"suspicious_activity"}, jobId) + case "new_terms": + newTermsRule, err := updateProps.AsSecurityDetectionsAPINewTermsRuleUpdateProps() + require.NoError(t, err) + require.Equal(t, "Test New Terms Rule", string(newTermsRule.Name)) + require.Equal(t, "Test description", string(newTermsRule.Description)) + require.Equal(t, "user.name:*", string(newTermsRule.Query)) + require.Equal(t, "now-7d", string(newTermsRule.HistoryWindowStart)) + require.Equal(t, int64(75), int64(newTermsRule.RiskScore)) + require.Equal(t, "medium", string(newTermsRule.Severity)) + require.Len(t, newTermsRule.NewTermsFields, 1) + require.Equal(t, "user.name", newTermsRule.NewTermsFields[0]) + case "saved_query": + savedQueryRule, err := updateProps.AsSecurityDetectionsAPISavedQueryRuleUpdateProps() + require.NoError(t, err) + require.Equal(t, "Test Saved Query Rule", string(savedQueryRule.Name)) + require.Equal(t, "Test description", string(savedQueryRule.Description)) + require.Equal(t, "my-saved-query", string(savedQueryRule.SavedId)) + require.Equal(t, int64(75), int64(savedQueryRule.RiskScore)) + require.Equal(t, "medium", string(savedQueryRule.Severity)) + case "threat_match": + threatMatchRule, err := updateProps.AsSecurityDetectionsAPIThreatMatchRuleUpdateProps() + require.NoError(t, err) + require.Equal(t, "Test Threat Match Rule", string(threatMatchRule.Name)) + require.Equal(t, "Test description", string(threatMatchRule.Description)) + require.Equal(t, "source.ip:*", string(threatMatchRule.Query)) + require.Equal(t, int64(75), int64(threatMatchRule.RiskScore)) + require.Equal(t, "medium", string(threatMatchRule.Severity)) + require.Len(t, threatMatchRule.ThreatIndex, 1) + require.Equal(t, "threat-intel-*", threatMatchRule.ThreatIndex[0]) + require.Len(t, threatMatchRule.ThreatMapping, 1) + case "threshold": + thresholdRule, err := updateProps.AsSecurityDetectionsAPIThresholdRuleUpdateProps() + require.NoError(t, err) + require.Equal(t, "Test Threshold Rule", string(thresholdRule.Name)) + require.Equal(t, "Test description", string(thresholdRule.Description)) + require.Equal(t, "event.action:login", string(thresholdRule.Query)) + require.Equal(t, int64(75), int64(thresholdRule.RiskScore)) + require.Equal(t, "medium", string(thresholdRule.Severity)) + require.NotNil(t, thresholdRule.Threshold) + require.Equal(t, int64(5), int64(thresholdRule.Threshold.Value)) + // Check single field + singleField, err := thresholdRule.Threshold.Field.AsSecurityDetectionsAPIThresholdField0() + require.NoError(t, err) + require.Equal(t, "user.name", string(singleField)) + } + }) + } +} + +func TestParseDurationToApi(t *testing.T) { + tests := []struct { + name string + duration customtypes.Duration + expectedVal int + expectedUnit kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnit + expectError bool + }{ + { + name: "valid seconds", + duration: customtypes.NewDurationValue("30s"), + expectedVal: 30, + expectedUnit: kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnitS, + expectError: false, + }, + { + name: "valid minutes", + duration: customtypes.NewDurationValue("5m"), + expectedVal: 5, + expectedUnit: kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnitM, + expectError: false, + }, + { + name: "valid hours", + duration: customtypes.NewDurationValue("2h"), + expectedVal: 2, + expectedUnit: kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnitH, + expectError: false, + }, + { + name: "valid days converted to hours", + duration: customtypes.NewDurationValue("1d"), + expectedVal: 24, + expectedUnit: kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnitH, + expectError: false, + }, + { + name: "multiple days converted to hours", + duration: customtypes.NewDurationValue("3d"), + expectedVal: 72, + expectedUnit: kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnitH, + expectError: false, + }, + { + name: "invalid format - no unit", + duration: customtypes.NewDurationValue("30"), + expectError: true, + }, + { + name: "invalid format - non-numeric value", + duration: customtypes.NewDurationValue("ABCs"), + expectError: true, + }, + { + name: "invalid format - unsupported unit", + duration: customtypes.NewDurationValue("30w"), + expectError: true, + }, + { + name: "invalid format - empty string", + duration: customtypes.NewDurationValue(""), + expectError: true, + }, + { + name: "null duration", + duration: customtypes.NewDurationNull(), + expectError: true, + }, + { + name: "unknown duration", + duration: customtypes.NewDurationUnknown(), + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, diags := parseDurationToApi(tt.duration) + + if tt.expectError { + require.True(t, diags.HasError(), "Expected error but got none") + return + } + + require.False(t, diags.HasError(), "Unexpected error: %v", diags) + require.Equal(t, tt.expectedVal, result.Value) + require.Equal(t, tt.expectedUnit, result.Unit) + }) + } +} diff --git a/internal/kibana/security_detection_rule/models_threat_match.go b/internal/kibana/security_detection_rule/models_threat_match.go new file mode 100644 index 000000000..f0c73b330 --- /dev/null +++ b/internal/kibana/security_detection_rule/models_threat_match.go @@ -0,0 +1,430 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +type ThreatMatchRuleProcessor struct{} + +func (t ThreatMatchRuleProcessor) HandlesRuleType(ruleType string) bool { + return ruleType == "threat_match" +} + +func (t ThreatMatchRuleProcessor) ToCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + return d.toThreatMatchRuleCreateProps(ctx, client) +} + +func (t ThreatMatchRuleProcessor) ToUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + return d.toThreatMatchRuleUpdateProps(ctx, client) +} + +func (t ThreatMatchRuleProcessor) HandlesAPIRuleResponse(rule any) bool { + _, ok := rule.(kbapi.SecurityDetectionsAPIThreatMatchRule) + return ok +} + +func (t ThreatMatchRuleProcessor) UpdateFromResponse(ctx context.Context, rule any, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + value, ok := rule.(kbapi.SecurityDetectionsAPIThreatMatchRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return diags + } + + return d.updateFromThreatMatchRule(ctx, &value) +} + +func (t ThreatMatchRuleProcessor) ExtractId(response any) (string, diag.Diagnostics) { + var diags diag.Diagnostics + value, ok := response.(kbapi.SecurityDetectionsAPIThreatMatchRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return "", diags + } + return value.Id.String(), diags +} + +func (d SecurityDetectionRuleData) toThreatMatchRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var createProps kbapi.SecurityDetectionsAPIRuleCreateProps + + threatMatchRule := kbapi.SecurityDetectionsAPIThreatMatchRuleCreateProps{ + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIThreatMatchRuleCreatePropsType("threat_match"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // Set threat index + if utils.IsKnown(d.ThreatIndex) { + threatIndex := utils.ListTypeAs[string](ctx, d.ThreatIndex, path.Root("threat_index"), &diags) + if !diags.HasError() { + threatMatchRule.ThreatIndex = threatIndex + } + } + + if utils.IsKnown(d.ThreatMapping) && len(d.ThreatMapping.Elements()) > 0 { + apiThreatMapping, threatMappingDiags := d.threatMappingToApi(ctx) + if !threatMappingDiags.HasError() { + threatMatchRule.ThreatMapping = apiThreatMapping + } + diags.Append(threatMappingDiags...) + } + + d.setCommonCreateProps(ctx, &CommonCreateProps{ + Actions: &threatMatchRule.Actions, + ResponseActions: &threatMatchRule.ResponseActions, + RuleId: &threatMatchRule.RuleId, + Enabled: &threatMatchRule.Enabled, + From: &threatMatchRule.From, + To: &threatMatchRule.To, + Interval: &threatMatchRule.Interval, + Index: &threatMatchRule.Index, + Author: &threatMatchRule.Author, + Tags: &threatMatchRule.Tags, + FalsePositives: &threatMatchRule.FalsePositives, + References: &threatMatchRule.References, + License: &threatMatchRule.License, + Note: &threatMatchRule.Note, + Setup: &threatMatchRule.Setup, + MaxSignals: &threatMatchRule.MaxSignals, + Version: &threatMatchRule.Version, + ExceptionsList: &threatMatchRule.ExceptionsList, + AlertSuppression: &threatMatchRule.AlertSuppression, + RiskScoreMapping: &threatMatchRule.RiskScoreMapping, + SeverityMapping: &threatMatchRule.SeverityMapping, + RelatedIntegrations: &threatMatchRule.RelatedIntegrations, + RequiredFields: &threatMatchRule.RequiredFields, + BuildingBlockType: &threatMatchRule.BuildingBlockType, + DataViewId: &threatMatchRule.DataViewId, + Namespace: &threatMatchRule.Namespace, + RuleNameOverride: &threatMatchRule.RuleNameOverride, + TimestampOverride: &threatMatchRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &threatMatchRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &threatMatchRule.InvestigationFields, + Filters: &threatMatchRule.Filters, + }, &diags, client) + + // Set threat-specific fields + if utils.IsKnown(d.ThreatQuery) { + threatMatchRule.ThreatQuery = kbapi.SecurityDetectionsAPIThreatQuery(d.ThreatQuery.ValueString()) + } + + if utils.IsKnown(d.ThreatIndicatorPath) { + threatIndicatorPath := kbapi.SecurityDetectionsAPIThreatIndicatorPath(d.ThreatIndicatorPath.ValueString()) + threatMatchRule.ThreatIndicatorPath = &threatIndicatorPath + } + + if utils.IsKnown(d.ConcurrentSearches) { + concurrentSearches := kbapi.SecurityDetectionsAPIConcurrentSearches(d.ConcurrentSearches.ValueInt64()) + threatMatchRule.ConcurrentSearches = &concurrentSearches + } + + if utils.IsKnown(d.ItemsPerSearch) { + itemsPerSearch := kbapi.SecurityDetectionsAPIItemsPerSearch(d.ItemsPerSearch.ValueInt64()) + threatMatchRule.ItemsPerSearch = &itemsPerSearch + } + + // Set query language + threatMatchRule.Language = d.getKQLQueryLanguage() + + if utils.IsKnown(d.SavedId) { + savedId := kbapi.SecurityDetectionsAPISavedQueryId(d.SavedId.ValueString()) + threatMatchRule.SavedId = &savedId + } + + // Convert to union type + err := createProps.FromSecurityDetectionsAPIThreatMatchRuleCreateProps(threatMatchRule) + if err != nil { + diags.AddError( + "Error building create properties", + "Could not convert threat match rule properties: "+err.Error(), + ) + } + + return createProps, diags +} +func (d SecurityDetectionRuleData) toThreatMatchRuleUpdateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps + + // Parse ID to get space_id and rule_id + compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) + diags.Append(resourceIdDiags...) + + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + diags.AddError("ID was not a valid UUID", err.Error()) + return updateProps, diags + } + var id = kbapi.SecurityDetectionsAPIRuleObjectId(uid) + + threatMatchRule := kbapi.SecurityDetectionsAPIThreatMatchRuleUpdateProps{ + Id: &id, + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIThreatMatchRuleUpdatePropsType("threat_match"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // For updates, we need to include the rule_id if it's set + if utils.IsKnown(d.RuleId) { + ruleId := kbapi.SecurityDetectionsAPIRuleSignatureId(d.RuleId.ValueString()) + threatMatchRule.RuleId = &ruleId + threatMatchRule.Id = nil // if rule_id is set, we cant send id + } + + // Set threat index + if utils.IsKnown(d.ThreatIndex) { + threatIndex := utils.ListTypeAs[string](ctx, d.ThreatIndex, path.Root("threat_index"), &diags) + if !diags.HasError() { + threatMatchRule.ThreatIndex = threatIndex + } + } + + if utils.IsKnown(d.ThreatMapping) && len(d.ThreatMapping.Elements()) > 0 { + apiThreatMapping, threatMappingDiags := d.threatMappingToApi(ctx) + if !threatMappingDiags.HasError() { + threatMatchRule.ThreatMapping = apiThreatMapping + } + diags.Append(threatMappingDiags...) + } + + d.setCommonUpdateProps(ctx, &CommonUpdateProps{ + Actions: &threatMatchRule.Actions, + ResponseActions: &threatMatchRule.ResponseActions, + RuleId: &threatMatchRule.RuleId, + Enabled: &threatMatchRule.Enabled, + From: &threatMatchRule.From, + To: &threatMatchRule.To, + Interval: &threatMatchRule.Interval, + Index: &threatMatchRule.Index, + Author: &threatMatchRule.Author, + Tags: &threatMatchRule.Tags, + FalsePositives: &threatMatchRule.FalsePositives, + References: &threatMatchRule.References, + License: &threatMatchRule.License, + Note: &threatMatchRule.Note, + InvestigationFields: &threatMatchRule.InvestigationFields, + Setup: &threatMatchRule.Setup, + MaxSignals: &threatMatchRule.MaxSignals, + Version: &threatMatchRule.Version, + ExceptionsList: &threatMatchRule.ExceptionsList, + AlertSuppression: &threatMatchRule.AlertSuppression, + RiskScoreMapping: &threatMatchRule.RiskScoreMapping, + SeverityMapping: &threatMatchRule.SeverityMapping, + RelatedIntegrations: &threatMatchRule.RelatedIntegrations, + RequiredFields: &threatMatchRule.RequiredFields, + BuildingBlockType: &threatMatchRule.BuildingBlockType, + DataViewId: &threatMatchRule.DataViewId, + Namespace: &threatMatchRule.Namespace, + RuleNameOverride: &threatMatchRule.RuleNameOverride, + TimestampOverride: &threatMatchRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &threatMatchRule.TimestampOverrideFallbackDisabled, + Filters: &threatMatchRule.Filters, + }, &diags, client) + + // Set threat-specific fields + if utils.IsKnown(d.ThreatQuery) { + threatMatchRule.ThreatQuery = kbapi.SecurityDetectionsAPIThreatQuery(d.ThreatQuery.ValueString()) + } + + if utils.IsKnown(d.ThreatIndicatorPath) { + threatIndicatorPath := kbapi.SecurityDetectionsAPIThreatIndicatorPath(d.ThreatIndicatorPath.ValueString()) + threatMatchRule.ThreatIndicatorPath = &threatIndicatorPath + } + + if utils.IsKnown(d.ConcurrentSearches) { + concurrentSearches := kbapi.SecurityDetectionsAPIConcurrentSearches(d.ConcurrentSearches.ValueInt64()) + threatMatchRule.ConcurrentSearches = &concurrentSearches + } + + if utils.IsKnown(d.ItemsPerSearch) { + itemsPerSearch := kbapi.SecurityDetectionsAPIItemsPerSearch(d.ItemsPerSearch.ValueInt64()) + threatMatchRule.ItemsPerSearch = &itemsPerSearch + } + + // Set query language + threatMatchRule.Language = d.getKQLQueryLanguage() + + if utils.IsKnown(d.SavedId) { + savedId := kbapi.SecurityDetectionsAPISavedQueryId(d.SavedId.ValueString()) + threatMatchRule.SavedId = &savedId + } + + // Convert to union type + err = updateProps.FromSecurityDetectionsAPIThreatMatchRuleUpdateProps(threatMatchRule) + if err != nil { + diags.AddError( + "Error building update properties", + "Could not convert threat match rule properties: "+err.Error(), + ) + } + + return updateProps, diags +} + +func (d *SecurityDetectionRuleData) updateFromThreatMatchRule(ctx context.Context, rule *kbapi.SecurityDetectionsAPIThreatMatchRule) diag.Diagnostics { + var diags diag.Diagnostics + + compId := clients.CompositeId{ + ClusterId: d.SpaceId.ValueString(), + ResourceId: rule.Id.String(), + } + d.Id = types.StringValue(compId.String()) + + d.RuleId = types.StringValue(string(rule.RuleId)) + d.Name = types.StringValue(string(rule.Name)) + d.Type = types.StringValue(string(rule.Type)) + + // Update common fields + diags.Append(d.updateDataViewIdFromApi(ctx, rule.DataViewId)...) + diags.Append(d.updateNamespaceFromApi(ctx, rule.Namespace)...) + diags.Append(d.updateRuleNameOverrideFromApi(ctx, rule.RuleNameOverride)...) + diags.Append(d.updateTimestampOverrideFromApi(ctx, rule.TimestampOverride)...) + diags.Append(d.updateTimestampOverrideFallbackDisabledFromApi(ctx, rule.TimestampOverrideFallbackDisabled)...) + + // Update building block type + diags.Append(d.updateBuildingBlockTypeFromApi(ctx, rule.BuildingBlockType)...) + d.Query = types.StringValue(rule.Query) + d.Language = types.StringValue(string(rule.Language)) + d.Enabled = types.BoolValue(bool(rule.Enabled)) + d.From = types.StringValue(string(rule.From)) + d.To = types.StringValue(string(rule.To)) + d.Interval = types.StringValue(string(rule.Interval)) + d.Description = types.StringValue(string(rule.Description)) + d.RiskScore = types.Int64Value(int64(rule.RiskScore)) + d.Severity = types.StringValue(string(rule.Severity)) + d.MaxSignals = types.Int64Value(int64(rule.MaxSignals)) + d.Version = types.Int64Value(int64(rule.Version)) + + // Update read-only fields + d.CreatedAt = types.StringValue(rule.CreatedAt.Format("2006-01-02T15:04:05.000Z")) + d.CreatedBy = types.StringValue(rule.CreatedBy) + d.UpdatedAt = types.StringValue(rule.UpdatedAt.Format("2006-01-02T15:04:05.000Z")) + d.UpdatedBy = types.StringValue(rule.UpdatedBy) + d.Revision = types.Int64Value(int64(rule.Revision)) + + // Update index patterns + diags.Append(d.updateIndexFromApi(ctx, rule.Index)...) + + // Threat Match-specific fields + d.ThreatQuery = types.StringValue(string(rule.ThreatQuery)) + if len(rule.ThreatIndex) > 0 { + d.ThreatIndex = utils.ListValueFrom(ctx, rule.ThreatIndex, types.StringType, path.Root("threat_index"), &diags) + } else { + d.ThreatIndex = types.ListValueMust(types.StringType, []attr.Value{}) + } + + if rule.ThreatIndicatorPath != nil { + d.ThreatIndicatorPath = types.StringValue(string(*rule.ThreatIndicatorPath)) + } else { + d.ThreatIndicatorPath = types.StringNull() + } + + if rule.ConcurrentSearches != nil { + d.ConcurrentSearches = types.Int64Value(int64(*rule.ConcurrentSearches)) + } else { + d.ConcurrentSearches = types.Int64Null() + } + + if rule.ItemsPerSearch != nil { + d.ItemsPerSearch = types.Int64Value(int64(*rule.ItemsPerSearch)) + } else { + d.ItemsPerSearch = types.Int64Null() + } + + // Optional saved query ID + if rule.SavedId != nil { + d.SavedId = types.StringValue(string(*rule.SavedId)) + } else { + d.SavedId = types.StringNull() + } + + // Update author + diags.Append(d.updateAuthorFromApi(ctx, rule.Author)...) + + // Update tags + diags.Append(d.updateTagsFromApi(ctx, rule.Tags)...) + + // Update false positives + diags.Append(d.updateFalsePositivesFromApi(ctx, rule.FalsePositives)...) + + // Update references + diags.Append(d.updateReferencesFromApi(ctx, rule.References)...) + + // Update optional string fields + diags.Append(d.updateLicenseFromApi(ctx, rule.License)...) + diags.Append(d.updateNoteFromApi(ctx, rule.Note)...) + diags.Append(d.updateSetupFromApi(ctx, rule.Setup)...) + + // Convert threat mapping + if len(rule.ThreatMapping) > 0 { + listValue, threatMappingDiags := convertThreatMappingToModel(ctx, rule.ThreatMapping) + diags.Append(threatMappingDiags...) + if !threatMappingDiags.HasError() { + d.ThreatMapping = listValue + } + } + + // Update actions + actionDiags := d.updateActionsFromApi(ctx, rule.Actions) + diags.Append(actionDiags...) + + // Update exceptions list + exceptionsListDiags := d.updateExceptionsListFromApi(ctx, rule.ExceptionsList) + diags.Append(exceptionsListDiags...) + + // Update risk score mapping + riskScoreMappingDiags := d.updateRiskScoreMappingFromApi(ctx, rule.RiskScoreMapping) + diags.Append(riskScoreMappingDiags...) + + // Update investigation fields + investigationFieldsDiags := d.updateInvestigationFieldsFromApi(ctx, rule.InvestigationFields) + diags.Append(investigationFieldsDiags...) + + // Update filters field + filtersDiags := d.updateFiltersFromApi(ctx, rule.Filters) + diags.Append(filtersDiags...) + + // Update severity mapping + severityMappingDiags := d.updateSeverityMappingFromApi(ctx, &rule.SeverityMapping) + diags.Append(severityMappingDiags...) + + // Update related integrations + relatedIntegrationsDiags := d.updateRelatedIntegrationsFromApi(ctx, &rule.RelatedIntegrations) + diags.Append(relatedIntegrationsDiags...) + + // Update required fields + requiredFieldsDiags := d.updateRequiredFieldsFromApi(ctx, &rule.RequiredFields) + diags.Append(requiredFieldsDiags...) + + // Update alert suppression + alertSuppressionDiags := d.updateAlertSuppressionFromApi(ctx, rule.AlertSuppression) + diags.Append(alertSuppressionDiags...) + + // Update response actions + responseActionsDiags := d.updateResponseActionsFromApi(ctx, rule.ResponseActions) + diags.Append(responseActionsDiags...) + + return diags +} diff --git a/internal/kibana/security_detection_rule/models_threshold.go b/internal/kibana/security_detection_rule/models_threshold.go new file mode 100644 index 000000000..3590c8071 --- /dev/null +++ b/internal/kibana/security_detection_rule/models_threshold.go @@ -0,0 +1,357 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +type ThresholdRuleProcessor struct{} + +func (th ThresholdRuleProcessor) HandlesRuleType(t string) bool { + return t == "threshold" +} + +func (th ThresholdRuleProcessor) ToCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + return d.toThresholdRuleCreateProps(ctx, client) +} + +func (th ThresholdRuleProcessor) ToUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + return d.toThresholdRuleUpdateProps(ctx, client) +} + +func (th ThresholdRuleProcessor) HandlesAPIRuleResponse(rule any) bool { + _, ok := rule.(kbapi.SecurityDetectionsAPIThresholdRule) + return ok +} + +func (th ThresholdRuleProcessor) UpdateFromResponse(ctx context.Context, rule any, d *SecurityDetectionRuleData) diag.Diagnostics { + var diags diag.Diagnostics + value, ok := rule.(kbapi.SecurityDetectionsAPIThresholdRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return diags + } + + return d.updateFromThresholdRule(ctx, &value) +} + +func (th ThresholdRuleProcessor) ExtractId(response any) (string, diag.Diagnostics) { + var diags diag.Diagnostics + value, ok := response.(kbapi.SecurityDetectionsAPIThresholdRule) + if !ok { + diags.AddError( + "Error extracting rule ID", + "Could not extract rule ID from response", + ) + return "", diags + } + return value.Id.String(), diags +} + +func (d SecurityDetectionRuleData) toThresholdRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var createProps kbapi.SecurityDetectionsAPIRuleCreateProps + + thresholdRule := kbapi.SecurityDetectionsAPIThresholdRuleCreateProps{ + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIThresholdRuleCreatePropsType("threshold"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // Set threshold - this is required for threshold rules + threshold := d.thresholdToApi(ctx, &diags) + if threshold != nil { + thresholdRule.Threshold = *threshold + } + + d.setCommonCreateProps(ctx, &CommonCreateProps{ + Actions: &thresholdRule.Actions, + ResponseActions: &thresholdRule.ResponseActions, + RuleId: &thresholdRule.RuleId, + Enabled: &thresholdRule.Enabled, + From: &thresholdRule.From, + To: &thresholdRule.To, + Interval: &thresholdRule.Interval, + Index: &thresholdRule.Index, + Author: &thresholdRule.Author, + Tags: &thresholdRule.Tags, + FalsePositives: &thresholdRule.FalsePositives, + References: &thresholdRule.References, + License: &thresholdRule.License, + Note: &thresholdRule.Note, + Setup: &thresholdRule.Setup, + MaxSignals: &thresholdRule.MaxSignals, + Version: &thresholdRule.Version, + ExceptionsList: &thresholdRule.ExceptionsList, + RiskScoreMapping: &thresholdRule.RiskScoreMapping, + SeverityMapping: &thresholdRule.SeverityMapping, + RelatedIntegrations: &thresholdRule.RelatedIntegrations, + RequiredFields: &thresholdRule.RequiredFields, + BuildingBlockType: &thresholdRule.BuildingBlockType, + DataViewId: &thresholdRule.DataViewId, + Namespace: &thresholdRule.Namespace, + RuleNameOverride: &thresholdRule.RuleNameOverride, + TimestampOverride: &thresholdRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &thresholdRule.TimestampOverrideFallbackDisabled, + InvestigationFields: &thresholdRule.InvestigationFields, + Filters: &thresholdRule.Filters, + AlertSuppression: nil, // Handle specially for threshold rule + }, &diags, client) + + // Handle threshold-specific alert suppression + if utils.IsKnown(d.AlertSuppression) { + alertSuppression := d.alertSuppressionToThresholdApi(ctx, &diags) + if alertSuppression != nil { + thresholdRule.AlertSuppression = alertSuppression + } + } + + // Set query language + thresholdRule.Language = d.getKQLQueryLanguage() + + if utils.IsKnown(d.SavedId) { + savedId := kbapi.SecurityDetectionsAPISavedQueryId(d.SavedId.ValueString()) + thresholdRule.SavedId = &savedId + } + + // Convert to union type + err := createProps.FromSecurityDetectionsAPIThresholdRuleCreateProps(thresholdRule) + if err != nil { + diags.AddError( + "Error building create properties", + "Could not convert threshold rule properties: "+err.Error(), + ) + } + + return createProps, diags +} +func (d SecurityDetectionRuleData) toThresholdRuleUpdateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps + + // Parse ID to get space_id and rule_id + compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) + diags.Append(resourceIdDiags...) + + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + diags.AddError("ID was not a valid UUID", err.Error()) + return updateProps, diags + } + var id = kbapi.SecurityDetectionsAPIRuleObjectId(uid) + + thresholdRule := kbapi.SecurityDetectionsAPIThresholdRuleUpdateProps{ + Id: &id, + Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), + Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), + Type: kbapi.SecurityDetectionsAPIThresholdRuleUpdatePropsType("threshold"), + Query: kbapi.SecurityDetectionsAPIRuleQuery(d.Query.ValueString()), + RiskScore: kbapi.SecurityDetectionsAPIRiskScore(d.RiskScore.ValueInt64()), + Severity: kbapi.SecurityDetectionsAPISeverity(d.Severity.ValueString()), + } + + // For updates, we need to include the rule_id if it's set + if utils.IsKnown(d.RuleId) { + ruleId := kbapi.SecurityDetectionsAPIRuleSignatureId(d.RuleId.ValueString()) + thresholdRule.RuleId = &ruleId + thresholdRule.Id = nil // if rule_id is set, we cant send id + } + + // Set threshold - this is required for threshold rules + threshold := d.thresholdToApi(ctx, &diags) + if threshold != nil { + thresholdRule.Threshold = *threshold + } + + d.setCommonUpdateProps(ctx, &CommonUpdateProps{ + Actions: &thresholdRule.Actions, + ResponseActions: &thresholdRule.ResponseActions, + RuleId: &thresholdRule.RuleId, + Enabled: &thresholdRule.Enabled, + From: &thresholdRule.From, + To: &thresholdRule.To, + Interval: &thresholdRule.Interval, + Index: &thresholdRule.Index, + Author: &thresholdRule.Author, + Tags: &thresholdRule.Tags, + FalsePositives: &thresholdRule.FalsePositives, + References: &thresholdRule.References, + License: &thresholdRule.License, + Note: &thresholdRule.Note, + InvestigationFields: &thresholdRule.InvestigationFields, + Setup: &thresholdRule.Setup, + MaxSignals: &thresholdRule.MaxSignals, + Version: &thresholdRule.Version, + ExceptionsList: &thresholdRule.ExceptionsList, + RiskScoreMapping: &thresholdRule.RiskScoreMapping, + SeverityMapping: &thresholdRule.SeverityMapping, + RelatedIntegrations: &thresholdRule.RelatedIntegrations, + RequiredFields: &thresholdRule.RequiredFields, + BuildingBlockType: &thresholdRule.BuildingBlockType, + DataViewId: &thresholdRule.DataViewId, + Namespace: &thresholdRule.Namespace, + RuleNameOverride: &thresholdRule.RuleNameOverride, + TimestampOverride: &thresholdRule.TimestampOverride, + TimestampOverrideFallbackDisabled: &thresholdRule.TimestampOverrideFallbackDisabled, + Filters: &thresholdRule.Filters, + AlertSuppression: nil, // Handle specially for threshold rule + }, &diags, client) + + // Handle threshold-specific alert suppression + if utils.IsKnown(d.AlertSuppression) { + alertSuppression := d.alertSuppressionToThresholdApi(ctx, &diags) + if alertSuppression != nil { + thresholdRule.AlertSuppression = alertSuppression + } + } + + // Set query language + thresholdRule.Language = d.getKQLQueryLanguage() + + if utils.IsKnown(d.SavedId) { + savedId := kbapi.SecurityDetectionsAPISavedQueryId(d.SavedId.ValueString()) + thresholdRule.SavedId = &savedId + } + + // Convert to union type + err = updateProps.FromSecurityDetectionsAPIThresholdRuleUpdateProps(thresholdRule) + if err != nil { + diags.AddError( + "Error building update properties", + "Could not convert threshold rule properties: "+err.Error(), + ) + } + + return updateProps, diags +} + +func (d *SecurityDetectionRuleData) updateFromThresholdRule(ctx context.Context, rule *kbapi.SecurityDetectionsAPIThresholdRule) diag.Diagnostics { + var diags diag.Diagnostics + + compId := clients.CompositeId{ + ClusterId: d.SpaceId.ValueString(), + ResourceId: rule.Id.String(), + } + d.Id = types.StringValue(compId.String()) + + d.RuleId = types.StringValue(string(rule.RuleId)) + d.Name = types.StringValue(string(rule.Name)) + d.Type = types.StringValue(string(rule.Type)) + + // Update common fields + diags.Append(d.updateDataViewIdFromApi(ctx, rule.DataViewId)...) + diags.Append(d.updateNamespaceFromApi(ctx, rule.Namespace)...) + diags.Append(d.updateRuleNameOverrideFromApi(ctx, rule.RuleNameOverride)...) + diags.Append(d.updateTimestampOverrideFromApi(ctx, rule.TimestampOverride)...) + diags.Append(d.updateTimestampOverrideFallbackDisabledFromApi(ctx, rule.TimestampOverrideFallbackDisabled)...) + + d.Query = types.StringValue(rule.Query) + d.Language = types.StringValue(string(rule.Language)) + d.Enabled = types.BoolValue(bool(rule.Enabled)) + + // Update building block type + diags.Append(d.updateBuildingBlockTypeFromApi(ctx, rule.BuildingBlockType)...) + d.From = types.StringValue(string(rule.From)) + d.To = types.StringValue(string(rule.To)) + d.Interval = types.StringValue(string(rule.Interval)) + d.Description = types.StringValue(string(rule.Description)) + d.RiskScore = types.Int64Value(int64(rule.RiskScore)) + d.Severity = types.StringValue(string(rule.Severity)) + d.MaxSignals = types.Int64Value(int64(rule.MaxSignals)) + d.Version = types.Int64Value(int64(rule.Version)) + + // Update read-only fields + d.CreatedAt = types.StringValue(rule.CreatedAt.Format("2006-01-02T15:04:05.000Z")) + d.CreatedBy = types.StringValue(rule.CreatedBy) + d.UpdatedAt = types.StringValue(rule.UpdatedAt.Format("2006-01-02T15:04:05.000Z")) + d.UpdatedBy = types.StringValue(rule.UpdatedBy) + d.Revision = types.Int64Value(int64(rule.Revision)) + + // Update index patterns + diags.Append(d.updateIndexFromApi(ctx, rule.Index)...) + + // Threshold-specific fields + thresholdObj, thresholdDiags := convertThresholdToModel(ctx, rule.Threshold) + diags.Append(thresholdDiags...) + if !thresholdDiags.HasError() { + d.Threshold = thresholdObj + } + + // Optional saved query ID + if rule.SavedId != nil { + d.SavedId = types.StringValue(string(*rule.SavedId)) + } else { + d.SavedId = types.StringNull() + } + + // Update author + diags.Append(d.updateAuthorFromApi(ctx, rule.Author)...) + + // Update tags + diags.Append(d.updateTagsFromApi(ctx, rule.Tags)...) + + // Update false positives + diags.Append(d.updateFalsePositivesFromApi(ctx, rule.FalsePositives)...) + + // Update references + diags.Append(d.updateReferencesFromApi(ctx, rule.References)...) + + // Update optional string fields + diags.Append(d.updateLicenseFromApi(ctx, rule.License)...) + diags.Append(d.updateNoteFromApi(ctx, rule.Note)...) + diags.Append(d.updateSetupFromApi(ctx, rule.Setup)...) + + // Update actions + actionDiags := d.updateActionsFromApi(ctx, rule.Actions) + diags.Append(actionDiags...) + + // Update exceptions list + exceptionsListDiags := d.updateExceptionsListFromApi(ctx, rule.ExceptionsList) + diags.Append(exceptionsListDiags...) + + // Update risk score mapping + riskScoreMappingDiags := d.updateRiskScoreMappingFromApi(ctx, rule.RiskScoreMapping) + diags.Append(riskScoreMappingDiags...) + + // Update investigation fields + investigationFieldsDiags := d.updateInvestigationFieldsFromApi(ctx, rule.InvestigationFields) + diags.Append(investigationFieldsDiags...) + + // Update filters field + filtersDiags := d.updateFiltersFromApi(ctx, rule.Filters) + diags.Append(filtersDiags...) + + // Update severity mapping + severityMappingDiags := d.updateSeverityMappingFromApi(ctx, &rule.SeverityMapping) + diags.Append(severityMappingDiags...) + + // Update related integrations + relatedIntegrationsDiags := d.updateRelatedIntegrationsFromApi(ctx, &rule.RelatedIntegrations) + diags.Append(relatedIntegrationsDiags...) + + // Update required fields + requiredFieldsDiags := d.updateRequiredFieldsFromApi(ctx, &rule.RequiredFields) + diags.Append(requiredFieldsDiags...) + + // Update alert suppression + thresholdAlertSuppressionDiags := d.updateThresholdAlertSuppressionFromApi(ctx, rule.AlertSuppression) + diags.Append(thresholdAlertSuppressionDiags...) + + // Update response actions + responseActionsDiags := d.updateResponseActionsFromApi(ctx, rule.ResponseActions) + diags.Append(responseActionsDiags...) + + return diags +} diff --git a/internal/kibana/security_detection_rule/models_to_api_type_utils.go b/internal/kibana/security_detection_rule/models_to_api_type_utils.go new file mode 100644 index 000000000..241419a7e --- /dev/null +++ b/internal/kibana/security_detection_rule/models_to_api_type_utils.go @@ -0,0 +1,829 @@ +package security_detection_rule + +import ( + "context" + "fmt" + "regexp" + "strconv" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/elastic/terraform-provider-elasticstack/internal/diagutil" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" + "github.com/elastic/terraform-provider-elasticstack/internal/utils/customtypes" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/types/basetypes" +) + +// getKQLQueryLanguage maps language string to kbapi.SecurityDetectionsAPIKqlQueryLanguage +func (d SecurityDetectionRuleData) getKQLQueryLanguage() *kbapi.SecurityDetectionsAPIKqlQueryLanguage { + if !utils.IsKnown(d.Language) { + return nil + } + var language kbapi.SecurityDetectionsAPIKqlQueryLanguage + switch d.Language.ValueString() { + case "kuery": + language = "kuery" + case "lucene": + language = "lucene" + default: + language = "kuery" + } + return &language +} + +// buildOsqueryResponseAction creates an Osquery response action from the terraform model +func (d SecurityDetectionRuleData) buildOsqueryResponseAction(ctx context.Context, params ResponseActionParamsModel) (kbapi.SecurityDetectionsAPIResponseAction, diag.Diagnostics) { + var diags diag.Diagnostics + + osqueryAction := kbapi.SecurityDetectionsAPIOsqueryResponseAction{ + ActionTypeId: kbapi.SecurityDetectionsAPIOsqueryResponseActionActionTypeId(".osquery"), + Params: kbapi.SecurityDetectionsAPIOsqueryParams{}, + } + + // Set osquery-specific params + if utils.IsKnown(params.Query) { + osqueryAction.Params.Query = params.Query.ValueStringPointer() + } + if utils.IsKnown(params.PackId) { + osqueryAction.Params.PackId = params.PackId.ValueStringPointer() + } + if utils.IsKnown(params.SavedQueryId) { + osqueryAction.Params.SavedQueryId = params.SavedQueryId.ValueStringPointer() + } + if utils.IsKnown(params.Timeout) { + timeout := float32(params.Timeout.ValueInt64()) + osqueryAction.Params.Timeout = &timeout + } + if utils.IsKnown(params.EcsMapping) { + + // Convert map to ECS mapping structure + ecsMappingElems := make(map[string]basetypes.StringValue) + elemDiags := params.EcsMapping.ElementsAs(ctx, &ecsMappingElems, false) + if !elemDiags.HasError() { + ecsMapping := make(kbapi.SecurityDetectionsAPIEcsMapping) + for key, value := range ecsMappingElems { + if stringVal := value; utils.IsKnown(value) { + ecsMapping[key] = struct { + Field *string `json:"field,omitempty"` + Value *kbapi.SecurityDetectionsAPIEcsMapping_Value `json:"value,omitempty"` + }{ + Field: stringVal.ValueStringPointer(), + } + } + } + osqueryAction.Params.EcsMapping = &ecsMapping + } else { + diags.Append(elemDiags...) + } + } + if utils.IsKnown(params.Queries) { + queries := make([]OsqueryQueryModel, len(params.Queries.Elements())) + queriesDiags := params.Queries.ElementsAs(ctx, &queries, false) + if !queriesDiags.HasError() { + apiQueries := make([]kbapi.SecurityDetectionsAPIOsqueryQuery, 0) + for _, query := range queries { + apiQuery := kbapi.SecurityDetectionsAPIOsqueryQuery{ + Id: query.Id.ValueString(), + Query: query.Query.ValueString(), + } + if utils.IsKnown(query.Platform) { + apiQuery.Platform = query.Platform.ValueStringPointer() + } + if utils.IsKnown(query.Version) { + apiQuery.Version = query.Version.ValueStringPointer() + } + if utils.IsKnown(query.Removed) { + apiQuery.Removed = query.Removed.ValueBoolPointer() + } + if utils.IsKnown(query.Snapshot) { + apiQuery.Snapshot = query.Snapshot.ValueBoolPointer() + } + if utils.IsKnown(query.EcsMapping) { + // Convert map to ECS mapping structure for queries + queryEcsMappingElems := make(map[string]basetypes.StringValue) + queryElemDiags := query.EcsMapping.ElementsAs(ctx, &queryEcsMappingElems, false) + if !queryElemDiags.HasError() { + queryEcsMapping := make(kbapi.SecurityDetectionsAPIEcsMapping) + for key, value := range queryEcsMappingElems { + if stringVal := value; utils.IsKnown(value) { + queryEcsMapping[key] = struct { + Field *string `json:"field,omitempty"` + Value *kbapi.SecurityDetectionsAPIEcsMapping_Value `json:"value,omitempty"` + }{ + Field: stringVal.ValueStringPointer(), + } + } + } + apiQuery.EcsMapping = &queryEcsMapping + } + } + apiQueries = append(apiQueries, apiQuery) + } + osqueryAction.Params.Queries = &apiQueries + } else { + diags = append(diags, queriesDiags...) + } + } + + var apiResponseAction kbapi.SecurityDetectionsAPIResponseAction + err := apiResponseAction.FromSecurityDetectionsAPIOsqueryResponseAction(osqueryAction) + if err != nil { + diags.AddError("Error converting osquery response action", err.Error()) + } + + return apiResponseAction, diags +} + +// buildEndpointResponseAction creates an Endpoint response action from the terraform model +func (d SecurityDetectionRuleData) buildEndpointResponseAction(ctx context.Context, params ResponseActionParamsModel) (kbapi.SecurityDetectionsAPIResponseAction, diag.Diagnostics) { + var diags diag.Diagnostics + + endpointAction := kbapi.SecurityDetectionsAPIEndpointResponseAction{ + ActionTypeId: kbapi.SecurityDetectionsAPIEndpointResponseActionActionTypeId(".endpoint"), + } + + // Determine the type of endpoint action based on the command + if utils.IsKnown(params.Command) { + command := params.Command.ValueString() + switch command { + case "isolate": + // Use DefaultParams for isolate command + defaultParams := kbapi.SecurityDetectionsAPIDefaultParams{ + Command: kbapi.SecurityDetectionsAPIDefaultParamsCommand("isolate"), + } + if utils.IsKnown(params.Comment) { + defaultParams.Comment = params.Comment.ValueStringPointer() + } + err := endpointAction.Params.FromSecurityDetectionsAPIDefaultParams(defaultParams) + if err != nil { + diags.AddError("Error setting endpoint default params", err.Error()) + return kbapi.SecurityDetectionsAPIResponseAction{}, diags + } + + case "kill-process", "suspend-process": + // Use ProcessesParams for process commands + processesParams := kbapi.SecurityDetectionsAPIProcessesParams{ + Command: kbapi.SecurityDetectionsAPIProcessesParamsCommand(command), + } + if utils.IsKnown(params.Comment) { + processesParams.Comment = params.Comment.ValueStringPointer() + } + + // Set config if provided + if utils.IsKnown(params.Config) { + config := utils.ObjectTypeToStruct(ctx, params.Config, path.Root("response_actions").AtName("params").AtName("config"), &diags, + func(item EndpointProcessConfigModel, meta utils.ObjectMeta) EndpointProcessConfigModel { + return item + }) + + processesParams.Config = struct { + Field string `json:"field"` + Overwrite *bool `json:"overwrite,omitempty"` + }{ + Field: config.Field.ValueString(), + } + if utils.IsKnown(config.Overwrite) { + processesParams.Config.Overwrite = config.Overwrite.ValueBoolPointer() + } + } + + err := endpointAction.Params.FromSecurityDetectionsAPIProcessesParams(processesParams) + if err != nil { + diags.AddError("Error setting endpoint processes params", err.Error()) + return kbapi.SecurityDetectionsAPIResponseAction{}, diags + } + default: + diags.AddError( + "Unsupported params type", + fmt.Sprintf("Params type '%s' is not supported", params.Command.ValueString()), + ) + } + } + + var apiResponseAction kbapi.SecurityDetectionsAPIResponseAction + err := apiResponseAction.FromSecurityDetectionsAPIEndpointResponseAction(endpointAction) + if err != nil { + diags.AddError("Error converting endpoint response action", err.Error()) + } + + return apiResponseAction, diags +} + +// Helper function to process threshold configuration for threshold rules +func (d SecurityDetectionRuleData) thresholdToApi(ctx context.Context, diags *diag.Diagnostics) *kbapi.SecurityDetectionsAPIThreshold { + if !utils.IsKnown(d.Threshold) { + return nil + } + + threshold := utils.ObjectTypeToStruct(ctx, d.Threshold, path.Root("threshold"), diags, + func(item ThresholdModel, meta utils.ObjectMeta) kbapi.SecurityDetectionsAPIThreshold { + threshold := kbapi.SecurityDetectionsAPIThreshold{ + Value: kbapi.SecurityDetectionsAPIThresholdValue(item.Value.ValueInt64()), + } + + // Handle threshold field(s) + if utils.IsKnown(item.Field) { + fieldList := utils.ListTypeToSlice_String(ctx, item.Field, meta.Path.AtName("field"), meta.Diags) + if len(fieldList) > 0 { + var thresholdField kbapi.SecurityDetectionsAPIThresholdField + if len(fieldList) == 1 { + err := thresholdField.FromSecurityDetectionsAPIThresholdField0(fieldList[0]) + if err != nil { + meta.Diags.AddError("Error setting threshold field", err.Error()) + } else { + threshold.Field = thresholdField + } + } else { + err := thresholdField.FromSecurityDetectionsAPIThresholdField1(fieldList) + if err != nil { + meta.Diags.AddError("Error setting threshold fields", err.Error()) + } else { + threshold.Field = thresholdField + } + } + } + } + + // Handle cardinality (optional) + if utils.IsKnown(item.Cardinality) { + cardinalityList := utils.ListTypeToSlice(ctx, item.Cardinality, meta.Path.AtName("cardinality"), meta.Diags, + func(item CardinalityModel, meta utils.ListMeta) struct { + Field string `json:"field"` + Value int `json:"value"` + } { + return struct { + Field string `json:"field"` + Value int `json:"value"` + }{ + Field: item.Field.ValueString(), + Value: int(item.Value.ValueInt64()), + } + }) + if len(cardinalityList) > 0 { + threshold.Cardinality = (*kbapi.SecurityDetectionsAPIThresholdCardinality)(&cardinalityList) + } + } + + return threshold + }) + + return threshold +} + +// Helper function to convert alert suppression from TF data to API type +func (d SecurityDetectionRuleData) alertSuppressionToApi(ctx context.Context, diags *diag.Diagnostics) *kbapi.SecurityDetectionsAPIAlertSuppression { + if !utils.IsKnown(d.AlertSuppression) { + return nil + } + + var model AlertSuppressionModel + objDiags := d.AlertSuppression.As(ctx, &model, basetypes.ObjectAsOptions{}) + diags.Append(objDiags...) + if diags.HasError() { + return nil + } + + suppression := &kbapi.SecurityDetectionsAPIAlertSuppression{} + + // Handle group_by (required) + if utils.IsKnown(model.GroupBy) { + groupByList := utils.ListTypeToSlice_String(ctx, model.GroupBy, path.Root("alert_suppression").AtName("group_by"), diags) + if len(groupByList) > 0 { + suppression.GroupBy = groupByList + } + } + + // Handle duration (optional) + if utils.IsKnown(model.Duration) { + duration, durationDiags := parseDurationToApi(model.Duration) + diags.Append(durationDiags...) + if !durationDiags.HasError() { + suppression.Duration = &duration + } + } + + // Handle missing_fields_strategy (optional) + if utils.IsKnown(model.MissingFieldsStrategy) { + strategy := kbapi.SecurityDetectionsAPIAlertSuppressionMissingFieldsStrategy(model.MissingFieldsStrategy.ValueString()) + suppression.MissingFieldsStrategy = &strategy + } + + return suppression +} + +// Helper function to convert alert suppression from TF data to threshold-specific API type +func (d SecurityDetectionRuleData) alertSuppressionToThresholdApi(ctx context.Context, diags *diag.Diagnostics) *kbapi.SecurityDetectionsAPIThresholdAlertSuppression { + if !utils.IsKnown(d.AlertSuppression) { + return nil + } + + var model AlertSuppressionModel + objDiags := d.AlertSuppression.As(ctx, &model, basetypes.ObjectAsOptions{}) + diags.Append(objDiags...) + if diags.HasError() { + return nil + } + + suppression := &kbapi.SecurityDetectionsAPIThresholdAlertSuppression{} + + // Handle duration (required for threshold alert suppression) + if !utils.IsKnown(model.Duration) { + diags.AddError( + "Duration required for threshold alert suppression", + "Threshold alert suppression requires a duration to be specified", + ) + return nil + } + + duration, durationDiags := parseDurationToApi(model.Duration) + diags.Append(durationDiags...) + if !durationDiags.HasError() { + suppression.Duration = duration + } + + // Note: Threshold alert suppression only supports duration field. + // GroupBy and MissingFieldsStrategy are not supported for threshold rules. + + return suppression +} + +// Helper function to process threat mapping configuration for threat match rules +func (d SecurityDetectionRuleData) threatMappingToApi(ctx context.Context) (kbapi.SecurityDetectionsAPIThreatMapping, diag.Diagnostics) { + var diags diag.Diagnostics + + threatMapping := make([]SecurityDetectionRuleTfDataItem, len(d.ThreatMapping.Elements())) + + threatMappingDiags := d.ThreatMapping.ElementsAs(ctx, &threatMapping, false) + if threatMappingDiags.HasError() { + diags.Append(threatMappingDiags...) + return nil, diags + } + + apiThreatMapping := make(kbapi.SecurityDetectionsAPIThreatMapping, 0) + for _, mapping := range threatMapping { + if !utils.IsKnown(mapping.Entries) { + continue + } + + entries := make([]SecurityDetectionRuleTfDataItemEntry, len(mapping.Entries.Elements())) + entryDiag := mapping.Entries.ElementsAs(ctx, &entries, false) + diags = append(diags, entryDiag...) + + apiThreatMappingEntries := make([]kbapi.SecurityDetectionsAPIThreatMappingEntry, 0) + for _, entry := range entries { + + apiMapping := kbapi.SecurityDetectionsAPIThreatMappingEntry{ + Field: kbapi.SecurityDetectionsAPINonEmptyString(entry.Field.ValueString()), + Type: kbapi.SecurityDetectionsAPIThreatMappingEntryType(entry.Type.ValueString()), + Value: kbapi.SecurityDetectionsAPINonEmptyString(entry.Value.ValueString()), + } + apiThreatMappingEntries = append(apiThreatMappingEntries, apiMapping) + + } + + apiThreatMapping = append(apiThreatMapping, struct { + Entries []kbapi.SecurityDetectionsAPIThreatMappingEntry `json:"entries"` + }{Entries: apiThreatMappingEntries}) + } + + return apiThreatMapping, diags +} + +// Helper function to process response actions configuration for all rule types +func (d SecurityDetectionRuleData) responseActionsToApi(ctx context.Context, client clients.MinVersionEnforceable) ([]kbapi.SecurityDetectionsAPIResponseAction, diag.Diagnostics) { + var diags diag.Diagnostics + + if client == nil { + diags.AddError( + "Client is not initialized", + "Response actions require a valid API client", + ) + return nil, diags + } + + if !utils.IsKnown(d.ResponseActions) || len(d.ResponseActions.Elements()) == 0 { + return nil, diags + } + + // Check version support for response actions + if supported, versionDiags := client.EnforceMinVersion(ctx, MinVersionResponseActions); versionDiags.HasError() { + diags.Append(diagutil.FrameworkDiagsFromSDK(versionDiags)...) + return nil, diags + } else if !supported { + // Version is not supported, return nil without error + diags.AddError("Response actions are unsupported", + fmt.Sprintf("Response actions require server version %s or higher", MinVersionResponseActions.String())) + return nil, diags + } + + apiResponseActions := utils.ListTypeToSlice(ctx, d.ResponseActions, path.Root("response_actions"), &diags, + func(responseAction ResponseActionModel, meta utils.ListMeta) kbapi.SecurityDetectionsAPIResponseAction { + + actionTypeId := responseAction.ActionTypeId.ValueString() + + params := utils.ObjectTypeToStruct(ctx, responseAction.Params, meta.Path.AtName("params"), &diags, + func(item ResponseActionParamsModel, meta utils.ObjectMeta) ResponseActionParamsModel { + return item + }) + + if params == nil { + return kbapi.SecurityDetectionsAPIResponseAction{} + } + + switch actionTypeId { + case ".osquery": + apiAction, actionDiags := d.buildOsqueryResponseAction(ctx, *params) + diags.Append(actionDiags...) + return apiAction + + case ".endpoint": + apiAction, actionDiags := d.buildEndpointResponseAction(ctx, *params) + diags.Append(actionDiags...) + return apiAction + + default: + diags.AddError( + "Unsupported action_type_id in response actions", + fmt.Sprintf("action_type_id '%s' is not supported", actionTypeId), + ) + return kbapi.SecurityDetectionsAPIResponseAction{} + } + }) + + return apiResponseActions, diags +} + +// Helper function to process actions configuration for all rule types +func (d SecurityDetectionRuleData) actionsToApi(ctx context.Context) ([]kbapi.SecurityDetectionsAPIRuleAction, diag.Diagnostics) { + var diags diag.Diagnostics + + if !utils.IsKnown(d.Actions) || len(d.Actions.Elements()) == 0 { + return nil, diags + } + + apiActions := utils.ListTypeToSlice(ctx, d.Actions, path.Root("actions"), &diags, + func(action ActionModel, meta utils.ListMeta) kbapi.SecurityDetectionsAPIRuleAction { + apiAction := kbapi.SecurityDetectionsAPIRuleAction{ + ActionTypeId: action.ActionTypeId.ValueString(), + Id: kbapi.SecurityDetectionsAPIRuleActionId(action.Id.ValueString()), + } + + // Convert params map + if utils.IsKnown(action.Params) { + paramsStringMap := make(map[string]string) + paramsDiags := action.Params.ElementsAs(meta.Context, ¶msStringMap, false) + if !paramsDiags.HasError() { + paramsMap := make(map[string]interface{}) + for k, v := range paramsStringMap { + paramsMap[k] = v + } + apiAction.Params = kbapi.SecurityDetectionsAPIRuleActionParams(paramsMap) + } + meta.Diags.Append(paramsDiags...) + } + + // Set optional fields + if utils.IsKnown(action.Group) { + group := kbapi.SecurityDetectionsAPIRuleActionGroup(action.Group.ValueString()) + apiAction.Group = &group + } + + if utils.IsKnown(action.Uuid) { + uuid := kbapi.SecurityDetectionsAPINonEmptyString(action.Uuid.ValueString()) + apiAction.Uuid = &uuid + } + + if utils.IsKnown(action.AlertsFilter) { + alertsFilterStringMap := make(map[string]string) + alertsFilterDiags := action.AlertsFilter.ElementsAs(meta.Context, &alertsFilterStringMap, false) + if !alertsFilterDiags.HasError() { + alertsFilterMap := make(map[string]interface{}) + for k, v := range alertsFilterStringMap { + alertsFilterMap[k] = v + } + apiAlertsFilter := kbapi.SecurityDetectionsAPIRuleActionAlertsFilter(alertsFilterMap) + apiAction.AlertsFilter = &apiAlertsFilter + } + meta.Diags.Append(alertsFilterDiags...) + } + + // Handle frequency using ObjectTypeToStruct + if utils.IsKnown(action.Frequency) { + frequency := utils.ObjectTypeToStruct(meta.Context, action.Frequency, meta.Path.AtName("frequency"), meta.Diags, + func(frequencyModel ActionFrequencyModel, freqMeta utils.ObjectMeta) kbapi.SecurityDetectionsAPIRuleActionFrequency { + apiFreq := kbapi.SecurityDetectionsAPIRuleActionFrequency{ + NotifyWhen: kbapi.SecurityDetectionsAPIRuleActionNotifyWhen(frequencyModel.NotifyWhen.ValueString()), + Summary: frequencyModel.Summary.ValueBool(), + } + + // Handle throttle - can be string or specific values + if utils.IsKnown(frequencyModel.Throttle) { + throttleStr := frequencyModel.Throttle.ValueString() + var throttle kbapi.SecurityDetectionsAPIRuleActionThrottle + if throttleStr == "no_actions" || throttleStr == "rule" { + // Use the enum value + var throttle0 kbapi.SecurityDetectionsAPIRuleActionThrottle0 + if throttleStr == "no_actions" { + throttle0 = kbapi.SecurityDetectionsAPIRuleActionThrottle0NoActions + } else { + throttle0 = kbapi.SecurityDetectionsAPIRuleActionThrottle0Rule + } + err := throttle.FromSecurityDetectionsAPIRuleActionThrottle0(throttle0) + if err != nil { + freqMeta.Diags.AddError("Error setting throttle enum", err.Error()) + } + } else { + // Use the time interval string + throttle1 := kbapi.SecurityDetectionsAPIRuleActionThrottle1(throttleStr) + err := throttle.FromSecurityDetectionsAPIRuleActionThrottle1(throttle1) + if err != nil { + freqMeta.Diags.AddError("Error setting throttle interval", err.Error()) + } + } + apiFreq.Throttle = throttle + } + + return apiFreq + }) + + if frequency != nil { + apiAction.Frequency = frequency + } + } + + return apiAction + }) + + // Filter out empty actions (where ActionTypeId or Id was null) + validActions := make([]kbapi.SecurityDetectionsAPIRuleAction, 0) + for _, action := range apiActions { + if action.ActionTypeId != "" && action.Id != "" { + validActions = append(validActions, action) + } + } + + return validActions, diags +} + +// Helper function to process exceptions list configuration for all rule types +func (d SecurityDetectionRuleData) exceptionsListToApi(ctx context.Context) ([]kbapi.SecurityDetectionsAPIRuleExceptionList, diag.Diagnostics) { + var diags diag.Diagnostics + + if !utils.IsKnown(d.ExceptionsList) || len(d.ExceptionsList.Elements()) == 0 { + return nil, diags + } + + apiExceptionsList := utils.ListTypeToSlice(ctx, d.ExceptionsList, path.Root("exceptions_list"), &diags, + func(exception ExceptionsListModel, meta utils.ListMeta) kbapi.SecurityDetectionsAPIRuleExceptionList { + + apiException := kbapi.SecurityDetectionsAPIRuleExceptionList{ + Id: exception.Id.ValueString(), + ListId: exception.ListId.ValueString(), + NamespaceType: kbapi.SecurityDetectionsAPIRuleExceptionListNamespaceType(exception.NamespaceType.ValueString()), + Type: kbapi.SecurityDetectionsAPIExceptionListType(exception.Type.ValueString()), + } + + return apiException + }) + + // Filter out empty exceptions (where required fields were null) + validExceptions := make([]kbapi.SecurityDetectionsAPIRuleExceptionList, 0) + for _, exception := range apiExceptionsList { + if exception.Id != "" && exception.ListId != "" { + validExceptions = append(validExceptions, exception) + } + } + + return validExceptions, diags +} + +// Helper function to process risk score mapping configuration for all rule types +func (d SecurityDetectionRuleData) riskScoreMappingToApi(ctx context.Context) (kbapi.SecurityDetectionsAPIRiskScoreMapping, diag.Diagnostics) { + var diags diag.Diagnostics + + if !utils.IsKnown(d.RiskScoreMapping) || len(d.RiskScoreMapping.Elements()) == 0 { + return nil, diags + } + + apiRiskScoreMapping := utils.ListTypeToSlice(ctx, d.RiskScoreMapping, path.Root("risk_score_mapping"), &diags, + func(mapping RiskScoreMappingModel, meta utils.ListMeta) struct { + Field string `json:"field"` + Operator kbapi.SecurityDetectionsAPIRiskScoreMappingOperator `json:"operator"` + RiskScore *kbapi.SecurityDetectionsAPIRiskScore `json:"risk_score,omitempty"` + Value string `json:"value"` + } { + apiMapping := struct { + Field string `json:"field"` + Operator kbapi.SecurityDetectionsAPIRiskScoreMappingOperator `json:"operator"` + RiskScore *kbapi.SecurityDetectionsAPIRiskScore `json:"risk_score,omitempty"` + Value string `json:"value"` + }{ + Field: mapping.Field.ValueString(), + Operator: kbapi.SecurityDetectionsAPIRiskScoreMappingOperator(mapping.Operator.ValueString()), + Value: mapping.Value.ValueString(), + } + + // Set optional risk score if provided + if utils.IsKnown(mapping.RiskScore) { + riskScore := kbapi.SecurityDetectionsAPIRiskScore(mapping.RiskScore.ValueInt64()) + apiMapping.RiskScore = &riskScore + } + + return apiMapping + }) + + // Return the mappings (any empty mappings were filtered out during creation) + return apiRiskScoreMapping, diags +} + +// Helper function to process investigation fields configuration for all rule types +func (d SecurityDetectionRuleData) investigationFieldsToApi(ctx context.Context) (*kbapi.SecurityDetectionsAPIInvestigationFields, diag.Diagnostics) { + var diags diag.Diagnostics + + if !utils.IsKnown(d.InvestigationFields) || len(d.InvestigationFields.Elements()) == 0 { + return nil, diags + } + + fieldNames := make([]string, len(d.InvestigationFields.Elements())) + fieldDiag := d.InvestigationFields.ElementsAs(ctx, &fieldNames, false) + if fieldDiag.HasError() { + diags.Append(fieldDiag...) + return nil, diags + } + + // Convert to API type + apiFieldNames := make([]kbapi.SecurityDetectionsAPINonEmptyString, len(fieldNames)) + for i, field := range fieldNames { + apiFieldNames[i] = kbapi.SecurityDetectionsAPINonEmptyString(field) + } + + return &kbapi.SecurityDetectionsAPIInvestigationFields{ + FieldNames: apiFieldNames, + }, diags +} + +// Helper function to process related integrations configuration for all rule types +func (d SecurityDetectionRuleData) relatedIntegrationsToApi(ctx context.Context) (*kbapi.SecurityDetectionsAPIRelatedIntegrationArray, diag.Diagnostics) { + var diags diag.Diagnostics + + if !utils.IsKnown(d.RelatedIntegrations) || len(d.RelatedIntegrations.Elements()) == 0 { + return nil, diags + } + + apiRelatedIntegrations := utils.ListTypeToSlice(ctx, d.RelatedIntegrations, path.Root("related_integrations"), &diags, + func(integration RelatedIntegrationModel, meta utils.ListMeta) kbapi.SecurityDetectionsAPIRelatedIntegration { + + apiIntegration := kbapi.SecurityDetectionsAPIRelatedIntegration{ + Package: kbapi.SecurityDetectionsAPINonEmptyString(integration.Package.ValueString()), + Version: kbapi.SecurityDetectionsAPINonEmptyString(integration.Version.ValueString()), + } + + // Set optional integration field if provided + if utils.IsKnown(integration.Integration) { + integrationName := kbapi.SecurityDetectionsAPINonEmptyString(integration.Integration.ValueString()) + apiIntegration.Integration = &integrationName + } + + return apiIntegration + }) + + return &apiRelatedIntegrations, diags +} + +// Helper function to process required fields configuration for all rule types +func (d SecurityDetectionRuleData) requiredFieldsToApi(ctx context.Context) (*[]kbapi.SecurityDetectionsAPIRequiredFieldInput, diag.Diagnostics) { + var diags diag.Diagnostics + + if !utils.IsKnown(d.RequiredFields) || len(d.RequiredFields.Elements()) == 0 { + return nil, diags + } + + apiRequiredFields := utils.ListTypeToSlice(ctx, d.RequiredFields, path.Root("required_fields"), &diags, + func(field RequiredFieldModel, meta utils.ListMeta) kbapi.SecurityDetectionsAPIRequiredFieldInput { + + return kbapi.SecurityDetectionsAPIRequiredFieldInput{ + Name: field.Name.ValueString(), + Type: field.Type.ValueString(), + } + }) + + return &apiRequiredFields, diags +} + +// Helper function to process severity mapping configuration for all rule types +func (d SecurityDetectionRuleData) severityMappingToApi(ctx context.Context) (*kbapi.SecurityDetectionsAPISeverityMapping, diag.Diagnostics) { + var diags diag.Diagnostics + + if !utils.IsKnown(d.SeverityMapping) || len(d.SeverityMapping.Elements()) == 0 { + return nil, diags + } + + apiSeverityMapping := utils.ListTypeToSlice(ctx, d.SeverityMapping, path.Root("severity_mapping"), &diags, + func(mapping SeverityMappingModel, meta utils.ListMeta) struct { + Field string `json:"field"` + Operator kbapi.SecurityDetectionsAPISeverityMappingOperator `json:"operator"` + Severity kbapi.SecurityDetectionsAPISeverity `json:"severity"` + Value string `json:"value"` + } { + return struct { + Field string `json:"field"` + Operator kbapi.SecurityDetectionsAPISeverityMappingOperator `json:"operator"` + Severity kbapi.SecurityDetectionsAPISeverity `json:"severity"` + Value string `json:"value"` + }{ + Field: mapping.Field.ValueString(), + Operator: kbapi.SecurityDetectionsAPISeverityMappingOperator(mapping.Operator.ValueString()), + Severity: kbapi.SecurityDetectionsAPISeverity(mapping.Severity.ValueString()), + Value: mapping.Value.ValueString(), + } + }) + + // Convert to the expected slice type + severityMappingSlice := make(kbapi.SecurityDetectionsAPISeverityMapping, len(apiSeverityMapping)) + copy(severityMappingSlice, apiSeverityMapping) + + return &severityMappingSlice, diags +} + +// filtersToApi converts the Terraform filters field to the API type +func (d SecurityDetectionRuleData) filtersToApi(ctx context.Context) (*kbapi.SecurityDetectionsAPIRuleFilterArray, diag.Diagnostics) { + var diags diag.Diagnostics + + if !utils.IsKnown(d.Filters) { + return nil, diags + } + + // Unmarshal the JSON string to []interface{} + var filters kbapi.SecurityDetectionsAPIRuleFilterArray + unmarshalDiags := d.Filters.Unmarshal(&filters) + diags.Append(unmarshalDiags...) + + if diags.HasError() { + return nil, diags + } + + return &filters, diags +} + +// parseDurationToApi converts a customtypes.Duration to the API structure +func parseDurationToApi(duration customtypes.Duration) (kbapi.SecurityDetectionsAPIAlertSuppressionDuration, diag.Diagnostics) { + var diags diag.Diagnostics + + if !utils.IsKnown(duration) { + diags.AddError("Duration Parse error", "duration string value is unknown") + return kbapi.SecurityDetectionsAPIAlertSuppressionDuration{}, diags + } + + // Get the raw duration string (e.g. "5m", "1h", "30s") + durationStr := duration.ValueString() + + // Parse the duration string using regex to extract value and unit + durationRegex := regexp.MustCompile(`^(\d+)([smhd])$`) + matches := durationRegex.FindStringSubmatch(durationStr) + + if len(matches) != 3 { + diags.AddError( + "Invalid duration format", + fmt.Sprintf("Duration '%s' is not in valid format. Expected format: number followed by unit (s, m, h)", durationStr), + ) + return kbapi.SecurityDetectionsAPIAlertSuppressionDuration{}, diags + } + + // Parse the numeric value + value, err := strconv.Atoi(matches[1]) + if err != nil { + diags.AddError( + "Invalid duration value", + fmt.Sprintf("Failed to parse duration value '%s': %s", matches[1], err.Error()), + ) + return kbapi.SecurityDetectionsAPIAlertSuppressionDuration{}, diags + } + + // Map the unit from the string to the API unit type + var unit kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnit + switch matches[2] { + case "s": + unit = kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnitS + case "m": + unit = kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnitM + case "h": + unit = kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnitH + case "d": + // Convert days to hours since API doesn't support days unit + value = value * 24 + unit = kbapi.SecurityDetectionsAPIAlertSuppressionDurationUnitH + default: + diags.AddError( + "Unsupported duration unit", + fmt.Sprintf("Unit '%s' is not supported. Supported units: s, m, h", matches[2]), + ) + return kbapi.SecurityDetectionsAPIAlertSuppressionDuration{}, diags + } + + return kbapi.SecurityDetectionsAPIAlertSuppressionDuration{ + Value: value, + Unit: unit, + }, diags +} diff --git a/internal/kibana/security_detection_rule/read.go b/internal/kibana/security_detection_rule/read.go new file mode 100644 index 000000000..e0b86bd1b --- /dev/null +++ b/internal/kibana/security_detection_rule/read.go @@ -0,0 +1,117 @@ +package security_detection_rule + +import ( + "context" + "fmt" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/resource" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +func (r *securityDetectionRuleResource) Read(ctx context.Context, req resource.ReadRequest, resp *resource.ReadResponse) { + var data SecurityDetectionRuleData + + resp.Diagnostics.Append(req.State.Get(ctx, &data)...) + if resp.Diagnostics.HasError() { + return + } + + // Parse ID to get space_id and rule_id + compId, diags := clients.CompositeIdFromStrFw(data.Id.ValueString()) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + // Use the extracted read method + readData, diags := r.read(ctx, compId.ResourceId, compId.ClusterId) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + // Check if the rule was found (nil data indicates 404) + if readData == nil { + // Rule was deleted outside of Terraform + resp.State.RemoveResource(ctx) + return + } + + // Set the composite ID and state + readData.Id = data.Id + resp.Diagnostics.Append(resp.State.Set(ctx, readData)...) +} + +// read extracts the core functionality of reading a security detection rule +func (r *securityDetectionRuleResource) read(ctx context.Context, resourceId, spaceId string) (*SecurityDetectionRuleData, diag.Diagnostics) { + var diags diag.Diagnostics + + data := &SecurityDetectionRuleData{} + data.initializeAllFieldsToDefaults(ctx, &diags) + + // Get the rule using kbapi client + kbClient, err := r.client.GetKibanaOapiClient() + if err != nil { + diags.AddError( + "Error getting Kibana client", + "Could not get Kibana OAPI client: "+err.Error(), + ) + return nil, diags + } + + // Read the rule + uid, err := uuid.Parse(resourceId) + if err != nil { + diags.AddError("ID was not a valid UUID", err.Error()) + return nil, diags + } + ruleObjectId := kbapi.SecurityDetectionsAPIRuleObjectId(uid) + params := &kbapi.ReadRuleParams{ + Id: &ruleObjectId, + } + + response, err := kbClient.API.ReadRuleWithResponse(ctx, params) + if err != nil { + diags.AddError( + "Error reading security detection rule", + "Could not read security detection rule: "+err.Error(), + ) + return nil, diags + } + + if response.StatusCode() == 404 { + // Rule was deleted - return nil to indicate this + return nil, diags + } + + if response.StatusCode() != 200 { + diags.AddError( + "Error reading security detection rule", + fmt.Sprintf("API returned status %d: %s", response.StatusCode(), string(response.Body)), + ) + return nil, diags + } + + // Parse the response + updateDiags := data.updateFromRule(ctx, response.JSON200) + diags.Append(updateDiags...) + if diags.HasError() { + return nil, diags + } + + // Ensure space_id is set correctly + data.SpaceId = types.StringValue(spaceId) + + compId := clients.CompositeId{ + ResourceId: resourceId, + ClusterId: spaceId, + } + + data.Id = types.StringValue(compId.String()) + + return data, diags +} diff --git a/internal/kibana/security_detection_rule/resource.go b/internal/kibana/security_detection_rule/resource.go new file mode 100644 index 000000000..cf4658b0f --- /dev/null +++ b/internal/kibana/security_detection_rule/resource.go @@ -0,0 +1,35 @@ +package security_detection_rule + +import ( + "context" + + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/resource" +) + +var _ resource.Resource = &securityDetectionRuleResource{} +var _ resource.ResourceWithConfigure = &securityDetectionRuleResource{} +var _ resource.ResourceWithImportState = &securityDetectionRuleResource{} + +func NewSecurityDetectionRuleResource() resource.Resource { + return &securityDetectionRuleResource{} +} + +type securityDetectionRuleResource struct { + client *clients.ApiClient +} + +func (r *securityDetectionRuleResource) Metadata(_ context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse) { + resp.TypeName = req.ProviderTypeName + "_kibana_security_detection_rule" +} + +func (r *securityDetectionRuleResource) Configure(_ context.Context, req resource.ConfigureRequest, resp *resource.ConfigureResponse) { + client, diags := clients.ConvertProviderData(req.ProviderData) + resp.Diagnostics.Append(diags...) + r.client = client +} + +func (r *securityDetectionRuleResource) ImportState(ctx context.Context, request resource.ImportStateRequest, response *resource.ImportStateResponse) { + resource.ImportStatePassthroughID(ctx, path.Root("id"), request, response) +} diff --git a/internal/kibana/security_detection_rule/rule_processor.go b/internal/kibana/security_detection_rule/rule_processor.go new file mode 100644 index 000000000..cfe808c16 --- /dev/null +++ b/internal/kibana/security_detection_rule/rule_processor.go @@ -0,0 +1,124 @@ +package security_detection_rule + +import ( + "context" + "fmt" + + "github.com/elastic/terraform-provider-elasticstack/generated/kbapi" + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/hashicorp/terraform-plugin-framework/diag" +) + +type ruleProcessor interface { + HandlesRuleType(t string) bool + HandlesAPIRuleResponse(rule any) bool + ToCreateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) + ToUpdateProps(ctx context.Context, client clients.MinVersionEnforceable, d SecurityDetectionRuleData) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) + UpdateFromResponse(ctx context.Context, rule any, d *SecurityDetectionRuleData) diag.Diagnostics + ExtractId(response any) (string, diag.Diagnostics) +} + +func getRuleProcessors() []ruleProcessor { + return []ruleProcessor{ + QueryRuleProcessor{}, + EqlRuleProcessor{}, + EsqlRuleProcessor{}, + MachineLearningRuleProcessor{}, + NewTermsRuleProcessor{}, + SavedQueryRuleProcessor{}, + ThreatMatchRuleProcessor{}, + ThresholdRuleProcessor{}, + } +} + +func processorForType(t string) (ruleProcessor, bool) { + for _, proc := range getRuleProcessors() { + if proc.HandlesRuleType(t) { + return proc, true + } + } + + return nil, false +} + +func getProcessorForResponse(resp *kbapi.SecurityDetectionsAPIRuleResponse) (ruleProcessor, interface{}, diag.Diagnostics) { + var diags diag.Diagnostics + respValue, err := resp.ValueByDiscriminator() + if err != nil { + diags.AddError( + "Error determining rule processor", + "Could not determine the processor for the security detection rule from the API response: "+err.Error(), + ) + return nil, nil, diags + } + + for _, proc := range getRuleProcessors() { + if proc.HandlesAPIRuleResponse(respValue) { + return proc, respValue, diags + } + } + + diags.AddError( + "Error determining rule processor.", + "No processor found for rule", + ) + + return nil, nil, diags +} + +func (d SecurityDetectionRuleData) toCreateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var createProps kbapi.SecurityDetectionsAPIRuleCreateProps + + processorForType, ok := processorForType(d.Type.ValueString()) + if !ok { + diags.AddError( + "Unsupported rule type", + fmt.Sprintf("Rule type '%s' is not supported", d.Type.ValueString()), + ) + return createProps, diags + } + return processorForType.ToCreateProps(ctx, client, d) +} + +func (d SecurityDetectionRuleData) toUpdateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleUpdateProps, diag.Diagnostics) { + var diags diag.Diagnostics + var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps + + processorForType, ok := processorForType(d.Type.ValueString()) + if !ok { + diags.AddError( + "Unsupported rule type", + fmt.Sprintf("Rule type '%s' is not supported", d.Type.ValueString()), + ) + return updateProps, diags + } + return processorForType.ToUpdateProps(ctx, client, d) +} + +func (d *SecurityDetectionRuleData) updateFromRule(ctx context.Context, response *kbapi.SecurityDetectionsAPIRuleResponse) diag.Diagnostics { + var diags diag.Diagnostics + + // Get the processor for this rule type and use it to update the data + processorForType, respValue, responseDiags := getProcessorForResponse(response) + if responseDiags.HasError() { + diags.Append(responseDiags...) + return diags + } + + return processorForType.UpdateFromResponse(ctx, respValue, d) +} + +// Helper function to extract rule ID from any rule type +func extractId(response *kbapi.SecurityDetectionsAPIRuleResponse) (string, diag.Diagnostics) { + var diags diag.Diagnostics + + // Get the processor for this rule type and use it to update the data + processorForType, respValue, responseDiags := getProcessorForResponse(response) + if responseDiags.HasError() || processorForType == nil || respValue == nil { + diags.Append(responseDiags...) + return "", diags + } + + return processorForType.ExtractId(respValue) +} diff --git a/internal/kibana/security_detection_rule/schema.go b/internal/kibana/security_detection_rule/schema.go new file mode 100644 index 000000000..0a8b5696a --- /dev/null +++ b/internal/kibana/security_detection_rule/schema.go @@ -0,0 +1,896 @@ +package security_detection_rule + +import ( + "context" + "regexp" + + "github.com/elastic/terraform-provider-elasticstack/internal/utils/customtypes" + "github.com/hashicorp/terraform-plugin-framework-jsontypes/jsontypes" + "github.com/hashicorp/terraform-plugin-framework-validators/int64validator" + "github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/resource" + "github.com/hashicorp/terraform-plugin-framework/resource/schema" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/booldefault" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/int64default" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/listdefault" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/stringdefault" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/stringplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +func (r *securityDetectionRuleResource) Schema(_ context.Context, _ resource.SchemaRequest, resp *resource.SchemaResponse) { + resp.Schema = GetSchema() +} + +func GetSchema() schema.Schema { + return schema.Schema{ + MarkdownDescription: "Creates or updates a Kibana security detection rule. See the [rules API documentation](https://www.elastic.co/guide/en/security/current/rules-api-create.html) for more details.", + Attributes: map[string]schema.Attribute{ + "id": schema.StringAttribute{ + MarkdownDescription: "Internal identifier of the resource", + Computed: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + }, + }, + "space_id": schema.StringAttribute{ + MarkdownDescription: "An identifier for the space. If space_id is not provided, the default space is used.", + Optional: true, + Computed: true, + Default: stringdefault.StaticString("default"), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "rule_id": schema.StringAttribute{ + MarkdownDescription: "A stable unique identifier for the rule object. If omitted, a UUID is generated.", + Optional: true, + Computed: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplaceIfConfigured(), + }, + }, + "name": schema.StringAttribute{ + MarkdownDescription: "A human-readable name for the rule.", + Required: true, + Validators: []validator.String{ + stringvalidator.LengthBetween(1, 255), + }, + }, + "type": schema.StringAttribute{ + MarkdownDescription: "Rule type. Supported types: query, eql, esql, machine_learning, new_terms, saved_query, threat_match, threshold.", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf("query", "eql", "esql", "machine_learning", "new_terms", "saved_query", "threat_match", "threshold"), + }, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "data_view_id": schema.StringAttribute{ + MarkdownDescription: "Data view ID for the rule. Not supported for esql and machine_learning rule types.", + Optional: true, + }, + "namespace": schema.StringAttribute{ + MarkdownDescription: "Alerts index namespace. Available for all rule types.", + Optional: true, + }, + "rule_name_override": schema.StringAttribute{ + MarkdownDescription: "Override the rule name in Kibana. Available for all rule types.", + Optional: true, + }, + "timestamp_override": schema.StringAttribute{ + MarkdownDescription: "Field name to use for timestamp override. Available for all rule types.", + Optional: true, + }, + "timestamp_override_fallback_disabled": schema.BoolAttribute{ + MarkdownDescription: "Disables timestamp override fallback. Available for all rule types.", + Optional: true, + }, + "query": schema.StringAttribute{ + MarkdownDescription: "The query language definition.", + Optional: true, + }, + "language": schema.StringAttribute{ + MarkdownDescription: "The query language (KQL or Lucene).", + Optional: true, + Computed: true, + Validators: []validator.String{ + stringvalidator.OneOf("kuery", "lucene", "eql", "esql"), + }, + }, + "index": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Indices on which the rule functions.", + Optional: true, + Computed: true, + }, + "enabled": schema.BoolAttribute{ + MarkdownDescription: "Determines whether the rule is enabled.", + Optional: true, + Computed: true, + Default: booldefault.StaticBool(true), + }, + "from": schema.StringAttribute{ + MarkdownDescription: "Time from which data is analyzed each time the rule runs, using a date math range.", + Optional: true, + Computed: true, + Default: stringdefault.StaticString("now-6m"), + Validators: []validator.String{ + stringvalidator.RegexMatches(regexp.MustCompile(`^now-\d+[smhd]$`), "must be a valid date math expression like 'now-6m'"), + }, + }, + "to": schema.StringAttribute{ + MarkdownDescription: "Time to which data is analyzed each time the rule runs, using a date math range.", + Optional: true, + Computed: true, + Default: stringdefault.StaticString("now"), + }, + "interval": schema.StringAttribute{ + MarkdownDescription: "Frequency of rule execution, using a date math range.", + Optional: true, + Computed: true, + Default: stringdefault.StaticString("5m"), + Validators: []validator.String{ + stringvalidator.RegexMatches(regexp.MustCompile(`^\d+[smhd]$`), "must be a valid interval like '5m'"), + }, + }, + "description": schema.StringAttribute{ + MarkdownDescription: "The rule's description.", + Required: true, + }, + "risk_score": schema.Int64Attribute{ + MarkdownDescription: "A numerical representation of the alert's severity from 0 to 100.", + Optional: true, + Computed: true, + Default: int64default.StaticInt64(50), + Validators: []validator.Int64{ + int64validator.Between(0, 100), + }, + }, + "risk_score_mapping": schema.ListNestedAttribute{ + MarkdownDescription: "Array of risk score mappings to override the default risk score based on source event field values.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "field": schema.StringAttribute{ + MarkdownDescription: "Source event field used to override the default risk_score.", + Required: true, + }, + "operator": schema.StringAttribute{ + MarkdownDescription: "Operator to use for field value matching. Currently only 'equals' is supported.", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf("equals"), + }, + }, + "value": schema.StringAttribute{ + MarkdownDescription: "Value to match against the field.", + Required: true, + }, + "risk_score": schema.Int64Attribute{ + MarkdownDescription: "Risk score to use when the field matches the value (0-100). If omitted, uses the rule's default risk_score.", + Optional: true, + Validators: []validator.Int64{ + int64validator.Between(0, 100), + }, + }, + }, + }, + }, + "severity": schema.StringAttribute{ + MarkdownDescription: "Severity level of alerts produced by the rule.", + Optional: true, + Computed: true, + Default: stringdefault.StaticString("medium"), + Validators: []validator.String{ + stringvalidator.OneOf("low", "medium", "high", "critical"), + }, + }, + "severity_mapping": schema.ListNestedAttribute{ + MarkdownDescription: "Array of severity mappings to override the default severity based on source event field values.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "field": schema.StringAttribute{ + MarkdownDescription: "Source event field used to override the default severity.", + Required: true, + }, + "operator": schema.StringAttribute{ + MarkdownDescription: "Operator to use for field value matching. Currently only 'equals' is supported.", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf("equals"), + }, + }, + "value": schema.StringAttribute{ + MarkdownDescription: "Value to match against the field.", + Required: true, + }, + "severity": schema.StringAttribute{ + MarkdownDescription: "Severity level to use when the field matches the value.", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf("low", "medium", "high", "critical"), + }, + }, + }, + }, + }, + "author": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "The rule's author.", + Optional: true, + Computed: true, + Default: listdefault.StaticValue(types.ListValueMust(types.StringType, []attr.Value{})), + }, + "tags": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "String array containing words and phrases to help categorize, filter, and search rules.", + Optional: true, + Computed: true, + Default: listdefault.StaticValue(types.ListValueMust(types.StringType, []attr.Value{})), + }, + "license": schema.StringAttribute{ + MarkdownDescription: "The rule's license.", + Optional: true, + }, + "related_integrations": schema.ListNestedAttribute{ + MarkdownDescription: "Array of related integrations that provide additional context for the rule.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "package": schema.StringAttribute{ + MarkdownDescription: "Name of the integration package.", + Required: true, + }, + "version": schema.StringAttribute{ + MarkdownDescription: "Version of the integration package.", + Required: true, + }, + "integration": schema.StringAttribute{ + MarkdownDescription: "Name of the specific integration.", + Optional: true, + }, + }, + }, + }, + "required_fields": schema.ListNestedAttribute{ + MarkdownDescription: "Array of Elasticsearch fields and types that must be present in source indices for the rule to function properly.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "name": schema.StringAttribute{ + MarkdownDescription: "Name of the Elasticsearch field.", + Required: true, + }, + "type": schema.StringAttribute{ + MarkdownDescription: "Type of the Elasticsearch field.", + Required: true, + }, + "ecs": schema.BoolAttribute{ + MarkdownDescription: "Indicates whether the field is ECS-compliant. This is computed by the backend based on the field name and type.", + Computed: true, + }, + }, + }, + }, + "false_positives": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "String array used to describe common reasons why the rule may issue false-positive alerts.", + Optional: true, + Computed: true, + Default: listdefault.StaticValue(types.ListValueMust(types.StringType, []attr.Value{})), + }, + "references": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "String array containing references and URLs to sources of additional information.", + Optional: true, + Computed: true, + Default: listdefault.StaticValue(types.ListValueMust(types.StringType, []attr.Value{})), + }, + "investigation_fields": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Array of field names to include in alert investigation. Available for all rule types.", + Optional: true, + }, + "filters": schema.StringAttribute{ + MarkdownDescription: "Query and filter context array to define alert conditions as JSON. Supports complex filter structures including bool queries, term filters, range filters, etc. Available for all rule types.", + Optional: true, + CustomType: jsontypes.NormalizedType{}, + }, + "note": schema.StringAttribute{ + MarkdownDescription: "Notes to help investigate alerts produced by the rule.", + Optional: true, + }, + "setup": schema.StringAttribute{ + MarkdownDescription: "Setup guide with instructions on rule prerequisites.", + Optional: true, + }, + "max_signals": schema.Int64Attribute{ + MarkdownDescription: "Maximum number of alerts the rule can create during a single run.", + Optional: true, + Computed: true, + Default: int64default.StaticInt64(100), + Validators: []validator.Int64{ + int64validator.AtLeast(1), + }, + }, + "version": schema.Int64Attribute{ + MarkdownDescription: "The rule's version number.", + Optional: true, + Computed: true, + Default: int64default.StaticInt64(1), + Validators: []validator.Int64{ + int64validator.AtLeast(1), + }, + }, + + // Actions field (common across all rule types) + "actions": schema.ListNestedAttribute{ + MarkdownDescription: "Array of automated actions taken when alerts are generated by the rule.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "action_type_id": schema.StringAttribute{ + MarkdownDescription: "The action type used for sending notifications (e.g., .slack, .email, .webhook, .pagerduty, etc.).", + Required: true, + }, + "id": schema.StringAttribute{ + MarkdownDescription: "The connector ID.", + Required: true, + }, + "params": schema.MapAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Object containing the allowed connector fields, which varies according to the connector type.", + Required: true, + }, + "group": schema.StringAttribute{ + MarkdownDescription: "Optionally groups actions by use cases. Use 'default' for alert notifications.", + Optional: true, + }, + "uuid": schema.StringAttribute{ + MarkdownDescription: "A unique identifier for the action.", + Optional: true, + Computed: true, + }, + "alerts_filter": schema.MapAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Object containing an action's conditional filters.", + Optional: true, + }, + "frequency": schema.SingleNestedAttribute{ + MarkdownDescription: "The action frequency defines when the action runs.", + Optional: true, + Attributes: map[string]schema.Attribute{ + "notify_when": schema.StringAttribute{ + MarkdownDescription: "Defines how often rules run actions. Valid values: onActionGroupChange, onActiveAlert, onThrottleInterval.", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf("onActionGroupChange", "onActiveAlert", "onThrottleInterval"), + }, + }, + "summary": schema.BoolAttribute{ + MarkdownDescription: "Action summary indicates whether we will send a summary notification about all the generated alerts or notification per individual alert.", + Required: true, + }, + "throttle": schema.StringAttribute{ + MarkdownDescription: "Time interval for throttling actions (e.g., '1h', '30m', 'no_actions', 'rule').", + Required: true, + }, + }, + }, + }, + }, + }, + + // Response actions field (common across all rule types) + "response_actions": schema.ListNestedAttribute{ + MarkdownDescription: "Array of response actions to take when alerts are generated by the rule.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "action_type_id": schema.StringAttribute{ + MarkdownDescription: "The action type used for response actions (.osquery, .endpoint).", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf(".osquery", ".endpoint"), + }, + }, + "params": schema.SingleNestedAttribute{ + MarkdownDescription: "Parameters for the response action. Structure varies based on action_type_id.", + Required: true, + Attributes: map[string]schema.Attribute{ + // Osquery params + "query": schema.StringAttribute{ + MarkdownDescription: "SQL query to run (osquery only). Example: 'SELECT * FROM processes;'", + Optional: true, + }, + "pack_id": schema.StringAttribute{ + MarkdownDescription: "Query pack identifier (osquery only).", + Optional: true, + }, + "saved_query_id": schema.StringAttribute{ + MarkdownDescription: "Saved query identifier (osquery only).", + Optional: true, + }, + "timeout": schema.Int64Attribute{ + MarkdownDescription: "Timeout period in seconds (osquery only). Min: 60, Max: 900.", + Optional: true, + Validators: []validator.Int64{ + int64validator.Between(60, 900), + }, + }, + "ecs_mapping": schema.MapAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Map Osquery results columns to ECS fields (osquery only).", + Optional: true, + }, + "queries": schema.ListNestedAttribute{ + MarkdownDescription: "Array of queries to run (osquery only).", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "id": schema.StringAttribute{ + MarkdownDescription: "Query ID.", + Required: true, + }, + "query": schema.StringAttribute{ + MarkdownDescription: "Query to run.", + Required: true, + }, + "platform": schema.StringAttribute{ + MarkdownDescription: "Platform to run the query on.", + Optional: true, + }, + "version": schema.StringAttribute{ + MarkdownDescription: "Query version.", + Optional: true, + }, + "removed": schema.BoolAttribute{ + MarkdownDescription: "Whether the query is removed.", + Optional: true, + }, + "snapshot": schema.BoolAttribute{ + MarkdownDescription: "Whether this is a snapshot query.", + Optional: true, + }, + "ecs_mapping": schema.MapAttribute{ + ElementType: types.StringType, + MarkdownDescription: "ECS field mappings for this query.", + Optional: true, + }, + }, + }, + }, + // Endpoint params - common command and comment + "command": schema.StringAttribute{ + MarkdownDescription: "Command to run (endpoint only). Valid values: isolate, kill-process, suspend-process.", + Optional: true, + Validators: []validator.String{ + stringvalidator.OneOf("isolate", "kill-process", "suspend-process"), + }, + }, + "comment": schema.StringAttribute{ + MarkdownDescription: "Comment describing the action (endpoint only).", + Optional: true, + }, + // Endpoint process params - for kill-process and suspend-process commands + "config": schema.SingleNestedAttribute{ + MarkdownDescription: "Configuration for process commands (endpoint only).", + Optional: true, + Attributes: map[string]schema.Attribute{ + "field": schema.StringAttribute{ + MarkdownDescription: "Field to use instead of process.pid.", + Required: true, + }, + "overwrite": schema.BoolAttribute{ + MarkdownDescription: "Whether to overwrite field with process.pid.", + Optional: true, + Computed: true, + Default: booldefault.StaticBool(true), + }, + }, + }, + }, + }, + }, + }, + }, + + // Exceptions list field (common across all rule types) + "exceptions_list": schema.ListNestedAttribute{ + MarkdownDescription: "Array of exception containers to prevent the rule from generating alerts.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "id": schema.StringAttribute{ + MarkdownDescription: "The exception container ID.", + Required: true, + }, + "list_id": schema.StringAttribute{ + MarkdownDescription: "The exception container's list ID.", + Required: true, + }, + "namespace_type": schema.StringAttribute{ + MarkdownDescription: "The namespace type for the exception container.", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf("single", "agnostic"), + }, + }, + "type": schema.StringAttribute{ + MarkdownDescription: "The type of exception container.", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf("detection", "endpoint", "endpoint_events", "endpoint_host_isolation_exceptions", "endpoint_blocklists", "endpoint_trusted_apps"), + }, + }, + }, + }, + }, + + // Alert suppression field (common across all rule types) + "alert_suppression": schema.SingleNestedAttribute{ + MarkdownDescription: "Defines alert suppression configuration to reduce duplicate alerts.", + Optional: true, + Attributes: map[string]schema.Attribute{ + "group_by": schema.ListAttribute{ + MarkdownDescription: "Array of field names to group alerts by for suppression.", + Optional: true, + ElementType: types.StringType, + }, + "duration": schema.StringAttribute{ + Description: "Duration for which alerts are suppressed.", + Optional: true, + CustomType: customtypes.DurationType{}, + }, + "missing_fields_strategy": schema.StringAttribute{ + MarkdownDescription: "Strategy for handling missing fields in suppression grouping: 'suppress' - only one alert will be created per suppress by bucket, 'doNotSuppress' - per each document a separate alert will be created.", + Optional: true, + Validators: []validator.String{ + stringvalidator.OneOf("suppress", "doNotSuppress"), + }, + }, + }, + }, + + // Building block type field (common across all rule types) + "building_block_type": schema.StringAttribute{ + MarkdownDescription: "Determines if the rule acts as a building block. If set, value must be `default`. Building-block alerts are not displayed in the UI by default and are used as a foundation for other rules.", + Optional: true, + Validators: []validator.String{ + stringvalidator.OneOf("default"), + }, + }, + + // Read-only fields + "created_at": schema.StringAttribute{ + MarkdownDescription: "The time the rule was created.", + Computed: true, + }, + "created_by": schema.StringAttribute{ + MarkdownDescription: "The user who created the rule.", + Computed: true, + }, + "updated_at": schema.StringAttribute{ + MarkdownDescription: "The time the rule was last updated.", + Computed: true, + }, + "updated_by": schema.StringAttribute{ + MarkdownDescription: "The user who last updated the rule.", + Computed: true, + }, + "revision": schema.Int64Attribute{ + MarkdownDescription: "The rule's revision number.", + Computed: true, + }, + + // EQL-specific fields + "tiebreaker_field": schema.StringAttribute{ + MarkdownDescription: "Sets the tiebreaker field. Required for EQL rules when event.dataset is not provided.", + Optional: true, + }, + + // Machine Learning-specific fields + "anomaly_threshold": schema.Int64Attribute{ + MarkdownDescription: "Anomaly score threshold above which the rule creates an alert. Valid values are from 0 to 100. Required for machine_learning rules.", + Optional: true, + Validators: []validator.Int64{ + int64validator.Between(0, 100), + }, + }, + "machine_learning_job_id": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Machine learning job ID(s) the rule monitors for anomaly scores. Required for machine_learning rules.", + Optional: true, + }, + + // New Terms-specific fields + "new_terms_fields": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Field names containing the new terms. Required for new_terms rules.", + Optional: true, + }, + "history_window_start": schema.StringAttribute{ + MarkdownDescription: "Start date to use when checking if a term has been seen before. Supports relative dates like 'now-30d'. Required for new_terms rules.", + Optional: true, + }, + + // Saved Query-specific fields + "saved_id": schema.StringAttribute{ + MarkdownDescription: "Identifier of the saved query used for the rule. Required for saved_query rules.", + Optional: true, + }, + + // Threat Match-specific fields + "threat_index": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Array of index patterns for the threat intelligence indices. Required for threat_match rules.", + Optional: true, + }, + "threat_query": schema.StringAttribute{ + MarkdownDescription: "Query used to filter threat intelligence data. Optional for threat_match rules.", + Optional: true, + }, + "threat_mapping": schema.ListNestedAttribute{ + MarkdownDescription: "Array of threat mappings that specify how to match events with threat intelligence. Required for threat_match rules.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "entries": schema.ListNestedAttribute{ + MarkdownDescription: "Array of mapping entries.", + Required: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "field": schema.StringAttribute{ + MarkdownDescription: "Event field to match.", + Required: true, + }, + "type": schema.StringAttribute{ + MarkdownDescription: "Type of match (mapping).", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf("mapping"), + }, + }, + "value": schema.StringAttribute{ + MarkdownDescription: "Threat intelligence field to match against.", + Required: true, + }, + }, + }, + }, + }, + }, + }, + "threat_filters": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Additional filters for threat intelligence data. Optional for threat_match rules.", + Optional: true, + }, + "threat_indicator_path": schema.StringAttribute{ + MarkdownDescription: "Path to the threat indicator in the indicator documents. Optional for threat_match rules.", + Optional: true, + Computed: true, + }, + "concurrent_searches": schema.Int64Attribute{ + MarkdownDescription: "Number of concurrent searches for threat intelligence. Optional for threat_match rules.", + Optional: true, + Validators: []validator.Int64{ + int64validator.AtLeast(1), + }, + }, + "items_per_search": schema.Int64Attribute{ + MarkdownDescription: "Number of items to search for in each concurrent search. Optional for threat_match rules.", + Optional: true, + Validators: []validator.Int64{ + int64validator.AtLeast(1), + }, + }, + + // Threshold-specific fields + "threshold": schema.SingleNestedAttribute{ + MarkdownDescription: "Threshold settings for the rule. Required for threshold rules.", + Optional: true, + Attributes: map[string]schema.Attribute{ + "field": schema.ListAttribute{ + ElementType: types.StringType, + MarkdownDescription: "Field(s) to use for threshold aggregation.", + Optional: true, + }, + "value": schema.Int64Attribute{ + MarkdownDescription: "The threshold value from which an alert is generated.", + Required: true, + Validators: []validator.Int64{ + int64validator.AtLeast(1), + }, + }, + "cardinality": schema.ListNestedAttribute{ + MarkdownDescription: "Cardinality settings for threshold rule.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "field": schema.StringAttribute{ + MarkdownDescription: "The field on which to calculate and compare the cardinality.", + Required: true, + }, + "value": schema.Int64Attribute{ + MarkdownDescription: "The threshold cardinality value.", + Required: true, + Validators: []validator.Int64{ + int64validator.AtLeast(1), + }, + }, + }, + }, + }, + }, + }, + + // Optional timeline fields (common across multiple rule types) + "timeline_id": schema.StringAttribute{ + MarkdownDescription: "Timeline template ID for the rule.", + Optional: true, + }, + "timeline_title": schema.StringAttribute{ + MarkdownDescription: "Timeline template title for the rule.", + Optional: true, + }, + + // Threat field (common across multiple rule types) + "threat": schema.ListNestedAttribute{ + MarkdownDescription: "MITRE ATT&CK framework threat information.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "framework": schema.StringAttribute{ + MarkdownDescription: "Threat framework (typically 'MITRE ATT&CK').", + Required: true, + }, + "tactic": schema.SingleNestedAttribute{ + MarkdownDescription: "MITRE ATT&CK tactic information.", + Required: true, + Attributes: map[string]schema.Attribute{ + "id": schema.StringAttribute{ + MarkdownDescription: "MITRE ATT&CK tactic ID.", + Required: true, + }, + "name": schema.StringAttribute{ + MarkdownDescription: "MITRE ATT&CK tactic name.", + Required: true, + }, + "reference": schema.StringAttribute{ + MarkdownDescription: "MITRE ATT&CK tactic reference URL.", + Required: true, + }, + }, + }, + "technique": schema.ListNestedAttribute{ + MarkdownDescription: "MITRE ATT&CK technique information.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "id": schema.StringAttribute{ + MarkdownDescription: "MITRE ATT&CK technique ID.", + Required: true, + }, + "name": schema.StringAttribute{ + MarkdownDescription: "MITRE ATT&CK technique name.", + Required: true, + }, + "reference": schema.StringAttribute{ + MarkdownDescription: "MITRE ATT&CK technique reference URL.", + Required: true, + }, + "subtechnique": schema.ListNestedAttribute{ + MarkdownDescription: "MITRE ATT&CK sub-technique information.", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "id": schema.StringAttribute{ + MarkdownDescription: "MITRE ATT&CK sub-technique ID.", + Required: true, + }, + "name": schema.StringAttribute{ + MarkdownDescription: "MITRE ATT&CK sub-technique name.", + Required: true, + }, + "reference": schema.StringAttribute{ + MarkdownDescription: "MITRE ATT&CK sub-technique reference URL.", + Required: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } +} + +// func getCardinalityType() map[string]attr.Type { +func getCardinalityType() attr.Type { + return GetSchema().Attributes["threshold"].(schema.SingleNestedAttribute).Attributes["cardinality"].GetType().(attr.TypeWithElementType).ElementType() +} + +// getThresholdType returns the attribute types for threshold objects +func getThresholdType() map[string]attr.Type { + return GetSchema().Attributes["threshold"].GetType().(attr.TypeWithAttributeTypes).AttributeTypes() +} + +// getAlertSuppressionType returns the attribute types for alert suppression objects +func getAlertSuppressionType() map[string]attr.Type { + return GetSchema().Attributes["alert_suppression"].GetType().(attr.TypeWithAttributeTypes).AttributeTypes() +} + +// getThreatElementType returns the element type for threat objects (MITRE ATT&CK framework) +func getThreatElementType() attr.Type { + return GetSchema().Attributes["threat"].GetType().(attr.TypeWithElementType).ElementType() +} + +func getThreatMappingElementType() attr.Type { + return GetSchema().Attributes["threat_mapping"].GetType().(attr.TypeWithElementType).ElementType() +} + +func getThreatMappingEntryElementType() attr.Type { + threatMappingType := GetSchema().Attributes["threat_mapping"].GetType().(attr.TypeWithElementType).ElementType().(attr.TypeWithAttributeTypes) + return threatMappingType.AttributeTypes()["entries"].(attr.TypeWithElementType).ElementType() +} + +func getResponseActionElementType() attr.Type { + return GetSchema().Attributes["response_actions"].GetType().(attr.TypeWithElementType).ElementType() +} + +func getResponseActionParamsType() map[string]attr.Type { + responseActionType := GetSchema().Attributes["response_actions"].GetType().(attr.TypeWithElementType).ElementType().(attr.TypeWithAttributeTypes) + return responseActionType.AttributeTypes()["params"].(attr.TypeWithAttributeTypes).AttributeTypes() +} + +func getOsqueryQueryElementType() attr.Type { + responseActionType := GetSchema().Attributes["response_actions"].GetType().(attr.TypeWithElementType).ElementType().(attr.TypeWithAttributeTypes) + paramsType := responseActionType.AttributeTypes()["params"].(attr.TypeWithAttributeTypes) + return paramsType.AttributeTypes()["queries"].(attr.TypeWithElementType).ElementType() +} + +func getEndpointProcessConfigType() map[string]attr.Type { + responseActionType := GetSchema().Attributes["response_actions"].GetType().(attr.TypeWithElementType).ElementType().(attr.TypeWithAttributeTypes) + paramsType := responseActionType.AttributeTypes()["params"].(attr.TypeWithAttributeTypes) + return paramsType.AttributeTypes()["config"].(attr.TypeWithAttributeTypes).AttributeTypes() +} + +func getActionElementType() attr.Type { + return GetSchema().Attributes["actions"].GetType().(attr.TypeWithElementType).ElementType() +} + +func getActionFrequencyType() map[string]attr.Type { + actionType := GetSchema().Attributes["actions"].GetType().(attr.TypeWithElementType).ElementType().(attr.TypeWithAttributeTypes) + return actionType.AttributeTypes()["frequency"].(attr.TypeWithAttributeTypes).AttributeTypes() +} + +func getExceptionsListElementType() attr.Type { + return GetSchema().Attributes["exceptions_list"].GetType().(attr.TypeWithElementType).ElementType() +} + +func getRiskScoreMappingElementType() attr.Type { + return GetSchema().Attributes["risk_score_mapping"].GetType().(attr.TypeWithElementType).ElementType() +} + +func getRelatedIntegrationElementType() attr.Type { + return GetSchema().Attributes["related_integrations"].GetType().(attr.TypeWithElementType).ElementType() +} + +func getRequiredFieldElementType() attr.Type { + return GetSchema().Attributes["required_fields"].GetType().(attr.TypeWithElementType).ElementType() +} + +func getSeverityMappingElementType() attr.Type { + return GetSchema().Attributes["severity_mapping"].GetType().(attr.TypeWithElementType).ElementType() +} diff --git a/internal/kibana/security_detection_rule/update.go b/internal/kibana/security_detection_rule/update.go new file mode 100644 index 000000000..50ea9feb4 --- /dev/null +++ b/internal/kibana/security_detection_rule/update.go @@ -0,0 +1,83 @@ +package security_detection_rule + +import ( + "context" + "fmt" + + "github.com/elastic/terraform-provider-elasticstack/internal/clients" + "github.com/google/uuid" + "github.com/hashicorp/terraform-plugin-framework/resource" +) + +func (r *securityDetectionRuleResource) Update(ctx context.Context, req resource.UpdateRequest, resp *resource.UpdateResponse) { + var data SecurityDetectionRuleData + + resp.Diagnostics.Append(req.Plan.Get(ctx, &data)...) + if resp.Diagnostics.HasError() { + return + } + + // Get the rule using kbapi client + kbClient, err := r.client.GetKibanaOapiClient() + if err != nil { + resp.Diagnostics.AddError( + "Error getting Kibana client", + "Could not get Kibana OAPI client: "+err.Error(), + ) + return + } + + // Build the update request + updateProps, diags := data.toUpdateProps(ctx, r.client) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + // Update the rule + response, err := kbClient.API.UpdateRuleWithResponse(ctx, updateProps) + if err != nil { + resp.Diagnostics.AddError( + "Error updating security detection rule", + "Could not update security detection rule: "+err.Error(), + ) + return + } + + if response.StatusCode() != 200 { + resp.Diagnostics.AddError( + "Error updating security detection rule", + fmt.Sprintf("API returned status %d: %s", response.StatusCode(), string(response.Body)), + ) + return + } + + // Parse ID to get space_id and rule_id + compId, resourceIdDiags := clients.CompositeIdFromStrFw(data.Id.ValueString()) + resp.Diagnostics.Append(resourceIdDiags...) + if resp.Diagnostics.HasError() { + return + } + + uid, err := uuid.Parse(compId.ResourceId) + if err != nil { + resp.Diagnostics.AddError("ID was not a valid UUID", err.Error()) + return + } + + readData, diags := r.read(ctx, uid.String(), data.SpaceId.ValueString()) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + if readData == nil { + resp.Diagnostics.AddError( + "Error reading updated security detection rule", + "Could not read security detection rule after update", + ) + return + } + + resp.Diagnostics.Append(resp.State.Set(ctx, readData)...) +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 497db2836..15b6d423c 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -10,6 +10,7 @@ import ( "time" providerSchema "github.com/elastic/terraform-provider-elasticstack/internal/schema" + "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -224,3 +225,9 @@ func NonNilSlice[T any](s []T) []T { return s } + +// TimeToStringValue formats a time.Time to ISO 8601 format and returns a types.StringValue. +// This is a convenience function that combines FormatStrictDateTime and types.StringValue. +func TimeToStringValue(t time.Time) types.String { + return types.StringValue(FormatStrictDateTime(t)) +} diff --git a/provider/plugin_framework.go b/provider/plugin_framework.go index c48c303a0..b5c681d33 100644 --- a/provider/plugin_framework.go +++ b/provider/plugin_framework.go @@ -26,6 +26,7 @@ import ( "github.com/elastic/terraform-provider-elasticstack/internal/kibana/export_saved_objects" "github.com/elastic/terraform-provider-elasticstack/internal/kibana/import_saved_objects" "github.com/elastic/terraform-provider-elasticstack/internal/kibana/maintenance_window" + "github.com/elastic/terraform-provider-elasticstack/internal/kibana/security_detection_rule" "github.com/elastic/terraform-provider-elasticstack/internal/kibana/spaces" "github.com/elastic/terraform-provider-elasticstack/internal/kibana/synthetics" "github.com/elastic/terraform-provider-elasticstack/internal/kibana/synthetics/parameter" @@ -120,5 +121,6 @@ func (p *Provider) Resources(ctx context.Context) []func() resource.Resource { maintenance_window.NewResource, enrich.NewEnrichPolicyResource, role_mapping.NewRoleMappingResource, + security_detection_rule.NewSecurityDetectionRuleResource, } }