-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
general einsum support for io_parallel and latency
- Loading branch information
Showing
7 changed files
with
579 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from hls4ml.backends.backend import get_backend | ||
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate | ||
from hls4ml.model.layers import EinsumDense | ||
|
||
from .reshaping_templates import transpose_config_gen | ||
|
||
# Shared Dense template | ||
|
||
conv_dense_config_template = """struct config{index}_dense : nnet::dense_config {{ | ||
static const unsigned n_in = {n_in}; | ||
static const unsigned n_out = {n_out}; | ||
static const unsigned reuse_factor = {reuse}; | ||
static const unsigned strategy = nnet::{strategy}; | ||
static const unsigned n_zeros = {nzeros}; | ||
static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; | ||
typedef {accum_t.name} accum_t; | ||
typedef {bias_t.name} bias_t; | ||
typedef {weight_t.name} weight_t; | ||
template<class data_T, class res_T, class CONFIG_T> | ||
using kernel = nnet::{dense_function}<data_T, res_T, CONFIG_T>; | ||
template<class x_T, class y_T> | ||
using product = nnet::product::{product_type}<x_T, y_T>; | ||
}};\n""" | ||
|
||
# EinsumDense template | ||
|
||
einsum_dense_config_template = ''' | ||
struct config{index} {{ | ||
typedef config{index}_tpose_inp tpose_inp_conf; | ||
typedef config{index}_tpose_out tpose_out_conf; | ||
typedef config{index}_dense dense_conf; | ||
// Layer Sizes | ||
static const unsigned n_free_data = {n_free_data}; | ||
static const unsigned n_free_kernel = {n_free_kernel}; | ||
static const unsigned n_contract = {n_contract}; | ||
static const unsigned n_inplace = {n_inplace}; | ||
// Resource reuse info | ||
static const unsigned io_type = nnet::{iotype}; | ||
static const unsigned strategy = nnet::{strategy}; | ||
static const unsigned reuse_factor = {reuse_factor}; | ||
static const unsigned parallelization_factor = {parallelization_factor}; // Only useful when n_inplace > 1 | ||
static const bool store_weights_in_bram = false; // NOT USED | ||
}}; | ||
''' | ||
|
||
einsum_dense_function_template = 'nnet::einsum_dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' | ||
|
||
einsum_dense_include_list = ['nnet_utils/nnet_einsum_dense.h', 'nnet_utils/nnet_dense.h'] | ||
|
||
|
||
class EinsumDenseConfigTemplate(LayerConfigTemplate): | ||
def __init__(self): | ||
super().__init__(EinsumDense) | ||
self.template = einsum_dense_config_template | ||
self.dense_template = conv_dense_config_template | ||
|
||
def format(self, node: EinsumDense): | ||
default_params = self._default_config_params(node) | ||
|
||
strategy = node.model.config.get_strategy(node) | ||
io_type = node.model.config.get_config_value('IOType') | ||
|
||
assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now' | ||
assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now' | ||
|
||
# EinsumDense config | ||
params = default_params.copy() | ||
params['strategy'] = strategy | ||
params['n_free_data'] = node.attributes.attributes['n_free_data'] | ||
params['n_free_kernel'] = node.attributes.attributes['n_free_kernel'] | ||
params['n_contract'] = node.attributes.attributes['n_contract'] | ||
params['n_inplace'] = node.attributes.attributes['n_inplace'] | ||
params['parallelization_factor'] = node.attributes.attributes['parallelization_factor'] | ||
|
||
einsum_conf = self.template.format(**params) | ||
|
||
# inp/out transpose config | ||
inp_shape = node.attributes.attributes['inp_shape'] | ||
out_interpert_shape = node.attributes.attributes['out_interpert_shape'] | ||
inp_tpose_idxs = node.attributes.attributes['inp_tpose_idxs'] | ||
out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] | ||
tpose_inp_conf_name = f'config{node.index}_tpose_inp' | ||
tpose_out_conf_name = f'config{node.index}_tpose_out' | ||
|
||
inp_tpose_conf = transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) | ||
out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) | ||
|
||
# Dense config | ||
dense_params = default_params.copy() | ||
dense_params['strategy'] = strategy | ||
dense_params['n_in'] = node.attributes.attributes['n_contract'] | ||
dense_params['n_out'] = node.attributes.attributes['n_free_kernel'] | ||
if node.attributes.attributes['n_inplace'] == 1: | ||
dense_params['nzeros'] = node.get_weights('weight').nzeros # type: ignore | ||
else: | ||
dense_params['nzeros'] = '-1; // Not making sense when kernels are switching' | ||
dense_params['product_type'] = get_backend('vivado').product_type( | ||
node.get_input_variable().type.precision, node.get_weights('weight').type.precision # type: ignore | ||
) | ||
|
||
dense_params['dense_function'] = 'DenseLatency' # Latency only for now | ||
|
||
dense_config = self.dense_template.format(**dense_params) | ||
|
||
return '\n\n'.join((inp_tpose_conf, out_tpose_conf, dense_config, einsum_conf)) | ||
|
||
|
||
class EinsumDenseFunctionTemplate(FunctionCallTemplate): | ||
def __init__(self): | ||
super().__init__(EinsumDense, include_header=einsum_dense_include_list) | ||
self.template = einsum_dense_function_template | ||
|
||
def format(self, node): | ||
params = self._default_function_params(node) | ||
params['w'] = node.get_weights('weight').name | ||
params['b'] = node.get_weights('bias').name | ||
|
||
return self.template.format(**params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from . import conv # noqa: F401 | ||
from . import core # noqa: F401 | ||
from . import einsum_dense # noqa: F401 | ||
from ._base import registry as layer_handlers | ||
|
||
__all__ = ['layer_handlers'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import typing | ||
from typing import Sequence | ||
|
||
from ._base import KerasV3LayerHandler, register | ||
|
||
if typing.TYPE_CHECKING: | ||
import keras | ||
from keras.api import KerasTensor | ||
|
||
|
||
def strip_batch_dim(equation: str): | ||
"""Remove the batch dimension from the equation. | ||
Args: | ||
equation (str): The einsum equation. | ||
Returns: | ||
str: The einsum equation without the batch dimension. | ||
""" | ||
|
||
_inps, out = equation.split('->') | ||
inp0, inp1 = _inps.split(',') | ||
if inp0.startswith('...'): | ||
assert out.startswith('...'), f'Error in eq: {equation}: Batch dim mismatch for the input and output.' | ||
else: | ||
assert inp0[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the input and output.' | ||
assert inp0[0] not in inp1, f'Error in eq: {equation}: Batch dim is used in the kernel.' | ||
inp0, out = inp0[1:], out[1:] | ||
return f'{inp0},{inp1}->{out}' | ||
|
||
|
||
@register | ||
class KV3EinsumDenseHandler(KerasV3LayerHandler): | ||
handles = ('keras.src.layers.core.einsum_dense.EinsumDense',) | ||
|
||
def handle( | ||
self, | ||
layer: 'keras.layers.EinsumDense', | ||
in_tensors: Sequence['KerasTensor'], | ||
out_tensors: Sequence['KerasTensor'], | ||
): | ||
import keras | ||
|
||
assert len(in_tensors) == 1, 'EinsumDense layer must have exactly one input tensor' | ||
assert len(out_tensors) == 1, 'EinsumDense layer must have exactly one output tensor' | ||
|
||
inp_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore | ||
out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore | ||
|
||
# fmt: off | ||
assert all(d is not None for d in inp_shape), \ | ||
f'Error when processing {layer.name}: EinsumDense layer requires fully inp shapes' | ||
assert all(d is not None for d in out_shape), \ | ||
f'Error when processing {layer.name}: EinsumDense layer requires fully out shapes' | ||
# fmt: on | ||
|
||
equation = strip_batch_dim(layer.equation) | ||
|
||
kernel = keras.ops.convert_to_numpy(layer.kernel) | ||
|
||
bias = None | ||
if layer.bias_axes: | ||
bias = keras.ops.convert_to_numpy(layer.bias) | ||
|
||
return { | ||
'class_name': 'EinsumDense', | ||
'equation': equation, | ||
'weight_data': kernel, | ||
'bias_data': bias, | ||
'inp_shape': inp_shape, | ||
'out_shape': out_shape, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#ifndef NNET_EINSUM_DENSE_H_ | ||
#define NNET_EINSUM_DENSE_H_ | ||
|
||
#include "hls_stream.h" | ||
#include "nnet_common.h" | ||
#include "nnet_dense_latency.h" | ||
#include "nnet_dense_resource.h" | ||
#include "nnet_function_stubs.h" | ||
#include "nnet_helpers.h" | ||
#include "nnet_mult.h" | ||
#include "nnet_transpose.h" | ||
|
||
namespace nnet { | ||
|
||
struct einsum_dense_config { | ||
// Internal data type definitions | ||
|
||
typedef void tpose_inp_conf; | ||
typedef void tpose_out_conf; | ||
typedef void dense_conf; | ||
|
||
// Layer Sizes | ||
static const unsigned n_free_data = 1; | ||
static const unsigned n_free_kernel = 1; | ||
static const unsigned n_contract = 1; | ||
static const unsigned n_inplace = 1; | ||
|
||
// Resource reuse info | ||
static const unsigned io_type = io_parallel; | ||
static const unsigned strategy = latency; | ||
static const unsigned reuse_factor = 1; | ||
static const unsigned parallelization_factor = 1000; // Only useful when n_inplace > 1 | ||
static const bool store_weights_in_bram = false; // NOT USED | ||
|
||
// Product function to use | ||
template <class x_T, class y_T> using product = nnet::product::mult<x_T, y_T>; | ||
}; | ||
|
||
template <class data_T, class res_T, typename CONFIG_T> | ||
void einsum_dense( | ||
data_T data[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace], | ||
res_T res[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace], | ||
typename CONFIG_T::dense_conf::weight_t weights[CONFIG_T::n_free_kernel * CONFIG_T::n_contract * CONFIG_T::n_inplace], | ||
typename CONFIG_T::dense_conf::bias_t biases[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]) { | ||
data_T inp_tpose[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace]; | ||
res_T out_tpose[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]; | ||
res_T out_buffer[CONFIG_T::n_free_kernel]; | ||
#pragma HLS ARRAY_PARTITION variable = inp_tpose complete | ||
#pragma HLS ARRAY_PARTITION variable = out_tpose complete | ||
|
||
nnet::transpose<data_T, data_T, typename CONFIG_T::tpose_inp_conf>(data, inp_tpose); | ||
|
||
constexpr unsigned L0 = CONFIG_T::n_free_data; | ||
constexpr unsigned L1 = CONFIG_T::n_free_kernel; | ||
constexpr unsigned C = CONFIG_T::n_contract; | ||
constexpr unsigned I = CONFIG_T::n_inplace; | ||
|
||
for (unsigned l0 = 0; l0 < L0; l0++) { | ||
#pragma HLS UNROLL factor = CONFIG_T::parallelization_factor | ||
for (unsigned i = 0; i < I; i++) { | ||
#pragma HLS UNROLL | ||
// even w/o explicit distributed arithmetic optimization, latency kernels are partially implemented as such | ||
// so reusing the same multiplier for different weights doesn't really help... only full unrolling for now | ||
dense<data_T, res_T, typename CONFIG_T::dense_conf>(&inp_tpose[(i * L0 + l0) * C], out_buffer, | ||
&weights[(i * L1 * C)], &biases[((i * L0 + l0) * L1)]); | ||
for (unsigned j = 0; j < L1; j++) { | ||
#pragma HLS UNROLL | ||
out_tpose[(i * L0 + l0) * L1 + j] = out_buffer[j]; | ||
} | ||
} | ||
} | ||
|
||
nnet::transpose<res_T, res_T, typename CONFIG_T::tpose_out_conf>(out_tpose, res); | ||
} | ||
|
||
} // namespace nnet | ||
|
||
#endif |
Oops, something went wrong.