Skip to content

Commit

Permalink
Merge pull request #13 from una-auxme/dev
Browse files Browse the repository at this point in the history
Refactored Code with SciML Style Guide
  • Loading branch information
JulianTrommer authored Jul 22, 2024
2 parents 016149e + 3e13fdf commit 24bf893
Show file tree
Hide file tree
Showing 12 changed files with 1,039 additions and 908 deletions.
2 changes: 2 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "sciml"
separate_kwargs_with_semicolon = true
9 changes: 9 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"
labels:
- "dependencies"
- "github-actions"
36 changes: 36 additions & 0 deletions .github/workflows/Formatter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Formatter

on:
pull_request:
push:
branches:
- 'main'
tags: '*'

jobs:
build:
runs-on: ubuntu-latest
steps:
- name: "Check out repository"
uses: actions/checkout@v4

- name: "Set up Julia"
uses: julia-actions/setup-julia@v1
with:
version: '1.10'
arch: x64

- name: Install JuliaFormatter and format
run: julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter")); using JuliaFormatter; format(".", verbose=true)'

- name: Format check
run: |
julia -e '
out = Cmd(`git diff`) |> read |> String
if out == ""
exit(0)
else
@error "Some files are not formatted! Please use the SciMLStyle to format your files!"
write(stdout, out)
exit(1)
end'
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

[![Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://una-auxme.github.io/MeshGraphNets.jl/dev)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

[*GraphNetCore.jl*](https://github.com/una-auxme/GraphNetCore.jl) is a software package for the Julia programming language that provides an the core functionality of the [*MeshGraphNets.jl*](https://github.com/una-auxme/MeshGraphNets.jl) package. Some parts are based on the implementation of the [MeshGraphNets](https://arxiv.org/abs/2010.03409) framework by [Google DeepMind](https://deepmind.google/) for simulating mesh-based physical systems via graph neural networks:

Expand Down
3 changes: 2 additions & 1 deletion src/GraphNetCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ export FeatureGraph
# graph_network.jl
export GraphNetwork
# normaliser.jl
export NormaliserOffline, NormaliserOfflineMinMax, NormaliserOfflineMeanStd, NormaliserOnline
export NormaliserOffline, NormaliserOfflineMinMax, NormaliserOfflineMeanStd,
NormaliserOnline

# graph_network.jl
export build_model, step!, save!, load
Expand Down
158 changes: 80 additions & 78 deletions src/feature_graph.jl
Original file line number Diff line number Diff line change
@@ -1,78 +1,80 @@
#
# Copyright (c) 2023 Julian Trommer
# Licensed under the MIT license. See LICENSE file in the project root for details.
#

"""
FeatureGraph(nf, ef, senders, receivers)
Data structure that is used as an input for the [`GraphNetwork`](@ref).
## Arguments
- `nf`: Node features of the graph.
- `ef`: edge features of the graph.
- `senders`: List of nodes in the mesh where graph edges start.
- `receivers`: List of nodes in the mesh where graph edges end.
"""
mutable struct FeatureGraph
nf
ef
senders
receivers
end

"""
update_features!(g; nf, ef)
Updates the node and edge features of the given [`FeatureGraph`](@ref).
## Arguments
- `g`: [`FeatureGraph`](@ref) that should be updated.
## Keyword Arguments
- `nf`: Updated node features.
- `ef`: Updated edge features.
## Returns
- Updated graph as a [`FeatureGraph`](@ref) struct.
"""
function update_features!(g::FeatureGraph; nf, ef)
g.nf = nf
g.ef = ef
return g
end

"""
aggregate_edge_features(graph)
Aggregates the edge features based on the senders and receivers of the given [`FeatureGraph`](@ref).
## Arguments
- `graph`: [`FeatureGraph`](@ref) which node and edge features are used.
## Returns
- Two-dimensional array with the
- 1. dimension containing the concatenated features as new edge features and the
- 2. dimension representing the individual edges.
"""
@inline function aggregate_edge_features(graph::FeatureGraph)
return vcat(graph.nf[:, graph.senders], graph.nf[:, graph.receivers], graph.ef)
end

"""
aggregate_node_features(graph, updated_edge_features)
Aggregates the node features based on the given [`FeatureGraph`](@ref) and updated edge features.
## Arguments
- `graph`: [`FeatureGraph`](@ref) which node features are used.
- `updated_edge_features`: New edge features that were calculated in a previous step.
## Returns
- Two dimensional array with the
- 1. dimension containing the concatenated features as new node features and the
- 2. dimension representing the individual nodes.
"""
@inline function aggregate_node_features(graph::FeatureGraph, updated_edge_features)
return vcat(graph.nf, NNlib.scatter(+, updated_edge_features, graph.receivers, dstsize = size(graph.nf)))
end
#
# Copyright (c) 2023 Julian Trommer
# Licensed under the MIT license. See LICENSE file in the project root for details.
#

"""
FeatureGraph(nf, ef, senders, receivers)
Data structure that is used as an input for the [`GraphNetwork`](@ref).
## Arguments
- `nf`: Node features of the graph.
- `ef`: edge features of the graph.
- `senders`: List of nodes in the mesh where graph edges start.
- `receivers`: List of nodes in the mesh where graph edges end.
"""
mutable struct FeatureGraph{F <: AbstractArray, T <: AbstractArray}
nf::F
ef::F
senders::T
receivers::T
end

"""
update_features!(g; nf, ef)
Updates the node and edge features of the given [`FeatureGraph`](@ref).
## Arguments
- `g`: [`FeatureGraph`](@ref) that should be updated.
## Keyword Arguments
- `nf`: Updated node features.
- `ef`: Updated edge features.
## Returns
- Updated graph as a [`FeatureGraph`](@ref) struct.
"""
function update_features!(g::FeatureGraph; nf, ef)
g.nf = nf
g.ef = ef

return g
end

"""
aggregate_edge_features(graph)
Aggregates the edge features based on the senders and receivers of the given [`FeatureGraph`](@ref).
## Arguments
- `graph`: [`FeatureGraph`](@ref) which node and edge features are used.
## Returns
- Two-dimensional array with the
- 1. dimension containing the concatenated features as new edge features and the
- 2. dimension representing the individual edges.
"""
@inline function aggregate_edge_features(graph::FeatureGraph)
return vcat(graph.nf[:, graph.senders], graph.nf[:, graph.receivers], graph.ef)
end

"""
aggregate_node_features(graph, updated_edge_features)
Aggregates the node features based on the given [`FeatureGraph`](@ref) and updated edge features.
## Arguments
- `graph`: [`FeatureGraph`](@ref) which node features are used.
- `updated_edge_features`: New edge features that were calculated in a previous step.
## Returns
- Two dimensional array with the
- 1. dimension containing the concatenated features as new node features and the
- 2. dimension representing the individual nodes.
"""
@inline function aggregate_node_features(graph::FeatureGraph, updated_edge_features)
return vcat(graph.nf,
NNlib.scatter(+, updated_edge_features, graph.receivers; dstsize = size(graph.nf)))
end
Loading

0 comments on commit 24bf893

Please sign in to comment.