diff --git a/index.html b/index.html index 3f6060d6..49ea826c 100644 --- a/index.html +++ b/index.html @@ -274,9 +274,21 @@

Output

+
+
+ Select color that you would like to paint with. +
+ + +
-
Colors shows data, neuron and weight values. diff --git a/src/heatmap.ts b/src/heatmap.ts index 0bf21f66..3e979d93 100644 --- a/src/heatmap.ts +++ b/src/heatmap.ts @@ -88,7 +88,7 @@ export class HeatMap { position: "relative", top: `-${padding}px`, left: `-${padding}px` - }); + }) this.canvas = container.append("canvas") .attr("width", numSamples) .attr("height", numSamples) @@ -97,7 +97,7 @@ export class HeatMap { .style("position", "absolute") .style("top", `${padding}px`) .style("left", `${padding}px`); - + if (!this.settings.noSvg) { this.svg = container.append("svg").attr({ "width": width, diff --git a/src/playground.ts b/src/playground.ts index aeac0f9c..d37698b2 100644 --- a/src/playground.ts +++ b/src/playground.ts @@ -88,6 +88,7 @@ let HIDABLE_CONTROLS = [ ["Noise level", "noise"], ["Batch size", "batchSize"], ["# of hidden layers", "numHiddenLayers"], + ["Paint Platform", "paintPlatform"], ]; class Player { @@ -166,8 +167,6 @@ let colorScale = d3.scale.linear() .range(["#f59322", "#e8eaeb", "#0877bd"]) .clamp(true); let iter = 0; -let trainData: Example2D[] = []; -let testData: Example2D[] = []; let network: nn.Node[][] = null; let lossTrain = 0; let lossTest = 0; @@ -264,11 +263,44 @@ function makeGUI() { reset(); }); + // For changing state on different selections + d3.select("#select-orange").on("change", function() { + state.editColor = this.checked ? -1 : 1 + state.serialize() + userHasInteracted() + }); + + d3.select("#select-blue").on("change", function() { + state.editColor = this.checked ? 1 : -1 + state.serialize() + userHasInteracted() + }); + + // On drag, we want to paint our canvas with the dots. + let dragBehavior = d3.behavior.drag().on("drag", function() { + let isVisible = d3.select("#select-platform").style("display") === "block" + if(state.problem === Problem.CLASSIFICATION && isVisible) { + let [x, y] = d3.mouse(this) + let label = state.editColor + let padding = 20 + let maxScale = 5.0 + let factor = 23.07 + x -= padding + y -= padding + x = x/factor - maxScale + y = maxScale - y/factor + state.trainData.push({x, y, label}) + heatMap.updatePoints(state.trainData); + } + }); + + d3.select("#heatmap").call(dragBehavior); + let showTestData = d3.select("#show-test-data").on("change", function() { state.showTestData = this.checked; state.serialize(); userHasInteracted(); - heatMap.updateTestPoints(state.showTestData ? testData : []); + heatMap.updateTestPoints(state.showTestData ? state.testData : []); }); // Check/uncheck the checkbox according to the current state. showTestData.property("checked", state.showTestData); @@ -355,6 +387,7 @@ function makeGUI() { let problem = d3.select("#problem").on("change", function() { state.problem = problems[this.value]; + togglePaintSelection(); generateData(); drawDatasetThumbnails(); parametersChanged = true; @@ -908,7 +941,7 @@ function constructInput(x: number, y: number): number[] { function oneStep(): void { iter++; - trainData.forEach((point, i) => { + state.trainData.forEach((point, i) => { let input = constructInput(point.x, point.y); nn.forwardProp(network, input); nn.backProp(network, point.label, nn.Errors.SQUARE); @@ -917,8 +950,8 @@ function oneStep(): void { } }); // Compute the loss. - lossTrain = getLoss(network, trainData); - lossTest = getLoss(network, testData); + lossTrain = getLoss(network, state.trainData); + lossTest = getLoss(network, state.testData); updateUI(); } @@ -949,6 +982,11 @@ function reset(onStartup=false) { d3.select("#layers-label").text("Hidden layer" + suffix); d3.select("#num-layers").text(state.numHiddenLayers); + togglePaintSelection() + // Correct radio button on reset + let radioColor = state.editColor === - 1 ? "#select-orange" : "#select-blue"; + d3.select(radioColor).attr("checked", "checked") + // Make a simple network. iter = 0; let numInputs = constructInput(0 , 0).length; @@ -957,8 +995,8 @@ function reset(onStartup=false) { nn.Activations.LINEAR : nn.Activations.TANH; network = nn.buildNetwork(shape, state.activation, outputActivation, state.regularization, constructInputIds(), state.initZero); - lossTrain = getLoss(network, trainData); - lossTest = getLoss(network, testData); + lossTrain = getLoss(network, state.trainData); + lossTest = getLoss(network, state.testData); drawNetwork(network); updateUI(true); }; @@ -1064,6 +1102,11 @@ function hideControls() { .attr("href", window.location.href); } +function togglePaintSelection() { + let visiblity = state.problem === Problem.CLASSIFICATION ? "" : "none" + d3.select("#select-platform").style("display", visiblity); +} + function generateData(firstTime = false) { if (!firstTime) { // Change the seed. @@ -1081,10 +1124,10 @@ function generateData(firstTime = false) { shuffle(data); // Split into train and test data. let splitIndex = Math.floor(data.length * state.percTrainData / 100); - trainData = data.slice(0, splitIndex); - testData = data.slice(splitIndex); - heatMap.updatePoints(trainData); - heatMap.updateTestPoints(state.showTestData ? testData : []); + state.trainData = data.slice(0, splitIndex); + state.testData = data.slice(splitIndex); + heatMap.updatePoints(state.trainData); + heatMap.updateTestPoints(state.showTestData ? state.testData : []); } let firstInteraction = true; diff --git a/src/state.ts b/src/state.ts index 42dc8154..3c66d88e 100644 --- a/src/state.ts +++ b/src/state.ts @@ -130,7 +130,8 @@ export class State { {name: "tutorial", type: Type.STRING}, {name: "problem", type: Type.OBJECT, keyMap: problems}, {name: "initZero", type: Type.BOOLEAN}, - {name: "hideText", type: Type.BOOLEAN} + {name: "hideText", type: Type.BOOLEAN}, + {name: "editColor", type: Type.NUMBER} ]; [key: string]: any; @@ -160,8 +161,11 @@ export class State { sinX = false; cosY = false; sinY = false; + editColor = -1; dataset: dataset.DataGenerator = dataset.classifyCircleData; regDataset: dataset.DataGenerator = dataset.regressPlane; + trainData: dataset.Example2D[] = []; + testData: dataset.Example2D[] = []; seed: string; /**