Skip to content

Commit

Permalink
code clarification
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoisot committed Oct 12, 2023
1 parent 8405755 commit a4aa313
Showing 1 changed file with 46 additions and 7 deletions.
53 changes: 46 additions & 7 deletions code/confusion.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
"""
ConfusionMatrix
This defines a confusion matrix with four fields, in order: true positives, true
negatives, false positives, and false negatives. The types are meant to store
`Int` information (*i.e.* this is a proper contingency table).
"""
struct ConfusionMatrix
tp::Int
tn::Int
fp::Int
fn::Int
end

"""
ConfusionMatrix(pred::Vector{Bool}, truth::Vector{Bool})
Returns a `ConfusionMatrix` based on two `Vector{Bool}`, where the first is the
predictions, and the second in the observations.
"""
function ConfusionMatrix(pred::Vector{Bool}, truth::Vector{Bool})
tp = sum(pred .& truth)
tn = sum(.!pred .& .!truth)
Expand All @@ -13,18 +26,44 @@ function ConfusionMatrix(pred::Vector{Bool}, truth::Vector{Bool})
return ConfusionMatrix(tp, tn, fp, fn)
end

"""
ConfusionMatrix(pred::Vector{T}, truth::Vector{Bool}, τ::T) where {T <: Number}
Returns a `ConfusionMatrix` based on a vector of quantitative predictions, a
vector of Boolean observations, and a threshold. A prediction is counted as a
positive whenever it is larger than the threshold.
"""
function ConfusionMatrix(pred::Vector{T}, truth::Vector{Bool}, τ::T) where {T <: Number}
return ConfusionMatrix(convert(Vector{Bool}, pred .>= τ), truth)
end

"""
ConfusionMatrix(pred::Vector{T}, truth::Vector{Bool}) where {T <: Number}
Returns a `ConfusionMatrix` based on a vector of quantitative predictions, a
vector of Boolean observations, and a threshold assumed to be one half. A
prediction is counted as a positive whenever it is larger than the threshold.
This method is mostly here as a shortcut to use for untuned NBC.
"""
function ConfusionMatrix(pred::Vector{T}, truth::Vector{Bool}) where {T <: Number}
return ConfusionMatrix(pred, truth, 0.5)
end

function Base.Matrix(c::ConfusionMatrix)
return [c.tp c.fp; c.fn c.tn]
end
"""
Base.Matrix(c::ConfusionMatrix)
Returns the matrix representation of a `ConfusionMatrix`.
"""
Base.Matrix(c::ConfusionMatrix) = [c.tp c.fp; c.fn c.tn]

"""
Base.zero(ConfusionMatrix)
Returns an empty confusion matrix, *i.e.* a matrix where all the entries are set
to 0. This is useful in order to pre-allocate an array of matrices, using *e.g.*
`zeros(ConfusioMatrix, 10)`; note that the matrices themselves are immutable, so
the entries in this array will need to be overwritten.
"""
Base.zero(ConfusionMatrix) = ConfusionMatrix(0, 0, 0, 0)

tpr(M::ConfusionMatrix) = M.tp / (M.tp + M.fn)
Expand Down Expand Up @@ -61,13 +100,13 @@ function auc(x::Array{T}, y::Array{T}) where {T<:Number}
end

function rocauc(C::Vector{ConfusionMatrix})
x = fpr.(C)
y = tpr.(C)
x = [0., fpr.(C)..., 1.]
y = [0., tpr.(C)..., 1.]
return auc(x, y)
end

function praux(C::Vector{ConfusionMatrix})
x = tpr.(C)
y = ppv.(C)
x = [0., tpr.(C)..., 1.]
y = [1., ppv.(C)..., 0.]
return aux(x, y)
end

0 comments on commit a4aa313

Please sign in to comment.