Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions on the Gradients of LLM #41

Open
Schwartz-Zha opened this issue Nov 24, 2023 · 0 comments
Open

Questions on the Gradients of LLM #41

Schwartz-Zha opened this issue Nov 24, 2023 · 0 comments

Comments

@Schwartz-Zha
Copy link

As I understand, one of the core contributions claimed in the paper is that the whole training does not require the derivatives of LLM, so it saves a lot of resources.

But how is this enforced in the code?

In LMAdaptorModel,

for param in self.generator.model.parameters():
            param.requires_grad = False

In PromptedClassificationReward, there is a no_grad decorator:

@torch.no_grad()
    def _get_logits(
        self,
        texts: List[str]
    ) -> torch.Tensor:

But my experiments show that, both methods cannot really forbid the computation of gradients.

Denote some network blocks as function $g$, and $g$ is restricted by no_grad or requires_grad = False. And there are some network blocks $f$ attached before $g$, so the whole networks looks like $$g(f(x))$$.

However, $f$ does require gradients as $f$ need to be updated. And my experiments show that the gradients of $g$ will be computed in this case, because there is no other way to compute the gradients of $f$. So no_grad/requires_grad = False will have no effect. The gradients will still be computed.

I wonder, in this case, how exactly does the author arrange to make the gradient computation of LLM never happens. Because the training runs too fast, this has no possibility to happen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant