diff --git a/apps/hip-3-pusher/tests/test_listeners.py b/apps/hip-3-pusher/tests/test_listeners.py new file mode 100644 index 0000000000..e045fb39a4 --- /dev/null +++ b/apps/hip-3-pusher/tests/test_listeners.py @@ -0,0 +1,472 @@ +import datetime + +import pytest + +from pusher.config import Config, LazerConfig, HermesConfig, HyperliquidConfig +from pusher.hermes_listener import HermesListener +from pusher.hyperliquid_listener import HyperliquidListener +from pusher.lazer_listener import LazerListener +from pusher.seda_listener import SedaListener +from pusher.price_state import PriceSourceState + + +def get_base_config(): + """Create a base config for testing listeners.""" + config: Config = Config.model_construct() + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = "pyth" + config.hyperliquid.hyperliquid_ws_urls = ["wss://test.example.com/ws"] + config.hyperliquid.asset_context_symbols = ["BTC", "ETH"] + config.lazer = LazerConfig.model_construct() + config.lazer.lazer_urls = ["wss://lazer.example.com"] + config.lazer.lazer_api_key = "test-api-key" + config.lazer.feed_ids = [1, 8] + config.hermes = HermesConfig.model_construct() + config.hermes.hermes_urls = ["wss://hermes.example.com"] + config.hermes.feed_ids = ["hermes_feed1", "hermes_feed2"] + return config + + +class TestHermesListener: + """Tests for HermesListener message parsing.""" + + def test_get_subscribe_request(self): + """Test subscribe request format.""" + config = get_base_config() + hermes_state = PriceSourceState("hermes") + listener = HermesListener(config, hermes_state) + + request = listener.get_subscribe_request() + + assert request["type"] == "subscribe" + assert request["ids"] == ["hermes_feed1", "hermes_feed2"] + assert request["verbose"] is False + assert request["binary"] is True + assert request["allow_out_of_order"] is False + assert request["ignore_invalid_price_ids"] is False + + def test_parse_hermes_message_valid(self): + """Test parsing valid Hermes price update message.""" + config = get_base_config() + hermes_state = PriceSourceState("hermes") + listener = HermesListener(config, hermes_state) + + message = { + "type": "price_update", + "price_feed": { + "id": "test_feed_id", + "price": { + "price": "12345678", + "expo": -8, + "publish_time": 1700000000 + } + } + } + + listener.parse_hermes_message(message) + + update = hermes_state.get("test_feed_id") + assert update is not None + assert update.price == "12345678" + + def test_parse_hermes_message_non_price_update(self): + """Test that non-price_update messages are ignored.""" + config = get_base_config() + hermes_state = PriceSourceState("hermes") + listener = HermesListener(config, hermes_state) + + message = { + "type": "subscription_response", + "data": {} + } + + listener.parse_hermes_message(message) + + assert hermes_state.state == {} + + def test_parse_hermes_message_missing_type(self): + """Test that messages without type are ignored.""" + config = get_base_config() + hermes_state = PriceSourceState("hermes") + listener = HermesListener(config, hermes_state) + + message = { + "price_feed": { + "id": "test_feed_id", + "price": {"price": "12345678"} + } + } + + listener.parse_hermes_message(message) + + assert hermes_state.state == {} + + def test_parse_hermes_message_malformed(self): + """Test that malformed messages result in no state mutation. + + Note: parse_hermes_message swallows exceptions internally, so this test + verifies the observable behavior (no state change) rather than exception handling. + """ + config = get_base_config() + hermes_state = PriceSourceState("hermes") + listener = HermesListener(config, hermes_state) + + message = { + "type": "price_update", + "invalid_key": "invalid_value" + } + + listener.parse_hermes_message(message) + + assert hermes_state.state == {} + + +class TestLazerListener: + """Tests for LazerListener message parsing.""" + + def test_get_subscribe_request(self): + """Test subscribe request format.""" + config = get_base_config() + lazer_state = PriceSourceState("lazer") + listener = LazerListener(config, lazer_state) + + request = listener.get_subscribe_request(subscription_id=42) + + assert request["type"] == "subscribe" + assert request["subscriptionId"] == 42 + assert request["priceFeedIds"] == [1, 8] + assert request["properties"] == ["price"] + assert request["formats"] == [] + assert request["deliveryFormat"] == "json" + assert request["channel"] == "fixed_rate@200ms" + assert request["parsed"] is True + assert request["jsonBinaryEncoding"] == "base64" + + def test_parse_lazer_message_valid(self): + """Test parsing valid Lazer streamUpdated message.""" + config = get_base_config() + lazer_state = PriceSourceState("lazer") + listener = LazerListener(config, lazer_state) + + message = { + "type": "streamUpdated", + "parsed": { + "priceFeeds": [ + {"priceFeedId": 1, "price": "11050000000000"}, + {"priceFeedId": 8, "price": "99000000"} + ] + } + } + + listener.parse_lazer_message(message) + + update1 = lazer_state.get(1) + assert update1 is not None + assert update1.price == "11050000000000" + + update8 = lazer_state.get(8) + assert update8 is not None + assert update8.price == "99000000" + + def test_parse_lazer_message_non_stream_updated(self): + """Test that non-streamUpdated messages are ignored.""" + config = get_base_config() + lazer_state = PriceSourceState("lazer") + listener = LazerListener(config, lazer_state) + + message = { + "type": "subscribed", + "subscriptionId": 1 + } + + listener.parse_lazer_message(message) + + assert lazer_state.state == {} + + def test_parse_lazer_message_missing_feed_id(self): + """Test that feeds without priceFeedId are skipped.""" + config = get_base_config() + lazer_state = PriceSourceState("lazer") + listener = LazerListener(config, lazer_state) + + message = { + "type": "streamUpdated", + "parsed": { + "priceFeeds": [ + {"price": "11050000000000"}, + {"priceFeedId": 8, "price": "99000000"} + ] + } + } + + listener.parse_lazer_message(message) + + assert lazer_state.get(1) is None + update8 = lazer_state.get(8) + assert update8 is not None + assert update8.price == "99000000" + + def test_parse_lazer_message_missing_price(self): + """Test that feeds without price key are skipped.""" + config = get_base_config() + lazer_state = PriceSourceState("lazer") + listener = LazerListener(config, lazer_state) + + message = { + "type": "streamUpdated", + "parsed": { + "priceFeeds": [ + {"priceFeedId": 1}, + {"priceFeedId": 8, "price": "99000000"} + ] + } + } + + listener.parse_lazer_message(message) + + assert lazer_state.get(1) is None + update8 = lazer_state.get(8) + assert update8 is not None + + def test_parse_lazer_message_null_price(self): + """Test that feeds with explicit null price are skipped. + + Lazer emits null prices when no aggregation is available for a feed. + """ + config = get_base_config() + lazer_state = PriceSourceState("lazer") + listener = LazerListener(config, lazer_state) + + message = { + "type": "streamUpdated", + "parsed": { + "priceFeeds": [ + {"priceFeedId": 1, "price": None}, + {"priceFeedId": 8, "price": "99000000"} + ] + } + } + + listener.parse_lazer_message(message) + + assert lazer_state.get(1) is None + update8 = lazer_state.get(8) + assert update8 is not None + assert update8.price == "99000000" + + def test_parse_lazer_message_malformed(self): + """Test that malformed messages don't crash.""" + config = get_base_config() + lazer_state = PriceSourceState("lazer") + listener = LazerListener(config, lazer_state) + + message = { + "type": "streamUpdated", + "invalid_key": "invalid_value" + } + + listener.parse_lazer_message(message) + + assert lazer_state.state == {} + + +class TestHyperliquidListener: + """Tests for HyperliquidListener message parsing.""" + + def test_get_subscribe_request(self): + """Test subscribe request format.""" + config = get_base_config() + hl_oracle_state = PriceSourceState("hl_oracle") + hl_mark_state = PriceSourceState("hl_mark") + hl_mid_state = PriceSourceState("hl_mid") + listener = HyperliquidListener(config, hl_oracle_state, hl_mark_state, hl_mid_state) + + request = listener.get_subscribe_request("BTC") + + assert request["method"] == "subscribe" + assert request["subscription"]["type"] == "activeAssetCtx" + assert request["subscription"]["coin"] == "BTC" + + def test_parse_active_asset_ctx_update(self): + """Test parsing activeAssetCtx update message.""" + config = get_base_config() + hl_oracle_state = PriceSourceState("hl_oracle") + hl_mark_state = PriceSourceState("hl_mark") + hl_mid_state = PriceSourceState("hl_mid") + listener = HyperliquidListener(config, hl_oracle_state, hl_mark_state, hl_mid_state) + + message = { + "channel": "activeAssetCtx", + "data": { + "coin": "BTC", + "ctx": { + "oraclePx": "100000.0", + "markPx": "99500.0" + } + } + } + + listener.parse_hyperliquid_active_asset_ctx_update(message) + + oracle_update = hl_oracle_state.get("BTC") + assert oracle_update is not None + assert oracle_update.price == "100000.0" + + mark_update = hl_mark_state.get("BTC") + assert mark_update is not None + assert mark_update.price == "99500.0" + + def test_parse_active_asset_ctx_update_malformed(self): + """Test that malformed activeAssetCtx messages don't crash.""" + config = get_base_config() + hl_oracle_state = PriceSourceState("hl_oracle") + hl_mark_state = PriceSourceState("hl_mark") + hl_mid_state = PriceSourceState("hl_mid") + listener = HyperliquidListener(config, hl_oracle_state, hl_mark_state, hl_mid_state) + + message = { + "channel": "activeAssetCtx", + "data": {} + } + + listener.parse_hyperliquid_active_asset_ctx_update(message) + + assert hl_oracle_state.state == {} + assert hl_mark_state.state == {} + + def test_parse_all_mids_update(self): + """Test parsing allMids update message.""" + config = get_base_config() + hl_oracle_state = PriceSourceState("hl_oracle") + hl_mark_state = PriceSourceState("hl_mark") + hl_mid_state = PriceSourceState("hl_mid") + listener = HyperliquidListener(config, hl_oracle_state, hl_mark_state, hl_mid_state) + + message = { + "channel": "allMids", + "data": { + "mids": { + "pyth:BTC": "100250.0", + "pyth:ETH": "3050.0" + } + } + } + + listener.parse_hyperliquid_all_mids_update(message) + + btc_mid = hl_mid_state.get("pyth:BTC") + assert btc_mid is not None + assert btc_mid.price == "100250.0" + + eth_mid = hl_mid_state.get("pyth:ETH") + assert eth_mid is not None + assert eth_mid.price == "3050.0" + + def test_parse_all_mids_update_malformed(self): + """Test that malformed allMids messages don't crash.""" + config = get_base_config() + hl_oracle_state = PriceSourceState("hl_oracle") + hl_mark_state = PriceSourceState("hl_mark") + hl_mid_state = PriceSourceState("hl_mid") + listener = HyperliquidListener(config, hl_oracle_state, hl_mark_state, hl_mid_state) + + message = { + "channel": "allMids", + "data": {} + } + + listener.parse_hyperliquid_all_mids_update(message) + + assert hl_mid_state.state == {} + + def test_parse_all_mids_update_empty_mids(self): + """Test parsing allMids with empty mids dict.""" + config = get_base_config() + hl_oracle_state = PriceSourceState("hl_oracle") + hl_mark_state = PriceSourceState("hl_mark") + hl_mid_state = PriceSourceState("hl_mid") + listener = HyperliquidListener(config, hl_oracle_state, hl_mark_state, hl_mid_state) + + message = { + "channel": "allMids", + "data": { + "mids": {} + } + } + + listener.parse_hyperliquid_all_mids_update(message) + + assert hl_mid_state.state == {} + + +class MockSedaListener: + """Mock SedaListener for testing _parse_seda_message directly.""" + + def __init__(self, seda_state: PriceSourceState): + self.seda_state = seda_state + + +@pytest.fixture +def seda_listener_fixture(): + """Create a mock SEDA listener with fresh state for testing.""" + seda_state = PriceSourceState("seda") + listener = MockSedaListener(seda_state) + return listener, seda_state + + +class TestSedaListener: + """Tests for SedaListener message parsing.""" + + def test_parse_seda_message_valid(self, seda_listener_fixture): + """Test parsing valid SEDA message.""" + listener, seda_state = seda_listener_fixture + + message = { + "data": { + "result": '{"composite_rate": "42.5", "timestamp": "2024-01-15T12:00:00.000Z"}' + } + } + + SedaListener._parse_seda_message(listener, "custom_feed", message) + + update = seda_state.get("custom_feed") + assert update is not None + assert update.price == "42.5" + expected_timestamp = datetime.datetime.fromisoformat("2024-01-15T12:00:00.000Z").timestamp() + assert update.timestamp == expected_timestamp + + def test_parse_seda_message_different_timestamp_format(self, seda_listener_fixture): + """Test parsing SEDA message with microseconds in timestamp. + + Python's fromisoformat() handles various ISO 8601 formats including + timestamps with microseconds, which SEDA may emit. + """ + listener, seda_state = seda_listener_fixture + + message = { + "data": { + "result": '{"composite_rate": "100.25", "timestamp": "2024-06-20T15:30:45.123456Z"}' + } + } + + SedaListener._parse_seda_message(listener, "another_feed", message) + + update = seda_state.get("another_feed") + assert update is not None + assert update.price == "100.25" + + def test_parse_seda_message_numeric_rate(self, seda_listener_fixture): + """Test parsing SEDA message with numeric composite_rate.""" + listener, seda_state = seda_listener_fixture + + message = { + "data": { + "result": '{"composite_rate": 123.456, "timestamp": "2024-01-15T12:00:00.000Z"}' + } + } + + SedaListener._parse_seda_message(listener, "numeric_feed", message) + + update = seda_state.get("numeric_feed") + assert update is not None + assert update.price == 123.456 diff --git a/apps/hip-3-pusher/tests/test_price_state.py b/apps/hip-3-pusher/tests/test_price_state.py index 52e023119a..26ae69d9b6 100644 --- a/apps/hip-3-pusher/tests/test_price_state.py +++ b/apps/hip-3-pusher/tests/test_price_state.py @@ -1,8 +1,9 @@ +import pytest import time from pusher.config import Config, LazerConfig, HermesConfig, PriceConfig, PriceSource, SingleSourceConfig, \ - PairSourceConfig, HyperliquidConfig -from pusher.price_state import PriceState, PriceUpdate + PairSourceConfig, HyperliquidConfig, ConstantSourceConfig, OracleMidAverageConfig +from pusher.price_state import PriceState, PriceUpdate, PriceSourceState, OracleUpdate DEX = "pyth" SYMBOL = "BTC" @@ -12,6 +13,7 @@ def get_config(): config: Config = Config.model_construct() config.stale_price_threshold_seconds = 5 config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX config.hyperliquid.asset_context_symbols = [SYMBOL] config.lazer = LazerConfig.model_construct() config.lazer.feed_ids = [1, 8] @@ -44,8 +46,8 @@ def test_good_hl_price(): now = time.time() price_state.hl_oracle_state.put(SYMBOL, PriceUpdate("110000.0", now - price_state.stale_price_threshold_seconds / 2.0)) - oracle_px, _, _ = price_state.get_all_prices(DEX) - assert oracle_px == {f"{DEX}:{SYMBOL}": "110000.0"} + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:{SYMBOL}": "110000.0"} def test_fallback_lazer(): @@ -59,8 +61,8 @@ def test_fallback_lazer(): price_state.lazer_state.put(1, PriceUpdate("11050000000000", now - price_state.stale_price_threshold_seconds / 2.0)) price_state.lazer_state.put(8, PriceUpdate("99000000", now - price_state.stale_price_threshold_seconds / 2.0)) - oracle_px, _, _ = price_state.get_all_prices(DEX) - assert oracle_px == {f"{DEX}:{SYMBOL}": "111616.16"} + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:{SYMBOL}": "111616.16"} @@ -79,8 +81,8 @@ def test_fallback_hermes(): price_state.hermes_state.put("2b89b9dc8fdf9f34709a5b106b472f0f39bb6ca9ce04b0fd7f2e971688e2e53b", PriceUpdate("98000000", now - price_state.stale_price_threshold_seconds / 2.0)) - oracle_px, _, _ = price_state.get_all_prices(DEX) - assert oracle_px == {f"{DEX}:{SYMBOL}": "113265.31"} + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:{SYMBOL}": "113265.31"} def test_all_fail(): @@ -98,5 +100,588 @@ def test_all_fail(): price_state.hermes_state.put("2b89b9dc8fdf9f34709a5b106b472f0f39bb6ca9ce04b0fd7f2e971688e2e53b", PriceUpdate("98000000", now - price_state.stale_price_threshold_seconds - 1.0)) - oracle_px, _, _ = price_state.get_all_prices(DEX) - assert oracle_px == {} + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {} + + +class TestPriceUpdate: + """Tests for the PriceUpdate dataclass.""" + + def test_time_diff(self): + """Test time_diff calculation.""" + update = PriceUpdate(price="100.0", timestamp=1000.0) + assert update.time_diff(1005.0) == 5.0 + + def test_time_diff_negative(self): + """Test time_diff with future timestamp (negative diff).""" + update = PriceUpdate(price="100.0", timestamp=1010.0) + assert update.time_diff(1005.0) == -5.0 + + def test_price_can_be_float(self): + """Test that price can be a float.""" + update = PriceUpdate(price=100.5, timestamp=1000.0) + assert update.price == 100.5 + + def test_price_can_be_string(self): + """Test that price can be a string.""" + update = PriceUpdate(price="100.5", timestamp=1000.0) + assert update.price == "100.5" + + +class TestPriceSourceState: + """Tests for the PriceSourceState class.""" + + def test_init(self): + """Test initialization.""" + state = PriceSourceState("test_source") + assert state.name == "test_source" + assert state.state == {} + + def test_put_and_get(self): + """Test put and get operations.""" + state = PriceSourceState("test_source") + update = PriceUpdate(price="100.0", timestamp=1000.0) + state.put("BTC", update) + assert state.get("BTC") == update + + def test_get_missing_key(self): + """Test get returns None for missing key.""" + state = PriceSourceState("test_source") + assert state.get("MISSING") is None + + def test_repr(self): + """Test string representation.""" + state = PriceSourceState("test_source") + update = PriceUpdate(price="100.0", timestamp=1000.0) + state.put("BTC", update) + repr_str = repr(state) + assert "test_source" in repr_str + assert "BTC" in repr_str + + def test_overwrite_value(self): + """Test that put overwrites existing values.""" + state = PriceSourceState("test_source") + state.put("BTC", PriceUpdate(price="100.0", timestamp=1000.0)) + state.put("BTC", PriceUpdate(price="200.0", timestamp=2000.0)) + assert state.get("BTC").price == "200.0" + + +class TestOracleUpdate: + """Tests for the OracleUpdate dataclass.""" + + def test_init(self): + """Test initialization.""" + update = OracleUpdate( + oracle={"pyth:BTC": "100.0"}, + mark={"pyth:BTC": "99.0"}, + external={"pyth:ETH": "3000.0"} + ) + assert update.oracle == {"pyth:BTC": "100.0"} + assert update.mark == {"pyth:BTC": "99.0"} + assert update.external == {"pyth:ETH": "3000.0"} + + def test_empty_init(self): + """Test initialization with empty dicts.""" + update = OracleUpdate(oracle={}, mark={}, external={}) + assert update.oracle == {} + assert update.mark == {} + assert update.external == {} + + +class TestConstantSourceConfig: + """Tests for constant price source configuration.""" + + def test_constant_source(self): + """Test that constant source returns configured value.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [] + config.price = PriceConfig( + oracle={ + "STABLE": [ + ConstantSourceConfig(source_type="constant", value="1.0") + ] + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:STABLE": "1.0"} + + def test_constant_source_with_fallback(self): + """Test constant source as fallback when primary source is stale.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = ["STABLE"] + config.price = PriceConfig( + oracle={ + "STABLE": [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id="STABLE", exponent=None)), + ConstantSourceConfig(source_type="constant", value="1.0") + ] + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_oracle_state.put("STABLE", PriceUpdate("0.99", now - 10.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:STABLE": "1.0"} + + +class TestOracleMidAverageConfig: + """Tests for oracle-mid average price source configuration.""" + + def test_oracle_mid_average(self): + """Test oracle-mid average calculation.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [SYMBOL] + config.price = PriceConfig( + oracle={ + SYMBOL: [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id=SYMBOL, exponent=None)), + ] + }, + mark={ + SYMBOL: [ + OracleMidAverageConfig(source_type="oracle_mid_average", symbol=f"{DEX}:{SYMBOL}") + ] + }, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_oracle_state.put(SYMBOL, PriceUpdate("100.0", now - 1.0)) + price_state.hl_mid_state.put(f"{DEX}:{SYMBOL}", PriceUpdate("102.0", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:{SYMBOL}": "100.0"} + assert oracle_update.mark == {f"{DEX}:{SYMBOL}": "101.0"} + + def test_oracle_mid_average_missing_oracle(self): + """Test oracle-mid average returns None when oracle price is missing.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [SYMBOL] + config.price = PriceConfig( + oracle={}, + mark={ + SYMBOL: [ + OracleMidAverageConfig(source_type="oracle_mid_average", symbol=f"{DEX}:{SYMBOL}") + ] + }, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_mid_state.put(f"{DEX}:{SYMBOL}", PriceUpdate("102.0", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.mark == {} + + def test_oracle_mid_average_missing_mid(self): + """Test oracle-mid average returns None when mid price is missing.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [SYMBOL] + config.price = PriceConfig( + oracle={ + SYMBOL: [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id=SYMBOL, exponent=None)), + ] + }, + mark={ + SYMBOL: [ + OracleMidAverageConfig(source_type="oracle_mid_average", symbol=f"{DEX}:{SYMBOL}") + ] + }, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_oracle_state.put(SYMBOL, PriceUpdate("100.0", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:{SYMBOL}": "100.0"} + assert oracle_update.mark == {} + + def test_oracle_mid_average_stale_mid(self): + """Test oracle-mid average returns None when mid price is stale.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [SYMBOL] + config.price = PriceConfig( + oracle={ + SYMBOL: [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id=SYMBOL, exponent=None)), + ] + }, + mark={ + SYMBOL: [ + OracleMidAverageConfig(source_type="oracle_mid_average", symbol=f"{DEX}:{SYMBOL}") + ] + }, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_oracle_state.put(SYMBOL, PriceUpdate("100.0", now - 1.0)) + price_state.hl_mid_state.put(f"{DEX}:{SYMBOL}", PriceUpdate("102.0", now - 10.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:{SYMBOL}": "100.0"} + assert oracle_update.mark == {} + + +class TestMarkAndExternalPrices: + """Tests for mark and external price configurations.""" + + def test_mark_prices(self): + """Test mark prices are returned correctly.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [SYMBOL] + config.price = PriceConfig( + oracle={}, + mark={ + SYMBOL: [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_mark", source_id=SYMBOL, exponent=None)), + ] + }, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_mark_state.put(SYMBOL, PriceUpdate("99500.0", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.mark == {f"{DEX}:{SYMBOL}": "99500.0"} + + def test_external_prices(self): + """Test external prices are returned correctly.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = ["ETH"] + config.price = PriceConfig( + oracle={}, + mark={}, + external={ + "ETH": [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id="ETH", exponent=None)), + ] + } + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_oracle_state.put("ETH", PriceUpdate("3000.0", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.external == {f"{DEX}:ETH": "3000.0"} + + +class TestMultipleSymbols: + """Tests for multiple symbols in configuration.""" + + def test_multiple_oracle_symbols(self): + """Test multiple symbols in oracle config.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = ["BTC", "ETH", "SOL"] + config.price = PriceConfig( + oracle={ + "BTC": [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id="BTC", exponent=None)), + ], + "ETH": [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id="ETH", exponent=None)), + ], + "SOL": [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id="SOL", exponent=None)), + ], + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_oracle_state.put("BTC", PriceUpdate("100000.0", now - 1.0)) + price_state.hl_oracle_state.put("ETH", PriceUpdate("3000.0", now - 1.0)) + price_state.hl_oracle_state.put("SOL", PriceUpdate("200.0", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == { + f"{DEX}:BTC": "100000.0", + f"{DEX}:ETH": "3000.0", + f"{DEX}:SOL": "200.0", + } + + def test_partial_symbols_available(self): + """Test when only some symbols have fresh prices.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = ["BTC", "ETH"] + config.price = PriceConfig( + oracle={ + "BTC": [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id="BTC", exponent=None)), + ], + "ETH": [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id="ETH", exponent=None)), + ], + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_oracle_state.put("BTC", PriceUpdate("100000.0", now - 1.0)) + price_state.hl_oracle_state.put("ETH", PriceUpdate("3000.0", now - 10.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:BTC": "100000.0"} + + +class TestPairSourceEdgeCases: + """Tests for pair source edge cases.""" + + def test_pair_source_base_missing(self): + """Test pair source returns None when base price is missing.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [] + config.price = PriceConfig( + oracle={ + SYMBOL: [ + PairSourceConfig(source_type="pair", + base_source=PriceSource(source_name="lazer", source_id=1, exponent=-8), + quote_source=PriceSource(source_name="lazer", source_id=8, exponent=-8)), + ] + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.lazer_state.put(8, PriceUpdate("99000000", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {} + + def test_pair_source_quote_missing(self): + """Test pair source returns None when quote price is missing.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [] + config.price = PriceConfig( + oracle={ + SYMBOL: [ + PairSourceConfig(source_type="pair", + base_source=PriceSource(source_name="lazer", source_id=1, exponent=-8), + quote_source=PriceSource(source_name="lazer", source_id=8, exponent=-8)), + ] + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.lazer_state.put(1, PriceUpdate("11050000000000", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {} + + def test_pair_source_base_stale(self): + """Test pair source returns None when base price is stale.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [] + config.price = PriceConfig( + oracle={ + SYMBOL: [ + PairSourceConfig(source_type="pair", + base_source=PriceSource(source_name="lazer", source_id=1, exponent=-8), + quote_source=PriceSource(source_name="lazer", source_id=8, exponent=-8)), + ] + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.lazer_state.put(1, PriceUpdate("11050000000000", now - 10.0)) + price_state.lazer_state.put(8, PriceUpdate("99000000", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {} + + +class TestSedaSource: + """Tests for SEDA price source.""" + + def test_seda_source(self): + """Test SEDA source returns price correctly.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [] + config.price = PriceConfig( + oracle={ + "CUSTOM": [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="seda", source_id="custom_feed", exponent=None)), + ] + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.seda_state.put("custom_feed", PriceUpdate("42.5", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:CUSTOM": "42.5"} + + +class TestExponentConversion: + """Tests for exponent conversion in price sources.""" + + def test_positive_exponent(self): + """Test price conversion with positive exponent. + + Formula: price / (10 ** -exponent) + With exponent=2 and price=100: 100 / (10 ** -2) = 100 / 0.01 = 10000.0 + """ + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [] + config.price = PriceConfig( + oracle={ + SYMBOL: [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="lazer", source_id=1, exponent=2)), + ] + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.lazer_state.put(1, PriceUpdate("100", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:{SYMBOL}": "10000.0"} + + def test_negative_exponent(self): + """Test price conversion with negative exponent.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [] + config.price = PriceConfig( + oracle={ + SYMBOL: [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="lazer", source_id=1, exponent=-8)), + ] + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.lazer_state.put(1, PriceUpdate("10000000000000", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:{SYMBOL}": "100000.0"} + + def test_no_exponent(self): + """Test price pass-through with no exponent.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [SYMBOL] + config.price = PriceConfig( + oracle={ + SYMBOL: [ + SingleSourceConfig(source_type="single", source=PriceSource(source_name="hl_oracle", source_id=SYMBOL, exponent=None)), + ] + }, + mark={}, + external={} + ) + + price_state = PriceState(config) + now = time.time() + price_state.hl_oracle_state.put(SYMBOL, PriceUpdate("100000.0", now - 1.0)) + + oracle_update = price_state.get_all_prices() + assert oracle_update.oracle == {f"{DEX}:{SYMBOL}": "100000.0"} + + +class TestInvalidSourceConfig: + """Tests for invalid source configuration handling.""" + + def test_invalid_source_type_raises(self): + """Test that invalid source type raises ValueError.""" + config: Config = Config.model_construct() + config.stale_price_threshold_seconds = 5 + config.hyperliquid = HyperliquidConfig.model_construct() + config.hyperliquid.market_name = DEX + config.hyperliquid.asset_context_symbols = [] + config.price = PriceConfig(oracle={}, mark={}, external={}) + + price_state = PriceState(config) + + class InvalidConfig: + source_type = "invalid" + + with pytest.raises(ValueError): + price_state.get_price(InvalidConfig(), OracleUpdate({}, {}, {}))