-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
setup_model.py and setup_optimizer.py moved to seprate files and adde… #264
base: main
Are you sure you want to change the base?
Conversation
…d type hinits Signed-off-by: malinjawi <[email protected]>
Signed-off-by: malinjawi <[email protected]>
Signed-off-by: malinjawi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea moving out the setup logic. We probably don't need a separate file for each function, so I recommend just moving all of the setup functions to a single file. We already have setup_accelerator.py
, so maybe we can move these all there & simply rename the file as setup_objects.py
or something similar? (trying to avoid using setup.py
)
from instructlab.training.config import DistributedBackend | ||
|
||
|
||
def setup_optimizer(args: Any, model: torch.nn.Module) -> torch.optim.Optimizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args
here is actually not a typing.Any
but rather a argparse.Namespace
object:
def setup_optimizer(args: Any, model: torch.nn.Module) -> torch.optim.Optimizer: | |
def setup_optimizer(args: argparse.Namespace model: torch.nn.Module) -> torch.optim.Optimizer: |
|
||
|
||
def setup_model( | ||
args: Any, tokenizer: Any, train_loader: Any, grad_accum: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are the proper types for each arg:
args: argparse.Namespace
tokenizer: transformers.PreTrainedTokenizer
,train_loader: torch.utils.data.DataLoader
args: Any, tokenizer: Any, train_loader: Any, grad_accum: int | |
args: argparse.Namespace, tokenizer: transformers.PreTrainedTokenizer, train_loader: torch.utils.data.DataLoader, grad_accum: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR Mohammad! I left a few comments, but this looks good so far.
Hi @malinjawi could you also rebase on the latest main branch when applying review feedback, thanks! |
Description:
This PR addresses issue #225 by refactoring model and optimizer setup functions into setup_model.py and setup_optimizer.py. Key changes include:
Moved setup functions to separate files for better organization.
Added type hints for clarity and improved type checking.
These changes improve code maintainability and readability. Please review and test!