@@ -14,7 +14,7 @@ matrix: | |||
- python: 2.7 | |||
env: TOX_ENV=packaging | |||
- python: 2.7 | |||
- python: 3.6 | |||
env: TOX_ENV=pep8 | |||
- python: 2.7 | |||
@@ -0,0 +1 @@ | |||
Make the Python scripts in the top-level scripts folders meet pep8 and pass flake8. |
@@ -1,21 +1,20 @@ | |||
from synapse.events import FrozenEvent | |||
from synapse.api.auth import Auth | |||
from mock import Mock | |||
from __future__ import print_function | |||
import argparse | |||
import itertools | |||
import json | |||
import sys | |||
from mock import Mock | |||
from synapse.api.auth import Auth | |||
from synapse.events import FrozenEvent | |||
def check_auth(auth, auth_chain, events): | |||
auth_chain.sort(key=lambda e: e.depth) | |||
auth_map = { | |||
e.event_id: e | |||
for e in auth_chain | |||
} | |||
auth_map = {e.event_id: e for e in auth_chain} | |||
create_events = {} | |||
for e in auth_chain: | |||
@@ -25,31 +24,26 @@ def check_auth(auth, auth_chain, events): | |||
for e in itertools.chain(auth_chain, events): | |||
auth_events_list = [auth_map[i] for i, _ in e.auth_events] | |||
auth_events = { | |||
(e.type, e.state_key): e | |||
for e in auth_events_list | |||
} | |||
auth_events = {(e.type, e.state_key): e for e in auth_events_list} | |||
auth_events[("m.room.create", "")] = create_events[e.room_id] | |||
try: | |||
auth.check(e, auth_events=auth_events) | |||
except Exception as ex: | |||
print "Failed:", e.event_id, e.type, e.state_key | |||
print "Auth_events:", auth_events | |||
print ex | |||
print json.dumps(e.get_dict(), sort_keys=True, indent=4) | |||
print("Failed:", e.event_id, e.type, e.state_key) | |||
print("Auth_events:", auth_events) | |||
print(ex) | |||
print(json.dumps(e.get_dict(), sort_keys=True, indent=4)) | |||
# raise | |||
print "Success:", e.event_id, e.type, e.state_key | |||
print("Success:", e.event_id, e.type, e.state_key) | |||
if __name__ == '__main__': | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument( | |||
'json', | |||
nargs='?', | |||
type=argparse.FileType('r'), | |||
default=sys.stdin, | |||
'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin | |||
) | |||
args = parser.parse_args() | |||
@@ -1,10 +1,15 @@ | |||
from synapse.crypto.event_signing import * | |||
from unpaddedbase64 import encode_base64 | |||
import argparse | |||
import hashlib | |||
import sys | |||
import json | |||
import logging | |||
import sys | |||
from unpaddedbase64 import encode_base64 | |||
from synapse.crypto.event_signing import ( | |||
check_event_content_hash, | |||
compute_event_reference_hash, | |||
) | |||
class dictobj(dict): | |||
@@ -24,27 +29,26 @@ class dictobj(dict): | |||
def main(): | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), | |||
default=sys.stdin) | |||
parser.add_argument( | |||
"input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin | |||
) | |||
args = parser.parse_args() | |||
logging.basicConfig() | |||
event_json = dictobj(json.load(args.input_json)) | |||
algorithms = { | |||
"sha256": hashlib.sha256, | |||
} | |||
algorithms = {"sha256": hashlib.sha256} | |||
for alg_name in event_json.hashes: | |||
if check_event_content_hash(event_json, algorithms[alg_name]): | |||
print "PASS content hash %s" % (alg_name,) | |||
print("PASS content hash %s" % (alg_name,)) | |||
else: | |||
print "FAIL content hash %s" % (alg_name,) | |||
print("FAIL content hash %s" % (alg_name,)) | |||
for algorithm in algorithms.values(): | |||
name, h_bytes = compute_event_reference_hash(event_json, algorithm) | |||
print "Reference hash %s: %s" % (name, encode_base64(h_bytes)) | |||
print("Reference hash %s: %s" % (name, encode_base64(h_bytes))) | |||
if __name__=="__main__": | |||
main() | |||
if __name__ == "__main__": | |||
main() |
@@ -1,15 +1,15 @@ | |||
from signedjson.sign import verify_signed_json | |||
from signedjson.key import decode_verify_key_bytes, write_signing_keys | |||
from unpaddedbase64 import decode_base64 | |||
import urllib2 | |||
import argparse | |||
import json | |||
import logging | |||
import sys | |||
import urllib2 | |||
import dns.resolver | |||
import pprint | |||
import argparse | |||
import logging | |||
from signedjson.key import decode_verify_key_bytes, write_signing_keys | |||
from signedjson.sign import verify_signed_json | |||
from unpaddedbase64 import decode_base64 | |||
def get_targets(server_name): | |||
if ":" in server_name: | |||
@@ -23,6 +23,7 @@ def get_targets(server_name): | |||
except dns.resolver.NXDOMAIN: | |||
yield (server_name, 8448) | |||
def get_server_keys(server_name, target, port): | |||
url = "https://%s:%i/_matrix/key/v1" % (target, port) | |||
keys = json.load(urllib2.urlopen(url)) | |||
@@ -33,12 +34,14 @@ def get_server_keys(server_name, target, port): | |||
verify_keys[key_id] = verify_key | |||
return verify_keys | |||
def main(): | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument("signature_name") | |||
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), | |||
default=sys.stdin) | |||
parser.add_argument( | |||
"input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin | |||
) | |||
args = parser.parse_args() | |||
logging.basicConfig() | |||
@@ -48,24 +51,23 @@ def main(): | |||
for target, port in get_targets(server_name): | |||
try: | |||
keys = get_server_keys(server_name, target, port) | |||
print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port) | |||
print("Using keys from https://%s:%s/_matrix/key/v1" % (target, port)) | |||
write_signing_keys(sys.stdout, keys.values()) | |||
break | |||
except: | |||
except Exception: | |||
logging.exception("Error talking to %s:%s", target, port) | |||
json_to_check = json.load(args.input_json) | |||
print "Checking JSON:" | |||
print("Checking JSON:") | |||
for key_id in json_to_check["signatures"][args.signature_name]: | |||
try: | |||
key = keys[key_id] | |||
verify_signed_json(json_to_check, args.signature_name, key) | |||
print "PASS %s" % (key_id,) | |||
except: | |||
print("PASS %s" % (key_id,)) | |||
except Exception: | |||
logging.exception("Check for key %s failed" % (key_id,)) | |||
print "FAIL %s" % (key_id,) | |||
print("FAIL %s" % (key_id,)) | |||
if __name__ == '__main__': | |||
main() | |||
@@ -1,13 +1,21 @@ | |||
import psycopg2 | |||
import yaml | |||
import sys | |||
import hashlib | |||
import json | |||
import sys | |||
import time | |||
import hashlib | |||
from unpaddedbase64 import encode_base64 | |||
import six | |||
import psycopg2 | |||
import yaml | |||
from canonicaljson import encode_canonical_json | |||
from signedjson.key import read_signing_keys | |||
from signedjson.sign import sign_json | |||
from canonicaljson import encode_canonical_json | |||
from unpaddedbase64 import encode_base64 | |||
if six.PY2: | |||
db_type = six.moves.builtins.buffer | |||
else: | |||
db_type = memoryview | |||
def select_v1_keys(connection): | |||
@@ -39,7 +47,9 @@ def select_v2_json(connection): | |||
cursor.close() | |||
results = {} | |||
for server_name, key_id, key_json in rows: | |||
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8")) | |||
results.setdefault(server_name, {})[key_id] = json.loads( | |||
str(key_json).decode("utf-8") | |||
) | |||
return results | |||
@@ -47,10 +57,7 @@ def convert_v1_to_v2(server_name, valid_until, keys, certificate): | |||
return { | |||
"old_verify_keys": {}, | |||
"server_name": server_name, | |||
"verify_keys": { | |||
key_id: {"key": key} | |||
for key_id, key in keys.items() | |||
}, | |||
"verify_keys": {key_id: {"key": key} for key_id, key in keys.items()}, | |||
"valid_until_ts": valid_until, | |||
"tls_fingerprints": [fingerprint(certificate)], | |||
} | |||
@@ -65,7 +72,7 @@ def rows_v2(server, json): | |||
valid_until = json["valid_until_ts"] | |||
key_json = encode_canonical_json(json) | |||
for key_id in json["verify_keys"]: | |||
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json)) | |||
yield (server, key_id, "-", valid_until, valid_until, db_type(key_json)) | |||
def main(): | |||
@@ -87,7 +94,7 @@ def main(): | |||
result = {} | |||
for server in keys: | |||
if not server in json: | |||
if server not in json: | |||
v2_json = convert_v1_to_v2( | |||
server, valid_until, keys[server], certificates[server] | |||
) | |||
@@ -96,10 +103,7 @@ def main(): | |||
yaml.safe_dump(result, sys.stdout, default_flow_style=False) | |||
rows = list( | |||
row for server, json in result.items() | |||
for row in rows_v2(server, json) | |||
) | |||
rows = list(row for server, json in result.items() for row in rows_v2(server, json)) | |||
cursor = connection.cursor() | |||
cursor.executemany( | |||
@@ -107,7 +111,7 @@ def main(): | |||
" server_name, key_id, from_server," | |||
" ts_added_ms, ts_valid_until_ms, key_json" | |||
") VALUES (%s, %s, %s, %s, %s, %s)", | |||
rows | |||
rows, | |||
) | |||
connection.commit() | |||
@@ -1,8 +1,16 @@ | |||
#! /usr/bin/python | |||
from __future__ import print_function | |||
import argparse | |||
import ast | |||
import os | |||
import re | |||
import sys | |||
import yaml | |||
class DefinitionVisitor(ast.NodeVisitor): | |||
def __init__(self): | |||
super(DefinitionVisitor, self).__init__() | |||
@@ -42,15 +50,18 @@ def non_empty(defs): | |||
functions = {name: non_empty(f) for name, f in defs['def'].items()} | |||
classes = {name: non_empty(f) for name, f in defs['class'].items()} | |||
result = {} | |||
if functions: result['def'] = functions | |||
if classes: result['class'] = classes | |||
if functions: | |||
result['def'] = functions | |||
if classes: | |||
result['class'] = classes | |||
names = defs['names'] | |||
uses = [] | |||
for name in names.get('Load', ()): | |||
if name not in names.get('Param', ()) and name not in names.get('Store', ()): | |||
uses.append(name) | |||
uses.extend(defs['attrs']) | |||
if uses: result['uses'] = uses | |||
if uses: | |||
result['uses'] = uses | |||
result['names'] = names | |||
result['attrs'] = defs['attrs'] | |||
return result | |||
@@ -95,7 +106,6 @@ def used_names(prefix, item, defs, names): | |||
if __name__ == '__main__': | |||
import sys, os, argparse, re | |||
parser = argparse.ArgumentParser(description='Find definitions.') | |||
parser.add_argument( | |||
@@ -105,24 +115,28 @@ if __name__ == '__main__': | |||
"--ignore", action="append", metavar="REGEXP", help="Ignore a pattern" | |||
) | |||
parser.add_argument( | |||
"--pattern", action="append", metavar="REGEXP", | |||
help="Search for a pattern" | |||
"--pattern", action="append", metavar="REGEXP", help="Search for a pattern" | |||
) | |||
parser.add_argument( | |||
"directories", nargs='+', metavar="DIR", | |||
help="Directories to search for definitions" | |||
"directories", | |||
nargs='+', | |||
metavar="DIR", | |||
help="Directories to search for definitions", | |||
) | |||
parser.add_argument( | |||
"--referrers", default=0, type=int, | |||
help="Include referrers up to the given depth" | |||
"--referrers", | |||
default=0, | |||
type=int, | |||
help="Include referrers up to the given depth", | |||
) | |||
parser.add_argument( | |||
"--referred", default=0, type=int, | |||
help="Include referred down to the given depth" | |||
"--referred", | |||
default=0, | |||
type=int, | |||
help="Include referred down to the given depth", | |||
) | |||
parser.add_argument( | |||
"--format", default="yaml", | |||
help="Output format, one of 'yaml' or 'dot'" | |||
"--format", default="yaml", help="Output format, one of 'yaml' or 'dot'" | |||
) | |||
args = parser.parse_args() | |||
@@ -162,7 +176,7 @@ if __name__ == '__main__': | |||
for used_by in entry.get("used", ()): | |||
referrers.add(used_by) | |||
for name, definition in names.items(): | |||
if not name in referrers: | |||
if name not in referrers: | |||
continue | |||
if ignore and any(pattern.match(name) for pattern in ignore): | |||
continue | |||
@@ -176,7 +190,7 @@ if __name__ == '__main__': | |||
for uses in entry.get("uses", ()): | |||
referred.add(uses) | |||
for name, definition in names.items(): | |||
if not name in referred: | |||
if name not in referred: | |||
continue | |||
if ignore and any(pattern.match(name) for pattern in ignore): | |||
continue | |||
@@ -185,12 +199,12 @@ if __name__ == '__main__': | |||
if args.format == 'yaml': | |||
yaml.dump(result, sys.stdout, default_flow_style=False) | |||
elif args.format == 'dot': | |||
print "digraph {" | |||
print("digraph {") | |||
for name, entry in result.items(): | |||
print name | |||
print(name) | |||
for used_by in entry.get("used", ()): | |||
if used_by in result: | |||
print used_by, "->", name | |||
print "}" | |||
print(used_by, "->", name) | |||
print("}") | |||
else: | |||
raise ValueError("Unknown format %r" % (args.format)) |
@@ -1,8 +1,11 @@ | |||
#!/usr/bin/env python2 | |||
import pymacaroons | |||
from __future__ import print_function | |||
import sys | |||
import pymacaroons | |||
if len(sys.argv) == 1: | |||
sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],)) | |||
sys.exit(1) | |||
@@ -11,14 +14,14 @@ macaroon_string = sys.argv[1] | |||
key = sys.argv[2] if len(sys.argv) > 2 else None | |||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_string) | |||
print macaroon.inspect() | |||
print(macaroon.inspect()) | |||
print "" | |||
print("") | |||
verifier = pymacaroons.Verifier() | |||
verifier.satisfy_general(lambda c: True) | |||
try: | |||
verifier.verify(macaroon, key) | |||
print "Signature is correct" | |||
print("Signature is correct") | |||
except Exception as e: | |||
print str(e) | |||
print(str(e)) |
@@ -18,21 +18,21 @@ | |||
from __future__ import print_function | |||
import argparse | |||
import base64 | |||
import json | |||
import sys | |||
from urlparse import urlparse, urlunparse | |||
import nacl.signing | |||
import json | |||
import base64 | |||
import requests | |||
import sys | |||
from requests.adapters import HTTPAdapter | |||
import srvlookup | |||
import yaml | |||
from requests.adapters import HTTPAdapter | |||
# uncomment the following to enable debug logging of http requests | |||
#from httplib import HTTPConnection | |||
#HTTPConnection.debuglevel = 1 | |||
# from httplib import HTTPConnection | |||
# HTTPConnection.debuglevel = 1 | |||
def encode_base64(input_bytes): | |||
"""Encode bytes as a base64 string without any padding.""" | |||
@@ -58,15 +58,15 @@ def decode_base64(input_string): | |||
def encode_canonical_json(value): | |||
return json.dumps( | |||
value, | |||
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes | |||
ensure_ascii=False, | |||
# Remove unecessary white space. | |||
separators=(',',':'), | |||
# Sort the keys of dictionaries. | |||
sort_keys=True, | |||
# Encode the resulting unicode as UTF-8 bytes. | |||
).encode("UTF-8") | |||
value, | |||
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes | |||
ensure_ascii=False, | |||
# Remove unecessary white space. | |||
separators=(',', ':'), | |||
# Sort the keys of dictionaries. | |||
sort_keys=True, | |||
# Encode the resulting unicode as UTF-8 bytes. | |||
).encode("UTF-8") | |||
def sign_json(json_object, signing_key, signing_name): | |||
@@ -88,6 +88,7 @@ def sign_json(json_object, signing_key, signing_name): | |||
NACL_ED25519 = "ed25519" | |||
def decode_signing_key_base64(algorithm, version, key_base64): | |||
"""Decode a base64 encoded signing key | |||
Args: | |||
@@ -143,14 +144,12 @@ def request_json(method, origin_name, origin_key, destination, path, content): | |||
authorization_headers = [] | |||
for key, sig in signed_json["signatures"][origin_name].items(): | |||
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( | |||
origin_name, key, sig, | |||
) | |||
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig) | |||
authorization_headers.append(bytes(header)) | |||
print ("Authorization: %s" % header, file=sys.stderr) | |||
print("Authorization: %s" % header, file=sys.stderr) | |||
dest = "matrix://%s%s" % (destination, path) | |||
print ("Requesting %s" % dest, file=sys.stderr) | |||
print("Requesting %s" % dest, file=sys.stderr) | |||
s = requests.Session() | |||
s.mount("matrix://", MatrixConnectionAdapter()) | |||
@@ -158,10 +157,7 @@ def request_json(method, origin_name, origin_key, destination, path, content): | |||
result = s.request( | |||
method=method, | |||
url=dest, | |||
headers={ | |||
"Host": destination, | |||
"Authorization": authorization_headers[0] | |||
}, | |||
headers={"Host": destination, "Authorization": authorization_headers[0]}, | |||
verify=False, | |||
data=content, | |||
) | |||
@@ -171,50 +167,50 @@ def request_json(method, origin_name, origin_key, destination, path, content): | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description= | |||
"Signs and sends a federation request to a matrix homeserver", | |||
description="Signs and sends a federation request to a matrix homeserver" | |||
) | |||
parser.add_argument( | |||
"-N", "--server-name", | |||
"-N", | |||
"--server-name", | |||
help="Name to give as the local homeserver. If unspecified, will be " | |||
"read from the config file.", | |||
"read from the config file.", | |||
) | |||
parser.add_argument( | |||
"-k", "--signing-key-path", | |||
"-k", | |||
"--signing-key-path", | |||
help="Path to the file containing the private ed25519 key to sign the " | |||
"request with.", | |||
"request with.", | |||
) | |||
parser.add_argument( | |||
"-c", "--config", | |||
"-c", | |||
"--config", | |||
default="homeserver.yaml", | |||
help="Path to server config file. Ignored if --server-name and " | |||
"--signing-key-path are both given.", | |||
"--signing-key-path are both given.", | |||
) | |||
parser.add_argument( | |||
"-d", "--destination", | |||
"-d", | |||
"--destination", | |||
default="matrix.org", | |||
help="name of the remote homeserver. We will do SRV lookups and " | |||
"connect appropriately.", | |||
"connect appropriately.", | |||
) | |||
parser.add_argument( | |||
"-X", "--method", | |||
"-X", | |||
"--method", | |||
help="HTTP method to use for the request. Defaults to GET if --data is" | |||
"unspecified, POST if it is." | |||
"unspecified, POST if it is.", | |||
) | |||
parser.add_argument( | |||
"--body", | |||
help="Data to send as the body of the HTTP request" | |||
) | |||
parser.add_argument("--body", help="Data to send as the body of the HTTP request") | |||
parser.add_argument( | |||
"path", | |||
help="request path. We will add '/_matrix/federation/v1/' to this." | |||
"path", help="request path. We will add '/_matrix/federation/v1/' to this." | |||
) | |||
args = parser.parse_args() | |||
@@ -227,13 +223,15 @@ def main(): | |||
result = request_json( | |||
args.method, | |||
args.server_name, key, args.destination, | |||
args.server_name, | |||
key, | |||
args.destination, | |||
"/_matrix/federation/v1/" + args.path, | |||
content=args.body, | |||
) | |||
json.dump(result, sys.stdout) | |||
print ("") | |||
print("") | |||
def read_args_from_config(args): | |||
@@ -253,7 +251,7 @@ class MatrixConnectionAdapter(HTTPAdapter): | |||
return s, 8448 | |||
if ":" in s: | |||
out = s.rsplit(":",1) | |||
out = s.rsplit(":", 1) | |||
try: | |||
port = int(out[1]) | |||
except ValueError: | |||
@@ -263,7 +261,7 @@ class MatrixConnectionAdapter(HTTPAdapter): | |||
try: | |||
srv = srvlookup.lookup("matrix", "tcp", s)[0] | |||
return srv.host, srv.port | |||
except: | |||
except Exception: | |||
return s, 8448 | |||
def get_connection(self, url, proxies=None): | |||
@@ -272,10 +270,9 @@ class MatrixConnectionAdapter(HTTPAdapter): | |||
(host, port) = self.lookup(parsed.netloc) | |||
netloc = "%s:%d" % (host, port) | |||
print("Connecting to %s" % (netloc,), file=sys.stderr) | |||
url = urlunparse(( | |||
"https", netloc, parsed.path, parsed.params, parsed.query, | |||
parsed.fragment, | |||
)) | |||
url = urlunparse( | |||
("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment) | |||
) | |||
return super(MatrixConnectionAdapter, self).get_connection(url, proxies) | |||
@@ -1,23 +1,31 @@ | |||
from synapse.storage.pdu import PduStore | |||
from synapse.storage.signatures import SignatureStore | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.federation.units import Pdu | |||
from synapse.crypto.event_signing import ( | |||
add_event_pdu_content_hash, compute_pdu_event_reference_hash | |||
) | |||
from synapse.api.events.utils import prune_pdu | |||
from unpaddedbase64 import encode_base64, decode_base64 | |||
from canonicaljson import encode_canonical_json | |||
from __future__ import print_function | |||
import sqlite3 | |||
import sys | |||
from unpaddedbase64 import decode_base64, encode_base64 | |||
from synapse.crypto.event_signing import ( | |||
add_event_pdu_content_hash, | |||
compute_pdu_event_reference_hash, | |||
) | |||
from synapse.federation.units import Pdu | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.pdu import PduStore | |||
from synapse.storage.signatures import SignatureStore | |||
class Store(object): | |||
_get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"] | |||
_get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"] | |||
_get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"] | |||
_get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"] | |||
_get_pdu_origin_signatures_txn = SignatureStore.__dict__[ | |||
"_get_pdu_origin_signatures_txn" | |||
] | |||
_store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"] | |||
_store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"] | |||
_store_pdu_reference_hash_txn = SignatureStore.__dict__[ | |||
"_store_pdu_reference_hash_txn" | |||
] | |||
_store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] | |||
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] | |||
@@ -26,9 +34,7 @@ store = Store() | |||
def select_pdus(cursor): | |||
cursor.execute( | |||
"SELECT pdu_id, origin FROM pdus ORDER BY depth ASC" | |||
) | |||
cursor.execute("SELECT pdu_id, origin FROM pdus ORDER BY depth ASC") | |||
ids = cursor.fetchall() | |||
@@ -41,23 +47,30 @@ def select_pdus(cursor): | |||
for pdu in pdus: | |||
try: | |||
if pdu.prev_pdus: | |||
print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus | |||
print("PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus) | |||
for pdu_id, origin, hashes in pdu.prev_pdus: | |||
ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)] | |||
hashes[ref_alg] = encode_base64(ref_hsh) | |||
store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh) | |||
print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus | |||
store._store_prev_pdu_hash_txn( | |||
cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh | |||
) | |||
print("SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus) | |||
pdu = add_event_pdu_content_hash(pdu) | |||
ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu) | |||
reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh) | |||
store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh) | |||
store._store_pdu_reference_hash_txn( | |||
cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh | |||
) | |||
for alg, hsh_base64 in pdu.hashes.items(): | |||
print alg, hsh_base64 | |||
store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)) | |||
print(alg, hsh_base64) | |||
store._store_pdu_content_hash_txn( | |||
cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64) | |||
) | |||
except Exception: | |||
print("FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus) | |||
except: | |||
print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus | |||
def main(): | |||
conn = sqlite3.connect(sys.argv[1]) | |||
@@ -65,5 +78,6 @@ def main(): | |||
select_pdus(cursor) | |||
conn.commit() | |||
if __name__=='__main__': | |||
if __name__ == '__main__': | |||
main() |
@@ -1,18 +1,17 @@ | |||
#! /usr/bin/python | |||
import ast | |||
import argparse | |||
import ast | |||
import os | |||
import sys | |||
import yaml | |||
PATTERNS_V1 = [] | |||
PATTERNS_V2 = [] | |||
RESULT = { | |||
"v1": PATTERNS_V1, | |||
"v2": PATTERNS_V2, | |||
} | |||
RESULT = {"v1": PATTERNS_V1, "v2": PATTERNS_V2} | |||
class CallVisitor(ast.NodeVisitor): | |||
def visit_Call(self, node): | |||
@@ -21,7 +20,6 @@ class CallVisitor(ast.NodeVisitor): | |||
else: | |||
return | |||
if name == "client_path_patterns": | |||
PATTERNS_V1.append(node.args[0].s) | |||
elif name == "client_v2_patterns": | |||
@@ -42,8 +40,10 @@ def find_patterns_in_file(filepath): | |||
parser = argparse.ArgumentParser(description='Find url patterns.') | |||
parser.add_argument( | |||
"directories", nargs='+', metavar="DIR", | |||
help="Directories to search for definitions" | |||
"directories", | |||
nargs='+', | |||
metavar="DIR", | |||
help="Directories to search for definitions", | |||
) | |||
args = parser.parse_args() | |||
@@ -1,8 +1,9 @@ | |||
import requests | |||
import collections | |||
import json | |||
import sys | |||
import time | |||
import json | |||
import requests | |||
Entry = collections.namedtuple("Entry", "name position rows") | |||
@@ -30,11 +31,11 @@ def parse_response(content): | |||
def replicate(server, streams): | |||
return parse_response(requests.get( | |||
server + "/_synapse/replication", | |||
verify=False, | |||
params=streams | |||
).content) | |||
return parse_response( | |||
requests.get( | |||
server + "/_synapse/replication", verify=False, params=streams | |||
).content | |||
) | |||
def main(): | |||
@@ -45,7 +46,7 @@ def main(): | |||
try: | |||
streams = { | |||
row.name: row.position | |||
for row in replicate(server, {"streams":"-1"})["streams"].rows | |||
for row in replicate(server, {"streams": "-1"})["streams"].rows | |||
} | |||
except requests.exceptions.ConnectionError as e: | |||
time.sleep(0.1) | |||
@@ -53,8 +54,8 @@ def main(): | |||
while True: | |||
try: | |||
results = replicate(server, streams) | |||
except: | |||
sys.stdout.write("connection_lost("+ repr(streams) + ")\n") | |||
except Exception: | |||
sys.stdout.write("connection_lost(" + repr(streams) + ")\n") | |||
break | |||
for update in results.values(): | |||
for row in update.rows: | |||
@@ -62,6 +63,5 @@ def main(): | |||
streams[update.name] = update.position | |||
if __name__=='__main__': | |||
if __name__ == '__main__': | |||
main() |
@@ -1,12 +1,10 @@ | |||
#!/usr/bin/env python | |||
import argparse | |||
import getpass | |||
import sys | |||
import bcrypt | |||
import getpass | |||
import yaml | |||
bcrypt_rounds=12 | |||
@@ -52,4 +50,3 @@ if __name__ == "__main__": | |||
password = prompt_for_pass() | |||
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds)) | |||
@@ -36,12 +36,9 @@ from __future__ import print_function | |||
import argparse | |||
import logging | |||
import sys | |||
import os | |||
import shutil | |||
import sys | |||
from synapse.rest.media.v1.filepath import MediaFilePaths | |||
@@ -77,24 +74,23 @@ def move_media(origin_server, file_id, src_paths, dest_paths): | |||
if not os.path.exists(original_file): | |||
logger.warn( | |||
"Original for %s/%s (%s) does not exist", | |||
origin_server, file_id, original_file, | |||
origin_server, | |||
file_id, | |||
original_file, | |||
) | |||
else: | |||
mkdir_and_move( | |||
original_file, | |||
dest_paths.remote_media_filepath(origin_server, file_id), | |||
original_file, dest_paths.remote_media_filepath(origin_server, file_id) | |||
) | |||
# now look for thumbnails | |||
original_thumb_dir = src_paths.remote_media_thumbnail_dir( | |||
origin_server, file_id, | |||
) | |||
original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id) | |||
if not os.path.exists(original_thumb_dir): | |||
return | |||
mkdir_and_move( | |||
original_thumb_dir, | |||
dest_paths.remote_media_thumbnail_dir(origin_server, file_id) | |||
dest_paths.remote_media_thumbnail_dir(origin_server, file_id), | |||
) | |||
@@ -109,24 +105,16 @@ def mkdir_and_move(original_file, dest_file): | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser( | |||
description=__doc__, | |||
formatter_class = argparse.RawDescriptionHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"-v", action='store_true', help='enable debug logging') | |||
parser.add_argument( | |||
"src_repo", | |||
help="Path to source content repo", | |||
) | |||
parser.add_argument( | |||
"dest_repo", | |||
help="Path to source content repo", | |||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter | |||
) | |||
parser.add_argument("-v", action='store_true', help='enable debug logging') | |||
parser.add_argument("src_repo", help="Path to source content repo") | |||
parser.add_argument("dest_repo", help="Path to source content repo") | |||
args = parser.parse_args() | |||
logging_config = { | |||
"level": logging.DEBUG if args.v else logging.INFO, | |||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" | |||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", | |||
} | |||
logging.basicConfig(**logging_config) | |||
@@ -14,6 +14,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from __future__ import print_function | |||
import argparse | |||
import getpass | |||
@@ -22,19 +23,23 @@ import hmac | |||
import json | |||
import sys | |||
import urllib2 | |||
from six import input | |||
import yaml | |||
def request_registration(user, password, server_location, shared_secret, admin=False): | |||
req = urllib2.Request( | |||
"%s/_matrix/client/r0/admin/register" % (server_location,), | |||
headers={'Content-Type': 'application/json'} | |||
headers={'Content-Type': 'application/json'}, | |||
) | |||
try: | |||
if sys.version_info[:3] >= (2, 7, 9): | |||
# As of version 2.7.9, urllib2 now checks SSL certs | |||
import ssl | |||
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) | |||
else: | |||
f = urllib2.urlopen(req) | |||
@@ -42,18 +47,15 @@ def request_registration(user, password, server_location, shared_secret, admin=F | |||
f.close() | |||
nonce = json.loads(body)["nonce"] | |||
except urllib2.HTTPError as e: | |||
print "ERROR! Received %d %s" % (e.code, e.reason,) | |||
print("ERROR! Received %d %s" % (e.code, e.reason)) | |||
if 400 <= e.code < 500: | |||
if e.info().type == "application/json": | |||
resp = json.load(e) | |||
if "error" in resp: | |||
print resp["error"] | |||
print(resp["error"]) | |||
sys.exit(1) | |||
mac = hmac.new( | |||
key=shared_secret, | |||
digestmod=hashlib.sha1, | |||
) | |||
mac = hmac.new(key=shared_secret, digestmod=hashlib.sha1) | |||
mac.update(nonce) | |||
mac.update("\x00") | |||
@@ -75,30 +77,31 @@ def request_registration(user, password, server_location, shared_secret, admin=F | |||
server_location = server_location.rstrip("/") | |||
print "Sending registration request..." | |||
print("Sending registration request...") | |||
req = urllib2.Request( | |||
"%s/_matrix/client/r0/admin/register" % (server_location,), | |||
data=json.dumps(data), | |||
headers={'Content-Type': 'application/json'} | |||
headers={'Content-Type': 'application/json'}, | |||
) | |||
try: | |||
if sys.version_info[:3] >= (2, 7, 9): | |||
# As of version 2.7.9, urllib2 now checks SSL certs | |||
import ssl | |||
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) | |||
else: | |||
f = urllib2.urlopen(req) | |||
f.read() | |||
f.close() | |||
print "Success." | |||
print("Success.") | |||
except urllib2.HTTPError as e: | |||
print "ERROR! Received %d %s" % (e.code, e.reason,) | |||
print("ERROR! Received %d %s" % (e.code, e.reason)) | |||
if 400 <= e.code < 500: | |||
if e.info().type == "application/json": | |||
resp = json.load(e) | |||
if "error" in resp: | |||
print resp["error"] | |||
print(resp["error"]) | |||
sys.exit(1) | |||
@@ -106,35 +109,35 @@ def register_new_user(user, password, server_location, shared_secret, admin): | |||
if not user: | |||
try: | |||
default_user = getpass.getuser() | |||
except: | |||
except Exception: | |||
default_user = None | |||
if default_user: | |||
user = raw_input("New user localpart [%s]: " % (default_user,)) | |||
user = input("New user localpart [%s]: " % (default_user,)) | |||
if not user: | |||
user = default_user | |||
else: | |||
user = raw_input("New user localpart: ") | |||
user = input("New user localpart: ") | |||
if not user: | |||
print "Invalid user name" | |||
print("Invalid user name") | |||
sys.exit(1) | |||
if not password: | |||
password = getpass.getpass("Password: ") | |||
if not password: | |||
print "Password cannot be blank." | |||
print("Password cannot be blank.") | |||
sys.exit(1) | |||
confirm_password = getpass.getpass("Confirm password: ") | |||
if password != confirm_password: | |||
print "Passwords do not match" | |||
print("Passwords do not match") | |||
sys.exit(1) | |||
if admin is None: | |||
admin = raw_input("Make admin [no]: ") | |||
admin = input("Make admin [no]: ") | |||
if admin in ("y", "yes", "true"): | |||
admin = True | |||
else: | |||
@@ -146,42 +149,51 @@ def register_new_user(user, password, server_location, shared_secret, admin): | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser( | |||
description="Used to register new users with a given home server when" | |||
" registration has been disabled. The home server must be" | |||
" configured with the 'registration_shared_secret' option" | |||
" set.", | |||
" registration has been disabled. The home server must be" | |||
" configured with the 'registration_shared_secret' option" | |||
" set." | |||
) | |||
parser.add_argument( | |||
"-u", "--user", | |||
"-u", | |||
"--user", | |||
default=None, | |||
help="Local part of the new user. Will prompt if omitted.", | |||
) | |||
parser.add_argument( | |||
"-p", "--password", | |||
"-p", | |||
"--password", | |||
default=None, | |||
help="New password for user. Will prompt if omitted.", | |||
) | |||
admin_group = parser.add_mutually_exclusive_group() | |||
admin_group.add_argument( | |||
"-a", "--admin", | |||
"-a", | |||
"--admin", | |||
action="store_true", | |||
help="Register new user as an admin. Will prompt if --no-admin is not set either.", | |||
help=( | |||
"Register new user as an admin. " | |||
"Will prompt if --no-admin is not set either." | |||
), | |||
) | |||
admin_group.add_argument( | |||
"--no-admin", | |||
action="store_true", | |||
help="Register new user as a regular user. Will prompt if --admin is not set either.", | |||
help=( | |||
"Register new user as a regular user. " | |||
"Will prompt if --admin is not set either." | |||
), | |||
) | |||
group = parser.add_mutually_exclusive_group(required=True) | |||
group.add_argument( | |||
"-c", "--config", | |||
"-c", | |||
"--config", | |||
type=argparse.FileType('r'), | |||
help="Path to server config file. Used to read in shared secret.", | |||
) | |||
group.add_argument( | |||
"-k", "--shared-secret", | |||
help="Shared secret as defined in server config file.", | |||
"-k", "--shared-secret", help="Shared secret as defined in server config file." | |||
) | |||
parser.add_argument( | |||
@@ -189,7 +201,7 @@ if __name__ == "__main__": | |||
default="https://localhost:8448", | |||
nargs='?', | |||
help="URL to use to talk to the home server. Defaults to " | |||
" 'https://localhost:8448'.", | |||
" 'https://localhost:8448'.", | |||
) | |||
args = parser.parse_args() | |||
@@ -198,7 +210,7 @@ if __name__ == "__main__": | |||
config = yaml.safe_load(args.config) | |||
secret = config.get("registration_shared_secret", None) | |||
if not secret: | |||
print "No 'registration_shared_secret' defined in config." | |||
print("No 'registration_shared_secret' defined in config.") | |||
sys.exit(1) | |||
else: | |||
secret = args.shared_secret | |||
@@ -15,23 +15,23 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.internet import defer, reactor | |||
from twisted.enterprise import adbapi | |||
from synapse.storage._base import LoggingTransaction, SQLBaseStore | |||
from synapse.storage.engines import create_engine | |||
from synapse.storage.prepare_database import prepare_database | |||
import argparse | |||
import curses | |||
import logging | |||
import sys | |||
import time | |||
import traceback | |||
import yaml | |||
from six import string_types | |||
import yaml | |||
from twisted.enterprise import adbapi | |||
from twisted.internet import defer, reactor | |||
from synapse.storage._base import LoggingTransaction, SQLBaseStore | |||
from synapse.storage.engines import create_engine | |||
from synapse.storage.prepare_database import prepare_database | |||
logger = logging.getLogger("synapse_port_db") | |||
@@ -105,6 +105,7 @@ class Store(object): | |||
*All* database interactions should go through this object. | |||
""" | |||
def __init__(self, db_pool, engine): | |||
self.db_pool = db_pool | |||
self.database_engine = engine | |||
@@ -135,7 +136,8 @@ class Store(object): | |||
txn = conn.cursor() | |||
return func( | |||
LoggingTransaction(txn, desc, self.database_engine, [], []), | |||
*args, **kwargs | |||
*args, | |||
**kwargs | |||
) | |||
except self.database_engine.module.DatabaseError as e: | |||
if self.database_engine.is_deadlock(e): | |||
@@ -158,22 +160,20 @@ class Store(object): | |||
def r(txn): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
return self.runInteraction("execute_sql", r) | |||
def insert_many_txn(self, txn, table, headers, rows): | |||
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( | |||
table, | |||
", ".join(k for k in headers), | |||
", ".join("%s" for _ in headers) | |||
", ".join("%s" for _ in headers), | |||
) | |||
try: | |||
txn.executemany(sql, rows) | |||
except: | |||
logger.exception( | |||
"Failed to insert: %s", | |||
table, | |||
) | |||
except Exception: | |||
logger.exception("Failed to insert: %s", table) | |||
raise | |||
@@ -206,7 +206,7 @@ class Porter(object): | |||
"table_name": table, | |||
"forward_rowid": 1, | |||
"backward_rowid": 0, | |||
} | |||
}, | |||
) | |||
forward_chunk = 1 | |||
@@ -221,10 +221,10 @@ class Porter(object): | |||
table, forward_chunk, backward_chunk | |||
) | |||
else: | |||
def delete_all(txn): | |||
txn.execute( | |||
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", | |||
(table,) | |||
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,) | |||
) | |||
txn.execute("TRUNCATE %s CASCADE" % (table,)) | |||
@@ -232,11 +232,7 @@ class Porter(object): | |||
yield self.postgres_store._simple_insert( | |||
table="port_from_sqlite3", | |||
values={ | |||
"table_name": table, | |||
"forward_rowid": 1, | |||
"backward_rowid": 0, | |||
} | |||
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, | |||
) | |||
forward_chunk = 1 | |||
@@ -251,12 +247,16 @@ class Porter(object): | |||
) | |||
@defer.inlineCallbacks | |||
def handle_table(self, table, postgres_size, table_size, forward_chunk, | |||
backward_chunk): | |||
def handle_table( | |||
self, table, postgres_size, table_size, forward_chunk, backward_chunk | |||
): | |||
logger.info( | |||
"Table %s: %i/%i (rows %i-%i) already ported", | |||
table, postgres_size, table_size, | |||
backward_chunk+1, forward_chunk-1, | |||
table, | |||
postgres_size, | |||
table_size, | |||
backward_chunk + 1, | |||
forward_chunk - 1, | |||
) | |||
if not table_size: | |||
@@ -271,7 +271,9 @@ class Porter(object): | |||
return | |||
if table in ( | |||
"user_directory", "user_directory_search", "users_who_share_rooms", | |||
"user_directory", | |||
"user_directory_search", | |||
"users_who_share_rooms", | |||
"users_in_pubic_room", | |||
): | |||
# We don't port these tables, as they're a faff and we can regenreate | |||
@@ -283,37 +285,35 @@ class Porter(object): | |||
# We need to make sure there is a single row, `(X, null), as that is | |||
# what synapse expects to be there. | |||
yield self.postgres_store._simple_insert( | |||
table=table, | |||
values={"stream_id": None}, | |||
table=table, values={"stream_id": None} | |||
) | |||
self.progress.update(table, table_size) # Mark table as done | |||
return | |||
forward_select = ( | |||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" | |||
% (table,) | |||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,) | |||
) | |||
backward_select = ( | |||
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" | |||
% (table,) | |||
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,) | |||
) | |||
do_forward = [True] | |||
do_backward = [True] | |||
while True: | |||
def r(txn): | |||
forward_rows = [] | |||
backward_rows = [] | |||
if do_forward[0]: | |||
txn.execute(forward_select, (forward_chunk, self.batch_size,)) | |||
txn.execute(forward_select, (forward_chunk, self.batch_size)) | |||
forward_rows = txn.fetchall() | |||
if not forward_rows: | |||
do_forward[0] = False | |||
if do_backward[0]: | |||
txn.execute(backward_select, (backward_chunk, self.batch_size,)) | |||
txn.execute(backward_select, (backward_chunk, self.batch_size)) | |||
backward_rows = txn.fetchall() | |||
if not backward_rows: | |||
do_backward[0] = False | |||
@@ -325,9 +325,7 @@ class Porter(object): | |||
return headers, forward_rows, backward_rows | |||
headers, frows, brows = yield self.sqlite_store.runInteraction( | |||
"select", r | |||
) | |||
headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) | |||
if frows or brows: | |||
if frows: | |||
@@ -339,9 +337,7 @@ class Porter(object): | |||
rows = self._convert_rows(table, headers, rows) | |||
def insert(txn): | |||
self.postgres_store.insert_many_txn( | |||
txn, table, headers[1:], rows | |||
) | |||
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) | |||
self.postgres_store._simple_update_one_txn( | |||
txn, | |||
@@ -362,8 +358,9 @@ class Porter(object): | |||
return | |||
@defer.inlineCallbacks | |||
def handle_search_table(self, postgres_size, table_size, forward_chunk, | |||
backward_chunk): | |||
def handle_search_table( | |||
self, postgres_size, table_size, forward_chunk, backward_chunk | |||
): | |||
select = ( | |||
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" | |||
" FROM event_search as es" | |||
@@ -373,8 +370,9 @@ class Porter(object): | |||
) | |||
while True: | |||
def r(txn): | |||
txn.execute(select, (forward_chunk, self.batch_size,)) | |||
txn.execute(select, (forward_chunk, self.batch_size)) | |||
rows = txn.fetchall() | |||
headers = [column[0] for column in txn.description] | |||
@@ -402,18 +400,21 @@ class Porter(object): | |||
else: | |||
rows_dict.append(d) | |||
txn.executemany(sql, [ | |||
( | |||
row["event_id"], | |||
row["room_id"], | |||
row["key"], | |||
row["sender"], | |||
row["value"], | |||
row["origin_server_ts"], | |||
row["stream_ordering"], | |||
) | |||
for row in rows_dict | |||
]) | |||
txn.executemany( | |||
sql, | |||
[ | |||
( | |||
row["event_id"], | |||
row["room_id"], | |||
row["key"], | |||
row["sender"], | |||
row["value"], | |||
row["origin_server_ts"], | |||
row["stream_ordering"], | |||
) | |||
for row in rows_dict | |||
], | |||
) | |||
self.postgres_store._simple_update_one_txn( | |||
txn, | |||
@@ -437,7 +438,8 @@ class Porter(object): | |||
def setup_db(self, db_config, database_engine): | |||
db_conn = database_engine.module.connect( | |||
**{ | |||
k: v for k, v in db_config.get("args", {}).items() | |||
k: v | |||
for k, v in db_config.get("args", {}).items() | |||
if not k.startswith("cp_") | |||
} | |||
) | |||
@@ -450,13 +452,11 @@ class Porter(object): | |||
def run(self): | |||
try: | |||
sqlite_db_pool = adbapi.ConnectionPool( | |||
self.sqlite_config["name"], | |||
**self.sqlite_config["args"] | |||
self.sqlite_config["name"], **self.sqlite_config["args"] | |||
) | |||
postgres_db_pool = adbapi.ConnectionPool( | |||
self.postgres_config["name"], | |||
**self.postgres_config["args"] | |||
self.postgres_config["name"], **self.postgres_config["args"] | |||
) | |||
sqlite_engine = create_engine(sqlite_config) | |||
@@ -465,9 +465,7 @@ class Porter(object): | |||
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) | |||
self.postgres_store = Store(postgres_db_pool, postgres_engine) | |||
yield self.postgres_store.execute( | |||
postgres_engine.check_database | |||
) | |||
yield self.postgres_store.execute(postgres_engine.check_database) | |||
# Step 1. Set up databases. | |||
self.progress.set_state("Preparing SQLite3") | |||
@@ -477,6 +475,7 @@ class Porter(object): | |||
self.setup_db(postgres_config, postgres_engine) | |||
self.progress.set_state("Creating port tables") | |||
def create_port_table(txn): | |||
txn.execute( | |||
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" | |||
@@ -501,9 +500,7 @@ class Porter(object): | |||
) | |||
try: | |||
yield self.postgres_store.runInteraction( | |||
"alter_table", alter_table | |||
) | |||
yield self.postgres_store.runInteraction("alter_table", alter_table) | |||
except Exception as e: | |||
pass | |||
@@ -514,11 +511,7 @@ class Porter(object): | |||
# Step 2. Get tables. | |||
self.progress.set_state("Fetching tables") | |||
sqlite_tables = yield self.sqlite_store._simple_select_onecol( | |||
table="sqlite_master", | |||
keyvalues={ | |||
"type": "table", | |||
}, | |||
retcol="name", | |||
table="sqlite_master", keyvalues={"type": "table"}, retcol="name" | |||
) | |||
postgres_tables = yield self.postgres_store._simple_select_onecol( | |||
@@ -545,18 +538,14 @@ class Porter(object): | |||
# Step 4. Do the copying. | |||
self.progress.set_state("Copying to postgres") | |||
yield defer.gatherResults( | |||
[ | |||
self.handle_table(*res) | |||
for res in setup_res | |||
], | |||
consumeErrors=True, | |||
[self.handle_table(*res) for res in setup_res], consumeErrors=True | |||
) | |||
# Step 5. Do final post-processing | |||
yield self._setup_state_group_id_seq() | |||
self.progress.done() | |||
except: | |||
except Exception: | |||
global end_error_exec_info | |||
end_error_exec_info = sys.exc_info() | |||
logger.exception("") | |||
@@ -566,9 +555,7 @@ class Porter(object): | |||
def _convert_rows(self, table, headers, rows): | |||
bool_col_names = BOOLEAN_COLUMNS.get(table, []) | |||
bool_cols = [ | |||
i for i, h in enumerate(headers) if h in bool_col_names | |||
] | |||
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names] | |||
class BadValueException(Exception): | |||
pass | |||
@@ -577,18 +564,21 @@ class Porter(object): | |||
if j in bool_cols: | |||
return bool(col) | |||
elif isinstance(col, string_types) and "\0" in col: | |||
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col) | |||
raise BadValueException(); | |||
logger.warn( | |||
"DROPPING ROW: NUL value in table %s col %s: %r", | |||
table, | |||
headers[j], | |||
col, | |||
) | |||
raise BadValueException() | |||
return col | |||
outrows = [] | |||
for i, row in enumerate(rows): | |||
try: | |||
outrows.append(tuple( | |||
conv(j, col) | |||
for j, col in enumerate(row) | |||
if j > 0 | |||
)) | |||
outrows.append( | |||
tuple(conv(j, col) for j, col in enumerate(row) if j > 0) | |||
) | |||
except BadValueException: | |||
pass | |||
@@ -616,9 +606,7 @@ class Porter(object): | |||
return headers, [r for r in rows if r[ts_ind] < yesterday] | |||
headers, rows = yield self.sqlite_store.runInteraction( | |||
"select", r, | |||
) | |||
headers, rows = yield self.sqlite_store.runInteraction("select", r) | |||
rows = self._convert_rows("sent_transactions", headers, rows) | |||
@@ -639,7 +627,7 @@ class Porter(object): | |||
txn.execute( | |||
"SELECT rowid FROM sent_transactions WHERE ts >= ?" | |||
" ORDER BY rowid ASC LIMIT 1", | |||
(yesterday,) | |||
(yesterday,), | |||
) | |||
rows = txn.fetchall() | |||
@@ -657,21 +645,17 @@ class Porter(object): | |||
"table_name": "sent_transactions", | |||
"forward_rowid": next_chunk, | |||
"backward_rowid": 0, | |||
} | |||
}, | |||
) | |||
def get_sent_table_size(txn): | |||
txn.execute( | |||
"SELECT count(*) FROM sent_transactions" | |||
" WHERE ts >= ?", | |||
(yesterday,) | |||
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) | |||
) | |||
size, = txn.fetchone() | |||
return int(size) | |||
remaining_count = yield self.sqlite_store.execute( | |||
get_sent_table_size | |||
) | |||
remaining_count = yield self.sqlite_store.execute(get_sent_table_size) | |||
total_count = remaining_count + inserted_rows | |||
@@ -680,13 +664,11 @@ class Porter(object): | |||
@defer.inlineCallbacks | |||
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): | |||
frows = yield self.sqlite_store.execute_sql( | |||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), | |||
forward_chunk, | |||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk | |||
) | |||
brows = yield self.sqlite_store.execute_sql( | |||
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), | |||
backward_chunk, | |||
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk | |||
) | |||
defer.returnValue(frows[0][0] + brows[0][0]) | |||
@@ -694,7 +676,7 @@ class Porter(object): | |||
@defer.inlineCallbacks | |||
def _get_already_ported_count(self, table): | |||
rows = yield self.postgres_store.execute_sql( | |||
"SELECT count(*) FROM %s" % (table,), | |||
"SELECT count(*) FROM %s" % (table,) | |||
) | |||
defer.returnValue(rows[0][0]) | |||
@@ -717,22 +699,21 @@ class Porter(object): | |||
def _setup_state_group_id_seq(self): | |||
def r(txn): | |||
txn.execute("SELECT MAX(id) FROM state_groups") | |||
next_id = txn.fetchone()[0]+1 | |||
txn.execute( | |||
"ALTER SEQUENCE state_group_id_seq RESTART WITH %s", | |||
(next_id,), | |||
) | |||
next_id = txn.fetchone()[0] + 1 | |||
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) | |||
return self.postgres_store.runInteraction("setup_state_group_id_seq", r) | |||
############################################## | |||
###### The following is simply UI stuff ###### | |||
# The following is simply UI stuff | |||
############################################## | |||
class Progress(object): | |||
"""Used to report progress of the port | |||
""" | |||
def __init__(self): | |||
self.tables = {} | |||
@@ -758,6 +739,7 @@ class Progress(object): | |||
class CursesProgress(Progress): | |||
"""Reports progress to a curses window | |||
""" | |||
def __init__(self, stdscr): | |||
self.stdscr = stdscr | |||
@@ -801,7 +783,7 @@ class CursesProgress(Progress): | |||
duration = int(now) - int(self.start_time) | |||
minutes, seconds = divmod(duration, 60) | |||
duration_str = '%02dm %02ds' % (minutes, seconds,) | |||
duration_str = '%02dm %02ds' % (minutes, seconds) | |||
if self.finished: | |||
status = "Time spent: %s (Done!)" % (duration_str,) | |||
@@ -814,16 +796,12 @@ class CursesProgress(Progress): | |||
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) | |||
else: | |||
est_remaining_str = "Unknown" | |||
status = ( | |||
"Time spent: %s (est. remaining: %s)" | |||
% (duration_str, est_remaining_str,) | |||
status = "Time spent: %s (est. remaining: %s)" % ( | |||
duration_str, | |||
est_remaining_str, | |||
) | |||
self.stdscr.addstr( | |||
0, 0, | |||
status, | |||
curses.A_BOLD, | |||
) | |||
self.stdscr.addstr(0, 0, status, curses.A_BOLD) | |||
max_len = max([len(t) for t in self.tables.keys()]) | |||
@@ -831,9 +809,7 @@ class CursesProgress(Progress): | |||
middle_space = 1 | |||
items = self.tables.items() | |||
items.sort( | |||
key=lambda i: (i[1]["perc"], i[0]), | |||
) | |||
items.sort(key=lambda i: (i[1]["perc"], i[0])) | |||
for i, (table, data) in enumerate(items): | |||
if i + 2 >= rows: | |||
@@ -844,9 +820,7 @@ class CursesProgress(Progress): | |||
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) | |||
self.stdscr.addstr( | |||
i + 2, left_margin + max_len - len(table), | |||
table, | |||
curses.A_BOLD | color, | |||
i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color | |||
) | |||
size = 20 | |||
@@ -857,15 +831,13 @@ class CursesProgress(Progress): | |||
) | |||
self.stdscr.addstr( | |||
i + 2, left_margin + max_len + middle_space, | |||
i + 2, | |||
left_margin + max_len + middle_space, | |||
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), | |||
) | |||
if self.finished: | |||
self.stdscr.addstr( | |||
rows - 1, 0, | |||
"Press any key to exit...", | |||
) | |||
self.stdscr.addstr(rows - 1, 0, "Press any key to exit...") | |||
self.stdscr.refresh() | |||
self.last_update = time.time() | |||
@@ -877,29 +849,25 @@ class CursesProgress(Progress): | |||
def set_state(self, state): | |||
self.stdscr.clear() | |||
self.stdscr.addstr( | |||
0, 0, | |||
state + "...", | |||
curses.A_BOLD, | |||
) | |||
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD) | |||
self.stdscr.refresh() | |||
class TerminalProgress(Progress): | |||
"""Just prints progress to the terminal | |||
""" | |||
def update(self, table, num_done): | |||
super(TerminalProgress, self).update(table, num_done) | |||
data = self.tables[table] | |||
print "%s: %d%% (%d/%d)" % ( | |||
table, data["perc"], | |||
data["num_done"], data["total"], | |||
print( | |||
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"]) | |||
) | |||
def set_state(self, state): | |||
print state + "..." | |||
print(state + "...") | |||
############################################## | |||
@@ -909,34 +877,38 @@ class TerminalProgress(Progress): | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser( | |||
description="A script to port an existing synapse SQLite database to" | |||
" a new PostgreSQL database." | |||
" a new PostgreSQL database." | |||
) | |||
parser.add_argument("-v", action='store_true') | |||
parser.add_argument( | |||
"--sqlite-database", required=True, | |||
"--sqlite-database", | |||
required=True, | |||
help="The snapshot of the SQLite database file. This must not be" | |||
" currently used by a running synapse server" | |||
" currently used by a running synapse server", | |||
) | |||
parser.add_argument( | |||
"--postgres-config", type=argparse.FileType('r'), required=True, | |||
help="The database config file for the PostgreSQL database" | |||
"--postgres-config", | |||
type=argparse.FileType('r'), | |||
required=True, | |||
help="The database config file for the PostgreSQL database", | |||
) | |||
parser.add_argument( | |||
"--curses", action='store_true', | |||
help="display a curses based progress UI" | |||
"--curses", action='store_true', help="display a curses based progress UI" | |||
) | |||
parser.add_argument( | |||
"--batch-size", type=int, default=1000, | |||
"--batch-size", | |||
type=int, | |||
default=1000, | |||
help="The number of rows to select from the SQLite table each" | |||
" iteration [default=1000]", | |||
" iteration [default=1000]", | |||
) | |||
args = parser.parse_args() | |||
logging_config = { | |||
"level": logging.DEBUG if args.v else logging.INFO, | |||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" | |||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", | |||
} | |||
if args.curses: | |||
@@ -568,7 +568,7 @@ def run(hs): | |||
clock.call_later(5 * 60, start_phone_stats_home) | |||
if hs.config.daemonize and hs.config.print_pidfile: | |||
print (hs.config.pid_file) | |||
print(hs.config.pid_file) | |||
_base.start_reactor( | |||
"synapse-homeserver", | |||
@@ -28,7 +28,7 @@ if __name__ == "__main__": | |||
sys.stderr.write("\n" + str(e) + "\n") | |||
sys.exit(1) | |||
print (getattr(config, key)) | |||
print(getattr(config, key)) | |||
sys.exit(0) | |||
else: | |||
sys.stderr.write("Unknown command %r\n" % (action,)) | |||
@@ -106,10 +106,7 @@ class Config(object): | |||
@classmethod | |||
def check_file(cls, file_path, config_name): | |||
if file_path is None: | |||
raise ConfigError( | |||
"Missing config for %s." | |||
% (config_name,) | |||
) | |||
raise ConfigError("Missing config for %s." % (config_name,)) | |||
try: | |||
os.stat(file_path) | |||
except OSError as e: | |||
@@ -128,9 +125,7 @@ class Config(object): | |||
if e.errno != errno.EEXIST: | |||
raise | |||
if not os.path.isdir(dir_path): | |||
raise ConfigError( | |||
"%s is not a directory" % (dir_path,) | |||
) | |||
raise ConfigError("%s is not a directory" % (dir_path,)) | |||
return dir_path | |||
@classmethod | |||
@@ -156,21 +151,20 @@ class Config(object): | |||
return results | |||
def generate_config( | |||
self, | |||
config_dir_path, | |||
server_name, | |||
is_generating_file, | |||
report_stats=None, | |||
self, config_dir_path, server_name, is_generating_file, report_stats=None | |||
): | |||
default_config = "# vim:ft=yaml\n" | |||
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( | |||
"default_config", | |||
config_dir_path=config_dir_path, | |||
server_name=server_name, | |||
is_generating_file=is_generating_file, | |||
report_stats=report_stats, | |||
)) | |||
default_config += "\n\n".join( | |||
dedent(conf) | |||
for conf in self.invoke_all( | |||
"default_config", | |||
config_dir_path=config_dir_path, | |||
server_name=server_name, | |||
is_generating_file=is_generating_file, | |||
report_stats=report_stats, | |||
) | |||
) | |||
config = yaml.load(default_config) | |||
@@ -178,23 +172,22 @@ class Config(object): | |||
@classmethod | |||
def load_config(cls, description, argv): | |||
config_parser = argparse.ArgumentParser( | |||
description=description, | |||
) | |||
config_parser = argparse.ArgumentParser(description=description) | |||
config_parser.add_argument( | |||
"-c", "--config-path", | |||
"-c", | |||
"--config-path", | |||
action="append", | |||
metavar="CONFIG_FILE", | |||
help="Specify config file. Can be given multiple times and" | |||
" may specify directories containing *.yaml files." | |||
" may specify directories containing *.yaml files.", | |||
) | |||
config_parser.add_argument( | |||
"--keys-directory", | |||
metavar="DIRECTORY", | |||
help="Where files such as certs and signing keys are stored when" | |||
" their location is given explicitly in the config." | |||
" Defaults to the directory containing the last config file", | |||
" their location is given explicitly in the config." | |||
" Defaults to the directory containing the last config file", | |||
) | |||
config_args = config_parser.parse_args(argv) | |||
@@ -203,9 +196,7 @@ class Config(object): | |||
obj = cls() | |||
obj.read_config_files( | |||
config_files, | |||
keys_directory=config_args.keys_directory, | |||
generate_keys=False, | |||
config_files, keys_directory=config_args.keys_directory, generate_keys=False | |||
) | |||
return obj | |||
@@ -213,38 +204,38 @@ class Config(object): | |||
def load_or_generate_config(cls, description, argv): | |||
config_parser = argparse.ArgumentParser(add_help=False) | |||
config_parser.add_argument( | |||
"-c", "--config-path", | |||
"-c", | |||
"--config-path", | |||
action="append", | |||
metavar="CONFIG_FILE", | |||
help="Specify config file. Can be given multiple times and" | |||
" may specify directories containing *.yaml files." | |||
" may specify directories containing *.yaml files.", | |||
) | |||
config_parser.add_argument( | |||
"--generate-config", | |||
action="store_true", | |||
help="Generate a config file for the server name" | |||
help="Generate a config file for the server name", | |||
) | |||
config_parser.add_argument( | |||
"--report-stats", | |||
action="store", | |||
help="Whether the generated config reports anonymized usage statistics", | |||
choices=["yes", "no"] | |||
choices=["yes", "no"], | |||
) | |||
config_parser.add_argument( | |||
"--generate-keys", | |||
action="store_true", | |||
help="Generate any missing key files then exit" | |||
help="Generate any missing key files then exit", | |||
) | |||
config_parser.add_argument( | |||
"--keys-directory", | |||
metavar="DIRECTORY", | |||
help="Used with 'generate-*' options to specify where files such as" | |||
" certs and signing keys should be stored in, unless explicitly" | |||
" specified in the config." | |||
" certs and signing keys should be stored in, unless explicitly" | |||
" specified in the config.", | |||
) | |||
config_parser.add_argument( | |||
"-H", "--server-name", | |||
help="The server name to generate a config file for" | |||
"-H", "--server-name", help="The server name to generate a config file for" | |||
) | |||
config_args, remaining_args = config_parser.parse_known_args(argv) | |||
@@ -257,8 +248,8 @@ class Config(object): | |||
if config_args.generate_config: | |||
if config_args.report_stats is None: | |||
config_parser.error( | |||
"Please specify either --report-stats=yes or --report-stats=no\n\n" + | |||
MISSING_REPORT_STATS_SPIEL | |||
"Please specify either --report-stats=yes or --report-stats=no\n\n" | |||
+ MISSING_REPORT_STATS_SPIEL | |||
) | |||
if not config_files: | |||
config_parser.error( | |||
@@ -287,26 +278,32 @@ class Config(object): | |||
config_dir_path=config_dir_path, | |||
server_name=server_name, | |||
report_stats=(config_args.report_stats == "yes"), | |||
is_generating_file=True | |||
is_generating_file=True, | |||
) | |||
obj.invoke_all("generate_files", config) | |||
config_file.write(config_str) | |||
print(( | |||
"A config file has been generated in %r for server name" | |||
" %r with corresponding SSL keys and self-signed" | |||
" certificates. Please review this file and customise it" | |||
" to your needs." | |||
) % (config_path, server_name)) | |||
print( | |||
( | |||
"A config file has been generated in %r for server name" | |||
" %r with corresponding SSL keys and self-signed" | |||
" certificates. Please review this file and customise it" | |||
" to your needs." | |||
) | |||
% (config_path, server_name) | |||
) | |||
print( | |||
"If this server name is incorrect, you will need to" | |||
" regenerate the SSL certificates" | |||
) | |||
return | |||
else: | |||
print(( | |||
"Config file %r already exists. Generating any missing key" | |||
" files." | |||
) % (config_path,)) | |||
print( | |||
( | |||
"Config file %r already exists. Generating any missing key" | |||
" files." | |||
) | |||
% (config_path,) | |||
) | |||
generate_keys = True | |||
parser = argparse.ArgumentParser( | |||
@@ -338,8 +335,7 @@ class Config(object): | |||
return obj | |||
def read_config_files(self, config_files, keys_directory=None, | |||
generate_keys=False): | |||
def read_config_files(self, config_files, keys_directory=None, generate_keys=False): | |||
if not keys_directory: | |||
keys_directory = os.path.dirname(config_files[-1]) | |||
@@ -364,8 +360,9 @@ class Config(object): | |||
if "report_stats" not in config: | |||
raise ConfigError( | |||
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + | |||
MISSING_REPORT_STATS_SPIEL | |||
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS | |||
+ "\n" | |||
+ MISSING_REPORT_STATS_SPIEL | |||
) | |||
if generate_keys: | |||
@@ -399,16 +396,16 @@ def find_config_files(search_paths): | |||
for entry in os.listdir(config_path): | |||
entry_path = os.path.join(config_path, entry) | |||
if not os.path.isfile(entry_path): | |||
print ( | |||
"Found subdirectory in config directory: %r. IGNORING." | |||
) % (entry_path, ) | |||
err = "Found subdirectory in config directory: %r. IGNORING." | |||
print(err % (entry_path,)) | |||
continue | |||
if not entry.endswith(".yaml"): | |||
print ( | |||
"Found file in config directory that does not" | |||
" end in '.yaml': %r. IGNORING." | |||
) % (entry_path, ) | |||
err = ( | |||
"Found file in config directory that does not end in " | |||
"'.yaml': %r. IGNORING." | |||
) | |||
print(err % (entry_path,)) | |||
continue | |||
files.append(entry_path) | |||
@@ -18,7 +18,7 @@ import threading | |||
import time | |||
from six import PY2, iteritems, iterkeys, itervalues | |||
from six.moves import intern, range | |||
from six.moves import builtins, intern, range | |||
from canonicaljson import json | |||
from prometheus_client import Histogram | |||
@@ -1233,7 +1233,7 @@ def db_to_json(db_content): | |||
# psycopg2 on Python 2 returns buffer objects, which we need to cast to | |||
# bytes to decode | |||
if PY2 and isinstance(db_content, buffer): | |||
if PY2 and isinstance(db_content, builtins.buffer): | |||
db_content = bytes(db_content) | |||
# Decode it to a Unicode string before feeding it to json.loads, so we | |||
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) | |||
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, | |||
# despite being deprecated and removed in favor of memoryview | |||
if six.PY2: | |||
db_binary_type = buffer | |||
db_binary_type = six.moves.builtins.buffer | |||
else: | |||
db_binary_type = memoryview | |||
@@ -29,7 +29,7 @@ from ._base import SQLBaseStore | |||
logger = logging.getLogger(__name__) | |||
if six.PY2: | |||
db_binary_type = buffer | |||
db_binary_type = six.moves.builtins.buffer | |||
else: | |||
db_binary_type = memoryview | |||
@@ -27,7 +27,7 @@ from ._base import SQLBaseStore | |||
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, | |||
# despite being deprecated and removed in favor of memoryview | |||
if six.PY2: | |||
db_binary_type = buffer | |||
db_binary_type = six.moves.builtins.buffer | |||
else: | |||
db_binary_type = memoryview | |||
@@ -30,7 +30,7 @@ from ._base import SQLBaseStore, db_to_json | |||
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, | |||
# despite being deprecated and removed in favor of memoryview | |||
if six.PY2: | |||
db_binary_type = buffer | |||
db_binary_type = six.moves.builtins.buffer | |||
else: | |||
db_binary_type = memoryview | |||
@@ -15,6 +15,8 @@ | |||
import logging | |||
from six import integer_types | |||
from sortedcontainers import SortedDict | |||
from synapse.util import caches | |||
@@ -47,7 +49,7 @@ class StreamChangeCache(object): | |||
def has_entity_changed(self, entity, stream_pos): | |||
"""Returns True if the entity may have been updated since stream_pos | |||
""" | |||
assert type(stream_pos) is int or type(stream_pos) is long | |||
assert type(stream_pos) in integer_types | |||
if stream_pos < self._earliest_known_stream_pos: | |||
self.metrics.inc_misses() | |||
@@ -76,8 +76,7 @@ def start(configfile): | |||
try: | |||
subprocess.check_call(args) | |||
write("started synapse.app.homeserver(%r)" % | |||
(configfile,), colour=GREEN) | |||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) | |||
except subprocess.CalledProcessError as e: | |||
write( | |||
"error starting (exit code: %d); see above for logs" % e.returncode, | |||
@@ -86,21 +85,15 @@ def start(configfile): | |||
def start_worker(app, configfile, worker_configfile): | |||
args = [ | |||
sys.executable, "-B", | |||
"-m", app, | |||
"-c", configfile, | |||
"-c", worker_configfile | |||
] | |||
args = [sys.executable, "-B", "-m", app, "-c", configfile, "-c", worker_configfile] | |||
try: | |||
subprocess.check_call(args) | |||
write("started %s(%r)" % (app, worker_configfile), colour=GREEN) | |||
except subprocess.CalledProcessError as e: | |||
write( | |||
"error starting %s(%r) (exit code: %d); see above for logs" % ( | |||
app, worker_configfile, e.returncode, | |||
), | |||
"error starting %s(%r) (exit code: %d); see above for logs" | |||
% (app, worker_configfile, e.returncode), | |||
colour=RED, | |||
) | |||
@@ -120,9 +113,9 @@ def stop(pidfile, app): | |||
abort("Cannot stop %s: Unknown error" % (app,)) | |||
Worker = collections.namedtuple("Worker", [ | |||
"app", "configfile", "pidfile", "cache_factor", "cache_factors", | |||
]) | |||
Worker = collections.namedtuple( | |||
"Worker", ["app", "configfile", "pidfile", "cache_factor", "cache_factors"] | |||
) | |||
def main(): | |||
@@ -141,24 +134,20 @@ def main(): | |||
help="the homeserver config file, defaults to homeserver.yaml", | |||
) | |||
parser.add_argument( | |||
"-w", "--worker", | |||
metavar="WORKERCONFIG", | |||
help="start or stop a single worker", | |||
"-w", "--worker", metavar="WORKERCONFIG", help="start or stop a single worker" | |||
) | |||
parser.add_argument( | |||
"-a", "--all-processes", | |||
"-a", | |||
"--all-processes", | |||
metavar="WORKERCONFIGDIR", | |||
help="start or stop all the workers in the given directory" | |||
" and the main synapse process", | |||
" and the main synapse process", | |||
) | |||
options = parser.parse_args() | |||
if options.worker and options.all_processes: | |||
write( | |||
'Cannot use "--worker" with "--all-processes"', | |||
stream=sys.stderr | |||
) | |||
write('Cannot use "--worker" with "--all-processes"', stream=sys.stderr) | |||
sys.exit(1) | |||
configfile = options.configfile | |||
@@ -167,9 +156,7 @@ def main(): | |||
write( | |||
"No config file found\n" | |||
"To generate a config file, run '%s -c %s --generate-config" | |||
" --server-name=<server name>'\n" % ( | |||
" ".join(SYNAPSE), options.configfile | |||
), | |||
" --server-name=<server name>'\n" % (" ".join(SYNAPSE), options.configfile), | |||
stream=sys.stderr, | |||
) | |||
sys.exit(1) | |||
@@ -194,8 +181,7 @@ def main(): | |||
worker_configfile = options.worker | |||
if not os.path.exists(worker_configfile): | |||
write( | |||
"No worker config found at %r" % (worker_configfile,), | |||
stream=sys.stderr, | |||
"No worker config found at %r" % (worker_configfile,), stream=sys.stderr | |||
) | |||
sys.exit(1) | |||
worker_configfiles.append(worker_configfile) | |||
@@ -211,9 +197,9 @@ def main(): | |||
stream=sys.stderr, | |||
) | |||
sys.exit(1) | |||
worker_configfiles.extend(sorted(glob.glob( | |||
os.path.join(worker_configdir, "*.yaml") | |||
))) | |||
worker_configfiles.extend( | |||
sorted(glob.glob(os.path.join(worker_configdir, "*.yaml"))) | |||
) | |||
workers = [] | |||
for worker_configfile in worker_configfiles: | |||
@@ -223,14 +209,12 @@ def main(): | |||
if worker_app == "synapse.app.homeserver": | |||
# We need to special case all of this to pick up options that may | |||
# be set in the main config file or in this worker config file. | |||
worker_pidfile = ( | |||
worker_config.get("pid_file") | |||
or pidfile | |||
worker_pidfile = worker_config.get("pid_file") or pidfile | |||
worker_cache_factor = ( | |||
worker_config.get("synctl_cache_factor") or cache_factor | |||
) | |||
worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor | |||
worker_cache_factors = ( | |||
worker_config.get("synctl_cache_factors") | |||
or cache_factors | |||
worker_config.get("synctl_cache_factors") or cache_factors | |||
) | |||
daemonize = worker_config.get("daemonize") or config.get("daemonize") | |||
assert daemonize, "Main process must have daemonize set to true" | |||
@@ -239,19 +223,27 @@ def main(): | |||
for key in worker_config: | |||
if key == "worker_app": # But we allow worker_app | |||
continue | |||
assert not key.startswith("worker_"), \ | |||
"Main process cannot use worker_* config" | |||
assert not key.startswith( | |||
"worker_" | |||
), "Main process cannot use worker_* config" | |||
else: | |||
worker_pidfile = worker_config["worker_pid_file"] | |||
worker_daemonize = worker_config["worker_daemonize"] | |||
assert worker_daemonize, "In config %r: expected '%s' to be True" % ( | |||
worker_configfile, "worker_daemonize") | |||
worker_configfile, | |||
"worker_daemonize", | |||
) | |||
worker_cache_factor = worker_config.get("synctl_cache_factor") | |||
worker_cache_factors = worker_config.get("synctl_cache_factors", {}) | |||
workers.append(Worker( | |||
worker_app, worker_configfile, worker_pidfile, worker_cache_factor, | |||
worker_cache_factors, | |||
)) | |||
workers.append( | |||
Worker( | |||
worker_app, | |||
worker_configfile, | |||
worker_pidfile, | |||
worker_cache_factor, | |||
worker_cache_factors, | |||
) | |||
) | |||
action = options.action | |||
@@ -108,10 +108,10 @@ commands = | |||
[testenv:pep8] | |||
skip_install = True | |||
basepython = python2.7 | |||
basepython = python3.6 | |||
deps = | |||
flake8 | |||
commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}" | |||
commands = /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" | |||
[testenv:check_isort] | |||
skip_install = True | |||