diff --git a/weegee/__init__.py b/weegee/__init__.py index 2caacd1..a55aebe 100644 --- a/weegee/__init__.py +++ b/weegee/__init__.py @@ -14,6 +14,7 @@ from .core import ( do_discover_interface, do_config_interface, config_interface, do_sync_interface, sync_interface, sync_all_interfaces, ) +from .remote import WeegeeRemoteServer, REMOTE_COMMAND_SETS from .extra import WeegeeServer, WeegeeClient logger = getLogger(__name__) diff --git a/weegee/__main__.py b/weegee/__main__.py index 3564008..29fa06b 100644 --- a/weegee/__main__.py +++ b/weegee/__main__.py @@ -12,9 +12,11 @@ logging.basicConfig(level=logging.DEBUG) from .dazy import Item, export_items, import_items from .wireguard import WireguardHostType, WireguardConfigFormat from . import ( + REMOTE_COMMAND_SETS, WeegeeContext, WeegeeConfig, WeegeeHook, WeegeeHost, WeegeePublicInterface, WeegeeInterface, WeegeePeer, WeegeeConnection, WeegeeServer, WeegeeClient, + WeegeeRemoteServer, find_interface_other_peers, config_interface, sync_all_interfaces, setup, @@ -94,12 +96,12 @@ def main(): transfer = wg_peer.get_total_transfer() rx = siify(transfer[0]) if transfer else 'nothing' tx = siify(transfer[1]) if transfer else 'nothing' - handshake = wg_peer.last_handshake() - last_handshake = timestampify(handshake) if handshake else 'never' - endpoint = wg_peer.last_endpoint() - last_endpoint = f'{endpoint[0]}:{endpoint[1]}' if endpoint else '' + handshake = wg_peer.latest_handshake() + latest_handshake = timestampify(handshake) if handshake else 'never' + endpoint = wg_peer.latest_endpoint() + latest_endpoint = f'{endpoint[0]}:{endpoint[1]}' if endpoint else '' print(f' {peer.name} ({", ".join(str(x) for x in peer.interface.addresses)}, {wg_peer.name})') - print(f' last handshake {last_handshake}{f" from {last_endpoint}" if last_endpoint else ""}, {tx} sent, {rx} received') + print(f' latest handshake {latest_handshake}{f" from {latest_endpoint}" if latest_endpoint else ""}, {tx} sent, {rx} received') for wg_peer in wg_peers.values(): print(f' ({wg_peer.name}): found but not expected') success = False @@ -133,6 +135,41 @@ def main(): import_cmd.add_argument('file', type=argparse.FileType('r', encoding='utf-8'), nargs='?', default=sys.stdin) import_cmd.set_defaults(func=do_import, parser=import_cmd) + def do_remote(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> None: + if not args.host: + config = ctx.get_config() + args.host = config.default_remote_host or WeegeeHost.LOCAL_HOST_NAME + + commands = set() + for c in args.command: + if c.startswith('@'): + c = c[1:] + if c not in REMOTE_COMMAND_SETS: + parser.error(f'unknown command set: @{c}') + commands.update(REMOTE_COMMAND_SETS[c]) + else: + commands.add(c) + + host = WeegeeHost.load(ctx, args.host) + interfaces = set(args.interface) + addresses = set(args.address) + routes = set(args.route) + server = WeegeeRemoteServer(host.conn.conn, + allowed_commands=commands or None, + allowed_interfaces=interfaces or None, + allowed_addresses=addresses or None, + allowed_routes=routes or None, + ) + server.serve(sys.stdin, sys.stdout) + + remote = commands.add_parser('remote', help='remote server') + remote.add_argument('-c', '--command', type=str, action='append', default=[], help='allowed command(s) or @-prefixed command set(s)') + remote.add_argument('-i', '--interface', type=str, action='append', default=[], help='allowed interface(s)') + remote.add_argument('-a', '--address', type=ipaddress.ip_interface, action='append', default=[], help='allowed interface address(es)') + remote.add_argument('-r', '--route', type=ipaddress.ip_network, action='append', default=[], help='allowed interface route(s)') + remote.add_argument('host', help='host', nargs='?', default=None) + remote.set_defaults(func=do_remote, parser=remote) + # System commands diff --git a/weegee/core.py b/weegee/core.py index cfa80f2..1a81739 100644 --- a/weegee/core.py +++ b/weegee/core.py @@ -228,9 +228,8 @@ class WeegeeHost(WeegeeBase): 'configure', ] - @property - def config_format(self) -> O[WireguardConfigFormat]: - return self.conn.conn.CONFIG_FORMAT + def get_config_format(self, name: str) -> O[WireguardConfigFormat]: + return self.conn.conn.get_config_format(name) def get_hooks(self, name: str, **kwargs) -> WeegeeHookRunner: return WeegeeHookRunner([WeegeeHook(self.context, x) for x in getattr(self, f'{name}_hooks')], self.conn.conn, kwargs) @@ -264,18 +263,12 @@ class WeegeeHookedWireguardConnection(WireguardConnection): host: WeegeeHost inner: WireguardConnection - def __post_init__(self) -> None: - self.CONFIG_FORMAT = self.inner.CONFIG_FORMAT - - def _run(self, *args, **kwargs) -> str: - return self.inner._run(*args, **kwargs) - @staticmethod def _get_family(obj: U[IPAddress, IPInterface, IPNetwork]) -> str: return f'ipv{obj.version}' - def has_interface(self, name: str) -> bool: - return self.inner.has_interface(name) + def __getattr__(self, name: str) -> Any: + return getattr(self.inner, name) def create_interface(self, name: str) -> None: with self.host.get_hooks('interface_add', i=name): @@ -285,18 +278,6 @@ class WeegeeHookedWireguardConnection(WireguardConnection): with self.host.get_hooks('interface_del', i=name): return self.inner.destroy_interface(name) - def set_mtu(self, name: str, mtu: int) -> None: - return self.inner.set_mtu(name, mtu) - - def set_up(self, name: str) -> None: - return self.inner.set_up(name) - - def set_down(self, name: str) -> None: - return self.inner.set_down(name) - - def get_addresses(self, name: str) -> List[IPInterface]: - return self.inner.get_addresses(name) - def add_address(self, name: str, address: IPInterface) -> None: family_hook = f'address_{self._get_family(address)}_add' with self.host.get_hooks('address_add', i=name, a=str(address)), self.host.get_hooks(family_hook, i=name, a=str(address)): @@ -311,9 +292,6 @@ class WeegeeHookedWireguardConnection(WireguardConnection): for address in self.get_addresses(name): self.delete_address(name, address) - def get_routes(self, name: str) -> List[IPNetwork]: - return self.inner.get_routes(name) - def add_route(self, name: str, route: IPNetwork) -> None: family_hook = f'route_{self._get_family(route)}_add' with self.host.get_hooks('route_add', i=name, r=str(route)), self.host.get_hooks(family_hook, i=name, r=str(route)): @@ -601,7 +579,7 @@ def do_sync_interface(interface: WeegeeInterface, peers: Set[WeegeePeer], connec for host in interface.hosts: if auto and not host.autosync: continue - config = do_config_interface(interface, host.config_format, peers, connections, host) + config = do_config_interface(interface, host.get_config_format(interface.interface_name), peers, connections, host) if not host.sync_interface(interface.interface_name, interface.mtu, interface.addresses, routes, config): raise ValueError(f'host {host.name} failed to sync interface!') return other_peers diff --git a/weegee/desc.py b/weegee/desc.py index 0187d8e..ce2324e 100644 --- a/weegee/desc.py +++ b/weegee/desc.py @@ -102,8 +102,9 @@ WEEGEE_CONFIG = WeegeeMeta( name='wg/config', version=1, spec=[ - f'default_server_hosts: [str] = []', - f'default_client_hosts: [str] = []', + 'default_server_hosts: [str] = []', + 'default_client_hosts: [str] = []', + 'default_remote_host: ?str = ', 'log_level: str = "info"', 'meta_path: str = "."', 'data_path: ?str = ', diff --git a/weegee/remote.py b/weegee/remote.py new file mode 100644 index 0000000..caa7055 --- /dev/null +++ b/weegee/remote.py @@ -0,0 +1,269 @@ +from dataclasses import dataclass, field +import json +import enum +import ipaddress +import inspect +from typing import Optional as O, Tuple, Dict, Set, Any + +from .wireguard import WireguardConnection, IPInterface, IPNetwork + + +class WeegeeRemoteErrorCode(enum.Enum): + InvalidCommand = -1 + InvalidParameter = -2 + NotFound = -3 + NotAllowed = -4 + Conflict = -5 + InternalError = -10 + +class WeegeeRemoteError(Exception): + code: WeegeeRemoteErrorCode + + def __init__(self, code: WeegeeRemoteErrorCode, message: str) -> None: + super().__init__(message) + self.code = code + + +REMOTE_PROTOCOL_VERSION = 0 +REMOTE_COMMAND_SETS = {} +REMOTE_COMMAND_SETS['status'] = {'key.convert', 'interface.query', 'interface.info', 'peer.query', 'peer.info'} +REMOTE_COMMAND_SETS['sync'] = REMOTE_COMMAND_SETS['status'] | {'key.create', 'interface.set', 'interface.configure'} +REMOTE_COMMAND_SETS['manage'] = REMOTE_COMMAND_SETS['sync'] | { + 'interface.create', 'interface.destroy', + 'address.query', 'address.create', 'address.destroy', + 'route.query', 'route.create', 'route.destroy', +} + + +@dataclass +class WeegeeRemoteServer: + conn: WireguardConnection + allowed_commands: O[Set[str]] = None + allowed_interfaces: O[Set[str]] = None + allowed_addresses: O[Set[IPInterface]] = None + allowed_routes: O[Set[IPNetwork]] = None + handlers: Dict[str, Any] = field(init=False) + handler_sigs: Dict[str, inspect.Signature] = field(init=False) + + def __post_init__(self) -> None: + self.handlers = { + 'interface.query': self.handle_interface_query, + 'interface.create': self.handle_interface_create, + 'interface.destroy': self.handle_interface_destroy, + 'interface.set': self.handle_interface_set, + 'interface.info': self.handle_interface_info, + 'interface.configure': self.handle_interface_configure, + 'address.query': self.handle_address_query, + 'address.create': self.handle_address_create, + 'address.destroy': self.handle_address_destroy, + 'route.query': self.handle_route_query, + 'route.create': self.handle_route_create, + 'route.destroy': self.handle_route_destroy, + 'peer.query': self.handle_peer_query, + 'peer.info': self.handle_peer_info, + 'key.create': self.handle_key_create, + 'key.convert': self.handle_key_convert, + } + self.handler_sigs = {name: inspect.signature(fn) for name, fn in self.handlers.items()} + + + def _check_interface(self, interface: str) -> None: + if self.allowed_interfaces is not None and interface not in self.allowed_interfaces: + raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control interface: {interface}') + + def _check_existing_interface(self, interface: str) -> None: + self._check_interface(interface) + if not self.conn.has_interface(interface): + raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'interface does not exist: {interface}') + + def _check_address(self, address: IPInterface) -> None: + if self.allowed_addresse is not None and not any(address in a for a in self.allowed_addresses): + raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control address: {address}') + + def _check_route(self, route: IPNetwork) -> None: + if self.allowed_routes is not None and not any(route in r for r in self.allowed_routes): + raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control route: {route}') + + + def handle_interface_query(self, interface: O[str] = None) -> Dict[str, Any]: + interfaces = self.conn.get_interfaces() + return {'interfaces': [i for i in (interfaces or []) if interface in (i, None)]} + + def handle_interface_create(self, interface: str) -> Dict[str, Any]: + self._check_interface(interface) + if self.conn.has_interface(interface): + raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'interface exists: {interface}') + self.conn.create_interface(interface) + return {} + + def handle_interface_destroy(self, interface: str) -> Dict[str, Any]: + self._check_existing_interface(interface) + self.conn.destroy_interface(interface) + return {} + + def handle_interface_set(self, interface: str, up: O[bool] = None, mtu: O[int] = None) -> Dict[str, Any]: + self._check_existing_interface(interface) + if up is not None: + if up: + self.conn.set_up(interface) + else: + self.conn.set_down(interface) + if mtu is not None: + self.conn.set_mtu(interface, mtu) + return {} + + def handle_interface_info(self, interface: str) -> Dict[str, Any]: + self._check_existing_interface(interface) + return {'config': self.conn.get_config(interface), 'config_format': self.conn.get_config_format(interface).value} + + def handle_interface_configure(self, interface: str, config: O[str] = None, sync: O[bool] = True) -> Dict[str, Any]: + self._check_existing_interface(interface) + if config is not None: + self.conn.set_config(name, config, sync) + return {} + + def handle_address_query(self, interface: str, address: O[str] = None) -> Dict[str, Any]: + self._check_existing_interface(interface) + if address: + address = ipaddress.ip_interface(address) + return {'addresses': [str(x) for x in self.conn.get_addresses(interface) if address in (None, x)]} + + def handle_address_create(self, interface: str, address: str) -> Dict[str, Any]: + self._check_existing_interface(interface) + address = ipaddress.ip_interface(address) + self._check_address(address) + if address in self.conn.get_addresses(interface): + raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'address exists on interface: {address}') + self.conn.add_address(interface, address) + return {} + + def handle_address_destroy(self, interface: str, address: str) -> Dict[str, Any]: + self._check_existing_interface(interface) + address = ipaddress.ip_interface(address) + self._check_address(address) + if address not in self.conn.get_addresses(interface): + raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'address does not exist on interface: {address}') + self.conn.delete_address(interface, address) + return {} + + def handle_route_query(self, interface: str, route: O[str] = None) -> Dict[str, Any]: + self._check_existing_interface(interface) + if route: + route = ipaddress.ip_network(route) + return {'routes': [str(x) for x in self.conn.get_routes(interface) if route in (None, x)]} + + def handle_route_create(self, interface: str, route: str) -> Dict[str, Any]: + self._check_existing_interface(interface) + route = ipaddress.ip_network(route) + self._check_route(route) + if route in self.conn.get_routes(interface): + raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'route exists on interface: {route}') + self.conn.add_route(interface, route) + return {} + + def handle_route_destroy(self, interface: str, route: str) -> Dict[str, Any]: + self._check_existing_interface(interface) + route = ipaddress.ip_network(route) + self._check_route(route) + if route not in self.conn.get_addresses(interface): + raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'route does not exist on interface: {route}') + self.conn.delete_route(interface, route) + return {} + + def handle_peer_query(self, interface: str, peer: O[str] = None) -> Dict[str, Any]: + self._check_existing_interface(interface) + return {'peers': [x for x in self.conn.get_peers(interface) if peer in (None, x)]} + + def handle_peer_info(self, interface: str, peer: str) -> Dict[str, Any]: + self._check_existing_interface(interface) + info = self.conn.get_peer_info(interface, peer) + if not info: + return None + return { + 'latest_handshake': int(info.latest_handshake.timestamp()), + 'endpoints': [f'{ip}:{port}' for (ip, port) in info.endpoints], + 'allowed_ips': [str(x) for x in info.allowed_ips], + 'sent_bytes': info.sent_bytes, + 'recv_bytes': info.recv_bytes, + } + + def handle_key_create(self, type: str) -> Dict[str, Any]: + if type == 'preshared': + return {'key': self.conn.gen_preshared_key()} + if type == 'private': + return {'key': self.conn.gen_private_key()} + raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}') + + def handle_key_convert(self, type: str, value: str) -> Dict[str, Any]: + if type == 'public': + return {'key': self.conn.get_public_key(value)} + raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}') + + + def _parse_req(self, data: str) -> O[Tuple[str, Dict[str, Any]]]: + d = json.loads(data) + if d['type'] != 'request': + return None + return d['command'], d['parameters'] + + def _make_banner(self) -> str: + return json.dumps({ + 'type': 'banner', + 'version': f'weegee/0.1', + 'protocol': REMOTE_PROTOCOL_VERSION, + }) + + def _make_resp(self, command: str, error: bool, /, **kwargs: Any) -> str: + d = { + 'type': 'response', + 'command': command, + } + if error: + d['ok'] = False + d['error'] = kwargs + else: + d['ok'] = True + d['values'] = kwargs + return json.dumps(d) + + def serve(self, infile, outfile) -> None: + outfile.write(self._make_banner() + '\n') + while True: + line = infile.readline() + if not line: + break + req = self._parse_req(line) + if not req: + break + command, params = req + + if command not in self.handlers: + data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'invalid command: {command}'} + error = True + elif self.allowed_commands is not None and command not in self.allowed_commands: + data = {'code': WeegeeRemoteErrorCode.NotAllowed.value, 'msg': f'not allowed to request command: {command}'} + error = True + else: + signature = self.handler_sigs[command] + try: + signature.bind(**params) + except TypeError as e: + data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': str(e)} + error = True + else: + try: + data = self.handlers[command](**params) + error = False + except WeegeeRemoteError as e: + data = {'code': e.code.value, 'msg': str(e)} + error = True + except NotImplementedError as e: + data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'unimplemented command: {command}'} + error = True + except (ValueError, ipaddress.AddressValueError, ipaddress.NetmaskValueError): + data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': f'invalid parameter: {e}'} + error = True + except Exception as e: + data = {'code': WeegeeRemoteErrorCode.InternalError.value, 'msg': str(e)} + error = True + outfile.write(self._make_resp(command, error, **data) + '\n') diff --git a/weegee/wireguard.py b/weegee/wireguard.py index d4194a2..2122dc2 100644 --- a/weegee/wireguard.py +++ b/weegee/wireguard.py @@ -4,6 +4,7 @@ from datetime import datetime import shlex import ipaddress import enum +import json from typing import Optional as O, Union as U, Type, Tuple, List import subprocess from logging import getLogger @@ -16,6 +17,14 @@ IPNetwork = U[ipaddress.IPv4Network, ipaddress.IPv6Network] IPInterface = U[ipaddress.IPv4Interface, ipaddress.IPv6Interface] +@dataclass +class WireguardPeerInfo: + latest_handshake: O[datetime] + endpoints: List[Tuple[IPAddress, int]] + allowed_ips: List[IPNetwork] + sent_bytes: int + recv_bytes: int + class WireguardConfigFormat(enum.Enum): WG = 'wg' WGQuick = 'wg-quick' @@ -23,13 +32,8 @@ class WireguardConfigFormat(enum.Enum): class WireguardConnection: - CONFIG_FORMAT: O[WireguardConfigFormat] = None - - def _run(self, *args, shell=False, stdin=None, **kwargs) -> str: - raise NotImplementedError - - def _run_wg(self, *args, **kwargs) -> str: - return self._run('wg', *args, **kwargs) + def get_interfaces(self) -> O[List[str]]: + return None def has_interface(self, name: str) -> bool: raise NotImplementedError @@ -40,6 +44,9 @@ class WireguardConnection: def destroy_interface(self, name: str) -> None: raise NotImplementedError + def get_config_format(self, name: str) -> WireguardConfigFormat: + return None + def set_mtu(self, name: str, mtu: int) -> None: raise NotImplementedError @@ -50,7 +57,7 @@ class WireguardConnection: raise NotImplementedError def get_addresses(self, name: str) -> List[IPInterface]: - raise NotADirectoryError + raise NotImplementedError def add_address(self, name: str, address: IPInterface) -> None: raise NotImplementedError @@ -75,6 +82,29 @@ class WireguardConnection: for route in self.get_routes(name): self.delete_route(name, route) + def gen_preshared_key(self) -> str: + raise NotImplementedError + + def gen_private_key(self) -> str: + raise NotImplementedError + + def get_public_key(self, private_key: str) -> str: + raise NotImplementedError + + def get_peers(self, name: str) -> List[str]: + raise NotImplementedError + + def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]: + raise NotImplementedError + + def get_config(self, name: str) -> str: + raise NotImplementedError + + def set_config(self, name: str, config: str, sync: bool = True) -> None: + raise NotImplementedError + + def clear_config(self, name: str) -> None: + return self.set_config(name, '', sync=False) @dataclass class WireguardConnectionBase(WireguardConnection): @@ -82,7 +112,7 @@ class WireguardConnectionBase(WireguardConnection): user: O[str] elevate_user: O[str] - def _run(self, *args, shell=False, stdin=None, **kwargs) -> str: + def _open(self, *args, shell=False, **kwargs) -> str: cmd = [] logger.debug(f'#{"[" + self.host + "]" if self.host else ""} {args[0] if shell else shlex.join(args)}') if self.host: @@ -97,20 +127,27 @@ class WireguardConnectionBase(WireguardConnection): else: cmd += list(args) p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='ascii', **kwargs) + + def _run(self, *args, stdin=None, **kwargs) -> str: + p = self._open(*args, **kwargs) stdout, stderr = p.communicate(stdin) if p.returncode: raise subprocess.CalledProcessError(p.returncode, cmd, output=stdout, stderr=stderr) return stdout.strip() -class WireguardLinuxConnection(WireguardConnectionBase): - CONFIG_FORMAT = WireguardConfigFormat.WG + def _run_wg(self, *args, **kwargs) -> str: + return self._run('wg', *args, **kwargs) +class WireguardLinuxConnection(WireguardConnectionBase): def _run_ip(self, *args) -> str: return self._run('ip', *args) def _run_ip46(self, *args) -> str: return (self._run_ip('-4', *args) + '\n' + self._run_ip('-6', *args)).strip() + def get_interfaces(self) -> List[str]: + return self._run_wg('show', 'interfaces').split() + def has_interface(self, name: str) -> bool: try: self._run_ip('link', 'show', name) @@ -118,6 +155,9 @@ class WireguardLinuxConnection(WireguardConnectionBase): except subprocess.CalledProcessError: return False + def get_config_format(self, name: str) -> WireguardConfigFormat: + return WireguardConfigFormat.WG + def create_interface(self, name: str) -> None: self._run_ip('link', 'add', name, 'type', 'wireguard') @@ -163,42 +203,218 @@ class WireguardLinuxConnection(WireguardConnectionBase): def delete_route(self, name: str, route: IPNetwork) -> None: self._run_ip('route', 'del', str(route), 'dev', name) + def get_peers(self, name: str) -> List[str]: + return self._run_wg('show', name, 'peers').split() + + def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]: + for line in self.conn._run_wg('show', name, 'dump').splitlines(): + line_peer, *parts = line.split() + if line_peer != peer: + continue + + psk, endpoints_str, allowed_ips_str, latest_handshake, rx, tx, keepalive, *_ = parts + if endpoints_str == '(none)': + endpoints = [] + else: + endpoints = [x.split(':', maxsplit=1) for x in endpoints_str.split(',')] + if allowed_ips_str == '(none)': + allowed_ips = [] + else: + allowed_ips = allowed_ips_str.split(',') + latest_handshake_ts = int(latest_handshake) + return WireguardPeerInfo( + datetime.fromtimestamp(latest_handshake_ts) if latest_handshake_ts > 0 else None, + [(ipaddress.ip_address(ip.strip('[]')), port) for (ip, port) in endpoints], + [ipaddress.ip_network(x) for x in allowed_ips], + int(tx), + int(rx), + ) + + return None + + def get_config(self, name: str) -> str: + return self._run_wg('getconf', name) + + def set_config(self, name: str, config: str, sync: bool = True) -> None: + if sync: + self._run_wg('syncconf', name, '/dev/stdin', stdin=config) + else: + self._run_wg('setconf', name, '/dev/stdin', stdin=config) + + def gen_preshared_key(self) -> str: + return self._run_wg('genpsk') + + def gen_private_key(self) -> str: + return self._run_wg('genkey') + + def get_public_key(self, private_key: str) -> str: + return self._run_wg('pubkey', stdin=private_key) + +class WireguardRemoteConnection(WireguardConnectionBase): + process: subprocess.Popen + PROTOCOL_VERSION = 0 + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.process = None + + def _make_req(self, name: str, /, **kwargs) -> str: + return json.dumps({ + 'type': 'request', + 'command': name, + 'parameters': kwargs, + }) + + def _parse_banner(self, data: str) -> O[str]: + d = json.loads(data) + if d['type'] != 'banner': + return None + return d['version'], d['protocol'] + + def _parse_resp(self, name: str, data: str) -> O[Dict[str, Any]]: + try: + d = json.loads(data) + except json.JSONDecodeError as e: + raise subprocess.CalledProcessError(255, '', stderr=repr(e)) + if d['type'] != 'response' or d['command'] != command: + return None + if not d['ok']: + raise subprocess.CalledProcessError(d['error']['code'], '', stderr=d['error']['msg']) + return d['values'] + + def _run(self, name: str, /, **kwargs) -> Dict[str, Any]: + first = True + while i < 10: + if not first and self.process: + self.process.stdin.close() + self.process.terminate() + self.process = None + first = False + self.process = self._open('weegee', 'remote', encoding='utf-8', timeout=5) + line = p.stdout.readline() + if not line: + continue + banner, protocol = self._parse_banner(line) + if banner is None: + continue + if protocol != self.PROTOCOL_VERSION: + raise subprocess.CalledProcessError(255, '', stderr=f'protocol version mismatch: expected {self.PROTOCOL_VERSION}, got {protocol}') + logger.debug('connected to remote %r', banner) + p.stdin.write(self._make_req(name, **kwargs) + '\n') + line = p.stdout.readline() + if not line: + continue + resp = self._parse_resp(name, line) + if resp is None: + continue + break + return resp + + def get_interfaces(self) -> O[List[str]]: + resp = self._run('interface.query') + return resp['interfaces'] + + def has_interface(self, name: str) -> bool: + resp = self._run('interface.query', interface=name) + return bool(resp['interfaces']) + + def get_config_format(self, name: str) -> WireguardConfigFormat: + resp = self._run('interface.info', interface=name) + return WireguardConfigFormat(resp['config_format']) + + def create_interface(self, name: str) -> None: + self._run('interface.create', interface=name) + + def destroy_interface(self, name: str) -> None: + self._run('interface.destroy', interface=name) + + def set_mtu(self, name: str, mtu: int) -> None: + self._run('interface.set', interface=name, mtu=mtu) + + def set_up(self, name: str) -> None: + self._run('interface.set', interface=name, up=True) + + def set_down(self, name: str) -> None: + self._run('interface.set', interface=name, up=False) + + def get_addresses(self, name: str) -> List[IPInterface]: + resp = self._run('address.query', interface=name) + return [ipaddress.ip_interface(x) for x in resp['addresses']] + + def add_address(self, name: str, address: IPInterface) -> None: + self._run('address.create', interface=name, address=str(address)) + + def delete_address(self, name: str, address: IPInterface) -> None: + self._run('address.destroy', interface=name, address=str(address)) + + def get_routes(self, name: str) -> List[IPNetwork]: + resp = self._run('route.query', interface=name) + return [ipaddress.ip_network(x) for x in resp['routes']] + + def add_route(self, name: str, route: IPNetwork) -> None: + self._run('route.create', interface=name, route=str(route)) + + def delete_route(self, name: str, route: IPNetwork) -> None: + self._run('route.destroy', interface=name, route=str(route)) + + def get_peers(self, name: str) -> List[str]: + resp = self._run('peer.query', interface=name) + return resp['peers'] + + def get_peer_info(self, name: str, peer: str) -> O[WireguardPeerInfo]: + resp = self._run('peer.info', interface=name, peer=peer) + if not resp: + return None + return WireguardPeerInfo( + datetime.fromtimestamp(resp['latest_handshake']) if resp['latest_handshake'] > 0 else None, + [(ipaddress.ip_address(x.split(':', 1)[0].strip('[]')), int(x.split(':', 1)[1])) for x in resp['endpoints']], + [ipaddress.ip_network(x) for x in resp['allowed_ips']], + resp['sent_bytes'], + resp['recv_bytes'], + ) + + def get_config(self, name: str) -> str: + resp = self._run('interface.info', interface=name) + return resp['config'] + + def set_config(self, name: str, config: str, sync: bool = True) -> None: + self._run('interface.configure', interface=name, config=config, sync=sync) + + def gen_preshared_key(self) -> str: + resp = self._run('key.create', type='preshared') + return resp['key'] + + def gen_private_key(self) -> str: + resp = self._run('key.create', type='private') + return resp['key'] + + def get_public_key(self, private_key: str) -> str: + resp = self._run('key.convert', type='public', value=private_key) + return resp['key'] + @dataclass class WireguardPeer: conn: WireguardConnection interface: str name: str - def _filter_list(self, cmd: str) -> O[List[str]]: - for line in self.conn._run_wg('show', self.interface, cmd).splitlines(): - peer, *parts = line.split() - if peer == self.name: - return parts - return None + def latest_handshake(self) -> O[datetime]: + info = self.conn.get_peer_info(self.interface, self.name) + if not info: + return None + return info.latest_handshake - def last_handshake(self) -> O[datetime]: - parts = self._filter_list('latest-handshakes') - if not parts: + def latest_endpoint(self) -> O[Tuple[IPAddress, int]]: + info = self.conn.get_peer_info(self.interface, self.name) + if not info or not info.endpoints: return None - ts = int(parts[0]) - if not ts: - return None - return datetime.fromtimestamp(ts) - - def last_endpoint(self) -> O[Tuple[U[ipaddress.IPv4Address, ipaddress.IPv6Address], int]]: - parts = self._filter_list('endpoints') - if not parts: - return None - if parts[0] == '(none)': - return None - addr, port = parts[0].rsplit(':', maxsplit=1) - return (ipaddress.ip_address(addr.strip('[]')), int(port)) + return info.endpoints[-1] def get_total_transfer(self) -> O[Tuple[int, int]]: - parts = self._filter_list('transfer') - if not parts: + info = self.conn.get_peer_info(self.interface, self.name) + if not info: return None - return (int(parts[0]), int(parts[1])) + return (info.recv_bytes, info.sent_bytes) @dataclass class WireguardInterface: @@ -206,10 +422,10 @@ class WireguardInterface: name: str def list_peers(self) -> List[WireguardPeer]: - return [WireguardPeer(self.conn, self.name, x) for x in self.conn._run_wg('show', self.name, 'peers').split()] + return [WireguardPeer(self.conn, self.name, x) for x in self.conn.get_peers(self.name)] def get_config(self) -> str: - return self.conn._run_wg('getconf', self.name) + return self.conn.get_config(self.name) def sync(self, addresses: List[IPInterface], routes: List[IPNetwork], mtu: O[int] = None) -> None: routes += [addr.network for addr in addresses] @@ -234,13 +450,13 @@ class WireguardInterface: self.conn.add_route(self.name, wanted_routes[route]) def set_config(self, config: str) -> None: - self.conn._run_wg('setconf', self.name, '/dev/stdin', stdin=config) - + return self.conn.set_config(self.name, config, sync=False) + def sync_config(self, config: str) -> None: - self.conn._run_wg('syncconf', self.name, '/dev/stdin', stdin=config) + return self.conn.set_config(self.name, config, sync=True) def clear_config(self) -> None: - self.conn._run_wg('setconf', self.name, '/dev/stdin', stdin='') + return self.conn.clear_config(self.name) def delete(self) -> None: self.conn.delete_all_routes(self.name) @@ -253,7 +469,7 @@ class WireguardHost: conn: WireguardConnection def list_interfaces(self) -> List[str]: - return self.conn._run_wg('show', 'interfaces').split() + return self.conn.get_interfaces() def get_interface(self, name: str) -> O[WireguardInterface]: for interface in self.list_interfaces(): @@ -267,20 +483,22 @@ class WireguardHost: return self.get_interface(name) def gen_preshared_key(self) -> str: - return self.conn._run_wg('genpsk') + return self.conn.gen_preshared_key() def gen_private_key(self) -> str: - return self.conn._run_wg('genkey') + return self.conn.gen_private_key() def get_public_key(self, privkey: str) -> str: - return self.conn._run_wg('pubkey', stdin=privkey) + return self.conn.get_public_key(privkey) class WireguardHostType(enum.Enum): Unsupported = None Linux = 'linux' + Remote = 'remote' def which_connection(type: WireguardHostType) -> Type[WireguardConnectionBase]: return { WireguardHostType.Linux: WireguardLinuxConnection, + WireguardHostType.Remote: WireguardRemoteConnection, }.get(type, WireguardConnectionBase)