From dcd45b209b81946efa6a57253545c03a618db28d Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Tue, 6 Aug 2024 11:38:09 -0700 Subject: [PATCH] Adds error handling and reverts LmExample class to original --- src/levanter/models/lm_model.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index edcbb59f9..c36e0e622 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -26,6 +26,15 @@ class MaskedLmExample(eqx.Module): def masked_lm( tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None ) -> "MaskedLmExample": + if tokens.ndim != 1: + raise ValueError("tokens must be a 1D array") + + if not jnp.issubdtype(tokens.dtype, jnp.integer): + raise ValueError("tokens must be an integer array") + + if tokens.shape != targets.shape: + raise ValueError("tokens and targets must have the same shape") + Pos = tokens.axes[0] mask = tokens.array != targets.array @@ -41,8 +50,7 @@ def masked_lm( class LmExample(eqx.Module): tokens: hax.NamedArray loss_mask: hax.NamedArray - attn_mask: hax.NamedArray - targets: Optional[hax.NamedArray] = None + attn_mask: AttentionMask | NamedArray = AttentionMask.causal() @staticmethod def causal( @@ -66,27 +74,6 @@ def causal( attn_mask = AttentionMask.causal() return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) - @staticmethod - def masked_lm( - tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None - ) -> "LmExample": - Pos = tokens.axes[0] - - mask = tokens.array != targets.array - loss_mask = mask.astype(jnp.float32) - - if ignore_id is not None: - ignore_mask = targets.array != ignore_id - loss_mask = loss_mask * ignore_mask.astype(jnp.float32) - - print(f"tokens shape: {tokens.shape}") - print(f"targets shape: {targets.shape}") - print(f"loss_mask shape: {loss_mask.shape}") - print(f"attn_mask shape: {attn_mask.shape}") - - return LmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) - - class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property