Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienDarve committed Sep 12, 2024
1 parent b5d8e14 commit 8717c3f
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions src/levanter/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,11 +578,25 @@ def create_position_ids_from_inputs_embeds(self, input_axes, PosInput):

@named_call
def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None):
print(input_ids.dtype)
print(token_type_ids.dtype)
print(position_ids.dtype)
print(input_embeds.dtype)
if input_ids is not None:
jax.debug.print(f"input_ids: {input_ids.dtype}")
else:
jax.debug.print(f"input_ids: None")

if token_type_ids is not None:
jax.debug.print(f"token_type_ids: {token_type_ids.dtype}")
else:
jax.debug.print(f"token_type_ids: None")

if position_ids is not None:
jax.debug.print(f"position_ids: {position_ids.dtype}")
else:
jax.debug.print(f"position_ids: None")

if input_embeds is not None:
jax.debug.print(f"input_embeds: {input_embeds.dtype}")
else:
jax.debug.print(f"input_embeds: None")
"""
Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed"
for compatibility with the position_id creation function. Make sures its length is not equal to
Expand Down

0 comments on commit 8717c3f

Please sign in to comment.