diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 836da6e68..c7f5b490a 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -1,6 +1,15 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate -from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax +from hls4ml.model.layers import ( + Activation, + BatchNormalization, + Dense, + HardActivation, + LayerNormalization, + ParametrizedActivation, + PReLU, + Softmax, +) from hls4ml.model.optimizer.passes.hgq_proxy_model import UnaryLUT # Dense templates @@ -119,6 +128,59 @@ def format(self, node): return self.template.format(**params) +# LayerNormalization templates + +layernorm_config_template = """struct config{index} : nnet::layernorm_config {{ + static const unsigned n_in = {n_in}; + static const unsigned seq_len = {seq_len}; + static const unsigned table_size = {table_size}; + static constexpr double table_range = {table_range}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + static const bool store_weights_in_bram = false; + static constexpr double epsilon = {epsilon}; + typedef {bias_t.name} bias_t; + typedef {scale_t.name} scale_t; + typedef {mean_t.name} mean_t; + typedef {table_t.name} table_t; + template + using product = nnet::product::{product_type}; +}};\n""" + +layernorm_function_template = 'nnet::layernormalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});' + +layernorm_include_list = ['nnet_utils/nnet_layernorm.h'] + + +class LayerNormalizationConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(LayerNormalization) + self.template = layernorm_config_template + + def format(self, node): + params = self._default_config_params(node) + params['n_in'] = node.get_input_variable().size_cpp() + params['seq_len'] = node.get_attr('seq_len') + params['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('scale').type.precision + ) + + return self.template.format(**params) + + +class LayerNormalizationFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(LayerNormalization, include_header=layernorm_include_list) + self.template = layernorm_function_template + + def format(self, node): + params = self._default_function_params(node) + params['scale'] = node.get_weights('scale').name + params['bias'] = node.get_weights('bias').name + + return self.template.format(**params) + + # Activation templates activ_config_template = """struct {type}_config{index} : nnet::activ_config {{ diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 117805dd8..8ebbbc999 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -21,6 +21,7 @@ GarNet, GarNetStack, Layer, + LayerNormalization, Pooling1D, Pooling2D, SeparableConv1D, @@ -558,6 +559,21 @@ def init_softmax(self, layer): len(layer.get_input_variable().shape) == 1 ), 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.' + @layer_optimizer(LayerNormalization) + def init_layernormalization(self, layer): + if 'table_t' not in layer.attributes: + layer.set_attr( + 'table_t', NamedType(name=layer.name + '_table_t', precision=FixedPrecisionType(width=16, integer=6)) + ) + if 'table_size' not in layer.attributes: + layer.set_attr('table_size', 4096) # table size + if 'table_range' not in layer.attributes: + layer.set_attr('table_range', 1.0) # table range + if 'mean_t' not in layer.attributes: + layer.set_attr( + 'mean_t', NamedType(name=layer.name + '_mean_t', precision=FixedPrecisionType(width=19, integer=6)) + ) + @layer_optimizer(Embedding) def init_embed(self, layer): if layer.attributes['n_in'] is None: diff --git a/hls4ml/converters/keras/core.py b/hls4ml/converters/keras/core.py index 637bb6d40..6e4e1ebb1 100644 --- a/hls4ml/converters/keras/core.py +++ b/hls4ml/converters/keras/core.py @@ -1,146 +1,174 @@ -from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer -from hls4ml.model.quantizers import BinaryQuantizer, TernaryQuantizer -from hls4ml.model.types import IntegerPrecisionType - - -@keras_handler('InputLayer') -def parse_input_layer(keras_layer, input_names, input_shapes, data_reader): - assert keras_layer['class_name'] == 'InputLayer' - - layer = parse_default_keras_layer(keras_layer, input_names) - - layer['input_shape'] = keras_layer['config']['batch_input_shape'][1:] - - dtype = keras_layer['config']['dtype'] - if dtype.startswith('int') or dtype.startswith('uint'): - layer['type_name'] = 'integer_input_t' - width = int(dtype[dtype.index('int') + 3 :]) - signed = not dtype.startswith('u') - layer['precision'] = IntegerPrecisionType(width=width, signed=signed) - # elif bool, q[u]int, ... - - output_shape = keras_layer['config']['batch_input_shape'] - - return layer, output_shape - - -dense_layers = ['Dense', 'BinaryDense', 'TernaryDense'] - - -@keras_handler(*dense_layers) -def parse_dense_layer(keras_layer, input_names, input_shapes, data_reader): - assert 'Dense' in keras_layer['class_name'] - - layer = parse_default_keras_layer(keras_layer, input_names) - - layer['weight_data'], layer['bias_data'] = get_weights_data(data_reader, layer['name'], ['kernel', 'bias']) - layer['n_in'] = layer['weight_data'].shape[0] - layer['n_out'] = layer['weight_data'].shape[1] - if 'Binary' in layer['class_name']: - layer['weight_quantizer'] = BinaryQuantizer(bits=2) - layer['bias_quantizer'] = BinaryQuantizer(bits=2) - elif 'Ternary' in layer['class_name']: - layer['weight_quantizer'] = TernaryQuantizer() - layer['bias_quantizer'] = TernaryQuantizer() - else: - layer['weight_quantizer'] = None - layer['bias_quantizer'] = None - output_shape = input_shapes[0][:] - output_shape[-1] = layer['n_out'] - - return layer, output_shape - - -activation_layers = ['Activation', 'LeakyReLU', 'ThresholdedReLU', 'ELU', 'PReLU', 'Softmax', 'ReLU'] - - -@keras_handler(*activation_layers) -def parse_activation_layer(keras_layer, input_names, input_shapes, data_reader): - assert keras_layer['class_name'] in activation_layers - - layer = parse_default_keras_layer(keras_layer, input_names) - - if layer['class_name'] != 'Activation': - layer['activation'] = layer['class_name'] - - if layer['activation'] == 'elu': - layer['class_name'] = 'ELU' # always use ELU type for elu, even if passed as activation - - if layer['class_name'] == 'LeakyReLU': - # the name changes for version 3 - layer['activ_param'] = keras_layer['config'].get('negative_slope', keras_layer['config'].get('alpha', 0.3)) - elif layer['class_name'] == 'ThresholdedReLU': - layer['activ_param'] = keras_layer['config'].get('theta', 1.0) - elif layer['class_name'] == 'ELU': - layer['activ_param'] = keras_layer['config'].get('alpha', 1.0) - elif layer['class_name'] == 'ReLU': - layer['class_name'] = 'Activation' - elif layer['class_name'] == 'PReLU': - layer['param_data'] = get_weights_data(data_reader, layer['name'], 'alpha') - - if layer['class_name'] == 'Activation' and layer['activation'] == 'softmax': - layer['class_name'] = 'Softmax' - if layer['class_name'] == 'Activation' and layer['activation'] == 'hard_sigmoid': - layer['class_name'] = 'HardActivation' - if layer['class_name'] == 'Softmax': - layer['axis'] = keras_layer['config'].get('axis', -1) - if layer['class_name'] == 'Activation' and layer['activation'] == 'leaky_relu': - layer['class_name'] = 'LeakyReLU' - # The parameter name changes for API v3; the default is different than in LeakyReLU layer - layer['activ_param'] = keras_layer['config'].get('negative_slope', keras_layer['config'].get('alpha', 0.2)) - - return layer, [shape for shape in input_shapes[0]] - - -@keras_handler('BatchNormalization') -def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader): - assert 'BatchNormalization' in keras_layer['class_name'] or 'QConv2DBatchnorm' in keras_layer['class_name'] - - layer = parse_default_keras_layer(keras_layer, input_names) - - in_size = 1 - for dim in input_shapes[0][1:]: - in_size *= dim - layer['n_in'] = in_size - layer['n_out'] = layer['n_in'] - if len(input_shapes[0]) == 2: - layer['n_filt'] = -1 - elif len(input_shapes[0]) == 3: - layer['n_filt'] = input_shapes[0][2] - elif len(input_shapes[0]) == 4: - layer['n_filt'] = input_shapes[0][3] - - layer['use_gamma'] = keras_layer['config']['scale'] - if layer['use_gamma']: - layer['gamma_data'] = get_weights_data(data_reader, layer['name'], 'gamma') - else: - layer['gamma_data'] = 1 - - layer['use_beta'] = keras_layer['config']['center'] - if layer['use_beta']: - layer['beta_data'] = get_weights_data(data_reader, layer['name'], 'beta') - else: - layer['beta_data'] = 0 - - layer['mean_data'], layer['variance_data'] = get_weights_data( - data_reader, layer['name'], ['moving_mean', 'moving_variance'] - ) - - return layer, [shape for shape in input_shapes[0]] - - -@keras_handler('Embedding') -def parse_embedding_layer(keras_layer, input_names, input_shapes, data_reader): - assert 'Embedding' in keras_layer['class_name'] - - layer = parse_default_keras_layer(keras_layer, input_names) - - layer['n_in'] = input_shapes[0][1] - layer['vocab_size'] = keras_layer['config']['input_dim'] - layer['n_out'] = keras_layer['config']['output_dim'] - - layer['embeddings_data'] = get_weights_data(data_reader, layer['name'], 'embeddings') - - output_shape = input_shapes[0] + [layer['n_out']] - - return layer, output_shape +from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer +from hls4ml.model.quantizers import BinaryQuantizer, TernaryQuantizer +from hls4ml.model.types import IntegerPrecisionType + + +@keras_handler('InputLayer') +def parse_input_layer(keras_layer, input_names, input_shapes, data_reader): + assert keras_layer['class_name'] == 'InputLayer' + + layer = parse_default_keras_layer(keras_layer, input_names) + + layer['input_shape'] = keras_layer['config']['batch_input_shape'][1:] + + dtype = keras_layer['config']['dtype'] + if dtype.startswith('int') or dtype.startswith('uint'): + layer['type_name'] = 'integer_input_t' + width = int(dtype[dtype.index('int') + 3 :]) + signed = not dtype.startswith('u') + layer['precision'] = IntegerPrecisionType(width=width, signed=signed) + # elif bool, q[u]int, ... + + output_shape = keras_layer['config']['batch_input_shape'] + + return layer, output_shape + + +dense_layers = ['Dense', 'BinaryDense', 'TernaryDense'] + + +@keras_handler(*dense_layers) +def parse_dense_layer(keras_layer, input_names, input_shapes, data_reader): + assert 'Dense' in keras_layer['class_name'] + + layer = parse_default_keras_layer(keras_layer, input_names) + + layer['weight_data'], layer['bias_data'] = get_weights_data(data_reader, layer['name'], ['kernel', 'bias']) + layer['n_in'] = layer['weight_data'].shape[0] + layer['n_out'] = layer['weight_data'].shape[1] + if 'Binary' in layer['class_name']: + layer['weight_quantizer'] = BinaryQuantizer(bits=2) + layer['bias_quantizer'] = BinaryQuantizer(bits=2) + elif 'Ternary' in layer['class_name']: + layer['weight_quantizer'] = TernaryQuantizer() + layer['bias_quantizer'] = TernaryQuantizer() + else: + layer['weight_quantizer'] = None + layer['bias_quantizer'] = None + output_shape = input_shapes[0][:] + output_shape[-1] = layer['n_out'] + + return layer, output_shape + + +activation_layers = ['Activation', 'LeakyReLU', 'ThresholdedReLU', 'ELU', 'PReLU', 'Softmax', 'ReLU'] + + +@keras_handler(*activation_layers) +def parse_activation_layer(keras_layer, input_names, input_shapes, data_reader): + assert keras_layer['class_name'] in activation_layers + + layer = parse_default_keras_layer(keras_layer, input_names) + + if layer['class_name'] != 'Activation': + layer['activation'] = layer['class_name'] + + if layer['activation'] == 'elu': + layer['class_name'] = 'ELU' # always use ELU type for elu, even if passed as activation + + if layer['class_name'] == 'LeakyReLU': + # the name changes for version 3 + layer['activ_param'] = keras_layer['config'].get('negative_slope', keras_layer['config'].get('alpha', 0.3)) + elif layer['class_name'] == 'ThresholdedReLU': + layer['activ_param'] = keras_layer['config'].get('theta', 1.0) + elif layer['class_name'] == 'ELU': + layer['activ_param'] = keras_layer['config'].get('alpha', 1.0) + elif layer['class_name'] == 'ReLU': + layer['class_name'] = 'Activation' + elif layer['class_name'] == 'PReLU': + layer['param_data'] = get_weights_data(data_reader, layer['name'], 'alpha') + + if layer['class_name'] == 'Activation' and layer['activation'] == 'softmax': + layer['class_name'] = 'Softmax' + if layer['class_name'] == 'Activation' and layer['activation'] == 'hard_sigmoid': + layer['class_name'] = 'HardActivation' + if layer['class_name'] == 'Softmax': + layer['axis'] = keras_layer['config'].get('axis', -1) + if layer['class_name'] == 'Activation' and layer['activation'] == 'leaky_relu': + layer['class_name'] = 'LeakyReLU' + # The parameter name changes for API v3; the default is different than in LeakyReLU layer + layer['activ_param'] = keras_layer['config'].get('negative_slope', keras_layer['config'].get('alpha', 0.2)) + + return layer, [shape for shape in input_shapes[0]] + + +@keras_handler('BatchNormalization') +def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader): + assert 'BatchNormalization' in keras_layer['class_name'] or 'QConv2DBatchnorm' in keras_layer['class_name'] + + layer = parse_default_keras_layer(keras_layer, input_names) + + in_size = 1 + for dim in input_shapes[0][1:]: + in_size *= dim + layer['n_in'] = in_size + layer['n_out'] = layer['n_in'] + if len(input_shapes[0]) == 2: + layer['n_filt'] = -1 + elif len(input_shapes[0]) == 3: + layer['n_filt'] = input_shapes[0][2] + elif len(input_shapes[0]) == 4: + layer['n_filt'] = input_shapes[0][3] + + layer['use_gamma'] = keras_layer['config']['scale'] + if layer['use_gamma']: + layer['gamma_data'] = get_weights_data(data_reader, layer['name'], 'gamma') + else: + layer['gamma_data'] = 1 + + layer['use_beta'] = keras_layer['config']['center'] + if layer['use_beta']: + layer['beta_data'] = get_weights_data(data_reader, layer['name'], 'beta') + else: + layer['beta_data'] = 0 + + layer['mean_data'], layer['variance_data'] = get_weights_data( + data_reader, layer['name'], ['moving_mean', 'moving_variance'] + ) + + return layer, [shape for shape in input_shapes[0]] + + +@keras_handler('LayerNormalization') +def parse_layernorm_layer(keras_layer, input_names, input_shapes, data_reader): + assert 'LayerNormalization' in keras_layer['class_name'] + + layer = parse_default_keras_layer(keras_layer, input_names) + + in_size = 1 + for dim in input_shapes[0][1:]: + in_size *= dim + layer['n_in'] = layer['n_out'] = in_size + + if not ((len(input_shapes[0])) == 3): + raise Exception('input size is not currently supported by hls4ml, only dim3 is supported') + layer['seq_len'] = input_shapes[0][-2] + + if not (keras_layer['config']['axis'][0] == 2): + raise Exception('assigning the axis is not currently supported by hls4ml, only axis 2 is supported') + + layer['gamma_data'] = get_weights_data(data_reader, layer['name'], 'gamma') + layer['beta_data'] = get_weights_data(data_reader, layer['name'], 'beta') + + layer['epsilon'] = keras_layer['config']['epsilon'] + if layer['epsilon'] <= 0: + raise Exception('epsilon must be positive') + + return layer, [shape for shape in input_shapes[0]] + + +@keras_handler('Embedding') +def parse_embedding_layer(keras_layer, input_names, input_shapes, data_reader): + assert 'Embedding' in keras_layer['class_name'] + + layer = parse_default_keras_layer(keras_layer, input_names) + + layer['n_in'] = input_shapes[0][1] + layer['vocab_size'] = keras_layer['config']['input_dim'] + layer['n_out'] = keras_layer['config']['output_dim'] + + layer['embeddings_data'] = get_weights_data(data_reader, layer['name'], 'embeddings') + + output_shape = input_shapes[0] + [layer['n_out']] + + return layer, output_shape diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 2c05b7501..e4d99fe28 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -138,3 +138,32 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node layer['n_filt'] = input_shapes[0][1] # Always channel first for Pytorch return layer, [shape for shape in input_shapes[0]] + + +@pytorch_handler('LayerNorm') +def parse_layernorm_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert 'LayerNorm' in operation + + layer = {} + + layer['class_name'] = 'LayerNormalization' + layer['name'] = layer_name + layer['inputs'] = input_names + + in_size = 1 + for dim in input_shapes[0][1:]: + in_size *= dim + layer['n_in'] = layer['n_out'] = in_size + + if not ((len(input_shapes[0])) == 3): + raise Exception('input size is not currently supported by hls4ml, only dim3 is supported') + layer['seq_len'] = input_shapes[0][-2] + + layer['gamma_data'] = class_object.weight.data.numpy() + layer['beta_data'] = class_object.bias.data.numpy() + + layer['epsilon'] = class_object.eps + if layer['epsilon'] <= 0: + raise Exception('epsilon must be positive') + + return layer, [shape for shape in input_shapes[0]] diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 3847cda9c..f9324c1ee 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1058,6 +1058,30 @@ def add_bias(self, bias, quantizer=None, precision=None): self.add_weights_variable(name='bias', var_name='b{index}', data=bias, quantizer=quantizer, precision=precision) +class LayerNormalization(Layer): + _expected_attributes = [ + Attribute('n_in'), + Attribute('seq_len'), + Attribute('epsilon', value_type=float, default=1e-3), + WeightAttribute('scale'), + WeightAttribute('bias'), + TypeAttribute('scale'), + TypeAttribute('bias'), + ] + + def initialize(self): + inp = self.get_input_variable() + shape = inp.shape + dims = inp.dim_names + self.add_output_variable(shape, dims) + + scale = self.get_attr('gamma_data') + bias = self.get_attr('beta_data') + + self.add_weights_variable(name='scale', var_name='s{index}', data=scale) + self.add_weights_variable(name='bias', var_name='b{index}', data=bias) + + class Merge(Layer): def initialize(self): assert len(self.inputs) == 2 @@ -1682,6 +1706,7 @@ def initialize(self): 'BatchNormOnnx': BatchNormOnnx, 'LayerGroup': LayerGroup, 'SymbolicExpression': SymbolicExpression, + 'LayerNormalization': LayerNormalization, # TensorFlow-specific layers: 'BiasAdd': BiasAdd, } diff --git a/hls4ml/model/optimizer/passes/convert_to_channels_last.py b/hls4ml/model/optimizer/passes/convert_to_channels_last.py index 0b5f12c00..8150d0a1f 100644 --- a/hls4ml/model/optimizer/passes/convert_to_channels_last.py +++ b/hls4ml/model/optimizer/passes/convert_to_channels_last.py @@ -2,7 +2,7 @@ # Based on https://github.com/fastmachinelearning/qonnx/blob/ # 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py -from hls4ml.model.layers import Concatenate, Dense, Input, Reshape, Transpose +from hls4ml.model.layers import Concatenate, Dense, Input, LayerNormalization, Reshape, Transpose from hls4ml.model.optimizer import OptimizerPass from hls4ml.model.types import WeightVariable @@ -45,6 +45,24 @@ def transform(self, model, node): node.get_output_variable().shape = input_shape dim_names = [f'N_INPUT_{i}_{node.index}' for i in range(1, len(input_shape) + 1)] node.get_output_variable().dim_names = dim_names + elif isinstance(node, LayerNormalization): + # LayerNorm only works on the last dimension in PyTorch + perm = [1, 0] + pre_transpose = model.make_node( + 'Transpose', f'pre_transpose_for_{node.get_attr("name")}', {'perm': perm}, [node.get_input_node().name] + ) + pre_transpose.channels_last_converted = True + model.insert_node(pre_transpose) + + # If not the output layer, transpose again + if not ( + node.get_attr('name') in model.outputs and model.config.config['HLSConfig']['Model']['TransposeOutputs'] + ): + post_transpose = model.make_node( + 'Transpose', f'post_transpose_for_{node.get_attr("name")}', {'perm': perm}, [node.name] + ) + post_transpose.channels_last_converted = True + model.insert_node(post_transpose) else: # Transpose weight tensors tensors = ['weight', 'depthwise', 'pointwise', 'zero_bias', 'scale', 'recurrent_weight'] diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index bd439e4a0..af97b4ccd 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -51,7 +51,7 @@ def _infer_precision(self, node, types_to_infer): if node_class in ['Dense']: return self._infer_dense_precision(node, types_to_infer) - if node_class in ['BatchNormalization', 'ApplyAlpha']: + if node_class in ['BatchNormalization', 'ApplyAlpha', 'LayerNormalization']: return self._infer_bn_precision(node, types_to_infer) if node_class in ['Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D', 'Conv2DBatchnorm']: diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 84a83de23..519e8fabc 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -1,700 +1,713 @@ -import json -import os -import shutil -import uuid -from collections import defaultdict - -import matplotlib.pyplot as plt -import numpy as np -import pandas -import seaborn as sb - -from hls4ml.model.graph import ModelGraph -from hls4ml.model.layers import GRU, LSTM, SeparableConv1D, SeparableConv2D - -try: - import qkeras - from tensorflow import keras - - __tf_profiling_enabled__ = True -except ImportError: - __tf_profiling_enabled__ = False - -try: - import torch - - __torch_profiling_enabled__ = True -except ImportError: - __torch_profiling_enabled__ = False - - -def get_unoptimized_hlsmodel(model): - from hls4ml.converters import convert_from_config - - new_config = model.config.config.copy() - new_config['HLSConfig'] = json.loads(json.dumps(new_config['HLSConfig'])) - - new_output_dir = uuid.uuid4().hex - - while os.path.exists(new_output_dir): - new_output_dir = uuid.uuid4().hex - - if 'SkipOptimizers' in new_config['HLSConfig']: - del new_config['HLSConfig']['SkipOptimizers'] - - new_config['HLSConfig']['Optimizers'] = [] - new_config['OutputDir'] = new_output_dir - - return convert_from_config(new_config), new_output_dir - - -def array_to_summary(x, fmt='boxplot'): - if fmt == 'boxplot': - y = {'med': np.median(x), 'q1': np.percentile(x, 25), 'q3': np.percentile(x, 75), 'whislo': min(x), 'whishi': max(x)} - elif fmt == 'histogram': - # Power of 2 bins covering data range - high = np.ceil(np.log2(max(x))) + 1 - low = np.floor(np.log2(min(x))) - 1 - bits = np.arange(low, high, 1) - bins = 2**bits - h, b = np.histogram(x, bins=bins) - h = h * 1.0 / float(sum(h)) # normalize - y = {'h': h, 'b': np.log2(b)} - return y - - -def boxplot(data, fmt='longform'): - if fmt == 'longform': - f = plt.figure() # figsize=(3, 3)) - hue = 'layer' if 'layer' in data.keys() else None - vp = sb.boxplot(x='x', y='weight', hue=hue, data=data[data['x'] > 0], showfliers=False) - vp.set_yticklabels(vp.get_yticklabels(), rotation=45, ha='right') - if hue is not None: - vp.get_legend().remove() - vp.set_xscale('log', base=2) - return f - elif fmt == 'summary': - from matplotlib.patches import Rectangle - - medianprops = dict(linestyle='-', color='k') - f, ax = plt.subplots(1, 1) - data.reverse() - colors = sb.color_palette("Blues", len(data)) - bp = ax.bxp(data, showfliers=False, vert=False, medianprops=medianprops) - # add colored boxes - for line, color in zip(bp['boxes'], colors): - x = line.get_xdata() - xl, xh = min(x), max(x) - y = line.get_ydata() - yl, yh = min(y), max(y) - rect = Rectangle((xl, yl), (xh - xl), (yh - yl), fill=True, color=color) - ax.add_patch(rect) - ax.set_yticklabels([d['weight'] for d in data]) - ax.set_xscale('log', base=2) - plt.xlabel('x') - return f - else: - return None - - -def histogram(data, fmt='longform'): - f = plt.figure() - from matplotlib.ticker import MaxNLocator - - n = len(data) if fmt == 'summary' else len(data['weight'].unique()) - colors = sb.color_palette("husl", n) - if fmt == 'longform': - for i, weight in enumerate(data['weight'].unique()): - y = array_to_summary(data[data['weight'] == weight]['x'], fmt='histogram') - plt.bar(y['b'][:-1], y['h'], width=1, fill=False, label=weight, edgecolor=colors[i]) - elif fmt == 'summary': - for i, weight in enumerate(data): - plt.bar(weight['b'][:-1], weight['h'], width=1, fill=False, label=weight['weight'], edgecolor=colors[i]) - - plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True)) - plt.xlabel('log2(x)') - plt.ylabel('frequency') - plt.legend() - return f - - -plots = {'boxplot': boxplot, 'histogram': histogram} - - -def types_boxplot(data, fmt='longform'): - from matplotlib.patches import PathPatch, Rectangle - - ax = plt.gca() - _ = plt.gcf() - # Scale the data - data['low'] = 2.0 ** data['low'] - data['high'] = 2.0 ** data['high'] - - # Plot the custom precisions - ticks = np.array([tick.get_text() for tick in plt.yticks()[1]]) - # Get the coordinates of the boxes to place the markers - if fmt == 'longform': - # seaborn adjusts the box positions slightly in groups - boxes = [c.get_extents().inverse_transformed(ax.transData) for c in ax.get_children() if isinstance(c, PathPatch)] - ys = [(box.y0 + box.y1) / 2 for box in boxes] - ys = [(y, y) for y in ys] - elif fmt == 'summary': - ys = [(y, y) for y in plt.yticks()[0]] - for _irow, row in data[data['layer'] != 'model'].iterrows(): - if row['layer'] in ticks: - iy = np.argwhere(ticks == row['layer'])[0][0] # Determine which layer in the plot - rectangle = Rectangle( - (row['low'], ys[iy][0] - 0.4), row['high'] - row['low'], 0.8, fill=True, color='grey', alpha=0.2 - ) - ax.add_patch(rectangle) - - -def types_histogram(data, fmt='longform'): - ax = plt.gca() - layers = np.array(ax.get_legend_handles_labels()[1]) - colors = sb.color_palette("husl", len(layers)) - ylim = ax.get_ylim() - for _irow, row in data[data['layer'] != 'model'].iterrows(): - if row['layer'] in layers: - col = colors[np.argwhere(layers == row['layer'])[0][0]] - plt.plot((row['low'], row['low']), ylim, '--', color=col) - plt.plot((row['high'], row['high']), ylim, '--', color=col) - - -types_plots = {'boxplot': types_boxplot, 'histogram': types_histogram} - - -def ap_fixed_WIFS(dtype): - from hls4ml.backends import VivadoBackend - - dtype = VivadoBackend.convert_precision_string(dtype) - W, I, F, S = dtype.width, dtype.integer, dtype.fractional, dtype.signed - return W, I, F, S - - -def types_hlsmodel(model): - data = {'layer': [], 'low': [], 'high': []} - # Plot the default precision - default_precision = model.config.model_precision['default'] - W, I, F, S = ap_fixed_WIFS(default_precision) - data['layer'].append('model') - data['low'].append(-F) - data['high'].append(I - 1 if S else I) - - for layer in model.get_layers(): - if isinstance(layer, GRU) or isinstance(layer, LSTM): - suffix = ['w', 'rw', 'b', 'rb'] - elif isinstance(layer, SeparableConv1D) or isinstance(layer, SeparableConv2D): - suffix = ['dw', 'pw', 'db', 'pb'] - else: - suffix = ['w', 'b'] - for iw, weight in enumerate(layer.get_weights()): - wname = f'{layer.name}/{suffix[iw]}' - T = weight.type - if T.name != 'model': - W, I, F, S = ap_fixed_WIFS(T.precision) - data['layer'].append(wname) - data['low'].append(-F) - data['high'].append(I - 1 if S else I) - data = pandas.DataFrame(data) - return data - - -def activation_types_hlsmodel(model): - data = {'layer': [], 'low': [], 'high': []} - # Get the default precision - default_precision = model.config.model_precision['default'] - W, I, F, S = ap_fixed_WIFS(default_precision) - data['layer'].append('model') - data['low'].append(-F) - data['high'].append(I - 1 if S else I) - for layer in model.get_layers(): - T = layer.get_output_variable().type.precision - W, I, F, S = ap_fixed_WIFS(T) - data['layer'].append(layer.name) - data['low'].append(-F) - data['high'].append(I - 1 if S else I) - data = pandas.DataFrame(data) - return data - - -def weights_hlsmodel(model, fmt='longform', plot='boxplot'): - if fmt == 'longform': - data = {'x': [], 'layer': [], 'weight': []} - elif fmt == 'summary': - data = [] - - for layer in model.get_layers(): - if isinstance(layer, GRU) or isinstance(layer, LSTM): - suffix = ['w', 'rw', 'b', 'rb'] - elif isinstance(layer, SeparableConv1D) or isinstance(layer, SeparableConv2D): - suffix = ['dw', 'pw', 'db', 'pb'] - else: - suffix = ['w', 'b'] - name = layer.name - for iw, weight in enumerate(layer.get_weights()): - label = f'{name}/{suffix[iw]}' - w = weight.data.flatten() - w = abs(w[w != 0]) - n = len(w) - if n == 0: - print(f'Weights for {name} are only zeros, ignoring.') - break - if fmt == 'longform': - data['x'].extend(w.tolist()) - data['layer'].extend([name] * len(w)) - data['weight'].extend([label] * len(w)) - elif fmt == 'summary': - data.append(array_to_summary(w, fmt=plot)) - data[-1]['layer'] = name - data[-1]['weight'] = label - - if fmt == 'longform': - data = pandas.DataFrame(data) - return data - - -def _keras_batchnorm(layer): - weights = layer.get_weights() - epsilon = layer.epsilon - - gamma = weights[0] - beta = weights[1] - mean = weights[2] - var = weights[3] - - scale = gamma / np.sqrt(var + epsilon) - bias = beta - gamma * mean / np.sqrt(var + epsilon) - - return [scale, bias], ['s', 'b'] - - -def _keras_layer(layer): - return layer.get_weights(), ['w', 'b'] - - -def _keras_lstm(layer): - return layer.get_weights(), ['w', 'u', 'b'] - - -keras_process_layer_map = defaultdict( - lambda: _keras_layer, - { - 'BatchNormalization': _keras_batchnorm, - 'QBatchNormalization': _keras_batchnorm, - 'LSTM': _keras_lstm, - 'QLSTM': _keras_lstm, - }, -) - - -def activations_hlsmodel(model, X, fmt='summary', plot='boxplot'): - if fmt == 'longform': - raise NotImplementedError - elif fmt == 'summary': - data = [] - - _, trace = model.trace(np.ascontiguousarray(X)) - - if len(trace) == 0: - raise RuntimeError("ModelGraph must have tracing on for at least 1 layer (this can be set in its config)") - - for layer in trace.keys(): - print(f" {layer}") - - if fmt == 'summary': - y = trace[layer].flatten() - y = abs(y[y != 0]) - - if len(y) == 0: - print(f'Activations for {layer} are only zeros, ignoring.') - continue - - data.append(array_to_summary(y, fmt=plot)) - data[-1]['weight'] = layer - - return data - - -def weights_keras(model, fmt='longform', plot='boxplot'): - if fmt == 'longform': - data = {'x': [], 'layer': [], 'weight': []} - elif fmt == 'summary': - data = [] - for layer in model.layers: - name = layer.name - weights, suffix = keras_process_layer_map[type(layer).__name__](layer) - - for i, w in enumerate(weights): - label = f'{name}/{suffix[i]}' - w = w.flatten() - w = abs(w[w != 0]) - n = len(w) - if n == 0: - print(f'Weights for {name} are only zeros, ignoring.') - break - if fmt == 'longform': - data['x'].extend(w.tolist()) - data['layer'].extend([name] * n) - data['weight'].extend([label] * n) - elif fmt == 'summary': - data.append(array_to_summary(w, fmt=plot)) - data[-1]['layer'] = name - data[-1]['weight'] = label - - if fmt == 'longform': - data = pandas.DataFrame(data) - return data - - -def activations_keras(model, X, fmt='longform', plot='boxplot'): - # test layer by layer on data - if fmt == 'longform': - # return long form pandas dataframe for - # seaborn boxplot - data = {'x': [], 'weight': []} - elif fmt == 'summary': - # return summary statistics for matplotlib.axes.Axes.bxp - # or histogram bin edges and heights - data = [] - outputs = _get_outputs( - [layer for layer in model.layers if not isinstance(layer, keras.layers.InputLayer)], X, model.input - ) - outputs = dict(zip([layer.name for layer in model.layers if not isinstance(layer, keras.layers.InputLayer)], outputs)) - for layer_name, y in outputs.items(): - print(f" {layer_name}") - y = y.flatten() - y = abs(y[y != 0]) - if len(y) == 0: - print(f'Activations for {layer_name} are only zeros, ignoring.') - continue - if fmt == 'longform': - data['x'].extend(y.tolist()) - data['weight'].extend([layer_name for i in range(len(y))]) - elif fmt == 'summary': - data.append(array_to_summary(y, fmt=plot)) - data[-1]['weight'] = layer_name - - if fmt == 'longform': - data = pandas.DataFrame(data) - return data - - -def weights_torch(model, fmt='longform', plot='boxplot'): - suffix = ['w', 'b'] - if fmt == 'longform': - data = {'x': [], 'layer': [], 'weight': []} - elif fmt == 'summary': - data = [] - for layer in model.children(): - if isinstance(layer, torch.nn.Linear): - name = layer.__class__.__name__ - weights = list(layer.parameters()) - for i, w in enumerate(weights): - label = f'{name}/{suffix[i]}' - w = weights[i].detach().numpy() - w = w.flatten() - w = abs(w[w != 0]) - n = len(w) - if n == 0: - print(f'Weights for {name} are only zeros, ignoring.') - break - if fmt == 'longform': - data['x'].extend(w.tolist()) - data['layer'].extend([name] * n) - data['weight'].extend([label] * n) - elif fmt == 'summary': - data.append(array_to_summary(w, fmt=plot)) - data[-1]['layer'] = name - data[-1]['weight'] = label - - if fmt == 'longform': - data = pandas.DataFrame(data) - return data - - -def activations_torch(model, X, fmt='longform', plot='boxplot'): - X = torch.Tensor(X) - if fmt == 'longform': - data = {'x': [], 'weight': []} - elif fmt == 'summary': - data = [] - - partial_model = torch.nn.Sequential - layers = [] - for layer in model.children(): - lname = layer.__class__.__name__ - layers.append(layer) - pm = partial_model(*layers) - print(f" {lname}") - y = pm(X).flatten().detach().numpy() - y = abs(y[y != 0]) - if len(y) == 0: - print(f'Activations for {lname} are only zeros, ignoring.') - continue - if fmt == 'longform': - data['x'].extend(y.tolist()) - data['weight'].extend([lname for _ in range(len(y))]) - elif fmt == 'summary': - data.append(array_to_summary(y, fmt=plot)) - data[-1]['weight'] = lname - - if fmt == 'longform': - data = pandas.DataFrame(data) - return data - - -def numerical(model=None, hls_model=None, X=None, plot='boxplot'): - """Perform numerical profiling of a model. - - Args: - model (optional): Keras of PyTorch model. Defaults to None. - hls_model (ModelGraph, optional): The ModelGraph to profile. Defaults to None. - X (ndarray, optional): Test data on which to evaluate the model to profile activations. - Must be formatted suitably for the ``model.predict(X)``. Defaults to None. - plot (str, optional): The type of plot to produce. Options are: 'boxplot' (default), 'violinplot', 'histogram', - 'FacetGrid'. Defaults to 'boxplot'. - - Returns: - tuple: The quadruple of produced figures. First weights and biases - for the pre- and post-optimization models respectively, - then activations for the pre- and post-optimization models - respectively. (Optimizations are applied to an ModelGraph by hls4ml, - a post-optimization ModelGraph is a final model). - """ - wp, wph, ap, aph = None, None, None, None - - hls_model_present = hls_model is not None and isinstance(hls_model, ModelGraph) - model_present = model is not None - - if hls_model_present: - before = " (before optimization)" - after = " (final / after optimization)" - hls_model_unoptimized, tmp_output_dir = get_unoptimized_hlsmodel(hls_model) - else: - before = "" - after = "" - hls_model_unoptimized, tmp_output_dir = None, None - - print("Profiling weights" + before) - data = None - - if hls_model_present: - data = weights_hlsmodel(hls_model_unoptimized, fmt='summary', plot=plot) - elif model_present: - if __tf_profiling_enabled__ and isinstance(model, keras.Model): - data = weights_keras(model, fmt='summary', plot=plot) - elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): - data = weights_torch(model, fmt='summary', plot=plot) - - if data is None: - print("Only keras, PyTorch (Sequential) and ModelGraph models " + "can currently be profiled") - - if hls_model_present and os.path.exists(tmp_output_dir): - shutil.rmtree(tmp_output_dir) - - return wp, wph, ap, aph - - wp = plots[plot](data, fmt='summary') # weight plot - - if hls_model_present and plot in types_plots: - t_data = types_hlsmodel(hls_model_unoptimized) - types_plots[plot](t_data, fmt='summary') - - plt.title("Distribution of (non-zero) weights" + before) - plt.tight_layout() - - if hls_model_present: - print("Profiling weights" + after) - - data = weights_hlsmodel(hls_model, fmt='summary', plot=plot) - wph = plots[plot](data, fmt='summary') # weight plot - - if plot in types_plots: - t_data = types_hlsmodel(hls_model) - types_plots[plot](t_data, fmt='summary') - - plt.title("Distribution of (non-zero) weights" + after) - plt.tight_layout() - - if X is not None: - print("Profiling activations" + before) - data = None - if __tf_profiling_enabled__ and isinstance(model, keras.Model): - data = activations_keras(model, X, fmt='summary', plot=plot) - elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): - data = activations_torch(model, X, fmt='summary', plot=plot) - - if data is not None: - ap = plots[plot](data, fmt='summary') # activation plot - if hls_model_present and plot in types_plots: - t_data = activation_types_hlsmodel(hls_model_unoptimized) - types_plots[plot](t_data, fmt='summary') - plt.title("Distribution of (non-zero) activations" + before) - plt.tight_layout() - - if hls_model_present: - print("Profiling activations" + after) - data = activations_hlsmodel(hls_model, X, fmt='summary', plot=plot) - aph = plots[plot](data, fmt='summary') - - t_data = activation_types_hlsmodel(hls_model) - types_plots[plot](t_data, fmt='summary') - - plt.title("Distribution of (non-zero) activations (final / after optimization)") - plt.tight_layout() - - if hls_model_present and os.path.exists(tmp_output_dir): - shutil.rmtree(tmp_output_dir) - - return wp, wph, ap, aph - - -######### -# COMPARE OUTPUT IMPLEMENTATION -######### -def _is_ignored_layer(layer): - """Some layers need to be ingored during inference""" - if isinstance(layer, (keras.layers.InputLayer, keras.layers.Dropout)): - return True - return False - - -def _get_outputs(layers, X, model_input): - """Get outputs of intermediate layers""" - partial_models = keras.models.Model(inputs=model_input, outputs=[layer.output for layer in layers]) - y = partial_models.predict(X) - return y - - -def get_ymodel_keras(keras_model, X): - """Calculate each layer's ouput and put them into a dictionary. - - Args: - keras_model (_type_): A keras Model - X (ndarray): Test data on which to evaluate the model to profile activations. - Must be formatted suitably for the ``model.predict(X)``. - - Returns: - dict: A dictionary in the form {"layer_name": ouput array of layer}. - """ - ymodel = {} - traced_layers = [] - layer_names = [] - for layer in keras_model.layers: - if _is_ignored_layer(layer): - continue - # If the layer has activation integrated then separate them - # Note that if the layer is a standalone activation layer then skip this - name = layer.name - if ( - hasattr(layer, 'activation') - and layer.activation is not None - and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation)) - and layer.activation.__name__ != 'linear' - ): - tmp_activation = layer.activation - layer.activation = None - ymodel.update({layer.name: _get_outputs([layer], X, keras_model.input)}) - layer.activation = tmp_activation - name = layer.name + f"_{tmp_activation.__name__}" - traced_layers.append(layer) - layer_names.append(name) - outputs = _get_outputs(traced_layers, X, keras_model.input) - for name, output in zip(layer_names, outputs): - ymodel[name] = output - print("Done taking outputs for Keras model.") - return ymodel - - -def _norm_diff(ymodel, ysim): - """Calculate the square root of the sum of the squares of the differences""" - diff = {} - - for key in list(ysim.keys()): - diff[key] = np.linalg.norm(ysim[key] - ymodel[key]) - - # ---Bar Plot--- - f, ax = plt.subplots() - plt.bar(list(diff.keys()), list(diff.values())) - plt.title("layer-by-layer output differences") - ax.set_ylabel('Norm of difference vector') - plt.xticks(rotation=90) - plt.tight_layout() - return f - - -def _dist_diff(ymodel, ysim): - """ - Calculate the normalized distribution of the differences of the elements - of the output vectors. - If difference >= original value then the normalized difference will be set to 1, - meaning "very difference". - If difference < original value then the normalized difference would be difference/original. - """ - - diff = {} - - for key in list(ysim.keys()): - flattened_ysim = ysim[key].flatten() - flattened_ymodel = np.array(ymodel[key]).flatten() - - diff[key] = np.absolute(flattened_ymodel - flattened_ysim) / np.linalg.norm(flattened_ymodel - flattened_ysim) - diff_vector = np.absolute(flattened_ymodel - flattened_ysim) - abs_ymodel = np.absolute(flattened_ymodel) - - normalized_diff = np.zeros(diff_vector.shape) - normalized_diff[(diff_vector >= abs_ymodel) & (abs_ymodel > 0) & (diff_vector > 0)] = 1 - - # Fill out the rest - index = diff_vector < abs_ymodel - normalized_diff[index] = diff_vector[index] / abs_ymodel[index] - - diff[key] = normalized_diff - - # ---Box Plot--- - f, ax = plt.subplots() - pos = np.array(range(len(list(diff.values())))) + 1 - ax.boxplot(list(diff.values()), sym='k+', positions=pos) - - # --formatting - plt.title("Layer-by-layer distribution of output differences") - ax.set_xticklabels(list(diff.keys())) - ax.set_ylabel('Normalized difference') - ax.set_ylabel('Percent difference.') - plt.xticks(rotation=90) - plt.tight_layout() - - return f - - -def compare(keras_model, hls_model, X, plot_type="dist_diff"): - """Compare each layer's output in keras and hls model. Note that the hls_model should not be compiled before using this. - - Args: - keras_model: Original keras model. - hls_model (ModelGraph): Converted ModelGraph, with "Trace:True" in the configuration file. - X (ndarray): Input tensor for the model. - plot_type (str, optional): Different methods to visualize the y_model and y_sim differences. - Possible options include: - - 'norm_diff':: square root of the sum of the squares of the differences between each output vectors. - - 'dist_diff':: The normalized distribution of the differences of the elements between two output vectors. - Defaults to "dist_diff". - - Returns: - matplotlib figure: Plot object of the histogram depicting the difference in each layer's output. - """ - - # Take in output from both models - # Note that each y is a dictionary with structure {"layer_name": flattened ouput array} - ymodel = get_ymodel_keras(keras_model, X) - _, ysim = hls_model.trace(X) - - print("Plotting difference...") - f = plt.figure() - if plot_type == "norm_diff": - f = _norm_diff(ymodel, ysim) - elif plot_type == "dist_diff": - f = _dist_diff(ymodel, ysim) - - return f +import json +import os +import shutil +import uuid +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +import pandas +import seaborn as sb + +from hls4ml.model.graph import ModelGraph +from hls4ml.model.layers import GRU, LSTM, SeparableConv1D, SeparableConv2D + +try: + import qkeras + from tensorflow import keras + + __tf_profiling_enabled__ = True +except ImportError: + __tf_profiling_enabled__ = False + +try: + import torch + + __torch_profiling_enabled__ = True +except ImportError: + __torch_profiling_enabled__ = False + + +def get_unoptimized_hlsmodel(model): + from hls4ml.converters import convert_from_config + + new_config = model.config.config.copy() + new_config['HLSConfig'] = json.loads(json.dumps(new_config['HLSConfig'])) + + new_output_dir = uuid.uuid4().hex + + while os.path.exists(new_output_dir): + new_output_dir = uuid.uuid4().hex + + if 'SkipOptimizers' in new_config['HLSConfig']: + del new_config['HLSConfig']['SkipOptimizers'] + + new_config['HLSConfig']['Optimizers'] = [] + new_config['OutputDir'] = new_output_dir + + return convert_from_config(new_config), new_output_dir + + +def array_to_summary(x, fmt='boxplot'): + if fmt == 'boxplot': + y = {'med': np.median(x), 'q1': np.percentile(x, 25), 'q3': np.percentile(x, 75), 'whislo': min(x), 'whishi': max(x)} + elif fmt == 'histogram': + # Power of 2 bins covering data range + high = np.ceil(np.log2(max(x))) + 1 + low = np.floor(np.log2(min(x))) - 1 + bits = np.arange(low, high, 1) + bins = 2**bits + h, b = np.histogram(x, bins=bins) + h = h * 1.0 / float(sum(h)) # normalize + y = {'h': h, 'b': np.log2(b)} + return y + + +def boxplot(data, fmt='longform'): + if fmt == 'longform': + f = plt.figure() # figsize=(3, 3)) + hue = 'layer' if 'layer' in data.keys() else None + vp = sb.boxplot(x='x', y='weight', hue=hue, data=data[data['x'] > 0], showfliers=False) + vp.set_yticklabels(vp.get_yticklabels(), rotation=45, ha='right') + if hue is not None: + vp.get_legend().remove() + vp.set_xscale('log', base=2) + return f + elif fmt == 'summary': + from matplotlib.patches import Rectangle + + medianprops = dict(linestyle='-', color='k') + f, ax = plt.subplots(1, 1) + data.reverse() + colors = sb.color_palette("Blues", len(data)) + bp = ax.bxp(data, showfliers=False, vert=False, medianprops=medianprops) + # add colored boxes + for line, color in zip(bp['boxes'], colors): + x = line.get_xdata() + xl, xh = min(x), max(x) + y = line.get_ydata() + yl, yh = min(y), max(y) + rect = Rectangle((xl, yl), (xh - xl), (yh - yl), fill=True, color=color) + ax.add_patch(rect) + ax.set_yticklabels([d['weight'] for d in data]) + ax.set_xscale('log', base=2) + plt.xlabel('x') + return f + else: + return None + + +def histogram(data, fmt='longform'): + f = plt.figure() + from matplotlib.ticker import MaxNLocator + + n = len(data) if fmt == 'summary' else len(data['weight'].unique()) + colors = sb.color_palette("husl", n) + if fmt == 'longform': + for i, weight in enumerate(data['weight'].unique()): + y = array_to_summary(data[data['weight'] == weight]['x'], fmt='histogram') + plt.bar(y['b'][:-1], y['h'], width=1, fill=False, label=weight, edgecolor=colors[i]) + elif fmt == 'summary': + for i, weight in enumerate(data): + plt.bar(weight['b'][:-1], weight['h'], width=1, fill=False, label=weight['weight'], edgecolor=colors[i]) + + plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True)) + plt.xlabel('log2(x)') + plt.ylabel('frequency') + plt.legend() + return f + + +plots = {'boxplot': boxplot, 'histogram': histogram} + + +def types_boxplot(data, fmt='longform'): + from matplotlib.patches import PathPatch, Rectangle + + ax = plt.gca() + _ = plt.gcf() + # Scale the data + data['low'] = 2.0 ** data['low'] + data['high'] = 2.0 ** data['high'] + + # Plot the custom precisions + ticks = np.array([tick.get_text() for tick in plt.yticks()[1]]) + # Get the coordinates of the boxes to place the markers + if fmt == 'longform': + # seaborn adjusts the box positions slightly in groups + boxes = [c.get_extents().inverse_transformed(ax.transData) for c in ax.get_children() if isinstance(c, PathPatch)] + ys = [(box.y0 + box.y1) / 2 for box in boxes] + ys = [(y, y) for y in ys] + elif fmt == 'summary': + ys = [(y, y) for y in plt.yticks()[0]] + for _irow, row in data[data['layer'] != 'model'].iterrows(): + if row['layer'] in ticks: + iy = np.argwhere(ticks == row['layer'])[0][0] # Determine which layer in the plot + rectangle = Rectangle( + (row['low'], ys[iy][0] - 0.4), row['high'] - row['low'], 0.8, fill=True, color='grey', alpha=0.2 + ) + ax.add_patch(rectangle) + + +def types_histogram(data, fmt='longform'): + ax = plt.gca() + layers = np.array(ax.get_legend_handles_labels()[1]) + colors = sb.color_palette("husl", len(layers)) + ylim = ax.get_ylim() + for _irow, row in data[data['layer'] != 'model'].iterrows(): + if row['layer'] in layers: + col = colors[np.argwhere(layers == row['layer'])[0][0]] + plt.plot((row['low'], row['low']), ylim, '--', color=col) + plt.plot((row['high'], row['high']), ylim, '--', color=col) + + +types_plots = {'boxplot': types_boxplot, 'histogram': types_histogram} + + +def ap_fixed_WIFS(dtype): + from hls4ml.backends import VivadoBackend + + dtype = VivadoBackend.convert_precision_string(dtype) + W, I, F, S = dtype.width, dtype.integer, dtype.fractional, dtype.signed + return W, I, F, S + + +def types_hlsmodel(model): + data = {'layer': [], 'low': [], 'high': []} + # Plot the default precision + default_precision = model.config.model_precision['default'] + W, I, F, S = ap_fixed_WIFS(default_precision) + data['layer'].append('model') + data['low'].append(-F) + data['high'].append(I - 1 if S else I) + + for layer in model.get_layers(): + if isinstance(layer, GRU) or isinstance(layer, LSTM): + suffix = ['w', 'rw', 'b', 'rb'] + elif isinstance(layer, SeparableConv1D) or isinstance(layer, SeparableConv2D): + suffix = ['dw', 'pw', 'db', 'pb'] + else: + suffix = ['w', 'b'] + for iw, weight in enumerate(layer.get_weights()): + wname = f'{layer.name}/{suffix[iw]}' + T = weight.type + if T.name != 'model': + W, I, F, S = ap_fixed_WIFS(T.precision) + data['layer'].append(wname) + data['low'].append(-F) + data['high'].append(I - 1 if S else I) + data = pandas.DataFrame(data) + return data + + +def activation_types_hlsmodel(model): + data = {'layer': [], 'low': [], 'high': []} + # Get the default precision + default_precision = model.config.model_precision['default'] + W, I, F, S = ap_fixed_WIFS(default_precision) + data['layer'].append('model') + data['low'].append(-F) + data['high'].append(I - 1 if S else I) + for layer in model.get_layers(): + T = layer.get_output_variable().type.precision + W, I, F, S = ap_fixed_WIFS(T) + data['layer'].append(layer.name) + data['low'].append(-F) + data['high'].append(I - 1 if S else I) + data = pandas.DataFrame(data) + return data + + +def weights_hlsmodel(model, fmt='longform', plot='boxplot'): + if fmt == 'longform': + data = {'x': [], 'layer': [], 'weight': []} + elif fmt == 'summary': + data = [] + + for layer in model.get_layers(): + if isinstance(layer, GRU) or isinstance(layer, LSTM): + suffix = ['w', 'rw', 'b', 'rb'] + elif isinstance(layer, SeparableConv1D) or isinstance(layer, SeparableConv2D): + suffix = ['dw', 'pw', 'db', 'pb'] + else: + suffix = ['w', 'b'] + name = layer.name + for iw, weight in enumerate(layer.get_weights()): + label = f'{name}/{suffix[iw]}' + w = weight.data.flatten() + w = abs(w[w != 0]) + n = len(w) + if n == 0: + print(f'Weights for {name} are only zeros, ignoring.') + break + if fmt == 'longform': + data['x'].extend(w.tolist()) + data['layer'].extend([name] * len(w)) + data['weight'].extend([label] * len(w)) + elif fmt == 'summary': + data.append(array_to_summary(w, fmt=plot)) + data[-1]['layer'] = name + data[-1]['weight'] = label + + if fmt == 'longform': + data = pandas.DataFrame(data) + return data + + +def _keras_batchnorm(layer): + weights = layer.get_weights() + epsilon = layer.epsilon + + gamma = weights[0] + beta = weights[1] + mean = weights[2] + var = weights[3] + + scale = gamma / np.sqrt(var + epsilon) + bias = beta - gamma * mean / np.sqrt(var + epsilon) + + return [scale, bias], ['s', 'b'] + + +def _keras_layer(layer): + return layer.get_weights(), ['w', 'b'] + + +def _keras_layernorm(layer): + weights = layer.get_weights() + + gamma = weights[0] + beta = weights[1] + + scale = gamma + bias = beta + + return [scale, bias], ['s', 'b'] + + +def _keras_lstm(layer): + return layer.get_weights(), ['w', 'u', 'b'] + + +keras_process_layer_map = defaultdict( + lambda: _keras_layer, + { + 'BatchNormalization': _keras_batchnorm, + 'QBatchNormalization': _keras_batchnorm, + 'LayerNormalization': _keras_layernorm, + 'LSTM': _keras_lstm, + 'QLSTM': _keras_lstm, + }, +) + + +def activations_hlsmodel(model, X, fmt='summary', plot='boxplot'): + if fmt == 'longform': + raise NotImplementedError + elif fmt == 'summary': + data = [] + + _, trace = model.trace(np.ascontiguousarray(X)) + + if len(trace) == 0: + raise RuntimeError("ModelGraph must have tracing on for at least 1 layer (this can be set in its config)") + + for layer in trace.keys(): + print(f" {layer}") + + if fmt == 'summary': + y = trace[layer].flatten() + y = abs(y[y != 0]) + + if len(y) == 0: + print(f'Activations for {layer} are only zeros, ignoring.') + continue + + data.append(array_to_summary(y, fmt=plot)) + data[-1]['weight'] = layer + + return data + + +def weights_keras(model, fmt='longform', plot='boxplot'): + if fmt == 'longform': + data = {'x': [], 'layer': [], 'weight': []} + elif fmt == 'summary': + data = [] + for layer in model.layers: + name = layer.name + weights, suffix = keras_process_layer_map[type(layer).__name__](layer) + + for i, w in enumerate(weights): + label = f'{name}/{suffix[i]}' + w = w.flatten() + w = abs(w[w != 0]) + n = len(w) + if n == 0: + print(f'Weights for {name} are only zeros, ignoring.') + break + if fmt == 'longform': + data['x'].extend(w.tolist()) + data['layer'].extend([name] * n) + data['weight'].extend([label] * n) + elif fmt == 'summary': + data.append(array_to_summary(w, fmt=plot)) + data[-1]['layer'] = name + data[-1]['weight'] = label + + if fmt == 'longform': + data = pandas.DataFrame(data) + return data + + +def activations_keras(model, X, fmt='longform', plot='boxplot'): + # test layer by layer on data + if fmt == 'longform': + # return long form pandas dataframe for + # seaborn boxplot + data = {'x': [], 'weight': []} + elif fmt == 'summary': + # return summary statistics for matplotlib.axes.Axes.bxp + # or histogram bin edges and heights + data = [] + outputs = _get_outputs( + [layer for layer in model.layers if not isinstance(layer, keras.layers.InputLayer)], X, model.input + ) + outputs = dict(zip([layer.name for layer in model.layers if not isinstance(layer, keras.layers.InputLayer)], outputs)) + for layer_name, y in outputs.items(): + print(f" {layer_name}") + y = y.flatten() + y = abs(y[y != 0]) + if len(y) == 0: + print(f'Activations for {layer_name} are only zeros, ignoring.') + continue + if fmt == 'longform': + data['x'].extend(y.tolist()) + data['weight'].extend([layer_name for i in range(len(y))]) + elif fmt == 'summary': + data.append(array_to_summary(y, fmt=plot)) + data[-1]['weight'] = layer_name + + if fmt == 'longform': + data = pandas.DataFrame(data) + return data + + +def weights_torch(model, fmt='longform', plot='boxplot'): + suffix = ['w', 'b'] + if fmt == 'longform': + data = {'x': [], 'layer': [], 'weight': []} + elif fmt == 'summary': + data = [] + for layer in model.children(): + if isinstance(layer, torch.nn.Linear): + name = layer.__class__.__name__ + weights = list(layer.parameters()) + for i, w in enumerate(weights): + label = f'{name}/{suffix[i]}' + w = weights[i].detach().numpy() + w = w.flatten() + w = abs(w[w != 0]) + n = len(w) + if n == 0: + print(f'Weights for {name} are only zeros, ignoring.') + break + if fmt == 'longform': + data['x'].extend(w.tolist()) + data['layer'].extend([name] * n) + data['weight'].extend([label] * n) + elif fmt == 'summary': + data.append(array_to_summary(w, fmt=plot)) + data[-1]['layer'] = name + data[-1]['weight'] = label + + if fmt == 'longform': + data = pandas.DataFrame(data) + return data + + +def activations_torch(model, X, fmt='longform', plot='boxplot'): + X = torch.Tensor(X) + if fmt == 'longform': + data = {'x': [], 'weight': []} + elif fmt == 'summary': + data = [] + + partial_model = torch.nn.Sequential + layers = [] + for layer in model.children(): + lname = layer.__class__.__name__ + layers.append(layer) + pm = partial_model(*layers) + print(f" {lname}") + y = pm(X).flatten().detach().numpy() + y = abs(y[y != 0]) + if len(y) == 0: + print(f'Activations for {lname} are only zeros, ignoring.') + continue + if fmt == 'longform': + data['x'].extend(y.tolist()) + data['weight'].extend([lname for _ in range(len(y))]) + elif fmt == 'summary': + data.append(array_to_summary(y, fmt=plot)) + data[-1]['weight'] = lname + + if fmt == 'longform': + data = pandas.DataFrame(data) + return data + + +def numerical(model=None, hls_model=None, X=None, plot='boxplot'): + """Perform numerical profiling of a model. + + Args: + model (optional): Keras of PyTorch model. Defaults to None. + hls_model (ModelGraph, optional): The ModelGraph to profile. Defaults to None. + X (ndarray, optional): Test data on which to evaluate the model to profile activations. + Must be formatted suitably for the ``model.predict(X)``. Defaults to None. + plot (str, optional): The type of plot to produce. Options are: 'boxplot' (default), 'violinplot', 'histogram', + 'FacetGrid'. Defaults to 'boxplot'. + + Returns: + tuple: The quadruple of produced figures. First weights and biases + for the pre- and post-optimization models respectively, + then activations for the pre- and post-optimization models + respectively. (Optimizations are applied to an ModelGraph by hls4ml, + a post-optimization ModelGraph is a final model). + """ + wp, wph, ap, aph = None, None, None, None + + hls_model_present = hls_model is not None and isinstance(hls_model, ModelGraph) + model_present = model is not None + + if hls_model_present: + before = " (before optimization)" + after = " (final / after optimization)" + hls_model_unoptimized, tmp_output_dir = get_unoptimized_hlsmodel(hls_model) + else: + before = "" + after = "" + hls_model_unoptimized, tmp_output_dir = None, None + + print("Profiling weights" + before) + data = None + + if hls_model_present: + data = weights_hlsmodel(hls_model_unoptimized, fmt='summary', plot=plot) + elif model_present: + if __tf_profiling_enabled__ and isinstance(model, keras.Model): + data = weights_keras(model, fmt='summary', plot=plot) + elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): + data = weights_torch(model, fmt='summary', plot=plot) + + if data is None: + print("Only keras, PyTorch (Sequential) and ModelGraph models " + "can currently be profiled") + + if hls_model_present and os.path.exists(tmp_output_dir): + shutil.rmtree(tmp_output_dir) + + return wp, wph, ap, aph + + wp = plots[plot](data, fmt='summary') # weight plot + + if hls_model_present and plot in types_plots: + t_data = types_hlsmodel(hls_model_unoptimized) + types_plots[plot](t_data, fmt='summary') + + plt.title("Distribution of (non-zero) weights" + before) + plt.tight_layout() + + if hls_model_present: + print("Profiling weights" + after) + + data = weights_hlsmodel(hls_model, fmt='summary', plot=plot) + wph = plots[plot](data, fmt='summary') # weight plot + + if plot in types_plots: + t_data = types_hlsmodel(hls_model) + types_plots[plot](t_data, fmt='summary') + + plt.title("Distribution of (non-zero) weights" + after) + plt.tight_layout() + + if X is not None: + print("Profiling activations" + before) + data = None + if __tf_profiling_enabled__ and isinstance(model, keras.Model): + data = activations_keras(model, X, fmt='summary', plot=plot) + elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): + data = activations_torch(model, X, fmt='summary', plot=plot) + + if data is not None: + ap = plots[plot](data, fmt='summary') # activation plot + if hls_model_present and plot in types_plots: + t_data = activation_types_hlsmodel(hls_model_unoptimized) + types_plots[plot](t_data, fmt='summary') + plt.title("Distribution of (non-zero) activations" + before) + plt.tight_layout() + + if hls_model_present: + print("Profiling activations" + after) + data = activations_hlsmodel(hls_model, X, fmt='summary', plot=plot) + aph = plots[plot](data, fmt='summary') + + t_data = activation_types_hlsmodel(hls_model) + types_plots[plot](t_data, fmt='summary') + + plt.title("Distribution of (non-zero) activations (final / after optimization)") + plt.tight_layout() + + if hls_model_present and os.path.exists(tmp_output_dir): + shutil.rmtree(tmp_output_dir) + + return wp, wph, ap, aph + + +######### +# COMPARE OUTPUT IMPLEMENTATION +######### +def _is_ignored_layer(layer): + """Some layers need to be ingored during inference""" + if isinstance(layer, (keras.layers.InputLayer, keras.layers.Dropout)): + return True + return False + + +def _get_outputs(layers, X, model_input): + """Get outputs of intermediate layers""" + partial_models = keras.models.Model(inputs=model_input, outputs=[layer.output for layer in layers]) + y = partial_models.predict(X) + return y + + +def get_ymodel_keras(keras_model, X): + """Calculate each layer's ouput and put them into a dictionary. + + Args: + keras_model (_type_): A keras Model + X (ndarray): Test data on which to evaluate the model to profile activations. + Must be formatted suitably for the ``model.predict(X)``. + + Returns: + dict: A dictionary in the form {"layer_name": ouput array of layer}. + """ + ymodel = {} + traced_layers = [] + layer_names = [] + for layer in keras_model.layers: + if _is_ignored_layer(layer): + continue + # If the layer has activation integrated then separate them + # Note that if the layer is a standalone activation layer then skip this + name = layer.name + if ( + hasattr(layer, 'activation') + and layer.activation is not None + and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation)) + and layer.activation.__name__ != 'linear' + ): + tmp_activation = layer.activation + layer.activation = None + ymodel.update({layer.name: _get_outputs([layer], X, keras_model.input)}) + layer.activation = tmp_activation + name = layer.name + f"_{tmp_activation.__name__}" + traced_layers.append(layer) + layer_names.append(name) + outputs = _get_outputs(traced_layers, X, keras_model.input) + for name, output in zip(layer_names, outputs): + ymodel[name] = output + print("Done taking outputs for Keras model.") + return ymodel + + +def _norm_diff(ymodel, ysim): + """Calculate the square root of the sum of the squares of the differences""" + diff = {} + + for key in list(ysim.keys()): + diff[key] = np.linalg.norm(ysim[key] - ymodel[key]) + + # ---Bar Plot--- + f, ax = plt.subplots() + plt.bar(list(diff.keys()), list(diff.values())) + plt.title("layer-by-layer output differences") + ax.set_ylabel('Norm of difference vector') + plt.xticks(rotation=90) + plt.tight_layout() + return f + + +def _dist_diff(ymodel, ysim): + """ + Calculate the normalized distribution of the differences of the elements + of the output vectors. + If difference >= original value then the normalized difference will be set to 1, + meaning "very difference". + If difference < original value then the normalized difference would be difference/original. + """ + + diff = {} + + for key in list(ysim.keys()): + flattened_ysim = ysim[key].flatten() + flattened_ymodel = np.array(ymodel[key]).flatten() + + diff[key] = np.absolute(flattened_ymodel - flattened_ysim) / np.linalg.norm(flattened_ymodel - flattened_ysim) + diff_vector = np.absolute(flattened_ymodel - flattened_ysim) + abs_ymodel = np.absolute(flattened_ymodel) + + normalized_diff = np.zeros(diff_vector.shape) + normalized_diff[(diff_vector >= abs_ymodel) & (abs_ymodel > 0) & (diff_vector > 0)] = 1 + + # Fill out the rest + index = diff_vector < abs_ymodel + normalized_diff[index] = diff_vector[index] / abs_ymodel[index] + + diff[key] = normalized_diff + + # ---Box Plot--- + f, ax = plt.subplots() + pos = np.array(range(len(list(diff.values())))) + 1 + ax.boxplot(list(diff.values()), sym='k+', positions=pos) + + # --formatting + plt.title("Layer-by-layer distribution of output differences") + ax.set_xticklabels(list(diff.keys())) + ax.set_ylabel('Normalized difference') + ax.set_ylabel('Percent difference.') + plt.xticks(rotation=90) + plt.tight_layout() + + return f + + +def compare(keras_model, hls_model, X, plot_type="dist_diff"): + """Compare each layer's output in keras and hls model. Note that the hls_model should not be compiled before using this. + + Args: + keras_model: Original keras model. + hls_model (ModelGraph): Converted ModelGraph, with "Trace:True" in the configuration file. + X (ndarray): Input tensor for the model. + plot_type (str, optional): Different methods to visualize the y_model and y_sim differences. + Possible options include: + - 'norm_diff':: square root of the sum of the squares of the differences between each output vectors. + - 'dist_diff':: The normalized distribution of the differences of the elements between two output vectors. + Defaults to "dist_diff". + + Returns: + matplotlib figure: Plot object of the histogram depicting the difference in each layer's output. + """ + + # Take in output from both models + # Note that each y is a dictionary with structure {"layer_name": flattened ouput array} + ymodel = get_ymodel_keras(keras_model, X) + _, ysim = hls_model.trace(X) + + print("Plotting difference...") + f = plt.figure() + if plot_type == "norm_diff": + f = _norm_diff(ymodel, ysim) + elif plot_type == "dist_diff": + f = _dist_diff(ymodel, ysim) + + return f diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_layernorm.h b/hls4ml/templates/vivado/nnet_utils/nnet_layernorm.h new file mode 100644 index 000000000..17b071234 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_layernorm.h @@ -0,0 +1,138 @@ +#ifndef NNET_LAYERNORM_H_ +#define NNET_LAYERNORM_H_ + +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_dense.h" +#include + +#include "hls_math.h" + +namespace nnet { + +struct layernorm_config { + // Internal data type definitions + typedef float bias_t; + typedef float scale_t; + + // Layer Sizes + static const unsigned n_in = 20; + static const unsigned seq_len = 4; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; + + template using product = nnet::product::mult; +}; + +template void init_invert_sqr_table(typename CONFIG_T::table_t table_out[N_TABLE]) { + // Inversion function: + // result = 1/sqrt(x) + float min_val = CONFIG_T::epsilon; + float max_val = CONFIG_T::table_range; + float step = max_val / (float)(N_TABLE); + for (int ii = 0; ii < N_TABLE; ii++) { + float in_val = min_val + step * ii; + table_out[ii] = (typename CONFIG_T::table_t)(1.0 / sqrt(in_val)); + } +} + +template +void layernorm_1d(data_T data[CONFIG_T::n_in / CONFIG_T::seq_len], res_T res[CONFIG_T::n_in / CONFIG_T::seq_len], + typename CONFIG_T::scale_t scale[CONFIG_T::n_in / CONFIG_T::seq_len], + typename CONFIG_T::bias_t bias[CONFIG_T::n_in / CONFIG_T::seq_len]) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + #pragma HLS ARRAY_PARTITION variable=data complete + #pragma HLS ARRAY_PARTITION variable=res complete + int inv_range_inv = (int)1 / CONFIG_T::table_range; + typename CONFIG_T::table_t deno_inver = 0; +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t invert_sqr_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t invert_sqr_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_invert_sqr_table(invert_sqr_table); + initialized = true; + } + + static const unsigned dim = CONFIG_T::n_in / CONFIG_T::seq_len; + typename CONFIG_T::mean_t sum_cache = 0; + typename CONFIG_T::mean_t sum_cache2 = 0; + typename CONFIG_T::mean_t var, mean, diff; + typename CONFIG_T::mean_t data_diff[dim]; + typename CONFIG_T::mean_t var_epsilon = (typename CONFIG_T::mean_t)CONFIG_T::epsilon; + + #pragma HLS ARRAY_PARTITION variable=data_diff complete + + const typename CONFIG_T::mean_t k_inv = 1.0 / dim; + +LAYERNORM_1D_SUM: + for (int i = 0; i < dim; ++i) { + sum_cache += static_cast(data[i]); + } + mean = CONFIG_T::template product::product(sum_cache, k_inv); + +LAYERNORM_1D_VAR: + for (int i = 0; i < dim; ++i) { + data_diff[i] = static_cast(data[i]) - mean; + diff = data_diff[i] * data_diff[i]; + sum_cache2 += diff; + } + var = CONFIG_T::template product::product(sum_cache2, k_inv); + + int index = (var) * (CONFIG_T::table_size)*inv_range_inv; + if (CONFIG_T::table_range > 1) + index = (var) * (CONFIG_T::table_size) / (int)CONFIG_T::table_range; + if (index < 0) + index = 0; + if (index > CONFIG_T::table_size - 1) + index = CONFIG_T::table_size - 1; + deno_inver = invert_sqr_table[index]; + +LAYERNORM_1D_RESULT: + for (int i = 0; i < dim; ++i) { + res[i] = data_diff[i] * deno_inver * scale[i] + bias[i]; + } +} + +template +void layernormalize(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in], + typename CONFIG_T::scale_t scale[CONFIG_T::n_in / CONFIG_T::seq_len], + typename CONFIG_T::bias_t bias[CONFIG_T::n_in / CONFIG_T::seq_len]) { + static const unsigned dim = CONFIG_T::n_in / CONFIG_T::seq_len; + data_T in_val[dim]; + res_T outval[dim]; + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + #pragma HLS function_instantiate variable=scale,bias + + #pragma HLS ARRAY_PARTITION variable=scale complete + #pragma HLS ARRAY_PARTITION variable=bias complete + #pragma HLS ARRAY_PARTITION variable=in_val complete + #pragma HLS ARRAY_PARTITION variable=outval complete + +LAYERNORM_SEQ_LOOP: + for (int j = 0; j < CONFIG_T::seq_len; ++j) { + #pragma HLS PIPELINE + LAYERNORM_LOAD: + for (int i = 0; i < dim; ++i) { + #pragma HLS UNROLL + in_val[i] = data[j * dim + i]; + } + layernorm_1d(in_val, outval, scale, bias); + LAYERNORM_STORE: + for (int i = 0; i < dim; ++i) { + #pragma HLS UNROLL + res[j * dim + i] = outval[i]; + } + } +} + +} // namespace nnet + +#endif diff --git a/test/pytest/test_layernorm.py b/test/pytest/test_layernorm.py new file mode 100644 index 000000000..f3f0a5731 --- /dev/null +++ b/test/pytest/test_layernorm.py @@ -0,0 +1,42 @@ +from pathlib import Path + +import numpy as np +import pytest +from tensorflow.keras.layers import LayerNormalization +from tensorflow.keras.models import Sequential + +import hls4ml + +test_root_path = Path(__file__).parent + +in_shape = (10, 8) +atol = 5e-2 + + +@pytest.fixture(scope='module') +def data(): + np.random.seed(0) + return np.random.rand(100, *in_shape) + + +@pytest.fixture(scope='module') +def model(): + model = Sequential() + model.add(LayerNormalization(input_shape=in_shape)) + model.compile() + return model + + +# Currently only Vivado in io_parallel mode is supported +def test_layernorm(model, data): + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vivado') + output_dir = str(test_root_path / 'hls4mlprj_layernorm_Vivado_io_parallel') + hls_model = hls4ml.converters.convert_from_keras_model( + model, backend='Vivado', hls_config=config, io_type='io_parallel', output_dir=output_dir + ) + hls_model.compile() + + # Predict + y_keras = model.predict(data).flatten() + y_hls = hls_model.predict(data).flatten() + np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True) diff --git a/test/pytest/test_layernorm_pytorch.py b/test/pytest/test_layernorm_pytorch.py new file mode 100644 index 000000000..d61b0c436 --- /dev/null +++ b/test/pytest/test_layernorm_pytorch.py @@ -0,0 +1,41 @@ +from pathlib import Path + +import numpy as np +import pytest +import torch +from torch import nn + +import hls4ml + +test_root_path = Path(__file__).parent + +in_shape = (10, 8) +atol = 5e-2 + + +@pytest.fixture(scope='module') +def data(): + np.random.seed(0) + return np.random.rand(100, *in_shape) + + +@pytest.fixture(scope='module') +def model(): + model = nn.Sequential(nn.LayerNorm(in_shape[-1])) + model.eval() + return model + + +# Currently only Vivado in io_parallel mode is supported +def test_layernorm(model, data): + config = hls4ml.utils.config_from_pytorch_model(model, in_shape, granularity='name', backend='Vivado') + output_dir = str(test_root_path / 'hls4mlprj_layernorm_pytorch_Vivado_io_parallel') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, backend='Vivado', hls_config=config, io_type='io_parallel', output_dir=output_dir + ) + hls_model.compile() + + # Predict + y_pytorch = model(torch.Tensor(data)).detach().numpy().flatten() + y_hls = hls_model.predict(data).flatten() + np.testing.assert_allclose(y_pytorch, y_hls, rtol=0, atol=atol, verbose=True)