Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas custom kernels: vmem error when using larger models. #8429

Open
dshalem opened this issue Dec 2, 2024 · 2 comments
Open

Pallas custom kernels: vmem error when using larger models. #8429

dshalem opened this issue Dec 2, 2024 · 2 comments

Comments

@dshalem
Copy link

dshalem commented Dec 2, 2024

🐛 Bug

I developed custom kernels using Pallas for rounding up and rounding down when casting float32 to bfloat16. I’ve tested these kernels and verified that they behave as intended, including running them on a small-scale version of our model, where they worked as expected.
However, when I tried to use the kernels with a large-scale model, the training crashed during the training steps phase.

The Code For Kernels

This is how I defined my kernel:

from torch_xla.experimental.custom_kernel import jax_import_guard
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
jax_import_guard()
def round_down_and_up_bfloat16_kernel(x_ref, o_ref_down, o_ref_up):
    mask = jnp.array(0xFFFF0000, dtype=jnp.uint32).astype(jnp.int32)
    increment = jnp.array(0x10000, dtype=jnp.uint32).astype(jnp.int32)

    # Treat float32 as int32
    x_bits = x_ref[...].view(jnp.int32)

    # Compute rounded-down value
    bf16_towards_zero = x_bits & mask

    # Compute rounded-up value
    bf16_next = bf16_towards_zero + increment

    # Convert back to float32 and store results
    o_ref_down[...] = bf16_towards_zero.view(jnp.float32)
    o_ref_up[...] = bf16_next.view(jnp.float32)

@jax.jit
def round_down_and_up(x: jax.Array) -> (jax.Array, jax.Array):
    out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
    return pl.pallas_call(
        round_down_and_up_bfloat16_kernel,
        out_shape=(out_shape, out_shape)  # Output shapes for `round_down` and `round_up`
    )(x)

The wrapper:

def get_round_down_and_up_bfloat16_pt():
    if "round_down_and_up_bfloat16" not in _kernel_cache:
        if "initialize_jax" not in _kernel_cache:
            initialize_jax()  # Initialize JAX only when required
            _kernel_cache["initialize_jax"] = True

        _kernel_cache["round_down_and_up_bfloat16"] = make_kernel_from_pallas(
            round_down_and_up,  # Reference to the JAX wrapper
            lambda x: [(x.shape, x.dtype), (x.shape, x.dtype)]  # Define shape and dtype for both outputs
            )
    return _kernel_cache["round_down_and_up_bfloat16"]

And the way I use it:

round_down_and_up_bfloat16_pt = get_round_down_and_up_bfloat16_pt()
x_low, x_high = round_down_and_up_bfloat16_pt(x_f32)

Stack Trace

Traceback (most recent call last): raise self._exception
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space vmem. Used 96.30M of 63.94M vmem. Exceeded vmem capacity by 32.36M.Program vmem requirement 96.30M:
scoped 96.30M Largest program allocations in vmem: 1. Size: 32.00M
Shape: u8[33554432]{0}
Unpadded size: 32.00M
XLA label: name.711 = custom-call(get-tuple-element.2391), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[2048,4096]{1,0}}
Allocation type: scoped
========================== 2. Size: 32.00M
Shape: u8[33554432]{0}
==========================
Unpadded size: 32.00M
XLA label: name.711 = custom-call(get-tuple-element.2391), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[2048,4096]{1,0}}
Allocation type: scoped
========================== 2. Size: 32.00M
Shape: u8[33554432]{0}
Unpadded size: 32.00M
XLA label: name.711 = custom-call(get-tuple-element.2391), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[2048,4096]{1,0}}
Allocation type: scoped
========================== 3. Size: 32.00M
Shape: u8[33554432]{0}
Unpadded size: 32.00M
XLA label: name.711 = custom-call(get-tuple-element.2391), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[2048,4096]{1,0}}
Allocation type: scoped
========================== 4. Size: 308.0K
XLA label: register allocator spill slots call depth 2
Allocation type: scoped
==========================

Expected behavior

I expected the training steps to be performed as regular like they were in the smaller model.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 2.4.0

Additional context

The machine and code for training was the exact same in the larger and smaller model. The only difference was the model size and batch size. It seems from the output that it is because of the allocation of 32bit parameters I am doing during the kernel function. However, I will note that this is caused during training with bfloat16 model, and I can train in bfloat16 and even float32 without the kernel using the same model and batch size. Thus, I do not think the problem is because of the size of the model itself.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2024

There are 3 u8[33554432]{0} 32.00M allocated on VMEM which corresponding to 96.30M that exceed the VMEM size. I am not a pallas expert but you want might to look into https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html.

@dshalem
Copy link
Author

dshalem commented Dec 9, 2024

There are 3 u8[33554432]{0} 32.00M allocated on VMEM which corresponding to 96.30M that exceed the VMEM size. I am not a pallas expert but you want might to look into https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html.

Ok there is a bug in grid with TPU. opened issue here: #8469

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants