diff --git a/README.md b/README.md index 6b988ad..41f4f13 100644 --- a/README.md +++ b/README.md @@ -112,13 +112,6 @@ optional arguments: Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C to clear the conversation, `:r` or Ctrl-R to re-generate the last response. To enter multi-line mode, enter a backslash `\` followed by a new line. Exit the multi-line mode by pressing ESC and then Enter. -You can override the model parameters using `--model`, `--temperature` and `--top_p` arguments at the end of your prompt. For example: - -``` -> What is the meaning of life? --model gpt-4 --temperature 2.0 -The meaning of life is subjective and can be different for diverse human beings and unique-phil ethics.org/cultuties-/ it that reson/bdstals89im3_jrf334;mvs-bread99ef=g22me -``` - The `dev` assistant is instructed to be an expert in software development and provide short responses. ```bash @@ -197,7 +190,6 @@ assistants: - { role: system, content: !include "pirate.txt" } ``` - ### Customize OpenAI API URL If you are using other models compatible with the OpenAI Python SDK, you can configure them by modifying the `openai_base_url` setting in the config file or using the `OPENAI_BASE_URL` environment variable . diff --git a/gptcli/assistant.py b/gptcli/assistant.py index d6932a0..fdbe6ea 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -7,7 +7,6 @@ from gptcli.completion import ( CompletionEvent, CompletionProvider, - ModelOverrides, Message, ) from gptcli.providers.google import GoogleCompletionProvider @@ -107,28 +106,20 @@ def from_config(cls, name: str, config: AssistantConfig): def init_messages(self) -> List[Message]: return self.config.get("messages", [])[:] - def supported_overrides(self) -> List[str]: - return ["model", "temperature", "top_p"] - - def _param(self, param: str, override_params: ModelOverrides) -> Any: - # If the param is in the override_params, use that value - # Otherwise, use the value from the config + def _param(self, param: str) -> Any: + # Use the value from the config if exists # Otherwise, use the default value - return override_params.get( - param, self.config.get(param, CONFIG_DEFAULTS[param]) - ) + return self.config.get(param, CONFIG_DEFAULTS[param]) - def complete_chat( - self, messages, override_params: ModelOverrides = {}, stream: bool = True - ) -> Iterator[CompletionEvent]: - model = self._param("model", override_params) + def complete_chat(self, messages, stream: bool = True) -> Iterator[CompletionEvent]: + model = self._param("model") completion_provider = get_completion_provider(model) return completion_provider.complete( messages, { "model": model, - "temperature": float(self._param("temperature", override_params)), - "top_p": float(self._param("top_p", override_params)), + "temperature": float(self._param("temperature")), + "top_p": float(self._param("top_p")), }, stream, ) diff --git a/gptcli/cli.py b/gptcli/cli.py index 65e1040..62f10e5 100644 --- a/gptcli/cli.py +++ b/gptcli/cli.py @@ -1,5 +1,4 @@ -import re -from typing import Any, Dict, Optional, Tuple +from typing import Optional from openai import BadRequestError, OpenAIError from prompt_toolkit import PromptSession @@ -11,9 +10,16 @@ from rich.markdown import Markdown from rich.text import Text -from gptcli.session import (ALL_COMMANDS, COMMAND_CLEAR, COMMAND_QUIT, - COMMAND_RERUN, ChatListener, InvalidArgumentError, - ResponseStreamer, UserInputProvider) +from gptcli.session import ( + ALL_COMMANDS, + COMMAND_CLEAR, + COMMAND_QUIT, + COMMAND_RERUN, + ChatListener, + InvalidArgumentError, + ResponseStreamer, + UserInputProvider, +) TERMINAL_WELCOME = """ Hi! I'm here to help. Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C and Enter to clear @@ -113,43 +119,6 @@ def response_streamer(self) -> ResponseStreamer: return CLIResponseStreamer(self.console, self.markdown) -def parse_args(input: str) -> Tuple[str, Dict[str, Any]]: - # Extract parts enclosed in specific delimiters (triple backticks, triple quotes, single backticks) - extracted_parts = [] - delimiters = ['```', '"""', '`'] - - def replacer(match): - for i, delimiter in enumerate(delimiters): - part = match.group(i + 1) - if part is not None: - extracted_parts.append((part, delimiter)) - break - return f"__EXTRACTED_PART_{len(extracted_parts) - 1}__" - - # Construct the regex pattern dynamically from the delimiters list - pattern_fragments = [re.escape(d) + '(.*?)' + re.escape(d) for d in delimiters] - pattern = re.compile('|'.join(pattern_fragments), re.DOTALL) - - input = pattern.sub(replacer, input) - - # Parse the remaining string for arguments - args = {} - regex = r'--(\w+)(?:=(\S+)|\s+(\S+))?' - matches = re.findall(regex, input) - - if matches: - for key, value1, value2 in matches: - value = value1 if value1 else value2 if value2 else '' - args[key] = value.strip("\"'") - input = re.sub(regex, "", input).strip() - - # Add back the extracted parts, with enclosing backticks or quotes - for i, (part, delimiter) in enumerate(extracted_parts): - input = input.replace(f"__EXTRACTED_PART_{i}__", f"{delimiter}{part.strip()}{delimiter}") - - return input, args - - class CLIFileHistory(FileHistory): def append_string(self, string: str) -> None: if string in ALL_COMMANDS: @@ -163,12 +132,11 @@ def __init__(self, history_filename) -> None: history=CLIFileHistory(history_filename) ) - def get_user_input(self) -> Tuple[str, Dict[str, Any]]: + def get_user_input(self) -> str: while (next_user_input := self._request_input()) == "": pass - user_input, args = self._parse_input(next_user_input) - return user_input, args + return next_user_input def prompt(self, multiline=False): bindings = KeyBindings() @@ -219,7 +187,3 @@ def _request_input(self): return line return self.prompt(multiline=True) - - def _parse_input(self, input: str) -> Tuple[str, Dict[str, Any]]: - input, args = parse_args(input) - return input, args diff --git a/gptcli/completion.py b/gptcli/completion.py index a0348f9..e2c3aa0 100644 --- a/gptcli/completion.py +++ b/gptcli/completion.py @@ -9,12 +9,6 @@ class Message(TypedDict): content: str -class ModelOverrides(TypedDict, total=False): - model: str - temperature: float - top_p: float - - class Pricing(TypedDict): prompt: float response: float diff --git a/gptcli/composite.py b/gptcli/composite.py index 6f09e94..fde031f 100644 --- a/gptcli/composite.py +++ b/gptcli/composite.py @@ -1,4 +1,4 @@ -from gptcli.completion import Message, ModelOverrides, UsageEvent +from gptcli.completion import Message, UsageEvent from gptcli.session import ChatListener, ResponseStreamer @@ -56,8 +56,7 @@ def on_chat_response( self, messages: List[Message], response: Message, - overrides: ModelOverrides, usage: Optional[UsageEvent], ): for listener in self.listeners: - listener.on_chat_response(messages, response, overrides, usage) + listener.on_chat_response(messages, response, usage) diff --git a/gptcli/cost.py b/gptcli/cost.py index d8f1142..7b527c1 100644 --- a/gptcli/cost.py +++ b/gptcli/cost.py @@ -1,5 +1,5 @@ from gptcli.assistant import Assistant -from gptcli.completion import Message, ModelOverrides, UsageEvent +from gptcli.completion import Message, UsageEvent from gptcli.session import ChatListener from rich.console import Console @@ -22,13 +22,12 @@ def on_chat_response( self, messages: List[Message], response: Message, - args: ModelOverrides, usage: Optional[UsageEvent] = None, ): if usage is None: return - model = self.assistant._param("model", args) + model = self.assistant._param("model") num_tokens = usage.total_tokens cost = usage.cost diff --git a/gptcli/session.py b/gptcli/session.py index e107c55..c0364be 100644 --- a/gptcli/session.py +++ b/gptcli/session.py @@ -1,14 +1,12 @@ from abc import abstractmethod -from typing_extensions import TypeGuard from gptcli.assistant import Assistant from gptcli.completion import ( Message, - ModelOverrides, CompletionError, BadRequestError, UsageEvent, ) -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional class ResponseStreamer: @@ -45,7 +43,6 @@ def on_chat_response( self, messages: List[Message], response: Message, - overrides: ModelOverrides, usage: Optional[UsageEvent] = None, ): pass @@ -53,7 +50,7 @@ def on_chat_response( class UserInputProvider: @abstractmethod - def get_user_input(self) -> Tuple[str, Dict[str, Any]]: + def get_user_input(self) -> str: pass @@ -85,7 +82,7 @@ def __init__( ): self.assistant = assistant self.messages: List[Message] = assistant.init_messages() - self.user_prompts: List[Tuple[Message, ModelOverrides]] = [] + self.user_prompts: List[Message] = [] self.listener = listener self.stream = stream @@ -103,10 +100,9 @@ def _rerun(self): self.messages = self.messages[:-1] self.listener.on_chat_rerun(True) - _, args = self.user_prompts[-1] - self._respond(args) + self._respond() - def _respond(self, overrides: ModelOverrides) -> bool: + def _respond(self) -> bool: """ Respond to the user's input and return whether the assistant's response was saved. """ @@ -114,7 +110,7 @@ def _respond(self, overrides: ModelOverrides) -> bool: usage: Optional[UsageEvent] = None try: completion_iter = self.assistant.complete_chat( - self.messages, override_params=overrides, stream=self.stream + self.messages, stream=self.stream ) with self.listener.response_streamer() as stream: @@ -137,28 +133,16 @@ def _respond(self, overrides: ModelOverrides) -> bool: next_message: Message = {"role": "assistant", "content": next_response} self.listener.on_chat_message(next_message) - self.listener.on_chat_response(self.messages, next_message, overrides, usage) + self.listener.on_chat_response(self.messages, next_message, usage) self.messages = self.messages + [next_message] return True - def _validate_args(self, args: Dict[str, Any]) -> TypeGuard[ModelOverrides]: - for key in args: - supported_overrides = self.assistant.supported_overrides() - if key not in supported_overrides: - self.listener.on_error( - InvalidArgumentError( - f"Invalid argument: {key}. Allowed arguments: {supported_overrides}" - ) - ) - return False - return True - - def _add_user_message(self, user_input: str, args: ModelOverrides): + def _add_user_message(self, user_input: str): user_message: Message = {"role": "user", "content": user_input} self.messages = self.messages + [user_message] self.listener.on_chat_message(user_message) - self.user_prompts.append((user_message, args)) + self.user_prompts.append(user_message) def _rollback_user_message(self): self.messages = self.messages[:-1] @@ -168,13 +152,10 @@ def _print_help(self): with self.listener.response_streamer() as stream: stream.on_next_token(COMMANDS_HELP) - def process_input(self, user_input: str, args: Dict[str, Any]): + def process_input(self, user_input: str): """ Process the user's input and return whether the session should continue. """ - if not self._validate_args(args): - return True - if user_input in COMMAND_QUIT: return False elif user_input in COMMAND_CLEAR: @@ -187,8 +168,8 @@ def process_input(self, user_input: str, args: Dict[str, Any]): self._print_help() return True - self._add_user_message(user_input, args) - response_saved = self._respond(args) + self._add_user_message(user_input) + response_saved = self._respond() if not response_saved: self._rollback_user_message() diff --git a/tests/test_session.py b/tests/test_session.py index a72b98d..871cf01 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -33,14 +33,15 @@ def test_simple_input(): assistant_mock.complete_chat.return_value = [MessageDeltaEvent(expected_response)] user_input = "user message" - should_continue = session.process_input(user_input, {}) + should_continue = session.process_input(user_input) assert should_continue user_message = {"role": "user", "content": user_input} assistant_message = {"role": "assistant", "content": expected_response} assistant_mock.complete_chat.assert_called_once_with( - [system_message, user_message], override_params={}, stream=True, + [system_message, user_message], + stream=True, ) listener_mock.on_chat_message.assert_has_calls( [mock.call(user_message), mock.call(assistant_message)] @@ -49,7 +50,7 @@ def test_simple_input(): def test_quit(): _, _, session = setup_session() - should_continue = session.process_input(":q", {}) + should_continue = session.process_input(":q") assert not should_continue @@ -61,12 +62,12 @@ def test_clear(): assistant_mock.complete_chat.return_value = [MessageDeltaEvent("assistant_message")] - should_continue = session.process_input("user_message", {}) + should_continue = session.process_input("user_message") assert should_continue assistant_mock.complete_chat.assert_called_once_with( [system_message, {"role": "user", "content": "user_message"}], - override_params={}, stream=True, + stream=True, ) listener_mock.on_chat_message.assert_has_calls( [ @@ -77,7 +78,7 @@ def test_clear(): assistant_mock.complete_chat.reset_mock() listener_mock.on_chat_message.reset_mock() - should_continue = session.process_input(":c", {}) + should_continue = session.process_input(":c") assert should_continue assistant_mock.init_messages.assert_called_once() @@ -88,12 +89,12 @@ def test_clear(): MessageDeltaEvent("assistant_message_1") ] - should_continue = session.process_input("user_message_1", {}) + should_continue = session.process_input("user_message_1") assert should_continue assistant_mock.complete_chat.assert_called_once_with( [system_message, {"role": "user", "content": "user_message_1"}], - override_params={}, stream=True, + stream=True, ) listener_mock.on_chat_message.assert_has_calls( [ @@ -110,7 +111,7 @@ def test_rerun(): assistant_mock.init_messages.reset_mock() # Re-run before any input shouldn't do anything - should_continue = session.process_input(":r", {}) + should_continue = session.process_input(":r") assert should_continue assistant_mock.init_messages.assert_not_called() @@ -123,12 +124,12 @@ def test_rerun(): # Now proper re-run assistant_mock.complete_chat.return_value = [MessageDeltaEvent("assistant_message")] - should_continue = session.process_input("user_message", {}) + should_continue = session.process_input("user_message") assert should_continue assistant_mock.complete_chat.assert_called_once_with( [system_message, {"role": "user", "content": "user_message"}], - override_params={}, stream=True, + stream=True, ) listener_mock.on_chat_message.assert_has_calls( [ @@ -143,14 +144,14 @@ def test_rerun(): MessageDeltaEvent("assistant_message_1") ] - should_continue = session.process_input(":r", {}) + should_continue = session.process_input(":r") assert should_continue listener_mock.on_chat_rerun.assert_called_once_with(True) assistant_mock.complete_chat.assert_called_once_with( [system_message, {"role": "user", "content": "user_message"}], - override_params={}, stream=True, + stream=True, ) listener_mock.on_chat_message.assert_has_calls( [ @@ -159,43 +160,6 @@ def test_rerun(): ) -def test_args(): - assistant_mock, listener_mock, session = setup_session() - - assistant_mock.supported_overrides.return_value = ["arg1"] - - expected_response = "assistant message" - assistant_mock.complete_chat.return_value = [MessageDeltaEvent(expected_response)] - - user_input = "user message" - should_continue = session.process_input(user_input, {"arg1": "value1"}) - assert should_continue - - user_message = {"role": "user", "content": user_input} - assistant_message = {"role": "assistant", "content": expected_response} - - assistant_mock.complete_chat.assert_called_once_with( - [system_message, user_message], override_params={"arg1": "value1"}, stream=True, - ) - listener_mock.on_chat_message.assert_has_calls( - [mock.call(user_message), mock.call(assistant_message)] - ) - - # Now test that rerun reruns with the same args - assistant_mock.complete_chat.reset_mock() - listener_mock.on_chat_message.reset_mock() - - assistant_mock.complete_chat.return_value = [MessageDeltaEvent(expected_response)] - - should_continue = session.process_input(":r", {}) - assert should_continue - - assistant_mock.complete_chat.assert_called_once_with( - [system_message, user_message], override_params={"arg1": "value1"}, stream=True, - ) - listener_mock.on_chat_message.assert_has_calls([mock.call(assistant_message)]) - - def test_invalid_request_error(): assistant_mock, listener_mock, session = setup_session() @@ -203,7 +167,7 @@ def test_invalid_request_error(): assistant_mock.complete_chat.side_effect = error user_input = "user message" - should_continue = session.process_input(user_input, {}) + should_continue = session.process_input(user_input) assert should_continue user_message = {"role": "user", "content": user_input} @@ -215,7 +179,7 @@ def test_invalid_request_error(): listener_mock.on_chat_message.reset_mock() listener_mock.on_error.reset_mock() - should_continue = session.process_input(":r", {}) + should_continue = session.process_input(":r") assert should_continue assistant_mock.complete_chat.assert_not_called() @@ -231,7 +195,7 @@ def test_openai_error(): assistant_mock.complete_chat.side_effect = error user_input = "user message" - should_continue = session.process_input(user_input, {}) + should_continue = session.process_input(user_input) assert should_continue user_message = {"role": "user", "content": user_input} @@ -246,11 +210,12 @@ def test_openai_error(): assistant_mock.complete_chat.side_effect = None assistant_mock.complete_chat.return_value = [MessageDeltaEvent("assistant message")] - should_continue = session.process_input(":r", {}) + should_continue = session.process_input(":r") assert should_continue assistant_mock.complete_chat.assert_called_once_with( - [system_message, user_message], override_params={}, stream=True, + [system_message, user_message], + stream=True, ) listener_mock.on_chat_message.assert_has_calls( [ @@ -268,7 +233,7 @@ def test_stream(): response_streamer_mock = listener_mock.response_streamer.return_value - session.process_input("user message", {}) + session.process_input("user message") response_streamer_mock.assert_has_calls( [mock.call.on_next_token(token) for token in assistant_message] diff --git a/tests/test_term_utils.py b/tests/test_term_utils.py deleted file mode 100644 index fee724e..0000000 --- a/tests/test_term_utils.py +++ /dev/null @@ -1,75 +0,0 @@ -from gptcli.cli import parse_args - - -def test_parse_args(): - assert parse_args("foo") == ("foo", {}) - assert parse_args("foo bar") == ("foo bar", {}) - assert parse_args("this is a prompt --bar 1.0") == ( - "this is a prompt", - {"bar": "1.0"}, - ) - assert parse_args("this is a prompt --bar 1.0 --baz 2.0") == ( - "this is a prompt", - {"bar": "1.0", "baz": "2.0"}, - ) - assert parse_args("this is a prompt --bar=1.0 --baz=2.0") == ( - "this is a prompt", - {"bar": "1.0", "baz": "2.0"}, - ) - assert parse_args("this is a prompt --bar 1.0") == ( - "this is a prompt", - {"bar": "1.0"}, - ) - - -def test_parse_with_escape_blocks(): - test_cases = [ - ( - # escaped text at end of prompt - "this is a prompt --bar=1.0 {start}--baz=2.0{end}", - "this is a prompt {start}--baz=2.0{end}", - {"bar": "1.0"}, - ), - ( - # escaped text in middle of prompt with equal assignment - "this is a prompt {start}--bar=1.0{end} --baz=2.0", - "this is a prompt {start}--bar=1.0{end}", - {"baz": "2.0"}, - ), - ( - # escaped text in middle of prompt with space assignment - "this is a prompt {start}--bar 1.0{end} --baz 2.0", - "this is a prompt {start}--bar 1.0{end}", - {"baz": "2.0"}, - ), - ( - # escaped text in multiple escape sequences - 'this is a prompt --bar=1.0 {start}my first context block{end} and then ```my second context block``` --baz=2.0', - 'this is a prompt {start}my first context block{end} and then ```my second context block```', - {'bar': '1.0', 'baz': '2.0'}, - ), - ( - # entire prompt is escaped - "{start}this is a prompt --bar=1.0 --baz=2.0{end}", - "{start}this is a prompt --bar=1.0 --baz=2.0{end}", - {}, - ), - ( - # multi-line escaped text - "this is a prompt \n--bar=1.0 --baz=2.0\n{start}--foo=3.0 \n another line \nmy final line{end}", - "this is a prompt \n \n{start}--foo=3.0 \n another line \nmy final line{end}", - {'bar': '1.0', 'baz': '2.0'}, - ) - - ] - - delimiters = ["```", '"""', "`"] - - for start, end in [(d, d) for d in delimiters]: - for prompt, expected_prompt, expected_args in test_cases: - formatted_prompt = prompt.format(start=start, end=end) - formatted_expected_prompt = expected_prompt.format(start=start, end=end) - assert parse_args(formatted_prompt) == ( - formatted_expected_prompt, - expected_args, - ) \ No newline at end of file