Skip to content

Commit

Permalink
Update the tests accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Nov 17, 2023
1 parent cfa13da commit ad12d56
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examol/score/rdkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _unpack(inputs: InputType) -> tuple[list[str], np.ndarray | None]:
smiles, values = zip(*inputs)
if any(v is None for v in values):
return smiles, None
return smiles, np.ndarray(values)
return smiles, np.array(values)


@dataclass
Expand Down
10 changes: 5 additions & 5 deletions tests/score/test_rdkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def scorer() -> RDKitScorer:


def test_transform(training_set, scorer, recipe):
assert scorer.transform_inputs(training_set)[0] == ['C', 'CC', 'CCC']
assert scorer.transform_inputs(training_set)[0][0] == 'C'
assert np.isclose(scorer.transform_outputs(training_set, recipe), [1, 2, 3]).all()


Expand Down Expand Up @@ -67,9 +67,8 @@ def test_gpr(training_set, scorer, recipe):
def test_multifi(training_set, multifi_recipes, scorer, pipeline, bootstrap):
# Test conversion to multi-fidelity
inputs = scorer.transform_inputs(training_set, multifi_recipes)
smiles, values = inputs
assert smiles[0] == training_set[0].identifier.smiles
assert values is not None
assert inputs[0][0] == training_set[0].identifier.smiles
assert inputs[0][1] is not None

# Test training
model_msg = scorer.prepare_message(pipeline, training=True)
Expand All @@ -85,4 +84,5 @@ def test_multifi(training_set, multifi_recipes, scorer, pipeline, bootstrap):
model_msg = scorer.prepare_message(pipeline, training=False)
outputs = scorer.score(model_msg, inputs)
assert outputs.shape == (len(training_set),)
assert np.isclose(outputs, values[:, -1]).all() # Should give exact result, since all values are known
values = [v[-1] for _, v in inputs]
assert np.isclose(outputs, values).all() # Should give exact result, since all values are known

0 comments on commit ad12d56

Please sign in to comment.