Instead of wrapping the JSON into an object, this creates concrete instances for Transaction and Edu. This allows for improved type hints and simplified code.tags/v1.41.0rc1
@@ -0,0 +1 @@ | |||
Convert `Transaction` and `Edu` objects to attrs. |
@@ -195,13 +195,17 @@ class FederationServer(FederationBase): | |||
origin, room_id, versions, limit | |||
) | |||
res = self._transaction_from_pdus(pdus).get_dict() | |||
res = self._transaction_dict_from_pdus(pdus) | |||
return 200, res | |||
async def on_incoming_transaction( | |||
self, origin: str, transaction_data: JsonDict | |||
) -> Tuple[int, Dict[str, Any]]: | |||
self, | |||
origin: str, | |||
transaction_id: str, | |||
destination: str, | |||
transaction_data: JsonDict, | |||
) -> Tuple[int, JsonDict]: | |||
# If we receive a transaction we should make sure that kick off handling | |||
# any old events in the staging area. | |||
if not self._started_handling_of_staged_events: | |||
@@ -212,8 +216,14 @@ class FederationServer(FederationBase): | |||
# accurate as possible. | |||
request_time = self._clock.time_msec() | |||
transaction = Transaction(**transaction_data) | |||
transaction_id = transaction.transaction_id # type: ignore | |||
transaction = Transaction( | |||
transaction_id=transaction_id, | |||
destination=destination, | |||
origin=origin, | |||
origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore | |||
pdus=transaction_data.get("pdus"), # type: ignore | |||
edus=transaction_data.get("edus"), | |||
) | |||
if not transaction_id: | |||
raise Exception("Transaction missing transaction_id") | |||
@@ -221,9 +231,7 @@ class FederationServer(FederationBase): | |||
logger.debug("[%s] Got transaction", transaction_id) | |||
# Reject malformed transactions early: reject if too many PDUs/EDUs | |||
if len(transaction.pdus) > 50 or ( # type: ignore | |||
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore | |||
): | |||
if len(transaction.pdus) > 50 or len(transaction.edus) > 100: | |||
logger.info("Transaction PDU or EDU count too large. Returning 400") | |||
return 400, {} | |||
@@ -263,7 +271,7 @@ class FederationServer(FederationBase): | |||
# CRITICAL SECTION: the first thing we must do (before awaiting) is | |||
# add an entry to _active_transactions. | |||
assert origin not in self._active_transactions | |||
self._active_transactions[origin] = transaction.transaction_id # type: ignore | |||
self._active_transactions[origin] = transaction.transaction_id | |||
try: | |||
result = await self._handle_incoming_transaction( | |||
@@ -291,11 +299,11 @@ class FederationServer(FederationBase): | |||
if response: | |||
logger.debug( | |||
"[%s] We've already responded to this request", | |||
transaction.transaction_id, # type: ignore | |||
transaction.transaction_id, | |||
) | |||
return response | |||
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore | |||
logger.debug("[%s] Transaction is new", transaction.transaction_id) | |||
# We process PDUs and EDUs in parallel. This is important as we don't | |||
# want to block things like to device messages from reaching clients | |||
@@ -334,7 +342,7 @@ class FederationServer(FederationBase): | |||
report back to the sending server. | |||
""" | |||
received_pdus_counter.inc(len(transaction.pdus)) # type: ignore | |||
received_pdus_counter.inc(len(transaction.pdus)) | |||
origin_host, _ = parse_server_name(origin) | |||
@@ -342,7 +350,7 @@ class FederationServer(FederationBase): | |||
newest_pdu_ts = 0 | |||
for p in transaction.pdus: # type: ignore | |||
for p in transaction.pdus: | |||
# FIXME (richardv): I don't think this works: | |||
# https://github.com/matrix-org/synapse/issues/8429 | |||
if "unsigned" in p: | |||
@@ -436,10 +444,10 @@ class FederationServer(FederationBase): | |||
return pdu_results | |||
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): | |||
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None: | |||
"""Process the EDUs in a received transaction.""" | |||
async def _process_edu(edu_dict): | |||
async def _process_edu(edu_dict: JsonDict) -> None: | |||
received_edus_counter.inc() | |||
edu = Edu( | |||
@@ -452,7 +460,7 @@ class FederationServer(FederationBase): | |||
await concurrently_execute( | |||
_process_edu, | |||
getattr(transaction, "edus", []), | |||
transaction.edus, | |||
TRANSACTION_CONCURRENCY_LIMIT, | |||
) | |||
@@ -538,7 +546,7 @@ class FederationServer(FederationBase): | |||
pdu = await self.handler.get_persisted_pdu(origin, event_id) | |||
if pdu: | |||
return 200, self._transaction_from_pdus([pdu]).get_dict() | |||
return 200, self._transaction_dict_from_pdus([pdu]) | |||
else: | |||
return 404, "" | |||
@@ -879,18 +887,20 @@ class FederationServer(FederationBase): | |||
ts_now_ms = self._clock.time_msec() | |||
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) | |||
def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: | |||
def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict: | |||
"""Returns a new Transaction containing the given PDUs suitable for | |||
transmission. | |||
""" | |||
time_now = self._clock.time_msec() | |||
pdus = [p.get_pdu_json(time_now) for p in pdu_list] | |||
return Transaction( | |||
# Just need a dummy transaction ID and destination since it won't be used. | |||
transaction_id="", | |||
origin=self.server_name, | |||
pdus=pdus, | |||
origin_server_ts=int(time_now), | |||
destination=None, | |||
) | |||
destination="", | |||
).get_dict() | |||
async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: | |||
"""Process a PDU received in a federation /send/ transaction. | |||
@@ -45,7 +45,7 @@ class TransactionActions: | |||
`None` if we have not previously responded to this transaction or a | |||
2-tuple of `(int, dict)` representing the response code and response body. | |||
""" | |||
transaction_id = transaction.transaction_id # type: ignore | |||
transaction_id = transaction.transaction_id | |||
if not transaction_id: | |||
raise RuntimeError("Cannot persist a transaction with no transaction_id") | |||
@@ -56,7 +56,7 @@ class TransactionActions: | |||
self, origin: str, transaction: Transaction, code: int, response: JsonDict | |||
) -> None: | |||
"""Persist how we responded to a transaction.""" | |||
transaction_id = transaction.transaction_id # type: ignore | |||
transaction_id = transaction.transaction_id | |||
if not transaction_id: | |||
raise RuntimeError("Cannot persist a transaction with no transaction_id") | |||
@@ -27,6 +27,7 @@ from synapse.logging.opentracing import ( | |||
tags, | |||
whitelisted_homeserver, | |||
) | |||
from synapse.types import JsonDict | |||
from synapse.util import json_decoder | |||
from synapse.util.metrics import measure_func | |||
@@ -104,13 +105,13 @@ class TransactionManager: | |||
len(edus), | |||
) | |||
transaction = Transaction.create_new( | |||
transaction = Transaction( | |||
origin_server_ts=int(self.clock.time_msec()), | |||
transaction_id=txn_id, | |||
origin=self._server_name, | |||
destination=destination, | |||
pdus=pdus, | |||
edus=edus, | |||
pdus=[p.get_pdu_json() for p in pdus], | |||
edus=[edu.get_dict() for edu in edus], | |||
) | |||
self._next_txn_id += 1 | |||
@@ -131,7 +132,7 @@ class TransactionManager: | |||
# FIXME (richardv): I also believe it no longer works. We (now?) store | |||
# "age_ts" in "unsigned" rather than at the top level. See | |||
# https://github.com/matrix-org/synapse/issues/8429. | |||
def json_data_cb(): | |||
def json_data_cb() -> JsonDict: | |||
data = transaction.get_dict() | |||
now = int(self.clock.time_msec()) | |||
if "pdus" in data: | |||
@@ -143,7 +143,7 @@ class TransportLayerClient: | |||
"""Sends the given Transaction to its destination | |||
Args: | |||
transaction (Transaction) | |||
transaction | |||
Returns: | |||
Succeeds when we get a 2xx HTTP response. The result | |||
@@ -450,21 +450,12 @@ class FederationSendServlet(BaseFederationServerServlet): | |||
len(transaction_data.get("edus", [])), | |||
) | |||
# We should ideally be getting this from the security layer. | |||
# origin = body["origin"] | |||
# Add some extra data to the transaction dict that isn't included | |||
# in the request body. | |||
transaction_data.update( | |||
transaction_id=transaction_id, destination=self.server_name | |||
) | |||
except Exception as e: | |||
logger.exception(e) | |||
return 400, {"error": "Invalid transaction"} | |||
code, response = await self.handler.on_incoming_transaction( | |||
origin, transaction_data | |||
origin, transaction_id, self.server_name, transaction_data | |||
) | |||
return code, response | |||
@@ -17,18 +17,17 @@ server protocol. | |||
""" | |||
import logging | |||
from typing import Optional | |||
from typing import List, Optional | |||
import attr | |||
from synapse.types import JsonDict | |||
from synapse.util.jsonobject import JsonEncodedObject | |||
logger = logging.getLogger(__name__) | |||
@attr.s(slots=True) | |||
class Edu(JsonEncodedObject): | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class Edu: | |||
"""An Edu represents a piece of data sent from one homeserver to another. | |||
In comparison to Pdus, Edus are not persisted for a long time on disk, are | |||
@@ -36,10 +35,10 @@ class Edu(JsonEncodedObject): | |||
internal ID or previous references graph. | |||
""" | |||
edu_type = attr.ib(type=str) | |||
content = attr.ib(type=dict) | |||
origin = attr.ib(type=str) | |||
destination = attr.ib(type=str) | |||
edu_type: str | |||
content: dict | |||
origin: str | |||
destination: str | |||
def get_dict(self) -> JsonDict: | |||
return { | |||
@@ -55,14 +54,21 @@ class Edu(JsonEncodedObject): | |||
"destination": self.destination, | |||
} | |||
def get_context(self): | |||
def get_context(self) -> str: | |||
return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") | |||
def strip_context(self): | |||
def strip_context(self) -> None: | |||
getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}" | |||
class Transaction(JsonEncodedObject): | |||
def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]: | |||
if edus is None: | |||
return [] | |||
return edus | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class Transaction: | |||
"""A transaction is a list of Pdus and Edus to be sent to a remote home | |||
server with some extra metadata. | |||
@@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject): | |||
""" | |||
valid_keys = [ | |||
"transaction_id", | |||
"origin", | |||
"destination", | |||
"origin_server_ts", | |||
"previous_ids", | |||
"pdus", | |||
"edus", | |||
] | |||
internal_keys = ["transaction_id", "destination"] | |||
required_keys = [ | |||
"transaction_id", | |||
"origin", | |||
"destination", | |||
"origin_server_ts", | |||
"pdus", | |||
] | |||
def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs): | |||
"""If we include a list of pdus then we decode then as PDU's | |||
automatically. | |||
""" | |||
# If there's no EDUs then remove the arg | |||
if "edus" in kwargs and not kwargs["edus"]: | |||
del kwargs["edus"] | |||
super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs) | |||
@staticmethod | |||
def create_new(pdus, **kwargs): | |||
"""Used to create a new transaction. Will auto fill out | |||
transaction_id and origin_server_ts keys. | |||
""" | |||
if "origin_server_ts" not in kwargs: | |||
raise KeyError("Require 'origin_server_ts' to construct a Transaction") | |||
if "transaction_id" not in kwargs: | |||
raise KeyError("Require 'transaction_id' to construct a Transaction") | |||
kwargs["pdus"] = [p.get_pdu_json() for p in pdus] | |||
return Transaction(**kwargs) | |||
# Required keys. | |||
transaction_id: str | |||
origin: str | |||
destination: str | |||
origin_server_ts: int | |||
pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) | |||
edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) | |||
def get_dict(self) -> JsonDict: | |||
"""A JSON-ready dictionary of valid keys which aren't internal.""" | |||
result = { | |||
"origin": self.origin, | |||
"origin_server_ts": self.origin_server_ts, | |||
"pdus": self.pdus, | |||
} | |||
if self.edus: | |||
result["edus"] = self.edus | |||
return result |
@@ -1,102 +0,0 @@ | |||
# Copyright 2014-2016 OpenMarket Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
class JsonEncodedObject: | |||
"""A common base class for defining protocol units that are represented | |||
as JSON. | |||
Attributes: | |||
unrecognized_keys (dict): A dict containing all the key/value pairs we | |||
don't recognize. | |||
""" | |||
valid_keys = [] # keys we will store | |||
"""A list of strings that represent keys we know about | |||
and can handle. If we have values for these keys they will be | |||
included in the `dictionary` instance variable. | |||
""" | |||
internal_keys = [] # keys to ignore while building dict | |||
"""A list of strings that should *not* be encoded into JSON. | |||
""" | |||
required_keys = [] | |||
"""A list of strings that we require to exist. If they are not given upon | |||
construction it raises an exception. | |||
""" | |||
def __init__(self, **kwargs): | |||
"""Takes the dict of `kwargs` and loads all keys that are *valid* | |||
(i.e., are included in the `valid_keys` list) into the dictionary` | |||
instance variable. | |||
Any keys that aren't recognized are added to the `unrecognized_keys` | |||
attribute. | |||
Args: | |||
**kwargs: Attributes associated with this protocol unit. | |||
""" | |||
for required_key in self.required_keys: | |||
if required_key not in kwargs: | |||
raise RuntimeError("Key %s is required" % required_key) | |||
self.unrecognized_keys = {} # Keys we were given not listed as valid | |||
for k, v in kwargs.items(): | |||
if k in self.valid_keys or k in self.internal_keys: | |||
self.__dict__[k] = v | |||
else: | |||
self.unrecognized_keys[k] = v | |||
def get_dict(self): | |||
"""Converts this protocol unit into a :py:class:`dict`, ready to be | |||
encoded as JSON. | |||
The keys it encodes are: `valid_keys` - `internal_keys` | |||
Returns | |||
dict | |||
""" | |||
d = { | |||
k: _encode(v) | |||
for (k, v) in self.__dict__.items() | |||
if k in self.valid_keys and k not in self.internal_keys | |||
} | |||
d.update(self.unrecognized_keys) | |||
return d | |||
def get_internal_dict(self): | |||
d = { | |||
k: _encode(v, internal=True) | |||
for (k, v) in self.__dict__.items() | |||
if k in self.valid_keys | |||
} | |||
d.update(self.unrecognized_keys) | |||
return d | |||
def __str__(self): | |||
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) | |||
def _encode(obj, internal=False): | |||
if type(obj) is list: | |||
return [_encode(o, internal=internal) for o in obj] | |||
if isinstance(obj, JsonEncodedObject): | |||
if internal: | |||
return obj.get_internal_dict() | |||
else: | |||
return obj.get_dict() | |||
return obj |