Skip to content

Commit

Permalink
fix baichuan mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Jan 17, 2024
1 parent b319dce commit f1ed657
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 14 deletions.
120 changes: 110 additions & 10 deletions lmdeploy/pytorch/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,92 @@ def _qkv_proj(hidden_states):

class BaichuanModel(nn.Module):

def _continuous_batching_forward_7b(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite implementation of LlamaModel.forward."""
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)

if use_cache is None:
use_cache = self.config.use_cache

return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)

assert (
position_ids is not None
), 'position_ids can not be none when using continuous batching mode.'
assert position_ids.dim() == 2

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# Attention mask is not necessary in continuous batching
attention_mask = None

hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )

past_key_value = (past_key_values[idx]
if past_key_values is not None else None)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )

if output_attentions:
all_self_attns += (layer_outputs[1], )

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )

next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

def _continuous_batching_forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -279,6 +365,7 @@ def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
Expand All @@ -287,13 +374,26 @@ def forward(
return_dict: Optional[bool] = True,
):
"""Rewrite of BaichuanModel.forward."""
return self._continuous_batching_forward(
input_ids,
attention_mask,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
)
if position_ids is not None:
return self._continuous_batching_forward_7b(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
)
else:
return self._continuous_batching_forward(
input_ids,
attention_mask,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
)
4 changes: 0 additions & 4 deletions lmdeploy/pytorch/models/module_map.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from transformers.utils import TRANSFORMERS_DYNAMIC_MODULE_NAME

LMDEPLOY_PYTORCH_MODEL_PATH = 'lmdeploy.pytorch.models'

Expand Down Expand Up @@ -52,9 +51,6 @@
MODULE_MAP.update({
'modeling_baichuan.Model':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', # noqa
(f'{TRANSFORMERS_DYNAMIC_MODULE_NAME}.Baichuan2-7B-Chat'
'.modeling_baichuan.BaichuanModel'):
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', # noqa
'modeling_baichuan.BaichuanModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanModel', # noqa
'modeling_baichuan.Attention':
Expand Down

0 comments on commit f1ed657

Please sign in to comment.