diff --git a/.gitignore b/.gitignore index 321634843..92e7c7929 100644 --- a/.gitignore +++ b/.gitignore @@ -121,3 +121,5 @@ venv.bak/ # Srun *.out batchscript-* +work_dir +mmdeploy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 491ddaa78..cd73ef928 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,4 +68,5 @@ repos: ^test | ^docs | ^configs + | ^.*/configs* ) diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index 38abd2fcc..3de024635 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -66,6 +66,7 @@ def __init__(self, dict, Type[MutableChannelUnit]] = SequentialMutableChannelUnit, parse_cfg: Dict = dict( + _scope_='mmrazor', type='ChannelAnalyzer', demo_input=(1, 3, 224, 224), tracer_type='BackwardTracer'), diff --git a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py index 75a1db293..63c60e8cb 100644 --- a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py @@ -6,6 +6,7 @@ from mmrazor.registry import TASK_UTILS from mmrazor.utils import get_placeholder +from ...algorithms.base import BaseAlgorithm from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput, DefaultMMDemoInput, DefaultMMDetDemoInput, DefaultMMPoseDemoInput, DefaultMMRotateDemoInput, @@ -70,8 +71,12 @@ def get_default_demo_input_class(model, scope): def defaul_demo_inputs(model, input_shape, training=False, scope=None): """Get demo input according to a model and scope.""" - demo_input = get_default_demo_input_class(model, scope) - return demo_input().get_data(model, input_shape, training) + if isinstance(model, BaseAlgorithm): + return defaul_demo_inputs(model.architecture, input_shape, training, + scope) + else: + demo_input = get_default_demo_input_class(model, scope) + return demo_input().get_data(model, input_shape, training) @TASK_UTILS.register_module() diff --git a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py index 8664f3a2d..ab0dfb4b5 100644 --- a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py @@ -51,7 +51,9 @@ def _get_data(self, model, input_shape=None, training=None): return data def _get_mm_data(self, model, input_shape, training=False): - return {'inputs': torch.rand(input_shape), 'data_samples': None} + data = {'inputs': torch.rand(input_shape), 'data_samples': None} + data = model.data_preprocessor(data, training) + return data @TASK_UTILS.register_module() @@ -132,7 +134,7 @@ def _get_mm_data(self, model, input_shape, training=False): from mmpose.models import TopdownPoseEstimator from .mmpose_demo_input import demo_mmpose_inputs - assert isinstance(model, TopdownPoseEstimator) + assert isinstance(model, TopdownPoseEstimator), f'{type(model)}' data = demo_mmpose_inputs(model, input_shape) return data diff --git a/projects/cores/__init__.py b/projects/cores/__init__.py new file mode 100644 index 000000000..7aafea3e9 --- /dev/null +++ b/projects/cores/__init__.py @@ -0,0 +1,3 @@ +from .counters import * # noqa +from .hooks import * # noqa +from .models import * # noqa diff --git a/projects/cores/counters.py b/projects/cores/counters.py new file mode 100644 index 000000000..42f7b921f --- /dev/null +++ b/projects/cores/counters.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn + +from mmrazor.models.task_modules.estimators.counters import (Conv2dCounter, + LinearCounter) +from mmrazor.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class DynamicConv2dCounter(Conv2dCounter): + """Flop counter for DynamicCon2d.""" + + @staticmethod + def add_count_hook(module: nn.Conv2d, input: Tuple[torch.Tensor], + output: torch.Tensor) -> None: + """Count the flops and params of a DynamicConv2d. + + Args: + module (nn.Conv2d): A Conv2d module. + input (Tuple[torch.Tensor]): Input of this module. + output (torch.Tensor): Output of this module. + """ + batch_size = input[0].shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(module.kernel_size) + + out_channels = module.mutable_attrs['out_channels'].activated_channels + in_channels = module.mutable_attrs['in_channels'].activated_channels + + groups = module.groups + + filters_per_channel = out_channels / groups + conv_per_position_flops = int( + np.prod(kernel_dims)) * in_channels * filters_per_channel + + active_elements_count = batch_size * int(np.prod(output_dims)) + + overall_conv_flops = conv_per_position_flops * active_elements_count + overall_params = conv_per_position_flops + + bias_flops = 0 + overall_params = conv_per_position_flops + if module.bias is not None: + bias_flops = out_channels * active_elements_count + overall_params += out_channels + + overall_flops = overall_conv_flops + bias_flops + + module.__flops__ += overall_flops + module.__params__ += int(overall_params) + + +@TASK_UTILS.register_module() +class DynamicLinearCounter(LinearCounter): + pass diff --git a/projects/cores/expandable_ops/__init__.py b/projects/cores/expandable_ops/__init__.py new file mode 100644 index 000000000..456fd0a54 --- /dev/null +++ b/projects/cores/expandable_ops/__init__.py @@ -0,0 +1 @@ +"""This module is used to expand the channels of a supernet.""" diff --git a/projects/cores/expandable_ops/ops.py b/projects/cores/expandable_ops/ops.py new file mode 100644 index 000000000..345743669 --- /dev/null +++ b/projects/cores/expandable_ops/ops.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn + +from mmrazor.models.architectures import dynamic_ops +from mmrazor.models.mutables import MutableChannelContainer + + +class ExpandableMixin: + """This minin coroperates with dynamic ops. + + It defines interfaces to expand the channels of ops. We can get a wider + network than original supernet with it. + """ + + def expand(self, zero=False): + """Expand the op. + + Args: + zero (bool, optional): whether to set new weights to zero. Defaults + to False. + """ + return self.get_expand_op( + self.expanded_in_channel, + self.expanded_out_channel, + zero=zero, + ) + + def get_expand_op(self, in_c, out_c, zero=False): + """Get an expanded op. + + Args: + in_c (int): New input channels + out_c (int): New output channels + zero (bool, optional): Whether to zero new weights. Defaults to + False. + """ + pass + + @property + def _original_in_channel(self): + """Return original in channel.""" + raise NotImplementedError() + + @property + def _original_out_channel(self): + """Return original out channel.""" + + @property + def expanded_in_channel(self): + """Return expanded in channel number.""" + if self.in_mutable is not None: + return self.in_mutable.current_mask.numel() + else: + return self._original_in_channel + + @property + def expanded_out_channel(self): + """Return expanded out channel number.""" + if self.out_mutable is not None: + return self.out_mutable.current_mask.numel() + else: + return self._original_out_channel + + @property + def mutable_in_mask(self): + """Return the mutable in mask.""" + if self.in_mutable is not None: + return self.in_mutable.current_mask + else: + if hasattr(self, 'weight'): + return self.weight.new_ones([self.expanded_in_channel]) + else: + return torch.ones([self.expanded_in_channel]) + + @property + def mutable_out_mask(self): + """Return the mutable out mask.""" + if self.out_mutable is not None: + return self.out_mutable.current_mask + else: + if hasattr(self, 'weight'): + return self.weight.new_ones([self.expanded_out_channel]) + else: + return torch.ones([self.expanded_out_channel]) + + @property + def in_mutable(self) -> MutableChannelContainer: + """In channel mask.""" + return self.get_mutable_attr('in_channels') # type: ignore + + @property + def out_mutable(self) -> MutableChannelContainer: + """Out channel mask.""" + return self.get_mutable_attr('out_channels') # type: ignore + + def zero_weight_(self: nn.Module): + """Zero all weights.""" + for p in self.parameters(): + p.data.zero_() + + @torch.no_grad() + def expand_matrix(self, weight: torch.Tensor, old_weight: torch.Tensor): + """Expand weight matrix.""" + assert len(weight.shape) == 3 # out in c + assert len(old_weight.shape) == 3 # out in c + mask = self.mutable_out_mask.float().unsqueeze( + -1) * self.mutable_in_mask.float().unsqueeze(0) + mask = mask.unsqueeze(-1).expand(*weight.shape) + weight.data.masked_scatter_(mask.bool(), old_weight) + return weight + + @torch.no_grad() + def expand_vector(self, weight: torch.Tensor, old_weight: torch.Tensor): + """Expand weight vector.""" + assert len(weight.shape) == 2 # out c + assert len(old_weight.shape) == 2 # out c + mask = self.mutable_out_mask + mask = mask.unsqueeze(-1).expand(*weight.shape) + weight.data.masked_scatter_(mask.bool(), old_weight) + return weight + + @torch.no_grad() + def expand_bias(self, bias: torch.Tensor, old_bias: torch.Tensor): + """Expand bias.""" + assert len(bias.shape) == 1 # out c + assert len(old_bias.shape) == 1 # out c + return self.expand_vector(bias.unsqueeze(-1), + old_bias.unsqueeze(-1)).squeeze(1) + + +class ExpandableConv2d(dynamic_ops.DynamicConv2d, ExpandableMixin): + + @property + def _original_in_channel(self): + return self.in_channels + + @property + def _original_out_channel(self): + return self.out_channels + + def get_expand_op(self, in_c, out_c, zero=False): + module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride, + self.padding, self.dilation, self.groups, self.bias + is not None, self.padding_mode) + if zero: + ExpandableMixin.zero_weight_(module) + + weight = self.expand_matrix( + module.weight.flatten(2), self.weight.flatten(2)) + module.weight.data = weight.reshape(module.weight.shape) + if module.bias is not None and self.bias is not None: + bias = self.expand_vector( + module.bias.unsqueeze(-1), self.bias.unsqueeze(-1)) + module.bias.data = bias.reshape(module.bias.shape) + return module + + +class ExpandLinear(dynamic_ops.DynamicLinear, ExpandableMixin): + + @property + def _original_in_channel(self): + return self.in_features + + @property + def _original_out_channel(self): + return self.out_features + + def get_expand_op(self, in_c, out_c, zero=False): + module = nn.Linear(in_c, out_c, self.bias is not None) + if zero: + ExpandableMixin.zero_weight_(module) + + weight = self.expand_matrix( + module.weight.unsqueeze(-1), self.weight.unsqueeze(-1)) + module.weight.data = weight.reshape(module.weight.shape) + if module.bias is not None: + bias = self.expand_vector( + module.bias.unsqueeze(-1), self.bias.unsqueeze(-1)) + module.bias.data = bias.reshape(module.bias.shape) + return module + + +class ExpandableBatchNorm2d(dynamic_ops.DynamicBatchNorm2d, ExpandableMixin): + + @property + def _original_in_channel(self): + return self.num_features + + @property + def _original_out_channel(self): + return self.num_features + + def get_expand_op(self, in_c, out_c, zero=False): + assert in_c == out_c + module = nn.BatchNorm2d(in_c, self.eps, self.momentum, self.affine, + self.track_running_stats) + if zero: + ExpandableMixin.zero_weight_(module) + + if module.running_mean is not None: + module.running_mean.data = self.expand_bias( + module.running_mean, self.running_mean) + + if module.running_var is not None: + module.running_var.data = self.expand_bias(module.running_var, + self.running_var) + module.weight.data = self.expand_bias(module.weight, self.weight) + module.bias.data = self.expand_bias(module.bias, self.bias) + return module diff --git a/projects/cores/expandable_ops/unit.py b/projects/cores/expandable_ops/unit.py new file mode 100644 index 000000000..d38d76272 --- /dev/null +++ b/projects/cores/expandable_ops/unit.py @@ -0,0 +1,80 @@ +import copy + +import torch +import torch.nn as nn + +from mmrazor.models.mutables import (L1MutableChannelUnit, + MutableChannelContainer) +from mmrazor.models.mutators import ChannelMutator +from .ops import (ExpandableBatchNorm2d, ExpandableConv2d, ExpandableMixin, + ExpandLinear) + + +def expand_static_model(model: nn.Module, divisor): + """Expand the channels of a model. + + Args: + model (nn.Module): the model to be expanded. + divisor (_type_): the divisor to make the channels divisible. + + Returns: + nn.Module: an expanded model. + """ + from projects.cores.expandable_ops.unit import (ExpandableUnit, + expand_dynamic_model) + state_dict = model.state_dict() + mutator = ChannelMutator[ExpandableUnit](channel_unit_cfg=ExpandableUnit) + mutator.prepare_from_supernet(model) + model.load_state_dict(state_dict) + for unit in mutator.mutable_units: + num = unit.current_choice + if num % divisor == 0: + continue + else: + num = (num // divisor + 1) * divisor + num = max(num, unit.num_channels) + unit.expand_to(num) + expand_dynamic_model(model, zero=True) + + mutator = ChannelMutator[ExpandableUnit](channel_unit_cfg=ExpandableUnit) + mutator.prepare_from_supernet(copy.deepcopy(model)) + structure = mutator.choice_template + return structure + + +def expand_dynamic_model(model: nn.Module, zero=False) -> None: + """Expand a dynamic model.""" + + def traverse_children(module: nn.Module) -> None: + for name, mutable in module.items(): + if isinstance(mutable, ExpandableMixin): + module[name] = mutable.expand(zero=zero) + if hasattr(mutable, '_modules'): + traverse_children(mutable._modules) + + if isinstance(model, ExpandableMixin): + raise RuntimeError('Root model can not be dynamic op.') + + if hasattr(model, '_modules'): + traverse_children(model._modules) + + +class ExpandableUnit(L1MutableChannelUnit): + + def prepare_for_pruning(self, model: nn.Module): + self._replace_with_dynamic_ops( + model, { + nn.Conv2d: ExpandableConv2d, + nn.BatchNorm2d: ExpandableBatchNorm2d, + nn.Linear: ExpandLinear, + }) + self._register_channel_container(model, MutableChannelContainer) + self._register_mutable_channel(self.mutable_channel) + + def expand(self, num): + expand_mask = self.mutable_channel.mask.new_zeros([num]) + mask = torch.cat([self.mutable_channel.mask, expand_mask]) + self.mutable_channel.mask = mask + + def expand_to(self, num): + self.expand(num - self.num_channels) diff --git a/projects/cores/hooks/__init__.py b/projects/cores/hooks/__init__.py new file mode 100644 index 000000000..5966486c4 --- /dev/null +++ b/projects/cores/hooks/__init__.py @@ -0,0 +1,3 @@ +from .prune_hook import PruningStructureHook, ResourceInfoHook + +__all__ = ['PruningStructureHook', 'ResourceInfoHook'] diff --git a/projects/cores/hooks/prune_hook.py b/projects/cores/hooks/prune_hook.py new file mode 100644 index 000000000..0b39eae5f --- /dev/null +++ b/projects/cores/hooks/prune_hook.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.dist import master_only +from mmengine.hooks import Hook +from mmengine.runner import Runner, save_checkpoint + +from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput +from mmrazor.models.task_modules.estimators import ResourceEstimator +from mmrazor.registry import HOOKS, TASK_UTILS +from mmrazor.utils import print_log +from ..utils import RuntimeInfo, get_model_from_runner, is_pruning_algorithm + + +@HOOKS.register_module() +class PruningStructureHook(Hook): + """This hook is used to display the structurn information during pruning. + + Args: + by_epoch (bool, optional): Whether to display structure information + iteratively by epoch. Defaults to True. + interval (int, optional): The interval between two structure + information display. + """ + + def __init__(self, by_epoch=True, interval=1) -> None: + + super().__init__() + self.by_epoch = by_epoch + self.interval = interval + + def show_unit_info(self, algorithm): + """Show unit information of an algorithm.""" + if is_pruning_algorithm(algorithm): + chices = algorithm.mutator.choice_template + import json + print_log(json.dumps(chices, indent=4)) + + for unit in algorithm.mutator.mutable_units: + if hasattr(unit, 'importance'): + imp = unit.importance() + print_log( + f'{unit.name}: \t{imp.min().item()}\t{imp.max().item()}' # noqa + ) + + @master_only + def show(self, runner): + """Show pruning algorithm information of a runner.""" + algorithm = get_model_from_runner(runner) + if is_pruning_algorithm(algorithm): + self.show_unit_info(algorithm) + + # hook points + + def after_train_epoch(self, runner) -> None: + if self.by_epoch and RuntimeInfo.epoch() % self.interval == 0: + self.show(runner) + + def after_train_iter(self, runner, batch_idx: int, data_batch, + outputs) -> None: + if not self.by_epoch and RuntimeInfo.iter() % self.interval == 0: + self.show(runner) + + +@HOOKS.register_module() +class ResourceInfoHook(Hook): + """This hook is used to display the resource related information and save + the checkpoint according to a threshold during pruning. + + Args: + demo_input (dict, optional): the demo input for ResourceEstimator. + Defaults to DefaultDemoInput([1, 3, 224, 224]). + interval (int, optional): the interval to check the resource. Defaults + to 10. + resource_type (str, optional): the type of resource to check. + Defaults to 'flops'. + save_ckpt_thr (list, optional): the threshold to save checkpoint. + Defaults to [0.5]. + early_stop (bool, optional): whether to stop when all checkpoints have + been saved according to save_ckpt_thr. Defaults to True. + """ + + def __init__(self, + demo_input=DefaultDemoInput([1, 3, 224, 224]), + interval=10, + resource_type='flops', + save_ckpt_thr=[0.5], + early_stop=True) -> None: + + super().__init__() + if isinstance(demo_input, dict): + demo_input = TASK_UTILS.build(demo_input) + + self.demo_input = demo_input + self.save_ckpt_thr = sorted( + save_ckpt_thr, reverse=True) # big to small + self.resource_type = resource_type + self.early_stop = early_stop + self.estimator: ResourceEstimator = TASK_UTILS.build( + dict( + _scope_='mmrazor', + type='ResourceEstimator', + flops_params_cfg=dict( + input_shape=tuple(demo_input.input_shape), ))) + self.interval = interval + self.origin_delta = None + + def before_run(self, runner) -> None: + """Init original_resource.""" + model = get_model_from_runner(runner) + original_resource = self._evaluate(model) + print_log(f'get original resource: {original_resource}') + + self.origin_delta = original_resource[self.resource_type] + + # save checkpoint + + def after_train_iter(self, + runner: Runner, + batch_idx: int, + data_batch=None, + outputs=None) -> None: + """Check resource after train iteration.""" + if RuntimeInfo.iter() % self.interval == 0 and len( + self.save_ckpt_thr) > 0: + model = get_model_from_runner(runner) + current_delta = self._evaluate(model)[self.resource_type] + percent = current_delta / self.origin_delta + if percent < self.save_ckpt_thr[0]: + self._save_checkpoint(model, runner.work_dir, + self.save_ckpt_thr.pop(0)) + if self.early_stop and len(self.save_ckpt_thr) == 0: + exit() + + # show info + + @master_only + def after_train_epoch(self, runner) -> None: + """Check resource after train epoch.""" + model = get_model_from_runner(runner) + current_delta = self._evaluate(model)[self.resource_type] + print_log( + f'current {self.resource_type}: {current_delta} / {self.origin_delta}' # noqa + ) + + # + + def _evaluate(self, model: nn.Module): + """Evaluate the resource required by a model.""" + with torch.no_grad(): + training = model.training + model.eval() + res = self.estimator.estimate(model) + if training: + model.train() + return res + + @master_only + def _save_checkpoint(self, model, path, delta_percent): + """Save the checkpoint of a model.""" + ckpt = {'state_dict': model.state_dict()} + save_path = f'{path}/{self.resource_type}_{delta_percent:.2f}.pth' + save_checkpoint(ckpt, save_path) + print_log( + f'Save checkpoint to {save_path} with {self._evaluate(model)}' # noqa + ) diff --git a/projects/cores/models/__init__.py b/projects/cores/models/__init__.py new file mode 100644 index 000000000..76b714cb5 --- /dev/null +++ b/projects/cores/models/__init__.py @@ -0,0 +1,3 @@ +from .prune_deploy_wrapper import PruneDeployWrapper, PruneFinetuneWrapper + +__all__ = ['PruneDeployWrapper', 'PruneFinetuneWrapper'] diff --git a/projects/cores/models/prune_deploy_wrapper.py b/projects/cores/models/prune_deploy_wrapper.py new file mode 100644 index 000000000..683d60b6b --- /dev/null +++ b/projects/cores/models/prune_deploy_wrapper.py @@ -0,0 +1,120 @@ +import json +import types + +import torch.nn as nn +from mmengine.model import BaseModel, BaseModule + +from mmrazor.models import BaseAlgorithm +from mmrazor.models.mutators import ChannelMutator +from mmrazor.registry import MODELS +from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet, + load_fix_subnet) +from mmrazor.utils import print_log +from ..expandable_ops.unit import ExpandableUnit + + +def clean_params_init_info(model: nn.Module): + """""" + if hasattr(model, '_params_init_info'): + delattr(model, '_params_init_info') + for module in model.modules(): + if hasattr(module, '_params_init_info'): + delattr(module, '_params_init_info') + + +def clean_init_cfg(model: BaseModule): + for module in model.modules(): + if module is model: + continue + if isinstance(module, BaseModule): + module.init_cfg = {} + + +def empty_init_weights(model): + pass + + +def to_static_model(algorithm: BaseAlgorithm): + if hasattr(algorithm, 'to_static'): + model = algorithm.to_static() + else: + mutables = export_fix_subnet(algorithm.architecture)[0] + load_fix_subnet(algorithm.architecture, mutables) + model = algorithm.architecture + + model.data_preprocessor = algorithm.data_preprocessor + if isinstance(model, BaseModel): + model.init_cfg = None + model.init_weights = types.MethodType(empty_init_weights, model) + return model + + +@MODELS.register_module() +def PruneFinetuneWrapper(algorithm, data_preprocessor=None): + """A model wrapper for pruning algorithm. + + Args: + algorithm (_type_): _description_ + data_preprocessor (_type_, optional): _description_. Defaults to None. + + Returns: + nn.Module: a static model. + """ + algorithm: BaseAlgorithm = MODELS.build(algorithm) + algorithm.init_weights() + clean_params_init_info(algorithm) + print_log(json.dumps(algorithm.mutator.choice_template, indent=4)) + + if hasattr(algorithm, 'to_static'): + model = algorithm.to_static() + else: + mutables = export_fix_subnet(algorithm.architecture)[0] + load_fix_subnet(algorithm.architecture, mutables) + model = algorithm.architecture + + model.data_preprocessor = algorithm.data_preprocessor + if isinstance(model, BaseModel): + model.init_cfg = None + model.init_weights = types.MethodType(empty_init_weights, model) + return model + + +@MODELS.register_module() +def PruneDeployWrapper(architecture, + mutable_cfg={}, + divisor=1, + data_preprocessor=None, + init_cfg=None): + """A deploy wrapper for a pruned model. + + Args: + architecture (_type_): the model to be pruned. + mutable_cfg (dict, optional): the channel remaining ratio for each + unit. Defaults to {}. + divisor (int, optional): the divisor to make the channel number + divisible. Defaults to 1. + data_preprocessor (_type_, optional): Defaults to None. + init_cfg (_type_, optional): Defaults to None. + + Returns: + BaseModel: a BaseModel of mmengine. + """ + if isinstance(architecture, dict): + architecture = MODELS.build(architecture) + assert isinstance(architecture, nn.Module) + + # to dynamic model + mutator = ChannelMutator[ExpandableUnit](channel_unit_cfg=ExpandableUnit) + mutator.prepare_from_supernet(architecture) + for unit in mutator.mutable_units: + if unit.name in mutable_cfg: + unit.current_choice = mutable_cfg[unit.name] + print_log(json.dumps(mutator.choice_template, indent=4)) + + mutables = export_fix_subnet(architecture)[0] + load_fix_subnet(architecture, mutables) + + if divisor != 1: + setattr(architecture, '_razor_divisor', divisor) + + return architecture diff --git a/projects/cores/utils.py b/projects/cores/utils.py new file mode 100644 index 000000000..834415c3d --- /dev/null +++ b/projects/cores/utils.py @@ -0,0 +1,59 @@ +import math + +from mmengine.logging import MessageHub +from torch import distributed as torch_dist + +from mmrazor.models import BaseAlgorithm +from mmrazor.models.mutators import ChannelMutator + + +def is_pruning_algorithm(algorithm): + """Check whether a model is a pruning algorithm.""" + return isinstance(algorithm, BaseAlgorithm) \ + and isinstance(getattr(algorithm, 'mutator', None), ChannelMutator) # noqa + + +def get_model_from_runner(runner): + """Get the model from a runner.""" + if torch_dist.is_initialized(): + return runner.model.module + else: + return runner.model + + +class RuntimeInfo(): + """A tools to get runtime info in MessageHub.""" + + @classmethod + def get_info(cls, key): + hub = MessageHub.get_current_instance() + if key in hub.runtime_info: + return hub.runtime_info[key] + else: + raise KeyError(key) + + @classmethod + def epoch(cls): + return cls.get_info('epoch') + + @classmethod + def max_epochs(cls): + return cls.get_info('max_epochs') + + @classmethod + def iter(cls): + return cls.get_info('iter') + + @classmethod + def max_iters(cls): + return cls.get_info('max_iters') + + @classmethod + def iter_by_epoch(cls): + iter_per_epoch = math.ceil(cls.max_iters() / cls.max_epochs()) + return cls.iter() % iter_per_epoch + + @classmethod + def iter_pre_epoch(cls): + iter_per_epoch = math.ceil(cls.max_iters() / cls.max_epochs()) + return iter_per_epoch diff --git a/projects/group_fisher/__init__.py b/projects/group_fisher/__init__.py new file mode 100644 index 000000000..9e6824a83 --- /dev/null +++ b/projects/group_fisher/__init__.py @@ -0,0 +1,11 @@ +from .modules.group_fisher_algorthm import GroupFisherAlgorithm +from .modules.group_fisher_channel_mutator import GroupFisherChannelMutator +from .modules.group_fisher_channel_unit import GroupFisherChannelUnit +from .modules.group_fisher_ops import GroupFisherMixin + +__all__ = [ + 'GroupFisherChannelMutator', + 'GroupFisherAlgorithm', + 'GroupFisherMixin', + 'GroupFisherChannelUnit', +] diff --git a/projects/group_fisher/configs/README.md b/projects/group_fisher/configs/README.md new file mode 100644 index 000000000..1997ef541 --- /dev/null +++ b/projects/group_fisher/configs/README.md @@ -0,0 +1,49 @@ +# Group_fisher pruning + +> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf) + +## Abstract + +Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy. + +![pipeline](https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png) + +## Results and models + +### Detection + +| Dataset | Detector | Backbone | Baseline(mAP) | Pruned&Finetuned(mAP) | Model | log | +| :-----: | :-------: | :------: | :-----------: | :-------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -------------------------- | +| COCO | RetinaNet | R-50-FPN | 36.5 | 36.5 (50% flops) | [Baseline](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth)/[Pruned](<>)/[Finetuned](<>) | [Prune](<>)/[Finetune](<>) | + +## Citation + +@InProceedings{liu2021group, +title = {Group Fisher Pruning for Practical Network Compression}, +author = {Liu, Liyang and Zhang, Shilong and Kuang, Zhanghui and Zhou, Aojun and Xue, Jing-Hao and Wang, Xinjiang and Chen, Yimin and Yang, Wenming and Liao, Qingmin and Zhang, Wayne}, +booktitle = {Proceedings of the 38th International Conference on Machine Learning}, +year = {2021}, +series = {Proceedings of Machine Learning Research}, +month = {18--24 Jul}, +publisher ={PMLR}, +} + +## Get Started + +### Pruning + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \ + configs/pruning/mmdet/group_fisher/group-fisher-pruning_retinanet_resnet50_8xb2_coco.py 8 \ + --work-dir $WORK_DIR +``` + +### Finetune + +Update the `pruned_path` to your local file path that saves the pruned checkpoint. + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \ + configs/pruning/mmdet/group_fisher/group-fisher-finetune_retinanet_resnet50_8xb2_coco.py 8 \ + --work-dir $WORK_DIR +``` diff --git a/projects/group_fisher/configs/mmcls/mobilenet/exp.md b/projects/group_fisher/configs/mmcls/mobilenet/exp.md new file mode 100644 index 000000000..2927a02ac --- /dev/null +++ b/projects/group_fisher/configs/mmcls/mobilenet/exp.md @@ -0,0 +1,60 @@ +| name | flop | param | finetune | +| ----------- | ----- | ----- | -------- | +| baseline | 0.319 | 3.5 | 71.86 | +| fisher_act | 0.20 | 3.14 | 70.79 | +| fisher_flop | 0.20 | 2.78 | 70.87 | + +fisher_act +{ +"backbone.conv1.conv\_(0, 32)_32": 21, +"backbone.layer1.0.conv.1.conv_(0, 16)_16": 10, +"backbone.layer2.0.conv.0.conv_(0, 96)_96": 45, +"backbone.layer2.0.conv.2.conv_(0, 24)_24": 24, +"backbone.layer2.1.conv.0.conv_(0, 144)_144": 73, +"backbone.layer3.0.conv.0.conv_(0, 144)_144": 85, +"backbone.layer3.0.conv.2.conv_(0, 32)_32": 32, +"backbone.layer3.1.conv.0.conv_(0, 192)_192": 95, +"backbone.layer3.2.conv.0.conv_(0, 192)_192": 76, +"backbone.layer4.0.conv.0.conv_(0, 192)_192": 160, +"backbone.layer4.0.conv.2.conv_(0, 64)_64": 64, +"backbone.layer4.1.conv.0.conv_(0, 384)_384": 204, +"backbone.layer4.2.conv.0.conv_(0, 384)_384": 200, +"backbone.layer4.3.conv.0.conv_(0, 384)_384": 217, +"backbone.layer5.0.conv.0.conv_(0, 384)_384": 344, +"backbone.layer5.0.conv.2.conv_(0, 96)_96": 96, +"backbone.layer5.1.conv.0.conv_(0, 576)_576": 348, +"backbone.layer5.2.conv.0.conv_(0, 576)_576": 338, +"backbone.layer6.0.conv.0.conv_(0, 576)_576": 543, +"backbone.layer6.0.conv.2.conv_(0, 160)_160": 160, +"backbone.layer6.1.conv.0.conv_(0, 960)_960": 810, +"backbone.layer6.2.conv.0.conv_(0, 960)_960": 803, +"backbone.layer7.0.conv.0.conv_(0, 960)_960": 944, +"backbone.layer7.0.conv.2.conv_(0, 320)\_320": 320 +} +fisher_flop +{ +"backbone.conv1.conv\_(0, 32)_32": 27, +"backbone.layer1.0.conv.1.conv_(0, 16)_16": 16, +"backbone.layer2.0.conv.0.conv_(0, 96)_96": 77, +"backbone.layer2.0.conv.2.conv_(0, 24)_24": 24, +"backbone.layer2.1.conv.0.conv_(0, 144)_144": 85, +"backbone.layer3.0.conv.0.conv_(0, 144)_144": 115, +"backbone.layer3.0.conv.2.conv_(0, 32)_32": 32, +"backbone.layer3.1.conv.0.conv_(0, 192)_192": 102, +"backbone.layer3.2.conv.0.conv_(0, 192)_192": 95, +"backbone.layer4.0.conv.0.conv_(0, 192)_192": 181, +"backbone.layer4.0.conv.2.conv_(0, 64)_64": 64, +"backbone.layer4.1.conv.0.conv_(0, 384)_384": 169, +"backbone.layer4.2.conv.0.conv_(0, 384)_384": 176, +"backbone.layer4.3.conv.0.conv_(0, 384)_384": 180, +"backbone.layer5.0.conv.0.conv_(0, 384)_384": 308, +"backbone.layer5.0.conv.2.conv_(0, 96)_96": 96, +"backbone.layer5.1.conv.0.conv_(0, 576)_576": 223, +"backbone.layer5.2.conv.0.conv_(0, 576)_576": 241, +"backbone.layer6.0.conv.0.conv_(0, 576)_576": 511, +"backbone.layer6.0.conv.2.conv_(0, 160)_160": 160, +"backbone.layer6.1.conv.0.conv_(0, 960)_960": 467, +"backbone.layer6.2.conv.0.conv_(0, 960)_960": 510, +"backbone.layer7.0.conv.0.conv_(0, 960)_960": 771, +"backbone.layer7.0.conv.2.conv_(0, 320)\_320": 320 +} diff --git a/projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_finetune_flop.py b/projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_finetune_flop.py new file mode 100644 index 000000000..83e830dec --- /dev/null +++ b/projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_finetune_flop.py @@ -0,0 +1,37 @@ +_base_ = './mobilenet_v2_group_fisher_prune_flop.py' +custom_imports = dict(imports=['projects']) + +algorithm = _base_.model +pruned_path = './work_dirs/mobilenet_v2_group_fisher_prune_flop/flops_0.65.pth' +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='PruneDeployWrapper', + algorithm=algorithm, +) + +# restore optimizer + +optim_wrapper = dict( + _delete_=True, + optimizer=dict( + type='SGD', + lr=0.045, + momentum=0.9, + weight_decay=4e-05, + _scope_='mmcls')) +param_scheduler = dict( + _delete_=True, + type='StepLR', + by_epoch=True, + step_size=1, + gamma=0.98, + _scope_='mmcls') + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_prune_flop.py b/projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_prune_flop.py new file mode 100644 index 000000000..d3361062f --- /dev/null +++ b/projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_prune_flop.py @@ -0,0 +1,4 @@ +_base_ = '../mobilenet_v2_group_fisher_prune.py' +model = dict( + mutator=dict( + channel_unit_cfg=dict(default_args=dict(detla_type='flop', ), ), ), ) diff --git a/projects/group_fisher/configs/mmcls/mobilenet/flop/run.sh b/projects/group_fisher/configs/mmcls/mobilenet/flop/run.sh new file mode 100644 index 000000000..9f1a4619e --- /dev/null +++ b/projects/group_fisher/configs/mmcls/mobilenet/flop/run.sh @@ -0,0 +1,2 @@ +bash ./tools/dist_train.sh projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_prune_flop.py 8 +bash ./tools/dist_train.sh projects/group_fisher/configs/mmcls/mobilenet/flop/mobilenet_v2_group_fisher_finetune_flop.py 8 diff --git a/projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_finetune.py b/projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_finetune.py new file mode 100644 index 000000000..410358d27 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_finetune.py @@ -0,0 +1,37 @@ +_base_ = './mobilenet_v2_group_fisher_prune.py' +custom_imports = dict(imports=['projects']) + +algorithm = _base_.model +pruned_path = './work_dirs/mobilenet_v2_group_fisher_prune/flops_0.65.pth' +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='PruneDeployWrapper', + algorithm=algorithm, +) + +# restore optimizer + +optim_wrapper = dict( + _delete_=True, + optimizer=dict( + type='SGD', + lr=0.045, + momentum=0.9, + weight_decay=4e-05, + _scope_='mmcls')) +param_scheduler = dict( + _delete_=True, + type='StepLR', + by_epoch=True, + step_size=1, + gamma=0.98, + _scope_='mmcls') + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_prune.py b/projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_prune.py new file mode 100644 index 000000000..2a7553bf6 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_prune.py @@ -0,0 +1,59 @@ +_base_ = 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py' +custom_imports = dict(imports=['projects']) +architecture = _base_.model +pretrained_path = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa +architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path) +architecture.update({ + 'data_preprocessor': _base_.data_preprocessor, +}) +data_preprocessor = None + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherAlgorithm', + architecture=architecture, + interval=25, + mutator=dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict( + type='GroupFisherChannelUnit', + default_args=dict(detla_type='act', ), + ), + ), +) +model_wrapper_cfg = dict( + type='mmrazor.GroupFisherDDP', + broadcast_buffers=False, +) +# update optimizer + +optim_wrapper = dict(optimizer=dict(lr=0.004, )) +param_scheduler = None + +custom_hooks = [ + dict(type='mmrazor.PruningStructureHook'), + dict( + type='mmrazor.ResourceInfoHook', + interval=25, + demo_input=dict( + type='mmrazor.DefaultDemoInput', + input_shape=[1, 3, 224, 224], + ), + save_ckpt_delta_thr=[0.65, 0.33], + ), +] + +# original +""" +optim_wrapper = dict( + optimizer=dict( + type='SGD', + lr=0.045, + momentum=0.9, + weight_decay=4e-05, + _scope_='mmcls')) +param_scheduler = dict( + type='StepLR', by_epoch=True, step_size=1, gamma=0.98, _scope_='mmcls') +""" diff --git a/projects/group_fisher/configs/mmcls/mobilenet/run.sh b/projects/group_fisher/configs/mmcls/mobilenet/run.sh new file mode 100644 index 000000000..ea93d88d1 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/mobilenet/run.sh @@ -0,0 +1,2 @@ +bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_prune.py 8 +bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmcls/mobilenet/mobilenet_v2_group_fisher_finetune.py 8 diff --git a/projects/group_fisher/configs/mmcls/resnet50/exp.md b/projects/group_fisher/configs/mmcls/resnet50/exp.md new file mode 100644 index 000000000..338561031 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/resnet50/exp.md @@ -0,0 +1,81 @@ +| name | flop | param | finetune | +| ----------- | ---- | ----- | -------- | +| fisher_act | 2.05 | 16.22 | 75.2 | +| fisher_flop | 2.05 | 16.22 | 75.6 | + +act: +"backbone.conv1\_(0, 64)_64": 61, +"backbone.layer1.0.conv1_(0, 64)_64": 27, +"backbone.layer1.0.conv2_(0, 64)_64": 35, +"backbone.layer1.0.conv3_(0, 256)_256": 241, +"backbone.layer1.1.conv1_(0, 64)_64": 32, +"backbone.layer1.1.conv2_(0, 64)_64": 29, +"backbone.layer1.2.conv1_(0, 64)_64": 27, +"backbone.layer1.2.conv2_(0, 64)_64": 42, +"backbone.layer2.0.conv1_(0, 128)_128": 87, +"backbone.layer2.0.conv2_(0, 128)_128": 107, +"backbone.layer2.0.conv3_(0, 512)_512": 512, +"backbone.layer2.1.conv1_(0, 128)_128": 44, +"backbone.layer2.1.conv2_(0, 128)_128": 50, +"backbone.layer2.2.conv1_(0, 128)_128": 52, +"backbone.layer2.2.conv2_(0, 128)_128": 81, +"backbone.layer2.3.conv1_(0, 128)_128": 47, +"backbone.layer2.3.conv2_(0, 128)_128": 50, +"backbone.layer3.0.conv1_(0, 256)_256": 210, +"backbone.layer3.0.conv2_(0, 256)_256": 206, +"backbone.layer3.0.conv3_(0, 1024)_1024": 1024, +"backbone.layer3.1.conv1_(0, 256)_256": 107, +"backbone.layer3.1.conv2_(0, 256)_256": 108, +"backbone.layer3.2.conv1_(0, 256)_256": 86, +"backbone.layer3.2.conv2_(0, 256)_256": 126, +"backbone.layer3.3.conv1_(0, 256)_256": 91, +"backbone.layer3.3.conv2_(0, 256)_256": 112, +"backbone.layer3.4.conv1_(0, 256)_256": 98, +"backbone.layer3.4.conv2_(0, 256)_256": 110, +"backbone.layer3.5.conv1_(0, 256)_256": 112, +"backbone.layer3.5.conv2_(0, 256)_256": 115, +"backbone.layer4.0.conv1_(0, 512)_512": 397, +"backbone.layer4.0.conv2_(0, 512)_512": 427, +"backbone.layer4.1.conv1_(0, 512)_512": 373, +"backbone.layer4.1.conv2_(0, 512)_512": 348, +"backbone.layer4.2.conv1_(0, 512)_512": 433, +"backbone.layer4.2.conv2_(0, 512)\_512": 384 + +flop: +"backbone.conv1\_(0, 64)_64": 61, +"backbone.layer1.0.conv1_(0, 64)_64": 28, +"backbone.layer1.0.conv2_(0, 64)_64": 35, +"backbone.layer1.0.conv3_(0, 256)_256": 242, +"backbone.layer1.1.conv1_(0, 64)_64": 31, +"backbone.layer1.1.conv2_(0, 64)_64": 28, +"backbone.layer1.2.conv1_(0, 64)_64": 26, +"backbone.layer1.2.conv2_(0, 64)_64": 41, +"backbone.layer2.0.conv1_(0, 128)_128": 90, +"backbone.layer2.0.conv2_(0, 128)_128": 107, +"backbone.layer2.0.conv3_(0, 512)_512": 509, +"backbone.layer2.1.conv1_(0, 128)_128": 42, +"backbone.layer2.1.conv2_(0, 128)_128": 50, +"backbone.layer2.2.conv1_(0, 128)_128": 51, +"backbone.layer2.2.conv2_(0, 128)_128": 84, +"backbone.layer2.3.conv1_(0, 128)_128": 49, +"backbone.layer2.3.conv2_(0, 128)_128": 51, +"backbone.layer3.0.conv1_(0, 256)_256": 210, +"backbone.layer3.0.conv2_(0, 256)_256": 207, +"backbone.layer3.0.conv3_(0, 1024)_1024": 1024, +"backbone.layer3.1.conv1_(0, 256)_256": 103, +"backbone.layer3.1.conv2_(0, 256)_256": 108, +"backbone.layer3.2.conv1_(0, 256)_256": 90, +"backbone.layer3.2.conv2_(0, 256)_256": 124, +"backbone.layer3.3.conv1_(0, 256)_256": 94, +"backbone.layer3.3.conv2_(0, 256)_256": 114, +"backbone.layer3.4.conv1_(0, 256)_256": 99, +"backbone.layer3.4.conv2_(0, 256)_256": 111, +"backbone.layer3.5.conv1_(0, 256)_256": 108, +"backbone.layer3.5.conv2_(0, 256)_256": 111, +"backbone.layer4.0.conv1_(0, 512)_512": 400, +"backbone.layer4.0.conv2_(0, 512)_512": 421, +"backbone.layer4.1.conv1_(0, 512)_512": 377, +"backbone.layer4.1.conv2_(0, 512)_512": 347, +"backbone.layer4.2.conv1_(0, 512)_512": 443, +"backbone.layer4.2.conv2_(0, 512)\_512": 376 +} diff --git a/projects/group_fisher/configs/mmcls/resnet50/resnet_group_fisher_finetune.py b/projects/group_fisher/configs/mmcls/resnet50/resnet_group_fisher_finetune.py new file mode 100644 index 000000000..ba3dd6fc0 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/resnet50/resnet_group_fisher_finetune.py @@ -0,0 +1,37 @@ +_base_ = './resnet_group_fisher_prune.py' +custom_imports = dict(imports=['projects']) + +algorithm = _base_.model +pruned_path = './work_dirs/resnet_group_fisher_prune/flops_0.50.pth' +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='PruneDeployWrapper', + algorithm=algorithm, +) + +custom_hooks = [ + dict(type='mmrazor.PruneHook'), +] + +# restore optimizer + +optim_wrapper = dict( + optimizer=dict( + type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001, + _scope_='mmcls')) +param_scheduler = dict( + _delete_=True, + type='MultiStepLR', + by_epoch=True, + milestones=[30, 60, 90], + gamma=0.1, + _scope_='mmcls') + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/projects/group_fisher/configs/mmcls/resnet50/resnet_group_fisher_prune.py b/projects/group_fisher/configs/mmcls/resnet50/resnet_group_fisher_prune.py new file mode 100644 index 000000000..dedff7ea2 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/resnet50/resnet_group_fisher_prune.py @@ -0,0 +1,63 @@ +_base_ = 'mmcls::resnet/resnet50_8xb32_in1k.py' +custom_imports = dict(imports=['projects']) +architecture = _base_.model +pretrained_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa +architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path) +architecture.update({ + 'data_preprocessor': _base_.data_preprocessor, +}) +data_preprocessor = None + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherAlgorithm', + architecture=architecture, + interval=25, + mutator=dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict( + type='GroupFisherChannelUnit', + default_args=dict(detla_type='flop', ), + ), + ), +) +model_wrapper_cfg = dict( + type='mmrazor.GroupFisherDDP', + broadcast_buffers=False, +) +# update optimizer + +optim_wrapper = dict(optimizer=dict(lr=0.004, )) +param_scheduler = None + +custom_hooks = [ + dict(type='mmrazor.PruningStructureHook'), + dict( + type='mmrazor.ResourceInfoHook', + interval=25, + demo_input=dict( + type='mmrazor.DefaultDemoInput', + input_shape=[1, 3, 224, 224], + ), + save_ckpt_delta_thr=[0.75, 0.50], + ), +] + +# original +""" +optim_wrapper = dict( + optimizer=dict( + type='SGD', + lr=0.1, + momentum=0.9, + weight_decay=0.0001, + _scope_='mmcls')) +param_scheduler = dict( + type='MultiStepLR', + by_epoch=True, + milestones=[30, 60, 90], + gamma=0.1, + _scope_='mmcls') +""" diff --git a/projects/group_fisher/configs/mmcls/resnet50/run.sh b/projects/group_fisher/configs/mmcls/resnet50/run.sh new file mode 100644 index 000000000..73242cb2d --- /dev/null +++ b/projects/group_fisher/configs/mmcls/resnet50/run.sh @@ -0,0 +1,2 @@ +bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmcls/resnet_group_fisher_prune.py 8 +bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmcls/resnet_group_fisher_finetune.py 8 diff --git a/projects/group_fisher/configs/mmcls/vgg/deploy.sh b/projects/group_fisher/configs/mmcls/vgg/deploy.sh new file mode 100644 index 000000000..c4c24e0d6 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/vgg/deploy.sh @@ -0,0 +1,24 @@ +python mmdeploy/tools/deploy.py \ + mmdeploy/configs/mmcls/classification_onnxruntime_static.py \ + projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_deploy.py \ + ./work_dirs/vgg_group_fisher_finetune/best_accuracy/top1_epoch_142.pth \ + ./mmdeploy/demo/resources/face.png \ + --work-dir work_dirs/mmdeploy_model/ \ + --device cpu \ + --dump-info + +python mmdeploy/tools/profiler.py \ + mmdeploy/configs/mmcls/classification_onnxruntime_static.py \ + projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_deploy.py \ + mmdeploy/resources/ \ + --model ./work_dirs/mmdeploy_model/end2end.onnx \ + --shape 32x32 \ + --device cpu \ + --warmup 50 \ + --num-iter 200 + +python mmdeploy/tools/test.py \ + mmdeploy/configs/mmcls/classification_onnxruntime_static.py \ + projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_deploy.py \ + --model ./work_dirs/mmdeploy_model/end2end.onnx \ + --device cpu \ diff --git a/projects/group_fisher/configs/mmcls/vgg/exp.md b/projects/group_fisher/configs/mmcls/vgg/exp.md new file mode 100644 index 000000000..339f39be9 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/vgg/exp.md @@ -0,0 +1,18 @@ +| name | flop | param | prune | finetune | +| ------------- | ---- | ----- | ----- | -------- | +| fisher_act_4 | 94 | 2.09 | 91.51 | 93.42 | +| fisher_flop_4 | 94 | 1.61 | 91.51 | 93.13 | + +fisher_act_4 +"backbone.features.conv0\_(0, 64)_64": 21, +"backbone.features.conv1_(0, 64)_64": 42, +"backbone.features.conv3_(0, 128)_128": 86, +"backbone.features.conv4_(0, 128)_128": 110, +"backbone.features.conv6_(0, 256)_256": 203, +"backbone.features.conv7_(0, 256)_256": 170, +"backbone.features.conv8_(0, 256)_256": 145, +"backbone.features.conv10_(0, 512)_512": 138, +"backbone.features.conv11_(0, 512)_512": 84, +"backbone.features.conv12_(0, 512)_512": 54, +"backbone.features.conv14_(0, 512)_512": 94, +"backbone.features.conv15_(0, 512)\_512": 108 diff --git a/projects/group_fisher/configs/mmcls/vgg/run.sh b/projects/group_fisher/configs/mmcls/vgg/run.sh new file mode 100644 index 000000000..e0b434de0 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/vgg/run.sh @@ -0,0 +1,5 @@ +# python ./tools/train.py ./projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_prune.py +# python ./tools/train.py ./projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_finetune.py + +python ./tools/train.py ./projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_prune_flop.py +python ./tools/train.py ./projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_finetune_flop.py diff --git a/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_deploy.py b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_deploy.py new file mode 100644 index 000000000..808bded95 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_deploy.py @@ -0,0 +1,28 @@ +_base_ = '../../../../models/vgg/configs/vgg_pretrain.py' +custom_imports = dict(imports=['projects']) + +architecture = _base_.model +architecture.update({'data_preprocessor': _base_.data_preprocessor}) +data_preprocessor = {'_delete_': True} + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='PruneDeployWrapper2', + architecture=architecture, + mutable_cfg={ + 'backbone.features.conv0_(0, 64)_64': 21, + 'backbone.features.conv1_(0, 64)_64': 42, + 'backbone.features.conv3_(0, 128)_128': 86, + 'backbone.features.conv4_(0, 128)_128': 110, + 'backbone.features.conv6_(0, 256)_256': 203, + 'backbone.features.conv7_(0, 256)_256': 170, + 'backbone.features.conv8_(0, 256)_256': 145, + 'backbone.features.conv10_(0, 512)_512': 138, + 'backbone.features.conv11_(0, 512)_512': 84, + 'backbone.features.conv12_(0, 512)_512': 54, + 'backbone.features.conv14_(0, 512)_512': 94, + 'backbone.features.conv15_(0, 512)_512': 108 + }, + divisor=8, +) diff --git a/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_finetune.py b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_finetune.py new file mode 100644 index 000000000..710fa01f0 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_finetune.py @@ -0,0 +1,19 @@ +_base_ = './vgg_group_fisher_prune.py' +custom_imports = dict(imports=['projects']) + +algorithm = _base_.model +# `pruned_path` need to be updated. +pruned_path = './work_dirs/vgg_group_fisher_prune/flops_0.30.pth' +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='PruneDeployWrapper', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=0.01)) +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] diff --git a/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_finetune_flop.py b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_finetune_flop.py new file mode 100644 index 000000000..3eb12b6e7 --- /dev/null +++ b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_finetune_flop.py @@ -0,0 +1,19 @@ +_base_ = './vgg_group_fisher_prune.py' +custom_imports = dict(imports=['projects']) + +algorithm = _base_.model +# `pruned_path` need to be updated. +pruned_path = './work_dirs/vgg_group_fisher_prune_flop/flops_0.30.pth' +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='PruneDeployWrapper', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=0.01)) +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] diff --git a/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_prune.py b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_prune.py new file mode 100644 index 000000000..023de259f --- /dev/null +++ b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_prune.py @@ -0,0 +1,43 @@ +_base_ = '../../../../models/vgg/configs/vgg_pretrain.py' +custom_imports = dict(imports=['projects']) + +pretrained_path = './work_dirs/pretrained/vgg_pretrained.pth' # noqa + +architecture = _base_.model +architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path) +architecture.update({'data_preprocessor': _base_.data_preprocessor}) +data_preprocessor = {'_delete_': True} + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherAlgorithm', + architecture=architecture, + interval=4, + mutator=dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict( + type='GroupFisherChannelUnit', + default_args={'detla_type': 'act'}, + ), + ), +) +custom_hooks = [ + dict(type='mmrazor.PruningStructureHook'), + dict( + type='mmrazor.ResourceInfoHook', + interval=4, + demo_input=dict( + type='mmrazor.DefaultDemoInput', + input_shape=[1, 3, 32, 32], + ), + save_ckpt_delta_thr=[0.5, 0.4, 0.3], + ), +] +model_wrapper_cfg = dict( + type='mmrazor.GroupFisherDDP', + broadcast_buffers=False, +) + +optim_wrapper = dict(optimizer=dict(lr=0.0001)) diff --git a/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_prune_flop.py b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_prune_flop.py new file mode 100644 index 000000000..df662f35d --- /dev/null +++ b/projects/group_fisher/configs/mmcls/vgg/vgg_group_fisher_prune_flop.py @@ -0,0 +1,43 @@ +_base_ = '../../../../models/vgg/configs/vgg_pretrain.py' +custom_imports = dict(imports=['projects']) + +pretrained_path = './work_dirs/pretrained/vgg_pretrained.pth' # noqa + +architecture = _base_.model +architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path) +architecture.update({'data_preprocessor': _base_.data_preprocessor}) +data_preprocessor = None + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherAlgorithm', + architecture=architecture, + interval=4, + mutator=dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict( + type='GroupFisherChannelUnit', + default_args={'detla_type': 'flop'}, + ), + ), +) +custom_hooks = [ + dict(type='mmrazor.PruningStructureHook'), + dict( + type='mmrazor.ResourceInfoHook', + interval=4, + demo_input=dict( + type='mmrazor.DefaultDemoInput', + input_shape=[1, 3, 32, 32], + ), + save_ckpt_delta_thr=[0.5, 0.4, 0.3], + ), +] +model_wrapper_cfg = dict( + type='mmrazor.GroupFisherDDP', + broadcast_buffers=False, +) + +optim_wrapper = dict(optimizer=dict(lr=0.0001)) diff --git a/projects/group_fisher/configs/mmdet/retinanet/group-fisher-finetune_retinanet_resnet50_8xb2_coco.py b/projects/group_fisher/configs/mmdet/retinanet/group-fisher-finetune_retinanet_resnet50_8xb2_coco.py new file mode 100644 index 000000000..cff80c074 --- /dev/null +++ b/projects/group_fisher/configs/mmdet/retinanet/group-fisher-finetune_retinanet_resnet50_8xb2_coco.py @@ -0,0 +1,21 @@ +_base_ = './group-fisher-pruning_retinanet_resnet50_8xb2_coco.py' + +algorithm = _base_.model +# `pruned_path` need to be updated. +pruned_path = './work_dirs/group-fisher-pruning_retinanet_resnet50_8xb2_coco/flops_0.50.pth' # noqa +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='PruneDeployWrapper', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=0.01)) +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/projects/group_fisher/configs/mmdet/retinanet/group-fisher-finetune_retinanet_resnet50_8xb2_coco_act.py b/projects/group_fisher/configs/mmdet/retinanet/group-fisher-finetune_retinanet_resnet50_8xb2_coco_act.py new file mode 100644 index 000000000..3d374b918 --- /dev/null +++ b/projects/group_fisher/configs/mmdet/retinanet/group-fisher-finetune_retinanet_resnet50_8xb2_coco_act.py @@ -0,0 +1,7 @@ +_base_ = './group-fisher-finetune_retinanet_resnet50_8xb2_coco.py' + +pruned_path = './work_dirs/group-fisher-pruning_retinanet_resnet50_8xb2_coco_act/flops_0.50.pth' # noqa + +model = dict( + algorithm=dict(init_cfg=dict(type='Pretrained', + checkpoint=pruned_path), ), ) diff --git a/projects/group_fisher/configs/mmdet/retinanet/group-fisher-pruning_retinanet_resnet50_8xb2_coco.py b/projects/group_fisher/configs/mmdet/retinanet/group-fisher-pruning_retinanet_resnet50_8xb2_coco.py new file mode 100644 index 000000000..b58939f5f --- /dev/null +++ b/projects/group_fisher/configs/mmdet/retinanet/group-fisher-pruning_retinanet_resnet50_8xb2_coco.py @@ -0,0 +1,49 @@ +_base_ = 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py' +custom_imports = dict(imports=['projects']) + +architecture = _base_.model + +architecture.backbone.frozen_stages = -1 + +if hasattr(_base_, 'data_preprocessor'): + architecture.update({'data_preprocessor': _base_.data_preprocessor}) + data_preprocessor = None + +pretrained_path = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa +architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherAlgorithm', + architecture=architecture, + interval=10, + mutator=dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict( + type='GroupFisherChannelUnit', + default_args=dict(detla_type='flop', ), + ), + ), +) + +model_wrapper_cfg = dict( + type='mmrazor.GroupFisherDDP', + broadcast_buffers=False, +) + +optim_wrapper = dict(optimizer=dict(lr=0.002)) + +custom_hooks = [ + dict(type='mmrazor.PruningStructureHook'), + dict( + type='mmrazor.ResourceInfoHook', + interval=10, + demo_input=dict( + type='mmrazor.DefaultDemoInput', + input_shape=[1, 3, 1333, 800], + ), + save_ckpt_delta_thr=[0.75, 0.5], + ), +] diff --git a/projects/group_fisher/configs/mmdet/retinanet/group-fisher-pruning_retinanet_resnet50_8xb2_coco_act.py b/projects/group_fisher/configs/mmdet/retinanet/group-fisher-pruning_retinanet_resnet50_8xb2_coco_act.py new file mode 100644 index 000000000..2d160bbf3 --- /dev/null +++ b/projects/group_fisher/configs/mmdet/retinanet/group-fisher-pruning_retinanet_resnet50_8xb2_coco_act.py @@ -0,0 +1,5 @@ +_base_ = './group-fisher-pruning_retinanet_resnet50_8xb2_coco.py' + +model = dict( + mutator=dict( + channel_unit_cfg=dict(default_args=dict(detla_type='act', ), ), ), ) diff --git a/projects/group_fisher/configs/mmdet/retinanet/run.sh b/projects/group_fisher/configs/mmdet/retinanet/run.sh new file mode 100644 index 000000000..6405abf53 --- /dev/null +++ b/projects/group_fisher/configs/mmdet/retinanet/run.sh @@ -0,0 +1,6 @@ +bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmdet/group-fisher-pruning_retinanet_resnet50_8xb2_coco.py 8 +bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmdet/group-fisher-finetune_retinanet_resnet50_8xb2_coco.py 8 + + +bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmdet/group-fisher-pruning_retinanet_resnet50_8xb2_coco_act.py 8 +bash ./tools/dist_train.sh ./projects/group_fisher/configs/mmdet/group-fisher-finetune_retinanet_resnet50_8xb2_coco_act.py 8 diff --git a/projects/group_fisher/modules/group_fisher_algorthm.py b/projects/group_fisher/modules/group_fisher_algorthm.py new file mode 100644 index 000000000..f7174ed44 --- /dev/null +++ b/projects/group_fisher/modules/group_fisher_algorthm.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from mmengine.logging import print_log +from mmengine.model import BaseModel, MMDistributedDataParallel + +from mmrazor.models.algorithms.base import BaseAlgorithm +from mmrazor.registry import MODEL_WRAPPERS, MODELS +from ...cores.utils import RuntimeInfo # type: ignore +from .group_fisher_channel_mutator import GroupFisherChannelMutator + + +@MODELS.register_module() +class GroupFisherAlgorithm(BaseAlgorithm): + """`Group Fisher Pruning for Practical Network Compression`. + https://arxiv.org/pdf/2108.00708.pdf. + + Args: + architecture (Union[BaseModel, Dict]): The model to be pruned. + mutator (Union[Dict, ChannelMutator], optional): The config + of a mutator. Defaults to dict( type='GroupFisherChannelMutator', + channel_unit_cfg=dict( type='GroupFisherChannelUnit')). + interval (int): The interval of pruning two channels. Defaults to 10. + data_preprocessor (Optional[Union[Dict, nn.Module]], optional): + Defaults to None. + init_cfg (Optional[Dict], optional): init config for the model. + Defaults to None. + """ + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator: Union[Dict, GroupFisherChannelMutator] = dict( + type='GroupFisherChannelMutator', + channel_unit_cfg=dict(type='GroupFisherChannelUnit')), + interval: int = 10, + data_preprocessor: Optional[Union[Dict, nn.Module]] = None, + init_cfg: Optional[Dict] = None) -> None: + + super().__init__(architecture, data_preprocessor, init_cfg) + + self.interval = interval + + # using sync bn or normal bn + if dist.is_initialized(): + print_log('Convert Bn to SyncBn.') + self.architecture = nn.SyncBatchNorm.convert_sync_batchnorm( + self.architecture) + else: + from mmengine.model import revert_sync_batchnorm + self.architecture = revert_sync_batchnorm(self.architecture) + + # mutator + self.mutator: GroupFisherChannelMutator = MODELS.build(mutator) + self.mutator.prepare_from_supernet(self.architecture) + + def train_step(self, data: Union[dict, tuple, list], + optim_wrapper) -> Dict[str, torch.Tensor]: + algorithm = self + algorithm.mutator.start_record_info() + res = super().train_step(data, optim_wrapper) + algorithm.mutator.end_record_info() + + algorithm.mutator.update_imp() + algorithm.mutator.reset_recorded_info() + + if RuntimeInfo.iter() % algorithm.interval == 0: + algorithm.mutator.try_prune() + algorithm.mutator.reset_imp() + + return res + + +@MODEL_WRAPPERS.register_module() +class GroupFisherDDP(MMDistributedDataParallel): + """Train step for group fisher.""" + + def train_step(self, data: Union[dict, tuple, list], + optim_wrapper) -> Dict[str, torch.Tensor]: + algorithm = self.module + algorithm.mutator.start_record_info() + res = super().train_step(data, optim_wrapper) + algorithm.mutator.end_record_info() + + algorithm.mutator.update_imp() + algorithm.mutator.reset_recorded_info() + + if RuntimeInfo.iter() % algorithm.interval == 0: + algorithm.mutator.try_prune() + algorithm.mutator.reset_imp() + + return res diff --git a/projects/group_fisher/modules/group_fisher_channel_mutator.py b/projects/group_fisher/modules/group_fisher_channel_mutator.py new file mode 100644 index 000000000..8db1403c2 --- /dev/null +++ b/projects/group_fisher/modules/group_fisher_channel_mutator.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Dict, List, Type, Union + +from mmengine.dist import dist + +from mmrazor.models.mutators.channel_mutator.channel_mutator import \ + ChannelMutator +from mmrazor.registry import MODELS +from mmrazor.utils import print_log +from .group_fisher_channel_unit import GroupFisherChannelUnit + + +@MODELS.register_module() +class GroupFisherChannelMutator(ChannelMutator[GroupFisherChannelUnit]): + """Channel mutator for GroupFisher Pruning Algorithm. + + Args: + channel_unit_cfg (Union[dict, Type[ChannelUnitType]], optional): + Config of MutableChannelUnits. Defaults to + dict(type='GroupFisherChannelUnit', + default_args=dict(choice_mode='ratio')). + parse_cfg (Dict): The config of the tracer to parse the model. + Defaults to dict(type='ChannelAnalyzer', + demo_input=(1, 3, 224, 224), + tracer_type='FxTracer'). + """ + + def __init__(self, + channel_unit_cfg: Union[dict, + Type[GroupFisherChannelUnit]] = dict( + type='GroupFisherChannelUnit', + default_args=dict( + choice_mode='ratio')), + parse_cfg: Dict = dict( + type='ChannelAnalyzer', + demo_input=(1, 3, 224, 224), + tracer_type='FxTracer'), + min_ratio=0.0, + min_channel=0, + **kwargs) -> None: + super().__init__(channel_unit_cfg, parse_cfg, **kwargs) + self.mutable_units: List[GroupFisherChannelUnit] + self.min_ratio = min_ratio + self.min_channel = min_channel + + def start_record_info(self) -> None: + """Start recording the related information.""" + for unit in self.mutable_units: + unit.start_record_fisher_info() + + def end_record_info(self) -> None: + """Stop recording the related information.""" + for unit in self.mutable_units: + unit.end_record_fisher_info() + + def reset_recorded_info(self) -> None: + """Reset the related information.""" + for unit in self.mutable_units: + unit.reset_recorded() + + def try_prune(self) -> None: + """Prune the channel with the minimum fisher unless it is the last + channel of the current layer.""" + min_imp = 1e5 + min_unit = self.mutable_units[0] + for unit in self.mutable_units: + if unit.mutable_channel.activated_channels > max( + self.min_channel, (unit.num_channels * self.min_ratio), 0): + imp = unit.importance() + if imp.isnan().any(): + if dist.get_rank() == 0: + print_log( + f'{unit.name} detects nan in importance, this pruning skips.' # noqa + ) + return + if imp.min() < min_imp: + min_imp = imp.min().item() + min_unit = unit + if min_unit.try_to_prune_min_channel(): + if dist.get_rank() == 0: + print_log( + f'{min_unit.name} prunes a channel with min imp = {min_imp}' # noqa + ) + + def update_imp(self) -> None: + """Update the fisher information of each unit.""" + for unit in self.mutable_units: + unit.update_fisher_info() + + def reset_imp(self) -> None: + """Reset the fisher information of each unit.""" + for unit in self.mutable_units: + unit.reset_fisher_info() diff --git a/projects/group_fisher/modules/group_fisher_channel_unit.py b/projects/group_fisher/modules/group_fisher_channel_unit.py new file mode 100644 index 000000000..41f5be83c --- /dev/null +++ b/projects/group_fisher/modules/group_fisher_channel_unit.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +from mmengine.model.utils import _BatchNormXd +from mmengine.utils.dl_utils.parrots_wrapper import \ + SyncBatchNorm as EngineSyncBatchNorm +from torch import distributed as dist + +import mmrazor.models.architectures.dynamic_ops as dynamic_ops +from mmrazor.models.mutables.mutable_channel.mutable_channel_container import \ + MutableChannelContainer +from mmrazor.models.mutables.mutable_channel.units.l1_mutable_channel_unit import \ + L1MutableChannelUnit # noqa +from mmrazor.registry import MODELS +from .group_fisher_ops import GroupFisherConv2d, GroupFisherLinear + + +@MODELS.register_module() +class GroupFisherChannelUnit(L1MutableChannelUnit): + """ChannelUnit for GroupFisher Pruning Algorithm. + + Args: + num_channels (int): Number of channels. + detla_type (str): Type of delta, which is one of 'flop', 'act' or + 'none'. Defaults to 'flop'. + mutate_linear (bool): Whether to prune linear layers. + """ + + def __init__(self, + num_channels: int, + detla_type: str = 'flop', + mutate_linear=False, + *args) -> None: + super().__init__(num_channels, *args) + _fisher_info = torch.zeros([self.num_channels]) + self.register_buffer('normalized_fisher_info', _fisher_info) + self.normalized_fisher_info: torch.Tensor + + self.hook_handles: List = [] + assert detla_type in ['flop', 'act', 'none'] + self.delta_type = detla_type + + self.mutate_linear = mutate_linear + + def prepare_for_pruning(self, model: nn.Module) -> None: + """Prepare for pruning, including register mutable channels. + + Args: + model (nn.Module): The model need to be pruned. + """ + # register MutableMask + self._replace_with_dynamic_ops( + model, { + nn.Conv2d: GroupFisherConv2d, + nn.BatchNorm2d: dynamic_ops.DynamicBatchNorm2d, + nn.Linear: GroupFisherLinear, + nn.SyncBatchNorm: dynamic_ops.DynamicSyncBatchNorm, + EngineSyncBatchNorm: dynamic_ops.DynamicSyncBatchNorm, + _BatchNormXd: dynamic_ops.DynamicBatchNormXd, + }) + self._register_channel_container(model, MutableChannelContainer) + self._register_mutable_channel(self.mutable_channel) + + # prune + def try_to_prune_min_channel(self) -> bool: + """Prune the channel with the minimum value of fisher information.""" + if self.mutable_channel.activated_channels > 1: + imp = self.importance() + index = imp.argmin() + self.mutable_channel.mask.scatter_(0, index, 0.0) + return True + else: + return False + + @property + def is_mutable(self) -> bool: + """Whether the unit is mutable.""" + mutable = super().is_mutable + if self.mutate_linear: + return mutable + else: + has_linear = False + for layer in self.input_related: + if isinstance(layer.module, nn.Linear): + has_linear = True + return mutable and (not has_linear) + + # fisher information recorded + + def start_record_fisher_info(self) -> None: + """Start recording the related fisher info of each channel.""" + for channel in self.input_related + self.output_related: + module = channel.module + if isinstance(module, GroupFisherConv2d): + module.start_record() + + def end_record_fisher_info(self) -> None: + """Stop recording the related fisher info of each channel.""" + for channel in self.input_related + self.output_related: + module = channel.module + if isinstance(module, GroupFisherConv2d): + module.end_record() + + def reset_recorded(self) -> None: + """Reset the recorded info of each channel.""" + for channel in self.input_related + self.output_related: + module = channel.module + if isinstance(module, GroupFisherConv2d): + module.reset_recorded() + + # fisher related computation + + def importance(self): + """The importance of each channel.""" + fisher = self.normalized_fisher_info.clone() + mask = self.mutable_channel.current_mask + n_mask = (1 - mask.float()).bool() + fisher.masked_fill_(n_mask, fisher.max() + 1) + return fisher + + def reset_fisher_info(self) -> None: + """Reset the related fisher info.""" + self.normalized_fisher_info.zero_() + + @torch.no_grad() + def update_fisher_info(self) -> None: + """Update the fisher info of each channel.""" + batch_fisher_sum = 0.0 + for channel in self.input_related: + module = channel.module + if isinstance(module, GroupFisherConv2d): + batch_fisher = self.current_batch_fisher + batch_fisher_sum = batch_fisher_sum + batch_fisher + assert isinstance(batch_fisher_sum, torch.Tensor) + if dist.is_initialized(): + dist.all_reduce(batch_fisher_sum) + batch_fisher_sum = self._get_normalized_fisher_info( + batch_fisher_sum, self.delta_type) + self.normalized_fisher_info = self.normalized_fisher_info + batch_fisher_sum # noqa + + @property + def current_batch_fisher(self) -> torch.Tensor: + """Accumulate the unit's fisher info of this batch.""" + with torch.no_grad(): + fisher: torch.Tensor = 0 + for channel in self.input_related: + if isinstance(channel.module, GroupFisherConv2d): + fisher = fisher + self._fisher_of_a_module(channel.module) + return (fisher**2).sum(0) + + @torch.no_grad() + def _fisher_of_a_module(self, module: GroupFisherConv2d) -> torch.Tensor: + """Calculate the fisher info of one module. + + Args: + module (GroupFisherConv2d): A `GroupFisherConv2d` module. + """ + assert len(module.recorded_input) > 0 and \ + len(module.recorded_input) == len(module.recorded_grad) + fisher_sum: torch.Tensor = 0 + for input, grad_input in zip(module.recorded_input, + module.recorded_grad): + fisher: torch.Tensor = input * grad_input + fisher = fisher.sum(dim=[i for i in range(2, len(fisher.shape))]) + fisher_sum = fisher_sum + fisher + + # expand to full num_channel + batch_size = fisher_sum.shape[0] + mask = self.mutable_channel.current_mask.unsqueeze(0).expand( + [batch_size, self.num_channels]) + zeros = fisher_sum.new_zeros([batch_size, self.num_channels]) + fisher_sum = zeros.masked_scatter_(mask, fisher_sum) + return fisher_sum + + @property + def _delta_flop_of_a_channel(self) -> torch.Tensor: + """Calculate the flops of a channel.""" + delta_flop = 0 + for channel in self.output_related: + if isinstance(channel.module, GroupFisherConv2d): + delta_flop += channel.module.delta_flop_of_a_out_channel + for channel in self.input_related: + if isinstance(channel.module, GroupFisherConv2d): + delta_flop += channel.module.delta_flop_of_a_in_channel + return delta_flop + + @property + def _delta_memory_of_a_channel(self) -> torch.Tensor: + """Calculate the memory of a channel.""" + delta_memory = 0 + for channel in self.output_related: + if isinstance(channel.module, GroupFisherConv2d): + delta_memory += channel.module.delta_memory_of_a_out_channel + return delta_memory + + @torch.no_grad() + def _get_normalized_fisher_info(self, + fisher_info, + delta_type='flop') -> torch.Tensor: + """Get the normalized fisher info. + + Args: + delta_type (str): Type of delta. Defaults to 'flop'. + """ + fisher = fisher_info.double() + if delta_type == 'flop': + delta_flop = self._delta_flop_of_a_channel + assert delta_flop > 0 + fisher = fisher / (float(delta_flop) / 1e9) + elif delta_type == 'act': + delta_memory = self._delta_memory_of_a_channel + assert delta_memory > 0 + fisher = fisher / (float(delta_memory) / 1e6) + elif delta_type == 'none': + pass + else: + raise NotImplementedError(delta_type) + return fisher diff --git a/projects/group_fisher/modules/group_fisher_ops.py b/projects/group_fisher/modules/group_fisher_ops.py new file mode 100644 index 000000000..21f84ab0c --- /dev/null +++ b/projects/group_fisher/modules/group_fisher_ops.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch + +from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_conv import \ + DynamicConv2d +from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_linear import \ + DynamicLinear +from mmrazor.registry import TASK_UTILS +from ...cores.counters import DynamicConv2dCounter # type: ignore +from ...cores.counters import DynamicLinearCounter # type: ignore + + +class GroupFisherMixin: + """The mixin class for GroupFisher ops.""" + + def _init(self) -> None: + self.handlers: list = [] + self.recorded_input: List = [] + self.recorded_grad: List = [] + self.recorded_out_shape: List = [] + + def forward_hook_wrapper(self): + """Wrap the hook used in forward.""" + + def forward_hook(module: GroupFisherMixin, input, output): + module.recorded_out_shape.append(output.shape) + module.recorded_input.append(input[0]) + + return forward_hook + + def backward_hook_wrapper(self): + """Wrap the hook used in backward.""" + + def backward_hook(module: GroupFisherMixin, grad_in, grad_out): + module.recorded_grad.insert(0, grad_in[0]) + + return backward_hook + + def start_record(self: torch.nn.Module) -> None: + """Start recording information during forward and backward.""" + self.end_record() # ensure to run start_record only once + self.handlers.append( + self.register_forward_hook(self.forward_hook_wrapper())) + self.handlers.append( + self.register_backward_hook(self.backward_hook_wrapper())) + + def end_record(self): + """Stop recording information during forward and backward.""" + for handle in self.handlers: + handle.remove() + self.handlers = [] + + def reset_recorded(self): + """Reset the recorded information.""" + self.recorded_input = [] + self.recorded_grad = [] + self.recorded_out_shape = [] + + @property + def delta_flop_of_a_out_channel(self): + raise NotImplementedError() + + @property + def delta_flop_of_a_in_channel(self): + raise NotImplementedError() + + @property + def delta_memory_of_a_out_channel(self): + raise NotImplementedError() + + +class GroupFisherConv2d(DynamicConv2d, GroupFisherMixin): + """The Dynamic Conv2d operation used in GroupFisher Algorithm.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._init() + + @property + def delta_flop_of_a_out_channel(self) -> torch.Tensor: + """Calculate the summation of flops when prune an out_channel.""" + delta_flop_sum = 0 + for shape in self.recorded_out_shape: + _, _, h, w = shape + in_c = int(self.mutable_attrs['in_channels'].current_mask.float(). + sum().item()) + delta_flop = h * w * self.kernel_size[0] * self.kernel_size[ + 1] * in_c + delta_flop_sum += delta_flop + return delta_flop_sum + + @property + def delta_flop_of_a_in_channel(self): + """Calculate the summation of flops when prune an in_channel.""" + delta_flop_sum = 0 + for shape in self.recorded_out_shape: + _, out_c, h, w = shape + delta_flop = out_c * h * w * self.kernel_size[ + 0] * self.kernel_size[1] + delta_flop_sum += delta_flop + return delta_flop_sum + + @property + def delta_memory_of_a_out_channel(self): + """Calculate the summation of memory when prune a channel.""" + delta_flop_sum = 0 + for shape in self.recorded_out_shape: + _, _, h, w = shape + delta_flop_sum += h * w + return delta_flop_sum + + +class GroupFisherLinear(DynamicLinear, GroupFisherMixin): + """The Dynamic Linear operation used in GroupFisher Algorithm.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._init() + + @property + def delta_flop_of_a_out_channel(self): + """Calculate the summation of flops when prune an out_channel.""" + in_c = self.mutable_attrs['in_channels'].current_mask.float().sum() + return in_c * len(self.recorded_out_shape) + + @property + def delta_flop_of_a_in_channel(self): + """Calculate the summation of flops when prune an in_channel.""" + out_c = self.mutable_attrs['out_channels'].current_mask.float().sum() + return out_c * len(self.recorded_out_shape) + + @property + def delta_memory_of_a_out_channel(self): + """Calculate the summation of memory when prune a channel.""" + return 1 * len(self.recorded_out_shape) + + +@TASK_UTILS.register_module() +class GroupFisherConv2dCounter(DynamicConv2dCounter): + """Counter of GroupFisherConv2d.""" + pass + + +@TASK_UTILS.register_module() +class GroupFisherLinearCounter(DynamicLinearCounter): + """Counter of GroupFisherLinear.""" + pass diff --git a/projects/models/__init__.py b/projects/models/__init__.py new file mode 100644 index 000000000..a1ce01a29 --- /dev/null +++ b/projects/models/__init__.py @@ -0,0 +1 @@ +from .vgg import * # noqa diff --git a/projects/models/vgg/__init__.py b/projects/models/vgg/__init__.py new file mode 100644 index 000000000..d386f3439 --- /dev/null +++ b/projects/models/vgg/__init__.py @@ -0,0 +1,3 @@ +from .vgg_cifar import VGGCifar + +__all__ = ['VGGCifar'] diff --git a/projects/models/vgg/configs/cifar10_bs128.py b/projects/models/vgg/configs/cifar10_bs128.py new file mode 100644 index 000000000..35c32c415 --- /dev/null +++ b/projects/models/vgg/configs/cifar10_bs128.py @@ -0,0 +1,12 @@ +# optimizer +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.005)) +# learning policy +param_scheduler = dict( + type='MultiStepLR', by_epoch=True, milestones=[50, 100], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=150) +# val_cfg = dict(interval=1) # validate every epoch +val_cfg = dict() # validate every epoch +test_cfg = dict() diff --git a/projects/models/vgg/configs/cifar10_bs16.py b/projects/models/vgg/configs/cifar10_bs16.py new file mode 100644 index 000000000..ea9d4a193 --- /dev/null +++ b/projects/models/vgg/configs/cifar10_bs16.py @@ -0,0 +1,48 @@ +# dataset settings +dataset_type = 'CIFAR10' +data_preprocessor = dict( + num_classes=10, + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type='RandomCrop', crop_size=32, padding=4), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackClsInputs'), +] + +test_pipeline = [ + dict(type='CenterCrop', crop_size=32), + dict(type='PackClsInputs'), +] + +train_dataloader = dict( + batch_size=256, + num_workers=2, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10', + test_mode=False, + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=256, + num_workers=2, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10/', + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='Accuracy', topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/projects/models/vgg/configs/default_runtime.py b/projects/models/vgg/configs/default_runtime.py new file mode 100644 index 000000000..e78b3d753 --- /dev/null +++ b/projects/models/vgg/configs/default_runtime.py @@ -0,0 +1,48 @@ +# defaults to use registries in mmcls +default_scope = 'mmcls' + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type='IterTimerHook'), + + # print log every 100 iterations. + logger=dict(type='LoggerHook', interval=100), + + # enable the parameter scheduler. + param_scheduler=dict(type='ParamSchedulerHook'), + + # save checkpoint per epoch. + checkpoint=dict(type='CheckpointHook', interval=50, save_best='auto'), + + # set sampler seed in distributed evrionment. + sampler_seed=dict(type='DistSamplerSeedHook'), + + # validation results visualization, set True to enable it. + visualization=dict(type='VisualizationHook', enable=False), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict(type='ClsVisualizer', vis_backends=vis_backends) + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False diff --git a/projects/models/vgg/configs/deploy.sh b/projects/models/vgg/configs/deploy.sh new file mode 100644 index 000000000..b03ab3cac --- /dev/null +++ b/projects/models/vgg/configs/deploy.sh @@ -0,0 +1,19 @@ + +python mmdeploy/tools/deploy.py \ + mmdeploy/configs/mmcls/classification_onnxruntime_static.py \ + ./projects/models/vgg/configs/vgg_pretrain.py \ + ./work_dirs/pretrained/vgg_pretrained.pth \ + ./mmdeploy/demo/resources/face.png \ + --work-dir work_dirs/mmdeploy_model/ \ + --device cpu \ + --dump-info + +python mmdeploy/tools/profiler.py \ + mmdeploy/configs/mmcls/classification_onnxruntime_static.py \ + ./projects/models/vgg/configs/vgg_pretrain.py \ + mmdeploy/resources/ \ + --model ./work_dirs/mmdeploy_model/end2end.onnx \ + --shape 32x32 \ + --device cpu \ + --warmup 50 \ + --num-iter 200 diff --git a/projects/models/vgg/configs/vgg_model.py b/projects/models/vgg/configs/vgg_model.py new file mode 100644 index 000000000..5ee6735cc --- /dev/null +++ b/projects/models/vgg/configs/vgg_model.py @@ -0,0 +1,10 @@ +model = dict( + type='mmcls.ImageClassifier', + backbone=dict(type='mmrazor.VGGCifar', num_classes=10), + head=dict( + type='mmcls.LinearClsHead', + num_classes=10, + in_channels=512, + loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0), + ), +) diff --git a/projects/models/vgg/configs/vgg_pretrain.py b/projects/models/vgg/configs/vgg_pretrain.py new file mode 100644 index 000000000..5c9ed34f7 --- /dev/null +++ b/projects/models/vgg/configs/vgg_pretrain.py @@ -0,0 +1,5 @@ +_base_ = [ + './vgg_model.py', './cifar10_bs16.py', './cifar10_bs128.py', + './default_runtime.py' +] +custom_imports = dict(imports=['projects']) diff --git a/projects/models/vgg/configs/vgg_resource.py b/projects/models/vgg/configs/vgg_resource.py new file mode 100644 index 000000000..082d44ef4 --- /dev/null +++ b/projects/models/vgg/configs/vgg_resource.py @@ -0,0 +1,26 @@ +_base_ = ['./vgg_pretrain.py'] + +target_pruning_ratio = { + 'backbone.features.conv0_(0, 64)_64': 22, + 'backbone.features.conv1_(0, 64)_64': 43, + 'backbone.features.conv3_(0, 128)_128': 85, + 'backbone.features.conv4_(0, 128)_128': 104, + 'backbone.features.conv6_(0, 256)_256': 201, + 'backbone.features.conv7_(0, 256)_256': 166, + 'backbone.features.conv8_(0, 256)_256': 144, + 'backbone.features.conv10_(0, 512)_512': 147, + 'backbone.features.conv11_(0, 512)_512': 88, + 'backbone.features.conv12_(0, 512)_512': 80, + 'backbone.features.conv14_(0, 512)_512': 146, + 'backbone.features.conv15_(0, 512)_512': 179 +} +data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'} +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='PruneWrapper', + architecture=architecture, + target_pruning_ratio=target_pruning_ratio, +) diff --git a/projects/models/vgg/vgg_cifar.py b/projects/models/vgg/vgg_cifar.py new file mode 100644 index 000000000..8fecace9a --- /dev/null +++ b/projects/models/vgg/vgg_cifar.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import torch.nn as nn +from mmengine.model import BaseModel + +from mmrazor.registry import MODELS + +defaultcfg = [ + 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, + 512, 512 +] +relucfg = [2, 6, 9, 13, 16, 19, 23, 26, 29, 33, 36, 39] + + +@MODELS.register_module() +class VGGCifar(BaseModel): + """VGG16 with bn for pruning on cifar10 . + + It's modified from https://github.com/lmbxmu/HRank. + """ + + def __init__(self, cfg=None, num_classes=10): + super().__init__() + + if cfg is None: + cfg = defaultcfg + self.relucfg = relucfg + + self.features = self._make_layers(cfg) + self.classifier = nn.Sequential( + OrderedDict([ + ('linear1', nn.Linear(cfg[-2], cfg[-1])), + ('norm1', nn.BatchNorm1d(cfg[-1])), + ('relu1', nn.ReLU(inplace=True)), + # ('linear2', nn.Linear(cfg[-1], num_classes)), + ])) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + def _make_layers(self, cfg): + + layers = nn.Sequential() + in_channels = 3 + cnt = 0 + + for i, x in enumerate(cfg): + if x == 'M': + layers.add_module('pool%d' % i, + nn.MaxPool2d(kernel_size=2, stride=2)) + else: + cnt += 1 + conv2d = nn.Conv2d(in_channels, x, kernel_size=3, padding=1) + layers.add_module('conv%d' % i, conv2d) + layers.add_module('norm%d' % i, nn.BatchNorm2d(x)) + layers.add_module('relu%d' % i, nn.ReLU(inplace=True)) + in_channels = x + + return layers + + def forward(self, x): + x = self.features(x) + + x = self.avg_pool(x) + x = x.flatten(1) + x = self.classifier(x) + return (x, ) diff --git a/tests/test_models/test_algorithms/test_prune_algorithm.py b/tests/test_models/test_algorithms/test_prune_algorithm.py index 00d615815..ce27d8293 100644 --- a/tests/test_models/test_algorithms/test_prune_algorithm.py +++ b/tests/test_models/test_algorithms/test_prune_algorithm.py @@ -11,6 +11,9 @@ from mmrazor.models.algorithms.pruning.ite_prune_algorithm import ( ItePruneAlgorithm, ItePruneConfigManager) from mmrazor.registry import MODELS +from projects.group_fisher.modules.group_fisher_algorthm import \ + GroupFisherAlgorithm +from projects.group_fisher.modules.group_fisher_ops import GroupFisherConv2d from ...utils.set_dist_env import SetDistEnv @@ -262,3 +265,63 @@ def test_resume(self): print(algorithm2.mutator.current_choices) self.assertDictEqual(algorithm.mutator.current_choices, algorithm2.mutator.current_choices) + + +class TestGroupFisherPruneAlgorithm(TestItePruneAlgorithm): + + def test_group_fisher_prune(self): + data = self.fake_cifar_data() + + MUTATOR_CONFIG = dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict(type='GroupFisherChannelUnit')) + + epoch = 2 + interval = 1 + delta = 'flops' + + algorithm = GroupFisherAlgorithm( + MODEL_CFG, + pruning=True, + mutator=MUTATOR_CONFIG, + delta=delta, + interval=interval, + save_ckpt_delta_thr=[1.1]).to(DEVICE) + mutator = algorithm.mutator + + ckpt_path = os.path.dirname(__file__) + f'/{delta}_0.99.pth' + + fake_cfg_path = os.path.dirname(__file__) + '/cfg.py' + self.gen_fake_cfg(fake_cfg_path) + self.assertTrue(os.path.exists(fake_cfg_path)) + + message_hub = MessageHub.get_current_instance() + cfg_str = open(fake_cfg_path).read() + message_hub.update_info('cfg', cfg_str) + + for e in range(epoch): + for ite in range(10): + self._set_epoch_ite(e, ite, epoch) + algorithm.forward( + data['inputs'], data['data_samples'], mode='loss') + self.gen_fake_grad(mutator) + self.assertEqual(delta, algorithm.delta) + self.assertEqual(interval, algorithm.interval) + self.assertTrue(os.path.exists(ckpt_path)) + os.remove(ckpt_path) + os.remove(fake_cfg_path) + self.assertTrue(not os.path.exists(ckpt_path)) + self.assertTrue(not os.path.exists(fake_cfg_path)) + + def gen_fake_grad(self, mutator): + for unit in mutator.mutable_units: + for channel in unit.input_related: + module = channel.module + if isinstance(module, GroupFisherConv2d): + module.recorded_grad = module.recorded_input + + def gen_fake_cfg(self, fake_cfg_path): + with open(fake_cfg_path, 'a', encoding='utf-8') as cfg: + cfg.write(f'work_dir = \'{os.path.dirname(__file__)}\'') + cfg.write('\n') diff --git a/tests/test_projects/__init__.py b/tests/test_projects/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_projects/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_projects/test_expand/__init__.py b/tests/test_projects/test_expand/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_projects/test_expand/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_projects/test_expand/test_expand.py b/tests/test_projects/test_expand/test_expand.py new file mode 100644 index 000000000..0408a4c7b --- /dev/null +++ b/tests/test_projects/test_expand/test_expand.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmrazor.models.mutables import SimpleMutableChannel +from mmrazor.models.mutators import ChannelMutator +from projects.cores.expandable_ops.ops import ExpandLinear +from projects.cores.expandable_ops.unit import (ExpandableUnit, + expand_dynamic_model, + expand_static_model) +from ...data.models import MultiConcatModel, SingleLineModel + + +class TestExpand(unittest.TestCase): + + def test_expand(self): + x = torch.rand([1, 3, 224, 224]) + model = MultiConcatModel() + print(model) + mutator = ChannelMutator[ExpandableUnit]( + channel_unit_cfg=ExpandableUnit) + mutator.prepare_from_supernet(model) + print(mutator.choice_template) + print(model) + y1 = model(x) + + for unit in mutator.mutable_units: + unit.expand(10) + print(unit.mutable_channel.mask.shape) + expand_dynamic_model(model, zero=True) + print(model) + y2 = model(x) + self.assertTrue((y1 - y2).abs().max() < 1e-3) + + def test_expand_static_model(self): + x = torch.rand([1, 3, 224, 224]) + model = SingleLineModel() + y1 = model(x) + expand_static_model(model, divisor=4) + y2 = model(x) + print(y1.reshape([-1])[:5]) + print(y2.reshape([-1])[:5]) + self.assertTrue((y1 - y2).abs().max() < 1e-3) + + def test_ExpandConv2d(self): + linear = ExpandLinear(3, 3) + mutable_in = SimpleMutableChannel(3) + mutable_out = SimpleMutableChannel(3) + linear.register_mutable_attr('in_channels', mutable_in) + linear.register_mutable_attr('out_channels', mutable_out) + + print(linear.weight) + + mutable_in.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0]) + mutable_out.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0]) + linear_ex = linear.expand(zero=True) + print(linear_ex.weight)