Skip to content

Commit

Permalink
Merge pull request #14 from bebechien/main
Browse files Browse the repository at this point in the history
gemma 2 quickstart
  • Loading branch information
random-forests authored Jun 27, 2024
2 parents bbfe5b5 + 01ddfd4 commit 7c37343
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 909 deletions.
150 changes: 75 additions & 75 deletions Gemma/Keras_Gemma_2_Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@
"id": "PXNm5_p_oxMF"
},
"source": [
"This is a quick demo of Gemma running on KerasNLP. To run this you will need:\n",
"- To be added to a private github repo for Gemma.\n",
"- To be added to a private Kaggle model for weights.\n",
"This is a quick demo of Gemma running on KerasNLP.\n",
"\n",
"Note that you will need a large GPU (e.g. A100) to run this as well.\n",
"\n",
"General Keras reading:\n",
"- [Getting started with Keras](https://keras.io/getting_started/)\n",
"- [Getting started with KerasNLP](https://keras.io/guides/keras_nlp/getting_started/)\n",
"- [Generation and fine-tuning guide for GPT2](https://keras.io/guides/keras_nlp/getting_started/)\n",
"\n",
"<table align=\"left\">\n",
" <td>\n",
Expand Down Expand Up @@ -76,7 +73,9 @@
"from google.colab import userdata\n",
"\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"tensorflow\" or \"torch\"."
]
},
{
Expand All @@ -90,36 +89,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"id": "bMboT70Xop8G"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m21.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m72.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.2/311.2 kB\u001b[0m \u001b[31m35.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m589.8/589.8 MB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m95.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m76.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m107.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Building wheel for keras-nlp (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.16.1 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
}
],
"outputs": [],
"source": [
"# Install all deps\n",
"!pip install keras\n",
"!pip install keras-nlp"
"!pip install -U keras-nlp\n",
"!pip install -U keras==3.3.3"
]
},
{
Expand All @@ -137,42 +115,29 @@
"metadata": {
"id": "ww83zI9ToPso"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"tensorflow\" or \"torch\".\n",
"\n",
"import keras_nlp\n",
"import keras\n",
"\n",
"# Run at half precision.\n",
"keras.config.set_floatx(\"bfloat16\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "yygIK9DEIldp"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/metadata.json...\n",
"100%|██████████| 143/143 [00:00<00:00, 179kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/task.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/config.json...\n",
"100%|██████████| 780/780 [00:00<00:00, 895kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/model.weights.h5...\n",
"100%|██████████| 17.2G/17.2G [18:34<00:00, 16.6MB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/preprocessor.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/tokenizer.json...\n",
"100%|██████████| 315/315 [00:00<00:00, 431kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/assets/tokenizer/vocabulary.spm...\n",
"100%|██████████| 4.04M/4.04M [00:01<00:00, 2.41MB/s]\n"
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors.index.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/metadata.json...\n",
"100%|██████████| 143/143 [00:00<00:00, 153kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/task.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/config.json...\n",
"100%|██████████| 780/780 [00:00<00:00, 884kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors.index.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.weights.h5...\n",
"100%|██████████| 17.2G/17.2G [04:22<00:00, 70.5MB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors.index.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/preprocessor.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/tokenizer.json...\n",
"100%|██████████| 315/315 [00:00<00:00, 434kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/assets/tokenizer/vocabulary.spm...\n",
"100%|██████████| 4.04M/4.04M [00:00<00:00, 14.6MB/s]\n"
]
},
{
Expand Down Expand Up @@ -300,34 +265,69 @@
}
],
"source": [
"# Connect using the default `gemma2_9b_keras` or through huggingface weights `hf://google/gemma-2-9b-keras`\n",
"import keras_nlp\n",
"import keras\n",
"\n",
"# Run at half precision.\n",
"keras.config.set_floatx(\"bfloat16\")\n",
"\n",
"# using 9B base model\n",
"gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_9b_en\")\n",
"gemma_lm.summary()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {
"id": "aae5GHrdpj2_"
},
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'What is the meaning of life?\\n\\n[Answer 1]\\n\\nThe meaning of life is to live it.\\n\\n[Answer 2]\\n\\nThe'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"It was a dark and stormy night.\n",
"\n",
"The wind was howling, the rain was pouring, and the thunder was rumbling.\n",
"\n",
"I was sitting in my living room, watching the storm rage outside.\n",
"\n",
"Suddenly, I heard a knock at the door.\n",
"\n",
"I got up and opened it, and there stood a man in a black cloak.\n",
"\n",
"He had a strange look in his eyes, and he was holding a lantern.\n",
"\n",
"\"Who are you?\" I asked.\n",
"\n",
"\"I am the storm,\" he replied.\n",
"\n",
"\"And I have come to take you away.\"\n",
"\n",
"I was terrified, but I couldn't move.\n",
"\n",
"The man in the black cloak grabbed my arm and pulled me out into the storm.\n",
"\n",
"We walked for what seemed like hours, until we came to a clearing in the woods.\n",
"\n",
"There, the man in the black cloak stopped and turned to me.\n",
"\n",
"\"You are mine now,\" he said.\n",
"\n",
"\"And I will take you to my castle.\"\n",
"\n",
"I tried to fight him off, but he was too strong.\n",
"\n",
"He dragged me into the castle, and I was never seen again.\n",
"\n",
"The end.\n"
]
}
],
"source": [
"gemma_lm.generate(\"What is the meaning of life?\", max_length=32)"
"result = gemma_lm.generate(\"It was a dark and stormy night.\", max_length=256)\n",
"print(result)"
]
}
],
Expand Down
Loading

0 comments on commit 7c37343

Please sign in to comment.