Skip to content

Commit

Permalink
s
Browse files Browse the repository at this point in the history
  • Loading branch information
truskovskiyk committed Jun 11, 2024
1 parent 6e49c54 commit e0ff05d
Show file tree
Hide file tree
Showing 22 changed files with 122 additions and 62 deletions.
5 changes: 1 addition & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
FROM huggingface/transformers-pytorch-gpu:4.35.2

WORKDIR /app

COPY requirements.txt requirements.txt

RUN pip3 install --no-cache-dir -r requirements.txt
RUN MAX_JOBS=4 pip install flash-attn==2.5.9.post1 --no-build-isolation

RUN git clone https://github.com/philschmid/FastChat.git
RUN pip install -e "./FastChat[model_worker,llm_judge]"
RUN pip install matplotlib==3.7.3 tabulate==0.9.0


ENV DAGSTER_HOME /app/dagster_data
RUN mkdir -p $DAGSTER_HOME

ENV PYTHONPATH /app
RUN ln -s /usr/bin/python3 /usr/bin/python

Expand Down
9 changes: 0 additions & 9 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,3 @@ style_check:

style_fix:
ruff format rlhf_training/

docker_build:
docker build -t rlfh-dagster-modal:latest .

docker_ssh:
docker run -it --gpus all --ipc=host --net=host -v $PWD:/app rlfh-dagster-modal:latest /bin/bash

run_dev_dagster:
mkdir $DAGSTER_HOME && git clone https://github.com/philschmid/FastChat.git && dagster dev -f rlhf_training/__init__.py -p 3000 -h 0.0.0.0
50 changes: 48 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,52 @@
# rlhf-in-2024-with-dpo-and-hf
# RLHF with Dagster and Modal!

## Why

Re-usable & scalable RLHF training pipeline with Dagster and Modal.

![endstate](./docs/12.png)

Read full story in this blog post: [RLHF with Dagster and Modal](https://kyrylai.com/2024/06/10/rlhf-with-dagster-and-modal/)

## Access

You would need access to

- [HF](https://huggingface.co/docs/huggingface_hub/quick-start) to save your model.
- [Modal](https://modal.com/docs/guide#getting-started) to use GPU for training.
- [OpenAI](https://openai.com/index/openai-api/) to use GPT-4 for evaluation.

Make sure your .env file looks like this:

```
HF_TOKEN=hf_
MODAL_TOKEN_ID=ak-
MODAL_TOKEN_SECRET=as-
OPENAI_API_KEY=sk
```


## Setup

The recommended way is to use a prebuilt [Docker image](https://github.com/kyryl-opens-ml/rlfh-dagster-modal/pkgs/container/rlfh-dagster-modal).
```
docker pull ghcr.io/kyryl-opens-ml/rlfh-dagster-modal:main
docker run -it --env-file .env -p 3000:3000 ghcr.io/kyryl-opens-ml/rlfh-dagster-modal:main
```


## Deploy Modal functions

Make sure you depliyed training & inference function to [Modal](https://modal.com/).

```
modal deploy ./rlhf_training/serverless_functions.py
```

## Run Dagster end2end

Finally run Dagster.

```
modal deploy ./rlhf_training/smodal_functions.py
dagster dev -f rlhf_training/__init__.py -p 3000 -h 0.0.0.0
```
Binary file added docs/.DS_Store
Binary file not shown.
Binary file added docs/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/12.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.ruff]
line-length = 159
line-length = 120

[tool.ruff.lint]
# Add the `line-too-long` rule to the enforced rule set. By default, Ruff omits rules that
Expand Down
26 changes: 22 additions & 4 deletions rlhf_training/assets/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,42 @@ def create_triplets(example, tokenizer, default_system_message=DEFAULT_SYSTEM_ME
"chosen": tokenizer.apply_chat_template(chosen_messages, tokenize=False),
"rejected": tokenizer.apply_chat_template(rejected_messages, tokenize=False),
}
dataset = dataset.map(create_triplets, remove_columns=dataset.features, fn_kwargs={"tokenizer": tokenizer})

dataset = dataset.map(
create_triplets,
remove_columns=dataset.features,
fn_kwargs={"tokenizer": tokenizer},
)
dataset = dataset.train_test_split(test_size=config.eval_size)

dataset["train"].to_json(config.train_data_path, orient="records")
dataset["test"].to_json(config.eval_data_path, orient="records")

return {"train_path": config.train_data_path, "test_path": config.eval_data_path}
return {
"train_path": config.train_data_path,
"test_path": config.eval_data_path,
}


@asset(compute_kind="python")
def train_dataset(context: AssetExecutionContext, rlhf_dataset: Dict[str, str]) -> Dataset:
dataset = load_dataset("json", data_files=rlhf_dataset["train_path"], split="train")
context.add_output_metadata({"len": MetadataValue.int(len(dataset)), "sample": MetadataValue.json(dataset[randint(0, len(dataset))])})
context.add_output_metadata(
{
"len": MetadataValue.int(len(dataset)),
"sample": MetadataValue.json(dataset[randint(0, len(dataset))]),
}
)
return dataset


@asset(compute_kind="python")
def eval_dataset(context: AssetExecutionContext, rlhf_dataset: Dict[str, str]) -> Dataset:
dataset = load_dataset("json", data_files=rlhf_dataset["test_path"], split="train")
context.add_output_metadata({"len": MetadataValue.int(len(dataset)), "sample": MetadataValue.json(dataset[randint(0, len(dataset))])})
context.add_output_metadata(
{
"len": MetadataValue.int(len(dataset)),
"sample": MetadataValue.json(dataset[randint(0, len(dataset))]),
}
)
return dataset
28 changes: 14 additions & 14 deletions rlhf_training/assets/model.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
from dagster import Config, asset, MetadataValue, AssetExecutionContext
from huggingface_hub import hf_hub_download
from datasets import Dataset
from rlhf_training.utils import run_training, run_sample_inference
import modal


class ModelTrainingConfig(Config):
pretrained_model_id: str = "cognitivecomputations/dolphin-2.1-mistral-7b"
peft_model_id: str = "doplhin-dpo"
num_train_epochs: float = 0.001
peft_model_id: str = "doplhin-dpo-1-epoch"
num_train_epochs: float = 0.9


@asset(compute_kind="modal")
def trained_model(
context: AssetExecutionContext, config: ModelTrainingConfig, train_dataset: Dataset, eval_dataset: Dataset
context: AssetExecutionContext,
config: ModelTrainingConfig,
train_dataset: Dataset,
eval_dataset: Dataset,
) -> str:
# run_training_modal_function = modal.Function.lookup("fine-tune-llms-in-2024-with-trl", "run_training_modal")
# hub_model_id = run_training_modal_function.remote(
# train_data_pandas=train_data.to_pandas(),
# pretrained_model_id=config.pretrained_model_id,
# peft_model_id=config.peft_model_id,
# )
hub_model_id = run_training(
run_training_modal_function = modal.Function.lookup("rlfh-dagster-modal", "run_training_modal")
hub_model_id = run_training_modal_function.remote(
pretrained_model_id=config.pretrained_model_id,
rlhf_model_id=config.peft_model_id,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset_pandas=train_dataset.to_pandas(),
eval_dataset_pands=eval_dataset.to_pandas(),
num_train_epochs=config.num_train_epochs,
)
context.add_output_metadata({"model_url": MetadataValue.url(f"https://huggingface.co/{hub_model_id}")})
Expand All @@ -48,7 +46,9 @@ def vibe_check(context: AssetExecutionContext, trained_model: str):
"How can i get rid of llamas in my backyard?",
]

inference_samples = run_sample_inference(prompts=prompts, hub_model_id=trained_model)
run_sample_inference_modal_function = modal.Function.lookup("rlfh-dagster-modal", "run_sample_inference_modal")

inference_samples = run_sample_inference_modal_function.remote(prompts=prompts, hub_model_id=trained_model)
context.add_output_metadata(
{
"samples": MetadataValue.json(inference_samples),
Expand Down
21 changes: 16 additions & 5 deletions rlhf_training/assets/mt_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class MTBenchConfig(Config):
sft_model_id: str = "my-sft"
rlhf_model_id: str = "my-rlhf"


@asset(compute_kind="python")
def mt_bench_questions(context: AssetExecutionContext, config: MTBenchConfig):
_mt_bench_questions = read_jsonl(config.mt_bench_questions_path)
Expand All @@ -32,33 +33,43 @@ def original_responses(context: AssetExecutionContext, config: MTBenchConfig, mt
context.add_output_metadata(
{
"original_responses": MetadataValue.json(_original_responses),
"cli_output": MetadataValue.text(result.stdout)
"cli_output": MetadataValue.text(result.stdout),
}
)
return config.original_responses_path


@asset(compute_kind="python")
def rlhf_responses(context: AssetExecutionContext, config: MTBenchConfig, mt_bench_questions, trained_model: str):
def rlhf_responses(
context: AssetExecutionContext,
config: MTBenchConfig,
mt_bench_questions,
trained_model: str,
):
cmd = f"python FastChat/fastchat/llm_judge/gen_model_answer.py --model-id {config.rlhf_model_id} --model-path {trained_model}"
result = subprocess.run(cmd.split(), check=True, capture_output=True, text=True)
_rlhf_responses = read_jsonl(config.rlhf_responses_path)

context.add_output_metadata(
{
"rlhf_responses": MetadataValue.json(_rlhf_responses),
"cli_output": MetadataValue.text(result.stdout)
"cli_output": MetadataValue.text(result.stdout),
}
)
return config.rlhf_responses_path


@asset(compute_kind="python")
def judgment_results(context: AssetExecutionContext, config: MTBenchConfig, original_responses, rlhf_responses):
def judgment_results(
context: AssetExecutionContext,
config: MTBenchConfig,
original_responses,
rlhf_responses,
):
cmd = f"python FastChat/fastchat/llm_judge/gen_judgment.py --model-list {config.sft_model_id} {config.rlhf_model_id} --judge-model gpt-4-1106-preview --mode pairwise-all"
result_gen_judgment = subprocess.run(cmd.split(), check=True, capture_output=True, text=True)

cmd = f"python FastChat/fastchat/llm_judge/show_result.py --input-file ./data/mt_bench/model_judgment/gpt-4-1106-preview_pair.jsonl --model-list mistral-dolphin-dpo mistral-dolphin-sft --judge-model gpt-4-1106-preview --mode pairwise-all"
cmd = f"python FastChat/fastchat/llm_judge/show_result.py --input-file ./data/mt_bench/model_judgment/gpt-4-1106-preview_pair.jsonl --model-list {config.sft_model_id} {config.rlhf_model_id} --judge-model gpt-4-1106-preview --mode pairwise-all"
result_show_result = subprocess.run(cmd.split(), check=True, capture_output=True, text=True)

image_path = "win_rate_gpt-4-1106-preview.png"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
app = modal.App("rlfh-dagster-modal")
env = {"HF_TOKEN": os.getenv("HF_TOKEN")}
custom_image = Image.from_registry("ghcr.io/kyryl-opens-ml/rlfh-dagster-modal:main").env(env)
mount = modal.Mount.from_local_python_packages("rlhf_training", "rlhf_training")
timeout = 6 * 60 * 60


@app.function(
image=custom_image,
gpu="A100",
mounts=[modal.Mount.from_local_python_packages("rlhf_training", "rlhf_training")],
timeout=15 * 60,
)
def run_training_modal(pretrained_model_id: str, rlhf_model_id: str, train_dataset_pandas: pd.DataFrame, eval_dataset_pands: pd.DataFrame, num_train_epochs: float):
@app.function(image=custom_image, gpu="A100", mounts=[mount], timeout=timeout)
def run_training_modal(
pretrained_model_id: str,
rlhf_model_id: str,
train_dataset_pandas: pd.DataFrame,
eval_dataset_pands: pd.DataFrame,
num_train_epochs: float,
):
from datasets import Dataset
from rlhf_training.utils import run_training

Expand All @@ -24,17 +27,14 @@ def run_training_modal(pretrained_model_id: str, rlhf_model_id: str, train_datas
rlhf_model_id=rlhf_model_id,
train_dataset=Dataset.from_pandas(train_dataset_pandas),
eval_dataset=Dataset.from_pandas(eval_dataset_pands),
num_train_epochs=num_train_epochs
num_train_epochs=num_train_epochs,
)
return model_url

@app.function(
image=custom_image,
gpu="A100",
mounts=[modal.Mount.from_local_python_packages("rlhf_training", "rlhf_training")],
timeout=15 * 60,
)
def run_sample_inference_model(prompts: List[str], hub_model_id: str) -> List[Dict[str, str]]:

@app.function(image=custom_image, gpu="A100", mounts=[mount], timeout=timeout)
def run_sample_inference_modal(prompts: List[str], hub_model_id: str) -> List[Dict[str, str]]:
from rlhf_training.utils import run_sample_inference

inference_samples = run_sample_inference(prompts=prompts, hub_model_id=hub_model_id)
return inference_samples
return inference_samples
11 changes: 4 additions & 7 deletions rlhf_training/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from dagster import Config, asset, MetadataValue, AssetExecutionContext
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from pathlib import Path
from datasets import Dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from random import randint
import torch
from transformers import (
AutoTokenizer,
Expand All @@ -15,7 +10,6 @@
)
import json
from trl import DPOTrainer
import subprocess
import base64
from typing import List, Dict

Expand Down Expand Up @@ -147,7 +141,10 @@ def run_sample_inference(prompts: List[str], hub_model_id: str) -> List[Dict[str
pad_token_id=tokenizer.pad_token_id,
)
inference_samples.append(
{"prompt": prompt, "generated-answer": outputs[0]["generated_text"][len(prompt) :].strip()}
{
"prompt": prompt,
"generated-answer": outputs[0]["generated_text"][len(prompt) :].strip(),
}
)
return inference_samples

Expand Down

0 comments on commit e0ff05d

Please sign in to comment.