fix more mypy issues

This commit is contained in:
Shiz 2021-12-16 03:30:12 +01:00
parent 45b3d5b189
commit d3d34c3381
4 changed files with 114 additions and 82 deletions

View File

@ -5,11 +5,11 @@ import argparse
import logging
import ipaddress
from datetime import datetime
from typing import Optional as O
from typing import Optional as O, Any
logging.basicConfig(level=logging.DEBUG)
from .dazy import Item
from .dazy import Item, export_items, import_items
from .wireguard import WireguardHostType
from . import (
WeegeeContext, WeegeeConfig,
@ -114,68 +114,17 @@ def main():
sync.set_defaults(func=do_sync, parser=sync)
def do_import(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> int:
ret = 0
res = import_items(ctx.instance, args.file.read(), duplicates=not args.force, replacements=not args.force, deletions=not args.no_delete)
if not res:
return 1
# First split...
added = {}
deleted = set()
current = None
for line in args.file:
line = line.rstrip('\n')
if line.startswith('#--- '):
current = line[len('#--- '):].strip()
if args.no_delete:
logger.error(f'import: {name}: attempted delete but --no-delete given; aborting')
ret = 1
break
deleted.add(current)
current = None
continue
if line.startswith('#+++ '):
current = line[len('#+++ '):].strip()
if current in added:
logger.warn(f'import: {name}: duplicate encountered')
continue
if not current:
logger.warn(f'import: ignoring leading line: {line}')
continue
added.setdefault(current, []).append(line.rstrip('\n'))
if ret:
return ret
# ... then process additions ...
objects = {}
success = True
for name, definition in added.items():
if Item.exists(ctx.instance, name):
if not args.force:
logger.error(f'import: {name}: already exists; rolling back')
ret = 2
break
else:
logger.info(f'import: {name}: replaced')
else:
logger.info(f'import: {name}: added')
item = Item.parse(ctx.instance, name, definition)
item.save()
objects[name] = item
# ... maybe roll back ...
if ret:
for name, o in reversed(objects.items()):
logger.info(f'import: {name}: deleted')
o.delete()
return ret
# ... and finally process deletions!
added, deleted = res
for name in deleted:
if not Item.exists(ctx.instance, name):
logger.warn(f'import: {name}: does not exist')
continue
Item.load(ctx.instance, name).delete()
logger.info(f'import: {name}: deleted')
for name in added:
logger.info(f'import: {name}: added')
return ret
return 0
import_cmd = commands.add_parser('import', help='import definitions')
import_cmd.add_argument('-n', '--no-delete', action='store_true', help='do not delete items')
@ -344,9 +293,9 @@ def main():
hook_names = host.list_hooks()
hook_reset = {}
hook_added = {}
hook_removed = {}
hook_reset: dict[str, set[str]] = {}
hook_added: dict[str, dict[str, list[str]]] = {}
hook_removed: dict[str, dict[str, list[str]]] = {}
for when in ('pre', 'post'):
hook_reset[when] = set(getattr(args, f'reset_{when}_hooks'))
unknown = hook_reset[when] - set(hook_names)
@ -492,13 +441,7 @@ def main():
def do_export(parser: argparse.ArgumentParser, args: argparse.Namespace, ctx: WeegeeContext) -> None:
peer = args.objtype.load(ctx, args.name)
interface = peer.interface.to_public()
print(f'#+++ {interface.item_name}')
for l in interface.item.purify().dump():
print(l)
print(f'#+++ {peer.item_name}')
for l in peer.item.purify().dump():
print(l)
print(export_items([interface.item.purify(), peer.item.purify()]))
export = sparser.add_parser('export', help='export peer')
export.add_argument('name', help='name')
@ -511,9 +454,9 @@ def main():
)
else:
connection = WeegeeConnection.load(ctx, args.name)
connection.peers += [WeegeePeer.load(ctx, peer).item for peer in args.peer]
connection.peers += [WeegeePeer.load(ctx, peer) for peer in args.peer]
if args.preshared_key:
connection.preshared_key = preshared_key
connection.preshared_key = args.preshared_key
connection.save()
connect = sparser.add_parser('connect', help='connect peers')
@ -592,7 +535,6 @@ def main():
private_key=args.private_key, public_key=args.public_key,
addresses=args.address, port=args.port, hosts=[WeegeeHost.load(ctx, name) for name in args.host or ctx.get_config().default_server_hosts],
)
interface.save()
server = WeegeeServer.create(ctx, args.name, interface, routes=args.route, host=args.endpoint, extra=args.metadata)
server.save()
@ -626,9 +568,8 @@ def main():
private_key=args.private_key, public_key=args.public_key,
addresses=args.address, hosts=[WeegeeHost.load(ctx, name) for name in args.host or ctx.get_config().default_client_hosts],
)
interface.save()
client = WeegeeClient.create(ctx, args.name, server, preshared_key=args.preshared_key, interface=interface, routes=[], extra=args.metadata)
client.save()
client, connection = WeegeeClient.create(ctx, args.name, server, preshared_key=args.preshared_key, interface=interface, routes=[], extra=args.metadata)
connection.save()
add_client = client_commands.add_parser('create', help='add new client')
add_client.add_argument('-H', '--host', action='append', default=[], help='client host(s)')

View File

@ -317,11 +317,22 @@ class WeegeePublicInterface(WeegeeBase):
return self.item.is_complete(WeegeeInterface.get_meta(self.context))
def to_full(self) -> 'WeegeeInterface':
return WeegeeInterface.load(self.context, self.name)
return WeegeeInterface(self.context, self.item)
def to_public(self) -> 'WeegeePublicInterface':
return WeegeePublicInterface.load(self.context, self.name)
def save(self, path: O[str] = None) -> None:
if self.is_full():
if type(self) != WeegeeInterface:
self.to_full().save(path=path)
return
else:
if type(self) != WeegeePublicInterface:
self.to_public().save(path=path)
return
super().save(path=path)
@dataclass(eq=False, repr=False)
class WeegeeInterface(WeegeePublicInterface):
BASE = WEEGEE_INTERFACE
@ -534,7 +545,9 @@ def sync_interface(interface: WeegeeInterface, auto=False, log=None) -> None:
other_peers = do_sync_interface(interface, all_peers, all_connections, auto=auto)
for p in other_peers:
sync_interface(p.interface, auto=auto, log=log)
peer_interface = p.interface
if peer_interface.is_full():
sync_interface(peer_interface.to_full(), auto=auto, log=log)
def sync_all_interfaces(context: WeegeeContext, auto=False, log=None) -> None:
if log is None:

View File

@ -1,4 +1,6 @@
from __future__ import annotations
from logging import getLogger
import os
from copy import deepcopy
from enum import Enum
@ -13,6 +15,8 @@ from typing import Optional as O, Union as U, Type as TypeOf, Any, Iterator, Ite
import jinja2
import jinja2.meta
logger = getLogger(__name__)
IPAddressType = U[ipaddress.IPv4Address, ipaddress.IPv6Address]
IPNetworkType = U[ipaddress.IPv4Network, ipaddress.IPv6Network]
IPInterfaceType = U[ipaddress.IPv4Interface, ipaddress.IPv6Interface]
@ -731,6 +735,81 @@ class Template:
return self.template.render(**templated.__dict__)
EXPORT_ADD_MARKER = '#+++ '
EXPORT_DEL_MARKER = '#--- '
def export_items(objects: list[Item]) -> str:
lines = []
for o in objects:
if not o.exists(o.instance, o.name):
lines.append(EXPORT_DEL_MARKER + o.name)
else:
lines.append(EXPORT_ADD_MARKER + o.name)
lines.extend(o.dump())
return '\n'.join(lines)
def import_items(instance: Instance, spec: str, duplicates: bool = True, replacements: bool = True, deletions: bool = True) -> O[tuple[set[str], set[str]]]:
current = None
operations: list[tuple[str, O[list[str]]]]
for line in spec.splitlines():
line = line.rstrip('\n')
if line.startswith(EXPORT_DEL_MARKER):
current = line[len(EXPORT_DEL_MARKER):].strip()
if not deletions:
logger.warn(f'import: {current}: deletion found, but not allowed')
continue
if current in deleted:
logger.warn(f'import: {current}: duplicate found')
if not duplicates:
return None
operations.append((current, None))
current = None
continue
if line.startswith(EXPORT_ADD_MARKER):
current = line[len(EXPORT_ADD_MARKER):].strip()
if Item.exists(instance, current) and not replacements:
logger.warn(f'import: {current}: already exists')
return None
if current in added:
logger.warn(f'import: {current}: duplicate found')
if not duplicates:
return None
operations.append((current, []))
continue
if not current:
logger.warn(f'import: ignoring leading line: {line}')
continue
item_spec = operations[-1][1]
if item_spec is not None:
item_spec.append(line)
added: dict[str, Item] = {}
deleted: dict[str, Item] = {}
for name, contents in operations:
if name not in deleted and name not in added and Item.exists(instance, name):
deleted[name] = Item.load(instance, name)
try:
if contents is None:
deleted[name].delete()
else:
item = Item.parse(instance, name, spec)
item.save()
added[name] = item
except:
logger.error(f'import: {name}: failed, rolling back')
break
else:
return set(added), set(deleted)
for name, item in added.items():
item.delete()
for name, item in deleted.items():
item.save()
return None
if __name__ == '__main__':
import sys
import argparse

View File

@ -31,12 +31,11 @@ class WeegeeServer(WeegeePeer):
@dataclass(eq=False, repr=False)
class WeegeeClient(WeegeePeer):
@classmethod
def create(cls, ctx: WeegeeContext, name: str, server: WeegeeServer, preshared_key: O[str] = None, **kwargs) -> 'WeegeeClient':
def create(cls, ctx: WeegeeContext, name: str, server: WeegeeServer, preshared_key: O[str] = None, **kwargs) -> tuple['WeegeeClient', 'WeegeeConnection']:
name = server.get_client_name(name)
peer = super().create(ctx, name, **kwargs)
connection = WeegeeConnection.create(ctx, name, peers=[server, peer], preshared_key=preshared_key)
connection.save()
return cls(ctx, peer.item)
return cls(ctx, peer.item), connection
@property
def connection(self) -> WeegeeConnection: