Data Science, Machine Learning und KI
Kontakt

Im ersten Artikel dieser Serie über die Klassifizierung von Automodellen haben wir ein Modell gebaut, das Transfer Learning verwendet, um das Automodell durch ein Bild eines Autos zu klassifizieren. Im zweiten Beitrag haben wir gezeigt, wie TensorFlow Serving verwendet werden kann, um ein TensorFlow-Modell am Beispiel des Automodell-Classifiers einzusetzen. Diesen dritten Beitrag widmen wir einem weiteren wesentlichen Aspekt von Deep Learning und maschinellem Lernen im Allgemeinen: der Erklärbarkeit von Modellvorhersagen (englisch: Explainable AI).

Wir beginnen mit einer kurzen allgemeinen Einführung in das Thema Erklärbarkeit beim maschinellen Lernen. Als nächstes werden wir kurz auf verbreitete Methoden eingehen, die zur Erklärung und Interpretation von CNN-Vorhersagen verwendet werden können. Anschließend werden wir Grad-CAM, eine gradientenbasierte Methode, ausführlich erklären, indem wir Schritt für Schritt eine Implementierung des Verfahrens durchgehen. Zum Schluss zeigen wir Ergebnisse, die wir mit unserer Grad-CAM-Implementierung für den Auto-Modell-Classifier berechnet haben.

Inhalt

Eine kurze Einführung in die Erklärbarkeit von Machine Learning Modellen

In den letzten Jahren war die Erklärbarkeit ein immer wiederkehrendes Thema – aber dennoch ein Nischenthema – im Machine Learning. In den letzten vier Jahren jedoch hat das Interesse an diesem Thema stark zugenommen. Stark dazu beigetragen hat unter anderem die steigende Anzahl von Machine Learning-Modellen in der Produktion. Einerseits führt dies zu einer wachsenden Zahl von Endnutzern, die verstehen müssen, wie die Modelle Entscheidungen treffen. Andererseits müssen immer mehr Entwickler*innen von Machine Learning verstehen, warum (oder warum nicht) ein Modell auf eine bestimmte Weise funktioniert.

Dieser steigende Bedarf an Erklärbarkeit führte in den letzten Jahren zu einigen sowohl methodisch als auch technisch bemerkenswerten Innovationen:

Methoden zur Erklärung von CNN-Outputs für Bilddaten

Deep Neural Networks (DNNs) und insbesondere komplexe Architekturen wie CNNs galten lange Zeit als reine Blackbox-Modelle. Wie oben beschrieben änderte sich dies in den letzten Jahren, und inzwischen gibt es verschiedene Methoden, um CNN-Outputs zu erklären. Zum Beispiel implementiert die hervorragende Bibliothek tf-explain eine breite Palette nützlicher Methoden für TensorFlow 2.x. Wir werden nun kurz auf die Ideen der verschiedenen Ansätze eingehen, bevor wir uns Grad-CAM zuwenden:

Activations Visualization

Activations Visualization ist die einfachste Visualisierungstechnik. Hierbei wird die Ausgabe einer bestimmten Layer innerhalb des Netzwerks während des Vorwärtsdurchlaufs ausgegeben. Diese kann hilfreich sein, um ein Gefühl für die extrahierten Features zu bekommen, da die meisten Activations während des Trainings gegen Null tendieren (bei Verwendung der ReLu-Activation). Ein Beispiel für die Ausgabe der ersten Faltungsschicht des Auto-Modell-Classifiers ist unten dargestellt:

Activations Beispielbild

Vanilla Gradients

Man kann die Vanilla-Gradients der Ausgabe der vorhergesagten Klassen für das Eingangsbild verwenden, um die Bedeutung der Eingangspixel abzuleiten.

Vanilla Gradients Beispielbild

Wir sehen hier, dass der hervorgehobene Bereich hauptsächlich auf das Auto fokussiert ist. Im Vergleich zu den unten besprochenen Methoden ist der diskriminierende Bereich viel weniger eingegrenzt.

Occlusion Sensitivity

Bei diesem Ansatz wird die Signifikanz bestimmter Teile des Eingangsbildes berechnet, indem die Vorhersage des Modells für verschiedene ausgeblendete Teile des Eingangsbildes bewertet wird. Teile des Bildes werden iterativ ausgeblendet, indem sie durch graue Pixel ersetzt werden. Je schwächer die Vorhersage wird, wenn ein Teil des Bildes ausgeblendet ist, desto wichtiger ist dieser Teil für die endgültige Vorhersage. Basierend auf der Unterscheidungskraft der Bildregionen kann eine Heatmap erstellt und dargestellt werden. Die Anwendung der Occlusion Sensitivity für unseren Auto-Modell-Classifier hat keine aussagekräftigen Ergebnisse geliefert. Daher zeigen wir das Beispielbild von tf-explain, welches das Ergebnis der Anwendung des Verfahrens der Occlusion Sensitivity für ein Katzenbild zeigt.

Occlusion Sensitivity Beispielbild

CNN Fixations

Ein weiterer interessanter Ansatz namens CNN Fixations wurde in diesem Paper vorgestellt . Die Idee dabei ist, zurück zu verfolgen, welche Neuronen in jeder Schicht signifikant waren, indem man die Activations aus der Vorwärtsrechnung und die Netzwerkgewichte betrachtet. Die Neuronen mit großem Einfluss werden als Fixations bezeichnet. Dieser Ansatz erlaubt es also, die wesentlichen Regionen für das Ergebnis zu finden, ohne wiederholte Modellvorhersagen berechnen zu müssen (wie dies z.B. für die oben erklärte Occlusion Sensitivity der Fall ist).

Das Verfahren kann wie folgt beschrieben werden: Der Knoten, der der Klasse entspricht, wird als Fixation in der Ausgabeschicht gewählt. Dann werden die Fixations für die vorherige Schicht bestimmt, indem berechnet wird, welche der Knoten den größten Einfluss auf die Fixations der nächsthöheren Ebene haben, die im letzten Schritt bestimmt wurden. Die Knotengewichtung wird durch Multiplikation von Activations und Netzwerk-Gewichten errechnet. Wenn ihr an den Details des Verfahrens interessiert seid, schaut euch das Paper oder das entsprechende Github Repo an. Dieses Backtracking wird so lange durchgeführt, bis das Eingabebild erreicht ist, was eine Menge von Pixeln mit beträchtlicher Unterscheidungskraft ergibt. Ein Beispiel aus dem Paper ist unten dargestellt.

CNN Fixations Beispielbild

CAM

Das in diesem Paper vorgestellte Class Activation Mapping (CAM) ist ein Verfahren, um die diskriminante(n) Region(en) für eine CNN-Vorhersage durch die Berechnung von sogenannten Class Activation Maps zu finden. Ein wesentlicher Nachteil dieses Verfahrens ist, dass das Netzwerk als letzten Schritt vor der Vorhersageschicht ein Global Average Pooling (GAP) verwenden muss. Es ist daher nicht möglich, diesen Ansatz für allgemeine CNN-Architekturen anzuwenden. Ein Beispiel ist in der folgenden Abbildung dargestellt (entnommen aus dem CAM paper):

CAM Beispielbild

Die Class Activation Map weist jeder Position (x, y) in der letzten Faltungsschicht eine Bedeutung zu, indem sie die Linearkombination der Activations – gewichtet mit den entsprechenden Ausgangsgewichten für die beobachtete Klasse (im obigen Beispiel „Australian Terrier“) – berechnet. Die resultierende Class Activation Mapping wird dann auf die Größe des Eingabebildes hochgerechnet. Dies wird durch die oben dargestellte Heatmap veranschaulicht. Aufgrund der Architektur von CNNs ist die Aktivierung, z. B. oben links für eine beliebige Schicht, direkt mit der oberen linken Seite des Eingabebildes verbunden. Deshalb können wir nur aus der Betrachtung der letzten CNN-Schicht schließen, welche Eingabebereiche wichtig sind.

Bei dem Grad-CAM-Verfahren, das wir unten im Detail besprechen werden, handelt es sich um eine Verallgemeinerung von CAM. Grad-CAM kann auf Netzwerke mit allgemeinen CNN-Architekturen angewendet werden, die mehrere fully connected Layers am Ausgang enthalten.

Grad-CAM

Grad-CAM erweitert die Anwendbarkeit des CAM-Verfahrens durch das Einbeziehen von Gradienteninformationen. Konkret bestimmt der Gradient der Loss-Funktion in Bezug auf die letzte Faltungsschicht das Gewicht für jede der entsprechenden Feature Maps. Wie beim obigen CAM-Verfahren bestehen die weiteren Schritte in der Berechnung der gewichteten Summe der Aktivierungen und dem anschließenden Upsampling des Ergebnisses auf die Bildgröße, um das Originalbild mit der erhaltenen Heatmap darzustellen. Wir werden nun den Code, der zur Ausführung von Grad-CAM verwendet werden kann, zeigen und diskutieren. Der vollständige Code ist hier auf GitHub verfügbar.

import pickle
import tensorflow as tf
import cv2
from car_classifier.modeling import TransferModel

INPUT_SHAPE = (224, 224, 3)

# Load list of targets
file = open('.../classes.pickle', 'rb')
classes = pickle.load(file)

# Load model
model = TransferModel('ResNet', INPUT_SHAPE, classes=classes)
model.load('...')

# Gradient model, takes the original input and outputs tuple with:
# - output of conv layer (in this case: conv5_block3_3_conv)
# - output of head layer (original output)
grad_model = tf.keras.models.Model([model.model.inputs],
                                   [model.model.get_layer('conv5_block3_3_conv').output,
                                    model.model.output])

# Run model and record outputs, loss, and gradients
with tf.GradientTape() as tape:
    conv_outputs, predictions = grad_model(img)
    loss = predictions[:, label_idx]

# Output of conv layer
output = conv_outputs[0]

# Gradients of loss w.r.t. conv layer
grads = tape.gradient(loss, conv_outputs)[0]

# Guided Backprop (elimination of negative values)
gate_f = tf.cast(output > 0, 'float32')
gate_r = tf.cast(grads > 0, 'float32')
guided_grads = gate_f * gate_r * grads

# Average weight of filters
weights = tf.reduce_mean(guided_grads, axis=(0, 1))

# Class activation map (cam)
# Multiply output values of conv filters (feature maps) with gradient weights
cam = np.zeros(output.shape[0: 2], dtype=np.float32)
for i, w in enumerate(weights):
    cam += w * output[:, :, i]

# Or more elegant: 
# cam = tf.reduce_sum(output * weights, axis=2)

# Rescale to org image size and min-max scale
cam = cv2.resize(cam.numpy(), (224, 224))
cam = np.maximum(cam, 0)
heatmap = (cam - cam.min()) / (cam.max() - cam.min())

Detailbetrachtung des Codes

  • Der erste Schritt besteht darin, eine Instanz des Modells zu laden.
  • Dann erstellen wir eine neue keras.Model-Instanz, die zwei Ausgaben hat: Die Aktivierungen der letzten CNN-Schicht ('conv5_block3_3_conv') und die ursprüngliche Modellausgabe.
  • Als nächstes führen wir eine Vorwärtsrechnung für unser neues grad_model aus, wobei wir als Eingabe ein Bild ( img) der Form (1, 224, 224, 3) verwenden, das mit der Methode resnetv2.preprocess_input vorverarbeitet wurde. Zur Aufzeichnung der Gradienten wird tf.GradientTape angelegt und angewendet (die Gradienten werden hierbei im tapeObjekt gespeichert). Weiterhin werden die Ausgaben der Faltungsschicht (conv_outputs) und des heads (predictions) gespeichert. Schließlich können wir label_idx verwenden, um den Verlust zu erhalten, der dem Label entspricht, für das wir die diskriminierenden Regionen finden wollen.
  • Mit Hilfe der gradient-Methode kann man die gewünschten Gradienten aus tape extrahieren. In diesem Fall benötigen wir den Gradienten des Verlustes in Bezug auf die Ausgabe der Faltungsschicht.
  • In einem weiteren Schritt wird eine guided Backprop angewendet. Dabei werden nur Werte für die Gradienten behalten, bei denen sowohl die Aktivierungen als auch die Gradienten positiv sind. Dies bedeutet im Wesentlichen, dass die Aufmerksamkeit auf die Aktivierungen beschränkt wird, die positiv zu der gewünschten Ausgabevorhersage beitragen.
  • Die weights werden durch Mittelung der erhaltenen geführten Gradienten für jeden Filter berechnet.
  • Die Class Activation Map cam wird dann als gewichteter Durchschnitt der Aktivierungen der Feature Map (output) berechnet. Die Methode mit der obigen for-Schleife hilft zu verstehen, was die Funktion im Detail tut. Eine weniger einfache, aber effizientere Art, die CAM-Berechnung zu implementieren, ist die Verwendung von tf.reduce_mean und wird in der kommentierten Zeile unterhalb der Schleifenimplementierung gezeigt.
  • Schließlich wird das Resampling (Größenänderung) mit der resize-Methode von OpenCV2 durchgeführt, und die Heatmap wird so skaliert, dass sie Werte in [0, 1] enthält, um sie zu plotten.

Eine Version von Grad-CAM ist auch in tf-explain implementiert.

Beispiele für den Auto-Modell-Classifier

Wir verwenden nun die Grad-CAM-Implementierung, um die Vorhersagen des TransferModel für die Klassifizierung von Automodellen zu interpretieren und zu erklären. Wir beginnen mit der Betrachtung von Fahrzeugbildern, die von vorne aufgenommen wurden.

Grad-CAM für Fahrzeugaufnahmen von der Vorderseite
Grad-CAM für Fahrzeugaufnahmen von der Vorderseite

Die roten Regionen markieren die wichtigsten diskriminierenden Regionen, die blauen Regionen die unwichtigsten. Wir können sehen, dass sich das CNN bei Bildern von vorne auf den Kühlergrill des Autos und den Bereich des Logos konzentriert. Ist das Auto leicht gekippt, verschiebt sich der Fokus mehr auf den Rand des Fahrzeugs. Dies ist auch bei leicht gekippten Bildern von der Rückseite des Fahrzeugs der Fall, wie im mittleren Bild unten gezeigt.

Grad-CAM für Fahrzeugaufnahmen von der Rückseite
Grad-CAM für Fahrzeugaufnahmen von der Rückseite

Bei Bildern von der Rückseite des Autos liegt der wichtigste Unterscheidungsbereich in der Nähe des Nummernschilds. Wie bereits erwähnt, hat bei Autos, die aus einem Winkel betrachtet werden, die nächstgelegene Ecke die höchste Trennschärfe. Ein sehr interessantes Beispiel ist die Mercedes-Benz C-Klasse auf der rechten Seite, bei der sich das Modell nicht nur auf die Rückleuchten konzentriert, sondern auch die höchste Trennschärfe auf den Modellschriftzug legt.

Grad-CAM für Fahrzeugaufnahmen von der Seite
Grad-CAM für Fahrzeugaufnahmen von der Seite

Wenn wir Bilder von der Seite betrachten, stellen wir fest, dass die diskriminierende Region auf die untere Hälfte der Autos beschränkt ist. Auch hier bestimmt der Winkel, aus dem das Fahrzeugbild aufgenommen wurde, die Verschiebung der Region in Richtung der vorderen oder hinteren Ecke.

Im Allgemeinen ist die wichtigste Tatsache, dass die diskriminierenden Bereiche immer auf Teile der Autos beschränkt sind. Es gibt keine Bilder, bei denen der Hintergrund eine hohe Unterscheidungskraft hat. Die Betrachtung der Heatmaps und der zugehörigen diskriminierenden Regionen kann als Sanity-Check für CNN-Modelle verwendet werden.

Fazit

Wir haben mehrere Ansätze zur Erklärung von CNN-Klassifikatorausgaben diskutiert. Wir haben Grad-CAM im Detail vorgestellt, indem wir den Code untersucht und uns Beispiele für den Auto-Modell-Classifier angeschaut haben. Am auffälligsten ist, dass die durch das Grad-CAM-Verfahren hervorgehobenen diskriminierenden Regionen immer auf das Auto fokussiert sind und nie auf die Hintergründe der Bilder. Das Ergebnis zeigt, dass das Modell so funktioniert, wie wir es erwarten und spezifische Teile des Autos zur Unterscheidung zwischen verschiedenen Modellen verwendet werden.

Im vierten und letzten Teil dieser Blog-Serie werden wir zeigen, wie der Car Classifier mit Dash in eine Web-Anwendung eingebaut werden kann. Bis bald!

Im ersten Beitrag dieser Serie haben wir Transfer Learning im Detail besprochen und ein Modell zur Klassifizierung von Automodellen erstellt. In diesem Beitrag werden wir das Problem der Modellbereitstellung am Beispiel des im ersten Beitrags vorgestellten TransferModel diskutieren.

Ein Modell ist in der Praxis nutzlos, wenn es keine einfache Möglichkeit gibt, damit zu interagieren. Mit anderen Worten: Wir brauchen eine API für unsere Modelle. TensorFlow Serving wurde entwickelt, um diese Funktionalitäten für TensorFlow-Modelle bereitzustellen. In diesem Beitrag zeigen wir, wie ein TensorFlow Serving Server in einem Docker-Container gestartet werden kann und wie wir mit dem Server über HTTP-Anfragen interagieren können.

Wenn ihr noch nie mit Docker gearbeitet habt, empfehlen wir, dieses Tutorial von Docker durchzuarbeiten, bevor ihr diesen Artikel lest. Wenn ihr euch ein Beispiel für das Deployment in Docker ansehen möchtet, empfehlen wir euch, diesen Blogbeitrag von unserem Kollegen Oliver Guggenbühl zu lesen, in dem beschrieben wird, wie ein R-Skript in Docker ausgeführt werden kann.

Inhalt

Einführung in TensorFlow Serving

Zum Einstieg geben wir euch zunächst einen Überblick über TensorFlow Serving.

TensorFlow Serving ist das Serving-System von TensorFlow, das entwickelt wurde, um das Deployment von verschiedenen Modellen mit einer einheitlichen API zu ermöglichen. Unter Verwendung der Abstraktion von Servables, die im Grunde Objekte sind, mit denen Inferenz durchgeführt werden kann, ist es möglich, mehrere Versionen von deployten Modellen zu serven. Das ermöglicht zum Beispiel, dass eine neue Version eines Modells hochgeladen werden kann, während die vorherige Version noch für Kunden verfügbar ist. Im Großen und Ganzen sind sogenannte Manager für die Verwaltung des Lebenszyklus von Servables verantwortlich, d. h. für das Laden, Bereitstellen und Löschen.

In diesem Beitrag werden wir zeigen, wie eine einzelne Modellversion deployed werden kann. Die unten aufgeführten Code-Beispiele zeigen, wie ein Server in einem Docker-Container gestartet werden kann und wie die Predict API verwendet werden kann, um mit dem Modell zu interagieren. Um mehr über TensorFlow Serving zu erfahren, verweisen wir auf die TensorFlow-Website.

Implementierung

Wir werden nun die folgenden drei Schritte besprechen, die erforderlich sind, um das Modell einzusetzen und Requests zu senden.

  • Speichern eines Modells im richtigen Format und in der richtigen Ordnerstruktur mit TensorFlow SavedModel
  • Ausführen eines Serving-Servers innerhalb eines Docker-Containers
  • Interaktion mit dem Modell über REST Requests

Speichern von TensorFlow-Modellen

Für diejenigen, die den ersten Beitrag dieser Serie nicht gelesen haben, folgt nun eine kurze Zusammenfassung der wichtigsten Punkte, die zum Verständnis des nachfolgenden Codes notwendig sind:

Das TransferModel.model (unten im Code auch self.model) ist eine tf.keras.Model Instanz, also kann es mit der eingebauten save Methode gespeichert werden. Da das Modell auf im Internet gescrapten Daten trainiert wurde, können sich die Klassenbezeichnungen beim erneuten Scraping der Daten ändern. Wir speichern daher die Index-Klassen-Zuordnung beim Speichern des Modells in classes.pickle ab. TensorFlow Serving erfordert, dass das Modell im SavedModel Format gespeichert wird. Wenn Sie tf.keras.Model.save verwenden, muss der Pfad ein Ordnername sein, sonst wird das Modell in einem anderen, inkompatiblen Format (z.B. HDF5) gespeichert. Im Code unten enthält folderpath den Pfad des Ordners, in dem wir alle modellrelevanten Informationen speichern wollen. Das SavedModel wird unter folderpath/model gespeichert und das Class Mapping wird als folderpath/classes.pickle gespeichert.

def save(self, folderpath: str):
    """
    Save the model using tf.keras.model.save

    Args:
        folderpath: (Full) Path to folder where model should be stored
    """

    # Make sure folderpath ends on slash, else fix
    if not folderpath.endswith("/"):
        folderpath += "/"

    if self.model is not None:
        os.mkdir(folderpath)
        model_path = folderpath + "model"
        # Save model to model dir
        self.model.save(filepath=model_path)
        # Save associated class mapping
        class_df = pd.DataFrame({'classes': self.classes})
        class_df.to_pickle(folderpath + "classes.pickle")
    else:
        raise AttributeError('Model does not exist')

TensorFlow Serving im Docker Container starten

Nachdem wir das Modell auf der Festplatte gespeichert haben, müssen wir nun den TensorFlow Serving Server starten. Am schnellsten deployen kann man TensorFlow Serving mithilfe eines Docker-Containers. Der erste Schritt ist daher das Ziehen des TensorFlow Serving Images von DockerHub. Das kann im Terminal mit dem Befehl docker pull tensorflow/serving gemacht werden.

Dann können wir den unten stehenden Code verwenden, um einen TensorFlow Serving Container zu starten. Er führt den Shell-Befehl zum Starten eines Containers aus. Die in docker_run_cmd gesetzten Optionen sind die folgenden:

  • Das Serving-Image exponiert Port 8501 für die REST-API, die wir später zum Senden von Anfragen verwenden werden. Wir mappen mithilfe der -p– Flag also den Host-Port 8501 auf den Port 8501 des Containers.
  • Als nächstes binden wir unser Modell mit -v in den Container ein. Es ist wichtig, dass das Modell in einem versionierten Ordner gespeichert ist (hier MODEL_VERSION=1); andernfalls wird das Serving-Image das Modell nicht finden. Der model_path_guest muss also die Form <path>/<model name>/MODEL_VERSION haben, wobei MODEL_VERSION eine ganze Zahl ist.
  • Mit -e können wir die Umgebungsvariable MODEL_NAME setzen, die den Namen unseres Modells enthält.
  • Die Option --name tf_serving wird nur benötigt, um unserem neuen Docker-Container einen bestimmten Namen zuzuweisen.

Wenn wir versuchen, diese Datei zweimal hintereinander auszuführen, wird der Docker-Befehl beim zweiten Mal nicht ausgeführt, da bereits ein Container mit dem Namen tf_serving existiert. Um dieses Problem zu vermeiden, verwenden wir docker_run_cmd_cond. Hier prüfen wir zunächst, ob ein Container mit diesem spezifischen Namen bereits existiert und läuft. Wenn ja, lassen wir ihn gleich; wenn nicht, prüfen wir, ob eine beendete Version des Containers existiert. Wenn ja, wird diese gelöscht und ein neuer Container gestartet; wenn nicht, wird direkt ein neuer Container erstellt.

import os

MODEL_FOLDER = 'models'
MODEL_SAVED_NAME = 'resnet_unfreeze_all_filtered.tf'
MODEL_NAME = 'resnet_unfreeze_all_filtered'
MODEL_VERSION = '1'

# Define paths on host and guest system
model_path_host = os.path.join(os.getcwd(), MODEL_FOLDER, MODEL_SAVED_NAME, 'model')
model_path_guest = os.path.join('/models', MODEL_NAME, MODEL_VERSION)

# Container start command
docker_run_cmd = f'docker run ' 
                 f'-p 8501:8501 ' 
                 f'-v {model_path_host}:{model_path_guest} ' 
                 f'-e MODEL_NAME={MODEL_NAME} ' 
                 f'-d ' 
                 f'--name tf_serving ' 
                 f'tensorflow/serving'

# If container is not running, create a new instance and run it
docker_run_cmd_cond = f'if [ ! "(docker ps -q -f name=tf_serving)" ]; then n'                        f'   if [ "(docker ps -aq -f status=exited -f name=tf_serving)" ]; then 														n' 
                      f'   		docker rm tf_serving n' 
                      f'   fi n' 
                      f'   {docker_run_cmd} n' 
                      f'fi'

# Start container
os.system(docker_run_cmd_cond)

Anstatt das Modell von unserer lokalen Festplatte zu mounten, indem wir das -v-Flag im Docker-Befehl verwenden, könnten wir das Modell auch in das Docker-Image kopieren, so dass das Modell einfach durch das Ausführen eines Containers und die Angabe der Port-Zuweisungen bedient werden könnte. Es ist wichtig zu beachten, dass in diesem Fall das Modell mit der Ordnerstruktur Ordnerpfad/<Modellname>/1 gespeichert werden muss, wie oben erklärt. Wenn dies nicht der Fall ist, wird der TensorFlow Serving Container das Modell nicht finden. Wir werden hier nicht weiter auf diesen Fall eingehen. Wenn ihr daran interessiert seid, eure Modelle auf diese Weise zu deployen, verweisen wir auf diese Anleitung auf der TensorFlow Webseite.

REST Request

Da das Modell nun geserved ist und bereit zur Verwendung ist, brauchen wir einen Weg, um damit zu interagieren. TensorFlow Serving bietet zwei Optionen, um Anfragen an den Server zu senden: gRCP und REST API, welche beide an unterschiedlichen Ports verfügbar sind. Im folgenden Codebeispiel werden wir REST verwenden, um das Modell abzufragen.

Zuerst laden wir ein Bild von der Festplatte, für das wir eine Vorhersage machen wollen. Dies kann mit dem image Modul von TensorFlow gemacht werden. Als nächstes konvertieren wir das Bild in ein Numpy-Array mittels der img_to_array-Methode. Der nächste und letzte Schritt ist entscheidend für unseren Car Classifier Use Case: da wir das Trainingsbild vorverarbeitet haben, bevor wir unser Modell trainiert haben (z.B. Normalisierung), müssen wir die gleiche Transformation auf das Bild anwenden, das wir vorhersagen wollen. Die praktische Funktion „preprocess_input“ sorgt dafür, dass alle notwendigen Transformationen auf unser Bild angewendet werden.

from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet_v2 import preprocess_input

# Load image
img = image.load_img(path, target_size=(224, 224))
img = image.img_to_array(img)

# Preprocess and reshape data
img = preprocess_input(img)
img = img.reshape(-1, *img.shape)

Die RESTful API von TensorFlow Serving bietet mehrere Endpunkte. Im Allgemeinen akzeptiert die API Post-Requests der folgenden Struktur:

POST http://host:port/<URI>:<VERB>

URI: /v1/models/{MODEL_NAME}[/versions/{MODEL_VERSION}]
VERB: classify|regress|predict

Für unser Modell können wir die folgende URL für Vorhersagen verwenden: http://localhost:8501/v1/models/resnet_unfreeze_all_filtered:predict

Die Portnummer (hier 8501) ist der Port des Hosts, den wir oben angegeben haben, um ihn auf den Port 8501 des Serving-Images abzubilden. Wie oben erwähnt, ist 8501 der Port des Serving-Containers, der für die REST-API verwendet wird. Die Modellversion ist optional und wird standardmäßig auf die neueste Version gesetzt, wenn sie weggelassen wird.

In Python kann die Bibliothek requests verwendet werden, um HTTP-Anfragen zu senden. Wie in der Dokumentation angegeben, muss der Request-Body für die predict API ein JSON-Objekt mit den unten aufgeführten Schlüssel-Wert-Paaren sein:

  • signature_name – zu verwendende Signatur (weitere Informationen finden Sie in der Dokumentation)
  • instances – Modelleingabe im Zeilenformat
import json
import requests

# Send image as list to TF serving via json dump
request_url = 'http://localhost:8501/v1/models/resnet_unfreeze_all_filtered:predict'
request_body = json.dumps({"signature_name": "serving_default", "instances": img.tolist()})
request_headers = {"content-type": "application/json"}
json_response = requests.post(request_url, data=request_body, headers=request_headers)
response_body = json.loads(json_response.text)
predictions = response_body['predictions']

# Get label from prediction
y_hat_idx = np.argmax(predictions)
y_hat = classes[y_hat_idx]

Der Response-Body ist ebenfalls ein JSON-Objekt mit einem einzigen Schlüssel namens predictions. Da wir für jede Zeile in den Instanzen die Wahrscheinlichkeit für alle 300 Klassen erhalten, verwenden wir np.argmax, um die wahrscheinlichste Klasse zurückzugeben. Alternativ hätten wir auch die übergeordnete classify-API verwenden können.

Fazit

In diesem zweiten Blog-Artikel der Serie „Car Model Classification“ haben wir gelernt, wie ein TensorFlow-Modell zur Bilderkennung mittels TensorFlow Serving als RestAPI bereitgestellt werden kann, und wie damit Modellabfragen ausgeführt werden können.

Dazu haben wir zuerst das Modell im SavedModel Format abgespeichert. Als nächstes haben wir den TensorFlow Serving-Server in einem Docker-Container gestartet. Schließlich haben wir gezeigt, wie man Vorhersagen aus dem Modell mit Hilfe der API-Endpunkte und einem korrekt spezifizierten Request Body anfordert.

Ein Hauptkritikpunkt an Deep Learning Modellen jeglicher Art ist die fehlende Erklärbarkeit der Vorhersagen. Im dritten Beitrag werden wir zeigen, wie man Modellvorhersagen mit einer Methode namens Grad-CAM erklären kann.

Deep Learning ist eines der Themen im Bereich der künstlichen Intelligenz, die uns bei STATWORX besonders faszinieren. In dieser Blogserie möchten wir veranschaulichen, wie ein End-to-end Deep Learning Projekt implementiert werden kann. Dabei verwenden wir die TensorFlow 2.x Bibliothek für die Implementierung.

Die Themen der 4-teiligen Blogserie umfassen:

  • Transfer Learning für Computer Vision
  • Deployment über TensorFlow Serving
  • Interpretierbarkeit von Deep-Learning-Modellen mittels Grad-CAM
  • Integration des Modells in ein Dashboard

Im ersten Teil zeigen wir, wie man Transfer Learning nutzen kann, um die Marke eines Autos mittels Bildklassifizierung vorherzusagen. Wir beginnen mit einem kurzen Überblick über Transfer Learning und das ResNet und gehen dann auf die Details der Implementierung ein. Der vorgestellte Code ist in diesem Github Repository zu finden.

Table of Contents

Einführung: Transfer Learning & ResNet

Was ist Transfer Learning?

Beim traditionellen (Machine) Learning entwickeln wir ein Modell und trainieren es auf neuen Daten für jede neue Aufgabe, die ansteht. Transfer Learning unterscheidet sich von diesem Ansatz dadurch, dass das gesammelte Wissen von einer Aufgabe auf eine andere übertragen wird. Dieser Ansatz ist besonders nützlich, wenn einem zu wenige Trainingsdaten zur Verfügung stehen. Modelle, die für ein ähnliches Problem vortrainiert wurden, können als Ausgangspunkt für das Training neuer Modelle verwendet werden. Die vortrainierten Modelle werden als Basismodelle bezeichnet.

In unserem Beispiel kann ein Deep Learning-Modell, das auf dem ImageNet-Datensatz trainiert wurde, als Ausgangspunkt für die Erstellung eines Klassifikationsnetzwerks für Automodelle verwendet werden. Die Hauptidee hinter dem Transfer Learning für Deep Learning-Modelle ist, dass die ersten Layer eines Netzwerks verwendet werden, um wichtige High-Level-Features zu extrahieren, die für die jeweilige Art der behandelten Daten ähnlich bleiben. Die finalen Layer, auch „head“ genannt, des ursprünglichen Netzwerks werden durch einen benutzerdefinierten head ersetzt, der für das vorliegende Problem geeignet ist. Die Gewichte im head werden zufällig initialisiert, und das resultierende Netz kann für die spezifische Aufgabe trainiert werden.

Es gibt verschiedene Möglichkeiten, wie das Basismodell beim Training behandelt werden kann. Im ersten Schritt können seine Gewichte fixiert werden. Wenn der Lernfortschritt darauf schließen lässt, dass das Modell nicht flexibel genug ist, können bestimmte Layer oder das gesamte Basismodell auch mit trainiert werden. Ein weiterer wichtiger Aspekt, den es zu beachten gilt, ist, dass der Input die gleiche Dimensionalität haben muss wie die Daten, auf denen das Basismodell initial trainiert wurde – sofern die ersten Layer des Basismodells festgehalten werden sollen.

image-20200319174208670

Als nächstes stellen wir kurz das ResNet vor, eine beliebte und leistungsfähige CNN-Architektur für Bilddaten. Anschließend zeigen wir, wie wir Transfer Learning mit ResNet zur Klassifizierung von Automodellen eingesetzt haben.

Was ist ResNet?

Das Training von Deep Neural Networks kann aufgrund des sogenannten Vanishing Gradient-Problems schnell zur Herausforderung werden. Aber was sind Vanishing Gradients? Neuronale Netze werden in der Regel mit Back-Propagation trainiert. Dieser Algorithmus nutzt die Kettenregel der Differentialrechnung, um Gradienten in tieferen Layern des Netzes abzuleiten, indem Gradienten aus früheren Layern multipliziert werden. Da Gradienten in Deep Networks wiederholt multipliziert werden, können sie sich während der Backpropagation schnell infinitesimal kleinen Werten annähern.

ResNet ist ein CNN-Netz, welches das Problem des Vanishing Gradients mit sogenannten Residualblöcken löst (eine gute Erklärung, warum sie ‚Residual‘ heißen, findest du hier). Im Residualblock wird die unmodifizierte Eingabe an das nächste Layer weitergereicht, indem sie zum Ausgang eines Layers addiert wird (siehe Abbildung rechts). Diese Modifikation sorgt dafür, dass ein besserer Informationsfluss von der Eingabe zu den tieferen Layers möglich ist. Die gesamte ResNet-Architektur ist im rechten Netzwerk in der linken Abbildung unten dargestellt. Weiter sind daneben ein klassisches CNN und das VGG-19-Netzwerk, eine weitere Standard-CNN-Architektur, abgebildet.

Resnet-Architecture_Residual-Block

ResNet hat sich als leistungsfähige Netzarchitektur für Bildklassifikationsprobleme erwiesen. Zum Beispiel hat ein Ensemble von ResNets mit 152 Layern den ILSVRC 2015 Bildklassifikationswettbewerb gewonnen. Im Modul tensorflow.keras.application sind vortrainierte ResNet-Modelle unterschiedlicher Größe verfügbar, nämlich ResNet50, ResNet101, ResNet152 und die entsprechenden zweiten Versionen (ResNet50V2, …). Die Zahl hinter dem Modellnamen gibt die Anzahl der Layer an, über die die Netze verfügen. Die verfügbaren Gewichte sind auf dem ImageNet-Datensatz vortrainiert. Die Modelle wurden auf großen Rechenclustern unter Verwendung von spezialisierter Hardware (z.B. TPU) über signifikante Zeiträume trainiert. Transfer Learning ermöglicht es uns daher, diese Trainingsergebnisse zu nutzen und die erhaltenen Gewichte als Ausgangspunkt zu verwenden.

Klassifizierung von Automodellen

Als anschauliches Beispiel für die Anwendung von Transfer Learning behandeln wir das Problem der Klassifizierung des Automodells anhand eines Bildes des Autos. Wir beginnen mit der Beschreibung des verwendeten Datensatzes und wie wir unerwünschte Beispiele aus dem Datensatz herausfiltern können. Anschließend gehen wir darauf ein, wie eine Datenpipeline mit tensorflow.data eingerichtet werden kann. Im zweiten Abschnitt werden wir die Modellimplementierung durchgehen und aufzeigen, auf welche Aspekte ihr beim Training und bei der Inferenz besonders achten müsst.

Datenvorbereitung

Wir haben den Datensatz aus diesem GitHub Repo verwendet – dort könnt ihr den gesamten Datensatz herunterladen. Der Autor hat einen Datascraper gebaut, um alle Autobilder von der Car Connection Website zu scrapen. Er erklärt, dass viele Bilder aus dem Innenraum der Autos stammen. Da sie im Datensatz nicht erwünscht sind, filtern wir sie anhand der Pixelfarbe heraus. Der Datensatz enthält 64’467 jpg-Bilder, wobei die Dateinamen Informationen über die Automarke, das Modell, das Baujahr usw. enthalten. Für einen detaillierteren Einblick in den Datensatz empfehlen wir euch, das originale GitHub Repo zu konsultieren. Hier sind drei Beispielbilder:

Car Collage 01

Bei der Betrachtung der Daten haben wir festgestellt, dass im Datensatz noch viele unerwünschte Bilder enthalten waren, z.B. Bilder von Außenspiegeln, Türgriffen, GPS-Panels oder Leuchten. Beispiele für unerwünschte Bilder sind hier zu sehen:

Car Collage 02

Daher ist es von Vorteil, die Daten zusätzlich vorzufiltern, um mehr unerwünschte Bilder zu entfernen.

Filtern unerwünschter Bilder aus dem Datensatz

Es gibt mehrere mögliche Ansätze, um Nicht-Auto-Bilder aus dem Datensatz herauszufiltern:

  1. Verwendung eines vortrainierten Modells
  2. Ein anderes Modell trainieren, um Auto/Nicht-Auto zu klassifizieren
  3. Trainieren eines Generative Networks auf einem Auto-Datensatz und Verwendung des Diskriminatorteil des Netzwerks

Wir haben uns für den ersten Ansatz entschieden, da er der direkteste ist und ausgezeichnete, vortrainierte Modelle leicht verfügbar sind. Wenn ihr den zweiten oder dritten Ansatz verfolgen wollt, könnt ihr z. B. diesen Datensatz verwenden, um das Modell zu trainieren. Dieser Datensatz enthält nur Bilder von Autos, ist aber deutlich kleiner als der von uns verwendete Datensatz.

Unsere Wahl fiel auf das ResNet50V2 im Modul tensorflow.keras.applications mit den vortrainierten „imagenet“-Gewichten. In einem ersten Schritt müssen wir jetzt die Indizes und Klassennamen der imagenet-Labels herausfinden, die den Autobildern entsprechen.

# Class labels in imagenet corresponding to cars
CAR_IDX = [656, 627, 817, 511, 468, 751, 705, 757, 717, 734, 654, 675, 864, 609, 436]

CAR_CLASSES = ['minivan', 'limousine', 'sports_car', 'convertible', 'cab', 'racer', 'passenger_car', 'recreational_vehicle', 'pickup', 'police_van', 'minibus', 'moving_van', 'tow_truck', 'jeep', 'landrover', 'beach_wagon']

Als nächstes laden wir das vortrainierte ResNet50V2-Modell.

from tensorflow.keras.applications import ResNet50V2

model = ResNet50V2(weights='imagenet')

Wir können dieses Modell nun verwenden, um die Bilder zu klassifizieren. Die Bilder, die der Vorhersagemethode zugeführt werden, müssen identisch skaliert sein wie die Bilder, die zum Training verwendet wurden. Die verschiedenen ResNet-Modelle werden auf unterschiedlich skalierten Bilddaten trainiert. Es ist daher wichtig, das richtige Preprocessing anzuwenden.

from tensorflow.keras.applications.resnet_v2 import preprocess_input

image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.cast(image, tf.float32)
image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
image = preprocess_input(image)
predictions = model.predict(image)

Es gibt verschiedene Ideen, wie die erhaltenen Vorhersagen für die Autoerkennung verwendet werden können.

  • Ist eine der CAR_CLASSES unter den Top-k-Vorhersagen?
  • Ist die kumulierte Wahrscheinlichkeit der CAR_CLASSES in den Vorhersagen größer als ein definierter Schwellenwert?
  • Spezielle Behandlung unerwünschter Bilder (z. B. Erkennen und Herausfiltern von Rädern)?

Wir zeigen den Code für den Vergleich der kumulierten Wahrscheinlichkeitsmaße über die CAR_CLASSES.

def is_car_acc_prob(predictions, thresh=THRESH, car_idx=CAR_IDX):
    """
    Determine if car on image by accumulating probabilities of car prediction and comparing to threshold

    Args:
        predictions: (?, 1000) matrix of probability predictions resulting from ResNet with                                              imagenet weights
        thresh: threshold accumulative probability over which an image is considered a car
        car_idx: indices corresponding to cars

    Returns:
        np.array of booleans describing if car or not
    """
    predictions = np.array(predictions, dtype=float)
    car_probs = predictions[:, car_idx]
    car_probs_acc = car_probs.sum(axis=1)
    return car_probs_acc > thresh

Je höher der Schwellenwert eingestellt ist, desto strenger ist das Filterverfahren. Ein Wert für den Schwellenwert, der gute Ergebnisse liefert, ist THRESH = 0.1. Damit wird sichergestellt, dass nicht zu viele echte Bilder von Autos verloren gehen. Die Wahl eines geeigneten Schwellenwerts bleibt jedoch eine subjektive Angelegenheit.

Das Colab-Notebook, in dem die Funktion is_car_acc_prob zum Filtern des Datensatzes verwendet wird, ist im GitHub Repository verfügbar.

Bei der Abstimmung der Vorfilterung haben wir Folgendes beobachtet:

  • Viele der Autobilder mit hellem Hintergrund wurden als „Strandwagen“ klassifiziert. Wir haben daher entschieden, auch die Klasse „Strandwagen“ in imagenet als eine der CAR_CLASSES zu berücksichtigen.
  • Bilder, die die Front eines Autos zeigen, bekommen oft eine hohe Wahrscheinlichkeit der Klasse „Kühlergrill“ („grille“) zugeordnet, d.h. dem Gitter an der Front eines Autos, das zur Kühlung dient. Diese Zuordnung ist korrekt, führt aber dazu, dass die oben gezeigte Prozedur bestimmte Bilder von Autos nicht als Autos betrachtet, da wir „grille“ nicht in die CAR_CLASSES aufgenommen haben. Dieses Problem führt zu dem Kompromiss, entweder viele Nahaufnahmen von Autokühlergrills im Datensatz zu belassen oder einige Autobilder herauszufiltern. Wir haben uns für den zweiten Ansatz entschieden, da er einen saubereren Datensatz ergibt.

Nach der Vorfilterung der Bilder mit dem vorgeschlagenen Verfahren verbleiben zunächst 53’738 von 64’467 im Datensatz.

Übersicht über die endgültigen Datensätze

Der vorgefilterte Datensatz enthält Bilder von 323 Automodellen. Wir haben uns dazu entschieden, unsere Aufmerksamkeit auf die 300 häufigsten Klassen im Datensatz zu reduzieren. Das ist deshalb sinnvoll, da einige der am wenigsten häufigen Klassen weniger als zehn Repräsentanten haben und somit nicht sinnvoll in ein Trainings-, Validierungs- und Testset aufgeteilt werden können. Reduziert man den Datensatz auf die Bilder der 300 häufigsten Klassen, erhält man einen Datensatz mit 53.536 beschrifteten Bildern. Die Klassenvorkommen sind wie folgt verteilt:

Histogram

Die Anzahl der Bilder pro Klasse (Automodell) reicht von 24 bis knapp unter 500. Wir können sehen, dass der Datensatz sehr unausgewogen ist. Dies muss beim Training und bei der Auswertung des Modells unbedingt beachtet werden.

Aufbau von Datenpipelines mit tf.data

Selbst nach der Vorfilterung und der Reduktion auf die besten 300 Klassen bleiben immer noch zahlreiche Bilder übrig. Dies stellt ein potenzielles Problem dar, da wir nicht einfach alle Bilder auf einmal in den Speicher unserer GPU laden können. Um dieses Problem zu lösen, werden wir tf.data verwenden.

Mit tf.data und insbesondere der tf.data.Dataset API lassen sich elegante und gleichzeitig sehr effiziente Eingabe-Pipelines erstellen. Die API enthält viele allgemeine Methoden, die zum Laden und Transformieren potenziell großer Datensätze verwendet werden können. Die Methode tf.data.Dataset ist besonders nützlich, wenn Modelle auf GPU(s) trainiert werden. Es ermöglicht das Laden von Daten von der Festplatte, wendet on-the-fly Transformationen an und erstellt Batches, die dann an die GPU gesendet werden. Und das alles geschieht so, dass die GPU nie auf neue Daten warten muss.

Die folgenden Funktionen erstellen eine <code>tf.data.Dataset-Instanz für unseren konkreten Anwendungsfall:

def construct_ds(input_files: list,
                 batch_size: int,
                 classes: list,
                 label_type: str,
                 input_size: tuple = (212, 320),
                 prefetch_size: int = 10,
                 shuffle_size: int = 32,
                 shuffle: bool = True,
                 augment: bool = False):
    """
    Function to construct a tf.data.Dataset set from list of files

    Args:
        input_files: list of files
        batch_size: number of observations in batch
        classes: list with all class labels
        input_size: size of images (output size)
        prefetch_size: buffer size (number of batches to prefetch)
        shuffle_size: shuffle size (size of buffer to shuffle from)
        shuffle: boolean specifying whether to shuffle dataset
        augment: boolean if image augmentation should be applied
        label_type: 'make' or 'model'

    Returns:
        buffered and prefetched tf.data.Dataset object with (image, label) tuple
    """
    # Create tf.data.Dataset from list of files
    ds = tf.data.Dataset.from_tensor_slices(input_files)

    # Shuffle files
    if shuffle:
        ds = ds.shuffle(buffer_size=shuffle_size)

    # Load image/labels
    ds = ds.map(lambda x: parse_file(x, classes=classes, input_size=input_size,                                                                                                                                        label_type=label_type))

    # Image augmentation
    if augment and tf.random.uniform((), minval=0, maxval=1, dtype=tf.dtypes.float32, seed=None, name=None) < 0.7:
        ds = ds.map(image_augment)

    # Batch and prefetch data
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=prefetch_size)

    return ds

Wir werden nun die verwendeten tf.data-Methoden beschreiben:

  • from_tensor_slices() ist eine der verfügbaren Methoden für die Erstellung eines Datensatzes. Der erzeugte Datensatz enthält Slices des angegebenen Tensors, in diesem Fall die Dateinamen.
  • Als nächstes betrachtet die Methode shuffle() jeweils buffer_size-Elemente separat und mischt diese Elemente isoliert vom Rest des Datensatzes. Wenn das Mischen des gesamten Datensatzes erforderlich ist, muss buffer_size größer sein als die Anzahl der Einträge im Datensatz. Das Mischen wird nur durchgeführt, wenn shuffle=True gesetzt ist.
  • Mit map() lassen sich beliebige Funktionen auf den Datensatz anwenden. Wir haben eine Funktion parse_file() erstellt, die im GitHub Repo zu finden ist. Sie ist verantwortlich für das Lesen und die Größenänderung der Bilder, das Ableiten der Beschriftungen aus dem Dateinamen und die Kodierung der Beschriftungen mit einem One-Hot-Encoder. Wenn die Flag „augment“ gesetzt ist, wird das Verfahren zur Datenerweiterung aktiviert. Die Augmentierung wird nur in 70 % der Fälle angewendet, da es von Vorteil ist, das Modell auch auf nicht modifizierten Bildern zu trainieren. Die in image_augment verwendeten Augmentierungstechniken sind Flipping, Helligkeits- und Kontrastanpassungen.
  • Schließlich wird die Methode batch() verwendet, um den Datensatz in Batches der Größe batch_size zu gruppieren, und die Methode prefetch() ermöglicht die Vorbereitung späterer Batches, während der aktuelle Batch verarbeitet wird, und verbessert so die Leistung. Wenn die Methode nach einem Aufruf von batch() verwendet wird, werden prefetch_size-Batches vorab geholt.

Fine Tuning des Modells

Nachdem wir unsere Eingabe-Pipeline definiert haben, wenden wir uns nun dem Trainingsteil des Modells zu. Der Code unten zeigt auf, wie ein Modell basierend auf dem vortrainierten ResNet instanziiert werden kann:

from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D


class TransferModel:

    def __init__(self, shape: tuple, classes: list):
        """
        Class for transfer learning from ResNet

        Args:
            shape: Input shape as tuple (height, width, channels)
            classes: List of class labels
        """
        self.shape = shape
        self.classes = classes
        self.history = None
        self.model = None

        # Use pre-trained ResNet model
        self.base_model = ResNet50V2(include_top=False,
                                     input_shape=self.shape,
                                     weights='imagenet')

        # Allow parameter updates for all layers
        self.base_model.trainable = True

        # Add a new pooling layer on the original output
        add_to_base = self.base_model.output
        add_to_base = GlobalAveragePooling2D(data_format='channels_last', name='head_gap')(add_to_base)

        # Add new output layer as head
        new_output = Dense(len(self.classes), activation='softmax', name='head_pred')(add_to_base)

        # Define model
        self.model = Model(self.base_model.input, new_output)

Ein paar weitere Details zum oben stehenden Code:

  • Wir erzeugen zunächst eine Instanz der Klasse tf.keras.applications.ResNet50V2. Mit include_top=False weisen wir das vortrainierte Modell an, den ursprünglichen head des Modells (in diesem Fall für die Klassifikation von 1000 Klassen auf ImageNet ausgelegt) wegzulassen.
  • Mit base_model.trainable = True werden alle Layer trainierbar.
  • Mit der funktionalen API tf.keras stapeln wir dann ein neues Pooling-Layer auf den letzten Faltungsblock des ursprünglichen ResNet-Modells. Dies ist ein notwendiger Zwischenschritt, bevor die Ausgabe an die endgültigen Klassifizierungs-Layer weitergeleitet wird.
  • Die endgültigen Klassifizierungs-Layer wird dann mit „tf.keras.layers.Dense“ definiert. Wir definieren die Anzahl der Neuronen so, dass sie gleich der Anzahl der gewünschten Klassen ist. Und die Softmax-Aktivierungsfunktion sorgt dafür, dass die Ausgabe eine Pseudowahrscheinlichkeit im Bereich von (0,1] ist.

Die Vollversion von TransferModel (s. GitHub) enthält auch die Option, das Basismodell durch ein VGG16-Netzwerk zu ersetzen, ein weiteres Standard-CNN für die Bildklassifikation. Außerdem erlaubt es, nur bestimmte Layer freizugeben, d.h. wir können die entsprechenden Parameter trainierbar machen, während wir die anderen festgehalten werden. Standardmässig haben wir hier alle Parameter trainierbar gemacht.

Nachdem wir das Modell definiert haben, müssen wir es für das Training konfigurieren. Dies kann mit der compile()-Methode von tf.keras.Model gemacht werden:

def compile(self, **kwargs):
      """
    Compile method
    """
    self.model.compile(**kwargs)

Wir übergeben dann die folgenden Keyword-Argumente an unsere Methode:

  • loss = "categorical_crossentropy" für die Mehrklassen-Klassifikation,
  • optimizer = Adam(0.0001) für die Verwendung des Adam-Optimierers aus tf.keras.optimizers mit einer relativ kleinen Lernrate (mehr zur Lernrate weiter unten), und
  • metrics = ["categorical_accuracy"] für die Trainings- und Validierungsüberwachung.

Als Nächstes wollen wir uns das Trainingsverfahren ansehen. Dazu definieren wir eine train-Methode für unsere oben vorgestellte TransferModel-Klasse:

from tensorflow.keras.callbacks import EarlyStopping

def train(self,
          ds_train: tf.data.Dataset,
          epochs: int,
          ds_valid: tf.data.Dataset = None,
          class_weights: np.array = None):
    """
    Trains model in ds_train with for epochs rounds

    Args:
        ds_train: training data as tf.data.Dataset
        epochs: number of epochs to train
        ds_valid: optional validation data as tf.data.Dataset
        class_weights: optional class weights to treat unbalanced classes

    Returns
        Training history from self.history
    """

    # Define early stopping as callback
    early_stopping = EarlyStopping(monitor='val_loss',
                                   min_delta=0,
                                   patience=12,
                                   restore_best_weights=True)

    callbacks = [early_stopping]

    # Fitting
    self.history = self.model.fit(ds_train,
                                  epochs=epochs,
                                  validation_data=ds_valid,
                                  callbacks=callbacks,
                                  class_weight=class_weights)

    return self.history

Da unser Modell eine Instanz von tensorflow.keras.Model ist, können wir es mit der Methode fit trainieren. Um Overfitting zu verhindern, wird Early Stopping verwendet, indem es als Callback-Funktion an die fit-Methode übergeben wird. Der patience-Parameter kann eingestellt werden, um festzulegen, wie schnell das Early Stopping angewendet werden soll. Der Parameter steht für die Anzahl der Epochen, nach denen, wenn keine Abnahme des Validierungsverlustes registriert wird, das Training abgebrochen wird. Weiterhin können Klassengewichte an die Methode fit übergeben werden. Klassengewichte erlauben es, unausgewogene Daten zu behandeln, indem den verschiedenen Klassen unterschiedliche Gewichte zugewiesen werden, wodurch die Wirkung von Klassen mit weniger Trainingsbeispielen erhöht werden kann.

Wir können den Trainingsprozess mit einem vortrainierten Modell wie folgt beschreiben: Da die Gewichte im head zufällig initialisiert werden und die Gewichte des Basismodells vortrainiert sind, setzt sich das Training aus dem Training des heads von Grund auf und der Feinabstimmung der Gewichte des vortrainierten Modells zusammen. Es wird generell für Transfer Learning empfohlen, eine kleine Lernrate zu verwenden (z. B. 1e-4), da eine zu große Lernrate die nahezu optimalen vortrainierten Gewichte des Basismodells zerstören kann.

Der Trainingsvorgang kann beschleunigt werden, indem zunächst einige Epochen lang trainiert wird, ohne dass das Basismodell trainierbar ist. Der Zweck dieser ersten Epochen ist es, die Gewichte des heads an das Problem anzupassen. Dies beschleunigt das Training, da wenn nur der head trainiert wird, viel weniger Parameter trainierbar sind und somit für jeden Batch aktualisiert werden. Die resultierenden Modellgewichte können dann als Ausgangspunkt für das Training des gesamten Modells verwendet werden, wobei das Basismodell trainierbar ist. Für das hier betrachtete Autoklassifizierungsproblem führte die Anwendung dieses zweistufigen Trainings zu keiner nennenswerten Leistungsverbesserung.

Evaluation/Vorhersage der Modell Performance

Bei der Verwendung der API tf.data.Dataset muss man auf die Art der verwendeten Methoden achten. Die folgende Methode in unserer Klasse TransferModel kann als Vorhersagemethode verwendet werden.

def predict(self, ds_new: tf.data.Dataset, proba: bool = True):
    """
    Predict class probs or labels on ds_new
    Labels are obtained by taking the most likely class given the predicted probs

    Args:
        ds_new: New data as tf.data.Dataset
        proba: Boolean if probabilities should be returned

    Returns:
        class labels or probabilities
    """

    p = self.model.predict(ds_new)

    if proba:
        return p
    else:
        return [np.argmax(x) for x in p]

Es ist wichtig, dass der Datensatz ds_new nicht gemischt wird, sonst stimmen die erhaltenen Vorhersagen nicht mit den erhaltenen Bildern überein, wenn ein zweites Mal über den Datensatz iteriert wird. Dies ist der Fall, da die Flag reshuffle_each_iteration in der Implementierung der Methode shuffle standardmäßig auf True gesetzt ist. Ein weiterer Effekt des Shufflens ist, dass mehrere Aufrufe der Methode take nicht die gleichen Daten zurückgeben. Dies ist wichtig, wenn z. B. Vorhersagen für nur eine Charge überprüft werden sollen. Ein einfaches Beispiel, an dem dies zu sehen ist, ist:

# Use construct_ds method from above to create a shuffled dataset
ds = construct_ds(..., shuffle=True)

# Take 1 batch (e.g. 32 images) of dataset: This returns a new dataset
ds_batch = ds.take(1)

# Predict labels for one batch
predictions = model.predict(ds_batch)

# Predict labels again: The result will not be the same as predictions above due to shuffling
predictions_2 = model.predict(ds_batch)

Eine Funktion zum Plotten von Bildern, die mit den entsprechenden Vorhersagen beschriftet sind, könnte wie folgt aussehen:

def show_batch_with_pred(model, ds, classes, rescale=True, size=(10, 10), title=None):
      for image, label in ds.take(1):
        image_array = image.numpy()
        label_array = label.numpy()
        batch_size = image_array.shape[0]
        pred = model.predict(image, proba=False)
        for idx in range(batch_size):
            label = classes[np.argmax(label_array[idx])]
            ax = plt.subplot(np.ceil(batch_size / 4), 4, idx + 1)
            if rescale:
                plt.imshow(image_array[idx] / 255)
            else:
                plt.imshow(image_array[idx])
            plt.title("label: " + label + "n" 
                      + "prediction: " + classes[pred[idx]], fontsize=10)
            plt.axis('off')

Die Methode show_batch_with_pred funktioniert auch für gemischte Datensätze, da image und label demselben Aufruf der Methode take entsprechen.

Die Auswertung der Model-Performance kann mit der Methode evaluate von keras.Model durchgeführt werden.

Wie akkurat ist unser finales Modell?

Das Modell erreicht eine kategoriale Genauigkeit von etwas über 70 % für die Vorhersage des Automodells für Bilder aus 300 Modellklassen. Um die Vorhersagen des Modells besser zu verstehen, ist es hilfreich, die Konfusionsmatrix zu betrachten. Unten ist die Heatmap der Vorhersagen des Modells für den Validierungsdatensatz abgebildet.

heatmap

Wir haben die Heatmap auf Einträge der Konfusionsmatrix in [0, 5] beschränkt, da das Zulassen einer weiteren Spanne keine Region außerhalb der Diagonalen signifikant hervorgehoben hat. Wie in der Heatmap zu sehen ist, wird eine Klasse den Beispielen fast aller Klassen zugeordnet. Das ist an der dunkelroten vertikalen Linie zwei Drittel rechts in der Abbildung oben zu erkennen.

Abgesehen von der zuvor erwähnten Klasse gibt es keine offensichtlichen Verzerrungen in den Vorhersagen. Wir möchten an dieser Stelle betonen, dass die Accuracy im Allgemeinen nicht ausreicht, um die Leistung eines Modells zufriedenstellend zu beurteilen, insbesondere im Fall unausgewogener Klassen.

Fazit und nächste Schritte

In diesem Blog-Beitrag haben wir Transfer Learning mit dem ResNet50V2 angewendet, um das Fahrzeugmodell anhand von Bildern von Autos zu klassifizieren. Unser Modell erreicht 70% kategoriale Genauigkeit über 300 Klassen.

Wir haben festgestellt, dass das Trainieren des gesamten Basismodells und die Verwendung einer kleinen Lernrate die besten Ergebnisse erzielen. Ein cooles Auto-Klassifikationsmodell entwickelt zu haben ist großartig, aber wie können wir unser Modell in einer produktiven Umgebung einsetzen? Natürlich könnten wir unsere eigene Modell-API mit Flask oder FastAPI bauen… Aber gibt es vielleicht sogar einen einfacheren, standardisierten Weg?

Im zweiten Beitrag unserer Blog-Serie, „Deployment von TensorFlow-Modellen in Docker mit TensorFlow Serving“ zeigen wir Euch, wie dieses Modell mit TensorFlow Serving bereitgestellt werden kann.

Management Summary

Neuronale Netze haben sich in den letzten Jahren immer weiter zur Kerntechnologie im Bereich Machine Learning und AI entwickelt. Neben den klassischen Anwendungen der Daten-Klassifizierung und -Regression können sie auch dazu verwendet werden, neue, realistische Datenpunkte zu erzeugen. Die etablierte Modell-Architektur für die Datengenerierung sind sogenannte Generative Adversarial Networks (GANs). GANs bestehen aus zwei neuronalen Netzen, dem Generator- und dem Diskriminatornetzwerk. Die beiden Netzwerke werden iterativ gegeneinander trainiert. Hierbei versucht der Generator einen realistischen Datenpunkt zu erzeugen, während der Diskriminator lernt, echte und synthetische Datenpunkte zu unterscheiden. In diesem Minimax-Spiel verbessert der Generator seine Performance so weit, bis die erzeugten Datenpunkte nicht mehr von echten Datenpunkten zu unterscheiden sind.

Die Erzeugung synthetischer Daten mittels GANs hat eine Vielzahl von spannenden Anwendungsmöglichkeiten in der Praxis:

  1. Die Bild-Synthese
    Die ersten Erfolge von GANs wurden auf Basis von Bilddaten erreicht. Heute sind spezifische GAN-Architekturen fähig, Bilder von Gesichtern zu generieren, die kaum von echten unterscheidbar sind.
  2. Die Musik-Synthese
    Neben Bildern können GAN-Modelle auch verwendet werden, um Musik zu generieren. Im Bereich der Musik-Synthese wurden sowohl für das Erzeugen von realistischen Wellenformen als auch für die Generierung ganzer Melodien vielversprechende Resultate erreicht.
  3. Super-Resolution
    Ein weiteres Beispiel für den erfolgreichen Einsatz von GANs ist die Super-Resolution. Hierbei wird versucht, die Auflösung eines Bilds zu verbessern bzw. die Größe des Bildes verlustfrei zu erhöhen. Mittels GANs können schärfere Übergänge zwischen verschiedenen Bildbereichen als mit z.B. Interpolation erreicht und die allgemeine Bildqualität verbessert werden.
  4. Deepfakes
    Mittels Deepfake Algorithmen können Gesichter in Bildern und Videoaufnahmen durch andere Gesichter ersetzt werden. Dabei wird der Algorithmus auf möglichst vielen Daten der Zielperson trainiert. Danach kann damit die Mimik einer dritten Person auf die zu imitierende Person transferiert werden, und es entsteht ein „Deepfake“. 
  5. Generierung zusätzlicher Trainingsdaten
    GANs werden auch dazu verwendet, die im maschinellen Lernen oft nur in geringen Mengen verfügbaren Trainingsdaten zu vermehren (auch Augmentieren genannt). Dabei wird ein GAN auf dem Trainingsdatensatz kalibriert, damit anschließend beliebig viele, neue Trainingsbeispiele generiert werden können. Das Ziel ist, mit den zusätzlich generierten Daten die Performance bzw. Generalisierbarkeit von ML-Modellen zu verbessern.
  6. Anonymisierung von Daten
    Letztlich können mittels GANs auch Daten anonymisiert werden. Dies ist vor allem für Datensätze, die personenbezogene Daten enthalten, ein wichtiger Anwendungsfall. Der Vorteil von GANs gegenüber anderen Ansätzen zur Datenanonymisierung ist, dass die statistischen Eigenschaften des Datensets erhalten bleiben. Dies ist wiederum wichtig für die Performance anderer Machine Learning Modelle, die auf diesen Daten trainiert werden. Da das Verarbeiten von personenbezogenen Daten heute im Bereich des maschinellen Lernens eine große Hürde darstellt, sind GANs für die Anonymisierung von Datensätzen ein vielversprechender Ansatz und haben das Potential, für Firmen die Tür zur Verarbeitung personenbezogener Daten zu öffnen.

In diesem Artikel wird die Funktionsweise von GANs erläutert. Zudem werden Use Cases diskutiert, die mithilfe von GANs umgesetzt werden können, und aktuelle Trends vorgestellt, die sich im Bereich von generativen Netzwerken abzeichnen.

Einleitung

Die größten Fortschritte der letzten Jahre im Bereich der künstlichen Intelligenz wurden durch die Anwendung von neuronalen Netzen erzielt. Diese haben sich insbesondere bei der Verarbeitung von unstrukturierten Daten, wie Texten oder Bildern, als äußerst zuverlässiger Ansatz zur Regression und Klassifizierung erwiesen. Ein bekanntes Beispiel ist die Klassifizierung von Bildern in zwei oder mehr Gruppen. Auf dem sogenannten ImageNet-Datensatz [1] mit insgesamt 1000 Klassen erreichen Neuronale Netze eine Top-5 Accuracy von 98.7%. Das bedeutet, dass für 98.7% der Bilder die richtige Klasse in den fünf besten Modellvorhersagen enthalten ist. Auch im Bereich der natürlichen Sprachverabreitung wurden in den letzten Jahren bahnbrechende Resultate durch die Anwendung von neuronalen Netzen erzielt. Die Transformer-Architektur wurde beispielsweise sehr erfolgreich für verschiedene Sprachverarbeitungsprobleme wie Question-Answering, Named Entity Recognition und Sentimentanalyse verwendet.

Weniger bekannt ist, dass neuronale Netze auch dazu verwendet werden können die zugrundeliegende Verteilung einer Datenmenge zu erlernen, um somit neue, realistische Beispieldaten zu generieren. Als zielführende Modell-Architektur für diese Problemstellung haben sich in den letzten Jahren Generative Adversarial Networks (GANs) etabliert. Durch deren Anwendung ist es möglich, synthetische Datenpunkte zu erzeugen, die die gleichen statistischen Eigenschaften wie die zugrundeliegenden Trainingsdaten aufweisen. Dies eröffnet viele spannende Anwendungsfälle, von denen nachfolgend einige vorgestellt werden.

Generative Adversarial Networks

GANs sind eine neue Klasse von Algorithmen im Bereich des maschinellen Lernens. Wie zuvor erläutert, handelt es sich dabei um Modelle, die neue, realistische Datenpunkte generieren können, nachdem sie auf einem bestimmten Datensatz trainiert wurden. GANs bestehen im Kern aus zwei neuronalen Netzen, die für spezifische Tasks im Lernprozess zuständig sind. Der Generator ist dafür verantwortlich, einen zufällig generierten Input in ein realistisches Sample der zu lernenden Verteilung zu transformieren. Im Gegensatz dazu ist es die Aufgabe des Diskriminators, echte von generierten Datenpunkten zu unterscheiden.

Eine einfache Analogie ist das Wechselspiel eines Fälschers und eines Polizisten. Der Fälscher versucht möglichst realistische Münzfälschungen zu kreieren, während der Polizist die Fälschungen von echten Münzen unterscheiden will. Je besser die Nachbildungen des Fälschers sind, desto besser muss der Polizist die Münzen unterscheiden können, um die gefälschten Exemplare zu erkennen. Andererseits müssen die Fälschungen immer besser werden, da der Polizist mit seiner Erfahrung gefälschte Münzen zuverlässiger erkennen kann. Es ist einleuchtend, dass im Verlaufe dieses Prozesses sowohl der Fälscher als auch der Polizist in ihrer jeweiligen Aufgabe besser werden. Dieser iterative Prozess bildet auch die Grundlage für das Training der GAN-Modelle. Der Fälscher entspricht hier dem Generator, während der Polizist die Rolle des Diskriminators einnimmt.

Wie werden Generator- und Diskriminator-Netzwerke trainiert?

In einem ersten Schritt wird der Diskriminator fixiert. Das bedeutet, dass in diesem Schritt für das Diskriminatornetzwerk keine Anpassungen der Parameter vorgenommen werden. Für eine bestimmte Anzahl an Trainingsschritten wird dann der Generator trainiert. Der Generator wird, wie für neuronale Netze üblich, mittels Backpropagation trainiert. Sein Ziel ist, dass die aus dem zufälligen Input entstehenden „gefälschten“ Outputs von dem aktuellen Diskriminator als echte Beispiele klassifiziert werden. Es ist wichtig zu erwähnen, dass die Trainingsfortschritte vom aktuellen Stand des Diskriminators abhängen.

Im zweiten Schritt wird der Generator fixiert und ausschließlich der Diskriminator trainiert, indem er sowohl echte als auch generierte Beispiele mit entsprechenden Labels als Trainingsinput verarbeitet. Das Ziel des Trainingsprozesses ist, dass der Generator so realistische Beispiele zu erschaffen lernt, dass der Diskriminator sie nicht mehr von echten unterscheiden kann. Zum besseren Verständnis ist der GAN-Trainingsprozess in der Abbildung 1 schematisch abgebildet.

Abbildung 1 – Architektur eines GANs

Herausforderungen des Trainingsprozesses

Aufgrund der Natur des alternierenden Trainingsprozesses können verschiedene Probleme beim Training eines GANs auftreten. Eine häufige Herausforderung ist, dass das Feedback des Diskriminators schlechter wird, je besser der Generator im Verlauf des Trainings wird.

Man kann sich den Prozess folgendermaßen vorstellen: Wenn der Generator Beispiele generieren kann, die von echten Beispielen nicht mehr unterscheidbar sind, bleibt dem Diskriminator nichts anderes übrig, als zu raten aus welcher Klasse das jeweilige Beispiel stammt. Wenn man also das Training nicht rechtzeitig beendet, ist es möglich, dass die Qualität des Generators und des Diskriminators aufgrund der zufälligen Rückgabewerten des Diskriminators wieder abnimmt.

Ein weiteres häufig auftretendes Problem ist der sogenannte „Mode-Collapse“. Dieser tritt auf, wenn das Netzwerk anstatt die Eigenschaften der zugrundeliegenden Daten zu lernen, sich einzelne Beispiele dieser Daten merkt oder nur Beispiele mit geringer Variabilität generiert. Einige Ansätze, um diesem Problem entgegenzuwirken, sind das Verarbeiten mehrerer Beispiele gleichzeitig (in Batches) oder das simultane Zeigen vergangener Beispiele, damit der Diskriminator mangelnde Unterscheidbarkeit der Beispiele quantifizieren kann.

Use Cases für GANs

Die Erzeugung von synthetischen Daten mittels GANs hat eine Vielzahl von spannenden Anwendungsmöglichkeiten in der Praxis. Die ersten eindrucksvollen Resultate von GANs wurden auf Basis von Bilddaten erreicht, dicht gefolgt von der Generierung von Audiodaten. In den letzten Jahren wurden GANs jedoch auch erfolgreich auf andere Datentypen, wie tabellarische Daten, angewendet. Im Folgenden werden ausgewählte Anwendungen vorgestellt.

1.    Die Bild-Synthese

Eine der bekanntesten Anwendungen von GANs ist die Bild-Synthese. Hierbei wird ein GAN auf Basis eines großen Bilddatensatzes trainiert. Dabei erlernt das Generatornetzwerk die wichtigen gemeinsamen Merkmale und Strukturen der Bilder. Ein Beispiel ist die Generierung von Gesichtern. Die Portraits in der Abbildung 2 wurden mittels speziellen GAN-Netzwerken erzeugt, die für Bilddaten optimierte Faltungsnetzwerke, „Convolutional Neural Networks“ (CNN), verwenden. Natürlich können auch andere Arten von Bildern mittels GANs erzeugt werden, wie zum Beispiel handgeschriebene Zeichen, Fotos von Gegenständen oder Häusern. Die grundlegende Voraussetzung für gute Resultate ist ein ausreichend großer Datensatz.

Ein praktischer Anwendungsfall dieser Technologie findet sich im Online Retail. Hier werden mittels GAN Fotos von Models in spezifischen Kleidungsstücken oder bestimmten Posen erzeugt.

Abbildung 2 – GAN-generierte Gesichter

2.    Die Musik-Synthese

Im Gegensatz zur Erzeugung von Bildern beinhaltet die Musik-Synthese eine temporale Komponente. Da Audio-Wellenformen sehr periodisch sind und das menschliche Ohr sehr empfindlich gegen Abweichungen dieser Wellenformen ist, hat die Erhaltung der Signalperiodizität einen hohen Stellenwert für GANs, die Musik generieren. Dies hat zu GAN-Modellen geführt, die anstatt der Wellenform die Magnituden und Frequenzen des Tons generieren. Beispiele für GANs zur Musiksynthese sind das GANSynth [3], welches für das Generieren von realistischen Wellenformen optimiert ist, und das MuseGAN [4] –  seinerseits spezialisiert auf das Generieren von ganzen Melodiesequenzen. Im Internet finden sich diverse Quellen, die von GANs erzeugte Musikstücke präsentieren. [5, 6]  Es ist zu erwarten, dass GANs in den nächsten Jahren für eine hohe Disruption in der Musikbranche führen werden.

Die Abbildung 3 zeigt Ausschnitte der vom MuseGAN generierten Musik für verschiedene Trainingsschritte. Zu Beginn des Trainings (step 0) ist der Output rein zufällig und im letzten Schritt (step 7900) erkennt man, dass z.B. der Bass eine typische Basslinie spielt und die Streicher vorwiegend gehaltene Akkorde spielen.

Abbildung 3 – MuseGAN Trainingsverlauf

3.    Die Super-Resolution

Der Prozess der Wiedergewinnung eines hochauflösenden Bildes aus einem tiefer aufgelösten Bild heißt Super-Resolution. Die Schwierigkeit hierbei besteht darin, dass es sehr viele mögliche Lösungen für die Wiederherstellung des Bildes gibt.

Ein klassischer Ansatz für Super-Resolution ist Interpolation. Dabei werden die für die höhere Auflösung fehlenden Pixel mittels einer vorgegebenen Funktion von den benachbarten Pixeln abgeleitet. Dabei gehen aber typischerweise viele Details verloren und die Übergänge zwischen verschiedenen Bildteilen verschwimmen. Da GANs die Hauptmerkmale von Bilddaten gut erlernen können, sind sie auch für das Problem der Super-Resolution ein geeigneter Ansatz. Dabei wird für das Generator-Netzwerk nicht mehr ein zufälliger Input verwendet, sondern das niedrigaufgelöste Bild. Der Diskriminator lernt dann generierte Bilder höherer Auflösung von den echten Bildern höherer Auflösung zu unterscheiden. Das trainierte GAN-Netzwerk kann dann auf neuen tief aufgelösten Bildern verwendet werden.

Ein State of the Art GAN für Super-Resolution ist das SRGAN [7]. In der Abbildung 4 sind die Resultate für verschiedene Super-Resolution-Ansätze für ein Beispielbild abgebildet. Das SRGAN (zweites Bild von rechts) liefert das schärfste Resultat. Besonders Details wie Wassertropfen oder die Oberflächenstruktur des Kopfschmucks werden vom SRGAN erfolgreich rekonstruiert.

Abbildung 4 – Beispiel Super-Resolution für 4x Upscaling (von links nach rechts: bikubische Interpolation, SRResNet, SRGAN, Originalbild)

4.    Deepfakes: Audio und Video

Mittels Deepfake-Algorithmen können Gesichter in Bildern und Videoaufnahmen durch andere Gesichter ersetzt werden. Die Resultate sind mittlerweile so gut, dass die Deepfakes nur schwer von echten Aufnahmen unterschieden werden können. Um einen Deepfake herzustellen wird wird der Algorithmus zunächst auf möglichst vielen Daten der Zielperson trainiert, damit anschließend die Mimik einer dritten Person auf die zu imitierende Person transferiert werden kann. In der Abbildung 5 sehen Sie den Vorher-nachher-Vergleich eines Deepfakes. Dabei wurde das Gesicht des Schauspielers Matthew McConaughey (Bildausschnit aus dem Film „Interstellar“) (links) durch das Gesicht des Tech-Unternehmers Elon Musk ersetzt (rechts).

Diese Technik lässt sich beispielsweise auch auf Audiodaten anwenden. Dabei wird das Modell auf eine Zielstimme trainiert, um dann die Stimme in der Originalaufnahme durch die Zielstimme zu ersetzen.

Abbildung 5 – Deepfake Sample Images

5.    Generierung zusätzlicher Trainingsdaten

Um tiefe neuronale Netze mit vielen Parametern trainieren zu können, werden i.d.R. sehr große Datenmengen benötigt. Oft ist es nur schwer oder gar nicht möglich, an eine ausreichend große Datenmengen zu gelangen oder diese zu erheben. Ein Ansatz, um auch mit weniger Daten ein gutes Resultat zu erzielen, ist die Data Augmentation. Dabei werden die vorhandenen Datenpunkte leicht abgeändert, um somit neue Trainingsbeispiele zu kreieren. Häufig wird dies bspw. im Bereich von Computer Vision angewendet, indem Bilder rotiert oder mittels Zoom neue Bildausschnitte abgeleitet werden. Ein neuerer Ansatz für Data Augmentation ist das Verwenden von GANs. Die Idee ist, dass GANs die Verteilung der Trainingsdaten erlernen und somit theoretisch unendlich viele neue Beispiele generiert werden können. Für die Erkennung von Krankheiten auf Tomographie-Bildern wurde dieser Ansatz beispielsweise erfolgreich umgesetzt [9]. GANs können also dazu verwendet werden, neue Trainingsdaten für Probleme zu generieren, bei denen nicht genügend große Datenmengen für das Training vorhanden sind.

In vielen Fällen ist das Erheben größerer Datenmengen auch mit erheblichen Kosten verbunden, insbesondere dann, wenn die Daten für das Training manuell gekennzeichnet werden müssen. In diesen Situationen können GANs verwendet werden, um die anfallenden Kosten einer zusätzlichen Datenakquise zu reduzieren.

Ein weiteres technisches Problem beim Training von Machine Learning Modellen, bei dem GANs Abhilfe schaffen können, sind Imbalanced Datasets. Dabei handelt es sich um Datensätze mit verschiedenen Klassen, die unterschiedlich häufig repräsentiert sind. Um mehr Trainingsdaten der unterrepräsentierten Klasse zu erhalten, kann man für dieser Klasse ein GAN trainieren, um weitere synthetisch generierte Datenpunkte zu erzeugen, die dann im Training verwendet werden. Beispielsweise gibt es deutlich weniger Mikroskop-Aufnahmen von krebsbefallenen Zellen als von gesunden. GANs ermöglichen in diesem Fall, bessere Modelle zur Erkennung von Krebszellen zu trainieren und können so Ärzte bei der Krebsdiagnose unterstützen.

6.    Anonymisierung von Daten

Ein weiterer interessanter Anwendungsbereich von GANs ist die Anonymisierung von Datensätzen. Klassische Ansätze für die Anonymisierung sind das Entfernen von Identifikator-Spalten oder das zufällige Ändern ihrer Werte. Resultierende Probleme sind bspw., dass mit entsprechendem Vorwissen trotzdem Rückschlüsse auf die Personen hinter den personenbezogenen Daten gezogen werden können. Auch die Änderung der statistischen Eigenschaften des Datensatzes durch das Fehlen oder Abändern bestimmter Informationen kann den Nutzen der Daten schmälern. GANs können auch dazu verwendet werden, anonymisierte Datensätze zu generieren. Sie werden dabei so trainiert, dass die persönlichen Daten im generierten Datensatz nicht mehr identifiziert, aber Modelle trotzdem ähnlich gut damit trainiert werden können [10]. Erklären kann man diese gleichbleibende Qualität der Modelle damit, dass die zugrundeliegenden statistischen Eigenschaften des ursprünglichen Datensatzes von dem GAN gelernt, und somit auch erhalten werden.

Anonymisierung mithilfe von GANs kann neben tabellarischen Daten auch spezifisch für Bilder angewendet werden. Dabei werden Gesichter oder andere persönliche Merkmale/Elemente auf dem Bild durch generierte Varianten ersetzt. Dies erlaubt, Computer Vision Modelle mit realistisch aussehenden Daten zu trainieren, ohne mit Datenschutzproblemen konfrontiert zu werden. Oft wird die Qualität der Modelle deutlich schlechter, wenn wichtige Merkmale eines Bilds, wie beispielsweise ein Gesicht, verpixelt oder weichgezeichnet in den Trainingsprozess einbezogen werden.

Fazit und Ausblick

Durch die Anwendung von GANs kann die Verteilung von Datensätzen beliebiger Art gelernt werden. Wie oben erläutert werden GANs bereits auf verschiedenste Problemstellungen erfolgreich angewendet. Da GANs erst 2014 entdeckt wurden und großes Potential bewiesen haben, wird aktuell sehr intensiv daran geforscht. Die Lösung der oben erwähnten Probleme beim Training, wie z.B. „Mode-Collapse“, sind in der Forschung weit verbreitet. Es wird unter anderem an alternativen Loss-Funktionen und allgemein stabilisierenden Trainingsverfahren geforscht. Ein weiteres aktives Forschungsgebiet ist die Konvergenz von GAN-Netzwerken. Aufgrund der Fortschritte des Generators und Diskriminators während des Trainingsprozesses ist es sehr wichtig, zum richtigen Zeitpunkt das Training zu beenden, um den Generator nicht weiter auf schlechten Diskriminator-Ergebnissen zu trainieren. Um das Training zu stabilisieren, wird zudem an Ansätzen geforscht, die Diskriminator-Inputs mit Rauschen versehen, um die Anpassung des Diskriminators während des Trainings zu limitieren.

Ein modifizierter Ansatz zur Generierung neuer Trainingsdaten sind Generative Teaching Networks. Dabei wird der Fokus des Trainings nicht primär auf das Erlernen der Datenverteilung gelegt, sondern man versucht direkt zu lernen, welche Daten das Training am schnellsten vorantreiben, ohne zwingend Ähnlichkeit mit den Originaldaten vorzuschreiben. [11] Am Beispiel von handgeschriebenen Zahlen konnte gezeigt werden, dass neuronale Netzwerke mit diesen künstlichen Inputdaten schneller lernen können als mit den Originaldaten. Dieser Ansatz lässt sich auch auf andere Datenarten als Bilddaten anwenden. Im Bereich der Anonymisierung ist man bisher imstande, Teile eines Datensatzes mit Persönlichkeitsschutzgarantien zuverlässig zu generieren. Diese Anonymisierungsnetzwerke können weiterentwickelt werden, um mehr Typen von Datensätzen abzudecken.

Im Bereich von GANs werden sehr schnell theoretische Fortschritte gemacht, die bald den Weg in die Praxis finden werden. Da das Verarbeiten von personenbezogenen Daten heute im Bereich des maschinellen Lernens eine große Hürde darstellt, sind GANs für die Anonymisierung von Datensätzen ein vielversprechender Ansatz und haben das Potenzial, für Unternehmen die Tür zur Verarbeitung personenbezogener Daten zu öffnen.

Quellen

  1. http://www.image-net.org/
  2. https://arxiv.org/abs/1710.10196
  3. https://openreview.net/pdf?id=H1xQVn09FX
  4. https://arxiv.org/abs/1709.06298
  5. https://salu133445.github.io/musegan/results
  6. https://storage.googleapis.com/magentadata/papers/gansynth/index.html
  7. https://arxiv.org/abs/1609.04802
  8. https://github.com/iperov/DeepFaceLab
  9. https://arxiv.org/abs/1803.01229
  10. https://arxiv.org/abs/1806.03384
  11. https://eng.uber.com/generative-teaching-networks/

Management Summary

Neuronale Netze haben sich in den letzten Jahren immer weiter zur Kerntechnologie im Bereich Machine Learning und AI entwickelt. Neben den klassischen Anwendungen der Daten-Klassifizierung und -Regression können sie auch dazu verwendet werden, neue, realistische Datenpunkte zu erzeugen. Die etablierte Modell-Architektur für die Datengenerierung sind sogenannte Generative Adversarial Networks (GANs). GANs bestehen aus zwei neuronalen Netzen, dem Generator- und dem Diskriminatornetzwerk. Die beiden Netzwerke werden iterativ gegeneinander trainiert. Hierbei versucht der Generator einen realistischen Datenpunkt zu erzeugen, während der Diskriminator lernt, echte und synthetische Datenpunkte zu unterscheiden. In diesem Minimax-Spiel verbessert der Generator seine Performance so weit, bis die erzeugten Datenpunkte nicht mehr von echten Datenpunkten zu unterscheiden sind.

Die Erzeugung synthetischer Daten mittels GANs hat eine Vielzahl von spannenden Anwendungsmöglichkeiten in der Praxis:

  1. Die Bild-Synthese
    Die ersten Erfolge von GANs wurden auf Basis von Bilddaten erreicht. Heute sind spezifische GAN-Architekturen fähig, Bilder von Gesichtern zu generieren, die kaum von echten unterscheidbar sind.
  2. Die Musik-Synthese
    Neben Bildern können GAN-Modelle auch verwendet werden, um Musik zu generieren. Im Bereich der Musik-Synthese wurden sowohl für das Erzeugen von realistischen Wellenformen als auch für die Generierung ganzer Melodien vielversprechende Resultate erreicht.
  3. Super-Resolution
    Ein weiteres Beispiel für den erfolgreichen Einsatz von GANs ist die Super-Resolution. Hierbei wird versucht, die Auflösung eines Bilds zu verbessern bzw. die Größe des Bildes verlustfrei zu erhöhen. Mittels GANs können schärfere Übergänge zwischen verschiedenen Bildbereichen als mit z.B. Interpolation erreicht und die allgemeine Bildqualität verbessert werden.
  4. Deepfakes
    Mittels Deepfake Algorithmen können Gesichter in Bildern und Videoaufnahmen durch andere Gesichter ersetzt werden. Dabei wird der Algorithmus auf möglichst vielen Daten der Zielperson trainiert. Danach kann damit die Mimik einer dritten Person auf die zu imitierende Person transferiert werden, und es entsteht ein „Deepfake“. 
  5. Generierung zusätzlicher Trainingsdaten
    GANs werden auch dazu verwendet, die im maschinellen Lernen oft nur in geringen Mengen verfügbaren Trainingsdaten zu vermehren (auch Augmentieren genannt). Dabei wird ein GAN auf dem Trainingsdatensatz kalibriert, damit anschließend beliebig viele, neue Trainingsbeispiele generiert werden können. Das Ziel ist, mit den zusätzlich generierten Daten die Performance bzw. Generalisierbarkeit von ML-Modellen zu verbessern.
  6. Anonymisierung von Daten
    Letztlich können mittels GANs auch Daten anonymisiert werden. Dies ist vor allem für Datensätze, die personenbezogene Daten enthalten, ein wichtiger Anwendungsfall. Der Vorteil von GANs gegenüber anderen Ansätzen zur Datenanonymisierung ist, dass die statistischen Eigenschaften des Datensets erhalten bleiben. Dies ist wiederum wichtig für die Performance anderer Machine Learning Modelle, die auf diesen Daten trainiert werden. Da das Verarbeiten von personenbezogenen Daten heute im Bereich des maschinellen Lernens eine große Hürde darstellt, sind GANs für die Anonymisierung von Datensätzen ein vielversprechender Ansatz und haben das Potential, für Firmen die Tür zur Verarbeitung personenbezogener Daten zu öffnen.

In diesem Artikel wird die Funktionsweise von GANs erläutert. Zudem werden Use Cases diskutiert, die mithilfe von GANs umgesetzt werden können, und aktuelle Trends vorgestellt, die sich im Bereich von generativen Netzwerken abzeichnen.

Einleitung

Die größten Fortschritte der letzten Jahre im Bereich der künstlichen Intelligenz wurden durch die Anwendung von neuronalen Netzen erzielt. Diese haben sich insbesondere bei der Verarbeitung von unstrukturierten Daten, wie Texten oder Bildern, als äußerst zuverlässiger Ansatz zur Regression und Klassifizierung erwiesen. Ein bekanntes Beispiel ist die Klassifizierung von Bildern in zwei oder mehr Gruppen. Auf dem sogenannten ImageNet-Datensatz [1] mit insgesamt 1000 Klassen erreichen Neuronale Netze eine Top-5 Accuracy von 98.7%. Das bedeutet, dass für 98.7% der Bilder die richtige Klasse in den fünf besten Modellvorhersagen enthalten ist. Auch im Bereich der natürlichen Sprachverabreitung wurden in den letzten Jahren bahnbrechende Resultate durch die Anwendung von neuronalen Netzen erzielt. Die Transformer-Architektur wurde beispielsweise sehr erfolgreich für verschiedene Sprachverarbeitungsprobleme wie Question-Answering, Named Entity Recognition und Sentimentanalyse verwendet.

Weniger bekannt ist, dass neuronale Netze auch dazu verwendet werden können die zugrundeliegende Verteilung einer Datenmenge zu erlernen, um somit neue, realistische Beispieldaten zu generieren. Als zielführende Modell-Architektur für diese Problemstellung haben sich in den letzten Jahren Generative Adversarial Networks (GANs) etabliert. Durch deren Anwendung ist es möglich, synthetische Datenpunkte zu erzeugen, die die gleichen statistischen Eigenschaften wie die zugrundeliegenden Trainingsdaten aufweisen. Dies eröffnet viele spannende Anwendungsfälle, von denen nachfolgend einige vorgestellt werden.

Generative Adversarial Networks

GANs sind eine neue Klasse von Algorithmen im Bereich des maschinellen Lernens. Wie zuvor erläutert, handelt es sich dabei um Modelle, die neue, realistische Datenpunkte generieren können, nachdem sie auf einem bestimmten Datensatz trainiert wurden. GANs bestehen im Kern aus zwei neuronalen Netzen, die für spezifische Tasks im Lernprozess zuständig sind. Der Generator ist dafür verantwortlich, einen zufällig generierten Input in ein realistisches Sample der zu lernenden Verteilung zu transformieren. Im Gegensatz dazu ist es die Aufgabe des Diskriminators, echte von generierten Datenpunkten zu unterscheiden.

Eine einfache Analogie ist das Wechselspiel eines Fälschers und eines Polizisten. Der Fälscher versucht möglichst realistische Münzfälschungen zu kreieren, während der Polizist die Fälschungen von echten Münzen unterscheiden will. Je besser die Nachbildungen des Fälschers sind, desto besser muss der Polizist die Münzen unterscheiden können, um die gefälschten Exemplare zu erkennen. Andererseits müssen die Fälschungen immer besser werden, da der Polizist mit seiner Erfahrung gefälschte Münzen zuverlässiger erkennen kann. Es ist einleuchtend, dass im Verlaufe dieses Prozesses sowohl der Fälscher als auch der Polizist in ihrer jeweiligen Aufgabe besser werden. Dieser iterative Prozess bildet auch die Grundlage für das Training der GAN-Modelle. Der Fälscher entspricht hier dem Generator, während der Polizist die Rolle des Diskriminators einnimmt.

Wie werden Generator- und Diskriminator-Netzwerke trainiert?

In einem ersten Schritt wird der Diskriminator fixiert. Das bedeutet, dass in diesem Schritt für das Diskriminatornetzwerk keine Anpassungen der Parameter vorgenommen werden. Für eine bestimmte Anzahl an Trainingsschritten wird dann der Generator trainiert. Der Generator wird, wie für neuronale Netze üblich, mittels Backpropagation trainiert. Sein Ziel ist, dass die aus dem zufälligen Input entstehenden „gefälschten“ Outputs von dem aktuellen Diskriminator als echte Beispiele klassifiziert werden. Es ist wichtig zu erwähnen, dass die Trainingsfortschritte vom aktuellen Stand des Diskriminators abhängen.

Im zweiten Schritt wird der Generator fixiert und ausschließlich der Diskriminator trainiert, indem er sowohl echte als auch generierte Beispiele mit entsprechenden Labels als Trainingsinput verarbeitet. Das Ziel des Trainingsprozesses ist, dass der Generator so realistische Beispiele zu erschaffen lernt, dass der Diskriminator sie nicht mehr von echten unterscheiden kann. Zum besseren Verständnis ist der GAN-Trainingsprozess in der Abbildung 1 schematisch abgebildet.

Abbildung 1 – Architektur eines GANs

Herausforderungen des Trainingsprozesses

Aufgrund der Natur des alternierenden Trainingsprozesses können verschiedene Probleme beim Training eines GANs auftreten. Eine häufige Herausforderung ist, dass das Feedback des Diskriminators schlechter wird, je besser der Generator im Verlauf des Trainings wird.

Man kann sich den Prozess folgendermaßen vorstellen: Wenn der Generator Beispiele generieren kann, die von echten Beispielen nicht mehr unterscheidbar sind, bleibt dem Diskriminator nichts anderes übrig, als zu raten aus welcher Klasse das jeweilige Beispiel stammt. Wenn man also das Training nicht rechtzeitig beendet, ist es möglich, dass die Qualität des Generators und des Diskriminators aufgrund der zufälligen Rückgabewerten des Diskriminators wieder abnimmt.

Ein weiteres häufig auftretendes Problem ist der sogenannte „Mode-Collapse“. Dieser tritt auf, wenn das Netzwerk anstatt die Eigenschaften der zugrundeliegenden Daten zu lernen, sich einzelne Beispiele dieser Daten merkt oder nur Beispiele mit geringer Variabilität generiert. Einige Ansätze, um diesem Problem entgegenzuwirken, sind das Verarbeiten mehrerer Beispiele gleichzeitig (in Batches) oder das simultane Zeigen vergangener Beispiele, damit der Diskriminator mangelnde Unterscheidbarkeit der Beispiele quantifizieren kann.

Use Cases für GANs

Die Erzeugung von synthetischen Daten mittels GANs hat eine Vielzahl von spannenden Anwendungsmöglichkeiten in der Praxis. Die ersten eindrucksvollen Resultate von GANs wurden auf Basis von Bilddaten erreicht, dicht gefolgt von der Generierung von Audiodaten. In den letzten Jahren wurden GANs jedoch auch erfolgreich auf andere Datentypen, wie tabellarische Daten, angewendet. Im Folgenden werden ausgewählte Anwendungen vorgestellt.

1.    Die Bild-Synthese

Eine der bekanntesten Anwendungen von GANs ist die Bild-Synthese. Hierbei wird ein GAN auf Basis eines großen Bilddatensatzes trainiert. Dabei erlernt das Generatornetzwerk die wichtigen gemeinsamen Merkmale und Strukturen der Bilder. Ein Beispiel ist die Generierung von Gesichtern. Die Portraits in der Abbildung 2 wurden mittels speziellen GAN-Netzwerken erzeugt, die für Bilddaten optimierte Faltungsnetzwerke, „Convolutional Neural Networks“ (CNN), verwenden. Natürlich können auch andere Arten von Bildern mittels GANs erzeugt werden, wie zum Beispiel handgeschriebene Zeichen, Fotos von Gegenständen oder Häusern. Die grundlegende Voraussetzung für gute Resultate ist ein ausreichend großer Datensatz.

Ein praktischer Anwendungsfall dieser Technologie findet sich im Online Retail. Hier werden mittels GAN Fotos von Models in spezifischen Kleidungsstücken oder bestimmten Posen erzeugt.

Abbildung 2 – GAN-generierte Gesichter

2.    Die Musik-Synthese

Im Gegensatz zur Erzeugung von Bildern beinhaltet die Musik-Synthese eine temporale Komponente. Da Audio-Wellenformen sehr periodisch sind und das menschliche Ohr sehr empfindlich gegen Abweichungen dieser Wellenformen ist, hat die Erhaltung der Signalperiodizität einen hohen Stellenwert für GANs, die Musik generieren. Dies hat zu GAN-Modellen geführt, die anstatt der Wellenform die Magnituden und Frequenzen des Tons generieren. Beispiele für GANs zur Musiksynthese sind das GANSynth [3], welches für das Generieren von realistischen Wellenformen optimiert ist, und das MuseGAN [4] –  seinerseits spezialisiert auf das Generieren von ganzen Melodiesequenzen. Im Internet finden sich diverse Quellen, die von GANs erzeugte Musikstücke präsentieren. [5, 6]  Es ist zu erwarten, dass GANs in den nächsten Jahren für eine hohe Disruption in der Musikbranche führen werden.

Die Abbildung 3 zeigt Ausschnitte der vom MuseGAN generierten Musik für verschiedene Trainingsschritte. Zu Beginn des Trainings (step 0) ist der Output rein zufällig und im letzten Schritt (step 7900) erkennt man, dass z.B. der Bass eine typische Basslinie spielt und die Streicher vorwiegend gehaltene Akkorde spielen.

Abbildung 3 – MuseGAN Trainingsverlauf

3.    Die Super-Resolution

Der Prozess der Wiedergewinnung eines hochauflösenden Bildes aus einem tiefer aufgelösten Bild heißt Super-Resolution. Die Schwierigkeit hierbei besteht darin, dass es sehr viele mögliche Lösungen für die Wiederherstellung des Bildes gibt.

Ein klassischer Ansatz für Super-Resolution ist Interpolation. Dabei werden die für die höhere Auflösung fehlenden Pixel mittels einer vorgegebenen Funktion von den benachbarten Pixeln abgeleitet. Dabei gehen aber typischerweise viele Details verloren und die Übergänge zwischen verschiedenen Bildteilen verschwimmen. Da GANs die Hauptmerkmale von Bilddaten gut erlernen können, sind sie auch für das Problem der Super-Resolution ein geeigneter Ansatz. Dabei wird für das Generator-Netzwerk nicht mehr ein zufälliger Input verwendet, sondern das niedrigaufgelöste Bild. Der Diskriminator lernt dann generierte Bilder höherer Auflösung von den echten Bildern höherer Auflösung zu unterscheiden. Das trainierte GAN-Netzwerk kann dann auf neuen tief aufgelösten Bildern verwendet werden.

Ein State of the Art GAN für Super-Resolution ist das SRGAN [7]. In der Abbildung 4 sind die Resultate für verschiedene Super-Resolution-Ansätze für ein Beispielbild abgebildet. Das SRGAN (zweites Bild von rechts) liefert das schärfste Resultat. Besonders Details wie Wassertropfen oder die Oberflächenstruktur des Kopfschmucks werden vom SRGAN erfolgreich rekonstruiert.

Abbildung 4 – Beispiel Super-Resolution für 4x Upscaling (von links nach rechts: bikubische Interpolation, SRResNet, SRGAN, Originalbild)

4.    Deepfakes: Audio und Video

Mittels Deepfake-Algorithmen können Gesichter in Bildern und Videoaufnahmen durch andere Gesichter ersetzt werden. Die Resultate sind mittlerweile so gut, dass die Deepfakes nur schwer von echten Aufnahmen unterschieden werden können. Um einen Deepfake herzustellen wird wird der Algorithmus zunächst auf möglichst vielen Daten der Zielperson trainiert, damit anschließend die Mimik einer dritten Person auf die zu imitierende Person transferiert werden kann. In der Abbildung 5 sehen Sie den Vorher-nachher-Vergleich eines Deepfakes. Dabei wurde das Gesicht des Schauspielers Matthew McConaughey (Bildausschnit aus dem Film „Interstellar“) (links) durch das Gesicht des Tech-Unternehmers Elon Musk ersetzt (rechts).

Diese Technik lässt sich beispielsweise auch auf Audiodaten anwenden. Dabei wird das Modell auf eine Zielstimme trainiert, um dann die Stimme in der Originalaufnahme durch die Zielstimme zu ersetzen.

Abbildung 5 – Deepfake Sample Images

5.    Generierung zusätzlicher Trainingsdaten

Um tiefe neuronale Netze mit vielen Parametern trainieren zu können, werden i.d.R. sehr große Datenmengen benötigt. Oft ist es nur schwer oder gar nicht möglich, an eine ausreichend große Datenmengen zu gelangen oder diese zu erheben. Ein Ansatz, um auch mit weniger Daten ein gutes Resultat zu erzielen, ist die Data Augmentation. Dabei werden die vorhandenen Datenpunkte leicht abgeändert, um somit neue Trainingsbeispiele zu kreieren. Häufig wird dies bspw. im Bereich von Computer Vision angewendet, indem Bilder rotiert oder mittels Zoom neue Bildausschnitte abgeleitet werden. Ein neuerer Ansatz für Data Augmentation ist das Verwenden von GANs. Die Idee ist, dass GANs die Verteilung der Trainingsdaten erlernen und somit theoretisch unendlich viele neue Beispiele generiert werden können. Für die Erkennung von Krankheiten auf Tomographie-Bildern wurde dieser Ansatz beispielsweise erfolgreich umgesetzt [9]. GANs können also dazu verwendet werden, neue Trainingsdaten für Probleme zu generieren, bei denen nicht genügend große Datenmengen für das Training vorhanden sind.

In vielen Fällen ist das Erheben größerer Datenmengen auch mit erheblichen Kosten verbunden, insbesondere dann, wenn die Daten für das Training manuell gekennzeichnet werden müssen. In diesen Situationen können GANs verwendet werden, um die anfallenden Kosten einer zusätzlichen Datenakquise zu reduzieren.

Ein weiteres technisches Problem beim Training von Machine Learning Modellen, bei dem GANs Abhilfe schaffen können, sind Imbalanced Datasets. Dabei handelt es sich um Datensätze mit verschiedenen Klassen, die unterschiedlich häufig repräsentiert sind. Um mehr Trainingsdaten der unterrepräsentierten Klasse zu erhalten, kann man für dieser Klasse ein GAN trainieren, um weitere synthetisch generierte Datenpunkte zu erzeugen, die dann im Training verwendet werden. Beispielsweise gibt es deutlich weniger Mikroskop-Aufnahmen von krebsbefallenen Zellen als von gesunden. GANs ermöglichen in diesem Fall, bessere Modelle zur Erkennung von Krebszellen zu trainieren und können so Ärzte bei der Krebsdiagnose unterstützen.

6.    Anonymisierung von Daten

Ein weiterer interessanter Anwendungsbereich von GANs ist die Anonymisierung von Datensätzen. Klassische Ansätze für die Anonymisierung sind das Entfernen von Identifikator-Spalten oder das zufällige Ändern ihrer Werte. Resultierende Probleme sind bspw., dass mit entsprechendem Vorwissen trotzdem Rückschlüsse auf die Personen hinter den personenbezogenen Daten gezogen werden können. Auch die Änderung der statistischen Eigenschaften des Datensatzes durch das Fehlen oder Abändern bestimmter Informationen kann den Nutzen der Daten schmälern. GANs können auch dazu verwendet werden, anonymisierte Datensätze zu generieren. Sie werden dabei so trainiert, dass die persönlichen Daten im generierten Datensatz nicht mehr identifiziert, aber Modelle trotzdem ähnlich gut damit trainiert werden können [10]. Erklären kann man diese gleichbleibende Qualität der Modelle damit, dass die zugrundeliegenden statistischen Eigenschaften des ursprünglichen Datensatzes von dem GAN gelernt, und somit auch erhalten werden.

Anonymisierung mithilfe von GANs kann neben tabellarischen Daten auch spezifisch für Bilder angewendet werden. Dabei werden Gesichter oder andere persönliche Merkmale/Elemente auf dem Bild durch generierte Varianten ersetzt. Dies erlaubt, Computer Vision Modelle mit realistisch aussehenden Daten zu trainieren, ohne mit Datenschutzproblemen konfrontiert zu werden. Oft wird die Qualität der Modelle deutlich schlechter, wenn wichtige Merkmale eines Bilds, wie beispielsweise ein Gesicht, verpixelt oder weichgezeichnet in den Trainingsprozess einbezogen werden.

Fazit und Ausblick

Durch die Anwendung von GANs kann die Verteilung von Datensätzen beliebiger Art gelernt werden. Wie oben erläutert werden GANs bereits auf verschiedenste Problemstellungen erfolgreich angewendet. Da GANs erst 2014 entdeckt wurden und großes Potential bewiesen haben, wird aktuell sehr intensiv daran geforscht. Die Lösung der oben erwähnten Probleme beim Training, wie z.B. „Mode-Collapse“, sind in der Forschung weit verbreitet. Es wird unter anderem an alternativen Loss-Funktionen und allgemein stabilisierenden Trainingsverfahren geforscht. Ein weiteres aktives Forschungsgebiet ist die Konvergenz von GAN-Netzwerken. Aufgrund der Fortschritte des Generators und Diskriminators während des Trainingsprozesses ist es sehr wichtig, zum richtigen Zeitpunkt das Training zu beenden, um den Generator nicht weiter auf schlechten Diskriminator-Ergebnissen zu trainieren. Um das Training zu stabilisieren, wird zudem an Ansätzen geforscht, die Diskriminator-Inputs mit Rauschen versehen, um die Anpassung des Diskriminators während des Trainings zu limitieren.

Ein modifizierter Ansatz zur Generierung neuer Trainingsdaten sind Generative Teaching Networks. Dabei wird der Fokus des Trainings nicht primär auf das Erlernen der Datenverteilung gelegt, sondern man versucht direkt zu lernen, welche Daten das Training am schnellsten vorantreiben, ohne zwingend Ähnlichkeit mit den Originaldaten vorzuschreiben. [11] Am Beispiel von handgeschriebenen Zahlen konnte gezeigt werden, dass neuronale Netzwerke mit diesen künstlichen Inputdaten schneller lernen können als mit den Originaldaten. Dieser Ansatz lässt sich auch auf andere Datenarten als Bilddaten anwenden. Im Bereich der Anonymisierung ist man bisher imstande, Teile eines Datensatzes mit Persönlichkeitsschutzgarantien zuverlässig zu generieren. Diese Anonymisierungsnetzwerke können weiterentwickelt werden, um mehr Typen von Datensätzen abzudecken.

Im Bereich von GANs werden sehr schnell theoretische Fortschritte gemacht, die bald den Weg in die Praxis finden werden. Da das Verarbeiten von personenbezogenen Daten heute im Bereich des maschinellen Lernens eine große Hürde darstellt, sind GANs für die Anonymisierung von Datensätzen ein vielversprechender Ansatz und haben das Potenzial, für Unternehmen die Tür zur Verarbeitung personenbezogener Daten zu öffnen.

Quellen

  1. http://www.image-net.org/
  2. https://arxiv.org/abs/1710.10196
  3. https://openreview.net/pdf?id=H1xQVn09FX
  4. https://arxiv.org/abs/1709.06298
  5. https://salu133445.github.io/musegan/results
  6. https://storage.googleapis.com/magentadata/papers/gansynth/index.html
  7. https://arxiv.org/abs/1609.04802
  8. https://github.com/iperov/DeepFaceLab
  9. https://arxiv.org/abs/1803.01229
  10. https://arxiv.org/abs/1806.03384
  11. https://eng.uber.com/generative-teaching-networks/