From c49c2e9d181e554161049d2d9f8f2f0858ac08f1 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Thu, 6 Jun 2024 16:07:46 -0400 Subject: [PATCH] refactor: create experiment uuid outside run method --- parea/experiment/experiment.py | 15 +++++++++------ pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index f217a538..1fd3f942 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -89,8 +89,8 @@ async def experiment( data: Union[str, int, Iterable[dict]], func: Callable, p: Parea, + experiment_uuid: str, n_trials: int = 1, - metadata: Optional[Dict[str, str]] = None, dataset_level_evals: Optional[List[Callable]] = None, n_workers: int = 10, ) -> ExperimentStatsSchema: @@ -102,8 +102,8 @@ async def experiment( If it is a list of dictionaries, the key "target" is reserved for the target/expected output of that sample. param func: The function to run. This function should accept inputs that match the keys of the data field. param p: The Parea instance to use for running the experiment. + param experiment_uuid: The UUID of the experiment. This is used to associate traces with the experiment. param n_trials: The number of times to run the experiment on the same data. - param metadata: A dictionary of metadata to attach to the experiment. param dataset_level_evals: A list of functions to run on the dataset level. Each function should accept a list of EvaluatedLogs and return a float or a EvaluationResult. If a float is returned, the name of the function will be used as the name of the evaluation. param n_workers: The number of workers to use for running the experiment. @@ -122,8 +122,6 @@ async def experiment( len_test_cases = len(data) if isinstance(data, list) else 0 print(f"Running {n_trials} trials of the experiment \n") - experiment_schema: ExperimentSchema = p.create_experiment(CreateExperimentRequest(name=experiment_name, run_name=run_name, metadata=metadata)) - experiment_uuid = experiment_schema.uuid os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid sem = asyncio.Semaphore(n_workers) @@ -205,6 +203,7 @@ class Experiment: p: Parea = field(default=None) experiment_name: str = field run_name: str = field(init=False) + experiment_uuid: str = field(init=False, default=None) n_workers: int = field(default=10) # The number of times to run the experiment on the same data. n_trials: int = field(default=1) @@ -238,8 +237,10 @@ def run(self, run_name: Optional[str] = None) -> None: try: self._gen_run_name_if_none(run_name) + experiment_schema: ExperimentSchema = self.p.create_experiment(CreateExperimentRequest(name=self.experiment_name, run_name=run_name, metadata=self.metadata)) + self.experiment_uuid = experiment_schema.uuid self.experiment_stats = asyncio.run( - experiment(self.experiment_name, self.run_name, self.data, self.func, self.p, self.n_trials, self.metadata, self.dataset_level_evals, self.n_workers) + experiment(self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers) ) except Exception as e: import traceback @@ -258,8 +259,10 @@ async def arun(self, run_name: Optional[str] = None) -> None: try: self._gen_run_name_if_none(run_name) + experiment_schema: ExperimentSchema = await self.p.acreate_experiment(CreateExperimentRequest(name=self.experiment_name, run_name=run_name, metadata=self.metadata)) + self.experiment_uuid = experiment_schema.uuid self.experiment_stats = await experiment( - self.experiment_name, self.run_name, self.data, self.func, self.p, self.n_trials, self.metadata, self.dataset_level_evals, self.n_workers + self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers ) except Exception as e: import traceback diff --git a/pyproject.toml b/pyproject.toml index aef82259..8b49a62a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.167" +version = "0.2.168" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]