io: add align= to Terminated Stream, fix UCS-2 cstr, fix small errors

This commit is contained in:
Shiz 2022-05-15 06:46:23 +02:00
parent 910be3c5ec
commit 1304a56dcd
3 changed files with 24 additions and 13 deletions

View File

@ -54,7 +54,7 @@ class Switch(G[T, V], Type[T]):
def default(self, context: Context) -> T:
child = self.get_value(context, peek=True)
with context.enter(None, child):
return context.default(ChildProcessError)
return context.default(child)
def __str__(self) -> str:
return f'{format_value(self.options, str)}[{self.selector}]'
@ -106,7 +106,7 @@ class If(G[T,V], Type[U[T, V]]):
else:
child = to_type(self.false)
with context.enter(None, child):
return context.default(self.child)
return context.default(child)
def __str__(self) -> str:
return f'({self.cond} ? {self.true} : {self.false})'

View File

@ -95,11 +95,12 @@ class Sized(G[T], Wrapper[T]):
class TerminatedStream:
def __init__(self, stream: Stream, terminator: bytes, included: bool, blocksize: int = 8192) -> None:
def __init__(self, stream: Stream, terminator: bytes, included: bool, align: int = 1, blocksize: int = 8192) -> None:
self._stream = stream
self._terminator = terminator
self._included = included
self._end_pos = None
self._align = align
self._blocksize = blocksize
def read(self, n: int = -1, *, bits: bool = False) -> U[bytes, int]:
@ -115,6 +116,7 @@ class TerminatedStream:
n = min(n, remaining)
value = b''
off = 0
while n < 0 or len(value) < n:
try:
p = self._stream.tell()
@ -131,7 +133,14 @@ class TerminatedStream:
# find full terminator or start of terminator at the end
if self._end_pos is None:
termpos = find_overlap(value, self._terminator, len(value) - len(v))
while True:
termpos = find_overlap(value[off:], self._terminator, 0)
if termpos is None:
break
termpos += off
if termpos % self._align == 0:
break
off = termpos + 1
if termpos is not None:
# need to read more data?
termrem = len(self._terminator) - (len(value) - termpos)
@ -183,18 +192,19 @@ class TerminatedStream:
return getattr(self._stream, attr)
class Terminated(G[T], Wrapper[T]):
def __init__(self, child: Type[T], terminator: U[bytes, PossibleDynamic], required: U[bool, PossibleDynamic] = True, included: U[bool, PossibleDynamic] = False, blocksize: int = 8192) -> None:
def __init__(self, child: Type[T], terminator: U[bytes, PossibleDynamic], align=1, required: U[bool, PossibleDynamic] = True, included: U[bool, PossibleDynamic] = False, blocksize: int = 8192) -> None:
super().__init__(child)
self.terminator = terminator
self.required = required
self.included = included
self.blocksize = blocksize
self.align = align
def parse(self, context: Context, stream: Stream) -> T:
terminator = context.get(self.terminator)
required = context.get(self.required)
included = context.get(self.included)
tstream = TerminatedStream(stream, terminator, included, blocksize=self.blocksize)
tstream = TerminatedStream(stream, terminator, included, align=self.align, blocksize=self.blocksize)
value = super().parse(context, tstream)
if required and tstream._end_pos is None:
raise EOFError(f'terminator {terminator} not found in stream')
@ -204,7 +214,7 @@ class Terminated(G[T], Wrapper[T]):
terminator = context.get(self.terminator)
required = context.get(self.required)
included = context.get(self.included)
tstream = TerminatedStream(stream, terminator, included, blocksize=self.blocksize)
tstream = TerminatedStream(stream, terminator, included, align=self.align, blocksize=self.blocksize)
super().dump(context, tstream, value)
if required and not included:
stream.write(terminator)
@ -374,8 +384,9 @@ class AlignTo(G[T], Wrapper[T]):
align = context.get(self.alignment)
adjustment = stream.tell() % align
padding = stretch(context.get(self.value), align - adjustment)
stream.write(padding)
if adjustment:
padding = stretch(context.get(self.value), align - adjustment)
stream.write(padding)
def sizeof(self, context: Context, value: O[T]) -> O[Pos]:
# TODO
@ -406,9 +417,9 @@ class AlignedTo(G[T], Wrapper[T]):
def dump(self, context: Context, stream: Stream, value: T) -> None:
align = context.get(self.alignment)
adjustment = stream.tell() % align
padding = stretch(context.get(self.value), align - adjustment)
stream.write(padding)
if adjustment:
padding = stretch(context.get(self.value), align - adjustment)
stream.write(padding)
super().dump(context, stream, value)
def sizeof(self, context: Context, value: O[T]) -> O[Pos]:

View File

@ -59,7 +59,7 @@ class Str(Type[str]):
class CStr(Terminated[str]):
def __init__(self, *args, terminator_required=True, **kwargs) -> None:
terminator = '\x00'.encode(kwargs.get('encoding', 'utf-8'))
super().__init__(Str(*args, **kwargs), terminator, required=terminator_required, blocksize=16)
super().__init__(Str(*args, **kwargs), terminator, required=terminator_required, align=len(terminator), blocksize=16)
def __str__(self) -> str:
if self.child.encoding == 'utf-8':