-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented Legacy Inference Support and Base Conversion Logic for `t…
…orch.nn.RNN` Layers (#134) * Implemented legacy inference support and base conversion logic for `torch.nn.RNN` layers * Updated conftest.py * 🎨 Enforced Python/C++/CUDA Code Formatting with Black and Clang (#135) Co-authored-by: coreylammie <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: coreylammie <[email protected]>
- Loading branch information
1 parent
30251e0
commit 305c3a7
Showing
5 changed files
with
336 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,285 @@ | ||
import math | ||
import warnings | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn.modules import conv | ||
|
||
import memtorch | ||
from memtorch.bh.crossbar.Crossbar import init_crossbar, simulate_matmul | ||
from memtorch.bh.crossbar.Tile import tiled_inference | ||
from memtorch.map.Input import naive_scale | ||
from memtorch.map.Module import naive_tune | ||
from memtorch.map.Parameter import naive_map | ||
|
||
|
||
class RNN(nn.RNN): | ||
"""nn.RNN equivalent. | ||
Parameters | ||
---------- | ||
rnn_layer : torch.nn.RNN | ||
RNN layer to patch. | ||
memristor_model : memtorch.bh.memristor.Memristor.Memristor | ||
Memristor model. | ||
memristor_model_params : **kwargs | ||
Memristor model keyword arguments. | ||
mapping_routine : function | ||
Mapping routine to use. | ||
transistor : bool | ||
Used to determine if a 1T1R (True) or 1R arrangement (False) is simulated. | ||
programming_routine : function | ||
Programming routine to use. | ||
programming_routine_params : **kwargs | ||
Programming routine keyword arguments. | ||
p_l: float | ||
If not None, the proportion of weights to retain. | ||
scheme : memtorch.bh.Scheme | ||
Weight representation scheme. | ||
tile_shape : (int, int) | ||
Tile shape to use to store weights. If None, modular tiles are not used. | ||
max_input_voltage : float | ||
Maximum input voltage used to encode inputs. If None, inputs are unbounded. | ||
scaling_routine : function | ||
Scaling routine to use in order to scale batch inputs. | ||
scaling_routine_params : **kwargs | ||
Scaling routine keyword arguments. | ||
source_resistance : float | ||
The resistance between word/bit line voltage sources and crossbar(s). | ||
line_resistance : float | ||
The interconnect line resistance between adjacent cells. | ||
ADC_resolution : int | ||
ADC resolution (bit width). If None, quantization noise is not accounted for. | ||
ADC_overflow_rate : float | ||
Overflow rate threshold for linear quanitzation (if ADC_resolution is not None). | ||
quant_method: string | ||
Quantization method. Must be in ['linear', 'log', 'log_minmax', 'minmax', 'tanh'], or None. | ||
use_bindings : bool | ||
Used to determine if C++/CUDA bindings are used (True) or not (False). | ||
random_crossbar_init: bool | ||
Determines if the crossbar is to be initialized at random values in between Ron and Roff | ||
verbose : bool | ||
Used to determine if verbose output is enabled (True) or disabled (False). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
rnn_layer, | ||
memristor_model, | ||
memristor_model_params, | ||
mapping_routine=naive_map, | ||
transistor=True, | ||
programming_routine=None, | ||
programming_routine_params={}, | ||
p_l=None, | ||
scheme=memtorch.bh.Scheme.DoubleColumn, | ||
tile_shape=None, | ||
max_input_voltage=None, | ||
scaling_routine=naive_scale, | ||
scaling_routine_params={}, | ||
source_resistance=None, | ||
line_resistance=None, | ||
ADC_resolution=None, | ||
ADC_overflow_rate=0.0, | ||
quant_method=None, | ||
use_bindings=True, | ||
random_crossbar_init=False, | ||
verbose=True, | ||
*args, | ||
**kwargs, | ||
): | ||
assert isinstance(rnn_layer, nn.RNN), "rnn_layer is not an instance of nn.RNN." | ||
self.device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") | ||
self.transistor = transistor | ||
self.scheme = scheme | ||
self.tile_shape = tile_shape | ||
self.max_input_voltage = max_input_voltage | ||
self.scaling_routine = scaling_routine | ||
self.scaling_routine_params = scaling_routine_params | ||
self.source_resistance = source_resistance | ||
self.line_resistance = line_resistance | ||
self.ADC_resolution = ADC_resolution | ||
self.ADC_overflow_rate = ADC_overflow_rate | ||
if "cpu" not in memtorch.__version__: | ||
self.cuda_malloc_heap_size = 50 | ||
else: | ||
self.cuda_malloc_heap_size = None | ||
|
||
if not transistor: | ||
assert ( | ||
source_resistance is not None and source_resistance >= 0.0 | ||
), "Source resistance is invalid." | ||
assert ( | ||
line_resistance is not None and line_resistance >= 0.0 | ||
), "Line resistance is invalid." | ||
|
||
if quant_method in memtorch.bh.Quantize.quant_methods: | ||
self.quant_method = quant_method | ||
else: | ||
self.quant_method = None | ||
|
||
if quant_method is not None: | ||
assert ( | ||
ADC_resolution is not None | ||
and type(ADC_resolution) == int | ||
and ADC_resolution > 0 | ||
), "ADC resolution is invalid." | ||
assert ( | ||
ADC_overflow_rate is not None | ||
), "ADC_overflow_rate must be specified if quant_method is not None." | ||
|
||
self.use_bindings = use_bindings | ||
self.verbose = verbose | ||
self.forward_legacy_enabled = True | ||
super(RNN, self).__init__( | ||
input_size=rnn_layer.input_size, | ||
hidden_size=rnn_layer.hidden_size, | ||
num_layers=rnn_layer.num_layers, | ||
nonlinearity=rnn_layer.nonlinearity, | ||
bias=rnn_layer.bias, | ||
batch_first=False, # To add support. | ||
dropout=0.0, # To add support. | ||
bidirectional=rnn_layer.bidirectional, | ||
**kwargs, | ||
) | ||
if rnn_layer.nonlinearity in ["tanh", "relu"]: | ||
if rnn_layer.nonlinearity == "tanh": | ||
self.nonlinearity = torch.tanh | ||
elif rnn_layer.nonlinearity == "relu": | ||
self.nonlinearity = torch.relu | ||
else: | ||
raise Exception("Nonlinearity must be either tanh or relu") | ||
|
||
self.w_ih = [] | ||
self.w_hh = [] | ||
if rnn_layer.bias: | ||
self.b_ih = [] | ||
self.b_hh = [] | ||
|
||
if rnn_layer.bidirectional: | ||
self.w_ih_reverse = [] | ||
self.w_hh_reverse = [] | ||
if rnn_layer.bias: | ||
self.b_ih_reverse = [] | ||
self.b_hh_reverse = [] | ||
|
||
self.zero_grad() | ||
for i in range(rnn_layer.num_layers): | ||
self.w_ih.append(rnn_layer._parameters[f"weight_ih_l{i}"].data) | ||
self.w_ih[i].requires_grad = False | ||
self.w_hh.append(rnn_layer._parameters[f"weight_hh_l{i}"].data) | ||
self.w_hh[i].requires_grad = False | ||
if rnn_layer.bias: | ||
self.b_ih.append(rnn_layer._parameters[f"bias_ih_l{i}"].data) | ||
self.b_ih[i].requires_grad = False | ||
self.b_hh.append(rnn_layer._parameters[f"bias_hh_l{i}"].data) | ||
self.b_hh[i].requires_grad = False | ||
|
||
if rnn_layer.bidirectional: | ||
self.w_ih_reverse.append( | ||
rnn_layer._parameters["weight_ih_l{i}_reverse"].data | ||
) | ||
self.w_ih_reverse[i].requires_grad = False | ||
self.w_hh_reverse.append( | ||
rnn_layer._parameters["weight_hh_l{i}_reverse"].data | ||
) | ||
self.w_hh_reverse[i].requires_grad = False | ||
if rnn_layer.bias: | ||
self.b_ih_reverse.append( | ||
rnn_layer._parameters["bias_ih_l{i}_reverse"].data | ||
) | ||
self.b_ih_reverse[i].requires_grad = False | ||
self.b_hh_reverse.append( | ||
rnn_layer._parameters["bias_hh_l{i}_reverse"].data | ||
) | ||
self.b_hh_reverse[i].requires_grad = False | ||
|
||
warnings.warn( | ||
"RNN layers are not fully supported. Only legacy forward passes are currently enabled." | ||
) | ||
|
||
def forward(self, input, h_0=None): | ||
"""Method to perform forward propagations. | ||
Parameters | ||
---------- | ||
input : torch.Tensor | ||
Input tensor. | ||
h_0 : torch.Tensor | ||
The initial hidden state for the input sequence batch | ||
Returns | ||
------- | ||
torch.Tensor | ||
Output tensor. | ||
""" | ||
if h_0 is None: | ||
if self.bidirectional: | ||
h_0 = torch.zeros(2, self.num_layers, input.shape[1], self.hidden_size) | ||
else: | ||
h_0 = torch.zeros(1, self.num_layers, input.shape[1], self.hidden_size) | ||
|
||
if self.bidirectional: | ||
output = torch.zeros(input.shape[0], input.shape[1], 2 * self.hidden_size) | ||
else: | ||
output = torch.zeros(input.shape[0], input.shape[1], self.hidden_size) | ||
|
||
inp = input | ||
for layer in range(self.num_layers): | ||
h_t = h_0[0, layer] | ||
for t in range(inp.shape[0]): | ||
if self.bias: | ||
h_t = ( | ||
torch.matmul(inp[t], self.w_ih[layer].T) | ||
+ self.b_ih[layer] | ||
+ torch.matmul(h_t, self.w_hh[layer].T) | ||
+ self.b_hh[layer] | ||
) | ||
else: | ||
h_t = torch.matmul(inp[t], self.w_ih[layer].T) + torch.matmul( | ||
h_t, self.w_hh[layer].T | ||
) | ||
|
||
h_t = self.nonlinearity(h_t) | ||
output[t, :, : self.hidden_size] = h_t | ||
|
||
if self.bidirectional: | ||
h_t_reverse = h_0[1, layer] | ||
for t in range(inp.shape[0]): | ||
if self.bias: | ||
h_t_reverse = ( | ||
torch.matmul(inp[-1 - t], self.w_ih_reverse[layer].T) | ||
+ self.b_ih_reverse[layer] | ||
+ torch.matmul(h_t_reverse, self.w_hh_reverse[layer].T) | ||
+ self.b_hh_reverse[layer] | ||
) | ||
else: | ||
h_t_reverse = torch.matmul( | ||
inp[-1 - t], self.w_ih_reverse[layer].T | ||
) + torch.matmul(h_t_reverse, self.w_hh_reverse[layer].T) | ||
|
||
h_t_reverse = self.nonlinearity(h_t_reverse) | ||
output[-1 - t, :, self.hidden_size :] = h_t_reverse | ||
|
||
inp = output.clone() | ||
|
||
return output | ||
|
||
def tune(self): | ||
"""Tuning method.""" | ||
pass # To be implemented. | ||
|
||
def __str__(self): | ||
return ( | ||
"bh.RNN(input_size=%d, hidden_size=%d, num_layers=%d, nonlinearity=%s, bias=%s, bidirectional=%s)" | ||
% ( | ||
self.input_size, | ||
self.hidden_size, | ||
self.num_layers, | ||
self.nonlinearity, | ||
self.bias, | ||
self.bidirectional, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.