Skip to content

[FSSDK-11148] update: Implement CMAB Client #453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions optimizely/cmab/cmab_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright 2025 Optimizely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
import requests
import math
from typing import Dict, Any, Optional
from optimizely import logger as _logging
from optimizely.helpers.enums import Errors

# CMAB_PREDICTION_ENDPOINT is the endpoint for CMAB predictions
CMAB_PREDICTION_ENDPOINT = "https://prediction.cmab.optimizely.com/predict/%s"

# Default constants for CMAB requests
DEFAULT_MAX_RETRIES = 3
DEFAULT_INITIAL_BACKOFF = 0.1 # in seconds (100 ms)
DEFAULT_MAX_BACKOFF = 10 # in seconds
DEFAULT_BACKOFF_MULTIPLIER = 2.0
MAX_WAIT_TIME = 10.0


class CmabRetryConfig:
"""Configuration for retrying CMAB requests.

Contains parameters for maximum retries, backoff intervals, and multipliers.
"""
def __init__(
self,
max_retries: int = DEFAULT_MAX_RETRIES,
initial_backoff: float = DEFAULT_INITIAL_BACKOFF,
max_backoff: float = DEFAULT_MAX_BACKOFF,
backoff_multiplier: float = DEFAULT_BACKOFF_MULTIPLIER,
):
self.max_retries = max_retries
self.initial_backoff = initial_backoff
self.max_backoff = max_backoff
self.backoff_multiplier = backoff_multiplier


class DefaultCmabClient:
"""Client for interacting with the CMAB service.

Provides methods to fetch decisions with optional retry logic.
"""
def __init__(self, http_client: Optional[requests.Session] = None,
retry_config: Optional[CmabRetryConfig] = None,
logger: Optional[_logging.Logger] = None):
"""Initialize the CMAB client.

Args:
http_client (Optional[requests.Session]): HTTP client for making requests.
retry_config (Optional[CmabRetryConfig]): Configuration for retry logic.
logger (Optional[_logging.Logger]): Logger for logging messages.
"""
self.http_client = http_client or requests.Session()
self.retry_config = retry_config
self.logger = _logging.adapt_logger(logger or _logging.NoOpLogger())

def fetch_decision(
self,
rule_id: str,
user_id: str,
attributes: Dict[str, Any],
cmab_uuid: str
) -> Optional[str]:
"""Fetch a decision from the CMAB prediction service.

Args:
rule_id (str): The rule ID for the experiment.
user_id (str): The user ID for the request.
attributes (Dict[str, Any]): User attributes for the request.
cmab_uuid (str): Unique identifier for the CMAB request.

Returns:
Optional[str]: The variation ID if successful, None otherwise.
"""
url = CMAB_PREDICTION_ENDPOINT % rule_id
cmab_attributes = [
{"id": key, "value": value, "type": "custom_attribute"}
for key, value in attributes.items()
]

request_body = {
"instances": [{
"visitorId": user_id,
"experimentId": rule_id,
"attributes": cmab_attributes,
"cmabUUID": cmab_uuid,
}]
}

try:
if self.retry_config:
variation_id = self._do_fetch_with_retry(url, request_body, self.retry_config)
else:
variation_id = self._do_fetch(url, request_body)
return variation_id

except requests.RequestException as e:
self.logger.error(Errors.CMAB_FETCH_FAILED.format(str(e)))
return None

def _do_fetch(self, url: str, request_body: Dict[str, Any]) -> Optional[str]:
"""Perform a single fetch request to the CMAB prediction service.

Args:
url (str): The endpoint URL.
request_body (Dict[str, Any]): The request payload.

Returns:
Optional[str]: The variation ID if successful, None otherwise.
"""
headers = {'Content-Type': 'application/json'}
try:
response = self.http_client.post(url, data=json.dumps(request_body), headers=headers, timeout=MAX_WAIT_TIME)
except requests.exceptions.RequestException as e:
self.logger.exception(Errors.CMAB_FETCH_FAILED.format(str(e)))
return None

if not 200 <= response.status_code < 300:
self.logger.exception(Errors.CMAB_FETCH_FAILED.format(str(response.status_code)))
return None

try:
body = response.json()
except json.JSONDecodeError:
self.logger.exception(Errors.INVALID_CMAB_FETCH_RESPONSE)
return None

if not self.validate_response(body):
self.logger.exception(Errors.INVALID_CMAB_FETCH_RESPONSE)
return None

return str(body['predictions'][0]['variation_id'])

def validate_response(self, body: Dict[str, Any]) -> bool:
"""Validate the response structure from the CMAB service.

Args:
body (Dict[str, Any]): The response body to validate.

Returns:
bool: True if the response is valid, False otherwise.
"""
return (
isinstance(body, dict) and
'predictions' in body and
isinstance(body['predictions'], list) and
len(body['predictions']) > 0 and
isinstance(body['predictions'][0], dict) and
"variation_id" in body["predictions"][0]
)

def _do_fetch_with_retry(
self,
url: str,
request_body: Dict[str, Any],
retry_config: CmabRetryConfig
) -> Optional[str]:
"""Perform a fetch request with retry logic.

Args:
url (str): The endpoint URL.
request_body (Dict[str, Any]): The request payload.
retry_config (CmabRetryConfig): Configuration for retry logic.

Returns:
Optional[str]: The variation ID if successful, None otherwise.
"""
backoff = retry_config.initial_backoff
for attempt in range(retry_config.max_retries + 1):
variation_id = self._do_fetch(url, request_body)
if variation_id:
return variation_id
if attempt < retry_config.max_retries:
self.logger.info(f"Retrying CMAB request (attempt: {attempt + 1}) after {backoff} seconds...")
time.sleep(backoff)
backoff = min(backoff * math.pow(retry_config.backoff_multiplier, attempt + 1),
retry_config.max_backoff)
self.logger.error(Errors.CMAB_FETCH_FAILED.format('Exhausted all retries for CMAB request.'))
return None
2 changes: 2 additions & 0 deletions optimizely/helpers/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ class Errors:
ODP_INVALID_DATA: Final = 'ODP data is not valid.'
ODP_INVALID_ACTION: Final = 'ODP action is not valid (cannot be empty).'
MISSING_SDK_KEY: Final = 'SDK key not provided/cannot be found in the datafile.'
CMAB_FETCH_FAILED: Final = 'CMAB decision fetch failed with status: {}'
INVALID_CMAB_FETCH_RESPONSE = 'Invalid CMAB fetch response'


class ForcedDecisionLogs:
Expand Down
161 changes: 161 additions & 0 deletions tests/test_cmab_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import unittest
import json
from unittest.mock import MagicMock, patch
from optimizely.cmab.cmab_client import DefaultCmabClient, CmabRetryConfig
from requests.exceptions import RequestException
from optimizely.helpers.enums import Errors


class TestDefaultCmabClient_do_fetch(unittest.TestCase):
def setUp(self):
self.mock_http_client = MagicMock()
self.mock_logger = MagicMock()
self.client = DefaultCmabClient(http_client=self.mock_http_client, logger=self.mock_logger)

def test_do_fetch_success(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'predictions': [{'variation_id': 'abc123'}]
}
self.mock_http_client.post.return_value = mock_response

result = self.client._do_fetch('http://fake-url', {'some': 'data'})
self.assertEqual(result, 'abc123')

def test_do_fetch_http_exception(self):
self.mock_http_client.post.side_effect = RequestException('Connection error')
result = self.client._do_fetch('http://fake-url', {'some': 'data'})
self.assertIsNone(result)
self.mock_logger.exception.assert_called_with(Errors.CMAB_FETCH_FAILED.format('Connection error'))

def test_do_fetch_non_2xx_status(self):
mock_response = MagicMock()
mock_response.status_code = 500
self.mock_http_client.post.return_value = mock_response
result = self.client._do_fetch('http://fake-url', {'some': 'data'})
self.assertIsNone(result)
self.mock_logger.exception.assert_called_with(Errors.CMAB_FETCH_FAILED.format(str(mock_response.status_code)))

def test_do_fetch_invalid_json(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
self.mock_http_client.post.return_value = mock_response
result = self.client._do_fetch('http://fake-url', {'some': 'data'})
self.assertIsNone(result)
self.mock_logger.exception.assert_called_with(Errors.INVALID_CMAB_FETCH_RESPONSE)

def test_do_fetch_invalid_response_structure(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {'no_predictions': []}
self.mock_http_client.post.return_value = mock_response
result = self.client._do_fetch('http://fake-url', {'some': 'data'})
self.assertIsNone(result)
self.mock_logger.exception.assert_called_with(Errors.INVALID_CMAB_FETCH_RESPONSE)


class TestDefaultCmabClientWithRetry(unittest.TestCase):
def setUp(self):
self.mock_http_client = MagicMock()
self.mock_logger = MagicMock()
self.retry_config = CmabRetryConfig(max_retries=2, initial_backoff=0.01, max_backoff=1, backoff_multiplier=2)
self.client = DefaultCmabClient(
http_client=self.mock_http_client,
logger=self.mock_logger,
retry_config=self.retry_config
)

@patch("time.sleep", return_value=None)
def test_do_fetch_with_retry_success_on_first_try(self, _):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"predictions": [{"variation_id": "abc123"}]
}
self.mock_http_client.post.return_value = mock_response

result = self.client._do_fetch_with_retry("http://fake-url", {}, self.retry_config)
self.assertEqual(result, "abc123")
self.assertEqual(self.mock_http_client.post.call_count, 1)

@patch("time.sleep", return_value=None)
def test_do_fetch_with_retry_success_on_retry(self, _):
# First call fails, second call succeeds
failure_response = MagicMock()
failure_response.status_code = 500

success_response = MagicMock()
success_response.status_code = 200
success_response.json.return_value = {
"predictions": [{"variation_id": "xyz456"}]
}

self.mock_http_client.post.side_effect = [
failure_response,
success_response
]

result = self.client._do_fetch_with_retry("http://fake-url", {}, self.retry_config)
self.assertEqual(result, "xyz456")
self.assertEqual(self.mock_http_client.post.call_count, 2)
self.mock_logger.info.assert_called_with("Retrying CMAB request (attempt: 1) after 0.01 seconds...")

@patch("time.sleep", return_value=None)
def test_do_fetch_with_retry_exhausts_all_attempts(self, _):
failure_response = MagicMock()
failure_response.status_code = 500

self.mock_http_client.post.return_value = failure_response

result = self.client._do_fetch_with_retry("http://fake-url", {}, self.retry_config)
self.assertIsNone(result)
self.assertEqual(self.mock_http_client.post.call_count, 3) # 1 original + 2 retries
self.mock_logger.error.assert_called_with(
Errors.CMAB_FETCH_FAILED.format("Exhausted all retries for CMAB request."))


class TestDefaultCmabClientFetchDecision(unittest.TestCase):
def setUp(self):
self.mock_http_client = MagicMock()
self.mock_logger = MagicMock()
self.retry_config = CmabRetryConfig(max_retries=2, initial_backoff=0.01, max_backoff=1, backoff_multiplier=2)
self.client = DefaultCmabClient(
http_client=self.mock_http_client,
logger=self.mock_logger,
retry_config=self.retry_config
)
self.rule_id = 'test_rule'
self.user_id = 'user123'
self.attributes = {'attr1': 'value1'}
self.cmab_uuid = 'uuid-1234'

@patch.object(DefaultCmabClient, '_do_fetch', return_value='var-abc')
def test_fetch_decision_success_no_retry(self, mock_do_fetch):
result = self.client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)
self.assertEqual(result, 'var-abc')
mock_do_fetch.assert_called_once()

@patch.object(DefaultCmabClient, '_do_fetch_with_retry', return_value='var-xyz')
def test_fetch_decision_success_with_retry(self, mock_do_fetch_with_retry):
client_with_retry = DefaultCmabClient(
http_client=self.mock_http_client,
logger=self.mock_logger,
retry_config=self.retry_config
)
result = client_with_retry.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)
self.assertEqual(result, 'var-xyz')
mock_do_fetch_with_retry.assert_called_once()

@patch.object(DefaultCmabClient, '_do_fetch', side_effect=RequestException("Network error"))
def test_fetch_decision_request_exception(self, mock_do_fetch):
result = self.client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)
self.assertIsNone(result)
self.mock_logger.error.assert_called_with(Errors.CMAB_FETCH_FAILED.format("Network error"))

@patch.object(DefaultCmabClient, '_do_fetch', return_value=None)
def test_fetch_decision_invalid_response(self, mock_do_fetch):
result = self.client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)
self.assertIsNone(result)
self.mock_logger.error.assert_called_once()