From b879fb5cd30461bc6881fb096d7a753e909847c0 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 27 Sep 2023 15:47:12 +0000 Subject: [PATCH] updates --- examples/llama/model.py | 17 ++++++++--------- examples/llama/model_test.py | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/llama/model.py b/examples/llama/model.py index 3ea616774..069257de2 100644 --- a/examples/llama/model.py +++ b/examples/llama/model.py @@ -82,15 +82,17 @@ def __init__( super().__init__() self.eps = eps self.dtype = dtype - self.weight = ark.parameter([dim], dtype) + self.weight = ark.parameter([dim], ark.fp32) def _norm(self, x): # x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return ark.rmsnorm(x) def forward(self, x): + x = ark.cast(x, ark.fp32) output = self._norm(x) - return ark.mul(output, ark.reshape(self.weight, [1, 1, -1])) + output = ark.mul(output, ark.reshape(self.weight, [1, 1, -1])) + return ark.cast(output, self.dtype) class ColumnParallelLinear(ark.Module): @@ -369,18 +371,16 @@ def forward( ) if freqs_cis is not None: xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - # TODO: enable kv cache and mask later + # TODO: enable kv cache later keys = xk values = xv # (bs, n_local_heads, seqlen, head_dim) xq = ark.transpose(xq, [0, 2, 1, 3]) - keys = ark.transpose(keys, [0, 2, 1, 3]) values = ark.transpose(values, [0, 2, 1, 3]) # (bs, n_local_heads, head_dim, seqlen) - keys_transpose = ark.transpose(keys, [0, 1, 3, 2]) - scores = ark.matmul(xq, keys_transpose) + keys = ark.transpose(keys, [0, 2, 3, 1]) + scores = ark.matmul(xq, keys) scores = ark.scale(scores, 1.0 / math.sqrt(self.head_dim)) if mask is not None: @@ -394,8 +394,7 @@ def forward( output = ark.reshape( output, [bsz, seqlen, self.head_dim * self.n_local_heads] ) - output = self.wo(output) - return output + return self.wo(output) class TransformerBlock(ark.Module): diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py index 457599c20..40fa2c354 100644 --- a/examples/llama/model_test.py +++ b/examples/llama/model_test.py @@ -12,7 +12,7 @@ import llama.model as model_pt import model as model_ark import numpy as np -from typing import Any, Dict, List +from typing import Dict, List from model import ModelArgs, ModelArgs7B, ModelArgs13B, ModelArgs70B