From 855fbfe145a4bd9cd16a3292d182e6d20defd703 Mon Sep 17 00:00:00 2001 From: ckhened Date: Tue, 9 Apr 2024 07:50:34 -0700 Subject: [PATCH] RAG end to end perf measurements using Langsmith (#60) Co-authored-by: Antony Vance --- ChatQnA/langchain/test/README.md | 31 ++ ChatQnA/langchain/test/end_to_end_rag_test.py | 248 +++++++++ ChatQnA/langchain/test/tgi_gaudi.ipynb | 496 ++++++++++++++++++ 3 files changed, 775 insertions(+) create mode 100644 ChatQnA/langchain/test/README.md create mode 100644 ChatQnA/langchain/test/end_to_end_rag_test.py create mode 100644 ChatQnA/langchain/test/tgi_gaudi.ipynb diff --git a/ChatQnA/langchain/test/README.md b/ChatQnA/langchain/test/README.md new file mode 100644 index 000000000..ccdb0f79f --- /dev/null +++ b/ChatQnA/langchain/test/README.md @@ -0,0 +1,31 @@ +## Performance measurement tests with langsmith + +Pre-requisite: Signup in langsmith [https://www.langchain.com/langsmith] and get the api token
+ +### Steps to run perf measurements with tgi_gaudi.ipynb jupyter notebook + +1. This dir is mounted at /test in qna-rag-redis-server +2. Make sure redis container and LLM serving is up and running +3. enter into qna-rag-redis-server container and start jupyter notebook server (can specify needed IP address and jupyter will run on port 8888) + ``` + docker exec -it qna-rag-redis-server bash + cd /test + jupyter notebook --allow-root --ip=X.X.X.X + ``` +4. Launch jupyter notebook in your browser and open the tgi_gaudi.ipynb notebook +5. Update all the configuration parameters in the second cell of the notebook +6. Clear all the cells and run all the cells +7. The output of the last cell which calls client.run_on_dataset() will run the langchain Q&A test and captures measurements in the langsmith server. The URL to access the test result can be obtained from the output of the command +

+ +### Steps to run perf measurements with end_to_end_rag_test.py python script + +1. This dir is mounted at /test in qna-rag-redis-server +2. Make sure redis container and LLM serving is up and running +3. enter into qna-rag-redis-server container and run the python script + ``` + docker exec -it qna-rag-redis-server bash + cd /test + python end_to_end_rag_test.py -l "" -e -m -ht "" -lt -dbs "" -dbu "" -dbi "" -d "" + ``` +4. Check the results in langsmith server diff --git a/ChatQnA/langchain/test/end_to_end_rag_test.py b/ChatQnA/langchain/test/end_to_end_rag_test.py new file mode 100644 index 000000000..bfaff3124 --- /dev/null +++ b/ChatQnA/langchain/test/end_to_end_rag_test.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python + +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import uuid +from operator import itemgetter +from typing import Any, List, Mapping, Optional, Sequence + +from langchain.prompts import ChatPromptTemplate +from langchain.schema.document import Document +from langchain.schema.output_parser import StrOutputParser +from langchain.schema.runnable.passthrough import RunnableAssign +from langchain_benchmarks import clone_public_dataset, registry +from langchain_benchmarks.rag import get_eval_config +from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings +from langchain_community.llms import HuggingFaceEndpoint +from langchain_community.vectorstores import Redis +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.prompt_values import ChatPromptValue +from langchain_openai import ChatOpenAI +from langsmith.client import Client +from transformers import AutoTokenizer, LlamaForCausalLM + +# Parameters and settings +ENDPOINT_URL_GAUDI2 = "http://localhost:8000" +ENDPOINT_URL_VLLM = "http://localhost:8001/v1" +TEI_ENDPOINT = "http://localhost:8002" +LANG_CHAIN_DATASET = "" +HF_MODEL_NAME = "" +PROMPT_TOKENS_LEN = 214 # Magic number for prompt template tokens +MAX_INPUT_TOKENS = 1024 +MAX_OUTPUT_TOKENS = 128 + +# Generate a unique run ID for this experiment +run_uid = uuid.uuid4().hex[:6] + +tokenizer = None + + +def crop_tokens(prompt, max_len): + inputs = tokenizer(prompt, return_tensors="pt") + inputs_cropped = inputs["input_ids"][0:, 0:max_len] + prompt_cropped = tokenizer.batch_decode( + inputs_cropped, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + return prompt_cropped + + +# After the retriever fetches documents, this +# function formats them in a string to present for the LLM +def format_docs(docs: Sequence[Document]) -> str: + formatted_docs = [] + for i, doc in enumerate(docs): + doc_string = ( + f"\n" + f"{doc.metadata.get('source')}\n" + f"{doc.page_content[0:]}\n" + "" + ) + # Truncate the retrieval data based on the max tokens required + cropped = crop_tokens(doc_string, MAX_INPUT_TOKENS - PROMPT_TOKENS_LEN) + + formatted_docs.append(cropped) # doc_string + formatted_str = "\n".join(formatted_docs) + return f"\n{formatted_str}\n" + + +def ingest_dataset(args, langchain_docs): + clone_public_dataset(langchain_docs.dataset_id, dataset_name=langchain_docs.name) + docs = list(langchain_docs.get_docs()) + embedder = HuggingFaceHubEmbeddings(model=args.embedding_endpoint_url) + + _ = Redis.from_texts( + # appending this little bit can sometimes help with semantic retrieval + # especially with multiple companies + texts=[d.page_content for d in docs], + metadatas=[d.metadata for d in docs], + embedding=embedder, + index_name=args.db_index, + index_schema=args.db_schema, + redis_url=args.db_url, + ) + + +def GetLangchainDataset(args): + registry_retrieved = registry.filter(Type="RetrievalTask") + langchain_docs = registry_retrieved[args.langchain_dataset] + return langchain_docs + + +def buildchain(args): + embedder = HuggingFaceHubEmbeddings(model=args.embedding_endpoint_url) + vectorstore = Redis.from_existing_index( + embedding=embedder, index_name=args.db_index, schema=args.db_schema, redis_url=args.db_url + ) + retriever = vectorstore.as_retriever(search_kwargs={"k": 1}) + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are an AI assistant answering questions about LangChain." + "\n{context}\n" + "Respond solely based on the document content.", + ), + ("human", "{question}"), + ] + ) + + llm = None + match args.llm_service_api: + case "tgi-gaudi": + llm = HuggingFaceEndpoint( + endpoint_url=args.llm_endpoint_url, + max_new_tokens=MAX_OUTPUT_TOKENS, + top_k=10, + top_p=0.95, + typical_p=0.95, + temperature=1.0, + repetition_penalty=1.03, + streaming=False, + truncate=1024, + ) + case "vllm-openai": + llm = ChatOpenAI( + model=args.model_name, + openai_api_key="EMPTY", + openai_api_base=args.llm_endpoint_url, + max_tokens=MAX_OUTPUT_TOKENS, + temperature=1.0, + top_p=0.95, + streaming=False, + frequency_penalty=1.03, + ) + + response_generator = (prompt | llm | StrOutputParser()).with_config( + run_name="GenerateResponse", + ) + + # This is the final response chain. + # It fetches the "question" key from the input dict, + # passes it to the retriever, then formats as a string. + + chain = ( + RunnableAssign( + {"context": (itemgetter("question") | retriever | format_docs).with_config(run_name="FormatDocs")} + ) + # The "RunnableAssign" above returns a dict with keys + # question (from the original input) and + # context: the string-formatted docs. + # This is passed to the response_generator above + | response_generator + ) + return chain + + +def run_test(args, chain): + client = Client() + test_run = client.run_on_dataset( + dataset_name=args.langchain_dataset, + llm_or_chain_factory=chain, + evaluation=None, + project_name=f"{args.llm_service_api}-{args.model_name} op-{MAX_OUTPUT_TOKENS} cl-{args.concurrency} iter-{run_uid}", + project_metadata={ + "index_method": "basic", + }, + concurrency_level=args.concurrency, + verbose=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-l", + "--llm_endpoint_url", + type=str, + required=False, + default=ENDPOINT_URL_GAUDI2, + help="LLM Service Endpoint URL", + ) + parser.add_argument( + "-e", + "--embedding_endpoint_url", + type=str, + default=TEI_ENDPOINT, + required=False, + help="Embedding Service Endpoint URL", + ) + parser.add_argument("-m", "--model_name", type=str, default=HF_MODEL_NAME, required=False, help="Model Name") + parser.add_argument("-ht", "--huggingface_token", type=str, required=True, help="Huggingface API token") + parser.add_argument("-lt", "--langchain_token", type=str, required=True, help="langchain API token") + parser.add_argument( + "-d", + "--langchain_dataset", + type=str, + required=True, + help="langchain dataset name Refer: https://docs.smith.langchain.com/evaluation/quickstart ", + ) + + parser.add_argument("-c", "--concurrency", type=int, default=16, required=False, help="Concurrency Level") + + parser.add_argument( + "-lm", + "--llm_service_api", + type=str, + default="tgi-gaudi", + required=False, + help='Choose between "tgi-gaudi" or "vllm-openai"', + ) + + parser.add_argument( + "-ig", "--ingest_dataset", type=bool, default=False, required=False, help='Set True to ingest dataset"' + ) + + parser.add_argument("-dbu", "--db_url", type=str, required=True, help="Vector DB URL") + + parser.add_argument("-dbs", "--db_schema", type=str, required=True, help="Vector DB Schema") + + parser.add_argument("-dbi", "--db_index", type=str, required=True, help="Vector DB Index Name") + + args = parser.parse_args() + + if args.ingest_dataset: + langchain_doc = GetLangchainDataset(args) + ingest_dataset(args, langchain_doc) + + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" + os.environ["LANGCHAIN_API_KEY"] = args.langchain_token + os.environ["HUGGINGFACEHUB_API_TOKEN"] = args.huggingface_token + + chain = buildchain(args) + run_test(args, chain) diff --git a/ChatQnA/langchain/test/tgi_gaudi.ipynb b/ChatQnA/langchain/test/tgi_gaudi.ipynb new file mode 100644 index 000000000..ecc398196 --- /dev/null +++ b/ChatQnA/langchain/test/tgi_gaudi.ipynb @@ -0,0 +1,496 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7b419db2-6701-499c-abfa-1426f155fff5", + "metadata": {}, + "source": [ + "## Benchmarking RAG pipeline with Redis and LLM using langsmith\n", + "This notebook provides steps to Benchmark RAG pipeline using Langsmith. The RAG pipeline is implemented using Redis as vector database and llama2-70b-chat-hf model as LLM which is served by Huggingface TGI endpoint
\n", + "Langsmith documentation: https://docs.smith.langchain.com/" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e30e3f0f-6200-464e-b429-6b69c44e06b1", + "metadata": {}, + "outputs": [], + "source": [ + "#All imports\n", + "import os\n", + "import uuid\n", + "from operator import itemgetter\n", + "from typing import Sequence\n", + "\n", + "from langchain_benchmarks import clone_public_dataset, registry\n", + "from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings\n", + "from langchain_community.vectorstores import Redis\n", + "from langchain_community.llms import HuggingFaceEndpoint\n", + "from langchain.prompts import ChatPromptTemplate\n", + "from langchain.schema.document import Document\n", + "from langchain.schema.output_parser import StrOutputParser\n", + "from langchain.schema.runnable.passthrough import RunnableAssign\n", + "from transformers import AutoTokenizer, LlamaForCausalLM\n", + "\n", + "from langsmith.client import Client\n", + "from langchain_benchmarks.rag import get_eval_config\n" + ] + }, + { + "cell_type": "markdown", + "id": "c57bae87-2582-419d-8dcf-66c342594ae5", + "metadata": {}, + "source": [ + "### Configuration parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8989b9cf-ff52-4d10-941d-e43beb4678a9", + "metadata": {}, + "outputs": [], + "source": [ + "#Configuration parameters\n", + "\n", + "os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"\n", + "os.environ[\"LANGCHAIN_API_KEY\"] = \"add-your-langsmith-key\" # Your API key\n", + "\n", + "#Vector DB configuration\n", + "EMBED_MODEL = \"\" #Huggingface sentencetransformer model that you want to use. ex. \"BAAI/bge-base-en-v1.5\"\n", + "REDIS_INDEX_NAME = \"\" #Name of the index to be created in DB\n", + "REDIS_SERVER_URL = \"\" #Specify url of your redis server\n", + "REDIS_INDEX_SCHEMA = \"\" #path to redis schema yml file. Schema to stor data, vectors and desired metadata for every entry\n", + "\n", + "#Endpoints\n", + "TEI_ENDPOINT = \"Add your TEI endpoint\" #Huggingface TEI endpoint url for Embedding model serving. Make sure TEI is serving the same EMBED_MODEL specified above\n", + "TGI_ENDPOINT = \"Add your TGI endpoint\" #Huggingface TGI endpoint url for Embedding model serving\n", + "VLLM_ENDPOINT = \"Add your VLLM endpoint\" #vllm server endpoint (either this or TGI_ENDPOINT should be specified)\n", + "METHOD = \"\" #give \"tgi-gaudi\" to use TGI_ENDPOINT or \"vllm-openai\" to use VLLM_ENDPOINT\n", + "\n", + "#Test parameters\n", + "LANGSMITH_PROJECT_NAME = \"\" #The test result will be displayed in langsmith cloud with this project name and an unique uuid\n", + "CONCURRENCY_LEVEL = 16 #Number of concurrent queries to be sent to RAG chain\n", + "LANGCHAIN_DATASET_NAME = 'LangChain Docs Q&A' #Specify the Langchain dataset name (if using a dataset from langchain)\n", + "\n", + "#LLM parameters\n", + "LLM_MODEL_NAME = \"meta-llama/Llama-2-70b-chat-hf\"\n", + "MAX_OUTPUT_TOKENS = 128\n", + "PROMPT_TOKENS_LEN=214 # Magic number for prompt template tokens. This changes if prompt changes\n", + "MAX_INPUT_TOKENS=1024 #Use this and PROMPT_TOKENS_LEN if there is a need to limit input tokens." + ] + }, + { + "cell_type": "markdown", + "id": "e4e430ea-83e1-417f-a6fc-57a3ffdcac85", + "metadata": {}, + "source": [ + "### Selecting dataset\n", + "Below section covers selecting and using LangChain Docs Q&A dataset\n", + "This can be modified to use any dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "667c6870-10b7-492b-bc09-fddf0b1e3d76", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Name Type Dataset ID Description
LangChain Docs Q&A RetrievalTask452ccafc-18e1-4314-885b-edd735f17b9dQuestions and answers based on a snapshot of the LangChain python docs.\n", + "\n", + "The environment provides the documents and the retriever information.\n", + "\n", + "Each example is composed of a question and reference answer.\n", + "\n", + "Success is measured based on the accuracy of the answer relative to the reference answer.\n", + "We also measure the faithfulness of the model's response relative to the retrieved documents (if any).
Semi-structured ReportsRetrievalTaskc47d9617-ab99-4d6e-a6e6-92b8daf85a7dQuestions and answers based on PDFs containing tables and charts.\n", + "\n", + "The task provides the raw documents as well as factory methods to easily index them\n", + "and create a retriever.\n", + "\n", + "Each example is composed of a question and reference answer.\n", + "\n", + "Success is measured based on the accuracy of the answer relative to the reference answer.\n", + "We also measure the faithfulness of the model's response relative to the retrieved documents (if any).
Multi-modal slide decksRetrievalTask40afc8e7-9d7e-44ed-8971-2cae1eb59731This public dataset is a work-in-progress and will be extended over time.\n", + " \n", + "Questions and answers based on slide decks containing visual tables and charts.\n", + "\n", + "Each example is composed of a question and reference answer.\n", + "\n", + "Success is measured based on the accuracy of the answer relative to the reference answer.
" + ], + "text/plain": [ + "Registry(tasks=[RetrievalTask(name='LangChain Docs Q&A', dataset_id='https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/d', description=\"Questions and answers based on a snapshot of the LangChain python docs.\\n\\nThe environment provides the documents and the retriever information.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=, retriever_factories={'basic': , 'parent-doc': , 'hyde': }, architecture_factories={'conversational-retrieval-qa': }), RetrievalTask(name='Semi-structured Reports', dataset_id='https://smith.langchain.com/public/c47d9617-ab99-4d6e-a6e6-92b8daf85a7d/d', description=\"Questions and answers based on PDFs containing tables and charts.\\n\\nThe task provides the raw documents as well as factory methods to easily index them\\nand create a retriever.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=, retriever_factories={'basic': , 'parent-doc': , 'hyde': }, architecture_factories={}), RetrievalTask(name='Multi-modal slide decks', dataset_id='https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d', description='This public dataset is a work-in-progress and will be extended over time.\\n \\nQuestions and answers based on slide decks containing visual tables and charts.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\n', get_docs={}, retriever_factories={}, architecture_factories={})])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#Langchain supported datasets for Retrieval task\n", + "registry = registry.filter(Type=\"RetrievalTask\")\n", + "registry" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "81c8d507-40d5-4f56-9b68-6f579aa6cce7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Name LangChain Docs Q&A
Type RetrievalTask
Dataset ID 452ccafc-18e1-4314-885b-edd735f17b9d
Description Questions and answers based on a snapshot of the LangChain python docs.\n", + "\n", + "The environment provides the documents and the retriever information.\n", + "\n", + "Each example is composed of a question and reference answer.\n", + "\n", + "Success is measured based on the accuracy of the answer relative to the reference answer.\n", + "We also measure the faithfulness of the model's response relative to the retrieved documents (if any).
Retriever Factories basic, parent-doc, hyde
Architecture Factoriesconversational-retrieval-qa
get_docs
" + ], + "text/plain": [ + "RetrievalTask(name='LangChain Docs Q&A', dataset_id='https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/d', description=\"Questions and answers based on a snapshot of the LangChain python docs.\\n\\nThe environment provides the documents and the retriever information.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=, retriever_factories={'basic': , 'parent-doc': , 'hyde': }, architecture_factories={'conversational-retrieval-qa': })" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#Lets use LangChain Docs Q&A dataset for our benchmark\n", + "langchain_docs = registry[LANGCHAIN_DATASET_NAME]\n", + "langchain_docs" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bece8e4b-2fd7-4483-abce-2f37ebf858a7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset LangChain Docs Q&A already exists. Skipping.\n", + "You can access the dataset at https://smith.langchain.com/o/9534e90b-1d2b-55ed-bf79-31dc5ff16722/datasets/3ce3b4a1-0640-4fbf-925e-2c03caceb5ac.\n" + ] + } + ], + "source": [ + "#Download the dataset locally\n", + "clone_public_dataset(langchain_docs.dataset_id, dataset_name=langchain_docs.name)" + ] + }, + { + "cell_type": "markdown", + "id": "d816db14-8175-4c9e-9f99-0b877553fdc9", + "metadata": {}, + "source": [ + "### Ingesting data into Redis vector DB\n", + "This section needs to be run only when the Redis server doesn't already contain the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80581246-0fcd-4da3-9d45-022b871a787a", + "metadata": {}, + "outputs": [], + "source": [ + "#Embedding model for ingestion \n", + "embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25ada515", + "metadata": {}, + "outputs": [], + "source": [ + "#Ingest the dataset into vector DB\n", + "_ = Redis.from_texts(\n", + " # appending this little bit can sometimes help with semantic retrieval\n", + " # especially with multiple companies\n", + " texts=[d.page_content for d in docs],\n", + " metadatas=[d.metadata for d in docs],\n", + " embedding=embedder,\n", + " index_name=REDIS_INDEX_NAME,\n", + " index_schema=REDIS_INDEX_SCHEMA,\n", + " redis_url=REDIS_SERVER_URL,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a4c2f599-da93-4700-aec3-4db2d43d1ef8", + "metadata": {}, + "source": [ + "### RAG pipeline\n", + "Initialize each component of RAG pipeline and setup the chain" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5756e0db", + "metadata": {}, + "outputs": [], + "source": [ + "#enable TEI endpoint to get high throughput high throughput queries\n", + "embedder = HuggingFaceHubEmbeddings(model=TEI_ENDPOINT)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7f906c62-2e59-481f-9159-b2dca087d802", + "metadata": {}, + "outputs": [], + "source": [ + "#Initialize retriever to be added in langchain RAG chain.\n", + "vectorstore = Redis.from_existing_index(\n", + " embedding=embedder, index_name=REDIS_INDEX_NAME, schema=REDIS_INDEX_SCHEMA, redis_url=REDIS_SERVER_URL\n", + ")\n", + "retriever = vectorstore.as_retriever()" + ] + }, + { + "cell_type": "markdown", + "id": "cdf91d5b-e7e9-41d9-830b-465d87ffc5b0", + "metadata": {}, + "source": [ + "**Note:** Prompt is specific to dataset. Modify the prompt accordingly based on the dataset selected.
\n", + "The below prompt is for Langchain Docs Q&A dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "e9e77e72-fa65-48a6-aae9-1bec9875b124", + "metadata": {}, + "outputs": [], + "source": [ + "#Setup prompt\n", + "\n", + "#helper function to crop input tokens from retrieved doc from vector DB\n", + "#This can be used in format_docs function if there is a need to make sure\n", + "#number of input tokens doesn't exceed certain limit\n", + "tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)\n", + "def crop_tokens(prompt, max_len):\n", + " inputs = tokenizer(prompt, return_tensors=\"pt\")\n", + " inputs_cropped = inputs['input_ids'][0:,0:max_len]\n", + " prompt_cropped=tokenizer.batch_decode(inputs_cropped, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n", + " return prompt_cropped\n", + "\n", + "# After the retriever fetches documents, this\n", + "# function formats them in a string to present for the LLM\n", + "def format_docs(docs: Sequence[Document]) -> str:\n", + " formatted_docs = []\n", + " for i, doc in enumerate(docs):\n", + " doc_string = (\n", + " f\"\\n\"\n", + " f\"{doc.metadata.get('source')}\\n\"\n", + " f\"{doc.page_content}\\n\"\n", + " \"\"\n", + " )\n", + " # Truncate the retrieval data based on the max tokens required\n", + " cropped= crop_tokens(doc_string,MAX_INPUT_TOKENS-PROMPT_TOKENS_LEN) #remove this if there is not need of limiting INPUT tokens to LLM\n", + " formatted_docs.append(doc_string)\n", + " formatted_str = \"\\n\".join(formatted_docs)\n", + " return f\"\\n{formatted_str}\\n\"\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are an AI assistant answering questions about LangChain.\"\n", + " \"\\n{context}\\n\"\n", + " \"Respond solely based on the document content.\",\n", + " ),\n", + " (\"human\", \"{question}\"),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "eee03d33-7e7b-41b8-8717-b031a49c1a36", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n", + "Token is valid (permission: read).\n", + "Your token has been saved to /root/.cache/huggingface/token\n", + "Login successful\n" + ] + } + ], + "source": [ + "#Setup LLM \n", + "\n", + "llm = None\n", + "match METHOD:\n", + " case \"tgi-gaudi\":\n", + " llm = HuggingFaceEndpoint(\n", + " endpoint_url=TGI_ENDPOINT,\n", + " max_new_tokens=MAX_OUTPUT_TOKENS,\n", + " top_k=10,\n", + " top_p=0.95,\n", + " typical_p=0.95,\n", + " temperature=1.0,\n", + " repetition_penalty=1.03,\n", + " streaming=False,\n", + " truncate=1024\n", + " )\n", + " case \"vllm-openai\":\n", + " llm = ChatOpenAI(\n", + " model=LLM_MODEL_NAME,\n", + " openai_api_key=\"EMPTY\", \n", + " openai_api_base=VLLM_ENDPOINT,\n", + " max_tokens=MAX_OUTPUT_TOKENS,\n", + " temperature=1.0,\n", + " top_p=0.95,\n", + " streaming=False,\n", + " frequency_penalty=1.03\n", + " )\n", + "\n", + "response_generator = (prompt | llm | StrOutputParser()).with_config(\n", + " run_name=\"GenerateResponse\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "7164b12b-11e9-4cec-946a-e1902be507da", + "metadata": {}, + "outputs": [], + "source": [ + "# This is the final response chain.\n", + "# It fetches the \"question\" key from the input dict,\n", + "# passes it to the retriever, then formats as a string.\n", + "\n", + "chain = (\n", + " RunnableAssign(\n", + " {\n", + " \"context\": (itemgetter(\"question\") | retriever | format_docs).with_config(\n", + " run_name=\"FormatDocs\"\n", + " )\n", + " }\n", + " )\n", + " # The \"RunnableAssign\" above returns a dict with keys\n", + " # question (from the original input) and\n", + " # context: the string-formatted docs.\n", + " # This is passed to the response_generator above\n", + " | response_generator\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9f92a667-9904-437f-88d7-45caf56afbb9", + "metadata": {}, + "source": [ + "### Setup and run Langsmith benchmark" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b39849d7-4b18-4dc5-a97c-f50f528bc980", + "metadata": {}, + "outputs": [], + "source": [ + "#Initialize Langchain client\n", + "client = Client()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9aabc054-7371-4daf-b2a9-3a8ed891e03b", + "metadata": {}, + "outputs": [], + "source": [ + "# Generate a unique run ID for this experiment\n", + "run_uid = uuid.uuid4().hex[:6]\n", + "\n", + "#Run the test\n", + "test_run = client.run_on_dataset(\n", + " dataset_name=LANGCHAIN_DATASET_NAME,\n", + " llm_or_chain_factory=chain,\n", + " evaluation=None,\n", + " project_name=LANGSMITH_PROJECT_NAME+'_'+run_uid,\n", + " project_metadata={\n", + " \"index_method\": \"basic\",\n", + " },\n", + " concurrency_level=CONCURRENCY_LEVEL,\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8198756-0797-49ad-8072-1a1d61536689", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}