Machine Learning mit Apache Spark 2

Entscheidungsbäume

Klassifizieren mit Entscheidungsbäumen

Entscheidungsbäume (Decision Trees) sind geordnete, gerichtete Bäume. Jeder Knoten steht für eine logische Regel und jedes Blatt für eine Entscheidung. Das erste Beispiel soll die Aufgabe mit einem DecisionTreeClassifier von Spark lösen.

Die ersten Zeilen des Programms erzeugen eine Spark-Instanz, die auf dem lokalen Rechner mit drei Threads laufen soll ("local[3]"). Mit der Methode master() ließe sich alternativ die URL zu einem Spark-Cluster übergeben, das die Ausführung übernimmt:

SparkSession spark = SparkSession
.builder()
.appName("JavaMNISTDT")
.master("local[3]")
.getOrCreate();

Die folgenden Befehle konfigurieren einen Reader, der die CSV-Daten mit den Test- und Traingsdaten liest:

DataFrameReader reader = spark.read()
.option("header", "true")
.option("delimiter", ",")
.option("inferSchema", true)
.format("com.databricks.spark.csv");
Dataset<Row> test = reader
.load(Const.BASE_DIR_DATASETS+"mnist_test2.csv")
.filter(e -> Math.random() > 0.00 );
Dataset<Row> train = reader
.load(Const.BASE_DIR_DATASETS+"mnist_train2.csv")
.filter(e -> Math.random() > 0.00 );

Der filter-Befehl selektiert aus dem gesamten Datenbestand eine zufällige Stichprobe, damit Experimente schneller ablaufen können. Der Wert 0.0 wählt alle Zeilen aus. Während der Entwicklung sind folgende Befehle zum Prüfen der Struktur und der eingelesen Daten hilfreich:

train.showSchema(); // Logs the schema.
train.show(2); // Logs first 2 data rows.

Leider versteht der DecisionTreeClassifier die vorliegenden Daten nicht direkt. Er erwartet eine Spalte mit dem Label und eine Spalte mit einem FeatureVector. Das benötigte Zusammenführen der Bildpunkte lässt sich von Hand erledigen, aber genau für solche Aufgaben bietet Spark eine Reihe von Transformatoren und andere Hilfsklassen an. Der VectorAssembler führt die Werte mehrerer Spalten in einem Vektor zusammen. Zuerst kopiert das Programm die Namen der relevanten Spalten (p0 bis p783) aus dem Schema in ein Feld. Der Assembler übernimmt das Feld und den Namen der aggregierten Zielspalte (features):

StructType schema = train.schema();
String[] inputFeatures = Arrays.copyOfRange(schema.fieldNames(), 1,
schema.fieldNames().length);
VectorAssembler assembler = new VectorAssembler()
.setInputCols( inputFeatures )
.setOutputCol("features");

Meist liegen die Klassennamen (Labels) nicht als numerische Werte vor, sondern könnten fachliche Bezeichnungen sein. Um die Informationen für den Algorithmus zu nutzen, hilft der StringIndexer, der aus allen Werten einen numerischen Index aufbaut. Die Eingabe besteht aus der zu indizierenden Spalte label und der Zielspalte IndexedLabel:

StringIndexerModel stringIndexer = new StringIndexer()
.setInputCol("label")
.setHandleInvalid("skip")
.setOutputCol("indexedLabel")
.fit(train);

Der StringIndexer untersucht alle Werte der Spalte vor dem Trainingslauf durch Aufruf der Methode fit().

Classifier benötigt die Spalten mit den numerischen Labels und den FeatureVector. Optional lässt sich die Tiefe des erzeugten Baums begrenzen – maximal sind 30 Ebenen erlaubt. In der Praxis bieten sich Werte zwischen 10 und 15 an. Die Ergebnisse der Klassifikation landen in der Spalte prediction:

DecisionTreeClassifier dt = new DecisionTreeClassifier()
.setMaxDepth(28).setSeed(12345L)
.setLabelCol(stringIndexer.getOutputCol())
.setFeaturesCol(assembler.getOutputCol());

Die Ergebnisse in prediction entsprechen der numerischen Darstellung. Ein IndexToString-Transformer wandelt auf die fachlichen Bezeichnungen und legt diese in der Spalte predictionLabel ab. Die Zuordnung steuert der StringIndexer bei.

IndexToString indexToString = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(stringIndexer.labels());

Eine Pipeline fasst die einzelnen Schritte als Ablauf zusammen:

Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {
assembler
, stringIndexer
, dt
, indexToString
});

Die Methode fit() startet das Training und erstellt ein Modell für die komplette Pipeline. Das Modell liefert mit der Methode transform(test) auf den unbekannten Testbestand ein neues Dataset inklusive der Vorhersage:

PipelineModel model = pipeline.fit(train);
// Use the model to evaluate the test set.
Dataset<Row> result = model.transform(test);

Eine Zeile genügt, um die Erkennungsrate zu ermitteln. Der Befehl filter wählt die Zeilen, bei denen Label und Vorhersage übereinstimmen:

String correct = "Correct:"+ 
(100.0 * result.filter("label = predictedLabel")
.count() / result.count());

Die Rate ist mit 88 Prozent akzeptabel. Das gesamte Training und der Test dauern keine zwei Minuten auf dem Rechner des Autors, einem MacBook Pro 2013 mit Intel i5.

Spark bietet eine Debug-Ansicht, um Entscheidungsbäume zu visualisieren. Entwickler müssen das Modell lediglich aus der Pipeline entnehmen:

String showTree = ((DecisionTreeModel)model.stages()[2])
.toDebugString();

Für das gewählte Beispiel ist der erstellte Baum mit 27 Ebenen leider zu unübersichtlich.

Ein Wald voller Bäume

Eine Erweiterung der Entscheidungsbäume sind Random Forests, die nicht nur einen Baum generieren, sondern einen ganzen Wald. Jeder Baum trägt mit einem Faktor zur endgültigen Entscheidung bei. Die einheitliche API erleichtert das Austauschen. Der RandomForestClassifier erhält die Anzahl der zu erzeugenden Bäume und ersetzt den alten Algorithmus in der Pipeline:

RandomForestClassifier rf = new RandomForestClassifier()
.setNumTrees(30)
.setLabelCol("indexedLabel")
.setFeaturesCol(assembler.getOutputCol());
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {
assembler
, stringIndexer
, rf
, indexToString
});

Einen ganzen Wald zu erstellen, zahlt sich aus. Zwar verdoppelt sich die Trainingszeit auf vier Minuten, dafür steigt die Erkennungsrate auf gute 96,3 Prozent.

Eine Kreuztabelle stellt die erkannten Ziffern (Zeilen) den erwünschten Antworten (Spalten) gegenüber. Die erste Zeile und die erste Spalte zeigen das jeweilige Label. Die Diagonale hoher Zahlen zeigt viele korrekt erkannte Ziffern.

Label

0 1 2 3 4 5 6 7 8 9
0 968 1 3 4 1 3
1 1122 4 3 1 4 1
2 8 998 5 4 1 3 7 4 2
3 8 974 7 1 9 5 6
4 1 1 3 1 940 6 1 8 21
5 3 3 17 3 840 5 5 12 4
6 9 3 3 1 5 7 926 4
7 1 5 20 2 1 984 1 14
8 2 7 12 5 4 4 4 926 10
9 7 5 4 11 14 4 1 4 6 953