diff --git a/src/vercel/cache/cache_in_memory.py b/src/vercel/cache/cache_in_memory.py index 4df434f..e5b43b0 100644 --- a/src/vercel/cache/cache_in_memory.py +++ b/src/vercel/cache/cache_in_memory.py @@ -96,3 +96,6 @@ async def expire_tag(self, tag: str | Sequence[str]) -> None: async def contains(self, key: str) -> bool: return key in self.cache + + def __contains__(self, key: str) -> bool: + return key in self.cache diff --git a/src/vercel/cache/types.py b/src/vercel/cache/types.py index 93e753c..7dfff66 100644 --- a/src/vercel/cache/types.py +++ b/src/vercel/cache/types.py @@ -37,6 +37,8 @@ async def expire_tag(self, tag: str | Sequence[str]) -> None: ... async def contains(self, key: str) -> bool: ... + def __contains__(self, key: str) -> bool: ... + class PurgeAPI(Protocol): """Protocol for the purge API object.""" diff --git a/src/vercel/headers.py b/src/vercel/headers.py index 11020e3..dee1123 100644 --- a/src/vercel/headers.py +++ b/src/vercel/headers.py @@ -44,6 +44,7 @@ class Geo(TypedDict, total=False): latitude: str | None longitude: str | None postalCode: str | None + requestId: str | None def _get_header(headers: _HeadersLike, key: str) -> str | None: @@ -83,4 +84,5 @@ def geolocation(request: _RequestLike) -> Geo: "latitude": _get_header(headers, LATITUDE_HEADER_NAME), "longitude": _get_header(headers, LONGITUDE_HEADER_NAME), "postalCode": _get_header(headers, POSTAL_CODE_HEADER_NAME), + "requestId": _get_header(headers, REQUEST_ID_HEADER_NAME), } diff --git a/tests/test_cache_inmemory.py b/tests/test_cache_inmemory.py new file mode 100644 index 0000000..8d911f6 --- /dev/null +++ b/tests/test_cache_inmemory.py @@ -0,0 +1,212 @@ +""" +Tests for vercel.cache.cache_in_memory — InMemoryCache and AsyncInMemoryCache. + +Covers set/get/delete, TTL expiry, tag-based invalidation, __contains__, +__getitem__, the new __contains__ on AsyncInMemoryCache, and the key +transformer utilities. +""" + +from __future__ import annotations + +import time + +import pytest + +from vercel.cache.cache_in_memory import AsyncInMemoryCache, InMemoryCache +from vercel.cache.utils import create_key_transformer, default_key_hash_function + +# --------------------------------------------------------------------------- +# InMemoryCache (sync) +# --------------------------------------------------------------------------- + + +class TestInMemoryCache: + def setup_method(self) -> None: + self.cache = InMemoryCache() + + # Basic set / get + def test_set_and_get(self) -> None: + self.cache.set("key1", {"hello": "world"}) + assert self.cache.get("key1") == {"hello": "world"} + + def test_get_missing_key_returns_none(self) -> None: + assert self.cache.get("nonexistent") is None + + def test_overwrite_value(self) -> None: + self.cache.set("k", "first") + self.cache.set("k", "second") + assert self.cache.get("k") == "second" + + # Delete + def test_delete_removes_key(self) -> None: + self.cache.set("del_me", 42) + self.cache.delete("del_me") + assert self.cache.get("del_me") is None + + def test_delete_nonexistent_key_is_no_op(self) -> None: + self.cache.delete("ghost") # should not raise + + # __contains__ + def test_contains_existing_key(self) -> None: + self.cache.set("present", True) + assert "present" in self.cache + + def test_not_contains_missing_key(self) -> None: + assert "absent" not in self.cache + + # __getitem__ + def test_getitem_returns_value(self) -> None: + self.cache.set("item", [1, 2, 3]) + assert self.cache["item"] == [1, 2, 3] + + def test_getitem_raises_key_error_for_missing(self) -> None: + with pytest.raises(KeyError): + _ = self.cache["no_such_key"] + + # TTL + def test_get_returns_none_after_ttl_expires(self) -> None: + self.cache.set("ttl_key", "val", {"ttl": 0}) + # TTL of 0 seconds — should be considered expired immediately + time.sleep(0.01) + assert self.cache.get("ttl_key") is None + + def test_contains_false_after_ttl_expires(self) -> None: + self.cache.set("ttl_check", "v", {"ttl": 0}) + time.sleep(0.01) + assert "ttl_check" not in self.cache + + def test_get_returns_value_before_ttl_expires(self) -> None: + self.cache.set("long_ttl", "valid", {"ttl": 3600}) + assert self.cache.get("long_ttl") == "valid" + + # Tag invalidation + def test_expire_tag_removes_tagged_entries(self) -> None: + self.cache.set("a", 1, {"tags": ["tag1"]}) + self.cache.set("b", 2, {"tags": ["tag2"]}) + self.cache.expire_tag("tag1") + assert self.cache.get("a") is None + assert self.cache.get("b") == 2 + + def test_expire_multiple_tags_at_once(self) -> None: + self.cache.set("x", 10, {"tags": ["alpha"]}) + self.cache.set("y", 20, {"tags": ["beta"]}) + self.cache.set("z", 30, {"tags": ["gamma"]}) + self.cache.expire_tag(["alpha", "gamma"]) + assert self.cache.get("x") is None + assert self.cache.get("y") == 20 + assert self.cache.get("z") is None + + def test_expire_tag_no_match_is_no_op(self) -> None: + self.cache.set("safe", 99, {"tags": ["keep"]}) + self.cache.expire_tag("other_tag") + assert self.cache.get("safe") == 99 + + def test_entry_with_no_tags_not_removed_by_expire_tag(self) -> None: + self.cache.set("untagged", "data") + self.cache.expire_tag("any_tag") + assert self.cache.get("untagged") == "data" + + +# --------------------------------------------------------------------------- +# AsyncInMemoryCache +# --------------------------------------------------------------------------- + + +class TestAsyncInMemoryCache: + def setup_method(self) -> None: + self.cache = AsyncInMemoryCache() + + async def test_set_and_get(self) -> None: + await self.cache.set("k", "v") + assert await self.cache.get("k") == "v" + + async def test_get_missing_returns_none(self) -> None: + assert await self.cache.get("missing") is None + + async def test_delete_removes_key(self) -> None: + await self.cache.set("d", 123) + await self.cache.delete("d") + assert await self.cache.get("d") is None + + async def test_expire_tag_removes_entries(self) -> None: + await self.cache.set("p", 1, {"tags": ["t"]}) + await self.cache.set("q", 2, {"tags": ["other"]}) + await self.cache.expire_tag("t") + assert await self.cache.get("p") is None + assert await self.cache.get("q") == 2 + + # contains() coroutine + async def test_contains_returns_true_for_existing(self) -> None: + await self.cache.set("here", True) + assert await self.cache.contains("here") is True + + async def test_contains_returns_false_for_missing(self) -> None: + assert await self.cache.contains("nowhere") is False + + # __contains__ dunder (new — sync sugar over the delegate) + async def test_dunder_contains_returns_true(self) -> None: + await self.cache.set("chk", "yes") + assert "chk" in self.cache # uses __contains__ + + async def test_dunder_contains_returns_false(self) -> None: + assert "nope" not in self.cache + + # Shared delegate + def test_shares_delegate_with_sync_cache(self) -> None: + sync_cache = InMemoryCache() + sync_cache.set("shared", "val") + async_cache = AsyncInMemoryCache(delegate=sync_cache) + assert "shared" in async_cache # __contains__ via delegate + + async def test_ttl_expiry_via_async(self) -> None: + await self.cache.set("ttl_async", "data", {"ttl": 0}) + time.sleep(0.01) + assert await self.cache.get("ttl_async") is None + + +# --------------------------------------------------------------------------- +# Key transformer utilities +# --------------------------------------------------------------------------- + + +class TestDefaultKeyHashFunction: + def test_deterministic(self) -> None: + h1 = default_key_hash_function("hello") + h2 = default_key_hash_function("hello") + assert h1 == h2 + + def test_different_inputs_produce_different_hashes(self) -> None: + assert default_key_hash_function("a") != default_key_hash_function("b") + + def test_returns_hex_string(self) -> None: + h = default_key_hash_function("test") + assert isinstance(h, str) + int(h, 16) # Should not raise — must be valid hex + + +class TestCreateKeyTransformer: + def test_no_namespace_hashes_key(self) -> None: + transform = create_key_transformer(None, None, None) + key = transform("greeting") + assert key == default_key_hash_function("greeting") + + def test_with_namespace_prefixes(self) -> None: + transform = create_key_transformer(None, "myns", None) + key = transform("greeting") + assert key.startswith("myns$") + assert key == f"myns${default_key_hash_function('greeting')}" + + def test_custom_separator(self) -> None: + transform = create_key_transformer(None, "ns", "::") + key = transform("k") + assert key.startswith("ns::") + + def test_custom_hash_function(self) -> None: + identity = lambda k: k # noqa: E731 + transform = create_key_transformer(identity, None, None) + assert transform("raw_key") == "raw_key" + + def test_custom_hash_with_namespace(self) -> None: + identity = lambda k: k # noqa: E731 + transform = create_key_transformer(identity, "ns", "-") + assert transform("raw_key") == "ns-raw_key" diff --git a/tests/test_headers.py b/tests/test_headers.py new file mode 100644 index 0000000..e95b731 --- /dev/null +++ b/tests/test_headers.py @@ -0,0 +1,217 @@ +""" +Tests for vercel.headers — geolocation, ip_address, flag helpers. + +These tests use lightweight stub objects (no HTTP framework dependency) +to cover all header-parsing paths including the newly added ``requestId``. +""" + +from __future__ import annotations + +from vercel.headers import ( + CITY_HEADER_NAME, + COUNTRY_HEADER_NAME, + IP_HEADER_NAME, + LATITUDE_HEADER_NAME, + LONGITUDE_HEADER_NAME, + POSTAL_CODE_HEADER_NAME, + REGION_HEADER_NAME, + REQUEST_ID_HEADER_NAME, + _get_flag, + _region_from_request_id, + geolocation, + ip_address, +) + +# --------------------------------------------------------------------------- +# Stubs +# --------------------------------------------------------------------------- + + +class _Headers: + """Minimal dict-backed headers stub that satisfies _HeadersLike.""" + + def __init__(self, data: dict[str, str]) -> None: + self._data = data + + def get(self, name: str) -> str | None: + return self._data.get(name) + + +class _Request: + """Minimal request stub that satisfies _RequestLike.""" + + def __init__(self, data: dict[str, str]) -> None: + self.headers = _Headers(data) + + +# --------------------------------------------------------------------------- +# ip_address() +# --------------------------------------------------------------------------- + + +class TestIpAddress: + def test_returns_ip_from_request(self) -> None: + req = _Request({IP_HEADER_NAME: "1.2.3.4"}) + assert ip_address(req) == "1.2.3.4" + + def test_accepts_headers_directly(self) -> None: + headers = _Headers({IP_HEADER_NAME: "10.0.0.1"}) + assert ip_address(headers) == "10.0.0.1" + + def test_returns_none_when_header_missing(self) -> None: + req = _Request({}) + assert ip_address(req) is None + + def test_accepts_headers_without_ip(self) -> None: + headers = _Headers({}) + assert ip_address(headers) is None + + +# --------------------------------------------------------------------------- +# _get_flag() +# --------------------------------------------------------------------------- + + +class TestGetFlag: + def test_us_flag(self) -> None: + flag = _get_flag("US") + assert flag == "🇺🇸" + + def test_gb_flag(self) -> None: + flag = _get_flag("GB") + assert flag == "🇬🇧" + + def test_lowercase_country_code(self) -> None: + # Should be case-insensitive (uppercased internally) + flag = _get_flag("in") + assert flag is not None + assert len(flag) == 2 # two regional-indicator symbols + + def test_none_input(self) -> None: + assert _get_flag(None) is None + + def test_empty_string(self) -> None: + assert _get_flag("") is None + + def test_too_long_code(self) -> None: + assert _get_flag("USA") is None + + def test_single_char_code(self) -> None: + assert _get_flag("U") is None + + def test_non_alpha_code(self) -> None: + assert _get_flag("1A") is None + + +# --------------------------------------------------------------------------- +# _region_from_request_id() +# --------------------------------------------------------------------------- + + +class TestRegionFromRequestId: + def test_extracts_region_prefix(self) -> None: + assert _region_from_request_id("iad1:abc123:xyz") == "iad1" + + def test_no_colon_returns_full_string(self) -> None: + assert _region_from_request_id("iad1") == "iad1" + + def test_none_returns_dev1(self) -> None: + assert _region_from_request_id(None) == "dev1" + + +# --------------------------------------------------------------------------- +# geolocation() +# --------------------------------------------------------------------------- + + +class TestGeolocation: + def _make_request(self, overrides: dict[str, str] | None = None) -> _Request: + defaults: dict[str, str] = { + CITY_HEADER_NAME: "New%20York", + COUNTRY_HEADER_NAME: "US", + REGION_HEADER_NAME: "NY", + LATITUDE_HEADER_NAME: "40.7128", + LONGITUDE_HEADER_NAME: "-74.0060", + POSTAL_CODE_HEADER_NAME: "10001", + REQUEST_ID_HEADER_NAME: "iad1:abc:def", + } + if overrides: + defaults.update(overrides) + return _Request(defaults) + + def test_returns_geo_typeddict(self) -> None: + geo = geolocation(self._make_request()) + # TypedDict is just a dict at runtime + assert isinstance(geo, dict) + + def test_city_is_url_decoded(self) -> None: + geo = geolocation(self._make_request()) + assert geo["city"] == "New York" + + def test_country(self) -> None: + geo = geolocation(self._make_request()) + assert geo["country"] == "US" + + def test_flag_derived_from_country(self) -> None: + geo = geolocation(self._make_request()) + assert geo["flag"] == "🇺🇸" + + def test_country_region(self) -> None: + geo = geolocation(self._make_request()) + assert geo["countryRegion"] == "NY" + + def test_region_derived_from_request_id(self) -> None: + geo = geolocation(self._make_request()) + assert geo["region"] == "iad1" + + def test_latitude_and_longitude(self) -> None: + geo = geolocation(self._make_request()) + assert geo["latitude"] == "40.7128" + assert geo["longitude"] == "-74.0060" + + def test_postal_code(self) -> None: + geo = geolocation(self._make_request()) + assert geo["postalCode"] == "10001" + + def test_request_id_exposed(self) -> None: + """requestId should be the raw x-vercel-id header value.""" + geo = geolocation(self._make_request()) + assert geo["requestId"] == "iad1:abc:def" + + def test_request_id_none_when_header_absent(self) -> None: + req = _Request({}) + geo = geolocation(req) + assert geo["requestId"] is None + + def test_all_fields_none_when_no_headers(self) -> None: + req = _Request({}) + geo = geolocation(req) + assert geo["city"] is None + assert geo["country"] is None + assert geo["flag"] is None + assert geo["countryRegion"] is None + assert geo["latitude"] is None + assert geo["longitude"] is None + assert geo["postalCode"] is None + assert geo["requestId"] is None + + def test_region_falls_back_to_dev1_without_request_id(self) -> None: + req = _Request({COUNTRY_HEADER_NAME: "US"}) + geo = geolocation(req) + assert geo["region"] == "dev1" + + def test_geo_keys_are_complete(self) -> None: + """Ensure the returned dict has exactly the expected top-level keys.""" + geo = geolocation(self._make_request()) + expected_keys = { + "city", + "country", + "flag", + "region", + "countryRegion", + "latitude", + "longitude", + "postalCode", + "requestId", + } + assert set(geo.keys()) == expected_keys