Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support turbomind bf16 #803

Merged
merged 23 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ project(TurboMind LANGUAGES CXX CUDA)

find_package(CUDA 10.2 REQUIRED)

# if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11")
# add_definitions("-DENABLE_BF16")
# message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag")
# endif()
if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11")
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
add_definitions("-DENABLE_BF16")
message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag")
endif()

# if((${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11" AND ${CUDA_VERSION_MINOR} VERSION_GREATER_EQUAL "8") OR (${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "12"))
# add_definitions("-DENABLE_FP8")
Expand Down
54 changes: 54 additions & 0 deletions lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import fire
import torch
from huggingface_hub import snapshot_download

from lmdeploy.model import MODELS
Expand Down Expand Up @@ -113,6 +114,55 @@ def copy_tokenizer(model_path: str, tokenizer_path: str,
osp.join(triton_models_path, 'tokenizer'))


def update_output_format(model_name: str, model_format: str, model_path: str,
grimoire marked this conversation as resolved.
Show resolved Hide resolved
output_format: str):
"""Update output format according to model info."""
TORCH_DTYPE_MAP = {torch.bfloat16: 'bf16'}
MODEL_NAME_MAP = {'qwen': 'bf16', 'llama': 'half'}
model_name = model_name.split('-')[0]

def _fix_device_support(output_format):
"""fix device support."""
if output_format == 'bf16':
if not torch.cuda.is_bf16_supported():
# device does not support bf16
print('Device does not support bf16.')
output_format = 'fp16'
return output_format

def _infer_output_format(config):
"""_infer_output_format."""
torch_dtype = getattr(config, 'torch_dtype', None)
if torch_dtype:
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
updated_output_format = TORCH_DTYPE_MAP.get(
torch_dtype, output_format)
else:
# get model name prefix
updated_output_format = MODEL_NAME_MAP.get(model_name,
output_format)
return _fix_device_support(updated_output_format)

if model_format in MODEL_NAME_MAP:
updated_output_format = MODEL_NAME_MAP.get(model_name, output_format)
return _fix_device_support(updated_output_format)
else:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
return _infer_output_format(config)


def update_config_weight_type(output_format: str,
config: TurbomindModelConfig):
WEIGHT_TYPE_MAP = {
'fp32': 'fp32',
'fp16': 'fp16',
'bf16': 'bf16',
'w4': 'int4',
'w8': 'int8'
}
config.weight_type = WEIGHT_TYPE_MAP[output_format]


def pack_model_repository(workspace_path: str):
"""package the model repository.

Expand Down Expand Up @@ -215,6 +265,10 @@ def main(model_name: str,
cfg.weight_type = 'int4'
output_format = 'w4'
assert group_size > 0, f'group_size: {group_size} should > 0'
else:
output_format = update_output_format(model_name, inferred_model_format,
model_path, output_format)
update_config_weight_type(output_format, cfg)

# convert
print('model_name ', model_name)
Expand Down
29 changes: 23 additions & 6 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def valid(self):
return True


_WEIGHT_DTYPE_MAP = dict(
int4=torch.float16,
fp16=torch.float16,
fp32=torch.float16,
bf16=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
)


class BaseOutputModel(ABC):
"""Base output model."""

Expand Down Expand Up @@ -136,26 +144,35 @@ def export_config(self) -> None:

def export_weight(self, param: torch.Tensor, name: str) -> None:
"""export turbomind weight."""

def _tofile(tensor, path):
"""to file."""
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.half)
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
tensor.contiguous().cpu().numpy().tofile(path)

if self.to_file:
if param.dtype in [torch.float, torch.bfloat16]:
param = param.half()
torch_type = _WEIGHT_DTYPE_MAP.get(self.cfg.weight_type,
torch.float16)
param = param.to(torch_type)
tprint(name, param.shape)
param.contiguous().cpu().numpy().tofile(
osp.join(self.out_dir, name))
_tofile(param, osp.join(self.out_dir, name))
elif len(self.tm_params) > 0:
tm_params = self.tm_params
weight_type = self.cfg.weight_type
assert weight_type in ['fp16', 'fp32', 'int4']
assert weight_type in ['fp16', 'fp32', 'bf16', 'int4']

# currently, the tensor type should in
# [torch.float, torch.half, torch.int32]
# [torch.float, torch.half, torch.bfloat16, torch.int32]
torch_tensor = param.cuda().contiguous()
assert torch_tensor.dtype in [
torch.int32, torch.float, torch.half, torch.bfloat16
]
if torch_tensor.dtype != torch.int32:
if weight_type in ['fp16', 'int4']:
torch_tensor = torch_tensor.half()
elif weight_type == 'bf16':
torch_tensor = torch_tensor.bfloat16()
else:
torch_tensor = torch_tensor.float()
for tm_tensor in tm_params[name]:
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/turbomind/deploy/target_model/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def transpose_tensor(input: List[torch.Tensor]):
return output


@OUTPUT_MODELS.register_module(name='fp16')
@OUTPUT_MODELS.register_module(name=['fp16', 'bf16'])
class TurbomindModel(BaseOutputModel):
"""Export to turbomind fp16 format."""

Expand Down
9 changes: 8 additions & 1 deletion lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger

from .deploy.converter import get_model_format, supported_formats
from .deploy.converter import (get_model_format, supported_formats,
update_config_weight_type, update_output_format)
from .deploy.source_model.base import INPUT_MODELS
from .deploy.target_model.base import OUTPUT_MODELS, TurbomindModelConfig
from .utils import (ModelSource, check_tm_model_input, create_hf_download_args,
Expand Down Expand Up @@ -246,6 +247,12 @@ def _from_hf(self,
output_format = 'w4'
data_type = 'int4'
assert group_size > 0, f'group_size: {group_size} should > 0'
else:
output_format = update_output_format(model_name,
inferred_model_format,
model_path, output_format)
data_type = output_format
update_config_weight_type(output_format, cfg)

self.config = cfg
self.model_name = model_name
Expand Down
4 changes: 4 additions & 0 deletions src/turbomind/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ struct ReluActivation<__nv_bfloat162> {
static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)
{
const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
return turbomind::make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
#else
return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
#endif
}
};
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,39 @@ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
return b;
}

#ifdef ENABLE_BF16
template<>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
{
return cuda_cast<__nv_bfloat162, float2>(a);
}
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, float4>(const float4& a)
{
bf16_4_t b;
float2 val;
val.x = a.x;
val.y = a.y;
b.x = vec_conversion<__nv_bfloat162, float2>(val);

val.x = a.z;
val.y = a.w;
b.y = vec_conversion<__nv_bfloat162, float2>(val);

return b;
}
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
{
bf16_8_t b;
b.x = vec_conversion<__nv_bfloat162, float2>(a.x);
b.y = vec_conversion<__nv_bfloat162, float2>(a.y);
b.z = vec_conversion<__nv_bfloat162, float2>(a.z);
b.w = vec_conversion<__nv_bfloat162, float2>(a.w);
return b;
}
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

template<int THREADS_PER_KEY, typename K_vec, int N>
Expand Down Expand Up @@ -1053,6 +1086,55 @@ inline __device__ int64_t quant(uint4 a, const float scale, const float zp)
int16[3] = quant(a.w, scale, zp);
return int64;
}

// bfloat16 to int8
inline __device__ int8_t quant(__nv_bfloat16 a, const float scale, const float zp)
{
int8_t int8;
float b = bfloat16_to_float(a);
int8 = round(max(-128.f, min(127.f, (b - zp) / scale)));
return int8;
}
// bfloat16x2 to int8x2
inline __device__ int16_t quant(__nv_bfloat162 a, const float scale, const float zp)
{
union {
int8_t int8[2];
short int16;
};
float2 b = bfloat162_to_float2(a);

int8[0] = round(max(-128.f, min(127.f, (b.x - zp) / scale)));
int8[1] = round(max(-128.f, min(127.f, (b.y - zp) / scale)));
return int16;
}
// bfloat16x4 to int8x4
inline __device__ int32_t quant(bf16_4_t a, const float scale, const float zp)
{
union {
int16_t int16[2];
int32_t int32;
};

int16[0] = quant(a.x, scale, zp);
int16[1] = quant(a.y, scale, zp);
return int32;
}
// bfloat16x8 to int8x8
inline __device__ int64_t quant(bf16_8_t a, const float scale, const float zp)
{
union {
int16_t int16[4];
int64_t int64;
};

int16[0] = quant(a.x, scale, zp);
int16[1] = quant(a.y, scale, zp);
int16[2] = quant(a.z, scale, zp);
int16[3] = quant(a.w, scale, zp);
return int64;
}

// int8 to float32, then `vec_conversion` to target format
inline __device__ float dequant(int8_t a, const float scale, const float zp)
{
Expand Down
Loading
Loading