diff --git a/optimizely/odp/lru_cache.py b/optimizely/odp/lru_cache.py index e7fc32af..073973e6 100644 --- a/optimizely/odp/lru_cache.py +++ b/optimizely/odp/lru_cache.py @@ -91,6 +91,11 @@ def peek(self, key: K) -> Optional[V]: element = self.map.get(key) return element.value if element is not None else None + def remove(self, key: K) -> None: + """Remove the element associated with the provided key from the cache.""" + with self.lock: + self.map.pop(key, None) + @dataclass class CacheElement(Generic[V]): diff --git a/tests/test_lru_cache.py b/tests/test_lru_cache.py index cc4dfdb1..b30617b3 100644 --- a/tests/test_lru_cache.py +++ b/tests/test_lru_cache.py @@ -130,6 +130,82 @@ def test_reset(self): cache.save('cow', 'crate') self.assertEqual(cache.lookup('cow'), 'crate') + def test_remove_non_existent_key(self): + cache = LRUCache(3, 1000) + cache.save("1", 100) + cache.save("2", 200) + + cache.remove("3") # Doesn't exist + + self.assertEqual(cache.lookup("1"), 100) + self.assertEqual(cache.lookup("2"), 200) + + def test_remove_existing_key(self): + cache = LRUCache(3, 1000) + + cache.save("1", 100) + cache.save("2", 200) + cache.save("3", 300) + + self.assertEqual(cache.lookup("1"), 100) + self.assertEqual(cache.lookup("2"), 200) + self.assertEqual(cache.lookup("3"), 300) + + cache.remove("2") + + self.assertEqual(cache.lookup("1"), 100) + self.assertIsNone(cache.lookup("2")) + self.assertEqual(cache.lookup("3"), 300) + + def test_remove_from_zero_sized_cache(self): + cache = LRUCache(0, 1000) + cache.save("1", 100) + cache.remove("1") + + self.assertIsNone(cache.lookup("1")) + + def test_remove_and_add_back(self): + cache = LRUCache(3, 1000) + cache.save("1", 100) + cache.save("2", 200) + cache.save("3", 300) + + cache.remove("2") + cache.save("2", 201) + + self.assertEqual(cache.lookup("1"), 100) + self.assertEqual(cache.lookup("2"), 201) + self.assertEqual(cache.lookup("3"), 300) + + def test_thread_safety(self): + import threading + + max_size = 100 + cache = LRUCache(max_size, 1000) + + for i in range(1, max_size + 1): + cache.save(str(i), i * 100) + + def remove_key(k): + cache.remove(str(k)) + + threads = [] + for i in range(1, (max_size // 2) + 1): + thread = threading.Thread(target=remove_key, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + for i in range(1, max_size + 1): + if i <= max_size // 2: + self.assertIsNone(cache.lookup(str(i))) + else: + self.assertEqual(cache.lookup(str(i)), i * 100) + + self.assertEqual(len(cache.map), max_size // 2) + # type checker test # confirm that LRUCache matches OptimizelySegmentsCache protocol _: OptimizelySegmentsCache = LRUCache(0, 0)