From f61c008c5d7c7e3af178cbf02af91baf08ae3cfc Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 14 Nov 2024 19:12:53 +0900 Subject: [PATCH] Add missing device transfer in gpt_generate.py (#436) --- ch05/01_main-chapter-code/gpt_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ch05/01_main-chapter-code/gpt_generate.py b/ch05/01_main-chapter-code/gpt_generate.py index 0a5b8141..92c00102 100644 --- a/ch05/01_main-chapter-code/gpt_generate.py +++ b/ch05/01_main-chapter-code/gpt_generate.py @@ -270,7 +270,7 @@ def main(gpt_config, input_prompt, model_size): token_ids = generate( model=gpt, - idx=text_to_token_ids(input_prompt, tokenizer), + idx=text_to_token_ids(input_prompt, tokenizer).to(device), max_new_tokens=25, context_size=gpt_config["context_length"], top_k=50,