Compare commits

...

2 Commits

Author SHA1 Message Date
Shiz dcdb37b0f4 union: fix dumping semantics 2021-08-08 23:00:18 +02:00
Shiz bfb2a17b71 struct: fix up and expose union type 2021-08-08 22:35:29 +02:00
2 changed files with 37 additions and 5 deletions

View File

@ -18,7 +18,7 @@ from .types.num import (
)
from .types.str import Str, CStr, cstr, wcstr, utf8cstr, utf16cstr
from .types.seq import Arr, Tuple
from .types.struct import StructType, Struct
from .types.struct import StructType, Struct, Union
from .types.transforms import Default, Transform, Mapped, Enum, Check, Fixed
from .types.control import Switch, If
from .types.io import Sized, Terminated, Ref, AlignTo, AlignedTo

View File

@ -11,7 +11,7 @@ from contextlib import contextmanager
import sx
from ..core import to_type
from ..core.base import Context, Type, PathElement
from ..core.io import Stream, Pos, add_sizes
from ..core.io import Stream, Pos, add_sizes, max_sizes
from ..core.util import indent, format_value, get_annot_locations
from ..core.meta import Generic, TypeSource
from ..core.expr import ProxyExpr
@ -110,7 +110,16 @@ class StructType(G[T], Type[T]):
pos = stream.tell()
with self.enter():
for name, type in self.fields.items():
if self.union:
if value._sx_lastset_:
fields = [value._sx_lastset_]
else:
fields = next(self.fields)
else:
fields = list(self.fields)
for name in fields:
type = self.fields[name]
with context.enter(name, type):
if self.union:
stream.seek(pos, os.SEEK_SET)
@ -151,7 +160,13 @@ class StructType(G[T], Type[T]):
def sizeof(self, context: Context, value: O[T]) -> O[Mapping[str, int]]:
with self.enter():
sizes = self.get_sizes(context, value, None)
return add_sizes(*sizes) if sizes else 0
if sizes:
if self.union:
return max_sizes(*sizes)
else:
return add_sizes(*sizes)
else:
return 0
def offsetof(self, context: Context, path: Sequence[PathElement], value: O[T]) -> O[int]:
if not path:
@ -166,7 +181,10 @@ class StructType(G[T], Type[T]):
child = self.fields[field]
with self.enter():
sizes = self.get_sizes(context, value, field)
if self.union:
sizes = []
else:
sizes = self.get_sizes(context, value, field)
if path:
with context.enter(field, child):
sizes.append(context.offsetof(child, path, getattr(value, field) if value is not None else None))
@ -223,6 +241,7 @@ class Struct:
if parent:
generics = parent.generics + generics
bound = parent.bound + bound
kwargs.setdefault('union', parent.union)
# Get all annotations
annots = {}
@ -315,3 +334,16 @@ class Struct:
def __repr__(self) -> str:
return self._format_(repr)
class Union(Struct, union=True, inject=False):
_sx_lastset_ = ''
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
if kwargs:
super().__setattr__('_sx_lastset_', list(kwargs)[-1])
def __setattr__(self, name: str, value: Any) -> None:
super().__setattr__(name, value)
super().__setattr__('_sx_lastset_', name)