Skip to content

Commit

Permalink
revert config to previous version when remove a config (#836)
Browse files Browse the repository at this point in the history
* revert config to previous version when remove a config

* update

* update

* update

---------

Co-authored-by: yangfan100 <[email protected]>
  • Loading branch information
echoyang7 and yangfan100 authored Apr 3, 2024
1 parent 05f22b4 commit a19a25a
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions lyrebird/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def update_base_config(self):
self.write_config()

def add_config(self, config_dict: dict, rank=0, type='', override_same_type=False, level=1, apply_now=False) -> None:
config = Config(config_dict)
config = Config(config_dict, level, self.config)
config.type = type

if override_same_type:
Expand Down Expand Up @@ -184,20 +184,20 @@ def add_each_config_item(self, config):
continue
CONFIG_FUNC_MAP[key].add(value)

def remove_config(self, config, type='', level=-1, apply_now=False):
def remove_config(self, config=None, type='', apply_now=False):
remove_config = None
recover_config = None
for c in self.config_list[::-1]:
if remove_config:
recover_config = c
if c.type == type and (config is None or config == c.config):
remove_config = c
break

if c.type == type:
remove_config = c
if remove_config is None:
logger.error("No matching config found in config_list.")
return

self.unmerge_config(self.config, remove_config.config, level=level, apply_now=apply_now)
self.unmerge_config(self.config, remove_config.config, level=remove_config.level, apply_now=apply_now)

self.merge_config(self.config, recover_config.config, level=level, apply_now=apply_now)
self.merge_config(self.config, remove_config.previous_config, level=remove_config.level, apply_now=apply_now)

# todo handle the config added after remove_config

Expand Down Expand Up @@ -260,10 +260,31 @@ def write_personal_config(self):


class Config:
def __init__(self, config):
def __init__(self, config, level=1, current_config={}):
self.rank = 0
self.type = ''
self.config = config
self.level = level
self.previous_config = self._get_previuos_config(current_config, level)

def _get_previuos_config(self, current_config, level):
if not current_config:
return {}
previous_config = {}
for key in self.config.keys():
self._get_previous_config_generator(key, self.config, current_config, previous_config, level)
return previous_config

def _get_previous_config_generator(self, key, config, current_config, previous_config, level):
if not current_config or key not in current_config:
return

if level != 1 and isinstance(config[key], dict):
previous_config[key] = {}
for key_child in config[key].keys():
self._get_previous_config_generator(key_child, config[key], current_config.get(key), previous_config[key], level-1)
else:
previous_config[key] = deepcopy(current_config.get(key))


class ConfigException(Exception):
Expand Down

0 comments on commit a19a25a

Please sign in to comment.