diff --git a/hytek/hy3/file.py b/hytek/hy3/file.py index 15c5869..feb11ed 100644 --- a/hytek/hy3/file.py +++ b/hytek/hy3/file.py @@ -1,7 +1,7 @@ import zipfile from logging import warning -from .records import Hy3Record, RECORD_TYPES +from .records import Hy3Record, Hy3RecordGroup, RECORD_TYPES __all__ = ("Hy3File",) @@ -32,7 +32,6 @@ def read(self) -> None: if self.file is None: self.file = open(self.filepath, "rb") parents = [] - last_record = None records = [] data = self.file.read() for i, line in enumerate(data.splitlines()): @@ -46,33 +45,31 @@ def read(self) -> None: records.append(record) continue - if last_record: - # Hy3RecordGroup - if last_record.record_type[0] == record.record_type[0] and last_record.record_type[1] < record.record_type[1]: - last_record = last_record.add_to_group(record) - if parents: - parent = parents[-2] - parent.children[-1] = last_record - parents[-1] = last_record - else: - parents.append(last_record) - records.append(last_record) - continue - last_record = record - + done = False while parents: parent = parents[-1] - if parent.record_type >= record.record_type: - parents.pop() - else: + if parent._child_allowed(record): # child record + parent.append_child(record) + parents.append(record) + done = True break - if parents: - parent = parents[-1] - parent.append_child(record) - parents.append(record) - else: + elif parent._group_allowed(record): # group record? + group = parent.add_to_group(record) + parents[-1] = group + # get the group's parent + # should be the record immediately before + # the group record + group_parent = parents[-2] + # update children of the group's parent + group_parent.children[-1] = group + done = True + break + parents.pop() + if not done: records.append(record) parents.append(record) + pass + self.records = records self.file.close() self.file = None diff --git a/hytek/hy3/records.py b/hytek/hy3/records.py index 3cbb029..a240e34 100644 --- a/hytek/hy3/records.py +++ b/hytek/hy3/records.py @@ -30,6 +30,16 @@ ) +ALLOWED_CHILDREN = { + "A": ["B"], + "B": ["C"], + "C": ["D", "F"], + "D": ["E"], + "E": ["G", "H"], + "F": ["G", "H"], +} + + class Field: def __init__(self, *, name: str, length: int, type: Callable[[str], Any] = str): self.name = name @@ -46,9 +56,12 @@ def __init__(self, *records: Type[Hy3Record]): def __repr__(self): return f"" + def _child_allowed(self, child: Hy3Record | Hy3RecordGroup) -> bool: + return child.record_type[0] in ALLOWED_CHILDREN.get(self.record_type[0], []) + def append_child(self, child: Hy3Record | Hy3RecordGroup) -> None: - if child.record_type < self.group_type: - raise Exception("Child record type is less than parent") + if not self._child_allowed(child): + raise Exception("Child record type is not allowed in this parent.") self.children.append(child) def remove_child(self, child: Hy3Record | Hy3RecordGroup) -> None: @@ -74,9 +87,16 @@ def to_json(self) -> dict: def record_type(self) -> str: return self.records[-1].record_type - def add_to_group(self, record: Type[Hy3Record]) -> Self: + def _group_allowed(self, record: Type[Hy3Record]) -> bool: if record.record_type[0] != self.group_type: - raise Exception("Record type does not match group type") + return False + if int(record.record_type[1]) <= int(self.records[-1].record_type[1]): + return False + return True + + def add_to_group(self, record: Type[Hy3Record]) -> Self: + if not self._group_allowed(record): + raise Exception("Cannot group records") self.records.append(record) return self @@ -108,9 +128,12 @@ def __getitem__(self, item: str) -> Any: def __setitem__(self, item: str, value: Any) -> None: self.data[item] = value + def _child_allowed(self, child: Type[Hy3Record | Hy3RecordGroup]) -> bool: + return child.record_type[0] in ALLOWED_CHILDREN.get(self.record_type[0], []) + def append_child(self, child: Type[Hy3RecordGroup | Hy3Record]) -> None: - if child.record_type < self.record_type: - raise Exception("Child record type is less than parent") + if not self._child_allowed(child): + raise Exception("Child record type is not allowed in this parent.") self.children.append(child) def remove_child(self, child: Type['Hy3Record']) -> None: @@ -124,9 +147,16 @@ def print_tree(self, depth: int = 0) -> None: for child in self.children: child.print_tree(depth + 1) - def add_to_group(self, record: Type[Hy3Record]) -> Hy3RecordGroup: + def _group_allowed(self, record: Type[Hy3Record]) -> bool: if record.record_type[0] != self.record_type[0]: - raise Exception("Record type does not match group type") + return False + if int(record.record_type[1]) <= int(self.record_type[1]): + return False + return True + + def add_to_group(self, record: Type[Hy3Record]) -> Hy3RecordGroup: + if not self._group_allowed(record): + raise Exception("Cannot group records") group = Hy3RecordGroup(self, record) return group