Skip to content

Commit

Permalink
[WIP] Parallel Mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
LyuLumos committed Oct 7, 2023
1 parent 792a17e commit 63250c4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ build/
**/__pycache__
info.sh
tests/test.py
claude.sh
claude.sh
*.txt
54 changes: 38 additions & 16 deletions terminal_agent_x/tax.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ def fetch_code(openai_key: str, model: str, prompt: str, url_option: str, chat_f
command = f'{command} --ipv4' if model == 'claude' else command
else: # Linux
command = f"curl -s --location '{url}' --header '{headers[0]}' --header '{headers[1]}' --data '{data}'"
# print(command)
print(command)

try:
res, err = run_command_with_timeout(command, 60)
print(res)
# res = os.popen(command).read().encode('utf-8').decode('utf-8', 'ignore')
if model.lower() == 'dalle':
return json.loads(res)['data'][0]['url']
Expand Down Expand Up @@ -166,32 +167,39 @@ def chat(openai_key: str, model: str, url_option: str):
if user_input == "exit":
break
conversation.append({"role": "user", "content": user_input})
response = fetch_code(openai_key, model, json.dumps(
conversation), url_option, True)
response = fetch_code(openai_key, model, json.dumps(conversation), url_option, True)
print(f'Tax: {response}')
conversation.append({"role": "assistant", "content": response.encode(
'unicode-escape').decode('utf8').replace("'", "")})
# The bash command sent cannot contain single quotes, escaping has no effect. So the single quotes in the conversation will be deleted and the user will not see it.
# print(conversation)


def parallel_ask(data_prompts, chat_model, max_workers, output_file, **args):
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_prompt = {executor.submit(chat_model, prompt=prompt, **args): prompt for prompt in data_prompts}
def parallel_ask(data_prompts, chat_mode, max_workers, output_file, model, **args):
with concurrent.futures.ThreadPoolExecutor(max_workers=int(max_workers)) as executor:
future_to_prompt = []
for prompt in data_prompts:
future_to_prompt.append(executor.submit(
chat_mode, prompt=prompt, model=model, **args))
results = []
for future in concurrent.futures.as_completed(future_to_prompt):
prompt = future_to_prompt[future]
try:
data = future.result()
except Exception as exc:
print(f'{prompt} generated an exception: {exc}')
else:
# print(f'{prompt} generated {data}')
if output_file:
with open(output_file, 'a', encoding='utf-8') as f:
f.write(f'{prompt} : {data}\n\n')
f.close()
data = str(type(exc))
results.append(data)
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
f.write('\n'.join(results))
f.close()


def load_prompts_file(model, path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
text = f.readlines()
text = [line.strip() for line in text]
wrappers = [chat_data_wrapper(model, prompt, False) for prompt in text]
return wrappers


def main() -> None:
Expand All @@ -212,6 +220,9 @@ def main() -> None:
help="URL for API request. Choose from ['openai_gfw', 'openai', 'claude'] or your custom url.")
parser.add_argument('-a', '--show_all', action='store_true',
help='Show all contents in the response.')
parser.add_argument('-p', '--parallel', action='store_true',
help='Parallel mode. If specified, the input file will be read line by line and the responses will be saved to the output file.')
parser.add_argument('--option', metavar='KEY=VALUE', action='append', help='Custom option')
args = parser.parse_args()

prompt = ' '.join(args.prompt)
Expand All @@ -232,6 +243,17 @@ def main() -> None:
# res = get_model_response(openai_key, args.model, prompt)
res = fetch_code(key, args.model, prompt, args.url, False)

if args.option and args.parallel:
custom_options = {}
for option in args.option:
key, value = option.split('=')
custom_options[key] = value

print(parallel_ask(data_prompts=load_prompts_file(args.model, args.input), output_file=args.output, model=args.model, **custom_options))
return

# tax -i input.txt -o output.txt -m gpt-3.5-turbo -u openai_gfw -k xxx --option chat_mode=fetch_code

if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
f.write(res)
Expand All @@ -252,5 +274,5 @@ def main() -> None:


if __name__ == '__main__':
# main()
parallel_ask(data_prompts=['hi'], chat_model=fetch_code, max_workers=3, output_file='output.txt', openai_key='', model='gpt-3.5-turbo', url_option='openai', chat_flag=False)
main()
# parallel_ask(data_prompts=['hi'], chat_mode=fetch_code, max_workers=3, output_file='output.txt', openai_key='', model='gpt-3.5-turbo', url_option='openai', chat_flag=False)

0 comments on commit 63250c4

Please sign in to comment.