Skip to content

Commit

Permalink
[Feature] WIP Trainer demo
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 26, 2023
1 parent 48e8256 commit 46e1b0c
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 0 deletions.
13 changes: 13 additions & 0 deletions pipegoose/trainer/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch import nn


class Callback:
# NOTE: add more events
# NOTE: READING
# + Pytorch lightning's Callback

def on_fit_start(self, trainer: "pipegoose.Trainer", pl_module: nn.Module) -> None:
"""Called when fit begins."""

def on_fit_end(self, trainer: "pipegoose.Trainer", pl_module: nn.Module) -> None:
"""Called when fit ends."""
14 changes: 14 additions & 0 deletions pipegoose/trainer/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pipegoose.distributed import ParallelContext


class DistributedLogger:
LEVELS = ["warning", ...]

def __init__(self, parallel_context: ParallelContext):
pass

def set_level(self):
pass

def log(self):
pass
19 changes: 19 additions & 0 deletions pipegoose/trainer/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from enum import Enum


class TrainerStatus(Enum):
INITIALIZING = "initializing"
RUNNING = "running"
FINISHED = "finished"


class TrainerStage(Enum):
TRAINING = "train"
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"


class TrainerState(Enum):
status: TrainerStatus
stage: TrainerStage
35 changes: 35 additions & 0 deletions pipegoose/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import List

from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.trainer.callback import Callback
from pipegoose.trainer.logger import DistributedLogger
from pipegoose.trainer.state import TrainerState


class Trainer:
def __init__(
self,
module: nn.Module,
train_loader: DataLoader,
eval_loader: DataLoader,
optim: Optimizer,
num_epochs: int,
callbacks: List[Callback] = [],
loggers: List[DistributedLogger] = [],
parallel_context: ParallelContext = None,
):
# NOTE: based on the data_parallel_size, tensor_parallel_size, and pipeline_parallel_size
# in the parallel_context, we do the correspond parallel model.
self.state = TrainerState()

def fit(self):
# NOTE: both train and validation
pass

def train(self):
# NOTE: only train
pass

0 comments on commit 46e1b0c

Please sign in to comment.