Skip to content

Commit

Permalink
using asyncio and multiple questions per cycle
Browse files Browse the repository at this point in the history
  • Loading branch information
PsicoThePato committed May 10, 2024
1 parent ebbbe9d commit 9d23f1f
Showing 1 changed file with 73 additions and 40 deletions.
113 changes: 73 additions & 40 deletions src/synthia/validator/text_validator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import concurrent.futures
import re
import time
from functools import partial
import random
from enum import Enum
from dataclasses import dataclass


import numpy as np
import requests
Expand All @@ -27,7 +27,8 @@

# TODO: make it match ipv6
IP_REGEX = re.compile(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d+")

NUM_QUESTIONS_PER_CYCLE=5
MINIMUM_DATASET_SCORE=0.7

def set_weights(
score_dict: dict[int, float], netuid: int, client: CommuneClient, key: Keypair
Expand Down Expand Up @@ -132,6 +133,20 @@ class ClaudeProviders(Enum):
ANTHROPIC = "anthropic"
OPENROUTER = "openrouter"

@dataclass
class ValidationDataset:
prompt: str
val_answer: str
criteria: Criteria
question_age: float
chosen_subject: str
embedded_val_answer: list[float]

@dataclass
class ModuleInfo:
uid: int
address: list[str] #actually a tuple[str, str] but as a list
key: Ss58Address

class TextValidator(Module):
"""A class for validating text data using a Synthia network.
Expand Down Expand Up @@ -197,7 +212,7 @@ def get_modules(self, client: CommuneClient, netuid: int) -> dict[int, str]:
module_addreses = client.query_map_address(netuid)
return module_addreses

def _get_validation_dataset(self, settings: ValidatorSettings):
def _get_validation_dataset(self, settings: ValidatorSettings, size: int):

# TODO: make ValidatorSettings and the miners settings inherit from a
# common protocol
Expand All @@ -220,36 +235,48 @@ def _get_validation_dataset(self, settings: ValidatorSettings):
retrier = retry(4, [Exception])
generate_explanations = retrier(ig.gen_explanation)

explanations, prompt, criteria = generate_explanations()

dataset: tuple[str, str] = (prompt, explanations)
questions_age = time.time()
return dataset, criteria, questions_age
validation_list: list[ValidationDataset] = []
for _ in range(size):
explanations, prompt, criteria = generate_explanations()
questions_age = time.time()
subject, val_answer = self._split_val_subject(explanations)
embedded_val_answer = self.embedder.get_embedding(val_answer)
val_dataset = ValidationDataset(
prompt=prompt, criteria=criteria, question_age=questions_age,
val_answer=val_answer, chosen_subject=subject,
embedded_val_answer=embedded_val_answer
)
validation_list.append(val_dataset)
return validation_list

def _get_miner_prediction(
async def _get_miner_prediction(
self,
question: str,
val_info: ValidationDataset,
miner_info: tuple[list[str], Ss58Address],
) -> str | None:
) -> tuple[str | None, ValidationDataset]:
connection, miner_key = miner_info
module_ip, module_port = connection

question = get_miner_prompt(
val_info.criteria,
val_info.chosen_subject,
len(val_info.val_answer)
)
client = ModuleClient(module_ip, int(module_port), self.key)
try:
miner_answer = asyncio.run(
client.call(
miner_answer = await client.call(
"generate", miner_key,
{"prompt": question}, timeout=self.call_timeout
)
)
)

miner_answer = miner_answer["answer"]

except Exception as e:
log(f"Miner {module_ip}:{module_port} failed to generate an answer")
print(e)
miner_answer = None

return miner_answer
return miner_answer, val_info

def _get_unit_euclid_distance(
self, embedded_miner_answer: list[float], embbeded_val_answer: list[float]
Expand Down Expand Up @@ -321,37 +348,42 @@ async def validate_step(
raise RuntimeError(
f"validator key {val_ss58} is not registered in subnet"
)
modules_info: dict[int, tuple[list[str], Ss58Address]] = {}
modules_info: dict[int, ModuleInfo] = {}

modules_filtered_address = get_ip_port(modules_adresses)
for module_id in modules_keys.keys():
module_addr = modules_filtered_address.get(module_id, None)
if not module_addr:
continue
modules_info[module_id] = (module_addr, modules_keys[module_id])
modules_info[module_id] = ModuleInfo(
module_id, module_addr, modules_keys[module_id]
)

response_cache: list[str] = []
score_dict: dict[int, float] = {}
hf_data_list: list[dict[str, str]] = []
# == Validation loop / Scoring ==

dataset, criteria, _ = self._get_validation_dataset(settings)
_, val_answer = dataset
subject, val_answer = self._split_val_subject(val_answer)
miner_prompt = get_miner_prompt(criteria, subject, len(val_answer))
embedded_val_answer = self.embedder.get_embedding(val_answer)
val_dataset = self._get_validation_dataset(settings, NUM_QUESTIONS_PER_CYCLE)


get_miner_prediction = partial(self._get_miner_prediction, miner_prompt)
log(f"Selected the following miners: {modules_info.keys()}")
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
it = executor.map(get_miner_prediction, modules_info.values())
miner_answers = [*it]
futures: list[asyncio.Task[tuple[str | None, ValidationDataset]]] = []
for mod_info in modules_info.values():
val_info = random.choice(val_dataset)
future = asyncio.create_task(
self._get_miner_prediction(
val_info, (mod_info.address, mod_info.key)
)
)
futures.append(future)
miner_answers = await asyncio.gather(*futures)

for uid, miner_response in zip(modules_info.keys(), miner_answers):
miner_answer = miner_response
miner_answer, val_info = miner_response
if not miner_answer:
log(f"Skipping miner {uid} that didn't answer")
continue
score = self._score_miner(miner_answer, embedded_val_answer)
score = self._score_miner(miner_answer, val_info.embedded_val_answer)
for answer in response_cache:
similarity = fuzz.ratio(answer, miner_answer) # type: ignore
log(f"similarity: {similarity}")
Expand All @@ -360,18 +392,19 @@ async def validate_step(
# score has to be lower or eq to 1, as one is the best score
assert score <= 1
score_dict[uid] = score
hf_data = self._to_hf_data(
criteria,
subject,
miner_answer,
score,
)
hf_data_list.append(hf_data)
if score >= MINIMUM_DATASET_SCORE:
hf_data = self._to_hf_data(
val_info.criteria,
val_info.chosen_subject,
miner_answer,
score,
)
hf_data_list.append(hf_data)
if not score_dict:
log("No miner managed to give a valid answer")
return []
_ = set_weights(score_dict, self.netuid, self.client, self.key)

_ = set_weights(score_dict, self.netuid, self.client, self.key)
return hf_data_list

def upload_data(
Expand Down Expand Up @@ -402,7 +435,7 @@ def upload_data(
log(f"Upload attempt {attempt} failed: {e}")
attempt += 1
if attempt > max_attempts:
print("Could not upload data. ")
log("Could not upload data. ")
break

def validation_loop(self, settings: ValidatorSettings | None = None) -> None:
Expand Down

0 comments on commit 9d23f1f

Please sign in to comment.