From eb72c6a1792eaf9237cc620f6899b0e35405e377 Mon Sep 17 00:00:00 2001 From: ChenX17 Date: Sat, 13 Jul 2024 10:15:44 +0000 Subject: [PATCH 1/2] [Feature] Add LLaST --- xtuner/model/llast.py | 294 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 294 insertions(+) create mode 100644 xtuner/model/llast.py diff --git a/xtuner/model/llast.py b/xtuner/model/llast.py new file mode 100644 index 000000000..246fd7cde --- /dev/null +++ b/xtuner/model/llast.py @@ -0,0 +1,294 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import torch +import torch.nn as nn +from mmengine.config import Config, ConfigDict +from mmengine.model import BaseModel +from peft import get_peft_model, prepare_model_for_kbit_training +from transformers import PretrainedConfig, PreTrainedModel +from transformers.activations import ACT2FN + +from xtuner.dataset.llast import prepare_inputs_labels_for_llast +from xtuner.registry import BUILDER +from .modules import dispatch_modules +from .utils import (LoadWoInit, find_all_linear_names, + get_peft_model_state_dict, guess_load_checkpoint, + make_inputs_require_grad, traverse_dict) + + +class AudioProjectorConfig(PretrainedConfig): + model_type = 'projector' + _auto_class = 'AutoConfig' + + def __init__( + self, + audio_hidden_size=4096, + llm_hidden_size=4096, + depth=2, + hidden_act='gelu', + bias=True, + **kwargs, + ): + self.audio_hidden_size = audio_hidden_size + self.llm_hidden_size = llm_hidden_size + self.depth = depth + self.hidden_act = hidden_act + self.bias = bias + super().__init__(**kwargs) + + +class AudioEncoder(PreTrainedModel): + _auto_class = 'AutoModel' + config_class = AudioProjectorConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + + def __init__(self, config: AudioProjectorConfig) -> None: + super().__init__(config) + self.gradient_checkpointing = False + print('*' * 30) + print(config.audio_hidden_size, config.llm_hidden_size) + modules = [nn.Linear(config.audio_hidden_size, config.llm_hidden_size)] + for _ in range(1, config.depth): + modules.append(ACT2FN[config.hidden_act]) + modules.append( + nn.Linear( + config.llm_hidden_size, + config.llm_hidden_size, + bias=config.bias)) + self.model = nn.Sequential(*modules) + + def enable_input_require_grads(self): + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + self.model.register_forward_hook(make_inputs_require_grad) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, AudioProjectorConfig): + module.gradient_checkpointing = value + + def forward(self, x): + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x) + else: + layer_outputs = self.model(x) + return layer_outputs + + +class LLaSTModel(BaseModel): + """Implementation of LLaST. + + Acknowledge: LLaVA: Visual Instruction Tuning + (https://llava-vl.github.io/) + """ + + def __init__( + self, + llm, + speech_encoder, + freeze_llm=False, + freeze_speech_encoder=False, + speech_select_layer=-1, + pretrained_pth=None, + projector_depth=2, + llm_lora=None, + speech_encoder_lora=None, + use_activation_checkpointing=True, + ): + super().__init__() + self.freeze_llm = freeze_llm + self.freeze_speech_encoder = freeze_speech_encoder + with LoadWoInit(): + self.llm = self._build_from_cfg_or_module(llm) + self.speech_encoder = self._build_from_cfg_or_module( + speech_encoder) + + self.llm.config.use_cache = False + dispatch_modules(self.llm) + + projector_config = AudioProjectorConfig( + audio_hidden_size=self.speech_encoder.config.hidden_size, + llm_hidden_size=self.llm.config.hidden_size, + depth=projector_depth) + self.projector = AudioEncoder(projector_config).to( + self.speech_encoder.dtype) + + if self.freeze_llm: + self.llm.requires_grad_(False) + if self.freeze_speech_encoder: + self.speech_encoder.requires_grad_(False) + + if use_activation_checkpointing: + # For backward compatibility + if hasattr(self.llm, 'enable_input_require_grads'): + self.llm.enable_input_require_grads() + else: + self.llm.get_input_embeddings().register_forward_hook( + make_inputs_require_grad) + if hasattr(self.speech_encoder, 'enable_input_require_grads'): + self.speech_encoder.enable_input_require_grads() + else: + self.speech_encoder.get_input_embeddings( + ).register_forward_hook(make_inputs_require_grad) + self.projector.enable_input_require_grads() + + # enable gradient (activation) checkpointing for memory efficiency + self.gradient_checkpointing_enable() + + self.use_llm_lora = llm_lora is not None + self.use_speech_encoder_lora = speech_encoder_lora is not None + + if self.use_llm_lora: + self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) + if self.use_speech_encoder_lora: + self._prepare_speech_encoder_for_lora( + speech_encoder_lora, use_activation_checkpointing) + + if pretrained_pth is not None: + pretrained_state_dict = guess_load_checkpoint(pretrained_pth) + + out_str = self.load_state_dict(pretrained_state_dict, strict=False) + assert len(out_str.unexpected_keys) == 0, out_str.unexpected_keys + print(f'Load pretrained weight from {pretrained_pth}') + + self.speech_select_layer = speech_select_layer + + self._is_init = True + + def _parse_lora_config(self, lora_config): + if isinstance(lora_config, dict) or isinstance( + lora_config, Config) or isinstance(lora_config, ConfigDict): + lora_config = BUILDER.build(lora_config) + return lora_config + + def gradient_checkpointing_enable(self): + self.activation_checkpointing_enable() + + def activation_checkpointing_enable(self): + self.llm.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={'use_reentrant': False}) + self.speech_encoder.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={'use_reentrant': False}) + self.projector.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={'use_reentrant': False}) + + def gradient_checkpointing_disable(self): + self.activation_checkpointing_disable() + + def activation_checkpointing_disable(self): + self.llm.gradient_checkpointing_disable() + self.speech_encoder.gradient_checkpointing_disable() + self.projector.gradient_checkpointing_disable() + + def init_weights(self): + pass + + def state_dict(self, *args, **kwargs): + state_dict = super().state_dict(*args, **kwargs) + to_return = OrderedDict() + # Step 1. speech_encoder + if self.use_speech_encoder_lora: + to_return.update( + get_peft_model_state_dict( + self.speech_encoder, state_dict=state_dict)) + elif not self.freeze_speech_encoder: + to_return.update({ + k: v + for k, v in state_dict.items() if 'speech_encoder.' in k + }) + # Step 2. LLM + if self.use_llm_lora: + to_return.update( + get_peft_model_state_dict(self.llm, state_dict=state_dict)) + elif not self.freeze_llm: + to_return.update( + {k: v + for k, v in state_dict.items() if 'llm.' in k}) + # Step 3. Projector + to_return.update( + {k: v + for k, v in state_dict.items() if 'projector.' in k}) + return to_return + + def _build_from_cfg_or_module(self, cfg_or_mod): + if isinstance(cfg_or_mod, nn.Module): + return cfg_or_mod + elif isinstance(cfg_or_mod, dict): + traverse_dict(cfg_or_mod) + return BUILDER.build(cfg_or_mod) + else: + raise NotImplementedError + + def _prepare_llm_for_lora(self, + lora_config, + use_activation_checkpointing=True): + lora_config = self._parse_lora_config(lora_config) + self.llm = prepare_model_for_kbit_training( + self.llm, use_activation_checkpointing) + if lora_config.target_modules is None: + modules = find_all_linear_names(self.llm) + lora_config.target_modules = modules + self.llm = get_peft_model(self.llm, lora_config) + + def _prepare_speech_encoder_for_lora(self, + lora_config, + use_activation_checkpointing=True): + lora_config = self._parse_lora_config(lora_config) + if lora_config.target_modules is None: + modules = find_all_linear_names(self.speech_encoder) + lora_config.target_modules = modules + self.speech_encoder = get_peft_model(self.speech_encoder, lora_config) + + def forward(self, data, data_samples=None, mode='loss'): + if 'audio_tokens' in data: + data['audio_tokens'] = data['audio_tokens'].to( + self.speech_encoder.encoder.conv1.weight.dtype) + batch_size = data['audio_tokens'].shape[0] + decoder_input_ids = torch.tensor([ + [1] * batch_size + ]) * self.speech_encoder.config.decoder_start_token_id + + audio_outputs = self.speech_encoder( + data['audio_tokens'], + decoder_input_ids=decoder_input_ids.to( + data['audio_tokens'].device), + output_hidden_states=True).encoder_last_hidden_state + + audio_outputs = audio_outputs[:, :max(data['audio_lens']), :] + audio_tokens = self.projector(audio_outputs) + data['audio_tokens'] = audio_tokens + data = prepare_inputs_labels_for_llast(llm=self.llm, **data) + + if mode == 'loss': + return self.compute_loss(data, data_samples) + elif mode == 'predict': + return self.predict(data, data_samples) + elif mode == 'tensor': + return self._forward(data, data_samples) + else: + raise NotImplementedError + + def _forward(self, data, data_samples=None): + + outputs = self.llm(**data) + + return outputs + + def predict(self, data, data_samples=None): + outputs = self.llm(**data) + logits_dict = [{'logits': logits} for logits in outputs.logits] + return logits_dict + + def compute_loss(self, data, data_samples=None): + outputs = self.llm(**data) + loss_dict = {'loss': outputs.loss} + return loss_dict + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.llm, name) From 0958b5d6dfd0982b5aaca855bcee3f721641bd95 Mon Sep 17 00:00:00 2001 From: ChenX17 Date: Sat, 13 Jul 2024 11:51:27 +0000 Subject: [PATCH 2/2] Update LLaST implementation --- .pre-commit-config.yaml | 2 +- xtuner/dataset/__init__.py | 3 +- xtuner/dataset/collate_fns/__init__.py | 5 +- .../dataset/collate_fns/llast_collate_fn.py | 60 +++ xtuner/dataset/huggingface.py | 19 +- xtuner/dataset/llast.py | 492 ++++++++++++++++++ xtuner/dataset/utils.py | 20 +- xtuner/engine/__init__.py | 5 +- xtuner/engine/runner/__init__.py | 3 +- xtuner/engine/runner/llast_loops.py | 166 ++++++ xtuner/evaluation/metrics/__init__.py | 3 +- xtuner/evaluation/metrics/sacrebleu.py | 182 +++++++ xtuner/evaluation/metrics/sacrebleu_metric.py | 112 ++++ xtuner/model/__init__.py | 3 +- xtuner/model/llast.py | 1 + xtuner/tools/process_llast_data.py | 33 ++ xtuner/utils/__init__.py | 9 +- xtuner/utils/constants.py | 3 + xtuner/utils/templates.py | 4 + 19 files changed, 1108 insertions(+), 17 deletions(-) create mode 100644 xtuner/dataset/collate_fns/llast_collate_fn.py create mode 100644 xtuner/dataset/llast.py create mode 100644 xtuner/engine/runner/llast_loops.py create mode 100644 xtuner/evaluation/metrics/sacrebleu.py create mode 100644 xtuner/evaluation/metrics/sacrebleu_metric.py create mode 100644 xtuner/tools/process_llast_data.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f6bbfd633..98a23e46e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: rev: 5.0.4 hooks: - id: flake8 - args: ["--exclude=xtuner/model/transformers_models/*"] + args: ["--exclude=xtuner/model/transformers_models/*,xtuner/evaluation/metrics/sacrebleu.py"] - repo: https://github.com/PyCQA/isort rev: 5.11.5 hooks: diff --git a/xtuner/dataset/__init__.py b/xtuner/dataset/__init__.py index 2ad3d7bd9..a8e2590b3 100644 --- a/xtuner/dataset/__init__.py +++ b/xtuner/dataset/__init__.py @@ -8,6 +8,7 @@ load_intern_repo_untokenized_dataset) from .internvl_dataset import InternVL_V1_5_Dataset from .json_dataset import load_json_file +from .llast import LLaSTDataset from .llava import LLaVADataset from .modelscope import process_ms_dataset from .moss_sft import MOSSSFTDataset @@ -25,5 +26,5 @@ 'load_intern_repo_tokenized_dataset', 'load_intern_repo_untokenized_dataset', 'build_packed_dataset', 'RefCOCOJsonDataset', 'RefCOCOJsonEvalDataset', 'InvRefCOCOJsonDataset', - 'load_json_file', 'InternVL_V1_5_Dataset' + 'load_json_file', 'InternVL_V1_5_Dataset', 'LLaSTDataset' ] diff --git a/xtuner/dataset/collate_fns/__init__.py b/xtuner/dataset/collate_fns/__init__.py index 96652b259..86b8e29cf 100644 --- a/xtuner/dataset/collate_fns/__init__.py +++ b/xtuner/dataset/collate_fns/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .default_collate_fn import default_collate_fn +from .llast_collate_fn import llast_audiomask_mel_collate_fn from .mmlu_collate_fn import mmlu_collate_fn -__all__ = ['default_collate_fn', 'mmlu_collate_fn'] +__all__ = [ + 'default_collate_fn', 'mmlu_collate_fn', 'llast_audiomask_mel_collate_fn' +] diff --git a/xtuner/dataset/collate_fns/llast_collate_fn.py b/xtuner/dataset/collate_fns/llast_collate_fn.py new file mode 100644 index 000000000..01d926b3d --- /dev/null +++ b/xtuner/dataset/collate_fns/llast_collate_fn.py @@ -0,0 +1,60 @@ +# Copyright (c) LLaST. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Sequence + +import torch +from torch.nn.utils.rnn import pad_sequence + +from xtuner.utils import (DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX, + LLAST_AUDIO_PADDING_TOKEN_INDEX) + + +def llast_audiomask_mel_collate_fn( + instances: Sequence[Dict], + pad_index: int = DEFAULT_PAD_TOKEN_INDEX, + return_hf_format: bool = False) -> Dict[str, torch.Tensor]: + """Add audio tokens and conduct padding operation.""" + input_ids = [] + labels = [] + feats_lens = [] + has_audio = any(inst.get('audio_tokens') is not None for inst in instances) + + if has_audio: + audio_tokens = [] + for example in instances: + input_ids.append(torch.tensor(example['input_ids'])) + labels.append(torch.tensor(example['labels'])) + if has_audio: + audio_tokens.append(example['audio_tokens']) + feats_lens.append(torch.tensor(example['audio_lens'])) + if len(instances) > 1: + input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=pad_index) + labels = pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX) + # padding audio tokens + padded_audio_tokens = pad_sequence( + audio_tokens, + batch_first=True, + padding_value=LLAST_AUDIO_PADDING_TOKEN_INDEX) + + else: + input_ids = torch.stack(input_ids) + labels = torch.stack(labels) + padded_audio_tokens = torch.stack(audio_tokens) + + data_dict = { + 'input_ids': input_ids, + 'attention_mask': input_ids.ne(pad_index), + 'labels': labels + } + + if has_audio: + audio_lens = torch.stack(feats_lens) + data_dict['audio_tokens'] = padded_audio_tokens + data_dict['audio_lens'] = audio_lens + + if return_hf_format: + return data_dict + else: + return {'data': data_dict, 'data_samples': instances} diff --git a/xtuner/dataset/huggingface.py b/xtuner/dataset/huggingface.py index c44e88688..86d04641e 100644 --- a/xtuner/dataset/huggingface.py +++ b/xtuner/dataset/huggingface.py @@ -65,8 +65,8 @@ def add_template_to_dataset(dataset, template_map_fn, map_num_proc): def tokenize_dataset(dataset, tokenizer, max_length, with_image_token, - input_ids_with_output, remove_unused_columns, - map_num_proc): + with_audio_token, input_ids_with_output, + remove_unused_columns, map_num_proc): assert (tokenizer is not None) and (max_length is not None), \ f'({tokenizer}, {max_length})' if isinstance(tokenizer, dict) or isinstance( @@ -78,6 +78,7 @@ def tokenize_dataset(dataset, tokenizer, max_length, with_image_token, tokenizer=tokenizer, max_length=max_length, with_image_token=with_image_token, + with_audio_token=with_audio_token, input_ids_with_output=input_ids_with_output), remove_columns=list(dataset.column_names) if remove_unused_columns else None, @@ -112,6 +113,7 @@ def process(dataset, use_varlen_attn=False, input_ids_with_output=True, with_image_token=False, + with_audio_token=False, map_num_proc=32): """Post-process the dataset loaded from the Hugging Face Hub, or a local dataset. @@ -153,6 +155,9 @@ def process(dataset, with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to IMAGE_TOKEN_INDEX. Typically set it to True during the training of VLM. + with_audio_token: Whether to convert DEFAULT_AUDIO_TOKEN to + LLAST_AUDIO_TOKEN_INDEX. Typically set it to True during the + training of SLM. map_num_proc: Max number of processes when mapping the dataset. """ if use_varlen_attn: @@ -197,7 +202,8 @@ def process(dataset, if do_dataset_tokenization: dataset = tokenize_dataset(dataset, tokenizer, max_length, - with_image_token, input_ids_with_output, + with_image_token, with_audio_token, + input_ids_with_output, remove_unused_columns, map_num_proc) if input_ids_with_output: @@ -213,7 +219,7 @@ def process(dataset, shuffle_before_pack, map_num_proc) # add 'length' - dataset = dataset.map(get_lengths, num_proc=map_num_proc) + dataset = dataset.map(get_lengths, num_proc=1) setattr(dataset, 'length', dataset['length']) return dataset @@ -234,6 +240,7 @@ def process_hf_dataset(dataset, use_varlen_attn=False, input_ids_with_output=True, with_image_token=False, + with_audio_token=False, map_num_proc=32): """Post-process the dataset loaded from the Hugging Face Hub, or a local dataset. @@ -275,6 +282,9 @@ def process_hf_dataset(dataset, with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to IMAGE_TOKEN_INDEX. Typically set it to True during the training of VLM. + with_audio_token: Whether to convert DEFAULT_AUDIO_TOKEN to + LLAST_AUDIO_TOKEN_INDEX. Typically set it to True during the + training of SLM. map_num_proc: Max number of processes when mapping the dataset. """ kwargs = dict( @@ -293,6 +303,7 @@ def process_hf_dataset(dataset, use_varlen_attn=use_varlen_attn, input_ids_with_output=input_ids_with_output, with_image_token=with_image_token, + with_audio_token=with_audio_token, map_num_proc=map_num_proc) if not (dist.is_available() and dist.is_initialized()): return process(**kwargs) diff --git a/xtuner/dataset/llast.py b/xtuner/dataset/llast.py new file mode 100644 index 000000000..de54f8fef --- /dev/null +++ b/xtuner/dataset/llast.py @@ -0,0 +1,492 @@ +# Copyright (c) LLaST. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. +import csv +import logging +import os +from typing import List, Optional + +import torch +import whisper +from datasets import Dataset as HFDataset +from datasets import DatasetDict, load_from_disk +from mmengine import print_log +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import PreTrainedModel + +from xtuner.dataset.huggingface import process_hf_dataset +from xtuner.utils import IGNORE_INDEX, LLAST_AUDIO_TOKEN_INDEX + +# Language Mapping +LANG_DICT = { + 'English': 'en', + 'French': 'fr', + 'Spanish': 'es', + 'Chinese': 'zh-CN', + 'German': 'de', + 'Japanese': 'ja', + 'Catalan': 'ca', + 'Italian': 'it', + 'Russian': 'ru', + 'Portuguese': 'pt', + 'Persian': 'fa', + 'Estonian': 'et', + 'Mongolian': 'mn', + 'Dutch': 'nl', + 'Turkish': 'tr', + 'Arabic': 'ar', + 'Swedish': 'sv-SE', + 'Latvian': 'lv', + 'Slovenian': 'sl', + 'Tamil': 'ta', + 'Indonesian': 'id', + 'Welsh': 'cy' +} + +SIM_LANG_DICT = {v: k for k, v in LANG_DICT.items()} + +device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu' + + +def convert_data( + data_path, + mode='s2t', + src_lang='French', + tgt_lang='English', + check_audio_path=False, + cv_dir='', + audio_folder='', + postfix='', +): + """ + Args: + mode (str): + 's2t': speech-to-text translation + 's2s': speech-to-speech translation + 'asr': speech-to-text recognition + src_lang (str): + tgt_lang (str): + + Return: + output_list (List) + ids_list (List) + """ + + assert src_lang in list( + LANG_DICT.keys()), 'src_languge: {} is not supported currently.' + assert tgt_lang in list( + LANG_DICT.keys()), 'tgt_language: {} is not supported currently.' + + with open(data_path) as f: + reader = csv.DictReader( + f, + delimiter='\t', + quotechar=None, + doublequote=False, + lineterminator='\n', + quoting=csv.QUOTE_NONE, + ) + raw_data = [dict(e) for e in reader] + + output_list = [] + ids_list = [] + for item in raw_data: + tgt_lang_text = item['translation'] + src_lang_text = item['sentence'] + + if mode == 's2t': + conv = { + 'input': + '