270 lines
12 KiB
Python
270 lines
12 KiB
Python
from dataclasses import dataclass, field
|
|
import json
|
|
import enum
|
|
import ipaddress
|
|
import inspect
|
|
from typing import Optional as O, Tuple, Dict, Set, Any
|
|
|
|
from .wireguard import WireguardConnection, IPInterface, IPNetwork
|
|
|
|
|
|
class WeegeeRemoteErrorCode(enum.Enum):
|
|
InvalidCommand = -1
|
|
InvalidParameter = -2
|
|
NotFound = -3
|
|
NotAllowed = -4
|
|
Conflict = -5
|
|
InternalError = -10
|
|
|
|
class WeegeeRemoteError(Exception):
|
|
code: WeegeeRemoteErrorCode
|
|
|
|
def __init__(self, code: WeegeeRemoteErrorCode, message: str) -> None:
|
|
super().__init__(message)
|
|
self.code = code
|
|
|
|
|
|
REMOTE_PROTOCOL_VERSION = 0
|
|
REMOTE_COMMAND_SETS = {}
|
|
REMOTE_COMMAND_SETS['status'] = {'key.convert', 'interface.query', 'interface.info', 'peer.query', 'peer.info'}
|
|
REMOTE_COMMAND_SETS['sync'] = REMOTE_COMMAND_SETS['status'] | {'key.create', 'interface.set', 'interface.configure'}
|
|
REMOTE_COMMAND_SETS['manage'] = REMOTE_COMMAND_SETS['sync'] | {
|
|
'interface.create', 'interface.destroy',
|
|
'address.query', 'address.create', 'address.destroy',
|
|
'route.query', 'route.create', 'route.destroy',
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class WeegeeRemoteServer:
|
|
conn: WireguardConnection
|
|
allowed_commands: O[Set[str]] = None
|
|
allowed_interfaces: O[Set[str]] = None
|
|
allowed_addresses: O[Set[IPInterface]] = None
|
|
allowed_routes: O[Set[IPNetwork]] = None
|
|
handlers: Dict[str, Any] = field(init=False)
|
|
handler_sigs: Dict[str, inspect.Signature] = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
self.handlers = {
|
|
'interface.query': self.handle_interface_query,
|
|
'interface.create': self.handle_interface_create,
|
|
'interface.destroy': self.handle_interface_destroy,
|
|
'interface.set': self.handle_interface_set,
|
|
'interface.info': self.handle_interface_info,
|
|
'interface.configure': self.handle_interface_configure,
|
|
'address.query': self.handle_address_query,
|
|
'address.create': self.handle_address_create,
|
|
'address.destroy': self.handle_address_destroy,
|
|
'route.query': self.handle_route_query,
|
|
'route.create': self.handle_route_create,
|
|
'route.destroy': self.handle_route_destroy,
|
|
'peer.query': self.handle_peer_query,
|
|
'peer.info': self.handle_peer_info,
|
|
'key.create': self.handle_key_create,
|
|
'key.convert': self.handle_key_convert,
|
|
}
|
|
self.handler_sigs = {name: inspect.signature(fn) for name, fn in self.handlers.items()}
|
|
|
|
|
|
def _check_interface(self, interface: str) -> None:
|
|
if self.allowed_interfaces is not None and interface not in self.allowed_interfaces:
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control interface: {interface}')
|
|
|
|
def _check_existing_interface(self, interface: str) -> None:
|
|
self._check_interface(interface)
|
|
if not self.conn.has_interface(interface):
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'interface does not exist: {interface}')
|
|
|
|
def _check_address(self, address: IPInterface) -> None:
|
|
if self.allowed_addresse is not None and not any(address in a for a in self.allowed_addresses):
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control address: {address}')
|
|
|
|
def _check_route(self, route: IPNetwork) -> None:
|
|
if self.allowed_routes is not None and not any(route in r for r in self.allowed_routes):
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotAllowed, f'not allowed to control route: {route}')
|
|
|
|
|
|
def handle_interface_query(self, interface: O[str] = None) -> Dict[str, Any]:
|
|
interfaces = self.conn.get_interfaces()
|
|
return {'interfaces': [i for i in (interfaces or []) if interface in (i, None)]}
|
|
|
|
def handle_interface_create(self, interface: str) -> Dict[str, Any]:
|
|
self._check_interface(interface)
|
|
if self.conn.has_interface(interface):
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'interface exists: {interface}')
|
|
self.conn.create_interface(interface)
|
|
return {}
|
|
|
|
def handle_interface_destroy(self, interface: str) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
self.conn.destroy_interface(interface)
|
|
return {}
|
|
|
|
def handle_interface_set(self, interface: str, up: O[bool] = None, mtu: O[int] = None) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
if up is not None:
|
|
if up:
|
|
self.conn.set_up(interface)
|
|
else:
|
|
self.conn.set_down(interface)
|
|
if mtu is not None:
|
|
self.conn.set_mtu(interface, mtu)
|
|
return {}
|
|
|
|
def handle_interface_info(self, interface: str) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
return {'config': self.conn.get_config(interface), 'config_format': self.conn.get_config_format(interface).value}
|
|
|
|
def handle_interface_configure(self, interface: str, config: O[str] = None, sync: O[bool] = True) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
if config is not None:
|
|
self.conn.set_config(name, config, sync)
|
|
return {}
|
|
|
|
def handle_address_query(self, interface: str, address: O[str] = None) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
if address:
|
|
address = ipaddress.ip_interface(address)
|
|
return {'addresses': [str(x) for x in self.conn.get_addresses(interface) if address in (None, x)]}
|
|
|
|
def handle_address_create(self, interface: str, address: str) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
address = ipaddress.ip_interface(address)
|
|
self._check_address(address)
|
|
if address in self.conn.get_addresses(interface):
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'address exists on interface: {address}')
|
|
self.conn.add_address(interface, address)
|
|
return {}
|
|
|
|
def handle_address_destroy(self, interface: str, address: str) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
address = ipaddress.ip_interface(address)
|
|
self._check_address(address)
|
|
if address not in self.conn.get_addresses(interface):
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'address does not exist on interface: {address}')
|
|
self.conn.delete_address(interface, address)
|
|
return {}
|
|
|
|
def handle_route_query(self, interface: str, route: O[str] = None) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
if route:
|
|
route = ipaddress.ip_network(route)
|
|
return {'routes': [str(x) for x in self.conn.get_routes(interface) if route in (None, x)]}
|
|
|
|
def handle_route_create(self, interface: str, route: str) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
route = ipaddress.ip_network(route)
|
|
self._check_route(route)
|
|
if route in self.conn.get_routes(interface):
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.Conflict, f'route exists on interface: {route}')
|
|
self.conn.add_route(interface, route)
|
|
return {}
|
|
|
|
def handle_route_destroy(self, interface: str, route: str) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
route = ipaddress.ip_network(route)
|
|
self._check_route(route)
|
|
if route not in self.conn.get_addresses(interface):
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.NotFound, f'route does not exist on interface: {route}')
|
|
self.conn.delete_route(interface, route)
|
|
return {}
|
|
|
|
def handle_peer_query(self, interface: str, peer: O[str] = None) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
return {'peers': [x for x in self.conn.get_peers(interface) if peer in (None, x)]}
|
|
|
|
def handle_peer_info(self, interface: str, peer: str) -> Dict[str, Any]:
|
|
self._check_existing_interface(interface)
|
|
info = self.conn.get_peer_info(interface, peer)
|
|
if not info:
|
|
return None
|
|
return {
|
|
'latest_handshake': int(info.latest_handshake.timestamp()),
|
|
'endpoints': [f'{ip}:{port}' for (ip, port) in info.endpoints],
|
|
'allowed_ips': [str(x) for x in info.allowed_ips],
|
|
'sent_bytes': info.sent_bytes,
|
|
'recv_bytes': info.recv_bytes,
|
|
}
|
|
|
|
def handle_key_create(self, type: str) -> Dict[str, Any]:
|
|
if type == 'preshared':
|
|
return {'key': self.conn.gen_preshared_key()}
|
|
if type == 'private':
|
|
return {'key': self.conn.gen_private_key()}
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}')
|
|
|
|
def handle_key_convert(self, type: str, value: str) -> Dict[str, Any]:
|
|
if type == 'public':
|
|
return {'key': self.conn.get_public_key(value)}
|
|
raise WeegeeRemoteError(WeegeeRemoteErrorCode.InvalidParameter, f'unknown key type: {type}')
|
|
|
|
|
|
def _parse_req(self, data: str) -> O[Tuple[str, Dict[str, Any]]]:
|
|
d = json.loads(data)
|
|
if d['type'] != 'request':
|
|
return None
|
|
return d['command'], d['parameters']
|
|
|
|
def _make_banner(self) -> str:
|
|
return json.dumps({
|
|
'type': 'banner',
|
|
'version': f'weegee/0.1',
|
|
'protocol': REMOTE_PROTOCOL_VERSION,
|
|
})
|
|
|
|
def _make_resp(self, command: str, error: bool, /, **kwargs: Any) -> str:
|
|
d = {
|
|
'type': 'response',
|
|
'command': command,
|
|
}
|
|
if error:
|
|
d['ok'] = False
|
|
d['error'] = kwargs
|
|
else:
|
|
d['ok'] = True
|
|
d['values'] = kwargs
|
|
return json.dumps(d)
|
|
|
|
def serve(self, infile, outfile) -> None:
|
|
outfile.write(self._make_banner() + '\n')
|
|
while True:
|
|
line = infile.readline()
|
|
if not line:
|
|
break
|
|
req = self._parse_req(line)
|
|
if not req:
|
|
break
|
|
command, params = req
|
|
|
|
if command not in self.handlers:
|
|
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'invalid command: {command}'}
|
|
error = True
|
|
elif self.allowed_commands is not None and command not in self.allowed_commands:
|
|
data = {'code': WeegeeRemoteErrorCode.NotAllowed.value, 'msg': f'not allowed to request command: {command}'}
|
|
error = True
|
|
else:
|
|
signature = self.handler_sigs[command]
|
|
try:
|
|
signature.bind(**params)
|
|
except TypeError as e:
|
|
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': str(e)}
|
|
error = True
|
|
else:
|
|
try:
|
|
data = self.handlers[command](**params)
|
|
error = False
|
|
except WeegeeRemoteError as e:
|
|
data = {'code': e.code.value, 'msg': str(e)}
|
|
error = True
|
|
except NotImplementedError as e:
|
|
data = {'code': WeegeeRemoteErrorCode.InvalidCommand.value, 'msg': f'unimplemented command: {command}'}
|
|
error = True
|
|
except (ValueError, ipaddress.AddressValueError, ipaddress.NetmaskValueError):
|
|
data = {'code': WeegeeRemoteErrorCode.InvalidParameter.value, 'msg': f'invalid parameter: {e}'}
|
|
error = True
|
|
except Exception as e:
|
|
data = {'code': WeegeeRemoteErrorCode.InternalError.value, 'msg': str(e)}
|
|
error = True
|
|
outfile.write(self._make_resp(command, error, **data) + '\n')
|