fix more mypy issues
This commit is contained in:
parent
45b3d5b189
commit
d3d34c3381
|
@ -5,11 +5,11 @@ import argparse
|
|||
import logging
|
||||
import ipaddress
|
||||
from datetime import datetime
|
||||
from typing import Optional as O
|
||||
from typing import Optional as O, Any
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
from .dazy import Item
|
||||
from .dazy import Item, export_items, import_items
|
||||
from .wireguard import WireguardHostType
|
||||
from . import (
|
||||
WeegeeContext, WeegeeConfig,
|
||||
|
@ -114,68 +114,17 @@ def main():
|
|||
sync.set_defaults(func=do_sync, parser=sync)
|
||||
|
||||
def do_import(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> int:
|
||||
ret = 0
|
||||
res = import_items(ctx.instance, args.file.read(), duplicates=not args.force, replacements=not args.force, deletions=not args.no_delete)
|
||||
if not res:
|
||||
return 1
|
||||
|
||||
# First split...
|
||||
added = {}
|
||||
deleted = set()
|
||||
current = None
|
||||
for line in args.file:
|
||||
line = line.rstrip('\n')
|
||||
if line.startswith('#--- '):
|
||||
current = line[len('#--- '):].strip()
|
||||
if args.no_delete:
|
||||
logger.error(f'import: {name}: attempted delete but --no-delete given; aborting')
|
||||
ret = 1
|
||||
break
|
||||
deleted.add(current)
|
||||
current = None
|
||||
continue
|
||||
if line.startswith('#+++ '):
|
||||
current = line[len('#+++ '):].strip()
|
||||
if current in added:
|
||||
logger.warn(f'import: {name}: duplicate encountered')
|
||||
continue
|
||||
if not current:
|
||||
logger.warn(f'import: ignoring leading line: {line}')
|
||||
continue
|
||||
added.setdefault(current, []).append(line.rstrip('\n'))
|
||||
|
||||
if ret:
|
||||
return ret
|
||||
|
||||
# ... then process additions ...
|
||||
objects = {}
|
||||
success = True
|
||||
for name, definition in added.items():
|
||||
if Item.exists(ctx.instance, name):
|
||||
if not args.force:
|
||||
logger.error(f'import: {name}: already exists; rolling back')
|
||||
ret = 2
|
||||
break
|
||||
else:
|
||||
logger.info(f'import: {name}: replaced')
|
||||
else:
|
||||
logger.info(f'import: {name}: added')
|
||||
item = Item.parse(ctx.instance, name, definition)
|
||||
item.save()
|
||||
objects[name] = item
|
||||
|
||||
# ... maybe roll back ...
|
||||
if ret:
|
||||
for name, o in reversed(objects.items()):
|
||||
logger.info(f'import: {name}: deleted')
|
||||
o.delete()
|
||||
return ret
|
||||
|
||||
# ... and finally process deletions!
|
||||
added, deleted = res
|
||||
for name in deleted:
|
||||
if not Item.exists(ctx.instance, name):
|
||||
logger.warn(f'import: {name}: does not exist')
|
||||
continue
|
||||
Item.load(ctx.instance, name).delete()
|
||||
logger.info(f'import: {name}: deleted')
|
||||
for name in added:
|
||||
logger.info(f'import: {name}: added')
|
||||
|
||||
return ret
|
||||
return 0
|
||||
|
||||
import_cmd = commands.add_parser('import', help='import definitions')
|
||||
import_cmd.add_argument('-n', '--no-delete', action='store_true', help='do not delete items')
|
||||
|
@ -344,9 +293,9 @@ def main():
|
|||
|
||||
hook_names = host.list_hooks()
|
||||
|
||||
hook_reset = {}
|
||||
hook_added = {}
|
||||
hook_removed = {}
|
||||
hook_reset: dict[str, set[str]] = {}
|
||||
hook_added: dict[str, dict[str, list[str]]] = {}
|
||||
hook_removed: dict[str, dict[str, list[str]]] = {}
|
||||
for when in ('pre', 'post'):
|
||||
hook_reset[when] = set(getattr(args, f'reset_{when}_hooks'))
|
||||
unknown = hook_reset[when] - set(hook_names)
|
||||
|
@ -492,13 +441,7 @@ def main():
|
|||
def do_export(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> None:
|
||||
peer = args.objtype.load(ctx, args.name)
|
||||
interface = peer.interface.to_public()
|
||||
|
||||
print(f'#+++ {interface.item_name}')
|
||||
for l in interface.item.purify().dump():
|
||||
print(l)
|
||||
print(f'#+++ {peer.item_name}')
|
||||
for l in peer.item.purify().dump():
|
||||
print(l)
|
||||
print(export_items([interface.item.purify(), peer.item.purify()]))
|
||||
|
||||
export = sparser.add_parser('export', help='export peer')
|
||||
export.add_argument('name', help='name')
|
||||
|
@ -511,9 +454,9 @@ def main():
|
|||
)
|
||||
else:
|
||||
connection = WeegeeConnection.load(ctx, args.name)
|
||||
connection.peers += [WeegeePeer.load(ctx, peer).item for peer in args.peer]
|
||||
connection.peers += [WeegeePeer.load(ctx, peer) for peer in args.peer]
|
||||
if args.preshared_key:
|
||||
connection.preshared_key = preshared_key
|
||||
connection.preshared_key = args.preshared_key
|
||||
connection.save()
|
||||
|
||||
connect = sparser.add_parser('connect', help='connect peers')
|
||||
|
@ -592,7 +535,6 @@ def main():
|
|||
private_key=args.private_key, public_key=args.public_key,
|
||||
addresses=args.address, port=args.port, hosts=[WeegeeHost.load(ctx, name) for name in args.host or ctx.get_config().default_server_hosts],
|
||||
)
|
||||
interface.save()
|
||||
server = WeegeeServer.create(ctx, args.name, interface, routes=args.route, host=args.endpoint, extra=args.metadata)
|
||||
server.save()
|
||||
|
||||
|
@ -626,9 +568,8 @@ def main():
|
|||
private_key=args.private_key, public_key=args.public_key,
|
||||
addresses=args.address, hosts=[WeegeeHost.load(ctx, name) for name in args.host or ctx.get_config().default_client_hosts],
|
||||
)
|
||||
interface.save()
|
||||
client = WeegeeClient.create(ctx, args.name, server, preshared_key=args.preshared_key, interface=interface, routes=[], extra=args.metadata)
|
||||
client.save()
|
||||
client, connection = WeegeeClient.create(ctx, args.name, server, preshared_key=args.preshared_key, interface=interface, routes=[], extra=args.metadata)
|
||||
connection.save()
|
||||
|
||||
add_client = client_commands.add_parser('create', help='add new client')
|
||||
add_client.add_argument('-H', '--host', action='append', default=[], help='client host(s)')
|
||||
|
|
|
@ -317,11 +317,22 @@ class WeegeePublicInterface(WeegeeBase):
|
|||
return self.item.is_complete(WeegeeInterface.get_meta(self.context))
|
||||
|
||||
def to_full(self) -> 'WeegeeInterface':
|
||||
return WeegeeInterface.load(self.context, self.name)
|
||||
return WeegeeInterface(self.context, self.item)
|
||||
|
||||
def to_public(self) -> 'WeegeePublicInterface':
|
||||
return WeegeePublicInterface.load(self.context, self.name)
|
||||
|
||||
def save(self, path: O[str] = None) -> None:
|
||||
if self.is_full():
|
||||
if type(self) != WeegeeInterface:
|
||||
self.to_full().save(path=path)
|
||||
return
|
||||
else:
|
||||
if type(self) != WeegeePublicInterface:
|
||||
self.to_public().save(path=path)
|
||||
return
|
||||
super().save(path=path)
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class WeegeeInterface(WeegeePublicInterface):
|
||||
BASE = WEEGEE_INTERFACE
|
||||
|
@ -534,7 +545,9 @@ def sync_interface(interface: WeegeeInterface, auto=False, log=None) -> None:
|
|||
other_peers = do_sync_interface(interface, all_peers, all_connections, auto=auto)
|
||||
|
||||
for p in other_peers:
|
||||
sync_interface(p.interface, auto=auto, log=log)
|
||||
peer_interface = p.interface
|
||||
if peer_interface.is_full():
|
||||
sync_interface(peer_interface.to_full(), auto=auto, log=log)
|
||||
|
||||
def sync_all_interfaces(context: WeegeeContext, auto=False, log=None) -> None:
|
||||
if log is None:
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from __future__ import annotations
|
||||
from logging import getLogger
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
|
@ -13,6 +15,8 @@ from typing import Optional as O, Union as U, Type as TypeOf, Any, Iterator, Ite
|
|||
import jinja2
|
||||
import jinja2.meta
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
IPAddressType = U[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||
IPNetworkType = U[ipaddress.IPv4Network, ipaddress.IPv6Network]
|
||||
IPInterfaceType = U[ipaddress.IPv4Interface, ipaddress.IPv6Interface]
|
||||
|
@ -731,6 +735,81 @@ class Template:
|
|||
return self.template.render(**templated.__dict__)
|
||||
|
||||
|
||||
EXPORT_ADD_MARKER = '#+++ '
|
||||
EXPORT_DEL_MARKER = '#--- '
|
||||
|
||||
def export_items(objects: list[Item]) -> str:
|
||||
lines = []
|
||||
for o in objects:
|
||||
if not o.exists(o.instance, o.name):
|
||||
lines.append(EXPORT_DEL_MARKER + o.name)
|
||||
else:
|
||||
lines.append(EXPORT_ADD_MARKER + o.name)
|
||||
lines.extend(o.dump())
|
||||
return '\n'.join(lines)
|
||||
|
||||
def import_items(instance: Instance, spec: str, duplicates: bool = True, replacements: bool = True, deletions: bool = True) -> O[tuple[set[str], set[str]]]:
|
||||
current = None
|
||||
|
||||
operations: list[tuple[str, O[list[str]]]]
|
||||
for line in spec.splitlines():
|
||||
line = line.rstrip('\n')
|
||||
if line.startswith(EXPORT_DEL_MARKER):
|
||||
current = line[len(EXPORT_DEL_MARKER):].strip()
|
||||
if not deletions:
|
||||
logger.warn(f'import: {current}: deletion found, but not allowed')
|
||||
continue
|
||||
if current in deleted:
|
||||
logger.warn(f'import: {current}: duplicate found')
|
||||
if not duplicates:
|
||||
return None
|
||||
operations.append((current, None))
|
||||
current = None
|
||||
continue
|
||||
if line.startswith(EXPORT_ADD_MARKER):
|
||||
current = line[len(EXPORT_ADD_MARKER):].strip()
|
||||
if Item.exists(instance, current) and not replacements:
|
||||
logger.warn(f'import: {current}: already exists')
|
||||
return None
|
||||
if current in added:
|
||||
logger.warn(f'import: {current}: duplicate found')
|
||||
if not duplicates:
|
||||
return None
|
||||
operations.append((current, []))
|
||||
continue
|
||||
if not current:
|
||||
logger.warn(f'import: ignoring leading line: {line}')
|
||||
continue
|
||||
item_spec = operations[-1][1]
|
||||
if item_spec is not None:
|
||||
item_spec.append(line)
|
||||
|
||||
added: dict[str, Item] = {}
|
||||
deleted: dict[str, Item] = {}
|
||||
for name, contents in operations:
|
||||
if name not in deleted and name not in added and Item.exists(instance, name):
|
||||
deleted[name] = Item.load(instance, name)
|
||||
try:
|
||||
if contents is None:
|
||||
deleted[name].delete()
|
||||
else:
|
||||
item = Item.parse(instance, name, spec)
|
||||
item.save()
|
||||
added[name] = item
|
||||
except:
|
||||
logger.error(f'import: {name}: failed, rolling back')
|
||||
break
|
||||
else:
|
||||
return set(added), set(deleted)
|
||||
|
||||
for name, item in added.items():
|
||||
item.delete()
|
||||
for name, item in deleted.items():
|
||||
item.save()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
import argparse
|
||||
|
|
|
@ -31,12 +31,11 @@ class WeegeeServer(WeegeePeer):
|
|||
@dataclass(eq=False, repr=False)
|
||||
class WeegeeClient(WeegeePeer):
|
||||
@classmethod
|
||||
def create(cls, ctx: WeegeeContext, name: str, server: WeegeeServer, preshared_key: O[str] = None, **kwargs) -> 'WeegeeClient':
|
||||
def create(cls, ctx: WeegeeContext, name: str, server: WeegeeServer, preshared_key: O[str] = None, **kwargs) -> tuple['WeegeeClient', 'WeegeeConnection']:
|
||||
name = server.get_client_name(name)
|
||||
peer = super().create(ctx, name, **kwargs)
|
||||
connection = WeegeeConnection.create(ctx, name, peers=[server, peer], preshared_key=preshared_key)
|
||||
connection.save()
|
||||
return cls(ctx, peer.item)
|
||||
return cls(ctx, peer.item), connection
|
||||
|
||||
@property
|
||||
def connection(self) -> WeegeeConnection:
|
||||
|
|
Loading…
Reference in New Issue