Skip to content

Commit

Permalink
Adds error handling and reverts LmExample class to original
Browse files Browse the repository at this point in the history
  • Loading branch information
prady-saligram committed Aug 6, 2024
1 parent 53fd8d2 commit dcd45b2
Showing 1 changed file with 10 additions and 23 deletions.
33 changes: 10 additions & 23 deletions src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ class MaskedLmExample(eqx.Module):
def masked_lm(
tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None
) -> "MaskedLmExample":
if tokens.ndim != 1:
raise ValueError("tokens must be a 1D array")

if not jnp.issubdtype(tokens.dtype, jnp.integer):
raise ValueError("tokens must be an integer array")

if tokens.shape != targets.shape:
raise ValueError("tokens and targets must have the same shape")

Pos = tokens.axes[0]

mask = tokens.array != targets.array
Expand All @@ -41,8 +50,7 @@ def masked_lm(
class LmExample(eqx.Module):
tokens: hax.NamedArray
loss_mask: hax.NamedArray
attn_mask: hax.NamedArray
targets: Optional[hax.NamedArray] = None
attn_mask: AttentionMask | NamedArray = AttentionMask.causal()

@staticmethod
def causal(
Expand All @@ -66,27 +74,6 @@ def causal(
attn_mask = AttentionMask.causal()
return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask)

@staticmethod
def masked_lm(
tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None
) -> "LmExample":
Pos = tokens.axes[0]

mask = tokens.array != targets.array
loss_mask = mask.astype(jnp.float32)

if ignore_id is not None:
ignore_mask = targets.array != ignore_id
loss_mask = loss_mask * ignore_mask.astype(jnp.float32)

print(f"tokens shape: {tokens.shape}")
print(f"targets shape: {targets.shape}")
print(f"loss_mask shape: {loss_mask.shape}")
print(f"attn_mask shape: {attn_mask.shape}")

return LmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask)



class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore
@property
Expand Down

0 comments on commit dcd45b2

Please sign in to comment.