You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I developed custom kernels using Pallas for rounding up and rounding down when casting float32 to bfloat16. I’ve tested these kernels and verified that they behave as intended, including running them on a small-scale version of our model, where they worked as expected.
However, when I tried to use the kernels with a large-scale model, the training crashed during the training steps phase.
The Code For Kernels
This is how I defined my kernel:
from torch_xla.experimental.custom_kernel import jax_import_guard
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
jax_import_guard()
def round_down_and_up_bfloat16_kernel(x_ref, o_ref_down, o_ref_up):
mask = jnp.array(0xFFFF0000, dtype=jnp.uint32).astype(jnp.int32)
increment = jnp.array(0x10000, dtype=jnp.uint32).astype(jnp.int32)
# Treat float32 as int32
x_bits = x_ref[...].view(jnp.int32)
# Compute rounded-down value
bf16_towards_zero = x_bits & mask
# Compute rounded-up value
bf16_next = bf16_towards_zero + increment
# Convert back to float32 and store results
o_ref_down[...] = bf16_towards_zero.view(jnp.float32)
o_ref_up[...] = bf16_next.view(jnp.float32)
@jax.jit
def round_down_and_up(x: jax.Array) -> (jax.Array, jax.Array):
out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return pl.pallas_call(
round_down_and_up_bfloat16_kernel,
out_shape=(out_shape, out_shape) # Output shapes for `round_down` and `round_up`
)(x)
The wrapper:
def get_round_down_and_up_bfloat16_pt():
if "round_down_and_up_bfloat16" not in _kernel_cache:
if "initialize_jax" not in _kernel_cache:
initialize_jax() # Initialize JAX only when required
_kernel_cache["initialize_jax"] = True
_kernel_cache["round_down_and_up_bfloat16"] = make_kernel_from_pallas(
round_down_and_up, # Reference to the JAX wrapper
lambda x: [(x.shape, x.dtype), (x.shape, x.dtype)] # Define shape and dtype for both outputs
)
return _kernel_cache["round_down_and_up_bfloat16"]
I expected the training steps to be performed as regular like they were in the smaller model.
Environment
Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
torch_xla version: 2.4.0
Additional context
The machine and code for training was the exact same in the larger and smaller model. The only difference was the model size and batch size. It seems from the output that it is because of the allocation of 32bit parameters I am doing during the kernel function. However, I will note that this is caused during training with bfloat16 model, and I can train in bfloat16 and even float32 without the kernel using the same model and batch size. Thus, I do not think the problem is because of the size of the model itself.
The text was updated successfully, but these errors were encountered:
🐛 Bug
I developed custom kernels using Pallas for rounding up and rounding down when casting float32 to bfloat16. I’ve tested these kernels and verified that they behave as intended, including running them on a small-scale version of our model, where they worked as expected.
However, when I tried to use the kernels with a large-scale model, the training crashed during the training steps phase.
The Code For Kernels
This is how I defined my kernel:
The wrapper:
And the way I use it:
Stack Trace
Traceback (most recent call last): raise self._exception
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space vmem. Used 96.30M of 63.94M vmem. Exceeded vmem capacity by 32.36M.Program vmem requirement 96.30M:
scoped 96.30M Largest program allocations in vmem: 1. Size: 32.00M
Shape: u8[33554432]{0}
Unpadded size: 32.00M
XLA label: name.711 = custom-call(get-tuple-element.2391), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[2048,4096]{1,0}}
Allocation type: scoped
========================== 2. Size: 32.00M
Shape: u8[33554432]{0}
==========================
Unpadded size: 32.00M
XLA label: name.711 = custom-call(get-tuple-element.2391), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[2048,4096]{1,0}}
Allocation type: scoped
========================== 2. Size: 32.00M
Shape: u8[33554432]{0}
Unpadded size: 32.00M
XLA label: name.711 = custom-call(get-tuple-element.2391), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[2048,4096]{1,0}}
Allocation type: scoped
========================== 3. Size: 32.00M
Shape: u8[33554432]{0}
Unpadded size: 32.00M
XLA label: name.711 = custom-call(get-tuple-element.2391), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[2048,4096]{1,0}}
Allocation type: scoped
========================== 4. Size: 308.0K
XLA label: register allocator spill slots call depth 2
Allocation type: scoped
==========================
Expected behavior
I expected the training steps to be performed as regular like they were in the smaller model.
Environment
Additional context
The machine and code for training was the exact same in the larger and smaller model. The only difference was the model size and batch size. It seems from the output that it is because of the allocation of 32bit parameters I am doing during the kernel function. However, I will note that this is caused during training with bfloat16 model, and I can train in bfloat16 and even float32 without the kernel using the same model and batch size. Thus, I do not think the problem is because of the size of the model itself.
The text was updated successfully, but these errors were encountered: