-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Integrated Training and Inference -- Part 1 #532
base: main
Are you sure you want to change the base?
Conversation
* support sequence * add configs * add sp example to custom dataset * WIP * add dispatch utils * delete useless codes * move xtuner/engine/sequence_parallel to xtuner/parallel/sequence * fix lint * fix lint * add init_dist to xtuner and add trust_remote_code=True to AutoConfig * add internlm2 custom_dataset sp4 config * Sequence Parallel doc V1 * Sequence Parallel doc V1 * Sequence Parallel doc V1 * fix bugs in llama_varlen_attn_forward * rename indexes to position_ids * add attn_implementation to config * delete useless codes * fix lint * refine default_collate_fn * refine doc * refine doc * refine doc * delete replace_internlm2_rote * add repeat_kv_bshd * fix apply_rotary_pos_emb bug * add enable_sequence_parallel flag * refine doc * assert {'input_ids', 'labels'}.issubset(dataset.column_names) * refine doc
attn_kwargs = cls._flash_attn_kwargs(config) | ||
kwargs.update(attn_kwargs) | ||
|
||
if torch.cuda.is_bf16_supported(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样写的话,用户是不是没法通过配置或者输入参数修改模型类型?
xtuner/model/auto.py
Outdated
return model | ||
|
||
@staticmethod | ||
def _flash_attn_kwargs(config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果用户自己新曾了一个 llm,这个字段应该如何修改?或者说用户如何知道要修改?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个操作主要是为了保证 attn_mask shape 的正确性 (flash_attn, sdpa 和普通 attn 的attn_mask可能不同)。
感觉之后可以把 _built_in_flash_attn_1
_built_in_flash_attn_2
放到一个别的什么地方,之后出一个文档讲一下新增模型需要考虑的东西
from pydantic import BaseModel | ||
|
||
|
||
class SampleParams(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个对象可以在配置里面修改吗?同时要考虑在评测时候不同数据这个参数不一样。需要在评测时候实时传给 model
xtuner/model/auto.py
Outdated
checkpoint: str, | ||
config: Optional[str] = None, | ||
from_hub: bool = False): | ||
config = Config.fromfile(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方,是不是得配合着下面的if-else分支针对 config 是否为 None 做个判断?
如果我要一起训练Alpaca和Alpaca-zh,我是先分别convert之后再用ConcatDataset还是一起convert |
xtuner/model/text/finetune.py
Outdated
position_ids.append(torch.arange(chunk_tokens)) | ||
position_ids = torch.cat(position_ids, dim=0).unsqueeze(0) | ||
|
||
from mmengine import MessageHub |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码位置
xtuner/tools/convert_dataset.py
Outdated
def main(): | ||
args = parse_args() | ||
|
||
dataset = load_dataset(path=args.path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load 方式有很多
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
支持老用户输入 config,将数据转换成新的
xtuner/model/auto.py
Outdated
else: | ||
raise RuntimeError | ||
|
||
model: BaseAlgorithm = BUILDER.build(config.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一步会自动下载未finetune的模型,应该得想办法避免
assert eos_token_ids is not None, \ | ||
'Please set eos_token for Qwen tokenizer!' | ||
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': | ||
eos_token_ids = tokenizer.eos_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 if 和下面的 else 有啥区别吗?
|
||
shard = converted.select(range(begin, end)).to_list() | ||
with open(save_path, 'w') as f: | ||
json.dump(shard, f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
json.dump(shard, f) | |
json.dump(shard, f, indent=2) |
chat_template: Union[Dict, ChatTemplate], | ||
sample_ratio: Union[float, List[float]] = 1.0, | ||
max_length: int = 2048, | ||
pack_to_max_length: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加 shuffle_before_pack 参数?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在默认就是 shuffle before pack,会有场景需要 pack 前不 shuffle 么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在 pretrain 场景,同一上下文的数据往往是相连的,会有人想要它们相邻。
if isinstance(sample_ratio, (list, tuple)): | ||
if len(sample_ratio) != len(data_files): | ||
raise ValueError('The length of `sample_ratio`' | ||
f'({len(sample_ratio)}) should be the same ' | ||
'as the length of `data_files`' | ||
f'({len(data_files)})') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
若 data_files
为 None,使用 data_dir
来传数据的时候,这个地方会报错。考虑在此之前就把 data_dir
转换成 data_files
?
return dataset | ||
|
||
def filter_non_labels_data(self, dataset: List[dict]) -> List[dict]: | ||
"""Filter the data which all labels are ignore. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Filter the data which all labels are ignore. | |
"""Filter out data that do not contain valid labels. |
f'Filtered {ori_samples - new_samples} samples ' | ||
'(all labels are ignore)', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f'Filtered {ori_samples - new_samples} samples ' | |
'(all labels are ignore)', | |
f'Filtered {ori_samples - new_samples} samples ' | |
'that do not contain valid labels.', |
if torch.cuda.is_bf16_supported(): | ||
kwargs.update(torch_dtype=torch.bfloat16) | ||
else: | ||
kwargs.update(torch_dtype=torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果不使用 DeepSpeed,直接使用普通的 amp optimizer,会报错。
# 设为bf16
RuntimeError: "_amp_foreach_non_finite_check_and_unscale_cuda" not implemented for 'BFloat16'
# 设为fp16
ValueError: Attempting to unscale FP16 gradients.
runner.logger.info(f'(ChatHook {position}){answer}') | ||
|
||
def before_train(self, runner: Union[Runner, FlexibleRunner]): | ||
runner.logger.info('before_train in EvaluateChatHook.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
runner.logger.info('before_train in EvaluateChatHook.') | |
runner.logger.info('before_train in ChatHook.') |
无法使用work_dirs保存的config进行训练,目前我是卡在了 |
super().__init__() | ||
|
||
self.llm = llm | ||
self.llm.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是 quant 模型,直接 cuda 会有问题?
# PART 2 Model & Tokenizer # | ||
####################################################################### | ||
model = dict( | ||
type=TextFinetune, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
传入 use_varlen_attn
pr 567 的修改需要同步 |
加载模型 & Chat 用例:
xtuner/model/auto.py
训练 alpaca