From 9d23f1fdc049202c0665e7bec184e5be9f234c33 Mon Sep 17 00:00:00 2001 From: PsicoThePato Date: Fri, 10 May 2024 02:56:24 -0300 Subject: [PATCH] using asyncio and multiple questions per cycle --- src/synthia/validator/text_validator.py | 113 +++++++++++++++--------- 1 file changed, 73 insertions(+), 40 deletions(-) diff --git a/src/synthia/validator/text_validator.py b/src/synthia/validator/text_validator.py index 0aaba78..56f9efd 100644 --- a/src/synthia/validator/text_validator.py +++ b/src/synthia/validator/text_validator.py @@ -1,10 +1,10 @@ import asyncio -import concurrent.futures import re import time -from functools import partial import random from enum import Enum +from dataclasses import dataclass + import numpy as np import requests @@ -27,7 +27,8 @@ # TODO: make it match ipv6 IP_REGEX = re.compile(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d+") - +NUM_QUESTIONS_PER_CYCLE=5 +MINIMUM_DATASET_SCORE=0.7 def set_weights( score_dict: dict[int, float], netuid: int, client: CommuneClient, key: Keypair @@ -132,6 +133,20 @@ class ClaudeProviders(Enum): ANTHROPIC = "anthropic" OPENROUTER = "openrouter" +@dataclass +class ValidationDataset: + prompt: str + val_answer: str + criteria: Criteria + question_age: float + chosen_subject: str + embedded_val_answer: list[float] + +@dataclass +class ModuleInfo: + uid: int + address: list[str] #actually a tuple[str, str] but as a list + key: Ss58Address class TextValidator(Module): """A class for validating text data using a Synthia network. @@ -197,7 +212,7 @@ def get_modules(self, client: CommuneClient, netuid: int) -> dict[int, str]: module_addreses = client.query_map_address(netuid) return module_addreses - def _get_validation_dataset(self, settings: ValidatorSettings): + def _get_validation_dataset(self, settings: ValidatorSettings, size: int): # TODO: make ValidatorSettings and the miners settings inherit from a # common protocol @@ -220,28 +235,40 @@ def _get_validation_dataset(self, settings: ValidatorSettings): retrier = retry(4, [Exception]) generate_explanations = retrier(ig.gen_explanation) - explanations, prompt, criteria = generate_explanations() - - dataset: tuple[str, str] = (prompt, explanations) - questions_age = time.time() - return dataset, criteria, questions_age + validation_list: list[ValidationDataset] = [] + for _ in range(size): + explanations, prompt, criteria = generate_explanations() + questions_age = time.time() + subject, val_answer = self._split_val_subject(explanations) + embedded_val_answer = self.embedder.get_embedding(val_answer) + val_dataset = ValidationDataset( + prompt=prompt, criteria=criteria, question_age=questions_age, + val_answer=val_answer, chosen_subject=subject, + embedded_val_answer=embedded_val_answer + ) + validation_list.append(val_dataset) + return validation_list - def _get_miner_prediction( + async def _get_miner_prediction( self, - question: str, + val_info: ValidationDataset, miner_info: tuple[list[str], Ss58Address], - ) -> str | None: + ) -> tuple[str | None, ValidationDataset]: connection, miner_key = miner_info module_ip, module_port = connection + question = get_miner_prompt( + val_info.criteria, + val_info.chosen_subject, + len(val_info.val_answer) + ) client = ModuleClient(module_ip, int(module_port), self.key) try: - miner_answer = asyncio.run( - client.call( + miner_answer = await client.call( "generate", miner_key, {"prompt": question}, timeout=self.call_timeout - ) - ) + ) + miner_answer = miner_answer["answer"] except Exception as e: @@ -249,7 +276,7 @@ def _get_miner_prediction( print(e) miner_answer = None - return miner_answer + return miner_answer, val_info def _get_unit_euclid_distance( self, embedded_miner_answer: list[float], embbeded_val_answer: list[float] @@ -321,37 +348,42 @@ async def validate_step( raise RuntimeError( f"validator key {val_ss58} is not registered in subnet" ) - modules_info: dict[int, tuple[list[str], Ss58Address]] = {} + modules_info: dict[int, ModuleInfo] = {} modules_filtered_address = get_ip_port(modules_adresses) for module_id in modules_keys.keys(): module_addr = modules_filtered_address.get(module_id, None) if not module_addr: continue - modules_info[module_id] = (module_addr, modules_keys[module_id]) + modules_info[module_id] = ModuleInfo( + module_id, module_addr, modules_keys[module_id] + ) response_cache: list[str] = [] score_dict: dict[int, float] = {} hf_data_list: list[dict[str, str]] = [] # == Validation loop / Scoring == - - dataset, criteria, _ = self._get_validation_dataset(settings) - _, val_answer = dataset - subject, val_answer = self._split_val_subject(val_answer) - miner_prompt = get_miner_prompt(criteria, subject, len(val_answer)) - embedded_val_answer = self.embedder.get_embedding(val_answer) + val_dataset = self._get_validation_dataset(settings, NUM_QUESTIONS_PER_CYCLE) + - get_miner_prediction = partial(self._get_miner_prediction, miner_prompt) log(f"Selected the following miners: {modules_info.keys()}") - with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: - it = executor.map(get_miner_prediction, modules_info.values()) - miner_answers = [*it] + futures: list[asyncio.Task[tuple[str | None, ValidationDataset]]] = [] + for mod_info in modules_info.values(): + val_info = random.choice(val_dataset) + future = asyncio.create_task( + self._get_miner_prediction( + val_info, (mod_info.address, mod_info.key) + ) + ) + futures.append(future) + miner_answers = await asyncio.gather(*futures) + for uid, miner_response in zip(modules_info.keys(), miner_answers): - miner_answer = miner_response + miner_answer, val_info = miner_response if not miner_answer: log(f"Skipping miner {uid} that didn't answer") continue - score = self._score_miner(miner_answer, embedded_val_answer) + score = self._score_miner(miner_answer, val_info.embedded_val_answer) for answer in response_cache: similarity = fuzz.ratio(answer, miner_answer) # type: ignore log(f"similarity: {similarity}") @@ -360,18 +392,19 @@ async def validate_step( # score has to be lower or eq to 1, as one is the best score assert score <= 1 score_dict[uid] = score - hf_data = self._to_hf_data( - criteria, - subject, - miner_answer, - score, - ) - hf_data_list.append(hf_data) + if score >= MINIMUM_DATASET_SCORE: + hf_data = self._to_hf_data( + val_info.criteria, + val_info.chosen_subject, + miner_answer, + score, + ) + hf_data_list.append(hf_data) if not score_dict: log("No miner managed to give a valid answer") return [] - _ = set_weights(score_dict, self.netuid, self.client, self.key) + _ = set_weights(score_dict, self.netuid, self.client, self.key) return hf_data_list def upload_data( @@ -402,7 +435,7 @@ def upload_data( log(f"Upload attempt {attempt} failed: {e}") attempt += 1 if attempt > max_attempts: - print("Could not upload data. ") + log("Could not upload data. ") break def validation_loop(self, settings: ValidatorSettings | None = None) -> None: