-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
81 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from torch import nn | ||
|
||
|
||
class Callback: | ||
# NOTE: add more events | ||
# NOTE: READING | ||
# + Pytorch lightning's Callback | ||
|
||
def on_fit_start(self, trainer: "pipegoose.Trainer", pl_module: nn.Module) -> None: | ||
"""Called when fit begins.""" | ||
|
||
def on_fit_end(self, trainer: "pipegoose.Trainer", pl_module: nn.Module) -> None: | ||
"""Called when fit ends.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from pipegoose.distributed import ParallelContext | ||
|
||
|
||
class DistributedLogger: | ||
LEVELS = ["warning", ...] | ||
|
||
def __init__(self, parallel_context: ParallelContext): | ||
pass | ||
|
||
def set_level(self): | ||
pass | ||
|
||
def log(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from enum import Enum | ||
|
||
|
||
class TrainerStatus(Enum): | ||
INITIALIZING = "initializing" | ||
RUNNING = "running" | ||
FINISHED = "finished" | ||
|
||
|
||
class TrainerStage(Enum): | ||
TRAINING = "train" | ||
VALIDATING = "validate" | ||
TESTING = "test" | ||
PREDICTING = "predict" | ||
|
||
|
||
class TrainerState(Enum): | ||
status: TrainerStatus | ||
stage: TrainerStage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import List | ||
|
||
from torch import nn | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
||
from pipegoose.distributed.parallel_context import ParallelContext | ||
from pipegoose.trainer.callback import Callback | ||
from pipegoose.trainer.logger import DistributedLogger | ||
from pipegoose.trainer.state import TrainerState | ||
|
||
|
||
class Trainer: | ||
def __init__( | ||
self, | ||
module: nn.Module, | ||
train_loader: DataLoader, | ||
eval_loader: DataLoader, | ||
optim: Optimizer, | ||
num_epochs: int, | ||
callbacks: List[Callback] = [], | ||
loggers: List[DistributedLogger] = [], | ||
parallel_context: ParallelContext = None, | ||
): | ||
# NOTE: based on the data_parallel_size, tensor_parallel_size, and pipeline_parallel_size | ||
# in the parallel_context, we do the correspond parallel model. | ||
self.state = TrainerState() | ||
|
||
def fit(self): | ||
# NOTE: both train and validation | ||
pass | ||
|
||
def train(self): | ||
# NOTE: only train | ||
pass |