From 29674db676d095f615e5d0fe55869084c14341ff Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 15 Nov 2024 05:05:42 +0000 Subject: [PATCH] general einsum support for io_parallel and latency --- hls4ml/backends/vivado/passes/einsum_dense.py | 120 +++++++++ .../vivado/passes/reshaping_templates.py | 4 +- hls4ml/converters/keras_v3/__init__.py | 1 + hls4ml/converters/keras_v3/einsum_dense.py | 72 ++++++ hls4ml/model/layers.py | 66 ++++- .../vivado/nnet_utils/nnet_einsum_dense.h | 78 ++++++ hls4ml/utils/einsum_utils.py | 241 ++++++++++++++++++ 7 files changed, 579 insertions(+), 3 deletions(-) create mode 100644 hls4ml/backends/vivado/passes/einsum_dense.py create mode 100644 hls4ml/converters/keras_v3/einsum_dense.py create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h create mode 100644 hls4ml/utils/einsum_utils.py diff --git a/hls4ml/backends/vivado/passes/einsum_dense.py b/hls4ml/backends/vivado/passes/einsum_dense.py new file mode 100644 index 000000000..fb5287381 --- /dev/null +++ b/hls4ml/backends/vivado/passes/einsum_dense.py @@ -0,0 +1,120 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import EinsumDense + +from .reshaping_templates import transpose_config_gen + +# Shared Dense template + +conv_dense_config_template = """struct config{index}_dense : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned reuse_factor = {reuse}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned n_zeros = {nzeros}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + template + using kernel = nnet::{dense_function}; + template + using product = nnet::product::{product_type}; +}};\n""" + +# EinsumDense template + +einsum_dense_config_template = ''' +struct config{index} {{ + typedef config{index}_tpose_inp tpose_inp_conf; + typedef config{index}_tpose_out tpose_out_conf; + typedef config{index}_dense dense_conf; + + // Layer Sizes + static const unsigned n_free_data = {n_free_data}; + static const unsigned n_free_kernel = {n_free_kernel}; + static const unsigned n_contract = {n_contract}; + static const unsigned n_inplace = {n_inplace}; + + // Resource reuse info + static const unsigned io_type = nnet::{iotype}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse_factor}; + static const unsigned parallelization_factor = {parallelization_factor}; // Only useful when n_inplace > 1 + static const bool store_weights_in_bram = false; // NOT USED +}}; +''' + +einsum_dense_function_template = 'nnet::einsum_dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' + +einsum_dense_include_list = ['nnet_utils/nnet_einsum_dense.h', 'nnet_utils/nnet_dense.h'] + + +class EinsumDenseConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(EinsumDense) + self.template = einsum_dense_config_template + self.dense_template = conv_dense_config_template + + def format(self, node: EinsumDense): + default_params = self._default_config_params(node) + + strategy = node.model.config.get_strategy(node) + io_type = node.model.config.get_config_value('IOType') + + assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now' + assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now' + + # EinsumDense config + params = default_params.copy() + params['strategy'] = strategy + params['n_free_data'] = node.attributes.attributes['n_free_data'] + params['n_free_kernel'] = node.attributes.attributes['n_free_kernel'] + params['n_contract'] = node.attributes.attributes['n_contract'] + params['n_inplace'] = node.attributes.attributes['n_inplace'] + params['parallelization_factor'] = node.attributes.attributes['parallelization_factor'] + + einsum_conf = self.template.format(**params) + + # inp/out transpose config + inp_shape = node.attributes.attributes['inp_shape'] + out_interpert_shape = node.attributes.attributes['out_interpert_shape'] + inp_tpose_idxs = node.attributes.attributes['inp_tpose_idxs'] + out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] + tpose_inp_conf_name = f'config{node.index}_tpose_inp' + tpose_out_conf_name = f'config{node.index}_tpose_out' + + inp_tpose_conf = transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) + out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) + + # Dense config + dense_params = default_params.copy() + dense_params['strategy'] = strategy + dense_params['n_in'] = node.attributes.attributes['n_contract'] + dense_params['n_out'] = node.attributes.attributes['n_free_kernel'] + if node.attributes.attributes['n_inplace'] == 1: + dense_params['nzeros'] = node.get_weights('weight').nzeros # type: ignore + else: + dense_params['nzeros'] = '-1; // Not making sense when kernels are switching' + dense_params['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision # type: ignore + ) + + dense_params['dense_function'] = 'DenseLatency' # Latency only for now + + dense_config = self.dense_template.format(**dense_params) + + return '\n\n'.join((inp_tpose_conf, out_tpose_conf, dense_config, einsum_conf)) + + +class EinsumDenseFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(EinsumDense, include_header=einsum_dense_include_list) + self.template = einsum_dense_function_template + + def format(self, node): + params = self._default_function_params(node) + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) diff --git a/hls4ml/backends/vivado/passes/reshaping_templates.py b/hls4ml/backends/vivado/passes/reshaping_templates.py index f43d394cd..e59d81c8c 100644 --- a/hls4ml/backends/vivado/passes/reshaping_templates.py +++ b/hls4ml/backends/vivado/passes/reshaping_templates.py @@ -127,7 +127,7 @@ def format(self, node): transpose_function_template = 'nnet::transpose<{input_t}, {output_t}, {config_name}>({input}, {output});' -def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]): +def transpose_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]): new_shape = tuple(shape[i] for i in perm) strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1] perm_strides = tuple(int(strides[i]) for i in perm) @@ -151,7 +151,7 @@ def format(self, node): shape = tuple(node.get_input_variable().shape) perm = tuple(node.get_attr('perm')) name = f'config{node.index}' - return permute_config_gen(name, shape, perm) + return transpose_config_gen(name, shape, perm) class TransposeFunctionTemplate(FunctionCallTemplate): diff --git a/hls4ml/converters/keras_v3/__init__.py b/hls4ml/converters/keras_v3/__init__.py index f658faa1f..6dffcb71d 100644 --- a/hls4ml/converters/keras_v3/__init__.py +++ b/hls4ml/converters/keras_v3/__init__.py @@ -1,5 +1,6 @@ from . import conv # noqa: F401 from . import core # noqa: F401 +from . import einsum_dense # noqa: F401 from ._base import registry as layer_handlers __all__ = ['layer_handlers'] diff --git a/hls4ml/converters/keras_v3/einsum_dense.py b/hls4ml/converters/keras_v3/einsum_dense.py new file mode 100644 index 000000000..f0f4c7223 --- /dev/null +++ b/hls4ml/converters/keras_v3/einsum_dense.py @@ -0,0 +1,72 @@ +import typing +from typing import Sequence + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + + +def strip_batch_dim(equation: str): + """Remove the batch dimension from the equation. + + Args: + equation (str): The einsum equation. + + Returns: + str: The einsum equation without the batch dimension. + """ + + _inps, out = equation.split('->') + inp0, inp1 = _inps.split(',') + if inp0.startswith('...'): + assert out.startswith('...'), f'Error in eq: {equation}: Batch dim mismatch for the input and output.' + else: + assert inp0[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the input and output.' + assert inp0[0] not in inp1, f'Error in eq: {equation}: Batch dim is used in the kernel.' + inp0, out = inp0[1:], out[1:] + return f'{inp0},{inp1}->{out}' + + +@register +class KV3EinsumDenseHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.einsum_dense.EinsumDense',) + + def handle( + self, + layer: 'keras.layers.EinsumDense', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + import keras + + assert len(in_tensors) == 1, 'EinsumDense layer must have exactly one input tensor' + assert len(out_tensors) == 1, 'EinsumDense layer must have exactly one output tensor' + + inp_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore + out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore + + # fmt: off + assert all(d is not None for d in inp_shape), \ + f'Error when processing {layer.name}: EinsumDense layer requires fully inp shapes' + assert all(d is not None for d in out_shape), \ + f'Error when processing {layer.name}: EinsumDense layer requires fully out shapes' + # fmt: on + + equation = strip_batch_dim(layer.equation) + + kernel = keras.ops.convert_to_numpy(layer.kernel) + + bias = None + if layer.bias_axes: + bias = keras.ops.convert_to_numpy(layer.bias) + + return { + 'class_name': 'EinsumDense', + 'equation': equation, + 'weight_data': kernel, + 'bias_data': bias, + 'inp_shape': inp_shape, + 'out_shape': out_shape, + } diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index aac11cc7a..5392e2ffe 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -27,10 +27,12 @@ find_minimum_width, ) from hls4ml.utils import attribute_descriptions as descriptions +from hls4ml.utils.einsum_utils import parse_einsum from hls4ml.utils.string_utils import convert_to_snake_case - # TODO move this to some utility module + + class classproperty: def __init__(self, func): self.func = func @@ -1618,6 +1620,67 @@ def initialize(self): self.add_output_variable([len(self.get_attr('expression'))], [f'N_OUTPUTS_{self.index}'], var_name='y') +class EinsumDense(Layer): + _expected_attributes = [ + WeightAttribute('weight'), + WeightAttribute('bias'), + TypeAttribute('weight'), + TypeAttribute('bias'), + TypeAttribute('accum'), + Attribute('equation', value_type=str), + Attribute('inp_shape', value_type=tuple), + Attribute('out_shape', value_type=tuple), + ] + + def initialize(self): + out_shape = self.attributes['out_shape'] + if len(out_shape) > 1: + dims = [f'N_LAYER_{self.index}_D{i}' for i in range(1, len(out_shape) + 1)] + else: + dims = [f'N_LAYER_{self.index}'] + self.add_output_variable(list(out_shape), dims) + + kernel: np.ndarray = self.attributes.attributes['weight_data'] + bias: np.ndarray | None = self.attributes.attributes['bias_data'] + equation = self.attributes['equation'] + inp_shape = self.attributes['inp_shape'] + out_shape = self.attributes['out_shape'] + + recipe = parse_einsum(equation, inp_shape, kernel.shape) + inp_tpose_idxs, ker_tpose_idxs = recipe['in_transpose_idxs'] + out_tpose_idxs = recipe['out_transpose_idxs'] + + # Pre-transpose kernel (and bias) to save a transpose in cpp. Shouldn't matter for latency strategy though. + # hls4ml dense acts like i,ij->j + # parser assumes ij,j->i, so we need to transpose the kernel to match + kernel = kernel.transpose(ker_tpose_idxs) + kernel = kernel.reshape(recipe['I'], recipe['L1'], recipe['C']).transpose(0, 2, 1) + + # TODO: for weight in bram mode (resource), broadcasting bias here shall be avoided. + if bias is not None: + bias = np.broadcast_to(bias, out_shape).transpose(np.argsort(out_tpose_idxs)) + else: + # The automatically created bias is just the last dimension of the output shape + # Which is too small in general for einsum dense. + # The transpose is just to match the shape in case of have real bias, no real effect. + bias = np.zeros(out_shape).transpose(np.argsort(out_tpose_idxs)) + + self.attributes.attributes['weight_data'] = kernel + self.attributes.attributes['bias_data'] = bias + self.attributes['inp_tpose_idxs'] = inp_tpose_idxs + self.attributes['out_tpose_idxs'] = out_tpose_idxs + self.attributes['out_interpert_shape'] = recipe['out_interpert_shape'] + self.attributes['n_free_data'] = recipe['L0'] + self.attributes['n_free_kernel'] = recipe['L1'] + self.attributes['n_inplace'] = recipe['I'] + self.attributes['n_contract'] = recipe['C'] + pf = self.attributes.attributes.get('parallelization_factor', recipe['L0']) + self.attributes['parallelization_factor'] = pf + + self.add_weights(compression=self.model.config.get_compression(self)) + self.add_bias() + + layer_map = { 'Input': Input, 'InputLayer': Input, @@ -1686,6 +1749,7 @@ def initialize(self): 'SymbolicExpression': SymbolicExpression, # TensorFlow-specific layers: 'BiasAdd': BiasAdd, + 'EinsumDense': EinsumDense, } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h new file mode 100644 index 000000000..1abb7c5d0 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_einsum_dense.h @@ -0,0 +1,78 @@ +#ifndef NNET_EINSUM_DENSE_H_ +#define NNET_EINSUM_DENSE_H_ + +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_dense_latency.h" +#include "nnet_dense_resource.h" +#include "nnet_function_stubs.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" +#include "nnet_transpose.h" + +namespace nnet { + +struct einsum_dense_config { + // Internal data type definitions + + typedef void tpose_inp_conf; + typedef void tpose_out_conf; + typedef void dense_conf; + + // Layer Sizes + static const unsigned n_free_data = 1; + static const unsigned n_free_kernel = 1; + static const unsigned n_contract = 1; + static const unsigned n_inplace = 1; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned strategy = latency; + static const unsigned reuse_factor = 1; + static const unsigned parallelization_factor = 1000; // Only useful when n_inplace > 1 + static const bool store_weights_in_bram = false; // NOT USED + + // Product function to use + template using product = nnet::product::mult; +}; + +template +void einsum_dense( + data_T data[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace], + res_T res[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace], + typename CONFIG_T::dense_conf::weight_t weights[CONFIG_T::n_free_kernel * CONFIG_T::n_contract * CONFIG_T::n_inplace], + typename CONFIG_T::dense_conf::bias_t biases[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]) { + data_T inp_tpose[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace]; + res_T out_tpose[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]; + res_T out_buffer[CONFIG_T::n_free_kernel]; + #pragma HLS ARRAY_PARTITION variable = inp_tpose complete + #pragma HLS ARRAY_PARTITION variable = out_tpose complete + + nnet::transpose(data, inp_tpose); + + constexpr unsigned L0 = CONFIG_T::n_free_data; + constexpr unsigned L1 = CONFIG_T::n_free_kernel; + constexpr unsigned C = CONFIG_T::n_contract; + constexpr unsigned I = CONFIG_T::n_inplace; + + for (unsigned l0 = 0; l0 < L0; l0++) { + #pragma HLS UNROLL factor = CONFIG_T::parallelization_factor + for (unsigned i = 0; i < I; i++) { + #pragma HLS UNROLL + // even w/o explicit distributed arithmetic optimization, latency kernels are partially implemented as such + // so reusing the same multiplier for different weights doesn't really help... only full unrolling for now + dense(&inp_tpose[(i * L0 + l0) * C], out_buffer, + &weights[(i * L1 * C)], &biases[((i * L0 + l0) * L1)]); + for (unsigned j = 0; j < L1; j++) { + #pragma HLS UNROLL + out_tpose[(i * L0 + l0) * L1 + j] = out_buffer[j]; + } + } + } + + nnet::transpose(out_tpose, res); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/utils/einsum_utils.py b/hls4ml/utils/einsum_utils.py new file mode 100644 index 000000000..7d4253f76 --- /dev/null +++ b/hls4ml/utils/einsum_utils.py @@ -0,0 +1,241 @@ +from math import prod +from typing import TypedDict + +import numpy as np + + +class EinsumRecipe(TypedDict): + in_transpose_idxs: tuple[tuple[int, ...], tuple[int, ...]] + L0: int + L1: int + I: int + C: int + out_interpert_shape: tuple[int, ...] + out_transpose_idxs: tuple[int, ...] + + +def _validate_einsum_expr(fn: str, shape0: tuple[int, ...], shape1: tuple[int, ...]): + """Validate, resolve broadcasting, and compute output shape for einsum string + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + shape0 : tuple[int,...] + shape of input0 + shape1 : tuple[int,...] + shape of input1 + + Returns + ------- + tuple[str, tuple[int,...]] + einsum string w/o broadcasting, and output shape + + Raises + ------ + ValueError + If the einsum string is invalid, or if it is incompatible with the input shapes + """ + inp, out = map(str.strip, fn.split('->')) + in0, in1 = map(str.strip, inp.split(',')) + alphabets = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' + s_alphabets = set(alphabets) + + # Invalid characters + if not (s_alphabets >= set(in0.replace('...', '') + in1.replace('...', '') + out.replace('...', ''))): + raise ValueError(f"einsum string {fn} is invalid: subscripts should be in [a-zA-Z] and '...' only") + + in0 = in0.replace('...', '0') + in1 = in1.replace('...', '0') + out = out.replace('...', '0') + ax_in0, ax_in1, ax_out = list(in0), list(in1), list(out) + sax_in0, sax_in1, sax_out = set(ax_in0), set(ax_in1), set(ax_out) + free_indices = ''.join(sorted(s_alphabets - sax_in0 - sax_in1 - sax_out)) + + # Repeated indices + if len(sax_in0) != len(ax_in0): + for a in in0: + if in0.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: input0 subscripts includes '{a}' multiple times") + if len(sax_in1) != len(ax_in1): + for a in in1: + if in1.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: input1 subscripts includes '{a}' multiple times") + if len(sax_out) != len(ax_out): + for a in out: + if out.count(a) == 1: + continue + a = a if a != '0' else '...' + raise ValueError(f"einsum string {fn} is invalid: output subscripts includes '{a}' multiple times") + + # Invalid broadcasting + if '0' in sax_in0 or '0' in sax_in1 or '0' in sax_out: + if '0' in sax_in0 and '0' in sax_in1: + raise ValueError(f"einsum string {fn} is invalid: both input0 and input1 allows broadcasting") + if '0' not in sax_out: + raise ValueError(f"einsum string {fn} is invalid: output does not allow broadcasting, but inputs do") + if '0' not in sax_in0 and '0' not in sax_in1: + raise ValueError(f"einsum string {fn} is invalid: output allows broadcasting, but inputs do not") + + # Output index out of nowhere + if remaining := sax_out - sax_in0 - sax_in1: + raise ValueError(f"einsum string {fn} is invalid: output subscripts {remaining} not found in inputs") + + _common_in = sax_in0 & sax_in1 + + # Invalid input dimensions + if '0' in sax_in0: + if len(sax_in0) - 1 > len(shape0): + raise ValueError(f"Input0 requires at least {len(sax_in0)-1} dimensions, but only {len(shape0)} given") + # Replace broadcasting indices with free indices + n_broadcast = len(shape0) - len(sax_in0) + 1 + in0 = in0.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in0 = list(in0) + ax_out = list(out) + else: + if len(sax_in0) != len(shape0): + raise ValueError(f"Input0 requires {len(sax_in0)} dimensions, but {len(shape0)} is given") + if '0' in sax_in1: + if len(sax_in1) - 1 > len(shape1): + raise ValueError(f"Input1 requires at least {len(sax_in1)-1} dimensions, but only {len(shape1)} given") + # Replace broadcasting indices with free indices + n_broadcast = len(shape1) - len(sax_in1) + 1 + in1 = in1.replace('0', free_indices[:n_broadcast]) + out = out.replace('0', free_indices[:n_broadcast]) + ax_in1 = list(in1) + ax_out = list(out) + else: + if len(sax_in1) != len(shape1): + raise ValueError(f"Input1 requires {len(sax_in1)} dimensions, but {len(shape1)} is given") + + # Input dimension mismatch + for a in _common_in: + ax_0 = ax_in0.index(a) + ax_1 = ax_in1.index(a) + if shape0[ax_0] != shape1[ax_1]: + raise ValueError( + f"Input dimension size mismatches for common subscript '{a}': {shape0[ax_0]} and {shape1[ax_1]}" + ) + + out_shape = tuple(shape0[ax_in0.index(a)] if a in ax_in0 else shape1[ax_in1.index(a)] for a in ax_out) + return f'{in0},{in1}->{out}', out_shape + + +def parse_einsum(fn: str, input_shape0: tuple[int, ...], input_shape1: tuple[int, ...]) -> EinsumRecipe: + """Execute einsum operation on two input arrays + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + input : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + + fn, _ = _validate_einsum_expr(fn, input_shape0, input_shape1) + + _in, _out = fn.split('->') + _in0, _in1 = _in.split(',') + + in0, in1, out = list(_in0), list(_in1), list(_out) + s_in0, s_in1, s_out = set(in0), set(in1), set(out) + _common = s_in0 & s_in1 + _contract = _common - s_out + _inplace = _common & s_out + contract = sorted(_contract, key=lambda x: in1.index(x)) + inplace = sorted(_inplace, key=lambda x: in1.index(x)) + invariant0 = sorted((s_out - _common) & s_in0, key=lambda x: in0.index(x)) + invariant1 = sorted((s_out - _common) & s_in1, key=lambda x: in1.index(x)) + + contract_idxs = tuple(map(in0.index, contract)), tuple(map(in1.index, contract)) + inplace_idxs = tuple(map(in0.index, inplace)), tuple(map(in1.index, inplace)) + invariant_idxs = tuple(map(in0.index, invariant0)), tuple(map(in1.index, invariant1)) + + inplace_shape = tuple(input_shape0[i] for i in inplace_idxs[0]) + inplace_size = prod(inplace_shape) + contract_size = prod(input_shape0[i] for i in contract_idxs[0]) + invariant_shape0 = tuple(input_shape0[i] for i in invariant_idxs[0]) + invariant_shape1 = tuple(input_shape1[i] for i in invariant_idxs[1]) + invariant_size0, invariant_size1 = prod(invariant_shape0), prod(invariant_shape1) + + transpose_idx0 = inplace_idxs[0] + invariant_idxs[0] + contract_idxs[0] + transpose_idx1 = inplace_idxs[1] + invariant_idxs[1] + contract_idxs[1] + + out_shape_pretranspose = inplace_shape + invariant_shape0 + invariant_shape1 + _out_transpose_idx = np.argsort(tuple(map(out.index, inplace + invariant0 + invariant1))) + out_transpose_idx = tuple(int(i) for i in _out_transpose_idx) + + return EinsumRecipe( + in_transpose_idxs=(transpose_idx0, transpose_idx1), + out_interpert_shape=out_shape_pretranspose, + out_transpose_idxs=out_transpose_idx, + L0=invariant_size0, + L1=invariant_size1, + I=inplace_size, + C=contract_size, + ) + + +def _exec_einsum(recipe: EinsumRecipe, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: + """Execute einsum operation on two input arrays + + Parameters + ---------- + recipe : EinsumRecipe + einsum recipe + input0 : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + input0 = input0.transpose(recipe['in_transpose_idxs'][0]).ravel() + input1 = input1.transpose(recipe['in_transpose_idxs'][1]).ravel() + output = np.zeros(recipe['L0'] * recipe['L1'] * recipe['I'], dtype=input0.dtype) + + L0, L1, I, C = recipe['L0'], recipe['L1'], recipe['I'], recipe['C'] + + for l0 in range(L0): + for i in range(I): + output[(i * L0 + l0) * L1 : (i * L0 + l0 + 1) * L1] = ( + input1[i * L1 * C : (i + 1) * L1 * C].reshape((L1, C)) @ input0[(i * L0 + l0) * C : (i * L0 + l0 + 1) * C] + ) + + return output.reshape(recipe['out_interpert_shape']).transpose(recipe['out_transpose_idxs']) + + +def einsum(fn: str, input0: np.ndarray, input1: np.ndarray) -> np.ndarray: + """Execute einsum operation on two input arrays + + Parameters + ---------- + fn : str + einsum string, e.g. 'ij,jk->ik' + input : np.ndarray + input0, the first input array + input1 : np.ndarray + input1, the second input array + + Returns + ------- + np.ndarray + output array + """ + recipe = parse_einsum(fn, input0.shape, input1.shape) + return _exec_einsum(recipe, input0, input1)