Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNFlux] Translate Traffic prediction Pluto notebook to Literate #572

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

aurorarossi
Copy link
Member

@aurorarossi aurorarossi commented Dec 26, 2024


# ## Dataset: METR-LA

# We use the `METR-LA` dataset from the paper [Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926.pdf), which contains traffic data from loop detectors in the highway of Los Angeles County. The dataset contains traffic speed data from March 1, 2012 to June 30, 2012. The data is collected every 5 minutes, resulting in 12 observations per hour, from 207 sensors. Each sensor is a node in the graph, and the edges represent the distances between the sensors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean the edges represent the distances between the sensors? should be clarified

Comment on lines +66 to +67
train_loader = zip(features[1:200], targets[1:200]);
test_loader = zip(features[2001:2288], targets[2001:2288]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

motivate this choice of ranges

Comment on lines +85 to +99
for epoch in 1:100
for (x, y) in train_loader
x, y = (x, y)
grads = Flux.gradient(model) do model
ŷ = model(graph, x)
Flux.mae(ŷ, y)
end
Flux.update!(opt, model, grads[1])
end

if epoch % 10 == 0
loss = mean([Flux.mae(model(graph,x), y) for (x, y) in train_loader])
@show epoch, loss
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

identation

model = GNNChain(TGCN(2 => 100; add_self_loops = false), Dense(100, 1))

# ![](https://www.researchgate.net/profile/Haifeng-Li-3/publication/335353434/figure/fig4/AS:851870352437249@1580113127759/The-architecture-of-the-Gated-Recurrent-Unit-model.jpg)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here it would be useful to show the output of the model and how it is interpreted as a prediction

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants