diff --git a/changelog.d/18960.bugfix b/changelog.d/18960.bugfix new file mode 100644 index 00000000000..909089f8092 --- /dev/null +++ b/changelog.d/18960.bugfix @@ -0,0 +1 @@ +Fix a bug in the database function for fetching state deltas that could result in unnecessarily long query times. \ No newline at end of file diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 9c5e837ab07..4885268305d 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -683,7 +683,7 @@ async def get_current_state_deltas( # https://github.com/matrix-org/synapse/issues/13008 return await self.stores.main.get_partial_current_state_deltas( - prev_stream_id, max_stream_id + prev_stream_id, max_stream_id, limit=100 ) @trace diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index cd8f286d085..a5d5407327c 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -78,27 +78,41 @@ def __init__( ) async def get_partial_current_state_deltas( - self, prev_stream_id: int, max_stream_id: int + self, prev_stream_id: int, max_stream_id: int, limit: int = 100 ) -> tuple[int, list[StateDelta]]: - """Fetch a list of room state changes since the given stream id + """Fetch a list of room state changes since the given stream id. This may be the partial state if we're lazy joining the room. + This method takes care to handle state deltas that share the same + `stream_id`. That can happen when persisting state in a batch, + potentially as the result of state resolution (both adding new state and + undo'ing previous state). + + State deltas are grouped by `stream_id`. When hitting the given `limit` + would return only part of a "group" of state deltas, that entire group + is omitted. Thus, this function may return *up to* `limit` state deltas, + or slightly more when a single group itself exceeds `limit`. + Args: prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted - ie, an upper limit to return changes from. + limit: the maximum number of rows to return. Returns: A tuple consisting of: - the stream id which these results go up to - list of current_state_delta_stream rows. If it is empty, we are up to date. - - A maximum of 100 rows will be returned. """ prev_stream_id = int(prev_stream_id) + if limit <= 0: + raise ValueError( + "Invalid `limit` passed to `get_partial_current_state_deltas" + ) + # check we're not going backwards assert prev_stream_id <= max_stream_id, ( f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" @@ -115,45 +129,62 @@ async def get_partial_current_state_deltas( def get_current_state_deltas_txn( txn: LoggingTransaction, ) -> tuple[int, list[StateDelta]]: - # First we calculate the max stream id that will give us less than - # N results. - # We arbitrarily limit to 100 stream_id entries to ensure we don't - # select toooo many. - sql = """ - SELECT stream_id, count(*) + # First we group state deltas by `stream_id` and calculate which + # groups can be returned without exceeding the provided `limit`. + sql_grouped = """ + SELECT stream_id, COUNT(*) AS c FROM current_state_delta_stream WHERE stream_id > ? AND stream_id <= ? GROUP BY stream_id - ORDER BY stream_id ASC - LIMIT 100 + ORDER BY stream_id + LIMIT ? """ - txn.execute(sql, (prev_stream_id, max_stream_id)) - - total = 0 - - for stream_id, count in txn: - total += count - if total > 100: - # We arbitrarily limit to 100 entries to ensure we don't - # select toooo many. - logger.debug( - "Clipping current_state_delta_stream rows to stream_id %i", - stream_id, - ) - clipped_stream_id = stream_id + group_limit = limit + 1 + txn.execute(sql_grouped, (prev_stream_id, max_stream_id, group_limit)) + grouped_rows = txn.fetchall() + + if not grouped_rows: + # Nothing to return in the range; we are up to date through max_stream_id. + return max_stream_id, [] + + # Always retrieve the first group, at the bare minimum. This ensures the + # caller always makes progress, even if a single group exceeds `limit`. + fetch_upto_stream_id, included_rows = grouped_rows[0] + + # Determine which other groups we can retrieve at the same time, + # without blowing the budget. + included_all_groups = True + for stream_id, count in grouped_rows[1:]: + if included_rows + count > limit: + included_all_groups = False break - else: - # if there's no problem, we may as well go right up to the max_stream_id - clipped_stream_id = max_stream_id + included_rows += count + fetch_upto_stream_id = stream_id + + # If we retrieved fewer groups than the limit *and* we didn't hit the + # `LIMIT ?` cap on the grouping query, we know we've caught up with + # the stream. + caught_up_with_stream = ( + included_all_groups and len(grouped_rows) < group_limit + ) + + # At this point we should have advanced, or bailed out early above. + assert fetch_upto_stream_id != prev_stream_id - # Now actually get the deltas - sql = """ + # 2) Fetch the actual rows for only the included stream_id groups. + sql_rows = """ SELECT stream_id, room_id, type, state_key, event_id, prev_event_id FROM current_state_delta_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ - txn.execute(sql, (prev_stream_id, clipped_stream_id)) + txn.execute(sql_rows, (prev_stream_id, fetch_upto_stream_id)) + rows = txn.fetchall() + + clipped_stream_id = ( + max_stream_id if caught_up_with_stream else fetch_upto_stream_id + ) + return clipped_stream_id, [ StateDelta( stream_id=row[0], @@ -163,7 +194,7 @@ def get_current_state_deltas_txn( event_id=row[4], prev_event_id=row[5], ) - for row in txn.fetchall() + for row in rows ] return await self.db_pool.runInteraction( diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 8e821c6d183..dbbede812d9 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -19,6 +19,7 @@ # # +import json import logging from typing import cast @@ -33,6 +34,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID from synapse.types.state import StateFilter from synapse.util.clock import Clock +from synapse.util.stringutils import random_string from tests.unittest import HomeserverTestCase @@ -643,3 +645,315 @@ def test_batched_state_group_storing(self) -> None: ), ) self.assertEqual(context.state_group_before_event, groups[0][0]) + + +class CurrentStateDeltaStreamTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self.store = hs.get_datastores().main + self.storage = hs.get_storage_controllers() + self.state_datastore = self.storage.state.stores.state + self.event_creation_handler = hs.get_event_creation_handler() + self.event_builder_factory = hs.get_event_builder_factory() + + # Create a made-up room and a user. + self.alice_user_id = UserID.from_string("@alice:test") + self.room = RoomID.from_string("!abc1234:test") + + self.get_success( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) + ) + + def inject_state_event( + self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict + ) -> EventBase: + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": typ, + "sender": sender.to_string(), + "state_key": state_key, + "room_id": room.to_string(), + "content": content, + }, + ) + + event, unpersisted_context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + + context = self.get_success(unpersisted_context.persist(event)) + + assert self.storage.persistence is not None + self.get_success(self.storage.persistence.persist_event(event, context)) + + return event + + def test_get_partial_current_state_deltas_limit(self) -> None: + """ + Tests that `get_partial_current_state_deltas` actually returns `limit` rows. + + Regression test for https://github.com/element-hq/synapse/pull/18960. + """ + # Inject a create event which other events can auth with. + self.inject_state_event( + self.room, self.alice_user_id, EventTypes.Create, "", {} + ) + + limit = 2 + + # Make N*2 state changes in the room, resulting in 2N+1 total state + # events (including the create event) in the room. + for i in range(limit * 2): + self.inject_state_event( + self.room, + self.alice_user_id, + EventTypes.Name, + "", + {"name": f"rename #{i}"}, + ) + + # Call the function under test. This must return <= `limit` rows. + max_stream_id = self.store.get_room_max_stream_ordering() + clipped_stream_id, deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=0, + max_stream_id=max_stream_id, + limit=limit, + ) + ) + + self.assertLessEqual( + len(deltas), limit, f"Returned {len(deltas)} rows, expected at most {limit}" + ) + + # Advancing from the clipped point should eventually drain the remainder. + # Make sure we make progress and don’t get stuck. + if deltas: + next_prev = clipped_stream_id + next_clipped, next_deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=next_prev, max_stream_id=max_stream_id, limit=limit + ) + ) + self.assertNotEqual( + next_clipped, clipped_stream_id, "Did not advance clipped_stream_id" + ) + # Still should respect the limit. + self.assertLessEqual(len(next_deltas), limit) + + def test_non_unique_stream_ids_in_current_state_delta_stream(self) -> None: + """ + Tests that `get_partial_current_state_deltas` always returns entire + groups of state deltas (grouped by `stream_id`), and never part of one. + + We check by passing a `limit` that to the function that, if followed + blindly, would split a group of state deltas that share a `stream_id`. + The test passes if that group is not returned at all (because doing so + would overshoot the limit of returned state deltas). + + Regression test for https://github.com/element-hq/synapse/pull/18960. + """ + # Inject a create event to start with. + self.inject_state_event( + self.room, self.alice_user_id, EventTypes.Create, "", {} + ) + + # Then inject one "real" m.room.name event. This will give us a stream_id that + # we can create some more (fake) events with. + self.inject_state_event( + self.room, + self.alice_user_id, + EventTypes.Name, + "", + {"name": "rename #1"}, + ) + + # Get the stream_id of the last-inserted event. + max_stream_id = self.store.get_room_max_stream_ordering() + + # Make 3 more state changes in the room, resulting in 5 total state + # events (including the create event, and the first name update) in + # the room. + # + # All of these state deltas have the same `stream_id` as the original name event. + # Do so by editing the table directly as that's the simplest way to have + # all share the same `stream_id`. + self.get_success( + self.store.db_pool.simple_insert_many( + "current_state_delta_stream", + keys=( + "stream_id", + "room_id", + "type", + "state_key", + "event_id", + "prev_event_id", + "instance_name", + ), + values=[ + ( + max_stream_id, + self.room.to_string(), + EventTypes.Name, + "", + f"${random_string(5)}:test", + json.dumps({"name": f"rename #{i}"}), + "master", + ) + for i in range(3) + ], + desc="inject_room_name_state_events", + ) + ) + + # Call the function under test with a limit of 4. Without the limit, we + # would return 5 state deltas: + # + # C N N N N + # 1 2 3 4 5 + # + # C = m.room.create + # N = m.room.name + # + # With the limit, we should return only the create event, as returning 4 + # state deltas would result in splitting a group: + # + # 2 3 3 3 3 - state IDs/groups + # C N N N N + # 1 2 3 4 X + + clipped_stream_id, deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=0, + max_stream_id=max_stream_id, + limit=4, + ) + ) + + # 2 is the stream ID of the m.room.create event. + self.assertEqual(clipped_stream_id, 2) + self.assertEqual( + len(deltas), + 1, + f"Returned {len(deltas)} rows, expected only one (the create event): {deltas}", + ) + + # Advance once more with our limit of 4. We should now get all 4 + # `m.room.name` state deltas as they can fit under the limit. + clipped_stream_id, next_deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=clipped_stream_id, max_stream_id=max_stream_id, limit=4 + ) + ) + self.assertEqual( + clipped_stream_id, 3 + ) # The stream ID of the 4 m.room.name events. + + self.assertEqual( + len(next_deltas), + 4, + f"Returned {len(next_deltas)} rows, expected all 4 m.room.name events: {next_deltas}", + ) + + def test_get_partial_current_state_deltas_does_not_enter_infinite_loop( + self, + ) -> None: + """ + Tests that `get_partial_current_state_deltas` does not repeatedly return + zero entries due to the passed `limit` parameter being less than the + size of the next group of state deltas from the given `prev_stream_id`. + """ + # Inject a create event to start with. + self.inject_state_event( + self.room, self.alice_user_id, EventTypes.Create, "", {} + ) + + # Then inject one "real" m.room.name event. This will give us a stream_id that + # we can create some more (fake) events with. + self.inject_state_event( + self.room, + self.alice_user_id, + EventTypes.Name, + "", + {"name": "rename #1"}, + ) + + # Get the stream_id of the last-inserted event. + max_stream_id = self.store.get_room_max_stream_ordering() + + # Make 3 more state changes in the room, resulting in 5 total state + # events (including the create event, and the first name update) in + # the room. + # + # All of these state deltas have the same `stream_id` as the original name event. + # Do so by editing the table directly as that's the simplest way to have + # all share the same `stream_id`. + self.get_success( + self.store.db_pool.simple_insert_many( + "current_state_delta_stream", + keys=( + "stream_id", + "room_id", + "type", + "state_key", + "event_id", + "prev_event_id", + "instance_name", + ), + values=[ + ( + max_stream_id, + self.room.to_string(), + EventTypes.Name, + "", + f"${random_string(5)}:test", + json.dumps({"name": f"rename #{i}"}), + "master", + ) + for i in range(3) + ], + desc="inject_room_name_state_events", + ) + ) + + # Call the function under test with a limit of 4. Without the limit, we would return + # 5 state deltas: + # + # C N N N N + # 1 2 3 4 5 + # + # C = m.room.create + # N = m.room.name + # + # With the limit, we should return only the create event, as returning 4 + # state deltas would result in splitting a group: + # + # 2 3 3 3 3 - state IDs/groups + # C N N N N + # 1 2 3 4 X + + clipped_stream_id, deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=2, # Start after the create event (which has stream_id 2). + max_stream_id=max_stream_id, + limit=2, # Less than the size of the next group (which is 4). + ) + ) + + self.assertEqual( + clipped_stream_id, 3 + ) # The stream ID of the 4 m.room.name events. + + # We should get all 4 `m.room.name` state deltas, instead of 0, which + # would result in the caller entering an infinite loop. + self.assertEqual( + len(deltas), + 4, + f"Returned {len(deltas)} rows, expected 4 even though it broke our limit: {deltas}", + )