diff --git a/repeng/tests.py b/repeng/tests.py index b0441fc..d32cd77 100644 --- a/repeng/tests.py +++ b/repeng/tests.py @@ -68,13 +68,17 @@ def gen(vector: ControlVector | None, strength_coeff: float | None = None): assert baseline == gen(happy_vector * 0.0) assert baseline == gen(happy_vector - happy_vector) - assert happy == "You are feeling great and happy. I'm" + assert happy == "You are feeling a little more relaxed and enjoying" # these should be identical assert happy == gen(happy_vector, 20.0) assert happy == gen(happy_vector * 20) assert happy == gen(-(happy_vector * -20)) - assert sad == "You are feeling the worst,\n—(" + assert sad == 'You are feeling the fucking damn goddamn worst,"' + # these should be identical + assert sad == gen(happy_vector, -50.0) + assert sad == gen(happy_vector * -50) + assert sad == gen(-(happy_vector * 50)) def test_train_llama_tinystories():