Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 27, 2023
1 parent eaf0b3e commit b879fb5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
17 changes: 8 additions & 9 deletions examples/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,17 @@ def __init__(
super().__init__()
self.eps = eps
self.dtype = dtype
self.weight = ark.parameter([dim], dtype)
self.weight = ark.parameter([dim], ark.fp32)

def _norm(self, x):
# x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return ark.rmsnorm(x)

def forward(self, x):
x = ark.cast(x, ark.fp32)
output = self._norm(x)
return ark.mul(output, ark.reshape(self.weight, [1, 1, -1]))
output = ark.mul(output, ark.reshape(self.weight, [1, 1, -1]))
return ark.cast(output, self.dtype)


class ColumnParallelLinear(ark.Module):
Expand Down Expand Up @@ -369,18 +371,16 @@ def forward(
)
if freqs_cis is not None:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# TODO: enable kv cache and mask later
# TODO: enable kv cache later
keys = xk
values = xv
# (bs, n_local_heads, seqlen, head_dim)
xq = ark.transpose(xq, [0, 2, 1, 3])
keys = ark.transpose(keys, [0, 2, 1, 3])
values = ark.transpose(values, [0, 2, 1, 3])

# (bs, n_local_heads, head_dim, seqlen)
keys_transpose = ark.transpose(keys, [0, 1, 3, 2])
scores = ark.matmul(xq, keys_transpose)
keys = ark.transpose(keys, [0, 2, 3, 1])
scores = ark.matmul(xq, keys)
scores = ark.scale(scores, 1.0 / math.sqrt(self.head_dim))

if mask is not None:
Expand All @@ -394,8 +394,7 @@ def forward(
output = ark.reshape(
output, [bsz, seqlen, self.head_dim * self.n_local_heads]
)
output = self.wo(output)
return output
return self.wo(output)


class TransformerBlock(ark.Module):
Expand Down
2 changes: 1 addition & 1 deletion examples/llama/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import llama.model as model_pt
import model as model_ark
import numpy as np
from typing import Any, Dict, List
from typing import Dict, List
from model import ModelArgs, ModelArgs7B, ModelArgs13B, ModelArgs70B


Expand Down

0 comments on commit b879fb5

Please sign in to comment.