Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
MariaHei committed Nov 3, 2022
2 parents 73db85c + 7559259 commit e6c01d8
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
73 changes: 73 additions & 0 deletions src/eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,79 @@ function eval_SC_loose(SChat, SC, k, data, target_col; digits=4)
round(correct / total, digits=digits)
end


"""
eval_SC_loose(SChat, SC, SC_rest, k; digits=4)
Assess model accuracy on the basis of the correlations of row vectors of Chat and
C or Shat and S. Count it as correct if one of the top k candidates is correct.
Does not consider homophones.
Takes into account gold-standard vectors in both the actual targets (SC)
as well as in a second matrix (e.g. the training or validation data; SC_rest).
# Obligatory Arguments
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix of the data under consideration
- `SC_rest::Union{SparseMatrixCSC, Matrix}`: the C or S matrix of rest data
- `k`: top k candidates
# Optional Arguments
- `digits=4`: the specified number of digits after the decimal place (or before if negative)
```julia
eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, k)
eval_SC_loose(Shat_val, S_val, S_train, k)
```
"""
function eval_SC_loose(SChat, SC, SC_rest, k; digits=4)
SC_combined = vcat(SC, SC_rest)
eval_SC_loose(SChat, SC_combined, k, digits=digits)
end

"""
eval_SC_loose(SChat, SC, SC_rest, k, data, data_rest, target_col; digits=4)
Assess model accuracy on the basis of the correlations of row vectors of Chat and
C or Shat and S. Count it as correct if one of the top k candidates is correct.
Considers homophones.
Takes into account gold-standard vectors in both the actual targets (SC)
as well as in a second matrix (e.g. the training or validation data; SC_rest).
# Obligatory Arguments
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix of the data under consideration
- `SC_rest::Union{SparseMatrixCSC, Matrix}`: the C or S matrix of rest data
- `k`: top k candidates
- `data`: dataset under consideration
- `data_rest`: remaining dataset
- `target_col`: target column name
# Optional Arguments
- `digits=4`: the specified number of digits after the decimal place (or before if negative)
```julia
eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, k, latin_val, latin_train, :Word)
eval_SC_loose(Shat_val, S_val, S_train, k, latin_val, latin_train, :Word)
```
"""
function eval_SC_loose(SChat, SC, SC_rest, k, data, data_rest, target_col; digits=4)
SC_combined = vcat(SC, SC_rest)

n_data = size(data, 1)
n_data_rest = size(data_rest, 1)

if n_data > n_data_rest
data_combined = similar(data, 0)
else
data_combined = similar(data_rest, 0)
end

append!(data_combined, data)
append!(data_combined, data_rest)

eval_SC_loose(SChat, SC_combined, k, data_combined, target_col, digits=digits)
end

"""
eval_manual(res, data, i2f)
Expand Down
47 changes: 47 additions & 0 deletions test/eval_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,53 @@ end
@test JudiLing.eval_SC_loose(Chat, cue_obj.C, k, latin, :Word) == 1
@test JudiLing.eval_SC_loose(Shat, S, k, latin, :Word) == 1
end

latin_train = DataFrame(
Word = ["ABC", "BCD", "CDE", "BCD"],
Lexeme = ["A", "B", "C", "B"],
Person = ["B", "C", "D", "D"],
)

latin_val = DataFrame(
Word = ["ABC", "BCD"],
Lexeme = ["A", "B"],
Person = ["B", "C"],
)

cue_obj_train, cue_obj_val = JudiLing.make_combined_cue_matrix(
latin_train,
latin_val,
grams = 3,
target_col = :Word,
tokenized = false,
keep_sep = false,
)

n_features = size(cue_obj_train.C, 2)
S_train, S_val = JudiLing.make_combined_S_matrix(
latin_train,
latin_val,
[:Lexeme],
[:Person],
ncol = n_features,
add_noise = false
)

G = JudiLing.make_transform_matrix(S_train, cue_obj_train.C)
Chat_val = S_val * G
Chat_train = S_train * G
F = JudiLing.make_transform_matrix(cue_obj_train.C, S_train)
Shat_val = cue_obj_val.C * F
Shat_train = cue_obj_train.C * F

@test JudiLing.eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, 1) >= 0.5
@test JudiLing.eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, 2) == 1
@test JudiLing.eval_SC_loose(Shat_val, S_val, S_train, 1) >= 0.5
@test JudiLing.eval_SC_loose(Shat_val, S_val, S_train, 2) == 1
@test JudiLing.eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, 1, latin_val, latin_train, :Word) == 1
@test JudiLing.eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, 2, latin_val, latin_train, :Word) == 1
@test JudiLing.eval_SC_loose(Shat_val, S_val, S_train, 1, latin_val, latin_train, :Word) == 1
@test JudiLing.eval_SC_loose(Shat_val, S_val, S_train, 2, latin_val, latin_train, :Word) == 1
end

@testset "accuracy_comprehension" begin
Expand Down

0 comments on commit e6c01d8

Please sign in to comment.