io: add rudimentary Terminated type and rewrite str

This commit is contained in:
Shiz 2021-07-08 04:02:48 +02:00
parent acefb8ae82
commit b75c4f47b0
4 changed files with 205 additions and 41 deletions

View File

@ -16,12 +16,12 @@ from .types.num import (
float32, float32le, float32be, binary32, binary32le, binary32be, float_,
float64, float64le, float64be, binary64, binary64le, binary64be, double,
)
from .types.str import Str, StrType
from .types.str import Str, CStr, cstr, wcstr, utf8cstr, utf16cstr
from .types.seq import Arr, Tuple
from .types.struct import StructType, Struct
from .types.transforms import Default, Transform, Mapped, Enum, Check, Fixed
from .types.control import Switch, If
from .types.io import Sized, Ref, AlignTo, AlignedTo
from .types.io import Sized, Terminated, Ref, AlignTo, AlignedTo
del types
__all__ = [k for k in globals() if not k.startswith('_')]

View File

@ -2,7 +2,7 @@ import os
import inspect
import ast
import collections
from typing import BinaryIO, Generator, Callable, Union as U, Tuple, Mapping, Any, cast
from typing import BinaryIO, Generator, Callable, Optional as O, Union as U, Tuple, Mapping, Any, cast
from contextlib import contextmanager
@ -93,3 +93,15 @@ def get_annot_locations(cls: type) -> Tuple[str, Mapping[str, int]]:
lines[t.id] = start + b.lineno - 2
return fn, lines
def find_overlap(haystack: bytes, needle: bytes, start: int = 0) -> O[int]:
""" Find occurrence of `needle` in `haystack` or start of `needle` at the end of `haystack` """
if needle in haystack[start:]:
return haystack.index(needle, start)
n = len(needle) - 1
while n > 0:
if haystack[-n:] == needle[:n]:
return len(haystack) - n
n -= 1
return None

View File

@ -2,9 +2,9 @@ from typing import Any, Generic as G, TypeVar, Union as U, Optional as O, Sequen
import os
import errno
from ..core.base import Type, Context, PossibleDynamic, PathElement
from ..core.io import Stream, Segment, Pos
from ..core.io import Stream, Segment, Pos, add_sizes
from ..core.meta import Wrapper
from ..core.util import stretch, seeking
from ..core.util import stretch, seeking, find_overlap
class SizedStream:
@ -61,7 +61,7 @@ class SizedStream:
T = TypeVar('T')
class Sized(G[T], Wrapper[T]):
def __init__(self, child: Type[T], limit: U[Pos, PossibleDynamic]):
def __init__(self, child: Type[T], limit: U[Pos, PossibleDynamic]) -> None:
super().__init__(child)
self.limit = limit
@ -72,7 +72,7 @@ class Sized(G[T], Wrapper[T]):
stream.seek(start + limit, os.SEEK_SET)
return value
def dump(self, context: Context, stream: Stream, value: O[T]) -> None:
def dump(self, context: Context, stream: Stream, value: T) -> None:
limit = max(0, context.get(self.limit))
start = stream.tell()
super().dump(context, SizedStream(stream, limit), value)
@ -82,6 +82,135 @@ class Sized(G[T], Wrapper[T]):
return context.peek(self.limit)
class TerminatedStream:
def __init__(self, stream: Stream, terminator: bytes, included: bool, blocksize: int = 8192) -> None:
self._stream = stream
self._terminator = terminator
self._included = included
self._end_pos = None
self._blocksize = blocksize
def read(self, n: int = -1, *, bits: bool = False) -> U[bytes, int]:
if bits:
raise ValueError('terminated streams can not use bit-level reads')
if self._end_pos is not None:
if self._stream.tell() >= self._end_pos:
raise EOFError
remaining = self._end_pos - self._stream.tell()
if n < 0:
n = remaining
else:
n = min(n, remaining)
value = b''
while n < 0 or len(value) < n:
try:
p = self._stream.tell()
v = self._stream.read(n if n >= 0 else self._blocksize)
except EOFError:
if n < 0:
self._stream.seek(p, os.SEEK_SET)
v = self._stream.read(-1)
else:
raise
if not v:
break
value += v
# find full terminator or start of terminator at the end
if self._end_pos is None:
termpos = find_overlap(value, self._terminator, len(value) - len(v))
if termpos is not None:
# need to read more data?
termrem = len(self._terminator) - (len(value) - termpos)
if termrem > 0:
p = self._stream.tell()
try:
value += self._stream.read(termrem)
except EOFError:
pass
if value.endswith(self._terminator):
termrem = 0
else:
self._stream.seek(p, os.SEEK_SET)
if termrem <= 0:
# terminator found, reset overread data
self._stream.seek(-(len(value) - (termpos + len(self._terminator))), os.SEEK_CUR)
self._end_pos = self._stream.tell() - len(self._terminator)
if self._included:
termpos += len(self._terminator)
self._end_pos += len(self._terminator)
value = value[:termpos]
break
if n > 0 and len(value) != n:
raise EOFError
return value
def seek(self, pos: U[int, float], whence: int = os.SEEK_SET) -> None:
if self._end_pos is not None:
if whence == os.SEEK_SET:
pos = min(pos, self._end_pos)
elif whence == os.SEEK_CUR:
pos = min(pos + self._stream.tell(), self._end_pos)
elif whence == os.SEEK_END:
pos += self._end_pos
whence = os.SEEK_SET
return self._stream.seek(pos, whence)
def tell(self) -> U[int, float]:
pos = self._stream.tell()
if self._end_pos is not None:
pos = min(pos, self._end_pos)
return pos
def __getattr__(self, attr: str) -> Any:
return getattr(self._stream, attr)
class Terminated(G[T], Wrapper[T]):
def __init__(self, child: Type[T], terminator: U[bytes, PossibleDynamic], required: U[bool, PossibleDynamic] = True, included: U[bool, PossibleDynamic] = False, blocksize: int = 8192) -> None:
super().__init__(child)
self.terminator = terminator
self.required = required
self.included = included
self.blocksize = blocksize
def parse(self, context: Context, stream: Stream) -> T:
terminator = context.get(self.terminator)
required = context.get(self.required)
included = context.get(self.included)
tstream = TerminatedStream(stream, terminator, included, blocksize=self.blocksize)
value = super().parse(context, tstream)
if required and tstream._end_pos is None:
raise IOError(f'terminator {terminator} not found in stream')
return value
def dump(self, context: Context, stream: Stream, value: T) -> None:
terminator = context.get(self.terminator)
required = context.get(self.required)
included = context.get(self.included)
tstream = TerminatedStream(stream, terminator, included, blocksize=self.blocksize)
super().dump(context, tstream, value)
if required and not included:
stream.write(terminator)
def sizeof(self, context: Context, value: O[T]) -> O[Pos]:
terminator = context.peek(self.terminator)
required = context.peek(self.required)
included = context.peek(self.included)
if not required:
return None
size = super().sizeof(context, value)
if size is None:
return None
if not included:
size = add_sizes(size, context.to_size(len(terminator)))
return size
class Ref(G[T], Wrapper[T]):
def __init__(self, child: Type[T], pos: U[PossibleDynamic, Pos], whence: int = os.SEEK_SET, segment: O[Segment] = None) -> None:
super().__init__(child)

View File

@ -2,57 +2,80 @@ import enum
from typing import Optional as O, Union as U
from ..core.base import PossibleDynamic as D, Type, Context
from ..core.io import Stream
from ..core.util import stretch
from .num import uint8
from .io import Terminated
class StrType(enum.Enum):
Raw = enum.auto()
ZeroTerminated = C = enum.auto()
LengthPrefixed = Pascal = enum.auto()
class Str(Type[str]):
def __init__(self, type: U[D, StrType], length: U[D, O[int]] = None, encoding: U[D, str] = 'utf-8', char_size: U[D, int] = 1, length_type: U[D, Type] = uint8, terminator: U[D, bytes] = b'\x00') -> None:
self.type = type
def __init__(self, length: U[D, O[int]] = None, encoding: U[D, str] = 'utf-8', char_size: U[D, int] = 1) -> None:
self.length = length
self.encoding = encoding
self.char_size = char_size
self.length_type = length_type
self.terminator = terminator
def parse(self, context: Context, stream: Stream) -> str:
type = context.get(self.type)
length = context.get(self.length)
encoding = context.get(self.encoding)
char_size = context.get(self.char_size)
if type == StrType.Raw:
if length is None:
data = stream.read()
else:
data = stream.read(length * char_size)
elif type == StrType.C:
terminator = context.get(self.terminator)
if len(terminator) != char_size:
terminator = stretch(terminator, char_size)
data = b''
while True:
d = stream.read(char_size)
if d == terminator:
break
data += d
if length is not None and len(data) >= length * char_size:
break
elif type == StrType.Pascal:
length_type = context.get(self.length_type)
plength = context.parse(length_type, stream)
if length is None:
data = stream.read()
else:
data = stream.read(length * char_size)
return data.decode(encoding)
def dump(self, context: Context, stream: Stream, value: str) -> None:
pass
encoding = context.get(self.encoding)
char_size = context.get(self.char_size)
bs = value.encode(encoding)
length = len(bs) // char_size
stream.write(bs)
context.put(self.length, length)
def default(self, context: Context) -> str:
return ''
def sizeof(self, context: Context, value: O[str]) -> O[int]:
return 0
if value is not None:
return len(value.encode(context.peek(self.encoding)))
return None
def __str__(self) -> str:
if self.length is not None:
length = f'({self.length})'
else:
length = ''
if self.encoding != 'utf-8':
encoding = f'.{self.encoding}'
else:
encoding = ''
return f'str{self.encoding}{length}'
def __repr__(self) -> str:
return f'{__name__}.Str(length={self.length!r}, encoding={self.encoding!r}, char_size={self.char_size!r})'
class CStr(Terminated[str]):
def __init__(self, *args, **kwargs) -> None:
terminator = '\x00'.encode(kwargs.get('encoding', 'utf-8'))
super().__init__(Str(*args, **kwargs), terminator, required=True, blocksize=16)
def __str__(self) -> str:
if self.child.encoding == 'utf-8':
return 'cstr'
if self.child.encoding == 'utf-16le':
return 'wcstr'
if self.child.encoding == 'sjis':
return 'jcstr'
return str(self.child)
def __repr__(self) -> str:
return f'<{__name__}.CStr({self.child!r})>'
cstr = \
utf8cstr = CStr()
wcstr = \
utf16cstr = CStr(encoding='utf-16le', char_size=2)
jcstr = \
sjiscstr = CStr(encoding='sjis')