Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft - Multimodal OpenChat Training #145

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ochat/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import transformers

from ochat.config.model_config import ModelConfig
from ochat.config.conversation_template import Message, Conversation, ConversationTemplate
from ochat.config.conversation_template import Message, Conversation, ConversationTemplate, MultimodalConversation
import ochat.models


Expand Down
8 changes: 8 additions & 0 deletions ochat/config/conversation_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ class Conversation(BaseModel):

condition: str = ""
system: str = ""


class MultimodalConversation(BaseModel):
items: List[Message]

image: str = ""
condition: str = ""
system: str = ""


class ConversationTemplate(BaseModel):
Expand Down
30 changes: 24 additions & 6 deletions ochat/data/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def truncate_trailing_zero_weighted(tokens, weights):
return tokens[:non_zero_index + 1], weights[:non_zero_index + 1]


def add_single_conv(output, tokens, weights):
def add_single_conv(output, tokens, weights, image_file=None):
# truncate trailing zero weighted tokens
tokens, weights = truncate_trailing_zero_weighted(tokens, weights)
if not tokens:
Expand All @@ -55,23 +55,32 @@ def add_single_conv(output, tokens, weights):
"nz_shifted_loss_weights": weights[1:] + [0.0]
}
results["num_seqs"] = sum(results["nz_shifted_loss_weights"])

if image_file is not None:
results['image_file'] = image_file

for k, v in results.items():
output[k].append(v)


@ray.remote
def convert_conversation_batch(model_type: str, model_path: str, batch: list, schema: pyarrow.Schema, per_sequence_loss: bool):
from ochat.config import MODEL_CONFIG_MAP, Conversation
from ochat.config import MODEL_CONFIG_MAP, Conversation, MultimodalConversation

multimodal_flag = True if "image" in batch[0] else False

# Tokenization
model_config = MODEL_CONFIG_MAP[model_type]
tokenizer = model_config.model_tokenizer_create(model_path)
if multimodal_flag: # TODO the tokenization function does not tokenize special image tokens
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
print(f"<image> special token id: {len(tokenizer) - 1}")
conv_template = model_config.conversation_template(tokenizer=tokenizer)

# Decode data
print ("Decoding JSON ...")
batch = [Conversation(**orjson.loads(json_line)) for json_line in batch]
ConversationCls = MultimodalConversation if multimodal_flag else Conversation
batch = [ConversationCls(**orjson.loads(json_line)) for json_line in batch]

# Tokenize
print ("Tokenizing ...")
Expand All @@ -82,15 +91,15 @@ def convert_conversation_batch(model_type: str, model_path: str, batch: list, sc
max_context = model_config.model_max_context

outputs = {k: [] for k in schema.names}
for tokens, weights in zip(tokens_list, weights_list):
for tokens, weights, conv in zip(tokens_list, weights_list, batch):
assert len(tokens) == len(weights)

# Truncate to specified tokens
tokens = tokens[:max_context]
weights = weights[:max_context]

# Add to results
add_single_conv(outputs, tokens, weights)
add_single_conv(outputs, tokens, weights, conv.image if multimodal_flag else None)

print ("Chunk finish")

Expand All @@ -112,6 +121,9 @@ def generate_split(model_type: str, model_path: str, conversations: list, split_
pyarrow.field(f"nz_shifted_label_ids", pyarrow.list_(pyarrow.int32())),
pyarrow.field(f"nz_shifted_loss_weights", pyarrow.list_(pyarrow.float32()))
]

if "image" in conversations[0]: # multimodal data
schema.append(pyarrow.field(f"image_file", pyarrow.list_(pyarrow.string())))

schema = pyarrow.schema(schema, metadata={"metadata_json": orjson.dumps(metadata)})

Expand All @@ -131,12 +143,15 @@ def generate_split(model_type: str, model_path: str, conversations: list, split_
parquet.write_table(pyarrow.concat_tables([ray.get(handle) for handle in handles]), f"{out_prefix}.{split_name}.parquet")


def generate_dataset(model_type, model_path, in_files, out_prefix, per_sequence_loss, seed, eval_ratio):
def generate_dataset(model_type, model_path, in_files, out_prefix, per_sequence_loss, seed, eval_ratio, tokens_per_image=None):
# Load conversations
conversations = []
for filename in in_files:
with open(filename, "rt") as f:
conversations.extend(f.readlines())

if 'image' in conversations[0]:
conversations = [sample.replace("<image>", "<image>" * tokens_per_image) for sample in conversations]

# Train-test split
random.seed(seed)
Expand All @@ -162,6 +177,9 @@ def generate_dataset(model_type, model_path, in_files, out_prefix, per_sequence_
parser.add_argument("--per-sequence-loss", action="store_true")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--eval-ratio", type=float, default=0.005)

# image tokens, 577 is the size of openai_clip_336
parser.add_argument("--tokens-per-image", type=int, default=577)
args = parser.parse_args()

generate_dataset(**vars(args))
41 changes: 41 additions & 0 deletions ochat/data/process_multimodal_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import json

def convert_llava_data_to_ochat_format(file_path, save_path):
"""
each item of llava data: id, image, conversations
"""
with open(file_path, "r") as f:
data = json.load(f)

ochat_data_list = []
for sample in data:
ochat_sample = {}
ochat_sample['image'] = sample['image']

conversations = []
for uttr in sample['conversations']:
assert uttr['from'] in ['human', 'gpt']
conversations.append({
"role": "user" if uttr['from'] == 'human' else "assistant",
"content": uttr['value'],
"weight": 0.0 if uttr['from'] == 'human' else 1.0,
})
ochat_sample['items'] = conversations
ochat_sample['system'] = "hello boy"

ochat_data_list.append(ochat_sample)

with open(save_path, "w") as f:
for entry in ochat_data_list:
json_string = json.dumps(entry)
f.write(json_string + '\n')

# llava pretrain data, downloaded from https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain
file_path = "/share/project/qiying/datasets/llava/blip_laion_cc_sbu_558k.json"
# save to ochat jsonl format
save_path = "/share/project/qiying/datasets/llava/blip_laion_cc_sbu_558k_ochat.jsonl"
convert_llava_data_to_ochat_format(file_path, save_path)


# python -m ochat.data.generate_dataset --model-type openchat_v3.2_mistral --model-path /share/project/qiying/datasets/llava/Mistral_7B_with_EOT_token --in-files /share/project/qiying/datasets/llava/blip_laion_cc_sbu_558k_ochat.jsonl --out-prefix ./output
2 changes: 1 addition & 1 deletion ochat/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ochat.models.unpadded_llama import LlamaForCausalLM
from ochat.models.unpadded_mistral import MistralForCausalLM
from ochat.models.unpadded_mistral import MistralForCausalLM, MultimodalMistralForCausalLM
99 changes: 99 additions & 0 deletions ochat/models/unpadded_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def forward(
residual = nz_hidden_states

nz_hidden_states = self.input_layernorm(nz_hidden_states)
nz_hidden_states = nz_hidden_states.to(torch.bfloat16)

nz_hidden_states = self.self_attn(
cos_sin=cos_sin,

Expand All @@ -232,6 +234,8 @@ def forward(
residual = nz_hidden_states

nz_hidden_states = self.post_attention_layernorm(nz_hidden_states)
nz_hidden_states = nz_hidden_states.to(torch.bfloat16)

nz_hidden_states = self.mlp(nz_hidden_states)
nz_hidden_states = residual + nz_hidden_states

Expand Down Expand Up @@ -294,8 +298,16 @@ def forward(
nz_position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
image_embeds: torch.Tensor = None
) -> torch.Tensor:
nz_hidden_states = self.embed_tokens(nz_input_ids)

if image_embeds is not None:
if len(image_embeds.shape) == 3:
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
idx = nz_input_ids == 32002 # TODO: <image> special token id, no hard code
nz_hidden_states[idx] = image_embeds

cos_sin = self.rotary_emb()

# decoder layers
Expand All @@ -321,6 +333,7 @@ def forward(
)

nz_hidden_states = self.norm(nz_hidden_states)
nz_hidden_states = nz_hidden_states.to(torch.bfloat16)

return nz_hidden_states

Expand Down Expand Up @@ -385,6 +398,92 @@ def forward(
logits=logits
)

class MultimodalMistralForCausalLM(UnpaddedMistralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = UnpaddedMistralModel(config)

# vision encoder
from transformers import CLIPVisionModel, CLIPVisionConfig
self.vision_cfg = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14-336")
self.vision_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")
# self.vision_encoder = CLIPVisionModel(self.vision_cfg)
self.vl_bridge = nn.Sequential(
nn.Linear(self.vision_cfg.hidden_size, self.model.config.hidden_size),
nn.GELU(),
nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size)
)

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model = decoder

def get_decoder(self):
return self.model

def encode_image(self, images):
# test
# return torch.randn((2, 577, 4096), dtype=images.dtype, device=images.device)

emb = self.vision_encoder(images, output_hidden_states=True).hidden_states[-2] # [b, n_patch, c]
emb = self.vl_bridge(emb)
return emb

def forward(
self,
# Unpadded inputs
image_tensor: torch.Tensor,
nz_input_ids: torch.Tensor,
nz_position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
# Unpadded labels
nz_shifted_label_ids: Optional[torch.Tensor] = None,
nz_shifted_loss_weights: Optional[torch.Tensor] = None,
) -> CausalLMOutputWithPast:
"""
image_tensor: [B, 3, 224, 224]
"""
# Model logits
image_embeds = self.encode_image(image_tensor)

hidden_states = self.model(
nz_input_ids=nz_input_ids,
nz_position_ids=nz_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
image_embeds=image_embeds
)
logits = self.lm_head(hidden_states)

loss = None
if nz_shifted_label_ids is not None:
assert nz_shifted_loss_weights is not None

loss = weighted_cross_entropy(logits, nz_shifted_label_ids, nz_shifted_loss_weights), \
weighted_token_accuracy(logits.detach(), nz_shifted_label_ids, nz_shifted_loss_weights)

return CausalLMOutputWithPast(
loss=loss, # type: ignore
logits=logits
)


class PaddedMistralForCausalLM(MistralForCausalLM):
"""Compat layer for padded inputs"""
Expand Down
Loading