-
Notifications
You must be signed in to change notification settings - Fork 0
/
flashback.metal
258 lines (185 loc) · 9.91 KB
/
flashback.metal
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
#include <metal_stdlib>
kernel void backprop_attention(const device float* query[[buffer(0)]], const device float* key[[buffer(1)]], const device float* value[[buffer(2)]],
device float* out [[buffer(3)]], device float* dO [[buffer(4)]], device float* out_dQ [[buffer(5)]], device float* out_dK [[buffer(6)]], device float* out_dV [[buffer(7)]], device float* ROW_SUMS [[buffer(8)]], device float* ROW_MAX_VALS [[buffer(9)]], uint2 gid [[thread_position_in_grid]], uint2 tid [[thread_position_in_threadgroup]], uint2 tgid [[threadgroup_position_in_grid]]) {
const unsigned int query_size = 32;
const unsigned int key_size = 32;
const unsigned int score_size = query_size * key_size;
const unsigned int embed_dim = 8;
const unsigned int seq_len = 1024;
const unsigned int num_keys = seq_len / key_size;
const unsigned int num_heads = 8;
const unsigned int num_values_batch = num_heads * seq_len * embed_dim;
const unsigned int num_values_head = seq_len * embed_dim;
const unsigned int batch_index = tgid.y * num_values_batch;
const unsigned int head_index = tgid.x * num_values_head;
const unsigned int batch_plus_head_index = batch_index + head_index;
const unsigned int num_el_kv = key_size * embed_dim;
const unsigned int num_el_query = query_size * embed_dim;
const unsigned int dV_elements = key_size * embed_dim;
const unsigned int num_threads = seq_len / query_size;
const unsigned int row_val_offset = (tgid.y * num_heads * seq_len) + (tgid.x * seq_len) + (key_size*tid.y);
// Each thread adds num_el sets of elements, where num_el is number of elements in dV_i = embed_dim * key_size, divided by the total number of threads = (seq_len / query_size)
// hence num_el = query_size * key_size * embed_dim / seq_len
const unsigned int num_el = dV_elements / num_threads;
// IMPORTANT
// tid.y contains the query index. Each threadgroup contains B_q query blocks, which compute a particular attention score. Each threadgroup contributes for one head in a single batch
// tgid.x contains the current head dimension
// tgid.y contains the current batch dimension
// initialise buffers for copying -- not in SRAM as we do not wish to share it
float QUERY_LOCAL[num_el_query];
float OUTPUT_LOCAL[key_size * query_size];
float dO_LOCAL[num_el_query];
float O_LOCAL[num_el_query];
float dQ[num_el_query];
float dK[num_el_kv];
float dim_factor = metal::sqrt((float)embed_dim);
// copy all queries/outputs to SRAM
unsigned int elements_to_copy = num_el_query;
const unsigned int copying_index = batch_plus_head_index + tid.y*elements_to_copy;
for(unsigned int i = 0; i < elements_to_copy; i++) {
QUERY_LOCAL[i] = query[copying_index + i];
dO_LOCAL[i] = dO[copying_index + i];
O_LOCAL[i] = out[copying_index + i];
dQ[i] = 0.0;
}
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
unsigned int elements_key_copy = (key_size * key_size * embed_dim) / seq_len;
const unsigned int local_offset_kv = tid.y * elements_key_copy;
const unsigned int total_offset_bhkv = batch_index + head_index + local_offset_kv;
threadgroup float KEY_SRAM[num_el_kv];
threadgroup float VALUE_SRAM[num_el_kv];
// SRAM contains sum of all dVs computed by each thread in group
threadgroup float dKV_acc[dV_elements];
float dV[dV_elements];
float dP[score_size];
// iterate over each key block and compute attention scores
for(unsigned int k = 0; k < num_keys; k++) {
const unsigned int local_offset_bhkv = total_offset_bhkv + k*num_el_kv;
// copy from HBM, each thread copies a little bit of the shared key/value block into SRAM.
for(unsigned int i = 0; i < elements_key_copy; i++) {
KEY_SRAM[local_offset_kv + i] = key[local_offset_bhkv + i];
VALUE_SRAM[local_offset_kv + i] = value[local_offset_bhkv + i];
}
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
// this contains the attention score matrix before softmax
// do matmul -- outer loop is each row in Q-block
for(unsigned int i = 0; i < query_size; i++) {
const unsigned int query_row_index = i*embed_dim;
// inner loop is each row in K-block. Should be column but it's transposed
for(unsigned int j = 0; j < key_size; j++) {
const unsigned int key_row_index = j*embed_dim;
// for LT matrix
if((tid.y * query_size + i) < (k * key_size + j)) {
OUTPUT_LOCAL[i*query_size + j] = 0.0;
continue;
}
// compute dot product
float total_dot = 0.0;
for(unsigned int el = 0; el < embed_dim; el++) {
total_dot += QUERY_LOCAL[query_row_index + el] * KEY_SRAM[key_row_index + el];
}
// each query vector adds another row to the output attention scores
OUTPUT_LOCAL[i*query_size + j] = total_dot / dim_factor;
OUTPUT_LOCAL[i*query_size + j] = metal::exp(OUTPUT_LOCAL[i*query_size + j] - ROW_MAX_VALS[row_val_offset + i]) / ROW_SUMS[row_val_offset + i];
// out[(tid.y * query_size + i) * seq_len + k*key_size + j] = OUTPUT_LOCAL[i*query_size + j];
}
}
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
// compute dV_part = P^T dO
// == each column of OUTPUT_LOCAL dotted with each row of dO_LOCAL
// iterate over each column (OUTPUT_LOCAL is of shape (query_size, key_size)), and dO is of shape (query_size, embed_dim)
for(unsigned int o_col = 0; o_col < query_size; o_col++) {
const unsigned int o_col_by_embed_dim = o_col * embed_dim;
for(unsigned int dO_col = 0; dO_col < embed_dim; dO_col++) {
// dot product
float total_dot = 0.0;
for(unsigned int el = 0; el < query_size; el++) {
total_dot += OUTPUT_LOCAL[el*key_size + o_col] * dO_LOCAL[el*embed_dim + dO_col];
}
dV[o_col_by_embed_dim + dO_col] = total_dot;
}
}
// all threads must finish computation
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
// time to add!
// want to accumulate in a multi-threaded way
// basically chunk space up and iterate over each chunk, each thread writes particular bit into chunk
// iteration zero -> thread zero writes to zeroth block here (each block is num_el elements)
// thread one writes to first etc etc
// iteration one -> thread zero write to first block, first block of its own matrix to first block of acc matrix
// we get a pattern where tid.y is used to index into the appropriate block for each iteration
// for zeroth iteration, we have tid.y*num_el to index into, and then we just add the current iteration counter so ((tid.y+i) % num_threads) * num_el
// equivalent to reshaping matrix as (num_el, num_threads)
for(unsigned int i = 0; i < num_el; i++) dKV_acc[tid.y * num_el + i] = 0.0;
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
for(unsigned int i = 0; i < num_threads; i++) {
unsigned int rotated_block_index = ((tid.y + i) % num_threads) * num_el;
for(unsigned int el_acc = 0; el_acc < num_el; el_acc++) {
dKV_acc[rotated_block_index + el_acc] += dV[rotated_block_index + el_acc];
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
}
}
// now we want to copy this to an output tensor
for(unsigned int i = 0; i < num_el; i++) out_dV[batch_plus_head_index + k*dV_elements + tid.y * num_el + i] = dKV_acc[tid.y * num_el + i];
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
for(unsigned int dO_row = 0; dO_row < query_size; dO_row++) {
for(unsigned int VT_col = 0; VT_col < key_size; VT_col++) {
float total_dot = 0.0;
for(unsigned int i = 0; i < embed_dim; i++) {
total_dot += dO_LOCAL[dO_row * embed_dim + i] * VALUE_SRAM[VT_col * embed_dim + i];
}
dP[dO_row * key_size + VT_col] = total_dot;
}
}
// start computing row-sum
for(unsigned int o_row = 0; o_row < query_size; o_row++) {
const unsigned int dP_index = o_row*key_size;
const unsigned int out_local_index = o_row * query_size;
float total_acc = 0.0;
for(unsigned int row_el = 0; row_el < embed_dim; row_el++) {
unsigned int idx = o_row * embed_dim + row_el;
total_acc += dO_LOCAL[idx] * O_LOCAL[idx];
}
for(unsigned int i = 0; i < key_size; i++) {
dP[dP_index + i] -= total_acc;
dP[dP_index + i] *= (OUTPUT_LOCAL[out_local_index + i]) * (1/dim_factor);
}
}
// matmul dP with K_SRAM
// dP = (query_size , key_size), K_SRAM = (key_size, embed_dim)
// dQ = (query_size, embed_dim)
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
for(unsigned int dp_row = 0; dp_row < query_size; dp_row++) {
const unsigned int dp_row_abs_index = dp_row * embed_dim;
const unsigned int dp_row_abs_index_att = dp_row * key_size;
for(unsigned int k_col = 0; k_col < embed_dim; k_col++) {
for(unsigned int i = 0; i < key_size; i++) {
dQ[dp_row_abs_index + k_col] += (dP[dp_row_abs_index_att + i] * KEY_SRAM[k_col + i*embed_dim]);
}
}
}
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
// dp^T = (key_size, query_size), Q = (query_size, embed_dim)
for(unsigned int dp_col = 0; dp_col < key_size; dp_col++) {
for(unsigned int q_col = 0; q_col < embed_dim; q_col++) {
float total_dot = 0.0;
for(unsigned int i = 0; i < query_size; i++) {
total_dot += dP[i * key_size + dp_col] * QUERY_LOCAL[i * embed_dim + q_col];
}
dK[dp_col * embed_dim + q_col] = total_dot;
}
}
for(unsigned int i = 0; i < num_el; i++) dKV_acc[tid.y * num_el + i] = 0.0;
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
for(unsigned int i = 0; i < num_threads; i++) {
unsigned int rotated_block_index = ((tid.y + i) % num_threads) * num_el;
for(unsigned int el_acc = 0; el_acc < num_el; el_acc++) {
dKV_acc[rotated_block_index + el_acc] += dK[rotated_block_index + el_acc];
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
}
}
for(unsigned int i = 0; i < num_el; i++) out_dK[batch_plus_head_index + k*dV_elements + tid.y * num_el + i] = dKV_acc[tid.y * num_el + i];
}
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
for(unsigned int i = 0; i < num_el_query; i++) out_dQ[batch_plus_head_index + tid.y * num_el_query + i] = dQ[i];
}