Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for infinite output model fallback #2631

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions aider/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ def get_parser(default_config_files, git_root):
default=None,
help="Specify the edit format for the editor model (default: depends on editor model)",
)
group.add_argument(
"--infinite-output-model",
metavar="INFINITE_OUTPUT_MODEL",
default=None,
help="Specify the model to use for continuing long responses (default: None)",
)
group.add_argument(
"--show-model-warnings",
action=argparse.BooleanOptionalAction,
Expand Down
20 changes: 17 additions & 3 deletions aider/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def get_announcements(self):
# Model
main_model = self.main_model
weak_model = main_model.weak_model
infinite_output_model = main_model.infinite_output_model

if weak_model is not main_model:
prefix = "Main model"
Expand All @@ -210,6 +211,10 @@ def get_announcements(self):
output = f"Weak model: {weak_model.name}"
lines.append(output)

if infinite_output_model and infinite_output_model is not main_model:
output = f"Infinite output model: {infinite_output_model.name}"
lines.append(output)

# Repo
if self.repo:
rel_repo_dir = self.repo.get_rel_repo_dir()
Expand Down Expand Up @@ -1275,9 +1280,14 @@ def send_message(self, inp):
break
except FinishReasonLength:
# We hit the output limit!
if not self.main_model.info.get("supports_assistant_prefill"):
exhausted = True
break
if self.main_model.info.get("supports_assistant_prefill"):
use_model = self.main_model
else:
# Try to get an infinite output model
use_model = self.main_model.infinite_output_model
if not use_model or not use_model.info.get("supports_assistant_prefill"):
exhausted = True
break

self.multi_response_content = self.get_multi_response_content()

Expand All @@ -1287,6 +1297,10 @@ def send_message(self, inp):
messages.append(
dict(role="assistant", content=self.multi_response_content, prefix=True)
)

# Switch to the infinite output model if needed
if use_model != self.main_model:
self.main_model = use_model
except Exception as err:
self.mdstream = None
lines = traceback.format_exception(type(err), err, err.__traceback__)
Expand Down
1 change: 1 addition & 0 deletions aider/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ def get_io(pretty):
weak_model=args.weak_model,
editor_model=args.editor_model,
editor_edit_format=args.editor_edit_format,
infinite_output_model=args.infinite_output_model,
)

if args.copy_paste and args.edit_format is None:
Expand Down
27 changes: 26 additions & 1 deletion aider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class ModelSettings:
name: str
edit_format: str = "whole"
weak_model_name: Optional[str] = None
infinite_output_model_name: Optional[str] = None
use_repo_map: bool = False
send_undo_reply: bool = False
lazy: bool = False
Expand Down Expand Up @@ -857,7 +858,7 @@ def get_model_info(self, model):


class Model(ModelSettings):
def __init__(self, model, weak_model=None, editor_model=None, editor_edit_format=None):
def __init__(self, model, weak_model=None, editor_model=None, editor_edit_format=None, infinite_output_model=None):
# Map any alias to its canonical name
model = MODEL_ALIASES.get(model, model)

Expand All @@ -866,6 +867,7 @@ def __init__(self, model, weak_model=None, editor_model=None, editor_edit_format
self.max_chat_history_tokens = 1024
self.weak_model = None
self.editor_model = None
self.infinite_output_model = None

# Find the extra settings
self.extra_model_settings = next(
Expand Down Expand Up @@ -896,6 +898,11 @@ def __init__(self, model, weak_model=None, editor_model=None, editor_edit_format
else:
self.get_editor_model(editor_model, editor_edit_format)

if infinite_output_model is False:
self.infinite_output_model_name = None
else:
self.get_infinite_output_model(infinite_output_model)

def get_model_info(self, model):
return model_info_manager.get_model_info(model)

Expand Down Expand Up @@ -1015,6 +1022,24 @@ def get_weak_model(self, provided_weak_model_name):
def commit_message_models(self):
return [self.weak_model, self]

def get_infinite_output_model(self, provided_infinite_output_model_name):
# If infinite_output_model_name is provided, override the model settings
if provided_infinite_output_model_name:
self.infinite_output_model_name = provided_infinite_output_model_name

if not self.infinite_output_model_name:
return None

if self.infinite_output_model_name == self.name:
return self

self.infinite_output_model = Model(
self.infinite_output_model_name,
weak_model=False,
editor_model=False,
)
return self.infinite_output_model

def get_editor_model(self, provided_editor_model_name, editor_edit_format):
# If editor_model_name is provided, override the model settings
if provided_editor_model_name:
Expand Down