io: add rudimentary Terminated type and rewrite str
This commit is contained in:
parent
acefb8ae82
commit
b75c4f47b0
|
@ -16,12 +16,12 @@ from .types.num import (
|
|||
float32, float32le, float32be, binary32, binary32le, binary32be, float_,
|
||||
float64, float64le, float64be, binary64, binary64le, binary64be, double,
|
||||
)
|
||||
from .types.str import Str, StrType
|
||||
from .types.str import Str, CStr, cstr, wcstr, utf8cstr, utf16cstr
|
||||
from .types.seq import Arr, Tuple
|
||||
from .types.struct import StructType, Struct
|
||||
from .types.transforms import Default, Transform, Mapped, Enum, Check, Fixed
|
||||
from .types.control import Switch, If
|
||||
from .types.io import Sized, Ref, AlignTo, AlignedTo
|
||||
from .types.io import Sized, Terminated, Ref, AlignTo, AlignedTo
|
||||
del types
|
||||
|
||||
__all__ = [k for k in globals() if not k.startswith('_')]
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
import inspect
|
||||
import ast
|
||||
import collections
|
||||
from typing import BinaryIO, Generator, Callable, Union as U, Tuple, Mapping, Any, cast
|
||||
from typing import BinaryIO, Generator, Callable, Optional as O, Union as U, Tuple, Mapping, Any, cast
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
|
@ -93,3 +93,15 @@ def get_annot_locations(cls: type) -> Tuple[str, Mapping[str, int]]:
|
|||
lines[t.id] = start + b.lineno - 2
|
||||
|
||||
return fn, lines
|
||||
|
||||
|
||||
def find_overlap(haystack: bytes, needle: bytes, start: int = 0) -> O[int]:
|
||||
""" Find occurrence of `needle` in `haystack` or start of `needle` at the end of `haystack` """
|
||||
if needle in haystack[start:]:
|
||||
return haystack.index(needle, start)
|
||||
n = len(needle) - 1
|
||||
while n > 0:
|
||||
if haystack[-n:] == needle[:n]:
|
||||
return len(haystack) - n
|
||||
n -= 1
|
||||
return None
|
||||
|
|
137
sx/types/io.py
137
sx/types/io.py
|
@ -2,9 +2,9 @@ from typing import Any, Generic as G, TypeVar, Union as U, Optional as O, Sequen
|
|||
import os
|
||||
import errno
|
||||
from ..core.base import Type, Context, PossibleDynamic, PathElement
|
||||
from ..core.io import Stream, Segment, Pos
|
||||
from ..core.io import Stream, Segment, Pos, add_sizes
|
||||
from ..core.meta import Wrapper
|
||||
from ..core.util import stretch, seeking
|
||||
from ..core.util import stretch, seeking, find_overlap
|
||||
|
||||
|
||||
class SizedStream:
|
||||
|
@ -61,7 +61,7 @@ class SizedStream:
|
|||
T = TypeVar('T')
|
||||
|
||||
class Sized(G[T], Wrapper[T]):
|
||||
def __init__(self, child: Type[T], limit: U[Pos, PossibleDynamic]):
|
||||
def __init__(self, child: Type[T], limit: U[Pos, PossibleDynamic]) -> None:
|
||||
super().__init__(child)
|
||||
self.limit = limit
|
||||
|
||||
|
@ -72,7 +72,7 @@ class Sized(G[T], Wrapper[T]):
|
|||
stream.seek(start + limit, os.SEEK_SET)
|
||||
return value
|
||||
|
||||
def dump(self, context: Context, stream: Stream, value: O[T]) -> None:
|
||||
def dump(self, context: Context, stream: Stream, value: T) -> None:
|
||||
limit = max(0, context.get(self.limit))
|
||||
start = stream.tell()
|
||||
super().dump(context, SizedStream(stream, limit), value)
|
||||
|
@ -82,6 +82,135 @@ class Sized(G[T], Wrapper[T]):
|
|||
return context.peek(self.limit)
|
||||
|
||||
|
||||
class TerminatedStream:
|
||||
def __init__(self, stream: Stream, terminator: bytes, included: bool, blocksize: int = 8192) -> None:
|
||||
self._stream = stream
|
||||
self._terminator = terminator
|
||||
self._included = included
|
||||
self._end_pos = None
|
||||
self._blocksize = blocksize
|
||||
|
||||
def read(self, n: int = -1, *, bits: bool = False) -> U[bytes, int]:
|
||||
if bits:
|
||||
raise ValueError('terminated streams can not use bit-level reads')
|
||||
if self._end_pos is not None:
|
||||
if self._stream.tell() >= self._end_pos:
|
||||
raise EOFError
|
||||
remaining = self._end_pos - self._stream.tell()
|
||||
if n < 0:
|
||||
n = remaining
|
||||
else:
|
||||
n = min(n, remaining)
|
||||
|
||||
value = b''
|
||||
while n < 0 or len(value) < n:
|
||||
try:
|
||||
p = self._stream.tell()
|
||||
v = self._stream.read(n if n >= 0 else self._blocksize)
|
||||
except EOFError:
|
||||
if n < 0:
|
||||
self._stream.seek(p, os.SEEK_SET)
|
||||
v = self._stream.read(-1)
|
||||
else:
|
||||
raise
|
||||
if not v:
|
||||
break
|
||||
value += v
|
||||
|
||||
# find full terminator or start of terminator at the end
|
||||
if self._end_pos is None:
|
||||
termpos = find_overlap(value, self._terminator, len(value) - len(v))
|
||||
if termpos is not None:
|
||||
# need to read more data?
|
||||
termrem = len(self._terminator) - (len(value) - termpos)
|
||||
|
||||
if termrem > 0:
|
||||
p = self._stream.tell()
|
||||
try:
|
||||
value += self._stream.read(termrem)
|
||||
except EOFError:
|
||||
pass
|
||||
if value.endswith(self._terminator):
|
||||
termrem = 0
|
||||
else:
|
||||
self._stream.seek(p, os.SEEK_SET)
|
||||
|
||||
if termrem <= 0:
|
||||
# terminator found, reset overread data
|
||||
self._stream.seek(-(len(value) - (termpos + len(self._terminator))), os.SEEK_CUR)
|
||||
self._end_pos = self._stream.tell() - len(self._terminator)
|
||||
if self._included:
|
||||
termpos += len(self._terminator)
|
||||
self._end_pos += len(self._terminator)
|
||||
value = value[:termpos]
|
||||
break
|
||||
|
||||
if n > 0 and len(value) != n:
|
||||
raise EOFError
|
||||
|
||||
return value
|
||||
|
||||
def seek(self, pos: U[int, float], whence: int = os.SEEK_SET) -> None:
|
||||
if self._end_pos is not None:
|
||||
if whence == os.SEEK_SET:
|
||||
pos = min(pos, self._end_pos)
|
||||
elif whence == os.SEEK_CUR:
|
||||
pos = min(pos + self._stream.tell(), self._end_pos)
|
||||
elif whence == os.SEEK_END:
|
||||
pos += self._end_pos
|
||||
whence = os.SEEK_SET
|
||||
return self._stream.seek(pos, whence)
|
||||
|
||||
def tell(self) -> U[int, float]:
|
||||
pos = self._stream.tell()
|
||||
if self._end_pos is not None:
|
||||
pos = min(pos, self._end_pos)
|
||||
return pos
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
return getattr(self._stream, attr)
|
||||
|
||||
class Terminated(G[T], Wrapper[T]):
|
||||
def __init__(self, child: Type[T], terminator: U[bytes, PossibleDynamic], required: U[bool, PossibleDynamic] = True, included: U[bool, PossibleDynamic] = False, blocksize: int = 8192) -> None:
|
||||
super().__init__(child)
|
||||
self.terminator = terminator
|
||||
self.required = required
|
||||
self.included = included
|
||||
self.blocksize = blocksize
|
||||
|
||||
def parse(self, context: Context, stream: Stream) -> T:
|
||||
terminator = context.get(self.terminator)
|
||||
required = context.get(self.required)
|
||||
included = context.get(self.included)
|
||||
tstream = TerminatedStream(stream, terminator, included, blocksize=self.blocksize)
|
||||
value = super().parse(context, tstream)
|
||||
if required and tstream._end_pos is None:
|
||||
raise IOError(f'terminator {terminator} not found in stream')
|
||||
return value
|
||||
|
||||
def dump(self, context: Context, stream: Stream, value: T) -> None:
|
||||
terminator = context.get(self.terminator)
|
||||
required = context.get(self.required)
|
||||
included = context.get(self.included)
|
||||
tstream = TerminatedStream(stream, terminator, included, blocksize=self.blocksize)
|
||||
super().dump(context, tstream, value)
|
||||
if required and not included:
|
||||
stream.write(terminator)
|
||||
|
||||
def sizeof(self, context: Context, value: O[T]) -> O[Pos]:
|
||||
terminator = context.peek(self.terminator)
|
||||
required = context.peek(self.required)
|
||||
included = context.peek(self.included)
|
||||
if not required:
|
||||
return None
|
||||
size = super().sizeof(context, value)
|
||||
if size is None:
|
||||
return None
|
||||
if not included:
|
||||
size = add_sizes(size, context.to_size(len(terminator)))
|
||||
return size
|
||||
|
||||
|
||||
class Ref(G[T], Wrapper[T]):
|
||||
def __init__(self, child: Type[T], pos: U[PossibleDynamic, Pos], whence: int = os.SEEK_SET, segment: O[Segment] = None) -> None:
|
||||
super().__init__(child)
|
||||
|
|
|
@ -2,57 +2,80 @@ import enum
|
|||
from typing import Optional as O, Union as U
|
||||
from ..core.base import PossibleDynamic as D, Type, Context
|
||||
from ..core.io import Stream
|
||||
from ..core.util import stretch
|
||||
from .num import uint8
|
||||
from .io import Terminated
|
||||
|
||||
|
||||
class StrType(enum.Enum):
|
||||
Raw = enum.auto()
|
||||
ZeroTerminated = C = enum.auto()
|
||||
LengthPrefixed = Pascal = enum.auto()
|
||||
|
||||
class Str(Type[str]):
|
||||
def __init__(self, type: U[D, StrType], length: U[D, O[int]] = None, encoding: U[D, str] = 'utf-8', char_size: U[D, int] = 1, length_type: U[D, Type] = uint8, terminator: U[D, bytes] = b'\x00') -> None:
|
||||
self.type = type
|
||||
def __init__(self, length: U[D, O[int]] = None, encoding: U[D, str] = 'utf-8', char_size: U[D, int] = 1) -> None:
|
||||
self.length = length
|
||||
self.encoding = encoding
|
||||
self.char_size = char_size
|
||||
self.length_type = length_type
|
||||
self.terminator = terminator
|
||||
|
||||
def parse(self, context: Context, stream: Stream) -> str:
|
||||
type = context.get(self.type)
|
||||
length = context.get(self.length)
|
||||
encoding = context.get(self.encoding)
|
||||
char_size = context.get(self.char_size)
|
||||
|
||||
if type == StrType.Raw:
|
||||
if length is None:
|
||||
data = stream.read()
|
||||
else:
|
||||
data = stream.read(length * char_size)
|
||||
elif type == StrType.C:
|
||||
terminator = context.get(self.terminator)
|
||||
if len(terminator) != char_size:
|
||||
terminator = stretch(terminator, char_size)
|
||||
data = b''
|
||||
while True:
|
||||
d = stream.read(char_size)
|
||||
if d == terminator:
|
||||
break
|
||||
data += d
|
||||
if length is not None and len(data) >= length * char_size:
|
||||
break
|
||||
elif type == StrType.Pascal:
|
||||
length_type = context.get(self.length_type)
|
||||
plength = context.parse(length_type, stream)
|
||||
if length is None:
|
||||
data = stream.read()
|
||||
else:
|
||||
data = stream.read(length * char_size)
|
||||
|
||||
return data.decode(encoding)
|
||||
|
||||
def dump(self, context: Context, stream: Stream, value: str) -> None:
|
||||
pass
|
||||
encoding = context.get(self.encoding)
|
||||
char_size = context.get(self.char_size)
|
||||
|
||||
bs = value.encode(encoding)
|
||||
length = len(bs) // char_size
|
||||
stream.write(bs)
|
||||
|
||||
context.put(self.length, length)
|
||||
|
||||
def default(self, context: Context) -> str:
|
||||
return ''
|
||||
|
||||
def sizeof(self, context: Context, value: O[str]) -> O[int]:
|
||||
return 0
|
||||
if value is not None:
|
||||
return len(value.encode(context.peek(self.encoding)))
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.length is not None:
|
||||
length = f'({self.length})'
|
||||
else:
|
||||
length = ''
|
||||
if self.encoding != 'utf-8':
|
||||
encoding = f'.{self.encoding}'
|
||||
else:
|
||||
encoding = ''
|
||||
return f'str{self.encoding}{length}'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{__name__}.Str(length={self.length!r}, encoding={self.encoding!r}, char_size={self.char_size!r})'
|
||||
|
||||
|
||||
class CStr(Terminated[str]):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
terminator = '\x00'.encode(kwargs.get('encoding', 'utf-8'))
|
||||
super().__init__(Str(*args, **kwargs), terminator, required=True, blocksize=16)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.child.encoding == 'utf-8':
|
||||
return 'cstr'
|
||||
if self.child.encoding == 'utf-16le':
|
||||
return 'wcstr'
|
||||
if self.child.encoding == 'sjis':
|
||||
return 'jcstr'
|
||||
return str(self.child)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{__name__}.CStr({self.child!r})>'
|
||||
|
||||
cstr = \
|
||||
utf8cstr = CStr()
|
||||
wcstr = \
|
||||
utf16cstr = CStr(encoding='utf-16le', char_size=2)
|
||||
jcstr = \
|
||||
sjiscstr = CStr(encoding='sjis')
|
||||
|
|
Loading…
Reference in New Issue