parent
aa4c2eb60b
commit
7ee35a990e
6 changed files with 583 additions and 79 deletions
@ -0,0 +1,269 @@ |
||||
from dataclasses import dataclass, field |
||||
import json |
||||
import enum |
||||
import ipaddress |
||||
import inspect |
||||
from typing import Optional as O, Tuple, Dict, Set, Any |
||||
|
||||
from .wireguard import WireguardConnection, IPInterface, IPNetwork |
||||
|
||||
|
||||
class WeegeeRemoteErrorCode(enum.Enum): |
||||
InvalidCommand = -1 |
||||
InvalidParameter = -2 |
||||
NotFound = -3 |
||||
NotAllowed = -4 |
||||
Conflict = -5 |
||||
InternalError = -10 |
||||
|
||||
class WeegeeRemoteError(Exception): |
||||
code: WeegeeRemoteErrorCode |
||||
|
||||
def __init__(self, code: WeegeeRemoteErrorCode, message: str) -> None: |
||||
super().__init__(message) |
||||
self.code = code |
||||
|
||||
|
||||
REMOTE_PROTOCOL_VERSION = 0 |
||||
REMOTE_COMMAND_SETS = {} |
||||
REMOTE_COMMAND_SETS['status'] = {'key.convert', 'interface.query', 'interface.info', 'peer.query', 'peer.info'} |
||||
REMOTE_COMMAND_SETS['sync'] = REMOTE_COMMAND_SETS['status'] | {'key.create', 'interface.set', 'interface.configure'} |
||||
REMOTE_COMMAND_SETS['manage'] = REMOTE_COMMAND_SETS['sync'] | { |
||||
'interface.create', 'interface.destroy', |
||||
'address.query', 'address.create', 'address.destroy', |
||||
'route.query', 'route.create', 'route.destroy', |
||||
} |
||||
|
||||
|
||||
@dataclass |
||||
class WeegeeRemoteServer: |
||||
conn: WireguardConnection |
||||
allowed_commands: O[Set[str]] = None |
||||
allowed_interfaces: O[Set[str]] = None |
||||
allowed_addresses: O[Set[IPInterface]] = None |
||||
allowed_routes: O[Set[IPNetwork]] = None |
||||
handlers: Dict[str, Any] = field(init=False) |
||||
handler_sigs: Dict[str, inspect.Signature] = field(init=False) |
||||
|
||||
def __post_init__(self) -> None: |
||||
self.handlers = { |
||||
'interface.query': self.handle_interface_query, |
||||
'interface.create': self.handle_interface_create, |
||||
'interface.destroy': self.handle_interface_destroy, |
||||
'interface.set': self.handle_interface_set, |
||||
'interface.info': self.handle_interface_info, |
||||
'interface.configure': self.handle_interface_configure, |
||||
'address.query': self.handle_address_query, |
||||
'address.create': self.handle_address_create, |
||||
'address.destroy': self.handle_address_destroy, |
||||
'route.query': self.handle_route_query, |
||||
'route.create': self.handle_route_create, |
||||
'route.destroy': self.handle_route_destroy, |
||||
'peer.query': self.handle_peer_query, |
||||
'peer.info': self.handle_peer_info, |
||||
'key.create': self.handle_key_create, |
||||
'key.convert': self.handle_key_convert, |
||||
} |
||||
self.handler_sigs = {name: inspect.signature(fn) for name, fn in self.handlers.items()} |
||||
|
||||
|
||||
def _check_interface(self, interface: str) -> None: |
||||
if self.allowed_interfaces is not None and interface not in self.allowed_interfaces: |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control interface: {interface}') |
||||
|
||||
def _check_existing_interface(self, interface: str) -> None: |
||||
self._check_interface(interface) |
||||
if not self.conn.has_interface(interface): |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'interface does not exist: {interface}') |
||||
|
||||
def _check_address(self, address: IPInterface) -> None: |
||||
if self.allowed_addresse is not None and not any(address in a for a in self.allowed_addresses): |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control address: {address}') |
||||
|
||||
def _check_route(self, route: IPNetwork) -> None: |
||||
if self.allowed_routes is not None and not any(route in r for r in self.allowed_routes): |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control route: {route}') |
||||
|
||||
|
||||
def handle_interface_query(self, interface: O[str] = None) -> Dict[str, Any]: |
||||
interfaces = self.conn.get_interfaces() |
||||
return {'interfaces': [i for i in (interfaces or []) if interface in (i, None)]} |
||||
|
||||
def handle_interface_create(self, interface: str) -> Dict[str, Any]: |
||||
self._check_interface(interface) |
||||
if self.conn.has_interface(interface): |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'interface exists: {interface}') |
||||
self.conn.create_interface(interface) |
||||
return {} |
||||
|
||||
def handle_interface_destroy(self, interface: str) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
self.conn.destroy_interface(interface) |
||||
return {} |
||||
|
||||
def handle_interface_set(self, interface: str, up: O[bool] = None, mtu: O[int] = None) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
if up is not None: |
||||
if up: |
||||
self.conn.set_up(interface) |
||||
else: |
||||
self.conn.set_down(interface) |
||||
if mtu is not None: |
||||
self.conn.set_mtu(interface, mtu) |
||||
return {} |
||||
|
||||
def handle_interface_info(self, interface: str) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
return {'config': self.conn.get_config(interface), 'config_format': self.conn.get_config_format(interface).value} |
||||
|
||||
def handle_interface_configure(self, interface: str, config: O[str] = None, sync: O[bool] = True) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
if config is not None: |
||||
self.conn.set_config(name, config, sync) |
||||
return {} |
||||
|
||||
def handle_address_query(self, interface: str, address: O[str] = None) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
if address: |
||||
address = ipaddress.ip_interface(address) |
||||
return {'addresses': [str(x) for x in self.conn.get_addresses(interface) if address in (None, x)]} |
||||
|
||||
def handle_address_create(self, interface: str, address: str) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
address = ipaddress.ip_interface(address) |
||||
self._check_address(address) |
||||
if address in self.conn.get_addresses(interface): |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'address exists on interface: {address}') |
||||
self.conn.add_address(interface, address) |
||||
return {} |
||||
|
||||
def handle_address_destroy(self, interface: str, address: str) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
address = ipaddress.ip_interface(address) |
||||
self._check_address(address) |
||||
if address not in self.conn.get_addresses(interface): |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'address does not exist on interface: {address}') |
||||
self.conn.delete_address(interface, address) |
||||
return {} |
||||
|
||||
def handle_route_query(self, interface: str, route: O[str] = None) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
if route: |
||||
route = ipaddress.ip_network(route) |
||||
return {'routes': [str(x) for x in self.conn.get_routes(interface) if route in (None, x)]} |
||||
|
||||
def handle_route_create(self, interface: str, route: str) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
route = ipaddress.ip_network(route) |
||||
self._check_route(route) |
||||
if route in self.conn.get_routes(interface): |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'route exists on interface: {route}') |
||||
self.conn.add_route(interface, route) |
||||
return {} |
||||
|
||||
def handle_route_destroy(self, interface: str, route: str) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
route = ipaddress.ip_network(route) |
||||
self._check_route(route) |
||||
if route not in self.conn.get_addresses(interface): |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'route does not exist on interface: {route}') |
||||
self.conn.delete_route(interface, route) |
||||
return {} |
||||
|
||||
def handle_peer_query(self, interface: str, peer: O[str] = None) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
return {'peers': [x for x in self.conn.get_peers(interface) if peer in (None, x)]} |
||||
|
||||
def handle_peer_info(self, interface: str, peer: str) -> Dict[str, Any]: |
||||
self._check_existing_interface(interface) |
||||
info = self.conn.get_peer_info(interface, peer) |
||||
if not info: |
||||
return None |
||||
return { |
||||
'latest_handshake': int(info.latest_handshake.timestamp()), |
||||
'endpoints': [f'{ip}:{port}' for (ip, port) in info.endpoints], |
||||
'allowed_ips': [str(x) for x in info.allowed_ips], |
||||
'sent_bytes': info.sent_bytes, |
||||
'recv_bytes': info.recv_bytes, |
||||
} |
||||
|
||||
def handle_key_create(self, type: str) -> Dict[str, Any]: |
||||
if type == 'preshared': |
||||
return {'key': self.conn.gen_preshared_key()} |
||||
if type == 'private': |
||||
return {'key': self.conn.gen_private_key()} |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}') |
||||
|
||||
def handle_key_convert(self, type: str, value: str) -> Dict[str, Any]: |
||||
if type == 'public': |
||||
return {'key': self.conn.get_public_key(value)} |
||||
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}') |
||||
|
||||
|
||||
def _parse_req(self, data: str) -> O[Tuple[str, Dict[str, Any]]]: |
||||
d = json.loads(data) |
||||
if d['type'] != 'request': |
||||
return None |
||||
return d['command'], d['parameters'] |
||||
|
||||
def _make_banner(self) -> str: |
||||
return json.dumps({ |
||||
'type': 'banner', |
||||
'version': f'weegee/0.1', |
||||
'protocol': REMOTE_PROTOCOL_VERSION, |
||||
}) |
||||
|
||||
def _make_resp(self, command: str, error: bool, /, **kwargs: Any) -> str: |
||||
d = { |
||||
'type': 'response', |
||||
'command': command, |
||||
} |
||||
if error: |
||||
d['ok'] = False |
||||
d['error'] = kwargs |
||||
else: |
||||
d['ok'] = True |
||||
d['values'] = kwargs |
||||
return json.dumps(d) |
||||
|
||||
def serve(self, infile, outfile) -> None: |
||||
outfile.write(self._make_banner() + '\n') |
||||
while True: |
||||
line = infile.readline() |
||||
if not line: |
||||
break |
||||
req = self._parse_req(line) |
||||
if not req: |
||||
break |
||||
command, params = req |
||||
|
||||
if command not in self.handlers: |
||||
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'invalid command: {command}'} |
||||
error = True |
||||
elif self.allowed_commands is not None and command not in self.allowed_commands: |
||||
data = {'code': WeegeeRemoteErrorCode.NotAllowed.value, 'msg': f'not allowed to request command: {command}'} |
||||
error = True |
||||
else: |
||||
signature = self.handler_sigs[command] |
||||
try: |
||||
signature.bind(**params) |
||||
except TypeError as e: |
||||
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': str(e)} |
||||
error = True |
||||
else: |
||||
try: |
||||
data = self.handlers[command](**params) |
||||
error = False |
||||
except WeegeeRemoteError as e: |
||||
data = {'code': e.code.value, 'msg': str(e)} |
||||
error = True |
||||
except NotImplementedError as e: |
||||
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'unimplemented command: {command}'} |
||||
error = True |
||||
except (ValueError, ipaddress.AddressValueError, ipaddress.NetmaskValueError): |
||||
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': f'invalid parameter: {e}'} |
||||
error = True |
||||
except Exception as e: |
||||
data = {'code': WeegeeRemoteErrorCode.InternalError.value, 'msg': str(e)} |
||||
error = True |
||||
outfile.write(self._make_resp(command, error, **data) + '\n') |
Loading…
Reference in new issue