From 9f5d6f256202e1c9259a772bab8b04fc57ba8c20 Mon Sep 17 00:00:00 2001 From: yqy2001 <178526723@qq.com> Date: Sun, 31 Dec 2023 03:51:38 +0000 Subject: [PATCH 1/4] run ochat code locally: model placeholder & fix bf16 bugs after ln --- ochat/models/unpadded_mistral.py | 5 +++++ ochat/training_deepspeed/train.py | 5 ++++- run.sh | 11 +++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 run.sh diff --git a/ochat/models/unpadded_mistral.py b/ochat/models/unpadded_mistral.py index 2abb216..0a9e76c 100644 --- a/ochat/models/unpadded_mistral.py +++ b/ochat/models/unpadded_mistral.py @@ -218,6 +218,8 @@ def forward( residual = nz_hidden_states nz_hidden_states = self.input_layernorm(nz_hidden_states) + nz_hidden_states = nz_hidden_states.to(torch.bfloat16) + nz_hidden_states = self.self_attn( cos_sin=cos_sin, @@ -232,6 +234,8 @@ def forward( residual = nz_hidden_states nz_hidden_states = self.post_attention_layernorm(nz_hidden_states) + nz_hidden_states = nz_hidden_states.to(torch.bfloat16) + nz_hidden_states = self.mlp(nz_hidden_states) nz_hidden_states = residual + nz_hidden_states @@ -321,6 +325,7 @@ def forward( ) nz_hidden_states = self.norm(nz_hidden_states) + nz_hidden_states = nz_hidden_states.to(torch.bfloat16) return nz_hidden_states diff --git a/ochat/training_deepspeed/train.py b/ochat/training_deepspeed/train.py index 1126087..08b5a73 100644 --- a/ochat/training_deepspeed/train.py +++ b/ochat/training_deepspeed/train.py @@ -86,7 +86,10 @@ def create_model(args): print(f"Loading model {args.model_type} from {args.model_path}...") # Create model + optimizer + lr scheduler - model = MODEL_CONFIG_MAP[args.model_type].model_create_for_training(args.model_path) + # model = MODEL_CONFIG_MAP[args.model_type].model_create_for_training(args.model_path) + import transformers, ochat + model=ochat.models.MistralForCausalLM(transformers.models.mistral.configuration_mistral.MistralConfig.from_pretrained(args.model_path)) + # Model to assigned cuda device model = model.to(args.local_rank) # Enable gradient checkpointing diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..acb91db --- /dev/null +++ b/run.sh @@ -0,0 +1,11 @@ +NUM_GPUS=8 + +deepspeed --num_gpus=$NUM_GPUS --master_port 12345 --module ochat.training_deepspeed.train \ + --model_path /share/project/qiying/datasets/llava/Mistral_7B_with_EOT_token \ + --data_prefix output \ + --save_path ./output \ + --batch_max_len 2048 \ + --epochs 5 \ + --save_every 1 \ + --deepspeed \ + --deepspeed_config ochat/training_deepspeed/deepspeed_config.json \ No newline at end of file From f2d4fd42bd47a10b7a1acc72ee491a685aaf791b Mon Sep 17 00:00:00 2001 From: yqy2001 <178526723@qq.com> Date: Sun, 31 Dec 2023 08:55:30 +0000 Subject: [PATCH 2/4] dataset: convert llava data to openchat format --- ochat/config/__init__.py | 2 +- ochat/config/conversation_template.py | 8 ++ ochat/data/generate_dataset.py | 17 ++- ochat/data/process_multimodal_data.py | 39 +++++++ ochat/models/__init__.py | 2 +- ochat/models/unpadded_mistral.py | 61 +++++++++++ ochat/training_deepspeed/openchat_dataset.py | 108 +++++++++++++++++++ ochat/training_deepspeed/train.py | 32 ++++-- 8 files changed, 254 insertions(+), 15 deletions(-) create mode 100644 ochat/data/process_multimodal_data.py diff --git a/ochat/config/__init__.py b/ochat/config/__init__.py index be8c935..29fc0b9 100644 --- a/ochat/config/__init__.py +++ b/ochat/config/__init__.py @@ -4,7 +4,7 @@ import transformers from ochat.config.model_config import ModelConfig -from ochat.config.conversation_template import Message, Conversation, ConversationTemplate +from ochat.config.conversation_template import Message, Conversation, ConversationTemplate, MultimodalConversation import ochat.models diff --git a/ochat/config/conversation_template.py b/ochat/config/conversation_template.py index 3e0e79b..b4aeebb 100644 --- a/ochat/config/conversation_template.py +++ b/ochat/config/conversation_template.py @@ -15,6 +15,14 @@ class Conversation(BaseModel): condition: str = "" system: str = "" + + +class MultimodalConversation(BaseModel): + items: List[Message] + + image: str = "" + condition: str = "" + system: str = "" class ConversationTemplate(BaseModel): diff --git a/ochat/data/generate_dataset.py b/ochat/data/generate_dataset.py index f1ccf39..a675073 100644 --- a/ochat/data/generate_dataset.py +++ b/ochat/data/generate_dataset.py @@ -33,7 +33,7 @@ def truncate_trailing_zero_weighted(tokens, weights): return tokens[:non_zero_index + 1], weights[:non_zero_index + 1] -def add_single_conv(output, tokens, weights): +def add_single_conv(output, tokens, weights, image_file=None): # truncate trailing zero weighted tokens tokens, weights = truncate_trailing_zero_weighted(tokens, weights) if not tokens: @@ -55,6 +55,9 @@ def add_single_conv(output, tokens, weights): "nz_shifted_loss_weights": weights[1:] + [0.0] } results["num_seqs"] = sum(results["nz_shifted_loss_weights"]) + + if image_file is not None: + results['image_file'] = image_file for k, v in results.items(): output[k].append(v) @@ -62,7 +65,7 @@ def add_single_conv(output, tokens, weights): @ray.remote def convert_conversation_batch(model_type: str, model_path: str, batch: list, schema: pyarrow.Schema, per_sequence_loss: bool): - from ochat.config import MODEL_CONFIG_MAP, Conversation + from ochat.config import MODEL_CONFIG_MAP, Conversation, MultimodalConversation # Tokenization model_config = MODEL_CONFIG_MAP[model_type] @@ -71,7 +74,8 @@ def convert_conversation_batch(model_type: str, model_path: str, batch: list, sc # Decode data print ("Decoding JSON ...") - batch = [Conversation(**orjson.loads(json_line)) for json_line in batch] + ConversationCls = MultimodalConversation if "image" in batch[0] else Conversation + batch = [ConversationCls(**orjson.loads(json_line)) for json_line in batch] # Tokenize print ("Tokenizing ...") @@ -82,7 +86,7 @@ def convert_conversation_batch(model_type: str, model_path: str, batch: list, sc max_context = model_config.model_max_context outputs = {k: [] for k in schema.names} - for tokens, weights in zip(tokens_list, weights_list): + for tokens, weights, conv in zip(tokens_list, weights_list, batch): assert len(tokens) == len(weights) # Truncate to specified tokens @@ -90,7 +94,7 @@ def convert_conversation_batch(model_type: str, model_path: str, batch: list, sc weights = weights[:max_context] # Add to results - add_single_conv(outputs, tokens, weights) + add_single_conv(outputs, tokens, weights, conv.image if hasattr(conv, "image") else None) print ("Chunk finish") @@ -112,6 +116,9 @@ def generate_split(model_type: str, model_path: str, conversations: list, split_ pyarrow.field(f"nz_shifted_label_ids", pyarrow.list_(pyarrow.int32())), pyarrow.field(f"nz_shifted_loss_weights", pyarrow.list_(pyarrow.float32())) ] + + if "image" in conversations[0]: # multimodal data + schema.append(pyarrow.field(f"image_file", pyarrow.list_(pyarrow.string()))) schema = pyarrow.schema(schema, metadata={"metadata_json": orjson.dumps(metadata)}) diff --git a/ochat/data/process_multimodal_data.py b/ochat/data/process_multimodal_data.py new file mode 100644 index 0000000..2f56363 --- /dev/null +++ b/ochat/data/process_multimodal_data.py @@ -0,0 +1,39 @@ +import os +import json + +def convert_llava_data_to_ochat_format(file_path, save_path): + """ + each item of llava data: id, image, conversations + """ + with open(file_path, "r") as f: + data = json.load(f) + + ochat_data_list = [] + for sample in data: + ochat_sample = {} + ochat_sample['image'] = sample['image'] + + conversations = [] + for uttr in sample['conversations']: + assert uttr['from'] in ['human', 'gpt'] + conversations.append({ + "role": "user" if uttr['from'] == 'human' else "assistant", + "content": uttr['value'], + "weight": 0.0 if uttr['from'] == 'human' else 1.0, + }) + ochat_sample['items'] = conversations + ochat_sample['system'] = "hello boy" + + ochat_data_list.append(ochat_sample) + + with open(os.path.join(save_path, "blip_laion_cc_sbu_558k_ochat.jsonl"), "w") as f: + for entry in ochat_data_list: + json_string = json.dumps(entry) + f.write(json_string + '\n') + +file_path = "/share/project/qiying/datasets/llava/blip_laion_cc_sbu_558k.json" +save_path = "/share/project/qiying/datasets/llava" +convert_llava_data_to_ochat_format(file_path, save_path) + + +# python -m ochat.data.generate_dataset --model-type openchat_v3.2_mistral --model-path /share/project/qiying/datasets/llava/Mistral_7B_with_EOT_token --in-files /share/project/qiying/datasets/llava/blip_laion_cc_sbu_558k_ochat.jsonl --out-prefix ./output \ No newline at end of file diff --git a/ochat/models/__init__.py b/ochat/models/__init__.py index 7307e83..a66c8a1 100644 --- a/ochat/models/__init__.py +++ b/ochat/models/__init__.py @@ -1,2 +1,2 @@ from ochat.models.unpadded_llama import LlamaForCausalLM -from ochat.models.unpadded_mistral import MistralForCausalLM +from ochat.models.unpadded_mistral import MistralForCausalLM, MultimodalMistralForCausalLM diff --git a/ochat/models/unpadded_mistral.py b/ochat/models/unpadded_mistral.py index 0a9e76c..e8426fa 100644 --- a/ochat/models/unpadded_mistral.py +++ b/ochat/models/unpadded_mistral.py @@ -390,6 +390,67 @@ def forward( logits=logits ) +class MultimodalMistralForCausalLM(UnpaddedMistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = UnpaddedMistralModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + # Unpadded inputs + nz_input_ids: torch.Tensor, + nz_position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + # Unpadded labels + nz_shifted_label_ids: Optional[torch.Tensor] = None, + nz_shifted_loss_weights: Optional[torch.Tensor] = None, + image_list = None + ) -> CausalLMOutputWithPast: + # Model logits + hidden_states = self.model( + nz_input_ids=nz_input_ids, + nz_position_ids=nz_position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen + ) + logits = self.lm_head(hidden_states) + + loss = None + if nz_shifted_label_ids is not None: + assert nz_shifted_loss_weights is not None + + loss = weighted_cross_entropy(logits, nz_shifted_label_ids, nz_shifted_loss_weights), \ + weighted_token_accuracy(logits.detach(), nz_shifted_label_ids, nz_shifted_loss_weights) + + return CausalLMOutputWithPast( + loss=loss, # type: ignore + logits=logits + ) + class PaddedMistralForCausalLM(MistralForCausalLM): """Compat layer for padded inputs""" diff --git a/ochat/training_deepspeed/openchat_dataset.py b/ochat/training_deepspeed/openchat_dataset.py index dae3e78..b8e0628 100644 --- a/ochat/training_deepspeed/openchat_dataset.py +++ b/ochat/training_deepspeed/openchat_dataset.py @@ -1,3 +1,6 @@ +import json +from PIL import Image + import torch import numpy as np from torch.utils.data import IterableDataset, get_worker_info @@ -108,3 +111,108 @@ def __iter__(self): def estimate_num_batches(self): return self.sampler.estimate_num_batches() + + +class OpenchatMultimodalDataset(IterableDataset): + def __init__(self, dataset_filename, image_root, batch_max_length, rank, num_replicas): + super().__init__() + # Init constants + self.PAD_ID = 0 + self.PAD_MULTIPLE = 64 + self.BATCH_KEYS = { + "seqlens": torch.int32, + "nz_input_ids": torch.long, + "nz_position_ids": torch.long, + "nz_shifted_label_ids": torch.long, + + "nz_shifted_loss_weights": torch.bfloat16 + } + + assert batch_max_length % self.PAD_MULTIPLE == 0, f"Batch size {batch_max_length} need to be multiples of {self.PAD_MULTIPLE}" + + # Load data + # Convert parquet to numpy for fast random access + table = pq.read_table(dataset_filename, memory_map=True) + self.dataset = {k: v.to_numpy() for k, v in zip(table.column_names, table.columns)} + + # read metadata + self.metadata = table.schema.metadata.get(b"metadata_json", None) + if self.metadata is not None: + self.metadata = orjson.loads(self.metadata) + + # Free table space + del table + + # Create sampler + self.sampler = MultipackDistributedSampler( + lengths=self.dataset["total_length"], + numseqs=self.dataset["num_seqs"], + + batch_max_length=batch_max_length, + + rank=rank, + num_replicas=num_replicas, + seed=0 + ) + + # Init state + self._epoch = 0 + + self.image_root = image_root + # ''.join(self.dataset['image_file'][0]) + + def _load_batch(self, indices): + batch = {k: v[indices] for k, v in self.dataset.items()} + image_list = [''.join(i) for i in batch['image_file']] + + # Concat batches + batch = {k: np.concatenate(batch[k], axis=0) for k in self.BATCH_KEYS.keys()} + + # Pad an unused item to reach multiple of PAD_MULTIPLE, for faster GEMM + total_seqlen = batch["nz_input_ids"].size + pad_len = _find_multiple(total_seqlen, self.PAD_MULTIPLE) - total_seqlen + + if pad_len > 0: + assert pad_len < self.PAD_MULTIPLE + + # total length + padding_specs = { + "seqlens": (1, pad_len), + + "nz_input_ids": (pad_len, self.PAD_ID), + "nz_position_ids": (pad_len, 0), + "nz_shifted_label_ids": (pad_len, self.PAD_ID), + "nz_shifted_loss_weights": (pad_len, 0), + } + for k, pad_spec in padding_specs.items(): + batch[k] = np.concatenate((batch[k], np.full(*pad_spec, dtype=batch[k].dtype)), axis=0) + + # to tensor + batch_tensor = {} + for k, dtype in self.BATCH_KEYS.items(): + batch_tensor[k] = torch.from_numpy(batch[k]).to(dtype) + + # cu seqlens + batch_tensor["cu_seqlens"] = torch.nn.functional.pad(batch_tensor["seqlens"].cumsum(-1, dtype=torch.int32), (1, 0)) + # batch info + batch_info = { + "max_seqlen": torch.max(batch_tensor["seqlens"]).item(), + "image_list": image_list + } + + # inputs + del batch_tensor["seqlens"] + return batch_tensor, batch_info + + def __iter__(self): + worker_info = get_worker_info() + assert worker_info is None or worker_info.num_workers == 1 + + for indices, all_numseq, cur_numseq in self.sampler.iter(self._epoch): + yield self._load_batch(indices), all_numseq, cur_numseq + + # Increase epoch count + self._epoch += 1 + + def estimate_num_batches(self): + return self.sampler.estimate_num_batches() diff --git a/ochat/training_deepspeed/train.py b/ochat/training_deepspeed/train.py index 08b5a73..1fdfc7e 100644 --- a/ochat/training_deepspeed/train.py +++ b/ochat/training_deepspeed/train.py @@ -13,7 +13,7 @@ import numpy as np from ochat.config import MODEL_CONFIG_MAP -from ochat.training_deepspeed.openchat_dataset import OpenchatDataset +from ochat.training_deepspeed.openchat_dataset import OpenchatDataset, OpenchatMultimodalDataset try: import deepspeed @@ -46,6 +46,10 @@ def parse_args(): parser.add_argument("--beta1", type=float, default=0.9) parser.add_argument("--beta2", type=float, default=0.95) parser.add_argument("--eps", type=float, default=1e-5) + + # multimodal + parser.add_argument("--multimodal", action='store_true', default=False) + parser.add_argument("--image_root", type=str, default="/share/project/qiying/datasets/llava/pretrain") # DeepSpeed parameters parser = deepspeed.add_config_arguments(parser) @@ -62,13 +66,23 @@ def create_dataset_and_dataloader(args, split_name): # Create dataset and dataloader print(f"Loading {split_name} data from {filename}...") - dataset = OpenchatDataset( - dataset_filename=filename, + if args.multimodal: + dataset = OpenchatMultimodalDataset( + dataset_filename=filename, + image_root=args.image_root, + batch_max_length=args.batch_max_len, + rank=dist.get_rank(), + num_replicas=dist.get_world_size() + ) + else: + dataset = OpenchatDataset( + dataset_filename=filename, + + batch_max_length=args.batch_max_len, + rank=dist.get_rank(), + num_replicas=dist.get_world_size() + ) - batch_max_length=args.batch_max_len, - rank=dist.get_rank(), - num_replicas=dist.get_world_size() - ) dataloader = DataLoader( dataset, batch_size=None, @@ -88,7 +102,9 @@ def create_model(args): # Create model + optimizer + lr scheduler # model = MODEL_CONFIG_MAP[args.model_type].model_create_for_training(args.model_path) import transformers, ochat - model=ochat.models.MistralForCausalLM(transformers.models.mistral.configuration_mistral.MistralConfig.from_pretrained(args.model_path)) + model=ochat.models.MultimodalMistralForCausalLM(transformers.models.mistral.configuration_mistral.MistralConfig.from_pretrained(args.model_path)) + + # MistralForCausalLM # Model to assigned cuda device model = model.to(args.local_rank) From 9a2322abbdfdc26f3ead8e8f43294fd47b64eb0a Mon Sep 17 00:00:00 2001 From: yqy2001 <178526723@qq.com> Date: Sun, 31 Dec 2023 09:53:35 +0000 Subject: [PATCH 3/4] add image preprocessor --- ochat/models/unpadded_mistral.py | 2 +- ochat/training_deepspeed/openchat_dataset.py | 22 ++++++++++++++++---- ochat/training_deepspeed/train.py | 2 ++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/ochat/models/unpadded_mistral.py b/ochat/models/unpadded_mistral.py index e8426fa..5ffd190 100644 --- a/ochat/models/unpadded_mistral.py +++ b/ochat/models/unpadded_mistral.py @@ -421,6 +421,7 @@ def get_decoder(self): def forward( self, # Unpadded inputs + image_tensor: torch.Tensor, nz_input_ids: torch.Tensor, nz_position_ids: torch.Tensor, cu_seqlens: torch.Tensor, @@ -428,7 +429,6 @@ def forward( # Unpadded labels nz_shifted_label_ids: Optional[torch.Tensor] = None, nz_shifted_loss_weights: Optional[torch.Tensor] = None, - image_list = None ) -> CausalLMOutputWithPast: # Model logits hidden_states = self.model( diff --git a/ochat/training_deepspeed/openchat_dataset.py b/ochat/training_deepspeed/openchat_dataset.py index b8e0628..102291e 100644 --- a/ochat/training_deepspeed/openchat_dataset.py +++ b/ochat/training_deepspeed/openchat_dataset.py @@ -1,7 +1,9 @@ import json from PIL import Image +from pathlib import Path import torch +import torchvision.transforms as transforms import numpy as np from torch.utils.data import IterableDataset, get_worker_info @@ -114,7 +116,7 @@ def estimate_num_batches(self): class OpenchatMultimodalDataset(IterableDataset): - def __init__(self, dataset_filename, image_root, batch_max_length, rank, num_replicas): + def __init__(self, dataset_filename, image_root, image_size, batch_max_length, rank, num_replicas): super().__init__() # Init constants self.PAD_ID = 0 @@ -158,12 +160,23 @@ def __init__(self, dataset_filename, image_root, batch_max_length, rank, num_rep # Init state self._epoch = 0 - self.image_root = image_root + self.image_root = Path(image_root) + + OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) + OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) + self.image_transform = transforms.Compose([ + transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD) + ]) # ''.join(self.dataset['image_file'][0]) def _load_batch(self, indices): batch = {k: v[indices] for k, v in self.dataset.items()} - image_list = [''.join(i) for i in batch['image_file']] + # image process + image_list = [''.join(image) for image in batch['image_file']] + image_list = [self.image_transform(Image.open(self.image_root / image).convert('RGB')) for image in image_list] + image_tensor = torch.stack(image_list) # [B, 3, 224, 224] # Concat batches batch = {k: np.concatenate(batch[k], axis=0) for k in self.BATCH_KEYS.keys()} @@ -197,8 +210,9 @@ def _load_batch(self, indices): # batch info batch_info = { "max_seqlen": torch.max(batch_tensor["seqlens"]).item(), - "image_list": image_list } + + batch_tensor['image_tensor'] = image_tensor # inputs del batch_tensor["seqlens"] diff --git a/ochat/training_deepspeed/train.py b/ochat/training_deepspeed/train.py index 1fdfc7e..ce5af7d 100644 --- a/ochat/training_deepspeed/train.py +++ b/ochat/training_deepspeed/train.py @@ -50,6 +50,7 @@ def parse_args(): # multimodal parser.add_argument("--multimodal", action='store_true', default=False) parser.add_argument("--image_root", type=str, default="/share/project/qiying/datasets/llava/pretrain") + parser.add_argument("--image_size", type=int, default=224) # DeepSpeed parameters parser = deepspeed.add_config_arguments(parser) @@ -70,6 +71,7 @@ def create_dataset_and_dataloader(args, split_name): dataset = OpenchatMultimodalDataset( dataset_filename=filename, image_root=args.image_root, + image_size=args.image_size, batch_max_length=args.batch_max_len, rank=dist.get_rank(), num_replicas=dist.get_world_size() From 50905f9ee9aa076c38afd14d5d09711395621715 Mon Sep 17 00:00:00 2001 From: yqy2001 <178526723@qq.com> Date: Sun, 31 Dec 2023 13:30:08 +0000 Subject: [PATCH 4/4] add multimodal model training and updatedataset --- ochat/data/generate_dataset.py | 17 +++++++++--- ochat/data/process_multimodal_data.py | 8 +++--- ochat/models/unpadded_mistral.py | 37 +++++++++++++++++++++++++-- ochat/training_deepspeed/train.py | 3 +++ 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/ochat/data/generate_dataset.py b/ochat/data/generate_dataset.py index a675073..db03bd4 100644 --- a/ochat/data/generate_dataset.py +++ b/ochat/data/generate_dataset.py @@ -66,15 +66,20 @@ def add_single_conv(output, tokens, weights, image_file=None): @ray.remote def convert_conversation_batch(model_type: str, model_path: str, batch: list, schema: pyarrow.Schema, per_sequence_loss: bool): from ochat.config import MODEL_CONFIG_MAP, Conversation, MultimodalConversation + + multimodal_flag = True if "image" in batch[0] else False # Tokenization model_config = MODEL_CONFIG_MAP[model_type] tokenizer = model_config.model_tokenizer_create(model_path) + if multimodal_flag: # TODO the tokenization function does not tokenize special image tokens + tokenizer.add_special_tokens({"additional_special_tokens": [""]}) + print(f" special token id: {len(tokenizer) - 1}") conv_template = model_config.conversation_template(tokenizer=tokenizer) # Decode data print ("Decoding JSON ...") - ConversationCls = MultimodalConversation if "image" in batch[0] else Conversation + ConversationCls = MultimodalConversation if multimodal_flag else Conversation batch = [ConversationCls(**orjson.loads(json_line)) for json_line in batch] # Tokenize @@ -94,7 +99,7 @@ def convert_conversation_batch(model_type: str, model_path: str, batch: list, sc weights = weights[:max_context] # Add to results - add_single_conv(outputs, tokens, weights, conv.image if hasattr(conv, "image") else None) + add_single_conv(outputs, tokens, weights, conv.image if multimodal_flag else None) print ("Chunk finish") @@ -138,12 +143,15 @@ def generate_split(model_type: str, model_path: str, conversations: list, split_ parquet.write_table(pyarrow.concat_tables([ray.get(handle) for handle in handles]), f"{out_prefix}.{split_name}.parquet") -def generate_dataset(model_type, model_path, in_files, out_prefix, per_sequence_loss, seed, eval_ratio): +def generate_dataset(model_type, model_path, in_files, out_prefix, per_sequence_loss, seed, eval_ratio, tokens_per_image=None): # Load conversations conversations = [] for filename in in_files: with open(filename, "rt") as f: conversations.extend(f.readlines()) + + if 'image' in conversations[0]: + conversations = [sample.replace("", "" * tokens_per_image) for sample in conversations] # Train-test split random.seed(seed) @@ -169,6 +177,9 @@ def generate_dataset(model_type, model_path, in_files, out_prefix, per_sequence_ parser.add_argument("--per-sequence-loss", action="store_true") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--eval-ratio", type=float, default=0.005) + + # image tokens, 577 is the size of openai_clip_336 + parser.add_argument("--tokens-per-image", type=int, default=577) args = parser.parse_args() generate_dataset(**vars(args)) diff --git a/ochat/data/process_multimodal_data.py b/ochat/data/process_multimodal_data.py index 2f56363..ed41547 100644 --- a/ochat/data/process_multimodal_data.py +++ b/ochat/data/process_multimodal_data.py @@ -26,13 +26,15 @@ def convert_llava_data_to_ochat_format(file_path, save_path): ochat_data_list.append(ochat_sample) - with open(os.path.join(save_path, "blip_laion_cc_sbu_558k_ochat.jsonl"), "w") as f: + with open(save_path, "w") as f: for entry in ochat_data_list: json_string = json.dumps(entry) f.write(json_string + '\n') - + +# llava pretrain data, downloaded from https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain file_path = "/share/project/qiying/datasets/llava/blip_laion_cc_sbu_558k.json" -save_path = "/share/project/qiying/datasets/llava" +# save to ochat jsonl format +save_path = "/share/project/qiying/datasets/llava/blip_laion_cc_sbu_558k_ochat.jsonl" convert_llava_data_to_ochat_format(file_path, save_path) diff --git a/ochat/models/unpadded_mistral.py b/ochat/models/unpadded_mistral.py index 5ffd190..37b3d70 100644 --- a/ochat/models/unpadded_mistral.py +++ b/ochat/models/unpadded_mistral.py @@ -298,8 +298,16 @@ def forward( nz_position_ids: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, + image_embeds: torch.Tensor = None ) -> torch.Tensor: nz_hidden_states = self.embed_tokens(nz_input_ids) + + if image_embeds is not None: + if len(image_embeds.shape) == 3: + image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) + idx = nz_input_ids == 32002 # TODO: special token id, no hard code + nz_hidden_states[idx] = image_embeds + cos_sin = self.rotary_emb() # decoder layers @@ -394,7 +402,18 @@ class MultimodalMistralForCausalLM(UnpaddedMistralPreTrainedModel): def __init__(self, config): super().__init__(config) self.model = UnpaddedMistralModel(config) - + + # vision encoder + from transformers import CLIPVisionModel, CLIPVisionConfig + self.vision_cfg = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14-336") + self.vision_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336") + # self.vision_encoder = CLIPVisionModel(self.vision_cfg) + self.vl_bridge = nn.Sequential( + nn.Linear(self.vision_cfg.hidden_size, self.model.config.hidden_size), + nn.GELU(), + nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size) + ) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -418,6 +437,14 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + def encode_image(self, images): + # test + # return torch.randn((2, 577, 4096), dtype=images.dtype, device=images.device) + + emb = self.vision_encoder(images, output_hidden_states=True).hidden_states[-2] # [b, n_patch, c] + emb = self.vl_bridge(emb) + return emb + def forward( self, # Unpadded inputs @@ -430,12 +457,18 @@ def forward( nz_shifted_label_ids: Optional[torch.Tensor] = None, nz_shifted_loss_weights: Optional[torch.Tensor] = None, ) -> CausalLMOutputWithPast: + """ + image_tensor: [B, 3, 224, 224] + """ # Model logits + image_embeds = self.encode_image(image_tensor) + hidden_states = self.model( nz_input_ids=nz_input_ids, nz_position_ids=nz_position_ids, cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen + max_seqlen=max_seqlen, + image_embeds=image_embeds ) logits = self.lm_head(hidden_states) diff --git a/ochat/training_deepspeed/train.py b/ochat/training_deepspeed/train.py index ce5af7d..4b34b91 100644 --- a/ochat/training_deepspeed/train.py +++ b/ochat/training_deepspeed/train.py @@ -253,6 +253,9 @@ def train(): # To device batch_tensor = {k: (v.to(args.device) if v is not None else None) for k, v in batch_tensor.items()} + + if batch_tensor.get('image_tensor') is not None: + batch_tensor['image_tensor'] = batch_tensor['image_tensor'].to(model_engine.dtype) # Update loss, acc = model_engine(**batch_tensor, **batch_info).loss