Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Initial implementation of MSC3981.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Mar 24, 2023
1 parent e6af49f commit d41e1f6
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 14 deletions.
1 change: 1 addition & 0 deletions changelog.d/15315.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support to recursively provide relations per [MSC3981](https://github.com/matrix-org/matrix-spec-proposals/pull/3981).
5 changes: 5 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,8 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:

# MSC2659: Application service ping endpoint
self.msc2659_enabled = experimental.get("msc2659_enabled", False)

# MSC3981: Recurse relations
self.msc3981_recurse_relations = experimental.get(
"msc3981_recurse_relations", False
)
3 changes: 3 additions & 0 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def get_relations(
event_id: str,
room_id: str,
pagin_config: PaginationConfig,
recurse: bool,
include_original_event: bool,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
Expand All @@ -98,6 +99,7 @@ async def get_relations(
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
pagin_config: The pagination config rules to apply, if any.
recurse: Whether to recursively find relations.
include_original_event: Whether to include the parent event.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
Expand Down Expand Up @@ -132,6 +134,7 @@ async def get_relations(
direction=pagin_config.direction,
from_token=pagin_config.from_token,
to_token=pagin_config.to_token,
recurse=recurse,
)

events = await self._main_store.get_events_as_list(
Expand Down
10 changes: 9 additions & 1 deletion synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self._store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
self._support_recurse = hs.config.experimental.msc3981_recurse_relations

async def on_GET(
self,
Expand All @@ -63,6 +64,12 @@ async def on_GET(
pagination_config = await PaginationConfig.from_request(
self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
)
if self._support_recurse:
recurse = parse_boolean(
request, "org.matrix.msc3981.recurse", default=False
)
else:
recurse = False

# The unstable version of this API returns an extra field for client
# compatibility, see https://github.com/matrix-org/synapse/issues/12930.
Expand All @@ -75,6 +82,7 @@ async def on_GET(
event_id=parent_id,
room_id=room_id,
pagin_config=pagination_config,
recurse=recurse,
include_original_event=include_original_event,
relation_type=relation_type,
event_type=event_type,
Expand Down
52 changes: 39 additions & 13 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ async def get_relations_for_event(
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
recurse: bool = False,
) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Expand All @@ -186,6 +187,7 @@ async def get_relations_for_event(
oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
recurse: Whether to recursively find relations.
Returns:
A tuple of:
Expand All @@ -200,7 +202,7 @@ async def get_relations_for_event(
# Ensure bad limits aren't being passed in.
assert limit >= 0

where_clause = ["relates_to_id = ?", "room_id = ?"]
where_clause = ["room_id = ?"]
where_args: List[Union[str, int]] = [event.event_id, room_id]
is_redacted = event.internal_metadata.is_redacted()

Expand Down Expand Up @@ -229,18 +231,42 @@ async def get_relations_for_event(
if pagination_clause:
where_clause.append(pagination_clause)

sql = """
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?
""" % (
" AND ".join(where_clause),
order,
order,
)
# TODO This needs to be a recursive query that maybe still matches some of the times above.
if recurse:
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relation_type, relates_to_id, 0 depth
FROM event_relations
WHERE relates_to_id = ?
UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1
FROM event_relations e
INNER JOIN related_events r ON r.event_id = e.relates_to_id
WHERE depth <= 3
)
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM related_events
INNER JOIN events USING (event_id)
WHERE %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?;
""" % (
" AND ".join(where_clause),
order,
order,
)
else:
sql = """
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?
""" % (
" AND ".join(where_clause),
order,
order,
)

def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
Expand Down

0 comments on commit d41e1f6

Please sign in to comment.