diff --git a/.gitignore b/.gitignore index b321fa2..c28469d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ build/ **/__pycache__ info.sh tests/test.py -claude.sh \ No newline at end of file +claude.sh +*.txt \ No newline at end of file diff --git a/terminal_agent_x/tax.py b/terminal_agent_x/tax.py index 64ad233..03f132c 100644 --- a/terminal_agent_x/tax.py +++ b/terminal_agent_x/tax.py @@ -63,10 +63,11 @@ def fetch_code(openai_key: str, model: str, prompt: str, url_option: str, chat_f command = f'{command} --ipv4' if model == 'claude' else command else: # Linux command = f"curl -s --location '{url}' --header '{headers[0]}' --header '{headers[1]}' --data '{data}'" - # print(command) + print(command) try: res, err = run_command_with_timeout(command, 60) + print(res) # res = os.popen(command).read().encode('utf-8').decode('utf-8', 'ignore') if model.lower() == 'dalle': return json.loads(res)['data'][0]['url'] @@ -166,8 +167,7 @@ def chat(openai_key: str, model: str, url_option: str): if user_input == "exit": break conversation.append({"role": "user", "content": user_input}) - response = fetch_code(openai_key, model, json.dumps( - conversation), url_option, True) + response = fetch_code(openai_key, model, json.dumps(conversation), url_option, True) print(f'Tax: {response}') conversation.append({"role": "assistant", "content": response.encode( 'unicode-escape').decode('utf8').replace("'", "")}) @@ -175,23 +175,31 @@ def chat(openai_key: str, model: str, url_option: str): # print(conversation) -def parallel_ask(data_prompts, chat_model, max_workers, output_file, **args): - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_prompt = {executor.submit(chat_model, prompt=prompt, **args): prompt for prompt in data_prompts} +def parallel_ask(data_prompts, chat_mode, max_workers, output_file, model, **args): + with concurrent.futures.ThreadPoolExecutor(max_workers=int(max_workers)) as executor: + future_to_prompt = [] + for prompt in data_prompts: + future_to_prompt.append(executor.submit( + chat_mode, prompt=prompt, model=model, **args)) + results = [] for future in concurrent.futures.as_completed(future_to_prompt): - prompt = future_to_prompt[future] try: data = future.result() except Exception as exc: - print(f'{prompt} generated an exception: {exc}') - else: - # print(f'{prompt} generated {data}') - if output_file: - with open(output_file, 'a', encoding='utf-8') as f: - f.write(f'{prompt} : {data}\n\n') - f.close() + data = str(type(exc)) + results.append(data) + if output_file: + with open(output_file, 'w', encoding='utf-8') as f: + f.write('\n'.join(results)) + f.close() +def load_prompts_file(model, path: str) -> str: + with open(path, 'r', encoding='utf-8') as f: + text = f.readlines() + text = [line.strip() for line in text] + wrappers = [chat_data_wrapper(model, prompt, False) for prompt in text] + return wrappers def main() -> None: @@ -212,6 +220,9 @@ def main() -> None: help="URL for API request. Choose from ['openai_gfw', 'openai', 'claude'] or your custom url.") parser.add_argument('-a', '--show_all', action='store_true', help='Show all contents in the response.') + parser.add_argument('-p', '--parallel', action='store_true', + help='Parallel mode. If specified, the input file will be read line by line and the responses will be saved to the output file.') + parser.add_argument('--option', metavar='KEY=VALUE', action='append', help='Custom option') args = parser.parse_args() prompt = ' '.join(args.prompt) @@ -232,6 +243,17 @@ def main() -> None: # res = get_model_response(openai_key, args.model, prompt) res = fetch_code(key, args.model, prompt, args.url, False) + if args.option and args.parallel: + custom_options = {} + for option in args.option: + key, value = option.split('=') + custom_options[key] = value + + print(parallel_ask(data_prompts=load_prompts_file(args.model, args.input), output_file=args.output, model=args.model, **custom_options)) + return + + # tax -i input.txt -o output.txt -m gpt-3.5-turbo -u openai_gfw -k xxx --option chat_mode=fetch_code + if args.output: with open(args.output, 'w', encoding='utf-8') as f: f.write(res) @@ -252,5 +274,5 @@ def main() -> None: if __name__ == '__main__': - # main() - parallel_ask(data_prompts=['hi'], chat_model=fetch_code, max_workers=3, output_file='output.txt', openai_key='', model='gpt-3.5-turbo', url_option='openai', chat_flag=False) + main() + # parallel_ask(data_prompts=['hi'], chat_mode=fetch_code, max_workers=3, output_file='output.txt', openai_key='', model='gpt-3.5-turbo', url_option='openai', chat_flag=False)