Skip to content

Commit

Permalink
added penalty function to make positive diagonal norm (tests not yet …
Browse files Browse the repository at this point in the history
…passing)
  • Loading branch information
jehicken committed Nov 30, 2023
1 parent 761cfde commit a8814d3
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 33 deletions.
197 changes: 165 additions & 32 deletions src/norm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,61 @@ function cell_quadrature(degree, xc, xq, wq, ::Val{Dim}) where {Dim}
num_nodes = size(xc, 2)
num_quad = size(xq, 2)
@assert( num_nodes >= num_basis, "fewer nodes than basis functions")
# apply an affine transformation to the points xc and xq
lower = minimum([real.(xc) real.(xq)], dims=2)
upper = maximum([real.(xc) real.(xq)], dims=2)
dx = upper - lower
xavg = 0.5*(upper + lower)
xavg .*= 0.0
dx[:] .*= 1.001
xc_trans = zero(xc)
for I in CartesianIndices(xc)
xc_trans[I] = (xc[I] - xavg[I[1]])/dx[I[1]] - 0.5
end
xq_trans = zero(xq)
for I in CartesianIndices(xq)
xq_trans[I] = (xq[I] - xavg[I[1]])/dx[I[1]] - 0.5
end
# evaluate the polynomial basis at the quadrature and node points
workc = zeros(eltype(xc), (Dim+1)*num_nodes)
V = zeros(eltype(xc), num_nodes, num_basis)
poly_basis!(V, degree, xc_trans, workc, Val(Dim))
workq = zeros((Dim+1)*num_quad)
Vq = zeros(num_quad, num_basis)
poly_basis!(Vq, degree, xq_trans, workq, Val(Dim))
# integrate the polynomial basis using the given quadrature
b = zeros(num_basis)
for i = 1:num_basis
b[i] = dot(Vq[:,i], wq)
end
# complex-step does not play nice with pseudo-inverse
# w = pinv(transpose(V))*b
#lambda = -(transpose(V)*V)\b
#w = -V*lambda
A = [diagm(ones(num_nodes)) V; transpose(V) zeros(num_basis, num_basis)]
c = [zeros(num_nodes); b]
y = A\c
w = y[1:num_nodes]
return w
end

"""
cell_quadrature_rev!(xc_bar, degree, xc, xq, wq, w_bar, Val(Dim))
Reverse mode differentiated `cell_quadrature`. Returns the derivatives of
`dot(w, w_bar)` with respect to `xc` in the array `xc_bar`. All other inputs are the same as `cell_quadrature`.
"""
function cell_quadrature_rev!(xc_bar, degree, xc, xq, wq, w_bar, ::Val{Dim}
) where {Dim}
@assert( size(xc,1) == size(xq,1) == Dim, "xc/xq/Dim are inconsistent")
@assert( size(xq,2) == size(wq,1), "xq and wq have inconsistent sizes")
num_basis = binomial(Dim + degree, Dim)
num_nodes = size(xc, 2)
num_quad = size(xq, 2)
@assert( num_nodes >= num_basis, "fewer nodes than basis functions")

# forward sweep

# apply an affine transformation to the points xc and xq
lower = minimum([real.(xc) real.(xq)], dims=2)
upper = maximum([real.(xc) real.(xq)], dims=2)
Expand All @@ -33,6 +88,8 @@ function cell_quadrature(degree, xc, xq, wq, ::Val{Dim}) where {Dim}
workc = zeros((Dim+1)*num_nodes)
V = zeros(num_nodes, num_basis)
poly_basis!(V, degree, xc_trans, workc, Val(Dim))
dV = zeros(num_nodes, num_basis, Dim)
poly_basis_derivatives!(dV, degree, xc_trans, Val(Dim))
workq = zeros((Dim+1)*num_quad)
Vq = zeros(num_quad, num_basis)
poly_basis!(Vq, degree, xq_trans, workq, Val(Dim))
Expand All @@ -41,60 +98,69 @@ function cell_quadrature(degree, xc, xq, wq, ::Val{Dim}) where {Dim}
for i = 1:num_basis
b[i] = dot(Vq[:,i], wq)
end
lambda = -(transpose(V)*V)\b
w = -V*lambda
#w = pinv(transpose(V))*b

# find the scaling factors
dx = 1e16*ones(num_nodes)
for i = 1:num_nodes
for j = 1:num_nodes
if i == j
continue
end
dist = norm(xc_trans[:,j] - xc_trans[:,i])
dist < dx[i] ? dx[i] = dist : nothing
end
# reverse sweep
#adj = -transpose(pinv(transpose(V)))*w_bar
adj1 = -(transpose(V)*V) \ (transpose(V)*w_bar) # size = num_basis
adj2 = -w_bar - V*adj1 # size = num_nodes
for d = 1:Dim
#xc_bar[d,:] += w .* (dV[:,:,d]*adj)/dx[d]
xc_bar[d,:] += (adj2 .* (dV[:,:,d]*lambda) + w .* (dV[:,:,d]*adj1))/dx[d]
end
dx = dx.^Dim
#dx = sqrt.(dx)
#dx = 1.0./dx
#dx = dx.^2

A = V'*diagm(dx)
w = A\b
return diagm(dx)*w

#w = pinv(V')*b
#w = V'\b
#R = diagm(ones(num_nodes).*(prod(dx)/num_nodes))
#w = R*V*((V'*R*V)\b)
#@assert( norm(V'*w - b) < (1e-15)*(10^degree),
# "quadrature is not accurate!" )
#if norm(V'*w - b) > (1e-15)*(10^degree)
# println("WARNING: quadrature is not accurate! res = ", norm(V'*w - b))
#end
#return w
return nothing
end

function diagonal_norm(root::Cell{Data, Dim, T, L}, points, degree
) where {Data, Dim, T, L}
num_nodes = size(points, 2)
H = zeros(T, num_nodes)
# find the maximum number of phi basis over all cells
H = zeros(eltype(points), num_nodes)
x1d, w1d = lg_nodes(degree+1) # could also use lgl_nodes
wq = zeros(length(w1d)^Dim)
xq = zeros(Dim, length(wq))
for cell in allleaves(root)
# get the nodes in this cell's stencil, and an accurate quaduature
nodes = copy(points[:, cell.data.points])
#nodes = copy(points[:, cell.data.points])
nodes = view(points, :, cell.data.points)
quadrature!(xq, wq, cell.boundary, x1d, w1d)
# get cell quadrature and add to global norm
w = cell_quadrature(degree, nodes, xq, wq, Val(Dim))
#w = nodes[Dim,:].*nodes[1,:]
for i = 1:length(cell.data.points)
H[cell.data.points[i]] += w[i]
end
end
return H
end

function diagonal_norm_rev!(points_bar, root::Cell{Data, Dim, T, L}, points,
degree, H_bar) where {Data, Dim, T, L}
num_nodes = size(points, 2)
fill!(points_bar, zero(T))
x1d, w1d = lg_nodes(degree+1) # could also use lgl_nodes
wq = zeros(length(w1d)^Dim)
xq = zeros(Dim, length(wq))
for cell in allleaves(root)
# get the nodes in this cell's stencil, and an accurate quaduature
#nodes = copy(points[:, cell.data.points])
nodes = view(points, :, cell.data.points)
quadrature!(xq, wq, cell.boundary, x1d, w1d)
w_bar = zeros(length(cell.data.points))
for i = 1:length(cell.data.points)
# H[cell.data.points[i]] += w[i]
w_bar[i] = H_bar[cell.data.points[i]]
end
nodes_bar = view(points_bar, :, cell.data.points)
# w = cell_quadrature(degree, nodes, xq, wq, Val(Dim))
cell_quadrature_rev!(nodes_bar, degree, nodes, xq, wq, w_bar, Val(Dim))
#nodes_bar[Dim,:] += w_bar[:].*nodes[1,:]
#nodes_bar[1,:] += w_bar[:].*nodes[Dim,:]
end
return nothing
end

"""
Z, wp = cell_null_and_part(degree, xc, xq, wq, Val(Dim))
Expand Down Expand Up @@ -257,3 +323,70 @@ function obj_norm_grad!(g::AbstractVector{T}, root::Cell{Data, Dim, T, L},
end
return nothing
end

function penalty(root::Cell{Data, Dim, T, L}, xc, xc_init, dist_ref, mu, degree
) where {Data, Dim, T, L}
num_nodes = size(xc, 2)
# compute the norm part of the penalty
phi = 0.0
for i = 1:num_nodes
dist = 0.0
for d = 1:Dim
dist += (xc[d,i] - xc_init[d,i])^2
end
phi += dist/(dist_ref[i]^2)
end

# compute the diagonal norm based on x
H = diagonal_norm(root, xc, degree)

# add the penalties
for i = 1:num_nodes
H_ref = dist_ref[i]^Dim
if real(H[i]) < H_ref
phi += mu*(H[i]/H_ref - 1)^2
end
end

phi *= 0.5
return phi
end

function penalty_grad!(g::AbstractVector{T}, root::Cell{Data, Dim, T, L},
xc, xc_init, dist_ref, mu, degree
) where {Data, Dim, T, L}
num_nodes = size(xc, 2)
# need the diagonal norm for the reverse sweep
H = diagonal_norm(root, xc, degree)

# start the reverse sweep
fill!(g, zero(T))
# return phi
# phi *= 0.5
phi_bar = 0.5
# add the penalties
H_bar = zero(H)
for i = 1:num_nodes
H_ref = dist_ref[i]^Dim
if real(H[i]) < H_ref
# phi += mu*(H[i]/H_ref - 1)^2
H_bar[i] += phi_bar*2.0*mu*(H[i]/H_ref - 1)/H_ref
end
end

# compute the diagonal norm based on x
#H = diagonal_norm(root, xc, degree)
xc_bar = reshape(g, size(xc))
diagonal_norm_rev!(xc_bar, root, xc, degree, H_bar)

# compute the norm part of the penalty
for i = 1:num_nodes
# phi += dist/(dist_ref[i]^2)
dist_bar = phi_bar/(dist_ref[i]^2)
for d = 1:Dim
# dist += (xc[d,i] - xc_init[d,i])^2
xc_bar[d,i] += dist_bar*2.0*(xc[d,i] - xc_init[d,i])
end
end
return nothing
end
97 changes: 96 additions & 1 deletion test/test_diag_norm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,35 @@
end
end

@testset "test cell_quadrature_rev!: dimension $Dim, degree $degree" for Dim in 1:3, degree in 0:4
num_basis = binomial(Dim + degree, Dim)
num_nodes = 2*num_basis
x1d, w1d = CutDGD.lg_nodes(degree+1) # could also use lgl_nodes
num_quad = length(x1d)^Dim
xq = zeros(Dim, num_quad)
wq = zeros(num_quad)
cell = Cell(SVector(ntuple(i -> 0.0, Dim)),
SVector(ntuple(i -> 1.0, Dim)),
CellData(Vector{Int}(), Vector{Int}())) #Face{2,Float64}}()))
CutDGD.quadrature!(xq, wq, cell.boundary, x1d, w1d)
xc = randn(Dim, num_nodes) .+ 0.5

# compute the derivative of the (weighted) quad weights w.r.t. xc
w_bar = randn(num_nodes)
xc_bar = zero(xc)
CutDGD.cell_quadrature_rev!(xc_bar, degree, xc, xq, wq, w_bar, Val(Dim))
p = randn(size(xc_bar))
dot_prod = dot(vec(xc_bar), vec(p))

# now use complex step to approximate the same derivatives
ceps = 1e-60
xc_cmplx = complex.(xc, ceps*p)
w_cmplx = CutDGD.cell_quadrature(degree, xc_cmplx, xq, wq, Val(Dim))
dot_prod_cmplx = dot(w_bar, imag.(w_cmplx)/ceps)

@test isapprox(dot_prod, dot_prod_cmplx)
end

@testset "test diagonal_norm: dimension $Dim, degree $degree" for Dim in 1:3, degree in 0:1

# use a unit HyperRectangle
Expand Down Expand Up @@ -82,6 +111,38 @@ end

end

@testset "test diagonal_norm_rev!: dimension $Dim, degree $degree" for Dim in 1:3, degree in 0:4

# use a unit HyperRectangle
root = Cell(SVector(ntuple(i -> 0.0, Dim)),
SVector(ntuple(i -> 1.0, Dim)),
CellData(Vector{Int}(), Vector{Int}()))

# DGD dof locations
num_basis = binomial(Dim + degree, Dim)
num_nodes = 10*num_basis
points = rand(Dim, num_nodes)

# refine mesh, build sentencil, and evaluate norm
CutDGD.refine_on_points!(root, points)
CutDGD.build_nn_stencils!(root, points, degree)

# get the derivative of dot(H_bar, H) in direction p
H_bar = randn(num_nodes)
points_bar = zero(points)
CutDGD.diagonal_norm_rev!(points_bar, root, points, degree, H_bar)
p = randn(Dim, num_nodes)
dot_prod = dot(vec(points_bar), vec(p))

# get the derivative using complex step
ceps = 1e-60
points_cmplx = complex.(points, ceps*p)
H_cmplx = CutDGD.diagonal_norm(root, points_cmplx, degree)
dot_prod_cmplx = dot(H_bar, imag.(H_cmplx)/ceps)

@test isapprox(dot_prod, dot_prod_cmplx)
end

@testset "test obj_norm_grad!: dimension $Dim, degree $degree" for Dim in 1:3, degree in 0:4

# use a unit HyperRectangle
Expand Down Expand Up @@ -126,7 +187,41 @@ end

@test isapprox(gdotp, gdotp_cmplx)

end

# @testset "test penalty_grad!: dimension $Dim, degree $degree" for Dim in 1:3, degree in 0:4

# # use a unit HyperRectangle
# root = Cell(SVector(ntuple(i -> 0.0, Dim)),
# SVector(ntuple(i -> 1.0, Dim)),
# CellData(Vector{Int}(), Vector{Int}()))

# # DGD dof locations
# num_basis = binomial(Dim + degree, Dim)

# num_nodes = 5*num_basis
# points = rand(Dim, num_nodes)
# points_init = points + 0.05*rand(Dim, num_nodes)

# # refine mesh and build stencil
# CutDGD.refine_on_points!(root, points)
# CutDGD.build_nn_stencils!(root, points, degree)

# # compute the penalty gradient
# mu = 1.0
# g = zeros(Dim*num_nodes)
# dist_ref = ones(num_nodes)
# CutDGD.penalty_grad!(g, root, points, points_init, dist_ref, mu, degree)

# # compare against a complex-step based directional derivative
# p = randn(length(g))
# gdotp = dot(g, p)

# ceps = 1e-60
# points_cmplx = complex.(points, ceps.*reshape(p, size(points)))
# penalty = CutDGD.penalty(root, points_cmplx, points_init, dist_ref, mu, degree)
# gdotp_cmplx = imag(penalty)/ceps

# @test isapprox(gdotp, gdotp_cmplx)

end
# end

0 comments on commit a8814d3

Please sign in to comment.