struct: fix up and expose union type

This commit is contained in:
Shiz 2021-08-08 22:35:29 +02:00
parent c3d98171f3
commit 0c653fc3c3
2 changed files with 18 additions and 4 deletions

View File

@ -18,7 +18,7 @@ from .types.num import (
)
from .types.str import Str, CStr, cstr, wcstr, utf8cstr, utf16cstr
from .types.seq import Arr, Tuple
from .types.struct import StructType, Struct
from .types.struct import StructType, Struct, Union
from .types.transforms import Default, Transform, Mapped, Enum, Check, Fixed
from .types.control import Switch, If
from .types.io import Sized, Terminated, Ref, AlignTo, AlignedTo

View File

@ -11,7 +11,7 @@ from contextlib import contextmanager
import sx
from ..core import to_type
from ..core.base import Context, Type, PathElement
from ..core.io import Stream, Pos, add_sizes
from ..core.io import Stream, Pos, add_sizes, max_sizes
from ..core.util import indent, format_value, get_annot_locations
from ..core.meta import Generic, TypeSource
from ..core.expr import ProxyExpr
@ -151,7 +151,13 @@ class StructType(G[T], Type[T]):
def sizeof(self, context: Context, value: O[T]) -> O[Mapping[str, int]]:
with self.enter():
sizes = self.get_sizes(context, value, None)
return add_sizes(*sizes) if sizes else 0
if sizes:
if self.union:
return max_sizes(*sizes)
else:
return add_sizes(*sizes)
else:
return 0
def offsetof(self, context: Context, path: Sequence[PathElement], value: O[T]) -> O[int]:
if not path:
@ -166,7 +172,10 @@ class StructType(G[T], Type[T]):
child = self.fields[field]
with self.enter():
sizes = self.get_sizes(context, value, field)
if self.union:
sizes = []
else:
sizes = self.get_sizes(context, value, field)
if path:
with context.enter(field, child):
sizes.append(context.offsetof(child, path, getattr(value, field) if value is not None else None))
@ -223,6 +232,7 @@ class Struct:
if parent:
generics = parent.generics + generics
bound = parent.bound + bound
kwargs.setdefault('union', parent.union)
# Get all annotations
annots = {}
@ -315,3 +325,7 @@ class Struct:
def __repr__(self) -> str:
return self._format_(repr)
class Union(Struct, union=True, inject=False):
pass