|
5 | 5 |
|
6 | 6 | load_dotenv() |
7 | 7 |
|
| 8 | +# Default values for Entra ID authentication |
| 9 | +DEFAULT_TOKEN_EXPIRATION_REFRESH_RATIO = 0.9 |
| 10 | +DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 30000 # 30 seconds |
| 11 | +DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_MS = 10000 # 10 seconds |
| 12 | +DEFAULT_RETRY_MAX_ATTEMPTS = 3 |
| 13 | +DEFAULT_RETRY_DELAY_MS = 100 |
| 14 | + |
8 | 15 | REDIS_CFG = { |
9 | 16 | "host": os.getenv("REDIS_HOST", "127.0.0.1"), |
10 | 17 | "port": int(os.getenv("REDIS_PORT", 6379)), |
|
20 | 27 | "db": int(os.getenv("REDIS_DB", 0)), |
21 | 28 | } |
22 | 29 |
|
| 30 | +# Entra ID Authentication Configuration |
| 31 | +ENTRAID_CFG = { |
| 32 | + # Authentication flow selection |
| 33 | + "auth_flow": os.getenv( |
| 34 | + "REDIS_ENTRAID_AUTH_FLOW", None |
| 35 | + ), # service_principal, managed_identity, default_credential |
| 36 | + # Service Principal Authentication |
| 37 | + "client_id": os.getenv("REDIS_ENTRAID_CLIENT_ID", None), |
| 38 | + "client_secret": os.getenv("REDIS_ENTRAID_CLIENT_SECRET", None), |
| 39 | + "tenant_id": os.getenv("REDIS_ENTRAID_TENANT_ID", None), |
| 40 | + # Managed Identity Authentication |
| 41 | + "identity_type": os.getenv( |
| 42 | + "REDIS_ENTRAID_IDENTITY_TYPE", "system_assigned" |
| 43 | + ), # system_assigned, user_assigned |
| 44 | + "user_assigned_identity_client_id": os.getenv( |
| 45 | + "REDIS_ENTRAID_USER_ASSIGNED_CLIENT_ID", None |
| 46 | + ), |
| 47 | + # Default Azure Credential Authentication |
| 48 | + "scopes": os.getenv("REDIS_ENTRAID_SCOPES", "https://redis.azure.com/.default"), |
| 49 | + # Token lifecycle configuration |
| 50 | + "token_expiration_refresh_ratio": float( |
| 51 | + os.getenv( |
| 52 | + "REDIS_ENTRAID_TOKEN_EXPIRATION_REFRESH_RATIO", |
| 53 | + DEFAULT_TOKEN_EXPIRATION_REFRESH_RATIO, |
| 54 | + ) |
| 55 | + ), |
| 56 | + "lower_refresh_bound_millis": int( |
| 57 | + os.getenv( |
| 58 | + "REDIS_ENTRAID_LOWER_REFRESH_BOUND_MILLIS", |
| 59 | + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, |
| 60 | + ) |
| 61 | + ), |
| 62 | + "token_request_execution_timeout_ms": int( |
| 63 | + os.getenv( |
| 64 | + "REDIS_ENTRAID_TOKEN_REQUEST_EXECUTION_TIMEOUT_MS", |
| 65 | + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_MS, |
| 66 | + ) |
| 67 | + ), |
| 68 | + # Retry configuration |
| 69 | + "retry_max_attempts": int( |
| 70 | + os.getenv("REDIS_ENTRAID_RETRY_MAX_ATTEMPTS", DEFAULT_RETRY_MAX_ATTEMPTS) |
| 71 | + ), |
| 72 | + "retry_delay_ms": int( |
| 73 | + os.getenv("REDIS_ENTRAID_RETRY_DELAY_MS", DEFAULT_RETRY_DELAY_MS) |
| 74 | + ), |
| 75 | + # Resource configuration |
| 76 | + "resource": os.getenv("REDIS_ENTRAID_RESOURCE", "https://redis.azure.com/"), |
| 77 | +} |
| 78 | + |
23 | 79 |
|
24 | 80 | def parse_redis_uri(uri: str) -> dict: |
25 | 81 | """Parse a Redis URI and return connection parameters.""" |
@@ -99,3 +155,77 @@ def set_redis_config_from_cli(config: dict): |
99 | 155 | else: |
100 | 156 | # Convert other values to strings |
101 | 157 | REDIS_CFG[key] = str(value) if value is not None else None |
| 158 | + |
| 159 | + |
| 160 | +def set_entraid_config_from_cli(config: dict): |
| 161 | + """Update Entra ID configuration from CLI parameters.""" |
| 162 | + for key, value in config.items(): |
| 163 | + if value is not None: |
| 164 | + if key in ["token_expiration_refresh_ratio"]: |
| 165 | + # Keep float values as floats |
| 166 | + ENTRAID_CFG[key] = float(value) |
| 167 | + elif key in [ |
| 168 | + "lower_refresh_bound_millis", |
| 169 | + "token_request_execution_timeout_ms", |
| 170 | + "retry_max_attempts", |
| 171 | + "retry_delay_ms", |
| 172 | + ]: |
| 173 | + # Keep integer values as integers |
| 174 | + ENTRAID_CFG[key] = int(value) |
| 175 | + else: |
| 176 | + # Convert other values to strings |
| 177 | + ENTRAID_CFG[key] = str(value) |
| 178 | + |
| 179 | + |
| 180 | +def is_entraid_auth_enabled() -> bool: |
| 181 | + """Check if Entra ID authentication is enabled.""" |
| 182 | + return ENTRAID_CFG["auth_flow"] is not None |
| 183 | + |
| 184 | + |
| 185 | +def get_entraid_auth_flow() -> str: |
| 186 | + """Get the configured Entra ID authentication flow.""" |
| 187 | + return ENTRAID_CFG["auth_flow"] |
| 188 | + |
| 189 | + |
| 190 | +def validate_entraid_config() -> tuple[bool, str]: |
| 191 | + """Validate Entra ID configuration based on the selected auth flow. |
| 192 | +
|
| 193 | + Returns: |
| 194 | + tuple: (is_valid, error_message) |
| 195 | + """ |
| 196 | + auth_flow = ENTRAID_CFG["auth_flow"] |
| 197 | + |
| 198 | + if not auth_flow: |
| 199 | + return True, "" # No Entra ID auth configured, which is valid |
| 200 | + |
| 201 | + if auth_flow == "service_principal": |
| 202 | + required_fields = ["client_id", "client_secret", "tenant_id"] |
| 203 | + missing_fields = [field for field in required_fields if not ENTRAID_CFG[field]] |
| 204 | + if missing_fields: |
| 205 | + return ( |
| 206 | + False, |
| 207 | + f"Service principal authentication requires: {', '.join(missing_fields)}", |
| 208 | + ) |
| 209 | + |
| 210 | + elif auth_flow == "managed_identity": |
| 211 | + identity_type = ENTRAID_CFG["identity_type"] |
| 212 | + if ( |
| 213 | + identity_type == "user_assigned" |
| 214 | + and not ENTRAID_CFG["user_assigned_identity_client_id"] |
| 215 | + ): |
| 216 | + return ( |
| 217 | + False, |
| 218 | + "User-assigned managed identity requires user_assigned_identity_client_id", |
| 219 | + ) |
| 220 | + |
| 221 | + elif auth_flow == "default_credential": |
| 222 | + # Default credential doesn't require specific configuration |
| 223 | + pass |
| 224 | + |
| 225 | + else: |
| 226 | + return ( |
| 227 | + False, |
| 228 | + f"Invalid auth_flow: {auth_flow}. Must be one of: service_principal, managed_identity, default_credential", |
| 229 | + ) |
| 230 | + |
| 231 | + return True, "" |
0 commit comments