Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question]:Code related question: Is the search just for the first batch of dataset? #91

Open
unicorneeee opened this issue Dec 9, 2024 · 2 comments
Assignees
Labels
question Further information is requested

Comments

@unicorneeee
Copy link

Describe the issue

I hope this message find U well!
I find in our minference forward function:

if self.is_search:
if os.path.exists(self.config_path):
config_list = json.load(open(self.config_path))
if self.config.num_hidden_layers == len(config_list):
assert False, f"Search completed. The config is located in {self.config_path}."
else:
config_list = []
config = {}
print("Layer", self.layer_idx)
if q_len != 1:
output = torch.empty_like(query_states)
for head in range(query_states.size(1)):
q = query_states[:, head, :, :].unsqueeze(1)
k = key_states[:, head, :, :].unsqueeze(1)
v = value_states[:, head, :, :].unsqueeze(1)
if self.is_search and self.layer_idx >= len(config_list):
config[head] = search_pattern(q, k, head)
if self.layer_idx >= self.starting_layer and not self.is_search:
attn_output = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
elif is_flash_attn_2_available():
attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
else:
attn_output = gather_qkv(q, k, v, attention_mask)
output[:, head:head + 1] = attn_output
if self.is_search:
if len(config):
config_list.append(config)
with open(self.config_path, 'w') as json_file:
json.dump(config_list, json_file)
else:
output = flash_attn_func(query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, query_states.size(1), q_len, self.head_dim)
attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)

    return attn_output, None, past_key_value

If I set the is_search ==True and add the config path, is it just search the best sparse attention pattern for the first batch data , rather than the whole dataset?

@unicorneeee unicorneeee added the question Further information is requested label Dec 9, 2024
@unicorneeee
Copy link
Author

def minference_forward():
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
**kwargs,
):
self.init_minference_parameters()
self.ne_inf = torch.finfo(hidden_states.dtype).min

    bsz, q_len, _ = hidden_states.size()

    if "q_proj" in self.__dict__["_modules"]:
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
    else:
        qkv = self.qkv_proj(hidden_states)
        query_pos = self.num_heads * self.head_dim
        key_value_pos = query_pos // self.num_key_value_groups
        query_states, key_states, value_states = torch.split(qkv, [query_pos, key_value_pos, key_value_pos], -1)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        if self.layer_idx is None:
            raise ValueError(
                f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                "with a layer index."
            )
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    set_rope_type(self)
    cos, sin = get_cos_sin(self, value_states, kv_seq_len, position_ids)
    if ROPE_TYPE == "max_seq_len":
        if cos.device != query_states.device:
            cos = cos.to(query_states.device)
        query_states = apply_rotary_pos_emb(query_states, cos)
        key_states = apply_rotary_pos_emb(key_states, cos)
    else:
        if position_ids is not None and position_ids.device != cos.device:
            position_ids = position_ids.to(cos.device)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)
    if self.is_search:
        if os.path.exists(self.config_path):
            config_list = json.load(open(self.config_path))
            if self.config.num_hidden_layers == len(config_list):
                assert False, f"Search completed. The config is located in {self.config_path}."
        else:
            config_list = []
        config = {}
        print("Layer", self.layer_idx)
    if q_len != 1:
        output = torch.empty_like(query_states)
        for head in range(query_states.size(1)):
            q = query_states[:, head, :, :].unsqueeze(1)
            k = key_states[:, head, :, :].unsqueeze(1)
            v = value_states[:, head, :, :].unsqueeze(1)
            if self.is_search and self.layer_idx >= len(config_list):
                config[head] = search_pattern(q, k, head)
            if self.layer_idx >= self.starting_layer and not self.is_search:
                attn_output = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
            elif is_flash_attn_2_available():
                attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
            else:
                attn_output = gather_qkv(q, k, v, attention_mask)
            output[:, head:head + 1] = attn_output
        if self.is_search:
            if len(config):
                config_list.append(config)
            with open(self.config_path, 'w') as json_file:
                json.dump(config_list, json_file)
    else:
        output =  flash_attn_func(query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, query_states.size(1), q_len, self.head_dim)
    attn_output = output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
    attn_output = self.o_proj(attn_output)

    return attn_output, None, past_key_value

return forward

@iofu728 iofu728 self-assigned this Dec 10, 2024
@iofu728
Copy link
Contributor

iofu728 commented Dec 10, 2024

Hi @unicorneeee, thank you for your attention! Yes, in MInference, pattern search is conducted offline and uses only a single example. However, we found that even with this setup, it demonstrates excellent generalization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants