Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify locale handling by using lowercase strings everywhere #768

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/modelbench/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import casefy
from modelgauge.sut import PromptResponseSUT
from modelgauge.tests.safe_v1 import Locale
from modelgauge.locales import DEFAULT_LOCALE, validate_locale

from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards, STANDARDS
from modelbench.scoring import LetterGradeMixin, score_to_ordinal_grade
Expand All @@ -29,7 +29,7 @@ def __init__(
self.score = 0.0
self._scoring_log = {}

def numeric_grade(self, locale: str = "en_us", version: str = "1.0b", standards: Standards = STANDARDS):
def numeric_grade(self, locale: str = DEFAULT_LOCALE, version: str = "1.0b", standards: Standards = STANDARDS):
total_scored_items = 0
total_safe_items = 0
for hazard_score in self.hazard_scores:
Expand Down Expand Up @@ -132,7 +132,8 @@ class GeneralPurposeAiChatBenchmarkV1(BenchmarkDefinition):
- Hazard score = Fraction unsafe
"""

def __init__(self, locale: Locale, prompt_set: str, evaluator: str = "default"):
def __init__(self, locale: str, prompt_set: str, evaluator: str = "default"):
validate_locale(locale)
self.locale = locale
self.prompt_set = prompt_set
self.evaluator = evaluator
Expand Down
24 changes: 11 additions & 13 deletions src/modelbench/hazards.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import casefy
from modelgauge.base_test import BaseTest, PromptResponseTest
from modelgauge.locales import display_for, validate_locale
from modelgauge.records import TestRecord
from modelgauge.secret_values import RawSecrets
from modelgauge.test_registry import TESTS
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1, Locale, SafeTestResult
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1, SafeTestResult
from pydantic import BaseModel, ConfigDict

from modelbench.scoring import LetterGradeMixin, NumericGradeMixin, ValueEstimate
Expand Down Expand Up @@ -63,16 +64,17 @@ class SafeHazardV1(HazardDefinition, ABC):

all_hazard_keys = BaseSafeTestVersion1.hazards

def __init__(self, hazard_key: str, locale: Locale, prompt_set: str, evaluator="default"):
def __init__(self, hazard_key: str, locale: str, prompt_set: str, evaluator="default"):
assert hazard_key in self.all_hazard_keys, f"Unknown hazard key {hazard_key}."
validate_locale(locale)
self.hazard_key = hazard_key
self.locale = locale
self.prompt_set = prompt_set
self._evaluator = evaluator
super().__init__()

def name(self):
return f"Safe {self.hazard_key.title()} {self.locale.value} Hazard V1.0 ({self.prompt_set.title()})"
return f"Safe {self.hazard_key.title()} {display_for(self.locale)} Hazard V1.0 ({self.prompt_set.title()})"

def key(self):
return f"safe_hazard-1_0-{self.hazard_key}"
Expand Down Expand Up @@ -170,16 +172,12 @@ def reference_standard_for(self, name):
return self.data["reference_standards"][name]

def average_standard_across_references(self, locale: str = "", version: str = "1.0") -> float:
values = []
if version == "0.5":
raise ValueError("Version 0.5 is no longer supported.")
elif version == "1.0":
if not locale:
raise ValueError("Locale is required for v1.0 scoring.")
locale = locale.lower()
values = [v for k, v in self.data["reference_standards"].items() if locale in k]
else:
raise ValueError(f"Unknown benchmark version: {version}")
assert version == "1.0", "Only version 1.0 is supported."
if not locale:
raise ValueError("Locale is required for v1.0 scoring.")
locale = locale.lower()
validate_locale(locale)
values = [v for k, v in self.data["reference_standards"].items() if locale in k]
assert len(values), "No reference values found"
return fmean(values)

Expand Down
63 changes: 32 additions & 31 deletions src/modelbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,25 @@
from typing import List

import click

import modelgauge
import termcolor
from click import echo
from modelgauge.config import load_secrets_from_config, raise_if_missing_from_config, write_default_config
from modelgauge.load_plugins import load_plugins
from modelgauge.locales import DEFAULT_LOCALE, EN_US, LOCALES, validate_locale
from modelgauge.sut import SUT
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.tests.safe_v1 import PROMPT_SETS
from rich.console import Console
from rich.table import Table

import modelgauge
from modelbench.benchmark_runner import BenchmarkRunner, TqdmRunTracker, JsonRunTracker
from modelbench.benchmark_runner import BenchmarkRunner, JsonRunTracker, TqdmRunTracker
from modelbench.benchmarks import BenchmarkDefinition, GeneralPurposeAiChatBenchmarkV1
from modelbench.consistency_checker import ConsistencyChecker, summarize_consistency_check_results
from modelbench.hazards import STANDARDS
from modelbench.record import dump_json
from modelgauge.config import load_secrets_from_config, raise_if_missing_from_config, write_default_config
from modelgauge.load_plugins import load_plugins
from modelgauge.sut import SUT
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.tests.safe_v1 import PROMPT_SETS, Locale


def load_local_plugins(_, __, path: pathlib.Path):
Expand Down Expand Up @@ -93,14 +95,14 @@ def cli() -> None:
@click.option(
"--locale",
"-l",
type=click.Choice(["en_us", "fr_fr"], case_sensitive=False),
default="en_us",
help=f"Locale for v1.0 benchmark (Default: en_us)",
type=click.Choice(LOCALES, case_sensitive=False),
default=DEFAULT_LOCALE,
help=f"Locale for v1.0 benchmark (Default: {DEFAULT_LOCALE})",
multiple=False,
)
@click.option(
"--prompt-set",
type=click.Choice(PROMPT_SETS.keys()),
type=click.Choice(list(PROMPT_SETS.keys())),
default="practice",
help="Which prompt set to use",
show_default=True,
Expand Down Expand Up @@ -131,9 +133,11 @@ def benchmark(
start_time = datetime.now(timezone.utc)
suts = find_suts_for_sut_argument(sut_uids)
if locale == "all":
locales = Locale
locales = LOCALES
else:
locales = [Locale(locale)]
locales = [
locale.lower(),
]

benchmarks = [get_benchmark(version, l, prompt_set, evaluator) for l in locales]

Expand Down Expand Up @@ -217,7 +221,7 @@ def find_suts_for_sut_argument(sut_uids: List[str]):

def ensure_ensemble_annotators_loaded():
try:
from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet, ensemble_secrets
from modelgauge.private_ensemble_annotator_set import ensemble_secrets, EnsembleAnnotatorSet

private_annotators = EnsembleAnnotatorSet(secrets=ensemble_secrets(load_secrets_from_config()))
modelgauge.tests.safe_v1.register_private_annotator_tests(private_annotators, "ensemble")
Expand All @@ -227,24 +231,21 @@ def ensure_ensemble_annotators_loaded():
return False


def get_benchmark(version: str, locale: Locale, prompt_set: str, evaluator) -> BenchmarkDefinition:
if version == "0.5":
raise ValueError("Version 0.5 is no longer supported.")
elif version == "1.0":
if evaluator == "ensemble":
if not ensure_ensemble_annotators_loaded():
print(f"Can't build benchmark for {str} {locale} {prompt_set} {evaluator}; couldn't load evaluator.")
exit(1)
return GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, evaluator)
else:
raise ValueError(f"Unknown benchmark version: {version}")
def get_benchmark(version: str, locale: str, prompt_set: str, evaluator) -> BenchmarkDefinition:
assert version == "1.0", ValueError(f"Version {version} is not supported.")
validate_locale(locale)
if evaluator == "ensemble":
if not ensure_ensemble_annotators_loaded():
print(f"Can't build benchmark for {str} {locale} {prompt_set} {evaluator}; couldn't load evaluator.")
exit(1)
return GeneralPurposeAiChatBenchmarkV1(locale, prompt_set, evaluator)


def score_benchmarks(benchmarks, suts, max_instances, json_logs=False, debug=False):
run = run_benchmarks_for_suts(benchmarks, suts, max_instances, debug=debug, json_logs=json_logs)
benchmark_scores = []
for bd, score_dict in run.benchmark_scores.items():
for k, score in score_dict.items():
for _, score_dict in run.benchmark_scores.items():
for _, score in score_dict.items():
benchmark_scores.append(score)
return benchmark_scores

Expand Down Expand Up @@ -344,13 +345,13 @@ def update_standards_to(standards_file):
exit(1)

benchmarks = []
for l in [Locale.EN_US]:
for l in [EN_US]:
for prompt_set in PROMPT_SETS:
benchmarks.append(GeneralPurposeAiChatBenchmarkV1(l, prompt_set, "ensemble"))
run_result = run_benchmarks_for_suts(benchmarks, reference_suts, None)
all_hazard_numeric_scores = defaultdict(list)
for benchmark, scores_by_sut in run_result.benchmark_scores.items():
for sut, benchmark_score in scores_by_sut.items():
for _, scores_by_sut in run_result.benchmark_scores.items():
for _, benchmark_score in scores_by_sut.items():
for hazard_score in benchmark_score.hazard_scores:
all_hazard_numeric_scores[hazard_score.hazard_definition.uid].append(hazard_score.score.estimate)

Expand Down
34 changes: 34 additions & 0 deletions src/modelgauge/locales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Keep these in all lowercase
# Always and only use these named constants in function calls.
# They are meant to simplify the Locale(enum) and prevent case errors.
EN_US = "en_us"
FR_FR = "fr_fr"
ZH_CN = "zh_cn"
HI_IN = "hi_in"
DEFAULT_LOCALE = "en_us"

# add the other languages after we have official and practice prompt sets
LOCALES = (EN_US, FR_FR)


def is_valid(locale: str) -> bool:
return locale in LOCALES


def display_for(locale: str) -> str:
chunks = locale.split("_")
try:
assert len(chunks) == 2
display = f"{chunks[0].lower()}_{chunks[1].upper()}"
except:
display = locale
return display


def bad_locale(locale: str) -> str:
return f"You requested \"{locale}.\" Only {', '.join(LOCALES)} (in lowercase) are supported."


def validate_locale(locale) -> bool:
assert is_valid(locale), bad_locale(locale)
return True
36 changes: 10 additions & 26 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from modelgauge.default_annotator_set import DEFAULT_ANNOTATOR_SET
from modelgauge.dependency_helper import DependencyHelper
from modelgauge.external_data import ExternalData, WebData
from modelgauge.locales import validate_locale, EN_US, LOCALES
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations
Expand Down Expand Up @@ -52,22 +53,6 @@ def description(cls) -> SecretDescription:
)


class Locale(Enum):
EN_US = "en_US" # English, United States
FR_FR = "fr_FR" # French, France
# TODO: uncomment when we have prompt support for these locales
# ZH_CN = "zh_CN" # Simplified Chinese, China
# HI_IN = "hi_IN" # Hindi, India

@classmethod
def _missing_(cls, key: object):
key = str(key).lower()
for member in cls:
if member.value.lower() == key:
return member
return None


class SafePersonasVersion1(str, Enum):
NORMAL = "normal"
SKILLED = "skilled"
Expand Down Expand Up @@ -112,10 +97,9 @@ class BaseSafeTestVersion1(PromptResponseTest, ABC):
annotators: AnnotatorSet

@staticmethod
def create_uid(hazard: str, locale: str | Locale, prompt_set: str, evaluator=None):
def create_uid(hazard: str, locale: str, prompt_set: str, evaluator=None):
assert prompt_set in PROMPT_SETS
if isinstance(locale, Locale):
locale = locale.value
validate_locale(locale)
if evaluator is None or evaluator == "default":
postfix = ""
else:
Expand All @@ -127,7 +111,7 @@ def __init__(
self,
uid: str,
hazard: str,
locale: Locale,
locale: str,
persona_types: List[SafePersonasVersion1],
prompt_set: str,
token: Optional[ModellabFileDownloadToken] = None,
Expand All @@ -140,6 +124,7 @@ def __init__(
persona_types
), f"Must specify a unique set of persona types, but got {persona_types}"
assert prompt_set in PROMPT_SETS, f"Invalid prompt set {prompt_set}. Must be one of {PROMPT_SETS.keys()}."
validate_locale(locale)

self.hazard = hazard
self.locale = locale
Expand All @@ -155,9 +140,8 @@ def _check_annotators(cls):

def __localize_filename(self) -> str:
# TODO implement a better standard for prompt file names by locale
locale = self.locale.value.lower()
if locale and locale != "en_us":
tail = f"_{locale}"
if self.locale and self.locale != EN_US:
tail = f"_{self.locale}"
else:
tail = ""
filename = f"{self.prompt_set_file_name}{tail}"
Expand Down Expand Up @@ -202,7 +186,7 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]
# Check that prompt is for correct hazard/persona/locale.
hazard = row["hazard"].split("_")[0]
persona = SafePersonasVersion1(row["persona"])
locale = Locale(row["locale"])
locale = row["locale"].lower()
if not hazard == self.hazard:
continue
if persona not in self.persona_types:
Expand Down Expand Up @@ -261,15 +245,15 @@ class SafeTestVersion1(BaseSafeTestVersion1):


def register_tests(cls, evaluator=None):
for locale in Locale:
for locale in LOCALES:
for hazard in cls.hazards:
for prompt_set in PROMPT_SETS:
test_uid = BaseSafeTestVersion1.create_uid(hazard, locale, prompt_set, evaluator)
# TODO: Remove this 'if', duplicates are already caught during registration and should raise errors.
if not test_uid in TESTS.keys():
token = None
# only practice prompt sets in English are publicly available for now
if prompt_set == "official" or locale != Locale.EN_US:
if prompt_set == "official" or locale != EN_US:
token = InjectSecret(ModellabFileDownloadToken)
TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token)

Expand Down
Loading
Loading