From 13e9cad537a16108b0cb544ccdc24e7dc2ca33ae Mon Sep 17 00:00:00 2001 From: Marcel Date: Wed, 6 Sep 2023 21:19:17 +0200 Subject: [PATCH] Send the opentracing span information to appservices (#16227) --- changelog.d/16227.feature | 1 + synapse/appservice/api.py | 32 ++++++++++++++++++++++++-------- tests/appservice/test_api.py | 18 ++++++++++++------ 3 files changed, 37 insertions(+), 14 deletions(-) create mode 100644 changelog.d/16227.feature diff --git a/changelog.d/16227.feature b/changelog.d/16227.feature new file mode 100644 index 0000000000..510062b622 --- /dev/null +++ b/changelog.d/16227.feature @@ -0,0 +1 @@ +Add span information to requests sent to appservices. Contributed by MTRNord. \ No newline at end of file diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index de7a94bf26..b1523be208 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -40,6 +40,7 @@ from synapse.appservice import ( from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient, is_unknown_endpoint +from synapse.logging import opentracing from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache @@ -125,6 +126,17 @@ class ApplicationServiceApi(SimpleHttpClient): hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) + def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]: + """This makes sure we have always the auth header and opentracing headers set.""" + + # This is also ensured before in the functions. However this is needed to please + # the typechecks. + assert service.hs_token is not None + + headers = {b"Authorization": [b"Bearer " + service.hs_token.encode("ascii")]} + opentracing.inject_header_dict(headers, check_destination=False) + return headers + async def query_user(self, service: "ApplicationService", user_id: str) -> bool: if service.url is None: return False @@ -136,10 +148,11 @@ class ApplicationServiceApi(SimpleHttpClient): args = None if self.config.use_appservice_legacy_authorization: args = {"access_token": service.hs_token} + response = await self.get_json( f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}", args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if response is not None: # just an empty json object return True @@ -162,10 +175,11 @@ class ApplicationServiceApi(SimpleHttpClient): args = None if self.config.use_appservice_legacy_authorization: args = {"access_token": service.hs_token} + response = await self.get_json( f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}", args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if response is not None: # just an empty json object return True @@ -203,10 +217,11 @@ class ApplicationServiceApi(SimpleHttpClient): **fields, b"access_token": service.hs_token, } + response = await self.get_json( f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}", args=args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if not isinstance(response, list): logger.warning( @@ -243,10 +258,11 @@ class ApplicationServiceApi(SimpleHttpClient): args = None if self.config.use_appservice_legacy_authorization: args = {"access_token": service.hs_token} + info = await self.get_json( f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}", args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if not _is_valid_3pe_metadata(info): @@ -283,7 +299,7 @@ class ApplicationServiceApi(SimpleHttpClient): await self.post_json_get_json( uri=f"{service.url}{APP_SERVICE_PREFIX}/ping", post_json={"transaction_id": txn_id}, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) async def push_bulk( @@ -364,7 +380,7 @@ class ApplicationServiceApi(SimpleHttpClient): f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}", json_body=body, args=args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -437,7 +453,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.post_json_get_json( uri, body, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) except HttpResponseException as e: # The appservice doesn't support this endpoint. @@ -498,7 +514,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.post_json_get_json( uri, query, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) except HttpResponseException as e: # The appservice doesn't support this endpoint. diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 75fb5fae6b..366b6fd5f0 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -76,7 +76,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], ) -> List[JsonDict]: # Ensure the access token is passed as a header. - if not headers or not headers.get("Authorization"): + if not headers or not headers.get(b"Authorization"): raise RuntimeError("Access token not provided") # ... and not as a query param if b"access_token" in args: @@ -84,7 +84,9 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): "Access token should not be passed as a query param." ) - self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) + self.assertEqual( + headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()] + ) self.request_url = url if url == URL_USER: return SUCCESS_RESULT_USER @@ -152,11 +154,13 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): # Ensure the access token is passed as a both a query param and in the headers. if not args.get(b"access_token"): raise RuntimeError("Access token should be provided in query params.") - if not headers or not headers.get("Authorization"): + if not headers or not headers.get(b"Authorization"): raise RuntimeError("Access token should be provided in auth headers.") self.assertEqual(args.get(b"access_token"), TOKEN) - self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) + self.assertEqual( + headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()] + ) self.request_url = url if url == URL_USER: return SUCCESS_RESULT_USER @@ -208,10 +212,12 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], ) -> JsonDict: # Ensure the access token is passed as both a header and query arg. - if not headers.get("Authorization"): + if not headers.get(b"Authorization"): raise RuntimeError("Access token not provided") - self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) + self.assertEqual( + headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()] + ) return RESPONSE # We assign to a method, which mypy doesn't like.