Machine Learning im Browser mit TensorFlow.js

Anpassung über Regularisierung

Regularisierung heißt der Vorgang, um sich dem gewünschten Resultat zu nähern. TensorFlow.js bietet dafür dieselben Mechanismen wie TensorFlow. Indem Entwickler bei jedem Trainingsdurchgang nur einen gewissen Teil der Neuronen trainieren, verringern sie zum einen die Kapazität des Modells und erzeugen zum anderen effektiv ein Ensemble schwacher Netze. Das Vorgehen hat sich im klassischen Machine Learning als sinnvoll erwiesen.

In TensorFlow.js dient dazu ein Dropout-Layer, den folgender Code mit einem relativ hohen Wert von 0,7 konfiguriert – der optimale Wert lässt sich über Experimente herausfinden. Häufig liegt es um 0,5 oder etwas darüber.

model.add(tf.layers.dropout({ rate: 0.7 }));

Neben diesem Verfahren hat sich die Normalisierung der Ausgabe eines Layers als zweckmäßig erwiesen. Die Intuition ist dabei weniger deutlich als bei Dropout. Vereinfacht gesagt fließt durch die Normalisierung ein gewisser Störfaktor in das Training ein, der dem Modell Robustheit verleiht. Folgende Zeile konfiguriert die BatchNormalization als eigenen Layer:

model.add(tf.layers.batchNormalization());

Die Konfiguration des Netzes inklusive Regularisierung besteht in den meisten Fällen mehr aus Ausprobieren als strukturiertem Vorgehen.

Damit erreicht das Modell die in Abbildung 3 gezeigte Genauigkeit von etwa 72 Prozent und weist nahezu kein Overfitting auf. Ob ein Wert gut ist, lässt sich nicht allgemein sagen: Die Daten bestimmen die Möglichkeiten und der Anwendungsfall die Notwendigkeiten. Für eine Vorhersage der Schadensklasse mag der Wert ausreichen, für das Erkennen eines Fußgängers durch ein selbstfahrendes Auto wäre das Ergebnis deutlich zu schlecht.

Auswertung und Vorhersagen

Das Modell lässt sich nach dem Training durch Aufruf der predict-Methode für Vorhersagen nutzen:

model.predict(tf.tensor([[100, 48, 10]])).print();

Da das Training kein deterministischer Vorgang ist, fallen die Werte bei jedem Training unterschiedlich aus. Das Modell des Autors gibt für einen 48 Jahre alten Fahrer, dessen Auto eine Höchstgeschwindigkeit von 100 Mph (160 kmh) hat und der pro Jahr 10.000 Meilen fährt, folgende Wahrscheinlichkeiten aus

  • viele Unfälle 1 Prozent,
  • wenig Unfälle 87 Prozent,
  • im mittleren Bereich 12 Prozent.

Das erscheint plausibel und somit hat das Modell seinen ersten, eher anekdotischen Test bestanden.

Neben der Kurve, die den Trainingsverlauf beschreibt, bietet TensorFlow.js weitere Möglichkeiten zur Auswertung des Trainings. Hilfreich ist eine sogenannte Confusion Matrix. Sie fasst die Verwechslungen zwischen den Kategorien zusammen. Um sie zu erstellen, gilt es zunächst eine Vorhersage für alle bekannten Daten zu treffen, und diese mit den bekannten und richtigen Bewertungen zu vergleichen. Der Code für TensorFlow.js sieht folgendermaßen aus:

// die bekannten, richtigen Bewertungen
const yTrue = tf.tensor(ys);
// die Vorhersagen
const yPred = model.predict(X).argMax([-1]);
const confusionMatrix =
await tfvis.metrics.confusionMatrix(yTrue, yPred);

Anders als oben ist nur die jeweils höchste Wahrscheinlichkeit von Interesse, da die tatsächlichen Kategorien nur in der Form angegeben sind. Im Code erledigt das der Aufruf der argMax-Methode.

Folgende Zeilen stellen die Matrix anschaulich dar:

const matrixContainer = 
document.getElementById("matrix-surface");
const classNames = ["many accidents",
"few or no accidents",
"in the middle"];
tfvis.show.confusionMatrix(matrixContainer,
confusionMatrix,
classNames);

Abbildung 4 zeigt eine solche Matrix. Im Idealfall würden nur Einträge auf der Diagonalen existieren als Zeichen dafür, dass es keine Verwechslungen zwischen den Gruppen gibt. In der Praxis schaut das Ergebnis meist anders aus und erlaubt Rückschlüsse auf die Qualität des Modells.

Die Confusion Matrix sieht in der Praxis anders aus als der Idealfall (Abb. 4).

Im konkreten Fall sehen die beiden Kategorien für viele und wenige Unfälle recht gut aus. Allerdings fällt der im Diagramm unten dargestellte mittlere Bereich deutlich ab. Es fällt auf, dass besonders viele Daten, die im unteren rechten Feld landen sollten, als "wenig Unfälle" kategorisiert sind – im Kasten in der Mitte unten.

Um die Frage zu beantworten, ob das Ergebnis auf einen Fehler bei den Trainingsdaten oder im Training hindeutet, hilft ein Blick auf den Übersichtsplot in Abbildung 1. Er zeigt, dass die beiden Bereiche stark ineinander verschränkt sind und das Modell sie wahrscheinlich nicht besser auseinanderhalten kann. Da in den vermischten Bereichen die wenigen Unfälle klar in der Überzahl sind, wird das Modell für Vorhersage diesen Bereich bevorzugen. Somit scheinen die Ergebnisse insgesamt nachvollziehbar zu sein, was den Abweichungen zum Trotz für eine gute Qualität des Modells spricht.

Das Modell in Produktion bringen

Somit ist das Modell reif für den Einsatz in der Produktion. Dass es dafür nicht den Browser des Nutzers verlassen muss, ist bei sensiblen Daten durchaus ein Vorteil. In der Praxis erfolgt die Überführung über die save-Methode des Modells:

https://js.tensorflow.org/api/latest/#tf.LayersModel.save

In der URL lässt sich der Speicherort des Modells angeben. localstorage ist die einfachste Option, die jedoch bei großen Modellen eventuell nicht funktioniert. Daher ist für größere Datenmengen `indexeddb die bessere Wahl:

model.save("indexeddb://insurance");

Anschließend lässt sich das Modell mit den Entwicklerwerkzeugen des Browsers betrachten. Dabei zeigt sich, dass vor allem die Parameter der Neuronen viel Platz einnehmen.

Modell im Browser, im Application Tab der Chrome Dev Tools

Wer das Modell später in einer anderen Anwendung auf demselben Host nutzen möchte, kann es wieder laden, worauf es sich wie vor dem Speichern verhält.

model = await tf.loadLayersModel('indexeddb://insurance');
Beispielhafte Anwendung zur Risikoabschätzung über das trainierte Modell (Abb. 6)

Das Modell verhält sich beim Einbetten in andere Anwendungen aus deren Sicht wie eine API, die im konkreten Beispiel die Risikoabschätzung übernimmt. Sachbearbeiter könnten die Anwendung schließlich nutzen, um reale Vorhersagen zu treffen, wie Abbildung 6 zeigt.

Zusätzliche Ressourcen

Der online verfügbare Crash Risk Caclulator läuft in Chrome, Safari, Firefox und vermutlich in Kürze im Edge Browser. Letzterer unterstützte beim Verfassen des Artikels den TextDecoder noch nicht, den die Anwendung zum Laden der Daten nutzt. Der komplette Quellcode des Beispiels ist auf GitHub abgelegt.

Weitere Demos, Tutorials und die API-Beschreibung finden sich auf der TensorFlow-Site. Dort existiert zudem eine Anleitung zum Umwandeln eines TensorFlow-Modells, um es mit TensorFlow.js im Browser in Produktion zu bringen. Außerdem sind einige vortrainierte Modelle verfügbar, die sich in eine bestehende JavaScript-Anwendung integrieren lassen.

Die Site ml5.js hat sich zum Ziel gesetzt, den Einsatz von TensorFlow.js weiter zu vereinfachen und bietet dazu einige gute Beispiele und Modelle. Auf Stack Overflow gibt es Antworten auf viele Fragen rund um TensorFlow.js.