weegee/weegee/remote.py

270 lines
12 KiB
Python

from dataclasses import dataclass, field
import json
import enum
import ipaddress
import inspect
from typing import Optional as O, Tuple, Dict, Set, Any
from .wireguard import WireguardConnection, IPInterface, IPNetwork
class WeegeeRemoteErrorCode(enum.Enum):
InvalidCommand = -1
InvalidParameter = -2
NotFound = -3
NotAllowed = -4
Conflict = -5
InternalError = -10
class WeegeeRemoteError(Exception):
code: WeegeeRemoteErrorCode
def __init__(self, code: WeegeeRemoteErrorCode, message: str) -> None:
super().__init__(message)
self.code = code
REMOTE_PROTOCOL_VERSION = 0
REMOTE_COMMAND_SETS = {}
REMOTE_COMMAND_SETS['status'] = {'key.convert', 'interface.query', 'interface.info', 'peer.query', 'peer.info'}
REMOTE_COMMAND_SETS['sync'] = REMOTE_COMMAND_SETS['status'] | {'key.create', 'interface.set', 'interface.configure'}
REMOTE_COMMAND_SETS['manage'] = REMOTE_COMMAND_SETS['sync'] | {
'interface.create', 'interface.destroy',
'address.query', 'address.create', 'address.destroy',
'route.query', 'route.create', 'route.destroy',
}
@dataclass
class WeegeeRemoteServer:
conn: WireguardConnection
allowed_commands: O[Set[str]] = None
allowed_interfaces: O[Set[str]] = None
allowed_addresses: O[Set[IPInterface]] = None
allowed_routes: O[Set[IPNetwork]] = None
handlers: Dict[str, Any] = field(init=False)
handler_sigs: Dict[str, inspect.Signature] = field(init=False)
def __post_init__(self) -> None:
self.handlers = {
'interface.query': self.handle_interface_query,
'interface.create': self.handle_interface_create,
'interface.destroy': self.handle_interface_destroy,
'interface.set': self.handle_interface_set,
'interface.info': self.handle_interface_info,
'interface.configure': self.handle_interface_configure,
'address.query': self.handle_address_query,
'address.create': self.handle_address_create,
'address.destroy': self.handle_address_destroy,
'route.query': self.handle_route_query,
'route.create': self.handle_route_create,
'route.destroy': self.handle_route_destroy,
'peer.query': self.handle_peer_query,
'peer.info': self.handle_peer_info,
'key.create': self.handle_key_create,
'key.convert': self.handle_key_convert,
}
self.handler_sigs = {name: inspect.signature(fn) for name, fn in self.handlers.items()}
def _check_interface(self, interface: str) -> None:
if self.allowed_interfaces is not None and interface not in self.allowed_interfaces:
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control interface: {interface}')
def _check_existing_interface(self, interface: str) -> None:
self._check_interface(interface)
if not self.conn.has_interface(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'interface does not exist: {interface}')
def _check_address(self, address: IPInterface) -> None:
if self.allowed_addresse is not None and not any(address in a for a in self.allowed_addresses):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control address: {address}')
def _check_route(self, route: IPNetwork) -> None:
if self.allowed_routes is not None and not any(route in r for r in self.allowed_routes):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control route: {route}')
def handle_interface_query(self, interface: O[str] = None) -> Dict[str, Any]:
interfaces = self.conn.get_interfaces()
return {'interfaces': [i for i in (interfaces or []) if interface in (i, None)]}
def handle_interface_create(self, interface: str) -> Dict[str, Any]:
self._check_interface(interface)
if self.conn.has_interface(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'interface exists: {interface}')
self.conn.create_interface(interface)
return {}
def handle_interface_destroy(self, interface: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
self.conn.destroy_interface(interface)
return {}
def handle_interface_set(self, interface: str, up: O[bool] = None, mtu: O[int] = None) -> Dict[str, Any]:
self._check_existing_interface(interface)
if up is not None:
if up:
self.conn.set_up(interface)
else:
self.conn.set_down(interface)
if mtu is not None:
self.conn.set_mtu(interface, mtu)
return {}
def handle_interface_info(self, interface: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
return {'config': self.conn.get_config(interface), 'config_format': self.conn.get_config_format(interface).value}
def handle_interface_configure(self, interface: str, config: O[str] = None, sync: O[bool] = True) -> Dict[str, Any]:
self._check_existing_interface(interface)
if config is not None:
self.conn.set_config(name, config, sync)
return {}
def handle_address_query(self, interface: str, address: O[str] = None) -> Dict[str, Any]:
self._check_existing_interface(interface)
if address:
address = ipaddress.ip_interface(address)
return {'addresses': [str(x) for x in self.conn.get_addresses(interface) if address in (None, x)]}
def handle_address_create(self, interface: str, address: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
address = ipaddress.ip_interface(address)
self._check_address(address)
if address in self.conn.get_addresses(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'address exists on interface: {address}')
self.conn.add_address(interface, address)
return {}
def handle_address_destroy(self, interface: str, address: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
address = ipaddress.ip_interface(address)
self._check_address(address)
if address not in self.conn.get_addresses(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'address does not exist on interface: {address}')
self.conn.delete_address(interface, address)
return {}
def handle_route_query(self, interface: str, route: O[str] = None) -> Dict[str, Any]:
self._check_existing_interface(interface)
if route:
route = ipaddress.ip_network(route)
return {'routes': [str(x) for x in self.conn.get_routes(interface) if route in (None, x)]}
def handle_route_create(self, interface: str, route: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
route = ipaddress.ip_network(route)
self._check_route(route)
if route in self.conn.get_routes(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'route exists on interface: {route}')
self.conn.add_route(interface, route)
return {}
def handle_route_destroy(self, interface: str, route: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
route = ipaddress.ip_network(route)
self._check_route(route)
if route not in self.conn.get_addresses(interface):
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'route does not exist on interface: {route}')
self.conn.delete_route(interface, route)
return {}
def handle_peer_query(self, interface: str, peer: O[str] = None) -> Dict[str, Any]:
self._check_existing_interface(interface)
return {'peers': [x for x in self.conn.get_peers(interface) if peer in (None, x)]}
def handle_peer_info(self, interface: str, peer: str) -> Dict[str, Any]:
self._check_existing_interface(interface)
info = self.conn.get_peer_info(interface, peer)
if not info:
return None
return {
'latest_handshake': int(info.latest_handshake.timestamp()),
'endpoints': [f'{ip}:{port}' for (ip, port) in info.endpoints],
'allowed_ips': [str(x) for x in info.allowed_ips],
'sent_bytes': info.sent_bytes,
'recv_bytes': info.recv_bytes,
}
def handle_key_create(self, type: str) -> Dict[str, Any]:
if type == 'preshared':
return {'key': self.conn.gen_preshared_key()}
if type == 'private':
return {'key': self.conn.gen_private_key()}
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}')
def handle_key_convert(self, type: str, value: str) -> Dict[str, Any]:
if type == 'public':
return {'key': self.conn.get_public_key(value)}
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}')
def _parse_req(self, data: str) -> O[Tuple[str, Dict[str, Any]]]:
d = json.loads(data)
if d['type'] != 'request':
return None
return d['command'], d['parameters']
def _make_banner(self) -> str:
return json.dumps({
'type': 'banner',
'version': f'weegee/0.1',
'protocol': REMOTE_PROTOCOL_VERSION,
})
def _make_resp(self, command: str, error: bool, /, **kwargs: Any) -> str:
d = {
'type': 'response',
'command': command,
}
if error:
d['ok'] = False
d['error'] = kwargs
else:
d['ok'] = True
d['values'] = kwargs
return json.dumps(d)
def serve(self, infile, outfile) -> None:
outfile.write(self._make_banner() + '\n')
while True:
line = infile.readline()
if not line:
break
req = self._parse_req(line)
if not req:
break
command, params = req
if command not in self.handlers:
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'invalid command: {command}'}
error = True
elif self.allowed_commands is not None and command not in self.allowed_commands:
data = {'code': WeegeeRemoteErrorCode.NotAllowed.value, 'msg': f'not allowed to request command: {command}'}
error = True
else:
signature = self.handler_sigs[command]
try:
signature.bind(**params)
except TypeError as e:
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': str(e)}
error = True
else:
try:
data = self.handlers[command](**params)
error = False
except WeegeeRemoteError as e:
data = {'code': e.code.value, 'msg': str(e)}
error = True
except NotImplementedError as e:
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'unimplemented command: {command}'}
error = True
except (ValueError, ipaddress.AddressValueError, ipaddress.NetmaskValueError):
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': f'invalid parameter: {e}'}
error = True
except Exception as e:
data = {'code': WeegeeRemoteErrorCode.InternalError.value, 'msg': str(e)}
error = True
outfile.write(self._make_resp(command, error, **data) + '\n')