diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..5cc0062b1 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,99 @@ +# Change Log + +## Versions 0.6.0 / 0.6.1 + + * Models + * ChatGLM3 + * InternLM (contributed by @wangruohui) + * Mistral 7B (developed in collaboration with Mistral.AI) + * MQA/GQA support to MPT (and GPT) models (contributed by @bheilbrun) + * Qwen (contributed by @Tlntin and @zhaohb) + * Replit Code V-1.5 3B (external contribution) + * T5, mT5, Flan-T5 (Python runtime only) + + * Features + * Add runtime statistics related to active requests and KV cache + utilization from the batch manager (see + the [batch manager](docs/source/batch_manager.md) documentation) + * Add `sequence_length` tensor to support proper lengths in beam-search + (when beam-width > 1 - see + [tensorrt_llm/batch_manager/GptManager.h](cpp/include/tensorrt_llm/batch_manager/GptManager.h)) + * BF16 support for encoder-decoder models (Python runtime - see + [examples/enc_dec](examples/enc_dec/README.md)) + * Improvements to memory utilization (CPU and GPU - including memory + leaks) + * Improved error reporting and memory consumption + * Improved support for stop and bad words + * INT8 SmoothQuant and INT8 KV Cache support for the Baichuan models (see + [examples/baichuan](examples/baichuan/README.md)) + * INT4 AWQ Tensor Parallelism support and INT8 KV cache + AWQ/weight-only + support for the GPT-J model (see [examples/gptj](examples/gptj/README.md)) + * INT4 AWQ support for the Falcon models + (see [examples/falcon](examples/falcon/README.md)) + * LoRA support (functional preview only - limited to the Python runtime, + only QKV support and not optimized in terms of runtime performance) for + the GPT model (see the + [Run LoRA with the Nemo checkpoint](examples/gpt/README.md#Run-LoRA-with-the-Nemo-checkpoint) + in the GPT example) + * Multi-GPU support for encoder-decoder models (Python runtime - see + [examples/enc_dec](examples/enc_dec/README.md)) + * New heuristic for launching the Multi-block Masked MHA kernel (similar + to FlashDecoding - see + [decoderMaskedMultiheadAttentionLaunch.h](cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h)) + * Prompt-Tuning support for GPT and LLaMA models (see the + [Prompt-tuning](examples/gpt/README.md#Prompt-tuning) Section in the GPT example) + * Performance optimizations in various CUDA kernels + * Possibility to exclude input tokens from the output (see `excludeInputInOutput` in + [`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h)) + * Python binding for the C++ runtime (GptSession - see [`pybind`](cpp/tensorrt_llm/pybind)) + * Support for different micro batch sizes for context and generation + phases with pipeline parallelism (see `GptSession::Config::ctxMicroBatchSize` and + `GptSession::Config::genMicroBatchSize` in + [tensorrt_llm/runtime/gptSession.h](cpp/include/tensorrt_llm/runtime/gptSession.h)) + * Support for "remove input padding" for encoder-decoder models (see + [examples/enc_dec](examples/enc_dec/README.md)) + * Support for context and generation logits (see `mComputeContextLogits` and + `mComputeGenerationLogits` in + [tensorrt_llm/runtime/gptModelConfig.h](cpp/include/tensorrt_llm/runtime/gptModelConfig.h)) + * Support for `logProbs` and `cumLogProbs` (see `"output_log_probs"` and + `"cum_log_probs"` in [`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h)) + * Update to CUTLASS 3.x + + * Bug fixes + * Fix for ChatGLM2 #93 and #138 + * Fix tensor names error "RuntimeError: Tensor names + (`host_max_kv_cache_length`) in engine are not the same as expected in + the main branch" #369 + * Fix weights split issue in BLOOM when `world_size = 2` ("array split + does not result in an equal division") #374 + * Fix SmoothQuant multi-GPU failure with tensor parallelism is 2 #267 + * Fix a crash in GenerationSession if stream keyword argument is not None + #202 + * Fix a typo when calling PyNVML API [BUG] code bug #410 + * Fix bugs related to the improper management of the `end_id` for various + models [C++ and Python] + * Fix memory leaks [C++ code and Python models] + * Fix the std::alloc error when running the gptManagerBenchmark -- issue + gptManagerBenchmark std::bad_alloc error #66 + * Fix a bug in pipeline parallelism when beam-width > 1 + * Fix a bug with Llama GPTQ due to improper support of GQA + * Fix issue #88 + * Fix an issue with the Huggingface Transformers version #16 + * Fix link jump in windows readme.md #30 - by @yuanlehome + * Fix typo in batchScheduler.h #56 - by @eltociear + * Fix typo #58 - by @RichardScottOZ + * Fix Multi-block MMHA: Difference between `max_batch_size` in the engine + builder and `max_num_sequences` in TrtGptModelOptionalParams? #65 + * Fix the log message to be more accurate on KV cache #224 + * Fix Windows release wheel installation: Failed to install the release + wheel for Windows using pip #261 + * Fix missing torch dependencies: [BUG] The batch_manage.a choice error + in --cpp-only when torch's cxx_abi version is different with gcc #151 + * Fix linking error during compiling google-test & benchmarks #277 + * Fix logits dtype for Baichuan and ChatGLM: segmentation fault caused by + the lack of bfloat16 #335 + * Minor bug fixes + +## Version 0.5.0 + + * TensorRT-LLM v0.5.0 is the first public release. diff --git a/README.md b/README.md index 96bff448b..fde2f02bf 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ TensorRT-LLM [![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/) [![cuda](https://img.shields.io/badge/cuda-12.2-green)](https://developer.nvidia.com/cuda-downloads) [![trt](https://img.shields.io/badge/TRT-9.2-green)](https://developer.nvidia.com/tensorrt) -[![version](https://img.shields.io/badge/release-0.7.0-green)](./setup.py) +[![version](https://img.shields.io/badge/release-0.7.1-green)](./setup.py) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) [Architecture](./docs/source/architecture.md)   |   [Results](./docs/source/performance.md)   |   [Examples](./examples/)   |   [Documentation](./docs/source/) @@ -108,9 +108,7 @@ concepts used in TensorRT-LLM, we recommend you to read the following ## Installation -*For Windows installation, see [`Windows`](windows/README.md).* - -TensorRT-LLM must be built from source, instructions can be found +The documentation for installing TensorRT-LLM can be found [here](./docs/source/installation.md). An image of a Docker container with TensorRT-LLM and its Triton Inference Server Backend will be made available soon. @@ -118,6 +116,8 @@ soon. The remaining commands in that document must be executed from the TensorRT-LLM container. +*For Windows installation, see [`Windows`](windows/README.md).* + ## Quick Start To create a TensorRT engine for an existing model, there are 3 steps: @@ -379,103 +379,43 @@ For example: `mpirun -n 1 python3 examples/gpt/build.py ...` ### Change Log -#### Version 0.6.1 - - * Models - * ChatGLM3 - * InternLM (contributed by @wangruohui) - * Mistral 7B (developed in collaboration with Mistral.AI) - * MQA/GQA support to MPT (and GPT) models (contributed by @bheilbrun) - * Qwen (contributed by @Tlntin and @zhaohb) - * Replit Code V-1.5 3B (external contribution) - * T5, mT5, Flan-T5 (Python runtime only) - - * Features - * Add runtime statistics related to active requests and KV cache - utilization from the batch manager (see - the [batch manager](docs/source/batch_manager.md) documentation) - * Add `sequence_length` tensor to support proper lengths in beam-search - (when beam-width > 1 - see - [tensorrt_llm/batch_manager/GptManager.h](cpp/include/tensorrt_llm/batch_manager/GptManager.h)) - * BF16 support for encoder-decoder models (Python runtime - see - [examples/enc_dec](examples/enc_dec/README.md)) - * Improvements to memory utilization (CPU and GPU - including memory - leaks) - * Improved error reporting and memory consumption - * Improved support for stop and bad words - * INT8 SmoothQuant and INT8 KV Cache support for the Baichuan models (see - [examples/baichuan](examples/baichuan/README.md)) - * INT4 AWQ Tensor Parallelism support and INT8 KV cache + AWQ/weight-only - support for the GPT-J model (see [examples/gptj](examples/gptj/README.md)) - * INT4 AWQ support for the Falcon models - (see [examples/falcon](examples/falcon/README.md)) - * LoRA support (functional preview only - limited to the Python runtime, - only QKV support and not optimized in terms of runtime performance) for - the GPT model (see the - [Run LoRA with the Nemo checkpoint](examples/gpt/README.md#Run-LoRA-with-the-Nemo-checkpoint) - in the GPT example) - * Multi-GPU support for encoder-decoder models (Python runtime - see - [examples/enc_dec](examples/enc_dec/README.md)) - * New heuristic for launching the Multi-block Masked MHA kernel (similar - to FlashDecoding - see - [decoderMaskedMultiheadAttentionLaunch.h](cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h)) - * Prompt-Tuning support for GPT and LLaMA models (see the - [Prompt-tuning](examples/gpt/README.md#Prompt-tuning) Section in the GPT example) - * Performance optimizations in various CUDA kernels - * Possibility to exclude input tokens from the output (see `excludeInputInOutput` in - [`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h)) - * Python binding for the C++ runtime (GptSession - see [`pybind`](cpp/tensorrt_llm/pybind)) - * Support for different micro batch sizes for context and generation - phases with pipeline parallelism (see `GptSession::Config::ctxMicroBatchSize` and - `GptSession::Config::genMicroBatchSize` in - [tensorrt_llm/runtime/gptSession.h](cpp/include/tensorrt_llm/runtime/gptSession.h)) - * Support for "remove input padding" for encoder-decoder models (see - [examples/enc_dec](examples/enc_dec/README.md)) - * Support for context and generation logits (see `mComputeContextLogits` and - `mComputeGenerationLogits` in - [tensorrt_llm/runtime/gptModelConfig.h](cpp/include/tensorrt_llm/runtime/gptModelConfig.h)) - * Support for `logProbs` and `cumLogProbs` (see `"output_log_probs"` and - `"cum_log_probs"` in [`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h)) - * Update to CUTLASS 3.x - - * Bug fixes - * Fix for ChatGLM2 #93 and #138 - * Fix tensor names error "RuntimeError: Tensor names - (`host_max_kv_cache_length`) in engine are not the same as expected in - the main branch" #369 - * Fix weights split issue in BLOOM when `world_size = 2` ("array split - does not result in an equal division") #374 - * Fix SmoothQuant multi-GPU failure with tensor parallelism is 2 #267 - * Fix a crash in GenerationSession if stream keyword argument is not None - #202 - * Fix a typo when calling PyNVML API [BUG] code bug #410 - * Fix bugs related to the improper management of the `end_id` for various - models [C++ and Python] - * Fix memory leaks [C++ code and Python models] - * Fix the std::alloc error when running the gptManagerBenchmark -- issue - gptManagerBenchmark std::bad_alloc error #66 - * Fix a bug in pipeline parallelism when beam-width > 1 - * Fix a bug with Llama GPTQ due to improper support of GQA - * Fix issue #88 - * Fix an issue with the Huggingface Transformers version #16 - * Fix link jump in windows readme.md #30 - by @yuanlehome - * Fix typo in batchScheduler.h #56 - by @eltociear - * Fix typo #58 - by @RichardScottOZ - * Fix Multi-block MMHA: Difference between `max_batch_size` in the engine - builder and `max_num_sequences` in TrtGptModelOptionalParams? #65 - * Fix the log message to be more accurate on KV cache #224 - * Fix Windows release wheel installation: Failed to install the release - wheel for Windows using pip #261 - * Fix missing torch dependencies: [BUG] The batch_manage.a choice error - in --cpp-only when torch's cxx_abi version is different with gcc #151 - * Fix linking error during compiling google-test & benchmarks #277 - * Fix logits dtype for Baichuan and ChatGLM: segmentation fault caused by - the lack of bfloat16 #335 - * Minor bug fixes - -#### Version 0.5.0 - - * TensorRT-LLM v0.5.0 is the first public release. +#### Versions 0.7.0 / 0.7.1 + +* Models + - BART and mBART support in encoder-decoder models + - FairSeq Neural Machine Translation (NMT) family + - Mixtral-8x7B model + - Support weight loading for HuggingFace Mixtral model + - OpenAI Whisper + - Mixture of Experts support + - MPT - Int4 AWQ / SmoothQuant support + - Baichuan FP8 quantization support +* Features + - [Preview] Speculative decoding + - Add Python binding for `GptManager` + - Add a Python class `ModelRunnerCpp` that wraps C++ `gptSession` + - System prompt caching + - Enable split-k for weight-only cutlass kernels + - FP8 KV cache support for XQA kernel + - New Python builder API and `trtllm-build` command(already applied to [blip2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/blip2) and [OPT](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/opt#3-build-tensorrt-engines) ) + - Support `StoppingCriteria` and `LogitsProcessor` in Python generate API (thanks to the contribution from @zhang-ge-hao) + - fMHA support for chunked attention and paged kv cache +* Bug fixes + - Fix tokenizer usage in quantize.py #288, thanks to the contribution from @0xymoro + - Fix LLaMa with LoRA error #637 + - Fix LLaMA GPTQ failure #580 + - Fix Python binding for InferenceRequest issue #528 + - Fix CodeLlama SQ accuracy issue #453 +* Performance + - MMHA optimization for MQA and GQA + - LoRA optimization: cutlass grouped gemm + - Optimize Hopper warp specialized kernels + - Optimize AllReduce for parallel attention on Falcon and GPT-J + - Enable split-k for weight-only cutlass kernel when SM>=75 +* Documentation + - Add [documentation for new builder workflow](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/new_workflow.md) + +#### For history change log, please see [CHANGELOG.md](./CHANGELOG.md). ### Known Issues diff --git a/benchmarks/python/allowed_configs.py b/benchmarks/python/allowed_configs.py index 9562c7929..86a623216 100644 --- a/benchmarks/python/allowed_configs.py +++ b/benchmarks/python/allowed_configs.py @@ -232,6 +232,7 @@ class ModelConfig: builder_opt=None, pre_norm=False, do_layer_norm_before=False, + use_custom_all_reduce=False, )), "opt_2.7b": ModelConfig(name="opt_2.7b", @@ -250,6 +251,7 @@ class ModelConfig: builder_opt=None, pre_norm=False, do_layer_norm_before=True, + use_custom_all_reduce=False, )), "opt_6.7b": ModelConfig(name="opt_6.7b", @@ -268,6 +270,7 @@ class ModelConfig: builder_opt=None, pre_norm=False, do_layer_norm_before=True, + use_custom_all_reduce=False, )), "opt_66b": ModelConfig(name="opt_66b", @@ -286,6 +289,7 @@ class ModelConfig: builder_opt=None, pre_norm=True, do_layer_norm_before=True, + use_custom_all_reduce=False, )), "llama_7b": ModelConfig(name="llama_7b", @@ -512,6 +516,7 @@ class ModelConfig: max_output_len=200, builder_opt=None, remove_input_padding=False, + use_custom_all_reduce=False, )), "bloom_560m": ModelConfig(name="bloom_560m", @@ -528,6 +533,7 @@ class ModelConfig: max_input_len=1024, max_output_len=1024, builder_opt=None, + use_custom_all_reduce=False, )), "bloom_176b": ModelConfig(name="bloom_176b", @@ -544,6 +550,7 @@ class ModelConfig: max_input_len=1024, max_output_len=1024, builder_opt=None, + use_custom_all_reduce=False, )), "bert_base": ModelConfig(name="bert_base", diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a index 3871f74b6..284a97fb8 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:717c7aac842fe8d8cc52e07740d6a158889ab1ae07d02e6575e1eb3e640848c1 +oid sha256:c98f8854a1d8967775c94bb96a5a37dca190f1fa808b3f846db870c30cce2bfd size 1801434 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index 4d10726c8..46f4f526e 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2d9f22df17f665e526b1db997e6de4dfa2ca5e22c1a13cb125fc02e07389e43f +oid sha256:ed514ea9c0634d4fc95a0a53e7719f72ec8e2b0a596d1e8a60516652f66b8ca2 size 1819282 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt index c0cdafbbb..0c8a09832 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -516ff2db1e17536e92150b0c05200589 libtensorrt_llm_batch_manager_static.a -428a500536705184a1aad8aaf5c9c0ca libtensorrt_llm_batch_manager_static.pre_cxx11.a -33b6139e3bb108df093aab3a6de38a87f1f1e2dd commit +ffe001b0bf9ee66b3e3696423d6d09a2 libtensorrt_llm_batch_manager_static.a +3657ea3400959a64be77c12d8598dd72 libtensorrt_llm_batch_manager_static.pre_cxx11.a +9a775b3dbb20444f130f13f90e675cc971fe7e15 commit diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index a32518f19..8af84f92d 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4b4c8f559dddb001f8355a0162423af385e6803376d2cb4f9b9c37f7840659e0 +oid sha256:542ccb1497c91d82048eb9bec07527317c702e9c7466923d8b61e12374e087fb size 1722062 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index b7bcd4f46..4570a54e8 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b156f4fbdafcb12ae7c39be35da04b10fc42cd911f15f03f892c2d118ec3825a +oid sha256:51a4dc2d8e2b7624976fb5f8370b8f44e1c25f038bd3915e1c31eb63c60b7c22 size 1715766 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt index cca01d95f..b621b1581 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -0403e89a23fd77aed43cac0ecd8136cf libtensorrt_llm_batch_manager_static.a -9fa2a1c18860eaf226a6ce61a8e3ed5d libtensorrt_llm_batch_manager_static.pre_cxx11.a +bb69bf376c5f955c327e867049639d78 libtensorrt_llm_batch_manager_static.a +14b107676c74ce17bfc8ce950b36a984 libtensorrt_llm_batch_manager_static.pre_cxx11.a diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp index 79fd4b245..89b987f29 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp @@ -121,7 +121,8 @@ class FusedMHARunnerV2::mhaImpl if (mLaunchParams.useKernelWithoutAlibi) { // The kernel adopts the log2f optimziation. - set_alpha(params.scale_bmm1, scale_bmm1 * float(M_LOG2E), DATA_TYPE_FP32); + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + set_alpha(params.scale_bmm1, scale_bmm1 * float(kLog2e), DATA_TYPE_FP32); } else { diff --git a/docs/source/installation.md b/docs/source/installation.md index 6cbc185a7..aed68dbaa 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -1,6 +1,7 @@ -# Build TensorRT-LLM +# TensorRT-LLM Installation - [Overview](#overview) +- [Install From the Wheel Package](#install-from-the-wheel-package) - [Fetch the Sources](#fetch-the-sources) - [Build TensorRT-LLM in One Step](#build-tensorrt-llm-in-one-step) - [Build Step-by-step](#build-step-by-step) @@ -13,16 +14,28 @@ ## Overview -This document contains instructions to build TensorRT-LLM from sources. TensorRT-LLM depends on the latest versions of -TensorRT and -[Polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy) -which are distributed separately, and should be copied into this repository. - +This document contains instructions to install TensorRT-LLM. We recommend the use of [Docker](https://www.docker.com) to build and run TensorRT-LLM. Instructions to install an environment to run Docker containers for the NVIDIA platform can be found [here](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). +## Install From the Wheel Package + +After installing CUDA 12.2 according to the [instructions](https://developer.nvidia.com/cuda-toolkit), +please execute the following commands to install TensorRT-LLM. + +```bash +# Install dependencies, TensorRT-LLM requires Python 3.10 +apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev +# Install the latest version of TensorRT-LLM +pip3 install tensorrt_llm -U --extra-index-url https://pypi.nvidia.com +# Check installation +python3 -c "import tensorrt_llm; print(tensorrt_llm.__version__)" +``` + +Note that users who have debugging needs or use the GNU C++11 ABI need to compile TensorRT-LLM from source. + ## Fetch the Sources The first step to build TensorRT-LLM is to fetch the sources: diff --git a/examples/bloom/convert_checkpoint.py b/examples/bloom/convert_checkpoint.py index 9f21c4da3..b2dcafc0c 100644 --- a/examples/bloom/convert_checkpoint.py +++ b/examples/bloom/convert_checkpoint.py @@ -4,17 +4,23 @@ import os import time from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, Iterable, Optional, Union import numpy as np import safetensors import torch import torch.nn as nn from tqdm import tqdm -from transformers import BloomForCausalLM, BloomTokenizerFast +from transformers import BloomConfig, BloomForCausalLM, BloomTokenizerFast from transformers.models.bloom.modeling_bloom import BloomBlock from transformers.pytorch_utils import Conv1D +# isort: off import tensorrt_llm +from tensorrt_llm import logger +from tensorrt_llm.quantization import QuantMode +# isort: on @torch.no_grad() @@ -150,7 +156,7 @@ def reorder_torch_qkv_weight_or_bias(v, model, is_bias=False): def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--model_dir', type=str, default=None) + parser.add_argument('--model_dir', type=Path, default=None) parser.add_argument('--world_size', type=int, default=1, @@ -202,7 +208,7 @@ def parse_arguments(): 'Note: the flag might not take effect when the criteria are not met.') parser.add_argument( '--output_dir', - type=str, + type=Path, default='baichuan_tllm_checkpoint', help='The path to save the baichuan TensorRT-LLM checkpoint') parser.add_argument( @@ -236,6 +242,12 @@ def parse_arguments(): help= 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' ) + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers to convert checkpoint in parallel') + parser.add_argument('--log_level', type=str, default='info') args = parser.parse_args() return args @@ -359,9 +371,9 @@ def split(v, tp_size, idx, dim=0): if tp_size == 1: return v if len(v.shape) == 1: - return torch.chunk(v, tp_size)[idx].contiguous() + return torch.chunk(v, tp_size)[idx].clone() else: - return torch.chunk(v, tp_size, dim=dim)[idx].contiguous() + return torch.chunk(v, tp_size, dim=dim)[idx].clone() def reorder_qkv_weight_or_bias(v, n_head, n_hidden, is_bias=False): @@ -449,6 +461,32 @@ def get_tllm_linear_weight(weight, return results +def add_tllm_weight( + weights: Dict[str, torch.Tensor], + name: str, + param: torch.Tensor, + quant_mode: QuantMode = QuantMode(0), +): + assert name not in weights, f'{name} is already added.' + + if name.endswith('.weight') and quant_mode.is_weight_only(): + if quant_mode.is_int8_weight_only(): + quant_dtype = torch.int8 + elif quant_mode.is_int4_weight_only(): + quant_dtype = torch.quint4x2 + else: + raise ValueError( + f'Invalid configuration, got quant_mode={quant_mode}') + processed_torch_weights, torch_weight_scales = \ + torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + param.t().contiguous(), quant_dtype) + weights[name] = processed_torch_weights + scale_name = name.replace('.weight', '.per_channel_scale') + weights[scale_name] = torch_weight_scales + else: + weights[name] = param.contiguous() + + @torch.no_grad() def smooth_bloom_model(model, scales, alpha, bloom_qkv_param, bloom_smoother): # Smooth the activation and weights with smoother = $\diag{s}$ @@ -736,7 +774,6 @@ def convert_hf_bloom(hf_bloom, rank=rank, cat_dim=-1)) else: - split_v = split_matrix_tp(mlp_fc_weight, tensor_parallel, rank, @@ -822,7 +859,170 @@ def convert_hf_bloom(hf_bloom, return weights -if __name__ == '__main__': +def rename_hf_to_tllm(name: str): + """ Rename a HF parameter name by the corresponding TRT-LLM style name. """ + if 'word_embeddings_layernorm.' in name: + name = name.replace('word_embeddings_layernorm', 'ln_embed') + if not name.startswith('transformer.'): + name = f'transformer.{name}' + elif 'word_embeddings.' in name: + name = name.replace('word_embeddings', 'embedding') + if name.startswith(('ln_embed.', 'embedding.', 'ln_f.')): + name = f'transformer.{name}' + + # Parameter names in layers + if name.startswith(('transformer.h.', 'h.')): + import re + name = re.sub(r'^(transformer.h.|h.)', 'transformer.layers.', name, 1) + if 'post_attention_layernorm' in name: + name = name.replace('post_attention_layernorm', 'post_layernorm') + elif 'self_attention.query_key_value' in name: + name = name.replace('self_attention.query_key_value', 'attention.qkv') + elif 'self_attention.dense' in name: + name = name.replace('self_attention.dense', 'attention.dense') + elif 'mlp.dense_h_to_4h' in name: + name = name.replace('mlp.dense_h_to_4h', 'mlp.fc') + elif 'mlp.dense_4h_to_h' in name: + name = name.replace('mlp.dense_4h_to_h', 'mlp.proj') + return name + + +def contain_any(name: str, words: Iterable[str]): + for word in words: + if word in name: + return True + return False + + +def convert_from_hf_checkpoint( + model_dir: Union[str, Path], + rank=0, + tensor_parallel=1, + dtype: Union[str, torch.dtype] = torch.float32, + use_parallel_embedding: bool = False, + sharding_dim: int = 0, + share_embedding_table: bool = False, + use_weight_only: bool = False, + plugin_weight_only_quant_type: torch.dtype = torch.int8, + use_smooth_quant: bool = False, + bloom_qkv_param: Optional[Dict] = None, + smooth_act_range: Optional[Any] = None, + smoother: Optional[Any] = None, + per_channel: bool = False, + per_token: bool = False, + int8_kv_cache: bool = False, +): + logger.info('Loading weights from HF BLOOM...') + tik = time.time() + + weights = {} + hf_config = BloomConfig.from_pretrained(model_dir) + num_heads = hf_config.n_head + hidden_size = hf_config.hidden_size + if isinstance(dtype, str): + dtype = tensorrt_llm.str_dtype_to_torch(dtype) + tp_rank = rank + tp_size = tensor_parallel + + if use_smooth_quant: + quant_mode = QuantMode.use_smooth_quant(per_token, per_channel) + elif use_weight_only: + quant_mode = QuantMode.from_description( + quantize_weights=True, + quantize_activations=False, + per_token=False, + per_channel=False, + use_int8_kv_cache=int8_kv_cache, + use_int4_weights=plugin_weight_only_quant_type == torch.quint4x2) + else: + quant_mode = QuantMode(0) + + def is_bias(_name): + return 'bias' in _name + + # Load examples/common/utils.py + import sys + sys.path.append(str(Path(__file__).parent.parent)) + from common import utils + + for model_file in utils.iterate_shard_files(model_dir, tp_rank): + logger.debug(f'Loading file {str(model_file)}...') + model_params = utils.load_state_dict(model_file, dtype=dtype) + for name, param in model_params.items(): + logger.debug(f'Converting weight {name}...') + tllm_name = rename_hf_to_tllm(name) + param = param.detach().cpu() + + # TODO: Support SmmothQuant. + + if 'self_attention.query_key_value' in name: + if not is_bias(name): + param = split_qkv_tp(param, num_heads, hidden_size, tp_size, + tp_rank) + # TODO: Add KV scalers when quantizing KV cache. + else: + param = split_qkv_bias_tp(param, num_heads, hidden_size, + tp_size, tp_rank) + add_tllm_weight(weights, tllm_name, param, quant_mode) + elif 'self_attention.dense' in name: + if not is_bias(name): + param = split_matrix_tp(param, tp_size, tp_rank, dim=1) + add_tllm_weight(weights, tllm_name, param, quant_mode) + elif 'mlp.dense_h_to_4h' in name: + if not is_bias(name): + param = split_matrix_tp(param, tp_size, tp_rank, dim=0) + else: + param = split_matrix_tp(param, tp_size, tp_rank, dim=0) + add_tllm_weight(weights, tllm_name, param, quant_mode) + elif 'mlp.dense_4h_to_h' in name: + if not is_bias(name): + param = split_matrix_tp(param, tp_size, tp_rank, dim=1) + add_tllm_weight(weights, tllm_name, param, quant_mode) + elif 'word_embeddings.' in name: + if not share_embedding_table: + # TODO: safetensor doesn't allow to save a shared tensor. + # Currently, we clone the weight but to save the disk, it + # would be better to skip saving lm_head weights and + # handle it at the loading phase. + lm_head = split_matrix_tp(param, tp_size, tp_rank, dim=0) + weights['lm_head.weight'] = lm_head.clone() + if not use_parallel_embedding: + weights[tllm_name] = param + else: + assert hf_config.vocab_size % tp_size == 0 + weights[tllm_name] = split_matrix_tp(param, + tp_size, + tp_rank, + dim=sharding_dim) + elif contain_any(name, + ('input_layernorm', 'post_attention_layernorm', + 'word_embeddings_layernorm.', 'ln_f.')): + weights[tllm_name] = param + del model_params + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}') + return weights + + +def do_convert_from_ckpt(args): + return (args.model_dir.exists() and args.smoothquant is None + and not args.use_weight_only) + + +def convert(worker_rank, args, convert_args): + convert_from_ckpt = do_convert_from_ckpt(args) + for rank in range(worker_rank, args.world_size, args.workers): + if convert_from_ckpt: + weights = convert_from_hf_checkpoint(rank=rank, **convert_args) + else: + weights = convert_hf_bloom(rank=rank, **convert_args) + safetensors.torch.save_file(weights, + args.output_dir / f'rank{rank}.safetensors') + + +def main(): # TODO(qijun): Currently, the convert script depends on a torch op: # torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix, # which is included in tensorrt_llm Python package. Otherwise, the convert @@ -831,17 +1031,12 @@ def convert_hf_bloom(hf_bloom, print(tensorrt_llm.__version__) args = parse_arguments() + logger.set_level(args.log_level) tik = time.time() - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) + args.output_dir.mkdir(exist_ok=True, parents=True) - hf_bloom = BloomForCausalLM.from_pretrained( - args.model_dir, - torch_dtype="auto", - device_map="auto" if not args.use_weight_only else None, - trust_remote_code=True) - hf_config = hf_bloom.config + hf_config = BloomConfig.from_pretrained(args.model_dir) config = { 'architecture': hf_config.architectures[0], 'dtype': args.dtype, @@ -870,6 +1065,24 @@ def convert_hf_bloom(hf_bloom, 'embedding_sharding_dim': args.embedding_sharding_dim, 'share_embedding_table': args.use_embedding_sharing, } + + with (args.output_dir / 'config.json').open('w') as f: + json.dump(config, f, indent=4) + + # TODO: convert_from_hf_checkpoint is memory efficient but has not + # supported quantization yet. Will enable once implemented. + convert_from_ckpt = do_convert_from_ckpt(args) + if not convert_from_ckpt: + logger.info(f'Convert by using model') + hf_bloom = BloomForCausalLM.from_pretrained( + args.model_dir, + torch_dtype="auto", + device_map="auto" if not args.use_weight_only else None, + trust_remote_code=True) + else: + logger.info(f'Convert by using checkpoint') + hf_bloom = None + act_range = {} bloom_qkv_param = {} bloom_smoother = {} @@ -887,36 +1100,47 @@ def convert_hf_bloom(hf_bloom, smooth_bloom_model(hf_bloom, act_range, args.smoothquant, bloom_qkv_param, bloom_smoother) - with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) - if args.weight_only_precision == 'int8': plugin_weight_only_quant_type = torch.int8 elif args.weight_only_precision == 'int4': plugin_weight_only_quant_type = torch.quint4x2 + else: + plugin_weight_only_quant_type = None + + convert_args = dict( + tensor_parallel=args.world_size, + dtype=args.dtype, + use_weight_only=args.use_weight_only, + plugin_weight_only_quant_type=plugin_weight_only_quant_type, + use_parallel_embedding=args.use_parallel_embedding, + sharding_dim=args.embedding_sharding_dim, + share_embedding_table=args.use_embedding_sharing, + use_smooth_quant=args.smoothquant, + smooth_act_range=act_range, + bloom_qkv_param=bloom_qkv_param, + smoother=bloom_smoother, + per_channel=args.per_channel, + per_token=args.per_token, + int8_kv_cache=args.int8_kv_cache, + ) + if convert_from_ckpt: + convert_args['model_dir'] = args.model_dir + else: + convert_args['hf_bloom'] = hf_bloom - for rank in range(args.world_size): - weights = convert_hf_bloom( - hf_bloom, - rank, - args.world_size, - dtype=args.dtype, - use_weight_only=args.use_weight_only, - plugin_weight_only_quant_type=plugin_weight_only_quant_type, - use_parallel_embedding=args.use_parallel_embedding, - sharding_dim=args.embedding_sharding_dim, - share_embedding_table=args.use_embedding_sharing, - use_smooth_quant=args.smoothquant, - smooth_act_range=act_range, - bloom_qkv_param=bloom_qkv_param, - smoother=bloom_smoother, - per_channel=args.per_channel, - per_token=args.per_token, - int8_kv_cache=args.int8_kv_cache) - - safetensors.torch.save_file( - weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) + if args.workers == 1: + convert(0, args, convert_args) + else: + if args.workers > args.world_size: + args.workers = args.world_size + logger.info(f'Convert checkpoint using {args.workers} workers.') + import torch.multiprocessing as mp + mp.spawn(convert, nprocs=args.workers, args=(args, convert_args)) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) print(f'Total time of converting checkpoints: {t}') + + +if __name__ == '__main__': + main() diff --git a/examples/quantization/README.md b/examples/quantization/README.md index 2f536a173..0b48fdfe2 100644 --- a/examples/quantization/README.md +++ b/examples/quantization/README.md @@ -21,11 +21,7 @@ docker run --gpus all --ipc=host --ulimit memlock=-1 --shm-size=20g -it &1 | awk '{print $2}' | awk -F. '{print $1$2}') -# Download and install the AMMO package from the DevZone. -wget https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz -tar -xzf nvidia_ammo-0.5.0.tar.gz -pip install nvidia_ammo-0.5.0/nvidia_ammo-0.5.0-cp$python_version-cp$python_version-linux_x86_64.whl +pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo~=0.5.0 # Install the additional requirements cd pip install -r requirements.txt diff --git a/requirements-dev.txt b/requirements-dev.txt index 8ee66c697..d74866670 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,5 +12,4 @@ pytest-cov pytest-forked pytest-xdist rouge_score -# tensorrt>=8.6.0 typing-extensions==4.8.0 diff --git a/requirements.txt b/requirements.txt index 7952d5f95..d1e847b48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/cu121 +--extra-index-url https://pypi.nvidia.com accelerate==0.20.3 build colored @@ -9,7 +11,7 @@ numpy onnx>=1.12.0 polygraphy sentencepiece>=0.1.99 -tensorrt>=8.6.0 +tensorrt==9.2.0.post12.dev5 torch transformers==4.33.1 wheel diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index edd811f2b..692913f56 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -196,8 +196,7 @@ def main(build_type: str = "Release", def get_pybind_lib(): pybind_build_dir = (build_dir / "tensorrt_llm" / "pybind") if platform.system() == "Windows": - pybind_lib = list( - (pybind_build_dir / str(build_type)).glob("bindings.*.pyd")) + pybind_lib = list(pybind_build_dir.glob("bindings.*.pyd")) else: pybind_lib = list(pybind_build_dir.glob("bindings.*.so")) diff --git a/tensorrt_llm/models/quantized/quant.py b/tensorrt_llm/models/quantized/quant.py index 787d0c63f..747a24deb 100644 --- a/tensorrt_llm/models/quantized/quant.py +++ b/tensorrt_llm/models/quantized/quant.py @@ -106,8 +106,9 @@ def _smooth_quantize_llama(model, quant_mode): bias=False) assert hasattr(layer, "mlp"), "The layer has no mlp" - assert not model.moe_config.has_moe( - ), "MOE does not support smooth quant" + if hasattr(model, "moe_config"): + assert not model.moe_config.has_moe( + ), "MOE does not support smooth quant" layer.mlp = SmoothQuantGatedMLP(hidden_size=model.hidden_size, ffn_hidden_size=layer.mlp_hidden_size, hidden_act=layer.hidden_act, diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index 60328d501..8e6e29b93 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.7.0" +__version__ = "0.7.1" diff --git a/windows/README.md b/windows/README.md index 89f894967..03bf66f9d 100644 --- a/windows/README.md +++ b/windows/README.md @@ -60,7 +60,7 @@ It may be useful to create a single folder for holding TensorRT-LLM and its depe Clone TensorRT-LLM: ``` -git clone https://github.com/NVIDIA/TensorRT-LLM.git +git clone --branch rel https://github.com/NVIDIA/TensorRT-LLM.git cd TensorRT-LLM git submodule update --init --recursive ```