construct/tlv: Pass optional 'context' into construct decoder/encoder

The context is some opaque dictionary that can be used by the
constructs; let's allow the caller of parse_construct,  from_bytes,
from_tlv to specify it.

Also, when decoding a TLV_IE_Collection, pass the decode results of
existing siblings via the construct.

Change-Id: I021016aaa09cddf9d36521c1a54b468ec49ff54d
This commit is contained in:
Harald Welte
2023-12-17 10:07:01 +01:00
committed by laforge
parent 188869568a
commit caef0df663
2 changed files with 36 additions and 31 deletions

View File

@@ -200,13 +200,16 @@ def normalize_construct(c):
return r return r
def parse_construct(c, raw_bin_data: bytes, length: typing.Optional[int] = None, exclude_prefix: str = '_'): def parse_construct(c, raw_bin_data: bytes, length: typing.Optional[int] = None, exclude_prefix: str = '_', context: dict = {}):
"""Helper function to wrap around normalize_construct() and filter_dict().""" """Helper function to wrap around normalize_construct() and filter_dict()."""
if not length: if not length:
length = len(raw_bin_data) length = len(raw_bin_data)
parsed = c.parse(raw_bin_data, total_len=length) parsed = c.parse(raw_bin_data, total_len=length, **context)
return normalize_construct(parsed) return normalize_construct(parsed)
def build_construct(c, decoded_data, context: dict = {}):
"""Helper function to handle total_len."""
return c.build(decoded_data, total_len=None, **context)
# here we collect some shared / common definitions of data types # here we collect some shared / common definitions of data types
LV = Prefixed(Int8ub, HexAdapter(GreedyBytes)) LV = Prefixed(Int8ub, HexAdapter(GreedyBytes))

View File

@@ -26,7 +26,7 @@ from pySim.utils import comprehensiontlv_encode_tag, comprehensiontlv_parse_tag
from pySim.utils import bertlv_parse_one, comprehensiontlv_parse_one from pySim.utils import bertlv_parse_one, comprehensiontlv_parse_one
from pySim.utils import bertlv_parse_tag_raw, comprehensiontlv_parse_tag_raw from pySim.utils import bertlv_parse_tag_raw, comprehensiontlv_parse_tag_raw
from pySim.construct import parse_construct, LV, HexAdapter, BcdAdapter, BitsRFU, GsmStringAdapter from pySim.construct import build_construct, parse_construct, LV, HexAdapter, BcdAdapter, BitsRFU, GsmStringAdapter
from pySim.exceptions import * from pySim.exceptions import *
import inspect import inspect
@@ -84,15 +84,15 @@ class Transcodable(abc.ABC):
self.decoded = None self.decoded = None
self._construct = None self._construct = None
def to_bytes(self) -> bytes: def to_bytes(self, context: dict = {}) -> bytes:
"""Convert from internal representation to binary bytes. Store the binary result """Convert from internal representation to binary bytes. Store the binary result
in the internal state and return it.""" in the internal state and return it."""
if self.decoded == None: if self.decoded == None:
do = b'' do = b''
elif self._construct: elif self._construct:
do = self._construct.build(self.decoded, total_len=None) do = build_construct(self._construct, self.decoded, context)
elif self.__class__._construct: elif self.__class__._construct:
do = self.__class__._construct.build(self.decoded, total_len=None) do = build_construct(self.__class__._construct, self.decoded, context)
else: else:
do = self._to_bytes() do = self._to_bytes()
self.encoded = do self.encoded = do
@@ -102,16 +102,16 @@ class Transcodable(abc.ABC):
def _to_bytes(self): def _to_bytes(self):
raise NotImplementedError('%s._to_bytes' % type(self).__name__) raise NotImplementedError('%s._to_bytes' % type(self).__name__)
def from_bytes(self, do: bytes): def from_bytes(self, do: bytes, context: dict = {}):
"""Convert from binary bytes to internal representation. Store the decoded result """Convert from binary bytes to internal representation. Store the decoded result
in the internal state and return it.""" in the internal state and return it."""
self.encoded = do self.encoded = do
if self.encoded == b'': if self.encoded == b'':
self.decoded = None self.decoded = None
elif self._construct: elif self._construct:
self.decoded = parse_construct(self._construct, do) self.decoded = parse_construct(self._construct, do, context=context)
elif self.__class__._construct: elif self.__class__._construct:
self.decoded = parse_construct(self.__class__._construct, do) self.decoded = parse_construct(self.__class__._construct, do, context=context)
else: else:
self.decoded = self._from_bytes(do) self.decoded = self._from_bytes(do)
return self.decoded return self.decoded
@@ -174,27 +174,27 @@ class IE(Transcodable, metaclass=TlvMeta):
return False return False
@abc.abstractmethod @abc.abstractmethod
def to_ie(self) -> bytes: def to_ie(self, context: dict = {}) -> bytes:
"""Convert the internal representation to entire IE including IE header.""" """Convert the internal representation to entire IE including IE header."""
def to_bytes(self) -> bytes: def to_bytes(self, context: dict = {}) -> bytes:
"""Convert the internal representation *of the value part* to binary bytes.""" """Convert the internal representation *of the value part* to binary bytes."""
if self.is_constructed(): if self.is_constructed():
# concatenate the encoded IE of all children to form the value part # concatenate the encoded IE of all children to form the value part
out = b'' out = b''
for c in self.children: for c in self.children:
out += c.to_ie() out += c.to_ie(context=context)
return out return out
else: else:
return super().to_bytes() return super().to_bytes(context=context)
def from_bytes(self, do: bytes): def from_bytes(self, do: bytes, context: dict = {}):
"""Parse *the value part* from binary bytes to internal representation.""" """Parse *the value part* from binary bytes to internal representation."""
if self.nested_collection: if self.nested_collection:
self.children = self.nested_collection.from_bytes(do) self.children = self.nested_collection.from_bytes(do, context=context)
else: else:
self.children = [] self.children = []
return super().from_bytes(do) return super().from_bytes(do, context=context)
class TLV_IE(IE): class TLV_IE(IE):
@@ -226,15 +226,15 @@ class TLV_IE(IE):
"""Encode the length part assuming a certain binary value. Must be provided by """Encode the length part assuming a certain binary value. Must be provided by
derived (TLV format specific) class.""" derived (TLV format specific) class."""
def to_ie(self): def to_ie(self, context: dict = {}):
return self.to_tlv() return self.to_tlv(context=context)
def to_tlv(self): def to_tlv(self, context: dict = {}):
"""Convert the internal representation to binary TLV bytes.""" """Convert the internal representation to binary TLV bytes."""
val = self.to_bytes() val = self.to_bytes(context=context)
return self._encode_tag() + self._encode_len(val) + val return self._encode_tag() + self._encode_len(val) + val
def from_tlv(self, do: bytes): def from_tlv(self, do: bytes, context: dict = {}):
if len(do) == 0: if len(do) == 0:
return {}, b'' return {}, b''
(rawtag, remainder) = self.__class__._parse_tag_raw(do) (rawtag, remainder) = self.__class__._parse_tag_raw(do)
@@ -248,7 +248,7 @@ class TLV_IE(IE):
else: else:
value = do value = do
remainder = b'' remainder = b''
dec = self.from_bytes(value) dec = self.from_bytes(value, context=context)
return dec, remainder return dec, remainder
@@ -343,7 +343,7 @@ class TLV_IE_Collection(metaclass=TlvCollectionMeta):
else: else:
raise TypeError raise TypeError
def from_bytes(self, binary: bytes) -> List[TLV_IE]: def from_bytes(self, binary: bytes, context: dict = {}) -> List[TLV_IE]:
"""Create a list of TLV_IEs from the collection based on binary input data. """Create a list of TLV_IEs from the collection based on binary input data.
Args: Args:
binary : binary bytes of encoded data binary : binary bytes of encoded data
@@ -357,6 +357,7 @@ class TLV_IE_Collection(metaclass=TlvCollectionMeta):
first = next(iter(self.members_by_tag.values())) first = next(iter(self.members_by_tag.values()))
# iterate until no binary trailer is left # iterate until no binary trailer is left
while len(remainder): while len(remainder):
context['siblings'] = res
# obtain the tag at the start of the remainder # obtain the tag at the start of the remainder
tag, r = first._parse_tag_raw(remainder) tag, r = first._parse_tag_raw(remainder)
if tag == None: if tag == None:
@@ -365,7 +366,7 @@ class TLV_IE_Collection(metaclass=TlvCollectionMeta):
cls = self.members_by_tag[tag] cls = self.members_by_tag[tag]
# create an instance and parse accordingly # create an instance and parse accordingly
inst = cls() inst = cls()
dec, remainder = inst.from_tlv(remainder) dec, remainder = inst.from_tlv(remainder, context=context)
res.append(inst) res.append(inst)
else: else:
# unknown tag; create the related class on-the-fly using the same base class # unknown tag; create the related class on-the-fly using the same base class
@@ -376,7 +377,7 @@ class TLV_IE_Collection(metaclass=TlvCollectionMeta):
cls._to_bytes = lambda s: bytes.fromhex(s.decoded['raw']) cls._to_bytes = lambda s: bytes.fromhex(s.decoded['raw'])
# create an instance and parse accordingly # create an instance and parse accordingly
inst = cls() inst = cls()
dec, remainder = inst.from_tlv(remainder) dec, remainder = inst.from_tlv(remainder, context=context)
res.append(inst) res.append(inst)
self.children = res self.children = res
return res return res
@@ -413,17 +414,18 @@ class TLV_IE_Collection(metaclass=TlvCollectionMeta):
# self.__class__.__name__, but that is usually some meaningless auto-generated collection name. # self.__class__.__name__, but that is usually some meaningless auto-generated collection name.
return [x.to_dict() for x in self.children] return [x.to_dict() for x in self.children]
def to_bytes(self): def to_bytes(self, context: dict = {}):
out = b'' out = b''
context['siblings'] = self.children
for c in self.children: for c in self.children:
out += c.to_tlv() out += c.to_tlv(context=context)
return out return out
def from_tlv(self, do): def from_tlv(self, do, context: dict = {}):
return self.from_bytes(do) return self.from_bytes(do, context=context)
def to_tlv(self): def to_tlv(self, context: dict = {}):
return self.to_bytes() return self.to_bytes(context=context)
def flatten_dict_lists(inp): def flatten_dict_lists(inp):