-
Colors shows data, neuron and weight values.
diff --git a/src/heatmap.ts b/src/heatmap.ts
index 0bf21f66..3e979d93 100644
--- a/src/heatmap.ts
+++ b/src/heatmap.ts
@@ -88,7 +88,7 @@ export class HeatMap {
position: "relative",
top: `-${padding}px`,
left: `-${padding}px`
- });
+ })
this.canvas = container.append("canvas")
.attr("width", numSamples)
.attr("height", numSamples)
@@ -97,7 +97,7 @@ export class HeatMap {
.style("position", "absolute")
.style("top", `${padding}px`)
.style("left", `${padding}px`);
-
+
if (!this.settings.noSvg) {
this.svg = container.append("svg").attr({
"width": width,
diff --git a/src/playground.ts b/src/playground.ts
index aeac0f9c..d37698b2 100644
--- a/src/playground.ts
+++ b/src/playground.ts
@@ -88,6 +88,7 @@ let HIDABLE_CONTROLS = [
["Noise level", "noise"],
["Batch size", "batchSize"],
["# of hidden layers", "numHiddenLayers"],
+ ["Paint Platform", "paintPlatform"],
];
class Player {
@@ -166,8 +167,6 @@ let colorScale = d3.scale.linear()
.range(["#f59322", "#e8eaeb", "#0877bd"])
.clamp(true);
let iter = 0;
-let trainData: Example2D[] = [];
-let testData: Example2D[] = [];
let network: nn.Node[][] = null;
let lossTrain = 0;
let lossTest = 0;
@@ -264,11 +263,44 @@ function makeGUI() {
reset();
});
+ // For changing state on different selections
+ d3.select("#select-orange").on("change", function() {
+ state.editColor = this.checked ? -1 : 1
+ state.serialize()
+ userHasInteracted()
+ });
+
+ d3.select("#select-blue").on("change", function() {
+ state.editColor = this.checked ? 1 : -1
+ state.serialize()
+ userHasInteracted()
+ });
+
+ // On drag, we want to paint our canvas with the dots.
+ let dragBehavior = d3.behavior.drag().on("drag", function() {
+ let isVisible = d3.select("#select-platform").style("display") === "block"
+ if(state.problem === Problem.CLASSIFICATION && isVisible) {
+ let [x, y] = d3.mouse(this)
+ let label = state.editColor
+ let padding = 20
+ let maxScale = 5.0
+ let factor = 23.07
+ x -= padding
+ y -= padding
+ x = x/factor - maxScale
+ y = maxScale - y/factor
+ state.trainData.push({x, y, label})
+ heatMap.updatePoints(state.trainData);
+ }
+ });
+
+ d3.select("#heatmap").call(dragBehavior);
+
let showTestData = d3.select("#show-test-data").on("change", function() {
state.showTestData = this.checked;
state.serialize();
userHasInteracted();
- heatMap.updateTestPoints(state.showTestData ? testData : []);
+ heatMap.updateTestPoints(state.showTestData ? state.testData : []);
});
// Check/uncheck the checkbox according to the current state.
showTestData.property("checked", state.showTestData);
@@ -355,6 +387,7 @@ function makeGUI() {
let problem = d3.select("#problem").on("change", function() {
state.problem = problems[this.value];
+ togglePaintSelection();
generateData();
drawDatasetThumbnails();
parametersChanged = true;
@@ -908,7 +941,7 @@ function constructInput(x: number, y: number): number[] {
function oneStep(): void {
iter++;
- trainData.forEach((point, i) => {
+ state.trainData.forEach((point, i) => {
let input = constructInput(point.x, point.y);
nn.forwardProp(network, input);
nn.backProp(network, point.label, nn.Errors.SQUARE);
@@ -917,8 +950,8 @@ function oneStep(): void {
}
});
// Compute the loss.
- lossTrain = getLoss(network, trainData);
- lossTest = getLoss(network, testData);
+ lossTrain = getLoss(network, state.trainData);
+ lossTest = getLoss(network, state.testData);
updateUI();
}
@@ -949,6 +982,11 @@ function reset(onStartup=false) {
d3.select("#layers-label").text("Hidden layer" + suffix);
d3.select("#num-layers").text(state.numHiddenLayers);
+ togglePaintSelection()
+ // Correct radio button on reset
+ let radioColor = state.editColor === - 1 ? "#select-orange" : "#select-blue";
+ d3.select(radioColor).attr("checked", "checked")
+
// Make a simple network.
iter = 0;
let numInputs = constructInput(0 , 0).length;
@@ -957,8 +995,8 @@ function reset(onStartup=false) {
nn.Activations.LINEAR : nn.Activations.TANH;
network = nn.buildNetwork(shape, state.activation, outputActivation,
state.regularization, constructInputIds(), state.initZero);
- lossTrain = getLoss(network, trainData);
- lossTest = getLoss(network, testData);
+ lossTrain = getLoss(network, state.trainData);
+ lossTest = getLoss(network, state.testData);
drawNetwork(network);
updateUI(true);
};
@@ -1064,6 +1102,11 @@ function hideControls() {
.attr("href", window.location.href);
}
+function togglePaintSelection() {
+ let visiblity = state.problem === Problem.CLASSIFICATION ? "" : "none"
+ d3.select("#select-platform").style("display", visiblity);
+}
+
function generateData(firstTime = false) {
if (!firstTime) {
// Change the seed.
@@ -1081,10 +1124,10 @@ function generateData(firstTime = false) {
shuffle(data);
// Split into train and test data.
let splitIndex = Math.floor(data.length * state.percTrainData / 100);
- trainData = data.slice(0, splitIndex);
- testData = data.slice(splitIndex);
- heatMap.updatePoints(trainData);
- heatMap.updateTestPoints(state.showTestData ? testData : []);
+ state.trainData = data.slice(0, splitIndex);
+ state.testData = data.slice(splitIndex);
+ heatMap.updatePoints(state.trainData);
+ heatMap.updateTestPoints(state.showTestData ? state.testData : []);
}
let firstInteraction = true;
diff --git a/src/state.ts b/src/state.ts
index 42dc8154..3c66d88e 100644
--- a/src/state.ts
+++ b/src/state.ts
@@ -130,7 +130,8 @@ export class State {
{name: "tutorial", type: Type.STRING},
{name: "problem", type: Type.OBJECT, keyMap: problems},
{name: "initZero", type: Type.BOOLEAN},
- {name: "hideText", type: Type.BOOLEAN}
+ {name: "hideText", type: Type.BOOLEAN},
+ {name: "editColor", type: Type.NUMBER}
];
[key: string]: any;
@@ -160,8 +161,11 @@ export class State {
sinX = false;
cosY = false;
sinY = false;
+ editColor = -1;
dataset: dataset.DataGenerator = dataset.classifyCircleData;
regDataset: dataset.DataGenerator = dataset.regressPlane;
+ trainData: dataset.Example2D[] = [];
+ testData: dataset.Example2D[] = [];
seed: string;
/**