Skip to content

Commit

Permalink
Merge pull request #925 from parea-ai/refactor-create-exp-uuid-outsid…
Browse files Browse the repository at this point in the history
…e-run

refactor: create experiment uuid outside run method
  • Loading branch information
joschkabraun authored Jun 6, 2024
2 parents 9f21a61 + c49c2e9 commit 2f0db49
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
15 changes: 9 additions & 6 deletions parea/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ async def experiment(
data: Union[str, int, Iterable[dict]],
func: Callable,
p: Parea,
experiment_uuid: str,
n_trials: int = 1,
metadata: Optional[Dict[str, str]] = None,
dataset_level_evals: Optional[List[Callable]] = None,
n_workers: int = 10,
) -> ExperimentStatsSchema:
Expand All @@ -102,8 +102,8 @@ async def experiment(
If it is a list of dictionaries, the key "target" is reserved for the target/expected output of that sample.
param func: The function to run. This function should accept inputs that match the keys of the data field.
param p: The Parea instance to use for running the experiment.
param experiment_uuid: The UUID of the experiment. This is used to associate traces with the experiment.
param n_trials: The number of times to run the experiment on the same data.
param metadata: A dictionary of metadata to attach to the experiment.
param dataset_level_evals: A list of functions to run on the dataset level. Each function should accept a list of EvaluatedLogs and return a float or a
EvaluationResult. If a float is returned, the name of the function will be used as the name of the evaluation.
param n_workers: The number of workers to use for running the experiment.
Expand All @@ -122,8 +122,6 @@ async def experiment(
len_test_cases = len(data) if isinstance(data, list) else 0
print(f"Running {n_trials} trials of the experiment \n")

experiment_schema: ExperimentSchema = p.create_experiment(CreateExperimentRequest(name=experiment_name, run_name=run_name, metadata=metadata))
experiment_uuid = experiment_schema.uuid
os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid

sem = asyncio.Semaphore(n_workers)
Expand Down Expand Up @@ -205,6 +203,7 @@ class Experiment:
p: Parea = field(default=None)
experiment_name: str = field
run_name: str = field(init=False)
experiment_uuid: str = field(init=False, default=None)
n_workers: int = field(default=10)
# The number of times to run the experiment on the same data.
n_trials: int = field(default=1)
Expand Down Expand Up @@ -238,8 +237,10 @@ def run(self, run_name: Optional[str] = None) -> None:

try:
self._gen_run_name_if_none(run_name)
experiment_schema: ExperimentSchema = self.p.create_experiment(CreateExperimentRequest(name=self.experiment_name, run_name=run_name, metadata=self.metadata))
self.experiment_uuid = experiment_schema.uuid
self.experiment_stats = asyncio.run(
experiment(self.experiment_name, self.run_name, self.data, self.func, self.p, self.n_trials, self.metadata, self.dataset_level_evals, self.n_workers)
experiment(self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers)
)
except Exception as e:
import traceback
Expand All @@ -258,8 +259,10 @@ async def arun(self, run_name: Optional[str] = None) -> None:

try:
self._gen_run_name_if_none(run_name)
experiment_schema: ExperimentSchema = await self.p.acreate_experiment(CreateExperimentRequest(name=self.experiment_name, run_name=run_name, metadata=self.metadata))
self.experiment_uuid = experiment_schema.uuid
self.experiment_stats = await experiment(
self.experiment_name, self.run_name, self.data, self.func, self.p, self.n_trials, self.metadata, self.dataset_level_evals, self.n_workers
self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers
)
except Exception as e:
import traceback
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.167"
version = "0.2.168"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit 2f0db49

Please sign in to comment.