Skip to content

Commit

Permalink
Sets RobertaConfig as model architecture and creates default config file
Browse files Browse the repository at this point in the history
  • Loading branch information
prady-saligram committed Aug 26, 2024
1 parent dcd45b2 commit 027b176
Show file tree
Hide file tree
Showing 3 changed files with 887 additions and 1 deletion.
38 changes: 38 additions & 0 deletions config/roberta.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
data:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext_roberta/"
tokenizer: "roberta-base"

model:
type: roberta
vocab_size: 50265
hidden_size: 768
intermediate_size: 3072
num_hidden_layers: 12
num_attention_heads: 12
max_position_embeddings: 512
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
gradient_checkpointing: true

trainer:
tracker:
- type: wandb
project: "levanter"
tags: ["openwebtext", "roberta", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 32
num_train_steps: 20000

optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
3 changes: 2 additions & 1 deletion src/levanter/main/train_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from levanter.models.gpt2 import Gpt2Config
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig
from levanter.models.roberta import RobertaConfig
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig
from levanter.utils.jax_utils import parameter_count
Expand All @@ -30,7 +31,7 @@
class TrainMlmConfig:
data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=LlamaConfig)
model: LmConfig = field(default_factory=RobertaConfig)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)

# config related to continued pretraining
Expand Down
Loading

0 comments on commit 027b176

Please sign in to comment.