-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question]:Code related question: Is the search just for the first batch of dataset? #91
Comments
def minference_forward():
|
Hi @unicorneeee, thank you for your attention! Yes, in MInference, pattern search is conducted offline and uses only a single example. However, we found that even with this setup, it demonstrates excellent generalization. |
Describe the issue
I hope this message find U well!
I find in our minference forward function:
if self.is_search:
if os.path.exists(self.config_path):
config_list = json.load(open(self.config_path))
if self.config.num_hidden_layers == len(config_list):
assert False, f"Search completed. The config is located in {self.config_path}."
else:
config_list = []
config = {}
print("Layer", self.layer_idx)
if q_len != 1:
output = torch.empty_like(query_states)
for head in range(query_states.size(1)):
q = query_states[:, head, :, :].unsqueeze(1)
k = key_states[:, head, :, :].unsqueeze(1)
v = value_states[:, head, :, :].unsqueeze(1)
if self.is_search and self.layer_idx >= len(config_list):
config[head] = search_pattern(q, k, head)
if self.layer_idx >= self.starting_layer and not self.is_search:
attn_output = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
elif is_flash_attn_2_available():
attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
else:
attn_output = gather_qkv(q, k, v, attention_mask)
output[:, head:head + 1] = attn_output
if self.is_search:
if len(config):
config_list.append(config)
with open(self.config_path, 'w') as json_file:
json.dump(config_list, json_file)
else:
output = flash_attn_func(query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, query_states.size(1), q_len, self.head_dim)
attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
If I set the is_search ==True and add the config path, is it just search the best sparse attention pattern for the first batch data , rather than the whole dataset?
The text was updated successfully, but these errors were encountered: