The Squares performance visualization for multiclass classifiers, implemented using Stardust.
Data shown here is a 10-class classifier trained on the MNIST dataset.
<!DOCTYPE html>
<meta charset="utf-8" />
<link rel="stylesheet" href="../common/style.css" type="text/css" />
<script src="//d3js.org/d3.v3.min.js" type="text/javascript"></script>
<script src="../common/stardust/stardust.bundle.js" type="text/javascript"></script>
<script src="../common/utils.js" type="text/javascript"></script>
<style>
.squares-container {
position: relative;
}
.squares-container canvas,
.squares-container svg {
position: absolute;
left: 0;
top: 0;
}
.squares-container svg {
pointer-events: none;
}
.axis path,
.axis line {
fill: none;
stroke: black;
shape-rendering: crispEdges;
}
.axis text {
font-family: sans-serif;
font-size: 11px;
}
</style>
<div class="squares-container">
<canvas id="main-canvas"></canvas>
<svg id="main-svg"></svg>
</div>
<div class="initializing"><p>Initializing...</p></div>
<script src="squares.js" type="text/javascript"></script>
<script type="text/javascript">
var container = d3.select(".squares-container");
var vis = new SquaresVisualization(container.node());
loadData("mnist.csv", DATA => {
vis.layout();
vis.setInstances(DATA.slice(0, 4000));
vis.setLayoutParameter({
y0: 10,
numberBins: 10,
squaresPerBin: 11,
squareSize: 3,
squareSpacing: 4,
xSpacing: 88
});
});
</script>
class SquaresVisualization {
makeSquaresMark(side, mode) {
let squares = Stardust.mark
.custom()
.input("size", "float")
.input("spacing", "float")
.input("x0", "float")
.input("xSpacing", "float")
.input("y1", "float")
.input("binSpacing", "float")
.input("binIndex", "float")
.input("binSquares", "float")
.input("bin", "float")
.input("color", "Color");
if (side == "right") {
squares.input("assigned", "float").variable("x", "x0 + xSpacing * assigned");
} else {
squares.input("label", "float").variable("x", "x0 + xSpacing * label");
}
squares
.variable("y", "y1 - bin * binSpacing")
.variable("binIx", "floor(binIndex / binSquares)")
.variable("binIy", "(binIndex % binSquares)")
.variable("bx", "binIx * spacing")
.variable("by", "binIy * spacing");
if (side == "right") {
squares.variable("px", "x + bx").variable("py", "y + by");
} else {
squares.variable("px", "x - bx - spacing").variable("py", "y + by");
}
if (mode == "solid") {
squares
.add("P2D.Rectangle")
.attr("p1", "Vector2(px, py)")
.attr("p2", "Vector2(px + spacing, py + spacing)")
.attr("color", "Color(1, 1, 1, 1)");
squares
.add("P2D.Rectangle")
.attr("p1", "Vector2(px, py)")
.attr("p2", "Vector2(px + size, py + size)")
.attr("color", "color");
}
if (mode == "outlined") {
squares
.add("P2D.Rectangle")
.attr("p1", "Vector2(px, py)")
.attr("p2", "Vector2(px + spacing, py + spacing)")
.attr("color", "Color(1, 1, 1, 1)");
squares
.add("P2D.Rectangle")
.attr("p1", "Vector2(px, py)")
.attr("p2", "Vector2(px + size, py + size)")
.attr("color", "color");
squares
.add("P2D.Rectangle")
.attr("p1", "Vector2(px + 0.5, py + 0.5)")
.attr("p2", "Vector2(px + size - 0.5, py + size - 0.5)")
.attr("color", "Color(1, 1, 1, 1)");
}
if (mode == "selection") {
squares
.add("P2D.Rectangle")
.attr("p1", "Vector2(px, py)")
.attr("p2", "Vector2(px + size, py + size)")
.attr("color", "Color(0, 0, 0, 0)");
squares
.add("P2D.OutlinedRectangle")
.attr("p1", "Vector2(px - 0.5, py - 0.5)")
.attr("p2", "Vector2(px + size + 0.5, py + size + 0.5)")
.attr("color", "Color(0, 0, 0, 1)");
}
return squares;
}
constructor(container) {
let squares = this.makeSquaresMark("right", "solid");
let squaresOutlined = this.makeSquaresMark("left", "outlined");
let squaresSelection = this.makeSquaresMark("right", "selection");
let squaresOutlinedSelection = this.makeSquaresMark("left", "selection");
let parallelCoordinates = Stardust.mark
.custom()
.input("color", "Color")
.input("x0", "float")
.input("xSpacing", "float");
for (let i = 0; i < 10; i++) {
parallelCoordinates.input(`y${i}`, "float");
parallelCoordinates.variable(`x${i}`, `x0 + xSpacing * ${i}`);
if (i < 9) {
parallelCoordinates
.add("P2D.Line")
.attr("p1", `Vector2(x${i}, y${i})`)
.attr("p2", `Vector2(x${i + 1}, y${i + 1})`)
.attr("width", 2)
.attr("color", `Color(color.r, color.g, color.b, 0.3)`);
}
}
this._container = container;
this._canvas = d3.select(container).append("canvas");
this._canvasNode = this._canvas.node();
this._svg = d3.select(container).append("svg");
this._svgAxis = this._svg.append("g").classed("axis", true);
let platform = Stardust.platform("webgl-2d", this._canvasNode);
this._platform = platform;
this._layout = {
numberBins: 10,
squaresPerBin: 10,
squareSize: 2,
squareSpacing: 3,
x0: 80,
xSpacing: 100,
y0: 10,
numberClasses: 10
};
let colors = [
[31, 119, 180],
[255, 127, 14],
[44, 160, 44],
[214, 39, 40],
[148, 103, 189],
[140, 86, 75],
[227, 119, 194],
[127, 127, 127],
[188, 189, 34],
[23, 190, 207]
];
colors = colors.map(x => [x[0] / 255, x[1] / 255, x[2] / 255, 1]);
let mark = Stardust.mark.create(squares, platform);
mark
.attr("color", d => colors[d.label])
.attr("assigned", d => d.assigned)
.attr("binIndex", d => d.binIndex)
.attr("bin", d => d.scoreBin);
let mark2 = Stardust.mark.create(squaresOutlined, platform);
mark2
.attr("color", d => colors[d.assigned])
.attr("label", d => d.label)
.attr("binIndex", d => d.binIndex2)
.attr("bin", d => d.scoreBin);
let markOverlay = Stardust.mark.create(squaresSelection, platform);
markOverlay
.attr("color", [0, 0, 0, 1])
.attr("assigned", d => d.assigned)
.attr("binIndex", d => d.binIndex)
.attr("bin", d => d.scoreBin);
let markOverlayOutlined = Stardust.mark.create(squaresOutlinedSelection, platform);
markOverlayOutlined
.attr("color", [0, 0, 0, 1])
.attr("label", d => d.label)
.attr("binIndex", d => d.binIndex2)
.attr("bin", d => d.scoreBin);
let markPC = Stardust.mark.create(parallelCoordinates, platform);
let yScale = Stardust.scale
.linear()
.domain([0, 1])
.range([500, 100]);
markPC.attr("color", d => colors[d.label]);
for (let i = 0; i < 10; i++) {
(i => {
markPC.attr(`y${i}`, yScale(d => d.scores[i]));
})(i);
}
this._marks = {
squares: mark,
squaresOutlined: mark2,
squaresOverlay: markOverlay,
squaresOverlayOutlined: markOverlayOutlined,
parallelCoordinates: markPC,
yScale: yScale
};
this._canvasNode.onmousemove = e => {
let bounds = this._canvasNode.getBoundingClientRect();
var x = e.clientX - bounds.left;
var y = e.clientY - bounds.top;
var p = this._platform.getPickingPixel(x * 2, y * 2);
if (p) {
this.setSelection([p[0].data()[p[1]]]);
} else {
this.setSelection([]);
}
};
this._canvasNode.onmousedown = e => {
let bounds = this._canvasNode.getBoundingClientRect();
var x = e.clientX - bounds.left;
var y = e.clientY - bounds.top;
var p = this._platform.getPickingPixel(x * 2, y * 2);
if (p) {
let inst = p[0].data()[p[1]];
let selection = this._instances.filter(
d => d.label == inst.label && d.assigned == inst.assigned && d.scoreBin == inst.scoreBin
);
this.setSelection(selection);
} else {
this.setSelection([]);
}
};
this._canvasNode.ondblclick = e => {
let bounds = this._canvasNode.getBoundingClientRect();
var x = e.clientX - bounds.left;
var y = e.clientY - bounds.top;
var p = this._platform.getPickingPixel(x * 2, y * 2);
if (p) {
let inst = p[0].data()[p[1]];
let selection = this._instances.filter(d => d.label == inst.label && d.assigned == inst.assigned);
this.setSelection(selection);
} else {
this.setSelection([]);
}
};
}
setSelection(instances) {
this._marks.squaresOverlay.data(instances);
this._marks.squaresOverlayOutlined.data(instances.filter(d => d.label != d.assigned));
this._marks.parallelCoordinates.data(instances);
this.renderSelection();
}
setInstances(DATA) {
this._DATA = DATA;
this._layout.numberClasses = 10;
let instances = DATA.map(d => {
return {
label: parseInt(d.Label.substr(1)),
assigned: parseInt(d.Assigned.substr(1)),
score: d[d.Assigned],
scoreBin: Math.min(
this._layout.numberBins - 1,
Math.max(0, Math.floor(parseFloat(d[d.Assigned]) * this._layout.numberBins))
),
scores: [+d.C0, +d.C1, +d.C2, +d.C3, +d.C4, +d.C5, +d.C6, +d.C7, +d.C8, +d.C9]
};
});
let CM = [];
let CMBin = [];
for (let i = 0; i < this._layout.numberClasses; i++) {
CM[i] = [];
CMBin[i] = [];
for (let j = 0; j < this._layout.numberClasses; j++) {
CM[i][j] = 0;
CMBin[i][j] = [];
for (let k = 0; k < this._layout.numberBins; k++) {
CMBin[i][j][k] = 0;
}
}
}
instances.sort(function(a, b) {
if (a.label == a.assigned) return b.label == b.assigned ? 0 : +1;
if (b.label == b.assigned) return a.label == a.assigned ? 0 : -1;
if (a.assigned != b.assigned) return a.assigned - b.assigned;
if (a.label != b.label) return a.label - b.label;
return a.score - b.score;
});
instances.forEach(function(d) {
d.CMIndex = CM[d.label][d.assigned];
CM[d.label][d.assigned] += 1;
d.binIndex = CMBin[0][d.assigned][d.scoreBin];
CMBin[0][d.assigned][d.scoreBin] += 1;
});
instances.sort(function(a, b) {
if (a.label == a.assigned) return b.label == b.assigned ? 0 : +1;
if (b.label == b.assigned) return a.label == a.assigned ? 0 : -1;
if (a.assigned != b.assigned) return -(a.assigned - b.assigned);
if (a.label != b.label) return a.label - b.label;
return a.score - b.score;
});
instances.forEach(function(d) {
d.binIndex2 = CMBin[1][d.label][d.scoreBin];
CMBin[1][d.label][d.scoreBin] += 1;
});
instances.forEach(function(d) {
d.CMCount = CM[d.label][d.assigned];
});
instances.sort(function(a, b) {
return a.assigned - b.assigned;
});
this._instances = instances;
this._marks.squares.data(this._instances);
this._marks.squaresOutlined.data(this._instances.filter(d => d.label != d.assigned));
this.layout();
this.render();
}
layoutConfigSquares() {
let binSpacing = this._layout.squareSpacing * this._layout.squaresPerBin + this._layout.squareSpacing;
this._marks.yScale.range([this._layout.y0 + binSpacing * this._layout.numberBins, this._layout.y0]);
[
this._marks.squares,
this._marks.squaresOutlined,
this._marks.squaresOverlay,
this._marks.squaresOverlayOutlined
].forEach(s =>
s
.attr("size", this._layout.squareSize)
.attr("spacing", this._layout.squareSpacing)
.attr("x0", this._layout.x0)
.attr("xSpacing", this._layout.xSpacing)
.attr(
"y1",
this._layout.y0 + binSpacing * this._layout.numberBins - binSpacing + this._layout.squareSpacing / 2
)
.attr("binSpacing", binSpacing)
.attr("binSquares", this._layout.squaresPerBin)
);
this._marks.parallelCoordinates.attr("x0", this._layout.x0).attr("xSpacing", this._layout.xSpacing);
}
layout() {
this.layoutConfigSquares();
var d3yscale = d3.scale
.linear()
.domain(this._marks.yScale.domain())
.range(this._marks.yScale.range());
var axis = d3.svg
.axis()
.scale(d3yscale)
.orient("left");
this._svgAxis.attr("transform", "translate(30, 0)");
this._svgAxis.call(axis);
let width = 960;
let height = 500;
this._svg.attr("width", width).attr("height", height);
this._platform.resize(width, height);
}
render() {
this._platform.beginPicking(this._canvasNode.width, this._canvasNode.height);
this._marks.squares.render();
this._marks.squaresOutlined.render();
this._platform.endPicking();
this._platform.clear();
this._marks.squares.render();
this._marks.squaresOutlined.render();
}
renderSelection() {
this._platform.clear();
this._marks.squares.render();
this._marks.squaresOutlined.render();
this._marks.squaresOverlay.render();
this._marks.squaresOverlayOutlined.render();
this._marks.parallelCoordinates.render();
}
setLayoutParameter(layout) {
let shouldRecompute = false;
for (let p in layout) {
if (layout.hasOwnProperty(p)) {
this._layout[p] = layout[p];
if (p == "numberBins" || p == "numberClasses") {
shouldRecompute = true;
}
}
}
if (shouldRecompute) {
this.setInstances(this._DATA);
}
this.layout();
this.render();
}
}