To improve type safety & memory usage.tags/v1.95.0rc1
@@ -0,0 +1 @@ | |||
Improve type hints. |
@@ -14,17 +14,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Dict, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Set, | |||
Tuple, | |||
) | |||
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple | |||
from synapse.api import errors | |||
from synapse.api.constants import EduTypes, EventTypes | |||
@@ -41,6 +31,7 @@ from synapse.metrics.background_process_metrics import ( | |||
run_as_background_process, | |||
wrap_as_background_process, | |||
) | |||
from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo | |||
from synapse.types import ( | |||
JsonDict, | |||
JsonMapping, | |||
@@ -1008,14 +999,14 @@ class DeviceHandler(DeviceWorkerHandler): | |||
def _update_device_from_client_ips( | |||
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] | |||
device: JsonDict, client_ips: Mapping[Tuple[str, str], DeviceLastConnectionInfo] | |||
) -> None: | |||
ip = client_ips.get((device["user_id"], device["device_id"]), {}) | |||
ip = client_ips.get((device["user_id"], device["device_id"])) | |||
device.update( | |||
{ | |||
"last_seen_user_agent": ip.get("user_agent"), | |||
"last_seen_ts": ip.get("last_seen"), | |||
"last_seen_ip": ip.get("ip"), | |||
"last_seen_user_agent": ip.user_agent if ip else None, | |||
"last_seen_ts": ip.last_seen if ip else None, | |||
"last_seen_ip": ip.ip if ip else None, | |||
} | |||
) | |||
@@ -15,6 +15,7 @@ | |||
import logging | |||
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast | |||
import attr | |||
from typing_extensions import TypedDict | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
@@ -42,7 +43,8 @@ logger = logging.getLogger(__name__) | |||
LAST_SEEN_GRANULARITY = 120 * 1000 | |||
class DeviceLastConnectionInfo(TypedDict): | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class DeviceLastConnectionInfo: | |||
"""Metadata for the last connection seen for a user and device combination""" | |||
# These types must match the columns in the `devices` table | |||
@@ -499,24 +501,29 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke | |||
device_id: If None fetches all devices for the user | |||
Returns: | |||
A dictionary mapping a tuple of (user_id, device_id) to dicts, with | |||
keys giving the column names from the devices table. | |||
A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo. | |||
""" | |||
keyvalues = {"user_id": user_id} | |||
if device_id is not None: | |||
keyvalues["device_id"] = device_id | |||
res = cast( | |||
List[DeviceLastConnectionInfo], | |||
await self.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues=keyvalues, | |||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |||
), | |||
res = await self.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues=keyvalues, | |||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |||
) | |||
return {(d["user_id"], d["device_id"]): d for d in res} | |||
return { | |||
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo( | |||
user_id=d["user_id"], | |||
device_id=d["device_id"], | |||
ip=d["ip"], | |||
user_agent=d["user_agent"], | |||
last_seen=d["last_seen"], | |||
) | |||
for d in res | |||
} | |||
async def _get_user_ip_and_agents_from_database( | |||
self, user: UserID, since_ts: int = 0 | |||
@@ -683,8 +690,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke | |||
device_id: If None fetches all devices for the user | |||
Returns: | |||
A dictionary mapping a tuple of (user_id, device_id) to dicts, with | |||
keys giving the column names from the devices table. | |||
A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo. | |||
""" | |||
ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id) | |||
@@ -705,13 +711,13 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke | |||
continue | |||
if not device_id or did == device_id: | |||
ret[(user_id, did)] = { | |||
"user_id": user_id, | |||
"ip": ip, | |||
"user_agent": user_agent, | |||
"device_id": did, | |||
"last_seen": last_seen, | |||
} | |||
ret[(user_id, did)] = DeviceLastConnectionInfo( | |||
user_id=user_id, | |||
ip=ip, | |||
user_agent=user_agent, | |||
device_id=did, | |||
last_seen=last_seen, | |||
) | |||
return ret | |||
async def get_user_ip_and_agents( | |||
@@ -24,7 +24,10 @@ import synapse.rest.admin | |||
from synapse.http.site import XForwardedForRequest | |||
from synapse.rest.client import login | |||
from synapse.server import HomeServer | |||
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY | |||
from synapse.storage.databases.main.client_ips import ( | |||
LAST_SEEN_GRANULARITY, | |||
DeviceLastConnectionInfo, | |||
) | |||
from synapse.types import UserID | |||
from synapse.util import Clock | |||
@@ -65,15 +68,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
r = result[(user_id, device_id)] | |||
self.assertLessEqual( | |||
{ | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
"ip": "ip", | |||
"user_agent": "user_agent", | |||
"last_seen": 12345678000, | |||
}.items(), | |||
r.items(), | |||
self.assertEqual( | |||
DeviceLastConnectionInfo( | |||
user_id=user_id, | |||
device_id=device_id, | |||
ip="ip", | |||
user_agent="user_agent", | |||
last_seen=12345678000, | |||
), | |||
r, | |||
) | |||
def test_insert_new_client_ip_none_device_id(self) -> None: | |||
@@ -201,13 +204,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual( | |||
result, | |||
{ | |||
(user_id, device_id): { | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
"ip": "ip", | |||
"user_agent": "user_agent", | |||
"last_seen": 12345678000, | |||
}, | |||
(user_id, device_id): DeviceLastConnectionInfo( | |||
user_id=user_id, | |||
device_id=device_id, | |||
ip="ip", | |||
user_agent="user_agent", | |||
last_seen=12345678000, | |||
), | |||
}, | |||
) | |||
@@ -292,20 +295,20 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual( | |||
result, | |||
{ | |||
(user_id, device_id_1): { | |||
"user_id": user_id, | |||
"device_id": device_id_1, | |||
"ip": "ip_1", | |||
"user_agent": "user_agent_1", | |||
"last_seen": 12345678000, | |||
}, | |||
(user_id, device_id_2): { | |||
"user_id": user_id, | |||
"device_id": device_id_2, | |||
"ip": "ip_2", | |||
"user_agent": "user_agent_3", | |||
"last_seen": 12345688000 + LAST_SEEN_GRANULARITY, | |||
}, | |||
(user_id, device_id_1): DeviceLastConnectionInfo( | |||
user_id=user_id, | |||
device_id=device_id_1, | |||
ip="ip_1", | |||
user_agent="user_agent_1", | |||
last_seen=12345678000, | |||
), | |||
(user_id, device_id_2): DeviceLastConnectionInfo( | |||
user_id=user_id, | |||
device_id=device_id_2, | |||
ip="ip_2", | |||
user_agent="user_agent_3", | |||
last_seen=12345688000 + LAST_SEEN_GRANULARITY, | |||
), | |||
}, | |||
) | |||
@@ -526,15 +529,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
r = result[(user_id, device_id)] | |||
self.assertLessEqual( | |||
{ | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
"ip": None, | |||
"user_agent": None, | |||
"last_seen": None, | |||
}.items(), | |||
r.items(), | |||
self.assertEqual( | |||
DeviceLastConnectionInfo( | |||
user_id=user_id, | |||
device_id=device_id, | |||
ip=None, | |||
user_agent=None, | |||
last_seen=None, | |||
), | |||
r, | |||
) | |||
# Register the background update to run again. | |||
@@ -561,15 +564,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
r = result[(user_id, device_id)] | |||
self.assertLessEqual( | |||
{ | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
"ip": "ip", | |||
"user_agent": "user_agent", | |||
"last_seen": 0, | |||
}.items(), | |||
r.items(), | |||
self.assertEqual( | |||
DeviceLastConnectionInfo( | |||
user_id=user_id, | |||
device_id=device_id, | |||
ip="ip", | |||
user_agent="user_agent", | |||
last_seen=0, | |||
), | |||
r, | |||
) | |||
def test_old_user_ips_pruned(self) -> None: | |||
@@ -640,15 +643,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
r = result2[(user_id, device_id)] | |||
self.assertLessEqual( | |||
{ | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
"ip": "ip", | |||
"user_agent": "user_agent", | |||
"last_seen": 0, | |||
}.items(), | |||
r.items(), | |||
self.assertEqual( | |||
DeviceLastConnectionInfo( | |||
user_id=user_id, | |||
device_id=device_id, | |||
ip="ip", | |||
user_agent="user_agent", | |||
last_seen=0, | |||
), | |||
r, | |||
) | |||
def test_invalid_user_agents_are_ignored(self) -> None: | |||
@@ -777,13 +780,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): | |||
self.store.get_last_client_ip_by_device(self.user_id, device_id) | |||
) | |||
r = result[(self.user_id, device_id)] | |||
self.assertLessEqual( | |||
{ | |||
"user_id": self.user_id, | |||
"device_id": device_id, | |||
"ip": expected_ip, | |||
"user_agent": "Mozzila pizza", | |||
"last_seen": 123456100, | |||
}.items(), | |||
r.items(), | |||
self.assertEqual( | |||
DeviceLastConnectionInfo( | |||
user_id=self.user_id, | |||
device_id=device_id, | |||
ip=expected_ip, | |||
user_agent="Mozzila pizza", | |||
last_seen=123456100, | |||
), | |||
r, | |||
) |