diff --git a/memsave_torch/nn/LayerNorm.py b/memsave_torch/nn/LayerNorm.py index d792deb..6a65f0d 100644 --- a/memsave_torch/nn/LayerNorm.py +++ b/memsave_torch/nn/LayerNorm.py @@ -168,6 +168,8 @@ def from_existing(cls, ln): Returns: obj: The MemSaveRMSLayerNorm object """ + if ln.variance_epsilon is not None: # T5LayerNorm + ln.eps = ln.variance_epsilon obj = cls( ln.weight.shape, ln.eps,