Data Science, Machine Learning & AI
Kontakt

Management Summary

Deploying and monitoring machine learning projects is a complex undertaking. In addition to the consistent documentation of model parameters and the associated evaluation metrics, the main challenge is to transfer the desired model into a productive environment. If several people are involved in the development, additional synchronization problems arise concerning the models’ development environments and version statuses. For this reason, tools for the efficient management of model results through to extensive training and inference pipelines are required. In this article, we present the typical challenges along the machine learning workflow and describe a possible solution platform with MLflow. In addition, we present three different scenarios that can be used to professionalize machine learning workflows:

  1. Entry-level Variant: Model parameters and performance metrics are logged via a R/Python API and clearly presented in a GUI. In addition, the trained models are stored as artifacts and can be made available via APIs.
  2. Advanced Model Management: In addition to tracking parameters and metrics, certain models are logged and versioned. This enables consistent monitoring and simplifies the deployment of selected model versions.
  3. Collaborative Workflow Management: Encapsulating Machine Learning projects as packages or Git repositories and the accompanying local reproducibility of development environments enable smooth development of Machine Learning projects with multiple stakeholders.

Depending on the maturity of your machine learning project, these three scenarios can serve as inspiration for a potential machine learning workflow. We have elaborated each scenario in detail for better understanding and provide recommendations regarding the APIs and deployment environments to use.

Challenges Along the Machine Learning Workflow

Training machine learning models is becoming easier and easier. Meanwhile, a variety of open-source tools enable efficient data preparation as well as increasingly simple model training and deployment.

The added value for companies comes primarily from the systematic interaction of model training, in the form of model identification, hyperparameter tuning and fitting on the training data, and deployment, i.e., making the model available for inference tasks. This interaction is often not established as a continuous process, especially in the early phases of machine learning initiative development. However, a model can only generate added value in the long term if a stable production process is implemented from model training, through its validation, to testing and deployment. If this process is implemented correctly, complex dependencies and costly maintenance work in the long term can arise during the operational start-up of the model [2]. The following risks are particularly noteworthy in this regard.

1. Ensuring Synchronicity

Often, in an exploratory context, data preparation and modeling workflows are developed locally. Different configurations of development environments or even the use of different technologies make it difficult to reproduce results, especially between developers or teams. In addition, there are potential dangers concerning the compatibility of the workflow if several scripts must be executed in a logical sequence. Without an appropriate version control logic, the synchronization effort afterward can only be guaranteed with great effort.

2. Documentation Effort

To evaluate the performance of the model, model metrics are often calculated following training. These depend on various factors, such as the parameterization of the model or the influencing factors used. This meta-information about the model is often not stored centrally. However, for systematic further development and improvement of a model, it is mandatory to have an overview of the parameterization and performance of all past training runs.

3. Heterogeneity of Model Formats

In addition to managing model parameters and results, there is the challenge of subsequently transferring the model to the production environment. If different models from multiple packages are used for training, deployment can quickly become cumbersome and error-prone due to different packages and versions.

4. Recovery of Prior Results

In a typical machine learning project, the situation often arises that a model is developed over a long period of time. For example, new features may be used, or entirely new architectures may be evaluated. These experiments do not necessarily lead to better results. If experiments are not versioned cleanly, there is a risk that old results can no longer be reproduced.

Various tools have been developed in recent years to solve these and other challenges in the handling and management of machine learning workflows, such as TensorFlow TFX, cortex, Marvin, or MLFlow. The latter, in particular, is currently one of the most widely used solutions.

MLflow is an open-source project with the goal to combine the best of existing ML platforms to make the integration to existing ML libraries, algorithms, and deployment tools as straightforward as possible [3]. In the following, we will introduce the main MLflow modules and discuss how machine learning workflows can be mapped via MLflow.

MLflow Services

MLflow consists of four components: MLflow Tracking, MLflow Models, MLflow Projects, and MLflow Registry. Depending on the requirements of the experimental and deployment scenario, all services can be used together, or individual components can be isolated.

With MLflow Tracking, all hyperparameters, metrics (model performance), and artifacts, such as charts, can be logged. MLflow Tracking provides the ability to collect presets, parameters, and results for collective monitoring for each training or scoring run of a model. The logged results can be visualized in a GUI or alternatively accessed via a REST API.

The MLflow Models module acts as an interface between technologies and enables simplified deployment. Depending on its type, a model is stored as a binary, e.g., a pure Python function, or as a Keras or H2O model. One speaks here of the so-called model flavors. Furthermore, MLflow Models provides support for model deployment on various machine learning cloud services, e.g., for AzureML and Amazon Sagemaker.

MLflow Projects are used to encapsulate individual ML projects in a package or Git repository. The basic configurations of the respective environment are defined via a YAML file. This can be used, for example, to control how exactly the conda environment is parameterized, which is created when MLflow is executed. MLflow Projects allows experiments that have been developed locally to be executed on other computers in the same environment. This is an advantage, for example, when developing in smaller teams.

MLflow Registry provides a centralized model management. Selected MLflow models can be registered and versioned in it. A staging workflow enables a controlled transfer of models into the productive environment. The entire process can be controlled via a GUI or a REST API.

Examples of Machine Learning Pipelines Using MLflow

In the following, three different ML workflow scenarios are presented using the above MLflow modules. These increase in complexity from scenario to scenario. In all scenarios, a dataset is loaded into a development environment using a Python script, processed, and a machine learning model is trained. The last step in all scenarios is a deployment of the ML model in an exemplary production environment.

1. Scenario – Entry-Level Variant

Szenario 1 – Simple Metrics TrackingScenario 1 – Simple Metrics Tracking

Scenario 1 uses the MLflow Tracking and MLflow Models modules. Using the Python API, the model parameters and metrics of the individual runs can be stored on the MLflow Tracking Server Backend Store, and the corresponding MLflow Model File can be stored as an artifact on the MLflow Tracking Server Artifact Store. Each run is assigned to an experiment. For example, an experiment could be called ‘fraud_classification’, and a run would be a specific ML model with a certain hyperparameter configuration and the corresponding metrics. Each run is stored with a unique RunID.

Artikel MLFlow Tool Bild 01

In the screenshot above, the MLflow Tracking UI is shown as an example after executing a model training. The server is hosted locally in this example. Of course, it is also possible to host the server remotely. For example in a Docker container within a virtual machine. In addition to the parameters and model metrics, the time of the model training, as well as the user and the name of the underlying script, are also logged. Clicking on a specific run also displays additional information, such as the RunID and the model training duration.

Artikel MLFlow Tool Bild 02

If you have logged other artifacts in addition to the metrics, such as the model, the MLflow Model Artifact is also displayed in the Run view. In the example, a model from the sklearn.svm package was used. The MLmodel file contains metadata with information about how the model should be loaded. In addition to this, a conda.yaml is created that contains all the package dependencies of the environment at training time. The model itself is located as a serialized version under model.pkl and contains the model parameters optimized on the training data.

Artikel MLFlow Tool Bild 03

The deployment of the trained model can now be done in several ways. For example, suppose one wants to deploy the model with the best accuracy metric. In that case, the MLflow tracking server can be accessed via the Python API mlflow.list_run_infos to identify the RunID of the desired model. Now, the path to the desired artifact can be assembled, and the model loaded via, for example, the Python package pickle. This workflow can now be triggered via a Dockerfile, allowing flexible deployment to the infrastructure of your choice. MLflow offers additional separate APIs for deployment on Microsoft Azure and AWS. For example, if the model is to be deployed on AzureML, an Azure ML container image can be created using the Python API mlflow.azureml.build_image, which can be deployed as a web service to Azure Container Instances or Azure Kubernetes Service. In addition to the MLflow Tracking Server, it is also possible to use other storage systems for the artifact, such as Amazon S3, Azure Blob Storage, Google Cloud Storage, SFTP Server, NFS, and HDFS.

2. Scenario – Advanced Model Management

Szenario 2 – Advanced Model ManagementScenario 2 – Advanced Model Management

Scenario 2 includes, in addition to the modules used in scenario 1, MLflow Model Registry as a model management component. Here, it is possible to register and process the models logged there from specific runs. These steps can be controlled via the API or GUI. A basic requirement to use the Model Registry is deploying the MLflow Tracking Server Backend Store as Database Backend Store. To register a model via the GUI, select a specific run and scroll to the artifact overview.

Artikel MLFlow Tool Bild 04

Clicking on Register Model opens a new window in which a model can be registered. If you want to register a new version of an already existing model, select the desired model from the dropdown field. Otherwise, a new model can be created at any time. After clicking the Register button, the previously registered model appears in the Models tab with corresponding versioning.

Artikel MLFlow Tool Bild 05

Each model includes an overview page that shows all past versions. This is useful, for example, to track which models were in production when.

Artikel MLFlow Tool Bild 06

If you now select a model version, you will get to an overview where, for example, a model description can be added. The Source Run link also takes you to the run from which the model was registered. Here you will also find the associated artifact, which can be used later for deployment.

Artikel MLFlow Tool Bild 07

In addition, individual model versions can be categorized into defined phases in the Stage area. This feature can be used, for example, to determine which model is currently being used in production or is to be transferred there. For deployment, in contrast to scenario 1, versioning and staging status can be used to identify and deploy the appropriate model. For this, the Python API MlflowClient().search_model_versions can be used, for example, to filter the desired model and its associated RunID. Similar to scenario 1, deployment can then be completed to, for example, AWS Sagemaker or AzureML via the respective Python APIs.

3. Scenario – Collaborative Workflow Management

Szenario 3 – Full Workflow ManagementScenario 3 – Full Workflow Management

In addition to the modules used in scenario 2, scenario 3 also includes the MLflow Projects module. As already explained, MLflow Projects are particularly well suited for collaborative work. Any Git repository or local environment can act as a project and be controlled by an MLproject file. Here, package dependencies can be recorded in a conda.yaml, and the MLproject file can be accessed when starting the project. Then the corresponding conda environment is created with all dependencies before training and logging the model. This avoids the need for manual alignment of the development environments of all developers involved and also guarantees standardized and comparable results of all runs. Especially the latter is necessary for the deployment context since it cannot be guaranteed that different package versions produce the same model artifacts. Instead of a conda environment, a Docker environment can also be defined using a Dockerfile. This offers the advantage that package dependencies independent of Python can also be defined. Likewise, MLflow Projects allow the use of different commit hashes or branch names to use other project states, provided a Git repository is used.

An interesting use case is the modularized development of machine learning training pipelines [4]. For example, data preparation can be decoupled from model training and developed in parallel, while another team uses a different branch name to train the model. In this case, only a different branch name must be used as a parameter when starting the project in the MLflow Projects file. The final data preparation can then be pushed to the same branch name used for model training and would thus already be fully implemented in the training pipeline. The deployment can also be controlled as a sub-module within the project pipeline through a Python script via the ML Project File and can be carried out analogous to scenario 1 or 2 on a platform of your choice.

Conclusion and Outlook

MLflow offers a flexible way to make the machine learning workflow robust against the typical challenges in the daily life of a data scientist, such as synchronization problems due to different development environments or missing model management. Depending on the maturity level of the existing machine learning workflow, various services from the MLflow portfolio can be used to achieve a higher level of professionalization.

In the article, three machine learning workflows, ascending in complexity, were presented as examples. From simple logging of results in an interactive UI to more complex, modular modeling pipelines, MLflow services can support it. Logically, there are also synergies outside the MLflow ecosystem with other tools, such as Docker/Kubernetes for model scaling or even Jenkins for CI/CD pipeline control. If there is further interest in MLOps challenges and best practices, I refer you to the webinar on MLOps by our CEO Sebastian Heinz, which we provide free of charge.

Resources

John Vicente

Management Summary

OCR (Optical Character Recognition) ist eine große Herausforderung für viele Unternehmen. Am OCR-Markt tummeln sich diverse Open Source sowie kommerzielle Anbieter. Ein bekanntes Open Source Tool für OCR ist Tesseract, das mittlerweile von Google bereitgestellt wird. Tesseract ist aktuell in der Version 4 verfügbar, die die OCR Extraktion mittels rekurrenten neuronalen Netzen durchführt. Die OCR Performance von Tesseract ist nach wie vor jedoch volatil und hängt von verschiedenen Faktoren ab. Eine besondere Herausforderung ist die Anwendung von Tesseract auf Dokumente, die aus verschiedenen Strukturen aufgebaut sind, z.B. Texten, Tabellen und Bildern. Eine solche Dokumentenart stellen bspw. Rechnungen dar, die OCR Tools aller Anbieter nach wie vor besondere Herausforderungen stellen. In diesem Beitrag wird demonstriert, wie ein Finetuning der Tesseract-OCR (Optical Character Recognition) Engine auf einer kleinen Stichprobe von Daten bereits eine erhebliche Verbesserung der OCR-Leistung auf Rechnungsdokumenten bewirken kann. Dabei ist der dargestellte Prozess nicht ausschließlich auf Rechnungen anwendbar sondern auf beliebige Dokumentenarten. Es wird ein Anwendungsfall definiert, der auf eine korrekte Extraktion des gesamten Textes (Wörter und Zahlen) aus einem fiktiven, aber realistischen deutschen Rechnungsdokument abzielt. Es wird hierbei angenommen, dass die extrahierten Informationen für nachgelagerte Buchhaltungszwecke bestimmt sind. Daher wird eine korrekte Extraktion der Zahlen sowie des Euro-Zeichens als kritisch angesehen. Die OCR-Leistung von zwei Tesseract-Modellen für die deutsche Sprache wird verglichen: das Standardmodell (nicht getuned) und eine finegetunete Variante. Das Standardmodell wird aus dem Tesseract OCR GitHub Repository bezogen. Das feinabgestimmte Modell wird mit denen in diesem Artikel beschriebenen Schritten entwickelt. Eine zweite deutsche Rechnung ähnlich der ersten wird für die Feinabstimmung verwendet. Sowohl das Standardmodell als auch das getunte Modell werden auf der gleichen Out-of-Sample Rechnung bewertet, um einen fairen Vergleich zu gewährleisten. Die OCR-Leistung des Tesseract Standardmodells ist bei Zahlen vergleichsweise schlecht. Dies gilt insbesondere für Zahlen, die den Zahlen 1 und 7 ähnlich sind. Das Euro-Symbol wird in 50% der Fälle falsch erkannt, sodass das Ergebnis für eine etwaig nachgelagerte Buchhaltungsanwendung ungeeignet ist. Das getunte Modell zeigt eine ähnliche OCR-Leistung für deutsche Wörter. Die OCR-Leistung bei Zahlen verbessert sich jedoch deutlich. Alle Zahlen und jedes Euro-Symbol werden korrekt extrahiert.  Es zeigt sich, dass eine Feinabstimmung mit minimalem Aufwand und einer geringen Menge an Schulungsdaten eine große Verbesserung der Erkennungsleistung erzielen kann. Dadurch wird Tesseract OCR mit seiner Open-Source-Lizenzierung zu einer attraktiven Lösung im Vergleich zu proprietärer OCR-Software. Weiterhin werden abschließende Empfehlungen für das Finetuning von Tesseract LSTM-Modellen dargestellt, für den Fall, dass mehr Trainingsdaten vorliegen.

Download des Tesseract Docker Containers

Der gesamte Finetuning-Prozess des LSTM-Modells von Tesseract wird im Folgenden ausführlich erörtert. Da die Installation und Anwendung von Tesseract kompliziert werden kann, haben wir einen Docker Container vorbereitet, der alle nötigen Installationen bereits enthält. [contact-form-7 404 "Not Found"]

Einführung

Tesseract 4 mit seiner LSTM-Engine funktioniert out-of-the-box für einfache Texte bereits recht gut. Es gibt jedoch Szenarien, für die das Standardmodell schlecht abschneidet. Beispiele hierfür sind exotische Schriftarten, Bilder mit Hintergründen oder Text in Tabellen.  Glücklicherweise bietet Tesseract eine Möglichkeit zum Finetuning der LSTM-Engine, um die OCR-Leistung für speziellere Anwendungsfälle zu verbessern.

Warum OCR für Rechnungen eine Herausforderung ist

Auch wenn OCR in Teilbereichen als ein gelöstes Problem gilt, stellt die fehlerfreie Extraktion eines großen Textkorpus nach wie vor eine Herausforderung dar. Dies gilt insbesondere für OCR auf Dokumenten, die eine hohe strukturelle Varianz aufweisen, wie bspw. Rechnungsdokumente. Diese bestehen häufig aus unterschiedlichsten Elementen, die OCR-Engine von Tesseract for Herausforderungen stellen: 1. Farbige Hintergründe und Tabellenstrukturen stellen eine Herausforderung für die Seitensegmentierung dar. 2. Rechnungen enthalten normalerweise seltene Zeichen wie das EUR- oder USD-Zeichen 3. Zahlen können nicht mit einem Sprachwörterbuch überprüft werden. Darüber hinaus ist die Fehlermarge gering: Häufig ist eine exakte Extraktion der numerischen Daten für nachfolgenden Prozessschritte von größter Bedeutung. Problem (1) lässt sich in der Regel dadurch lösen, dass man eine der 14 von Tesseract bereitgestellten Segmentierungsmodus auswählt. Die beiden letztgenannten Probleme lassen sich häufig durch ein Finetuning der LSTM-Engine auf Basis von Beispielen ähnlicher Dokumente lösen.

Use Case Zielsetzung und Daten

Zwei ähnliche Beispielrechnungen werden in dem Artikel näher betrachtet. Die in Abbildung 1 gezeigte Rechnung wird zur Bewertung der OCR-Leistung sowohl für das Standard- als auch des feingetunte Tesseract-Modell verwendet. Besondere Aufmerksamkeit wird der korrekten Extraktion von Zahlen gewidmet. Die in Abbildung 2 gezeigte, zweite Rechnung wird zum Finetuning des LSTM Modells verwendet. Die meisten Rechnungsdokumente sind in einer sehr gut lesbaren Schriftart wie “Arial” geschrieben. Um die Vorteile des Tunings zu veranschaulichen, wird das anfängliche OCR-Problem durch die Berücksichtigung von Rechnungen, die in der Schriftart “Impact” geschrieben sind, erschwert. „Impact“ ist eine Schriftart, die sich deutlich von normalen serifenlosen Schriften unterscheidet, und zu einer höheren Fehlerkennung für Tesseract führt. Es wird im Folgenden gezeigt, dass Tesseract nach der Feinabstimmung auf Basis einer sehr kleinen Datenmenge trotz dieser schwierigen Schriftart sehr zufriedenstellende Ergebnisse liefert.
Abbildung 1: Rechnung 1, die zur Evaluierung der OCR Performance beider Modelle verwendet wird
Abbildung 2: Rechnung 2, die zum Finetuning der LSTM Engine verwendet wird

Verwendung des Tesseract 4.0 Docker Containers

Die Einrichtung zum Finetuning der Tesseract-LSTM-Engine funktioniert derzeit nur unter Linux und kann etwas knifflig sein. Daher wird zusammen mit diesem Artikel ein Docker-Container mit vorinstalliertem Tesseract 4.0 sowie mit den kompilierten Trainings-Tools und Skripten bereitgestellt. Laden Sie das Docker-Image aus der bereitgestellten Archivdatei oder pullen Sie das Container-Image über den bereitgestellten Link:
docker load -i docker/tesseract_image.tar
Sobald das image aufgebaut ist, starten Sie den Container im “detached” Modus:
docker run -d --rm --name tesseract_container tesseract:latest
Greifen Sie auf die Shell des laufenden Containers zu, um die folgenden Befehle in diesem Artikel zu replizieren:
docker exec -it tesseract_container /bin/bash

Allgemeine Verbesserungen der OCR Performance

Es gibt drei Möglichkeiten, wie die OCR-Leistung von Tesseract verbessert werden kann, noch bevor ein Finetuning der LSTM-Engine vorgenommen wird.

1. Preprocessing der Bilder

Gescannte Dokumente können eine schiefe Ausrichtung haben, wenn sie auf dem Scanner nicht richtig platziert wurden. Gedrehte Bilder sollten entzerrt werden, um die Liniensegmentierungsleistung von Tesseract zu optimieren. Darüber hinaus kann beim Scannen ein Bildrauschen entstehen, das durch einen Rauschunterdrückungsalgorithmus entfernt werden sollte. Beachten Sie, dass Tesseract standardmäßig eine Schwellenwertbildung unter Verwendung des Otsu-Algorithmus durchführt, um Graustufenbilder in schwarze und weiße Pixel zu binärisieren. Eine detaillierte Behandlung der Bildvorverarbeitung würde den Rahmen dieses Artikels sprengen und ist nicht notwendig, um für den gegebenen Anwendungsfall zufriedenstellende Ergebnisse zu erzielen. Die Tesseract-Dokumentation bietet einen praktischen Überblick.

2. Seitensegmentierung

Während der Seitensegmentierung versucht Tesseract, rechteckige Textbereiche zu identifizieren. Nur diese Bereiche werden im nächsten Schritt für die OCR ausgewählt. Es ist daher wichtig, alle Regionen mit Text zu erfassen, damit keine Informationen verloren gehen. Tesseract ermöglicht die Auswahl aus 14 verschiedenen Seitensegmentierungsmethoden, die mit dem folgenden Befehl angezeigt werden können:
tesseract --help-psm
Die Standard-Segmentierungsmethode erwartet eine Bild ähnlich zu einer Buchseite. Dieser Modus kann jedoch aufgrund der zusätzlichen tabellarischen Strukturen in Rechnungsdokumenten nicht alle Textbereiche korrekt identifizieren. Eine bessere Segmentierungsmethode ist durch Option 4 gegeben: „Assume a single column of text of variable sizes“. Um die Bedeutung einer geeigneten Seitensegmentierungsmethode zu veranschaulichen, betrachten wir das Ergebnis der Verwendung der Standardmethode “Fully automatic page segmentation, but no OSD ” in Abbildung 3:
Abbildung 3: Die Standard-Segmentierungsmethode kann nicht alle Textbereiche erkennen
Beachten Sie, dass die Texte “Rechnungsinformationen:”, “Pos.” und “Produkt” nicht segmentiert wurden. In Abbildung 4 führt eine geeignetere Methode zu einer perfekten Segmentierung der Seite.

3. Verwendung von Dictionaries, Wortlisten und Mustern für den Text

Die von Tesseract verwendeten LSTM-Modelle wurden auf Basis von großen Textmengen in einer bestimmten Sprache trainiert. Dieser Befehl zeigt die Sprachen an, die derzeit für Tesseract verfügbar sind:
tesseract --list-langs 
Weitere Sprachmodelle sind verfügbar, indem die entsprechenden language.tessdata heruntergelden und in den Ordner tessdata der lokalen Tesseract-Installation abgelegt werden. Das Tesseract-Repository auf GitHub stellt drei Varianten von Sprachmodellen zur Verfügung: normal, fast und best. Nur die schnelle sowie die beste Variante sind für ein Finetuning verwendbar. Wie der Name schon sagt, handelt es sich dabei um die schnellsten bzw. genauesten Varianten von Modellen. Weitere Modelle wurden ebenfalls für spezielle Anwendungsfälle wie die ausschließliche Erkennung von Ziffern und Interpunktion trainiert und sind in den Referenzen aufgeführt. Da die Sprache der Rechnungen in diesem Anwendungsfall Deutsch ist, wird das zu diesem Artikel gehörende Docker-Image mit dem deu.tessdata-Modell geliefert. Für eine bestimmte Sprache kann die Wortliste von Tesseract weiter ausgebaut oder auf bestimmte Wörter oder sogar Zeichen beschränkt werden. Dieses Thema liegt außerhalb des Rahmens dieses Artikels, da es nicht notwendig ist, um für den vorliegenden Anwendungsfall zufriedenstellende Ergebnisse zu erzielen.

Setup des Finetuning-Prozesses

Für das Finetuning müssen drei Dateitypen erstellt werden:

1. tiff-Dateien

Tagged Image File Format oder TIFF ist ein unkomprimiertes Bilddateiformat (im Gegensatz zu JPG oder PNG, die komprimierte Dateiformate sind). TIFF-Dateien können mit einem Konvertierungswerkzeug aus PNG- oder JPG-Formaten gewonnen werden. Obwohl Tesseract mit PNG- und JPG-Bildern arbeiten kann, wird das TIFF-Format empfohlen.

2. Box-Dateien

Zum Trainieren des LSTM-Modells verwendet Tesseract so genannte Box-Dateien mit der Erweiterung “.box”. Eine Box-Datei enthält den erkannten Text zusammen mit den Koordinaten der Bounding Box, in der sich der Text befindet. Box-Dateien enthalten sechs Spalten, die korrespondieren zu Symbol, Links, Unten, Rechts, Oben und Seite:
P 157 2566 1465 2609 0
r 157 2566 1465 2609 0
o 157 2566 1465 2609 0
d 157 2566 1465 2609 0
u 157 2566 1465 2609 0
k 157 2566 1465 2609 0
t 157 2566 1465 2609 0
  157 2566 1465 2609 0
P 157 2566 1465 2609 0
r 157 2566 1465 2609 0
e 157 2566 1465 2609 0
i 157 2566 1465 2609 0
s 157 2566 1465 2609 0
  157 2566 1465 2609 0
( 157 2566 1465 2609 0
N 157 2566 1465 2609 0
e 157 2566 1465 2609 0
t 157 2566 1465 2609 0
t 157 2566 1465 2609 0
o 157 2566 1465 2609 0
) 157 2566 1465 2609 0
  157 2566 1465 2609 0
Jedes Zeichen befindet sich auf einer separaten Zeile in der Box-Datei. Das LSTM-Modell akzeptiert entweder die Koordinaten einzelner Zeichen oder einer ganzen Textzeile. In der obigen Beispiel-Box-Datei befindet sich der Text “Produkt Preis (Netto)” optisch auf der gleichen Zeile im Dokument. Alle Zeichen haben die gleichen Koordinaten, nämlich die Koordinaten des Begrenzungsrahmens um diese Textzeile herum. Die Verwendung von Koordinaten auf Zeilenebene ist wesentlich einfacher und wird standardmäßig bereitgestellt, wenn die Box-Datei mit dem folgenden Befehl erzeugt wird:
cd /home/fine_tune/train
tesseract train_invoice.tiff train_invoice --psm 4 -l best/deu lstmbox
Das erste Argument ist die zu extrahierende Bilddatei, das zweite Argument stellt den Dateinamen der Box-Datei dar. Der Sprachparameter -l weist Tesseract an, das deutsche Modell für die OCR zu verwenden. Der Parameter –psm weist Tesseract an, das vierte Seitensegmentierungsverfahren zu verwenden. Nahezu unvermeidlich ist, dass die generierten OCR-Box-Dateien Fehler in der Symbolspalte enthalten. Jedes Symbol in der Box-Datei des Trainings muss daher von Hand überprüft werden. Dies ist ein mühsamer Prozess, da die Box-Datei der Demo-Rechnung fast tausend Zeilen enthält (eine für jedes Zeichen in der Rechnung). Um die Korrektur zu vereinfachen, stellt der Docker-Container ein Python-Skript zur Verfügung, das die Bounding-Boxes zusammen mit dem OCR-Text auf dem Originalrechnungsbild zeichnet, um einen Vergleich zwischen der Box Datei und dem Dokument zu erleichtern. Das Ergebnis ist in Abbildung 4 dargestellt. Der Docker-Container enthält bereits die korrigierten Box-Dateien, die durch den Suffix “_correct” gekennzeichnet sind.
Abbildung 4 – Extrahierter Text bei Anwendung des Tesseract Modells „deu“

3. lstmf Dateien

Während des Finetunings extrahiert Tesseract den Text aus der Tiff-Datei und überprüft die Vorhersage anhand der Koordinaten sowie des Symbols in der Box-Datei. Tesseract verwendet dabei nicht direkt die Tiff- und Box-Datei, sondern erwartet eine sog. lstmf-Datei, die aus den beiden vorherigen Dateien erstellt wurde. Hierbei ist zu beachten, dass zur Erstellung der lstmf-Datei die Tiff- und Box-Datei denselben Namen haben müssen, z.B. train_invoice.tiff und train_invoice.box. Der folgende Befehl erzeugt eine lstmf-Datei für die Zugrechnung:
cd /home/fine_tune/train
tesseract train_invoice.tiff train_invoice lstm.train 
Alle lstmf-Dateien, die für das Training relevant sind, müssen durch ihren relativen Pfad in einer Textdatei namens deu.training_files.txt angegeben werden. In diesem Anwendungsfall wird nur eine lstmf-Datei für das Training verwendet, so dass die Datei deu.training_files.txt nur eine Zeile enthält, nämlich: eval/train_invoice_correct.lstmf. Es wird empfohlen, auch eine lstfm-Datei für die Evaluierungs-Rechnung zu erstellen. Auf diese Weise kann die Performance des Modells während dem Trainingsvorgang bewertet werden:
cd /home/fine_tune/eval
tesseract eval_invoice_correct.tiff eval_invoice_correct lstm.train

Evaluierung des Standard-LSTM-Modells

OCR-Vorhersagen aus dem deutschen Standardmodell “deu” werden als Benchmark verwendet. Einen genauen Überblick über die OCR-Leistung des deutschen Standardmodells erhält man, indem man eine Box-Datei für die Evaluierungs-Rechnung erzeugt und den OCR-Text mit dem bereits erwähnten Python-Skript visualisiert. Dieses Skript, das die Datei “eval_invoice_ocr deu.tiff” erzeugt, befindet sich im mitgelieferten Container unter „/home/fine_tune/src/draw_box_file_data.py“. Das Skript erwartet als Argument den Pfad zu einer Tiff-Datei, die entsprechende Box-Datei sowie einen Namen für die Ausgabe-Tiff-Datei. Der durch das deutsche Standardmodell extrahierte OCR-Text wird als eval/eval_invoice_ocr_deu.tiff gespeichert und ist in Abbildung 1 dargestellt. Auf den ersten Blick sieht der durch OCR extrahierte Text gut aus. Das Modell extrahiert deutsche Zeichen wie ä, ö ü und ß korrekt. Tatsächlich gibt es nur drei Fälle, in denen Wörter Fehler enthalten:
OCR Truth
Jessel GmbH 8 Co Jessel GmbH & Co
11 Glasbehälter 1l Glasbehälter
Zeki64@hloch.com Zeki64@bloch.com
Das Modell schneidet bei gebräuchlichen deutschen Wörtern bereits gut ab, hat aber Schwierigkeiten mit singulären Symbolen wie “&” und “l” sowie Wörtern wie “bloch”, die nicht in der Wortliste des Modells enthalten sind. Preise und Zahlen sind für das Modell eine viel größere Herausforderung. Hierbei treten deutlich häufiger Fehler bei der Extraktion auf:
OCR Truth
159,16 159,1€
1% 7%
1305.816 1305.81€
227.66 227.6€
341.51 347.57€
1115.16 1115.7€
242.86 242.8€
1456.86 1456.8€
51.46 54.1€
1954.719€ 1954.79€
Das deutsche Standardmodell extrahiert das Euro-Symbol € in 9 von 18 Fällen nicht korrekt. Dies entspricht einer Fehlerquote von 50%.

Finetuning des Standard-LSTM-Modells

Das Standard-LSTM-Modell wird nun auf der in Abbildung 2 gezeigten Rechnung finegetuned. Anschließend wird die OCR-Leistung anhand der in Abbildung 1 gezeigten Evaluierungs-Rechnung bewertet, die auch zuvor für das Benchmarking des deutschen Standardmodells verwendet wurde. Zum Finetuning des LSTM-Modells muss dieses zunächst aus der Datei deu.traineddata extrahiert werden. Mit dem folgenden Befehl wird das LSTM-Modell aus dem deutschen Standardmodell in das Verzeichnis lstm_model extrahiert:
cd /home/fine_tune
combine_tessdata -e tesseract/tessdata/best/deu.traineddata lstm_model/deu.lstm
Anschließend werden alle notwendigen Dateien für das Finetuning zusammengestellt. Die Dateien sind ebenfalls im Docker-Container vorhanden:
  1. Die Trainings-Dateien train_invoice_correct.lstmf und deu.training_files.txt im Verzeichnis train.
  2. Die Evaluierungs-Dateien eval_invoice_correct.lstmf und deu.training_files.txt im eval-Verzeichnis.
  3. Das extrahierte LSTM-Modell deu.lstm im Verzeichnis lstm_model.
Der Docker-Container enthält das Skript src/fine_tune.sh, das den Prozess des Finetunings startet. Sein Inhalt ist:
/usr/bin/lstmtraining 
 --model_output output/fine_tuned 
 --continue_from lstm_model/deu.lstm 
 --traineddata tesseract/tessdata/best/deu.traineddata 
 --train_listfile train/deu.training_files.txt 
 --evallistfile eval/deu.training_files.txt 
 --max_iterations 400
Mit diesem Befehl wird das extrahierte Modell deu.lstm in der in train/deu.training_files.txt angegebenen Datei train_invoice.lstmf getuned. Das Finetuning des LSTM-Modells erfordert sprachspezifische Informationen, die im Ordner deu.tessdata enthalten sind. Die Datei eval_invoice.lstmf, die in eval/deu.training_files.txt angegeben ist, wird zur Berechnung der OCR-Performance während des Trainings verwendet. Das Finetuning wird nach 400 Iterationen beendet. Die gesamte Trainingsdauer dauert weniger als zwei Minuten. Der folgende Befehl führt das Skript aus und protokolliert die Ausgabe in einer Datei:
cd /home/fine_tune
sh src/fine_tune.sh > output/fine_tune.log 2>&1
Der Inhalt der Protokolldatei nach dem Training ist unten dargestellt:
src/fine_tune.log
Loaded file lstm_model/deu.lstm, unpacking...
Warning: LSTMTrainer deserialized an LSTMRecognizer!
Continuing from lstm_model/deu.lstm
Loaded 20/20 lines (1-20) of document train/train_invoice_correct.lstmf
Loaded 24/24 lines (1-24) of document eval/eval_invoice_correct.lstmf
2 Percent improvement time=69, best error was 100 @ 0
At iteration 69/100/100, Mean rms=1.249%, delta=2.886%, char train=8.17%, word train=22.249%, skip ratio=0%, New best char error = 8.17 Transitioned to stage 1 wrote best model:output/deu_fine_tuned8.17_69.checkpoint wrote checkpoint.
-----
2 Percent improvement time=62, best error was 8.17 @ 69
At iteration 131/200/200, Mean rms=1.008%, delta=2.033%, char train=5.887%, word train=20.832%, skip ratio=0%, New best char error = 5.887 wrote best model:output/deu_fine_tuned5.887_131.checkpoint wrote checkpoint.
-----
2 Percent improvement time=112, best error was 8.17 @ 69
At iteration 181/300/300, Mean rms=0.88%, delta=1.599%, char train=4.647%, word train=17.388%, skip ratio=0%, New best char error = 4.647 wrote best model:output/deu_fine_tuned4.647_181.checkpoint wrote checkpoint.
-----
2 Percent improvement time=159, best error was 8.17 @ 69
At iteration 228/400/400, Mean rms=0.822%, delta=1.416%, char train=4.144%, word train=16.126%, skip ratio=0%, New best char error = 4.144 wrote best model:output/deu_fine_tuned4.144_228.checkpoint wrote checkpoint.
-----
Finished! Error rate = 4.144
Während des Trainings speichert Tesseract nach jeder Iteration einen sog. Model Checkpoint. Die Leistung des Modells an diesem Kontrollpunkt wird anhand der Evaluierungs-Daten getestet und mit dem aktuell besten Ergebnis verglichen. Wenn sich das Ergebnis verbessert, d.h. der OCR-Fehler abnimmt, wird eine beschriftete Kopie des Checkpoints gespeichert. Die erste Nummer des Dateinamens für den Kontrollpunkt steht für den Zeichenfehler und die zweite Nummer für die Trainingsiteration. Der letzte Schritt ist die neue Zusammenstellung des finegetunten LSTM-Modells, so dass man wieder ein “traineddata” Modell erhält. Unter der Annahme, dass der Kontrollpunkt bei der 181. Iteration selektiert wurde, wird mit dem folgenden Befehl ein ausgewählter Kontrollpunkt “deu_fine_tuned4.647_181.checkpoint” in ein voll funktionsfähiges Tesseract-Modell “deu_fine_tuned.traineddata” umgewandelt:
cd /home/fine_tune
/usr/bin/lstmtraining 
 --stop_training 
 --continue_from output/deu_fine_tuned4.647_181.checkpoint 
 --traineddata tesseract/tessdata/best/deu.traineddata 
 --model_output output/deu_fine_tuned.traineddata
Dieses Modell muss in die Testdaten der lokalen Tesseract-Installation kopiert werden, um es Tesseract zur Verfügung zu stellen. Dies ist im Docker-Container bereits geschehen. Vergewissern Sie sich, dass das feinabgestimmte Modell in Tesseract verfügbar ist:
tesseract --list-langs

Evaluierung des finegetunten LSTM-Modells

Das finegetunte Modell wird analog zum Standardmodell evaluiert: Es wird eine Box-Datei der Auswertungs-Rechnung erstellt, und der OCR-Text wird mit Hilfe des Python-Skripts auf dem Bild der Auswertungsrechnung angezeigt. Der Befehl zur Erzeugung der Box-Dateien muss so modifiziert werden, dass das fein abgestimmte Modell “deu_fine_tuned” anstelle des Standardmodells “deu” verwendet wird:
cd /home/fine_tune/eval
tesseract eval_invoice.tiff eval_invoice --psm 4 -l deu_fine_tuned lstmbox
Der durch das fein abgestimmte Modell extrahierte OCR-Text ist in Abbildung 5 unten dargestellt.
Abbildung 5: OCR Ergebnisse des finegetunten LSTM Modells
Wie beim deutschen Standardmodell bleibt die Leistung bei den deutschen Wörtern gut, aber nicht perfekt. Um die Leistung bei seltenen Wörtern zu verbessern, könnte die Wortliste des Modells um weitere Worte erweitert werden.
OCR Truth
 Jessel GmbH 8 Co Jessel GmbH & Co
1! Glasbehälte 1l Glasbehälter
Zeki64@hloch.com Zeki64@bloch.com
Wichtiger ist, dass sich die OCR-Leistung bei Zahlen deutlich verbessert hat: Das verfeinerte Modell extrahierte alle Zahlen und jedes Vorkommen des €-Zeichens korrekt.
OCR Truth
159,1€ 159,1€
7% 7%
1305.81€ 1305.81€
227.6€ 227.6€
347.57€ 347.57€
1115.7€ 1115.7€
242.8€ 242.8€
1456.8€ 1456.8€
54.1€ 54.1€
1954.79€ 1954.79€

Fazit und Ausblick

In diesem Artikel wurde gezeigt, dass die OCR-Leistung von Tesseract durch Finetuning erheblich verbessert werden kann. Insbesondere bei Nicht-Standard-Anwendungsfällen, wie der Text-Extraktion von Rechnungsdokumenten, kann so die OCR-Leistung signifikant verbessert werden. Neben der Open Source Lizensierung macht die Möglichkeit, die LSTM-Engine von Tesseract mittels Finetunings für spezifische Anwendungsfälle zu tunen, das Framework zu einem attraktiven Tool, auch für anspruchsvollere OCR-Einsatzszenarien. Zur weiteren Verbesserung des Ergebnisses kann es sinnvoll sein, das Modell für weitere Iterationen zu tunen. In diesem Anwendungsfall wurde die Anzahl der Iterationen absichtlich begrenzt, da nur ein Dokument zum Finetuning verwendet wurde. Mehr Iterationen erhöhen potenziell das Risiko einer Überanpassung des LSTM-Modells auf bestimmten Symbolen, was wiederum die Fehlerquote bei anderen Symbolen erhöht. In der Praxis ist es wünschenswert, die Anzahl der Iterationen unter der Voraussetzung zu erhöhen, dass ausreichend Trainingsdaten zur Verfügung stehen. Die endgültige OCR-Leistung sollte immer auf Basis eines weiteren, jedoch repräsentativen Datensatz von Dokumenten überprüft werden.

Referenzen

  • Tesseract training: https://tesseract-ocr.github.io/tessdoc/TrainingTesseract-4.00.html
  • Image processing overview: https://tesseract-ocr.github.io/tessdoc/ImproveQuality#image-processing
  • Otsu thresholding: https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_imgproc/py_thresholding/py_thresholding.html
  • Tesseract digits comma model: https://github.com/Shreeshrii/tessdata_shreetest
  Here at STATWORX, a Data Science and AI consulting company, we thrive on creating data-driven solutions that can be acted on quickly and translate into real business value. We provide many of our solutions in some form of web application to our customers, to allow and support them in their data-driven decision-making.

Containerization Allows Flexible Solutions

At the start of a project, we typically need to decide where and how to implement the solution we are going to create. There are several good reasons to deploy the designed solutions directly into our customer IT infrastructure instead of acquiring an external solution. Often our data science solutions use sensitive data. By deploying directly to the customers’ infrastructure, we make sure to avoid data-related compliance or security issues. Furthermore, it allows us to build pipelines that automatically extract new data points from the source and incorporate them into the solution so that it is always up to date. However, this also imposes some constraints on us. We need to work with the infrastructure provided by our customers. On the one hand, that requires us to develop solutions that can exist in all sorts of different environments. On the other hand, we need to adapt to changes in the customers’ infrastructure quickly and efficiently. All of this can be achieved by containerizing our solutions.

The Advantages of Containerization

Containerization has evolved as a lightweight alternative to virtualization. It involves packaging up software code and all its dependencies in a “container” so that the software can run on practically any infrastructure. Traditionally, an application was developed in a specific computing development environment and then transferred to the production environment, often resulting in many bugs and errors; Especially when these environments were not mirroring each other. For example, when an application is transferred from a local desktop computer to a virtual machine or from a Linux to a Windows operating system. A container platform like Docker allows us to store the whole application with all the necessary code, system tools, libraries, and settings in a container that can be shipped to and work uniformly in any environment. We can develop our applications dockerized and do not have to worry about the specific infrastructure environment provided by our customers.
Docker Today
There are some other advantages that come with using Docker in comparison to traditional virtual machines for the deployment of data science applications.
  • Efficiency – As the container shares the machines’ OS system kernel and does not require a Guest OS per application, it uses the provided infrastructure more efficiently, resulting in lower infrastructure costs.
  • Speed – The start of a container does not require a Guest OS reboot; it can be started, stopped, replicated, and destroyed in seconds. That speeds up the development process, the time to market, and the operational speed. Releasing new software or updates has never been so fast: Bugs can be fixed, and new features implemented in hours or days.
  • Scalability – Horizontal scaling allows to start and stop additional container depending on the current demand.
  • Security – Docker provides the strongest default isolation capabilities in the industry. Containers run isolated from each other, which means that if one crashes, other containers serving the same applications will still be running.

The Key Benefits of a Microservices Architecture

In connection with the use of Docker for delivering data science solutions, we use another emerging method. Instead of providing a monolithic application that comes with all the required functionalities of an application, we create small, independent services that communicate with each other and together embody the complete application. Usually, we develop WebApps for our customers. As shown in the graphic, the WebApp will communicate directly with the different other backend microservices. Each one is designed for a specific task and has an exposed REST API that allows for different HTTP requests. Furthermore, the backend microservices are indirectly exposed to the mobile app. An API Gateway routes the requests to the desired microservices. It can also provide an API endpoint that invokes several backend microservices and aggregates the results. Moreover, it can be used for access control, caching, and load balancing. If suitable, you might also decide to place an API Gateway between the WebApp and the backend microservices.
Microservices
In summary, splitting the application into small microservices has several advantages for us:
  • Agility – As services operate independently, we can update or fix bugs for a specific microservice without redeploying the entire application.
  • Technology freedom – Different microservices can be based on different technologies or languages, thus allowing us to use the best of all worlds.
  • Fault isolation – If an individual microservice becomes unavailable, it will not crash the entire application. Only the function provided by the specific microservice will not be provided.
  • Scalability – Services can be scaled independently. It is possible to scale the services which do the work without scaling the application.
  • Reusability of service – Often, the functionalities of the services we create are also requested by other departments and other cases. We then expose application user interfaces so that the services can also be used independently of the focal application.

Containerized Microservices – The Best of Both Worlds!

The combination of docker with a clean microservices architecture allows us to combine the mentioned advantages. Each microservice lives in its own Docker container. We deliver fast solutions that are consistent across environments, efficient in terms of resource consumption, and easily scalable and updatable. We are not bound to a specific infrastructure and can adjust to changes quickly and efficiently.
Containerized_Microservices

Conclusion

Often the deployment of a data science solution is one of the most challenging tasks within data science projects. But without a proper deployment, there won’t be any business value created. Hopefully, I was able to help you figure out how to optimize the implementation of your data science application. If you need further help bringing your data science solution into production, feel free to contact us!

Sources

Did you ever want to make your machine learning model available to other people, but didn’t know how? Or maybe you just heard about the term API, and want to know what’s behind it? Then this post is for you! Here at STATWORX, we use and write APIs daily. For this article, I wrote down how you can build your own API for a machine learning model that you create and the meaning of some of the most important concepts like REST. After reading this short article, you will know how to make requests to your API within a Python program. So have fun reading and learning!

Table of Contents

What is an API?

API is short for Application Programming Interface. It allows users to interact with the underlying functionality of some written code by accessing the interface. There is a multitude of APIs, and chances are good that you already heard about the type of API, we are going to talk about in this blog post: The web API. This specific type of API allows users to interact with functionality over the internet. In this example, we are building an API that will provide predictions through our trained machine learning model. In a real-world setting, this kind of API could be embedded in some type of application, where a user enters new data and receives a prediction in return. APIs are very flexible and easy to maintain, making them a handy tool in the daily work of a Data Scientist or Data Engineer. An example of a publicly available machine learning API is Time Door. It provides Time Series tools that you can integrate into your applications. APIs can also be used to make data available, not only machine learning models.
API Illustration

And what is REST?

Representational State Transfer (or REST) is an approach that entails a specific style of communication through web services. When using some of the REST best practices to implement an API, we call that API a “REST API”. There are other approaches to web communication, too (such as the Simple Object Access Protocol: SOAP), but REST generally runs on less bandwidth, making it preferable to serve your machine learning models. In a REST API, the four most important types of requests are:
  • GET
  • PUT
  • POST
  • DELETE
For our little machine learning application, we will mostly focus on the POST method, since it is very versatile, and lots of clients can’t send GET methods. It’s important to mention that APIs are stateless. This means that they don’t save the inputs you give during an API call, so they don’t preserve the state. That’s significant because it allows multiple users and applications to use the API at the same time, without one user request interfering with another.

The Model

For this How-To-article, I decided to serve a machine learning model trained on the famous iris dataset. If you don’t know the dataset, you can check it out here. When making predictions, we will have four input parameters: sepal length, sepal width, petal length, and finally, petal width. Those will help to decide which type of iris flower the input is. For this example I used the scikit-learn implementation of a simple KNN (K-nearest neighbor) algorithm to predict the type of iris:
# model.py
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.externals import joblib
import numpy as np


def train(X,y):

    # train test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

    knn = KNeighborsClassifier(n_neighbors=1)

    # fit the model
    knn.fit(X_train, y_train)
    preds = knn.predict(X_test)
    acc = accuracy_score(y_test, preds)
    print(f'Successfully trained model with an accuracy of {acc:.2f}')

    return knn

if __name__ == '__main__':

    iris_data = datasets.load_iris()
    X = iris_data['data']
    y = iris_data['target']

    labels = {0 : 'iris-setosa',
              1 : 'iris-versicolor',
              2 : 'iris-virginica'}

    # rename integer labels to actual flower names
    y = np.vectorize(labels.__getitem__)(y)

    mdl = train(X,y)

    # serialize model
    joblib.dump(mdl, 'iris.mdl')
As you can see, I trained the model with 70% of the data and then validated with 30% out of sample test data. After the model training has taken place, I serialize the model with the joblib library. Joblib is basically an alternative to pickle, which preserves the persistence of scikit estimators, which include a large number of numpy arrays (such as the KNN model, which contains all the training data). After the file is saved as a joblib file (the file ending thereby is not important by the way, so don’t be confused that some people call it .model or .joblib), it can be loaded again later in our application.

The API with Python and Flask

To build an API from our trained model, we will be using the popular web development package Flask and Flask-RESTful. Further, we import joblib to load our model and numpy to handle the input and output data. In a new script, namely app.py, we can now set up an instance of a Flask app and an API and load the trained model (this requires saving the model in the same directory as the script):
from flask import Flask
from flask_restful import Api, Resource, reqparse
from sklearn.externals import joblib
import numpy as np

APP = Flask(__name__)
API = Api(APP)

IRIS_MODEL = joblib.load('iris.mdl')
The second step now is to create a class, which is responsible for our prediction. This class will be a child class of the Flask-RESTful class Resource. This lets our class inherit the respective class methods and allows Flask to do the work behind your API without needing to implement everything. In this class, we can also define the methods (REST requests) that we talked about before. So now we implement a Predict class with a .post() method we talked about earlier. The post method allows the user to send a body along with the default API parameters. Usually, we want the body to be in JSON format. Since this body is not delivered directly in the URL, but as a text, we have to parse this text and fetch the arguments. The flask _restful package offers the RequestParser class for that. We simply add all the arguments we expect to find in the JSON input with the .add_argument() method and parse them into a dictionary. We then convert it into an array and return the prediction of our model as JSON.
class Predict(Resource):

    @staticmethod
    def post():
        parser = reqparse.RequestParser()
        parser.add_argument('petal_length')
        parser.add_argument('petal_width')
        parser.add_argument('sepal_length')
        parser.add_argument('sepal_width')

        args = parser.parse_args()  # creates dict

        X_new = np.fromiter(args.values(), dtype=float)  # convert input to array

        out = {'Prediction': IRIS_MODEL.predict([X_new])[0]}

        return out, 200
You might be wondering what the 200 is that we are returning at the end: For APIs, some HTTP status codes are displayed when sending requests. You all might be familiar with the famous 404 - page not found code. 200 just means that the request has been received successfully. You basically let the user know that everything went according to plan. In the end, you just have to add the Predict class as a resource to the API, and write the main function:
API.add_resource(Predict, '/predict')

if __name__ == '__main__':
    APP.run(debug=True, port='1080')
The '/predict' you see in the .add_resource() call, is the so-called API endpoint. Through this endpoint, users of your API will be able to access and send (in this case) POST requests. If you don’t define a port, port 5000 will be the default. You can see the whole code for the app again here:
# app.py
from flask import Flask
from flask_restful import Api, Resource, reqparse
from sklearn.externals import joblib
import numpy as np

APP = Flask(__name__)
API = Api(APP)

IRIS_MODEL = joblib.load('iris.mdl')


class Predict(Resource):

    @staticmethod
    def post():
        parser = reqparse.RequestParser()
        parser.add_argument('petal_length')
        parser.add_argument('petal_width')
        parser.add_argument('sepal_length')
        parser.add_argument('sepal_width')

        args = parser.parse_args()  # creates dict

        X_new = np.fromiter(args.values(), dtype=float)  # convert input to array

        out = {'Prediction': IRIS_MODEL.predict([X_new])[0]}

        return out, 200


API.add_resource(Predict, '/predict')

if __name__ == '__main__':
    APP.run(debug=True, port='1080')

Run the API

Now it’s time to run and test our API! To run the app, simply open a terminal in the same directory as your app.py script and run this command.
python run app.py
You should now get a notification, that the API runs on your localhost in the port you defined. There are several ways of accessing the API once it is deployed. For debugging and testing purposes, I usually use tools like Postman. We can also access the API from within a Python application, just like another user might want to do to use your model in their code. We use the requests module, by first defining the URL to access and the body to send along with our HTTP request:
import requests

url = 'http://127.0.0.1:1080/predict'  # localhost and the defined port + endpoint
body = {
    "petal_length": 2,
    "sepal_length": 2,
    "petal_width": 0.5,
    "sepal_width": 3
}
response = requests.post(url, data=body)
response.json()
The output should look something like this:
Out[1]: {'Prediction': 'iris-versicolor'}
That’s how easy it is to include an API call in your Python code! Please note that this API is just running on your localhost. You would have to deploy the API to a live server (e.g., on AWS) for others to access it.

Conclusion

In this blog article, you got a brief overview of how to build a REST API to serve your machine learning model with a web interface. Further, you now understand how to integrate simple API requests into your Python code. For the next step, maybe try securing your APIs? If you are interested in learning how to build an API with R, you should check out this post. I hope that this gave you a solid introduction to the concept and that you will be building your own APIs immediately. Happy coding!  

Introduction

When working on data science projects in R, exporting internal R objects as files on your hard drive is often necessary to facilitate collaboration. Here at STATWORX, we regularly export R objects (such as outputs of a machine learning model) as .RDS files and put them on our internal file server. Our co-workers can then pick them up for further usage down the line of the data science workflow (such as visualizing them in a dashboard together with inputs from other colleagues). Over the last couple of months, I came to work a lot with RDS files and noticed a crucial shortcoming: The base R saveRDS function does not allow for any kind of archiving of existing same-named files on your hard drive. In this blog post, I will explain why this might be very useful by introducing the basics of serialization first and then showcasing my proposed solution: A wrapper function around the existing base R serialization framework.

Be wary of silent file replacements!

In base R, you can easily export any object from the environment to an RDS file with:
saveRDS(object = my_object, file = "path/to/dir/my_object.RDS")
However, including such a line somewhere in your script can carry unintended consequences: When calling saveRDS multiple times with identical file names, R silently overwrites existing, identically named .RDS files in the specified directory. If the object you are exporting is not what you expect it to be — for example due to some bug in newly edited code — your working copy of the RDS file is simply overwritten in-place. Needless to say, this can prove undesirable. If you are familiar with this pitfall, you probably used to forestall such potentially troublesome side effects by commenting out the respective lines, then carefully checking each time whether the R object looked fine, then executing the line manually. But even when there is nothing wrong with the R object you seek to export, it can make sense to retain an archived copy of previous RDS files: Think of a dataset you run through a data prep script, and then you get an update of the raw data, or you decide to change something in the data prep (like removing a variable). You may wish to archive an existing copy in such cases, especially with complex data prep pipelines with long execution time.

Don’t get tangled up in manual renaming

You could manually move or rename the existing file each time you plan to create a new one, but that’s tedious, error-prone, and does not allow for unattended execution and scalability. For this reason, I set out to write a carefully designed wrapper function around the existing saveRDS call, which is pretty straightforward: As a first step, it checks if the file you attempt to save already exists in the specified location. If it does, the existing file is renamed/archived (with customizable options), and the “updated” file will be saved under the originally specified name. This approach has the crucial advantage that the existing code that depends on the file name remaining identical (such as readRDS calls in other scripts) will continue to work with the latest version without any needs for adjustment! No more saving your objects as “models_2020-07-12.RDS”, then combing through the other scripts to replace the file name, only to repeat this process the next day. At the same time, an archived copy of the — otherwise overwritten — file will be kept.

What are RDS files anyways?

Before I walk you through my proposed solution, let’s first examine the basics of serialization, the underlying process behind high-level functions like saveRDS.
Simply speaking, serialization is the “process of converting an object into a stream of bytes so that it can be transferred over a network or stored in a persistent storage.” Stack Overflow: What is serialization?
There is also a low-level R interface, serialize, which you can use to explore (un-)serialization first-hand: Simply fire up R and run something like serialize(object = c(1, 2, 3), connection = NULL). This call serializes the specified vector and prints the output right to the console. The result is an odd-looking raw vector, with each byte separately represented as a pair of hex digits. Now let’s see what happens if we revert this process:
s <- serialize(object = c(1, 2, 3), connection = NULL)
print(s)
# >  [1] 58 0a 00 00 00 03 00 03 06 00 00 03 05 00 00 00 00 05 55 54 46 2d 38 00 00 00 0e 00
# > [29] 00 00 03 3f f0 00 00 00 00 00 00 40 00 00 00 00 00 00 00 40 08 00 00 00 00 00 00

unserialize(s)
# > 1 2 3
The length of this raw vector increases rapidly with the complexity of the stored information: For instance, serializing the famous, although not too large, iris dataset results in a raw vector consisting of 5959 pairs of hex digits! Besides the already mentioned saveRDS function, there is also the more generic save function. The former saves a single R object to a file. It allows us to restore the object from that file (with the counterpart readRDS), possibly under a different variable name: That is, you can assign the contents of a call to readRDS to another variable. By contrast, save allows for saving multiple R objects, but when reading back in (with load), they are simply restored in the environment under the object names they were saved with. (That’s also what happens automatically when you answer “Yes” to the notorious question of whether to “save the workspace image to ~/.RData” when quitting RStudio.)

Creating the archives

Obviously, it’s great to have the possibility to save internal R objects to a file and then be able to re-import them in a clean session or on a different machine. This is especially true for the results of long and computationally heavy operations such as fitting machine learning models. But as we learned earlier, one wrong keystroke can potentially erase that one precious 3-hour-fit fine-tuned XGBoost model you ran and carefully saved to an RDS file yesterday.

Digging into the wrapper

So, how did I go about fixing this? Let’s take a look at the code. First, I define the arguments and their defaults: The object and file arguments are taken directly from the wrapped function, the remaining arguments allow the user to customize the archiving process: Append the archive file name with either the date the original file was archived or last modified, add an additional timestamp (not just the calendar date), or save the file to a dedicated archive directory. For more details, please check the documentation here. I also include the ellipsis ... for additional arguments to be passed down to saveRDS. Additionally, I do some basic input handling (not included here).
save_rds_archive <- function(object,
                             file = "",
                             archive = TRUE,
                             last_modified = FALSE,
                             with_time = FALSE,
                             archive_dir_path = NULL,
                             ...) {
The main body of the function is basically a series of if/else statements. I first check if the archive argument (which controls whether the file should be archived in the first place) is set to TRUE, and then if the file we are trying to save already exists (note that “file” here actually refers to the whole file path). If it does, I call the internal helper function create_archived_file, which eliminates redundancy and allows for concise code.
if (archive) {

    # check if file exists
    if (file.exists(file)) {

      archived_file <- create_archived_file(file = file,
                                            last_modified = last_modified,
                                            with_time = with_time)

Composing the new file name

In this function, I create the new name for the file which is to be archived, depending on user input: If last_modified is set, then the mtime of the file is accessed. Otherwise, the current system date/time (= the date of archiving) is taken instead. Then the spaces and special characters are replaced with underscores, and, depending on the value of the with_time argument, the actual time information (not just the calendar date) is kept or not. To make it easier to identify directly from the file name what exactly (date of archiving vs. date of modification) the indicated date/time refers to, I also add appropriate information to the file name. Then I save the file extension for easier replacement (note that “.RDS”, “.Rds”, and “.rds” are all valid file extensions for RDS files). Lastly, I replace the current file extension with a concatenated string containing the type info, the new date/time suffix, and the original file extension. Note here that I add a “$” sign to the regex which is to be matched by gsub to only match the end of the string: If I did not do that and the file name would be something like “my_RDS.RDS”, then both matches would be replaced.
# create_archived_file.R

create_archived_file <- function(file, last_modified, with_time) {

  # create main suffix depending on type
  suffix_main <- ifelse(last_modified,
                        as.character(file.info(file)$mtime),
                        as.character(Sys.time()))

  if (with_time) {

    # create clean date-time suffix
    suffix <- gsub(pattern = " ", replacement = "_", x = suffix_main)
    suffix <- gsub(pattern = ":", replacement = "-", x = suffix)

    # add "at" between date and time
    suffix <- paste0(substr(suffix, 1, 10), "_at_", substr(suffix, 12, 19))

  } else {

    # create date suffix
    suffix <- substr(suffix_main, 1, 10)

  }

  # create info to paste depending on type
  type_info <- ifelse(last_modified,
                      "_MODIFIED_on_",
                      "_ARCHIVED_on_")

  # get file extension (could be any of "RDS", "Rds", "rds", etc.)
  ext <- paste0(".", tools::file_ext(file))

  # replace extension with suffix
  archived_file <- gsub(pattern = paste0(ext, "$"),
                        replacement = paste0(type_info,
                                             suffix,
                                             ext),
                        x = file)

  return(archived_file)

}

Archiving the archives?

By way of example, with last_modified = FALSE and with_time = TRUE, this function would turn the character file name “models.RDS” into “models_ARCHIVED_on_2020-07-12_at_11-31-43.RDS”. However, this is just a character vector for now — the file itself is not renamed yet. For this, we need to call the base R file.rename function, which provides a direct interface to your machine’s file system. I first check, however, whether a file with the same name as the newly created archived file string already exists: This could well be the case if one appends only the date (with_time = FALSE) and calls this function several times per day (or potentially on the same file if last_modified = TRUE). Somehow, we are back to the old problem in this case. However, I decided that it was not a good idea to archive files that are themselves archived versions of another file since this would lead to too much confusion (and potentially too much disk space being occupied). Therefore, only the most recent archived version will be kept. (Note that if you still want to keep multiple archived versions of a single file, you can set with_time = TRUE. This will append a timestamp to the archived file name up to the second, virtually eliminating the possibility of duplicated file names.) A warning is issued, and then the already existing archived file will be overwritten with the current archived version.

The last puzzle piece: Renaming the original file

To do this, I call the file.rename function, renaming the “file” originally passed by the user call to the string returned by the helper function. The file.rename function always returns a boolean indicating if the operation succeeded, which I save to a variable temp to inspect later. Under some circumstances, the renaming process may fail, for instance due to missing permissions or OS-specific restrictions. We did set up a CI pipeline with GitHub Actions and continuously test our code on Windows, Linux, and MacOS machines with different versions of R. So far, we didn’t run into any problems. Still, it’s better to provide in-built checks.

It’s an error! Or is it?

The problem here is that, when renaming the file on disk failed, file.rename raises merely a warning, not an error. Since any causes of these warnings most likely originate from the local file system, there is no sense in continuing the function if the renaming failed. That’s why I wrapped it into a tryCatch call that captures the warning message and passes it to the stop call, which then terminates the function with the appropriate message. Just to be on the safe side, I check the value of the temp variable, which should be TRUE if the renaming succeeded, and also check if the archived version of the file (that is, the result of our renaming operation) exists. If both of these conditions hold, I simply call saveRDS with the original specifications (now that our existing copy has been renamed, nothing will be overwritten if we save the new file with the original name), passing along further arguments with ....
        if (file.exists(archived_file)) {
          warning("Archived copy already exists - will overwrite!")
        }

        # rename existing file with the new name
        # save return value of the file.rename function
        # (returns TRUE if successful) and wrap in tryCatch
        temp <- tryCatch({file.rename(from = file,
                                      to = archived_file)
        },
        warning = function(e) {
          stop(e)
        })

      }

      # check return value and if archived file exists
      if (temp & file.exists(archived_file)) {
        # then save new file under specified name
        saveRDS(object = object, file = file, ...)
      }

    }
These code snippets represent the cornerstones of my function. I also skipped some portions of the source code for reasons of brevity, chiefly the creation of the “archive directory” (if one is specified) and the process of copying the archived file into it. Please refer to our GitHub for the complete source code of the main and the helper function. Finally, to illustrate, let’s see what this looks like in action:
x <- 5
y <- 10
z <- 20

## save to RDS
saveRDS(x, "temp.RDS")
saveRDS(y, "temp.RDS")

## "temp.RDS" is silently overwritten with y
## previous version is lost
readRDS("temp.RDS")
#> [1] 10

save_rds_archive(z, "temp.RDS")
## current version is updated
readRDS("temp.RDS")
#> [1] 20

## previous version is archived
readRDS("temp_ARCHIVED_on_2020-07-12.RDS")
#> [1] 10

Great, how can I get this?

The function save_rds_archive is now included in the newly refactored helfRlein package (now available in version 1.0.0!) which you can install directly from GitHub:
# install.packages("devtools")
devtools::install_github("STATWORX/helfRlein")
Feel free to check out additional documentation and the source code there. If you have any inputs or feedback on how the function could be improved, please do not hesitate to contact me or raise an issue on our GitHub.

Conclusion

That’s it! No more manually renaming your precious RDS files — with this function in place, you can automate this tedious task and easily keep a comprehensive archive of previous versions. You will be able to take another look at that one model you ran last week (and then discarded again) in the blink of an eye. I hope you enjoyed reading my post — maybe the function will come in handy for you someday!

Introduction

Sometimes here at STATWORX, we have impromptu discussions about statistical methods. In one such discussion, one of my colleagues decided to declare (albeit jokingly) Bayesian statistics unnecessary. That made me ask myself: Why would I ever use Bayesian models in the context of a standard regression problem? Existing approaches such as Ridge Regression are just as good, if not better. However, the Bayesian approach has the advantage that it lets you regularize your model to prevent overfitting and meaningfully interpret the regularization parameters. Contrary to the usual way of looking at ridge regression, the regularization parameters are no longer abstract numbers but can be interpreted through the Bayesian paradigm as derived from prior beliefs. In this post, I’ll show you the formal similarity between a generalized ridge estimator and the Bayesian equivalent.

A (very brief) Primer on Bayesian Stats

To understand the Bayesian regression estimator, a minimal amount of knowledge about Bayesian statistics is necessary, so here’s what you need to know (if you don’t already): In Bayesian statistics, we think about model parameters (i.e., regression coefficients) probabilistically. In other words, the data given to us is fixed, and the parameters are considered random. That runs counter to the standard frequentist perspective in which the underlying model parameters are treated as fixed. At the same time, the data are considered random realizations of the stochastic process driven by those fixed model parameters. The end goal of Bayesian analysis is to find the posterior distribution, which you may remember from Bayes Rule:

    \[p(theta|y) = frac{p(y|theta) p(theta)}{p(y)}\]

While p(y|theta) is our likelihood and p(y) is a normalizing constant, p(theta) is our prior which does not depend on the data, y. In classical statistics, p(theta) is set to 1 (an improper reference prior) so that when the posterior ‘probability’ is maximized, really just the likelihood is maximized because it’s the only part that still depends on theta. However, in Bayesian statistics, we use an actual probability distribution in place of p(theta), a Normal distribution, for example. So let’s consider the case of a regression problem, and we’ll assume that our target, y, and our prior follow normal distributions. That leads us to conjugate Bayesian analysis, in which we can neatly write down an equation for the posterior distribution. In many cases, this is not possible, and for this reason, Markov Chain Monte Carlo methods were invented to sample from the posterior – taking a frequentist approach, ironically. We’ll make the usual assumption about the data: y_i is i.i.d. N(bold {x_i beta}, sigma^2) for all observations i. This gives us our standard likelihood for the Normal distribution. Now we can specify the prior for the parameter we’re trying to estimate, (beta, sigma^2). If we choose a Normal prior (conditional on the variance, sigma^2) for the vector or weights in beta, i.e. N(b_0, sigma^2 B_0) and an inverse-Gamma prior over the variance parameter it can be shown that the posterior distribution for beta is Normally distributed with mean

    \[hatbeta_{Bayesian} = (B_0^{-1} + X'X)^{-1}(B_0^{-1} b_0 + X'X hatbeta)\]

If you’re interested in a proof of this result check out Jackman (2009, p.526). Let’s look at it piece by piece:
  • hatbeta is our standard OLS estimator, (X'X)^{-1}X'y
  • b_0 is the mean vector of (multivariate normal) prior distribution, so it lets us specify what we think the average values of each of our model parameters are
  • B_0 is the covariance matrix and contains our respective uncertainties about the model parameters. The inverse of the variance is called the precision
What we can see from the equation is that the mean of our posterior is a precision weighted average of our prior mean (information not based on data) and the OLS estimator (based solely on the data). The second term in parentheses indicates that we are taking the uncertainty weighted prior mean, B_0^{-1} b_0, and adding it to the weighted OLS estimator, X'Xhatbeta. Imagine for a moment that B_0^{-1} = 0 . Then

    \[hatbeta_{Bayesian} = (X'X)^{-1}(X'X hatbeta) = hatbeta\]

That would mean that we are infinitely uncertain about our prior beliefs that the mean vector of our prior distribution would vanish, contributing nothing to our posterior! Likewise, if our uncertainty decreases (and the precision thus increases), the prior mean, b_0, would contribute more to the posterior mean. After this short primer on Bayesian statistics, we can now formally compare the Ridge estimator with the above Bayesian estimator. But first, we need to take a look at a more general version of the Ridge estimator.

Generalizing the Ridge estimator

A standard tool used in many regression problems, the standard Ridge estimator is derived by solving a least-squares problem from the following loss function:

    \[L(beta,lambda) = frac{1}{2}sum(y-Xbeta)^2 + frac{1}{2} lambda ||beta||^2\]

While minimizing this gives us the standard Ridge estimator you have probably seen in textbooks on the subject, there’s a slightly more general version of this loss function:

    \[L(beta,lambda,mu) = frac{1}{2}sum(y-Xbeta)^2 + frac{1}{2} lambda ||beta - mu||^2\]

Let’s derive the estimator by first re-writing the loss function in terms of matrices:

    \[begin{aligned} L(beta,lambda,mu) &= frac{1}{2}(y - X beta)^{T}(y - X beta) + frac{1}{2} lambda||beta - mu||^2 &= frac{1}{2} y^Ty - beta^T X^T y + frac{1}{2} beta^T X^T X beta + frac{1}{2} lambda||beta - mu||^2 end{aligned}\]

Differentiating with respect to the parameter vector, we end up with this expression for the gradient:

    \[nabla_{beta} L (beta, lambda, mu) = -X^T y + X^T X beta + lambda (beta - mu)\]

So, Minimizing over beta we get this expression for the generalized ridge estimator:

    \[hatbeta_{Ridge} = (X'X + lambda I )^{-1}(lambda mu + X'y)\]

The standard Ridge estimator can be recovered by setting mu=0. Usually we regard lambda as an abstract parameter that regulates the penalty size and mu as a vector of values (one for each predictor) that increases the loss the further these coefficients deviate from these values. When mu=0 the coefficients are pulled towards zero. Let’s take a look at how the estimator behaves when the parameters, mu, and lambda change. We’ll define a meaningful ‘prior’ for our example and then vary the penalty parameter. As an example, we’ll use the diamonds dataset from the ggplot2 package and model the price as a linear function of the number of carats, in each diamond, the depth, table, x, y and z attributes
As we can see from the plot, both with and without a prior, the coefficient estimates change rapidly for the first few increases in the penalty size. We also see that the ‘shrinkage’ effect holds from the upper plot: as the penalty increases, the coefficients tend towards zero, some faster than others. The plot on the right shows how the coefficients change when we set a sensible ‘prior’. The coefficients still change, but they now tend towards the ‘prior’ we specified. That’s because lambda penalizes deviations from our mu, which means that larger values for the penalty pull the coefficients towards mu. You might be asking yourself how this compares to the Bayesian estimator. Let’s find out!

Comparing the Ridge and Bayesian Estimator

Now that we’ve seen both the Ridge and the Bayesian estimators, it’s time to compare them. We discovered, that the Bayesian estimator contains the OLS estimator. Since we know its form, let’s substitute it and see what happens:

    \[begin{aligned} hatbeta_{Bayesian} &= (X'X + B_0^{-1})^{-1}(B_0^{-1} b_0 + X'X hatbeta) &= (X'X + B_0^{-1})^{-1}(B_0^{-1} b_0 + X'X (X'X)^{-1}X'y) &= (X'X + B_0^{-1})^{-1}(B_0^{-1} b_0 + X'y) end{aligned}\]

This form makes the analogy much clearer:
  • lambda I corresponds to B_0^{-1}, the matrix of precisions. In other words, since I is the identity matrix, the ridge estimator assumes no covariances between the regression coefficients and a constant precision across all coefficients (recall that lambda is a scalar)
  • lambda mu corresponds to B_0^{-1} b_0, which makes sense, since the vector b_0 is the mean of our prior distribution, which essentially pulls the estimator towards it, just like mu ‘shrinks’ the coefficients towards its values. This ‘pull’ depends on the uncertainty captured by B_0^{-1} or lambda I in the ridge estimator.
That’s all well and good, but let’s see how changing the uncertainty in the Bayesian case compares to the behavior of the ridge estimator. Using the same data and the same model specification as above, we’ll set the covariance matrix B_0 matrix to equal lambda I and then change lambda. Remember, smaller values of lambda now imply a more significant contribution of the prior (less uncertainty), and therefore increasing them makes the prior less important.
The above plots match out understanding so far: With a prior mean of zeros, the coefficients are shrunken towards zero, as in the ridge regression case when the prior dominates, i.e., when the precision is high. And when a previous mean is set, the coefficients tend towards it as the precision increases. So much for the coefficients, but what about the performance? Let’s have a look!

Performance comparison

Lastly, we’ll compare the predictive performance of the two models. Although we could treat the parameters in the model as hyperparameters, which we would need to tune, this would defy the purpose of using prior knowledge. Instead, let’s choose a previous specification for both models, and then compare the performance on a holdout set (30% of the data). While we can use the simple Xhatbeta as our predictor for the Ridge model, the Bayesian model provides us with a full posterior predictive distribution, which we can sample from to get model predictions. To estimate the model I used the brmspackage.
RMSE MAE MAPE
Bayesian Linear Model 1625.38 1091.36 44.15
Ridge Estimator 1756.01 1173.50 43.44
Overall, both models perform similarly, although some error metrics slightly favor one model over the other. Judging by these errors, we could certainly improve our models by specifying a more appropriate probability distribution for our target variable. After all, prices can not be negative, yet our models can and do produce negative predictions.

Recap

In this post, I’ve shown you how the ridge estimator compares to the Bayesian conjugate linear model. Now you understand the connection between the two models and how a Bayesian approach can provide a more readily interpretable way of regularizing your model. Normally lambda would be considered a penalty size, but now it can be interpreted as a measure of prior uncertainty. Similarly, the parameter vector mu can be seen as a vector of prior means for our model parameters in the extended ridge model. As far as the Bayesian approach goes, we also can use prior distributions to implement expert knowledge in your estimation process. This regularizes your model and allows for incorporation of external information in your model. If you are interested in the code, check it out at our GitHub page!

References

  • Jackman S. 2009. Bayesian Analysis for the Social Sciences. West Sussex: Wiley.
In this blog post, I want to present two models of how to secure a REST API. Both models work with JSON Web Tokens (JWT). We suppose that token creation has already happened somewhere else, e.g., on the customer side. We at STATWORX are often interested in the verification part itself. The first model extends a running unsecured interface without changing any code in it. This model does not even rely on Web Tokens and can be applied in a broader context. The second model demonstrates how token verification can be implemented inside a Flask application.

About JWT

JSON Web Tokens are more or less JSON objects. They contain authority information like the issuer, the type of signing, information on their validity span, and custom information like the username, user groups and further personal display data. To be more precise: it consists of three parts: the header, the payload, and the signature. All three parts are JSON objects, are base64url encoded and put together in the following fashion, where the dots separate the three parts:
base64url(header).base64url(payload).base64url(signature)
All information is transparent and readable to whoever the token has (just decode the parts). Its security and validity stem from its digital signing, the signing guarantees that a Web Token is from its issuer. There are two different ways of digital signing: with a secret or with a public-private-key pair in various forms and algorithms. This choice also influences the verification possibilities: The symmetric algorithm with a password can only be verified by the owner of the password – the issuer – whereas, for asymmetric algorithms, the public part can be distributed and, therefore, used for verification. Besides this feature, JWT offers additional advantages and features. Here is a brief overview
  • Decoupling: The application or the API will not need to implement a secure way to exchange passwords and verify them against a backend.
  • Purpose oriented: Tokens can be issued for a particular purpose (defined within the payload) and are only valid for this purpose.
  • Less password exchange: Tokens can be used multiple times without password interactions.
  • Expiration: Tokens expire. Even in the event of theft and criminal intent, the amount of information that can be obtained is limited to its validity span.
More information about JWT and also a debugger can be found at https://jwt.io

Authorization Header

Moving over to consideration on how Authorization is exchanged within HTTP requests. To authenticate against a server, you can use the Request Header Authorization, of which various types exist. Down below are two common examples:
  • Basic
  Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==
where the latter part is a base64 encoded string of user:password
  • Bearer
  Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.eyJpbmZvIjoiSSdtIGEgc2lnbmVkIHRva2VuIn0.rjnRMAKcaRamEHnENhg0_Fqv7Obo-30U4bcI_v-nfEM
where the latter part is a JWT. Note that this information is totally transparent on transfer unless HTTPS protocol is used.

Pattern 1 – Sidecar Verification

Let’s start with the first pattern that I’d like to introduce. In this section, I’ll introduce a setup with nginx and use its feature of sub-requests. This is especially useful in cases where the implementation of an API can not be changed. Nginx is a versatile tool. It acts as a web server, a load balancer, a content cache, and a reversed proxy. With the latest Web Server Survey of April 2020, nginx is the most used web server for public web sites. By using nginx and sub-requests, each incoming request has to go through nginx, and the header, including the Authorization part, is passed to a sub-module (calling it Auth Service). This sub-module then checks the validity of the Authorization before sending it through to the actual REST API or declining it. This decision process depends on the following status codes of the Auth Service:
  • 20x: nginx passes the request over to the actual resource service
  • 401 or 403: nginx denies the access and sends the response from the authentication service instead
As you might have noticed, this setup does not exclusively build on JWTs, and therefore other authorization types can be used.
Sub-request model.
Schema of request handling: 1) pass over to Auth Service for verification, 2) pass over to REST API if the verification was successful.

Nginx configuration

Two configuration amendments need to be taken to have nginx configured for token validation:
  • Add an (internal) directive to the authorization service which verifies the request
  • Add authentication parameters to the directive that needs to be secured
The authorization directive is configured like this
location /auth {
        internal;
        proxy_pass                          http://auth-service/verify;
        proxy_pass_request_body off;
        proxy_set_header                Content-Length "";
        proxy_set_header                X-Original-URI $request-uri;
}
This cuts off the body and sends the rest of the request to the authorization service. Next, the configuration of the directive which forwards to your REST API:
location /protected {
        auth_request    auth/;
        proxy_pass      http://rest-service;
}
Note that auth_request points to the authorization directive we have set up before. Please see the Nginx documentation for additional information on the sub-request pattern.

Basic Implementation of the Auth Service

In the sub-module, we use the jwcrypto library for the token operations. It offers all features and algorithms needed for the authentication task. Many other libraries can do similar things, which can be found at Jwt.io. Suppose the token was created by an asymmetric cryptographical algorithm like RSA, and you have access to the public key (here called: public.pem) Following our configuration on nginx, we will add the directive /verify that does the verification job:
from jwcrypto import jwt, jwk
from flask import Flask

app = Flask(__name__)

# Load the public key
with open('public.pem', 'rb') as pemfile:
        public_key = jwk.JWK.from_pem(pemfile.read())

def parse_token(auth_header):
      # your implementation of Authorization Header extraction
    pass

@app.route("/verify")
def verify():
      token_raw = parse_token(request.headers["Authorization"])
      try:
            # decode token
        token = jwt.JWT(jwt=token_raw, key=public_key)
        return 200, "this is a secret information"
    except:
          return 403, ""
The script consists of three parts: Reading the public key with the start of the API, extracting the header information (not given here), and the actual verification that is embedded in a try-catch expression.

Pattern 2 – Verify within the API

In this section, we will implement the verification within our Flask API. There are packages in the Flask universe available like flask_jwt; however, they do not offer the full scope of features and, in particular, not for our case here. Instead, we again use the library jwcrypto like before. It is further assumed that you have access to the public key again, but this time, a specific directive is secured (here called: /secured ) – and also, access should only be granted to admin users.
from jwcrypto import jwt, jwk
from flask import Flask

app = Flask(__name__)

# Load the public key
with open('public.pem', 'rb') as pemfile:
        public_key = jwk.JWK.from_pem(pemfile.read())


@app.route("/secured")
def secured():
      token_raw = parse_token(request.headers["Authorization"])
      try:
            # decode token
        token = jwt.JWT(jwt=token_raw, key=public_key)
        if "admin" in token.claims["groups"]:
              return 200, "this is a secret information"
        else:
              return 403, ""
    except:
          return 403, ""
The setup is the same as in the previous example, but an additional check was added: whether the user is part of the admin group.

Summary

Two patterns were discussed, which both use Web Tokens to authenticate the request. While the first pattern’s charm lies in the decoupling of Authorization and API functionality, the second approach is more compact. It perfectly fits situations where the number of APIs is low, or the overhead caused by a separate service is too high. Instead, the sidecar pattern is perfect when you provide several APIs and like to unify the Authorization in one separate service. Web Tokens are used in popular authorization schemes like OAuth and OpenID Connect. In another blog post, I will focus on how they work and how they connect to the verification I presented. I hope you have enjoyed the post. Happy coding! “There is no way you know Thomas! What a coincidence! He’s my best friend’s saxophone teacher! This cannot be true. Here we are, at the other end of the world and we meet? What are the odds?” Surely, not only us here at STATWORX have experienced similar situations, be it in a hotel’s lobby, on the far away hiking trail or in the pub in that city you are completely new to. However, the very fact that this story is so suspiciously relatable might indicate that the chances of being socially connected to a stranger by a short chain of friends of friends isn’t too low after all. Lots of research has been done in this field, one particular popular result being the 6-Handshake-Rule. It states that most people living on this planet are connected by a chain of six handshakes or less. In the general setting of graphs, in which edges connect nodes, this is often referred to as the so-called small-world-effect. That is to say, the typical number of edges needed to get from node A to node B grows logarithmically in population size (i.e., # nodes). Note that, up until now, no geographic distance has been included in our consideration, which seems inadequate as it plays a significant role in social networks. When analyzing data from social networks such as Facebook or Instagram, three observations are especially striking:
  • Individuals who are geographically farther away from each other are less likely to connect, i.e., people from the same city are more likely to connect.
  • Few individuals have extremely many connections. Their number of connections follows a heavy-tailed Pareto distribution. Such individuals interact as hubs in the network. That could be a celebrity or just a really popular kid from school.
  • Connected individuals tend to share a set of other individuals they are both connected to (e.g., “friend cliques”). This is called the clustering property.

A model that explains these observations

Clearly, due to the characteristics of social networks mentioned above, only a model that includes geographic distances of the individuals makes sense. Also, to account for the occurrence of hubs, research has shown that reasonable models attach a random weight to each node (which can be regarded as the social attractiveness of the respective individual). A model that accounts for all three properties is the following: First, randomly place nodes in space with a certain intensity nu, which can be done with a Poisson process. Then, with an independent uniformly distributed weight U_x attached to each node x, every two nodes get connected by an edge with a probability

    \[p_{xy} =mathbb{P}(xtext{ is connected to } y):=varphi(frac{1}{beta}U_x^gamma U_y^gamma vert x-yvert^d)\]

where d is the dimension of the model (here: d=2 as we’ll simulate the model on the plane), model parameter gammain [0,1] controls the impact of the weights, model parameter beta>0 squishes the overall input to the profile function varphi, which is a monotonously decreasing, normalized function that returns a value between 0 and 1. That is, of course, what we want because its output shall be a probability. Take a moment to go through the effects of different beta and gamma on p_{xy}. A higher beta yields a smaller input value for varphi and thereby a higher connection probability. Similarly, a high gamma entails a lower U^gamma (as Uin [0,1]) and thus a higher connection probability. All this comprises a scale-free random connection model, which can be seen as a generalization of the model by Deprez and Würthrich. So much about the theory. Now that we have a model, we can use this to generate synthetic data that should look similar to real-world data. So let’s simulate!

Obtain data through simulation

From here on, the simulation is pretty straight forward. Don’t worry about specific numbers at this point.
library(tidyverse)
library(fields)
library(ggraph)
library(tidygraph)
library(igraph)
library(Matrix)

# Create a vector with plane dimensions. The random nodes will be placed on the plane.
plane <- c(1000, 1000)

poisson_para <- .5 * 10^(-3) # Poisson intensity parameter
beta <- .5 * 10^3
gamma <- .4

# Number of nodes is Poisson(gamma)*AREA - distributed
n_nodes <- rpois(1, poisson_para * plane[1] * plane[2])
weights <- runif(n_nodes) # Uniformly distributed weights

# The Poisson process locally yields node positions that are completely random.
x = plane[1] * runif(n_nodes)
y = plane[2] * runif(n_nodes)

phi <- function(z) { # Connection function
  pmin(z^(-1.8), 1)
} 
What we need next is some information on which nodes are connected. That means, we need to first get the connection probability by evaluating varphi for each pair of nodes and then flipping a biased coin, accordingly. This yields a 0-1 encoding, where 1 means that the two respective nodes are connected and 0 that they’re not. We can gather all the information for all pairs in a matrix that is commonly known as the adjacency matrix.
# Distance matrix needed as input
dist_matrix <-rdist(tibble(x,y))

weight_matrix <- outer(weights, weights, FUN="*") # Weight matrix

con_matrix_prob <- phi(1/beta * weight_matrix^gamma*dist_matrix^2)# Evaluation

con_matrix <- Matrix(rbernoulli(1,con_matrix_prob), sparse=TRUE) # Sampling
con_matrix <- con_matrix * upper.tri(con_matrix) # Transform to symmetric matrix
adjacency_matrix <- con_matrix + t(con_matrix)

Visualization with ggraph

In an earlier post we praised visNetwork as our go-to package for beautiful interactive graph visualization in R. While this remains true, we also have lots of love for tidyverse, and ggraph (spoken “g-giraffe”) as an extension of ggplot2 proves to be a comfortable alternative for non-interactive graph plots, especially when you’re already familiar with the grammar of graphics. In combination with tidygraph, which lets us describe a graph as two tidy data frames (one for the nodes and one for the edges), we obtain a full-fledged tidyverse experience. Note that tidygraph is based on a graph manipulation library called igraph from which it inherits all functionality and “exposes it in a tidy manner”. So before we get cracking with the visualization in ggraph, let’s first tidy up our data with tidygraph!

Make graph data tidy again!

Let’s attach some new columns to the node dataframe which will be useful for visualization. After we created the tidygraph object, this can be done in the usual dplyr fashion after using activate(nodes)and activate(edges)for accessing the respective dataframes.
# Create Igraph object
graph <- graph_from_adjacency_matrix(adjacency_matrix, mode="undirected")

# Make a tidygraph object from it. Igraph methods can still be called on it.
tbl_graph <- as_tbl_graph(graph)

hub_id <- which.max(degree(graph))

# Add spacial positions, hub distance and degree information to the nodes.
tbl_graph <- tbl_graph %>%
  activate(nodes) %>%
  mutate(
    x = x,
    y = y,
    hub_dist = replace_na(bfs_dist(root = hub_id), Inf),
    degree = degree(graph),
    friends_of_friends = replace_na(local_ave_degree(), 0),
    cluster = as.factor(group_infomap())
  )
Tidygraph supports most of igraphs methods, either directly or in the form of wrappers. This also applies to most of the functions used above. For example breadth-first search is implemented as the bfs_* family, wrapping igraph::bfs(), the group_graphfamily wraps igraphs clustering functions and local_ave_degree() wraps igraph::knn().

Let’s visualize!

GGraph is essentially built around three components: Nodes, Edges and Layouts. Nodes that are connected by edges compose a graph which can be created as an igraph object. Visualizing the igraph object can be done in numerous ways: Remember that nodes usually are not endowed with any coordinates. Therefore, arranging them in space can be done pretty much arbitrarily. In fact, there’s a specific research branch called graph drawing that deals with finding a good layout for a graph for a given purpose. Usually, the main criteria of a good layout are aesthetics (which is often interchangeable with clearness) and capturing specific graph properties. For example, a layout may force the nodes to form a circle, a star, two parallel lines, or a tree (if the graph’s data allows for it). Other times you might want to have a layout with a minimal number of intersecting edges. Fortunately, in ggraph all the layouts from igraph can be used. We start with a basic plot by passing the data and the layout to ggraph(), similar to what you would do with ggplot() in ggplot2. We can then add layers to the plot. Nodes can be created by using geom_node_point()and edges by using geom_edge_link(). From then on, it’s full-on ggplot2-style.
# Add coord_fixed() for fixed axis ratio!
basic <- tbl_graph %>%
  ggraph(layout = tibble(V(.)x, V(.)y)) +
  geom_edge_link(width = .1) +
  geom_node_point(aes(size = degree, color = degree)) +
  scale_color_gradient(low = "dodgerblue2", high = "firebrick4") +
  coord_fixed() +
  guides(size = FALSE)
To see more clearly what nodes are essential to the network, the degree, which is the number of edges a node is connected with, was highlighted for each node. Another way of getting a good overview of the graph is to show a visual decomposition of the components. Nothing easier than that!
cluster <- tbl_graph %>%
  ggraph(layout = tibble(V(.)x, V(.)y)) +
  geom_edge_link(width = .1) +
  geom_node_point(aes(size = degree, color = cluster)) +
  coord_fixed() +
  theme(legend.position = "none")
Wouldn’t it be interesting to visualize the reach of a hub node? Let’s do it with a facet plot:
# Copy of tbl_graph with columns that indicate weather in n - reach of hub.
reach_graph <- function(n) {
  tbl_graph %>%
    activate(nodes) %>%
    mutate(
      reach = n,
      reachable = ifelse(hub_dist <= n, "reachable", "non_reachable"),
      reachable = ifelse(hub_dist == 0, "Hub", reachable)
    )
}
# Tidygraph allows to bind graphs. This means binding rows of the node and edge dataframes.
evolving_graph <- bind_graphs(reach_graph(0), reach_graph(1), reach_graph(2), reach_graph(3))

evol <- evolving_graph %>%
  ggraph(layout = tibble(V(.)x, V(.)y)) +
  geom_edge_link(width = .1, alpha = .2) +
  geom_node_point(aes(size = degree, color = reachable)) +
  scale_size(range = c(.5, 2)) +
  scale_color_manual(values = c("Hub" = "firebrick4",
                                "non_reachable" = "#00BFC4",
                                "reachable" = "#F8766D","")) +
  coord_fixed() +
  facet_nodes(~reach, ncol = 4, nrow = 1, labeller = label_both) +
  theme(legend.position = "none")

A curious observation

At this point, there are many graph properties (including the three above but also cluster sizes and graph distances) that are worth taking a closer look at, but this is beyond the scope of this blogpost. However, let’s look at one last thing. Somebody just recently told me about a very curious fact about social networks that seems paradoxical at first: Your average friend on Facebook (or Instagram) has way more friends than the average user of that platform. It sounds odd, but if you think about it for a second, it is not too surprising. Sampling from the pool of your friends is very different from sampling from all users on the platform (entirely at random). It’s exactly those very prominent people who have a much higher probability of being among your friends. Hence, when calculating the two averages, we receive very different results.
As can be seen, the model also reflects that property: In the small excerpt of the graph that we simulate, the average node has a degree of around 5 (blue intercept). The degree of connected nodes is over 10 on average (red intercept).

Conclusion

In the first part, I introduced a model that describes the features of real-life data of social networks well. In the second part, we obtained artificial data from that model and used it to create an igraph object (by means of the adjacency matrix). The latter can then be transformed into a tidygraph object, allowing us to easily make manipulation on the node and edge tibble to calculate any graph statistic (e.g., the degree) we like. Further, the tidygraph object is then used for conveniently visualizing the network through Ggraph. I hope that this post has sparked your interest in network modeling and has given you an idea of how seamlessly graph manipulation and visualization with Tidygraph and Ggraph merge into the usual tidyverse workflow. Have a wonderful day! In my previous blog post, I have shown you how to run your R-scripts inside a docker container. For many of the projects we work on here at STATWORX, we end up using the RShiny framework to build our R-scripts into interactive applications. Using containerization for the deployment of ShinyApps has a multitude of advantages. There are the usual suspects such as easy cloud deployment, scalability, and easy scheduling, but it also addresses one of RShiny’s essential drawbacks: Shiny creates only a single R session per app, meaning that if multiple users access the same app, they all work with the same R session, leading to a multitude of problems. With the help of Docker, we can address this issue and start a container instance for every user, circumventing this problem by giving every user access to their own instance of the app and their individual corresponding R session. If you’re not familiar with building R-scripts into a docker image or with Docker terminology, I would recommend you to first read my previous blog post. So let’s move on from simple R-scripts and run entire ShinyApps in Docker now!

The Setup

Setting up a project

It is highly advisable to use RStudio’s project setup when working with ShinyApps, especially when using Docker. Not only do projects make it easy to keep your RStudio neat and tidy, but they also allow us to use the renv package to set up a package library for our specific project. This will come in especially handy when installing the needed packages for our app to the Docker image. For demonstration purposes, I decided to use an example app created in a previous blog post, which you can clone from the STATWORX GitHub repository. It is located in the “example-app” subfolder and consists of the three typical scripts used by ShinyApps (global.R, ui.R, and server.R) as well as files belonging to the renv package library. If you choose to use the example app linked above, then you won’t have to set up your own RStudio Project, you can instead open “example-app.Rproj”, which opens the project context I have already set up. If you choose to work along with an app of your own and haven’t created a project for it yet, you can instead set up your own by following the instructions provided by RStudio.

Setting up a package library

The RStudio project I provided already comes with a package library stored in the renv.lock file. If you prefer to work with your own app, you can create your own renv.lock file by installing the renv package from within your RStudio project and executing renv::init(). This initializes renv for your project and creates a renv.lock file in your project root folder. You can find more information on renv over at RStudio’s introduction article on it.

The Dockerfile

The Dockerfile is once again the central piece of creating a Docker image. We now aim to repeat this process for an entire app where we previously only built a single script into an image. The step from a single script to a folder with multiple scripts is small, but there are some significant changes needed to make our app run smoothly.
# Base image https://hub.docker.com/u/rocker/
FROM rocker/shiny:latest

# system libraries of general use
## install debian packages
RUN apt-get update -qq && apt-get -y --no-install-recommends install 
    libxml2-dev 
    libcairo2-dev 
    libsqlite3-dev 
    libmariadbd-dev 
    libpq-dev 
    libssh2-1-dev 
    unixodbc-dev 
    libcurl4-openssl-dev 
    libssl-dev

## update system libraries
RUN apt-get update && 
    apt-get upgrade -y && 
    apt-get clean

# copy necessary files
## app folder
COPY /example-app ./app
## renv.lock file
COPY /example-app/renv.lock ./renv.lock

# install renv & restore packages
RUN Rscript -e 'install.packages("renv")'
RUN Rscript -e 'renv::consent(provided = TRUE)'
RUN Rscript -e 'renv::restore()'

# expose port
EXPOSE 3838

# run app on container start
CMD ["R", "-e", "shiny::runApp('/app', host = '0.0.0.0', port = 3838)"]

The base image

The first difference is in the base image. Because we’re dockerizing a ShinyApp here, we can save ourselves a lot of work by using the rocker/shiny base image. This image handles the necessary dependencies for running a ShinyApp and comes with multiple R packages already pre-installed.

Necessary files

It is necessary to copy all relevant scripts and files for your app to your Docker image, so the Dockerfile does precisely that by copying the entire folder containing the app to the image. We can also make use of renv to handle package installation for us. This is why we first copy the renv.lock file to the image separately. We also need to install the renv package separately by using the Dockerfile’s ability to execute R-code by prefacing it with RUN Rscript -e. This package installation allows us to then call renv directly and restore our package library inside the image with renv::restore(). Now our entire project package library will be installed in our Docker image, with the exact same version and source of all the packages as in your local development environment. All this with just a few lines of code in our Dockerfile.

Starting the App at Runtime

At the very end of our Dockerfile, we tell the container to execute the following R-command:
shiny::runApp('/app', host = '0.0.0.0', port = 3838)
The first argument allows us to specify the file path to our scripts, which in our case is ./app. For the exposed port, I have chosen 3838, as this is the default choice for RStudio Server, but can be freely changed to whatever suits you best. With the final command in place every container based on this image will start the app in question automatically at runtime (and of course close it again once it’s been terminated).

The Finishing Touches

With the Dockerfile set up we’re now almost finished. All that remains is building the image and starting a container of said image.

Building the image

We open the terminal, navigate to the folder containing our new Dockerfile, and start the building process:
docker build -t my-shinyapp-image . 

Starting a container

After the building process has finished, we can now test our newly built image by starting a container:
docker run -d --rm -p 3838:3838 my-shinyapp-image
And there it is, running on localhost:3838.
docker-shiny-app-example

Outlook

Now that you have your ShinyApp running inside a Docker container, it is ready for deployment! Having containerized our app already makes this process a lot easier; there are further tools we can employ to ensure state of the art security, scalability, and seamless deployment. Stay tuned until next time, when we’ll go deeper into the full range of RShiny and Docker capabilities by introducing ShinyProxy. As a data scientist, it is always tempting to focus on the newest technology, the latest release of your favorite deep learning network, or a fancy statistical test you recently heard of. While all of this is very important, and we here at STATWORX are proud to use the latest open-source machine learning tools, it is often more important to take a step back and have a closer look at the problem we want to solve. In this article, I want to show you the importance of framing your business question in a different way – the data science way. Once the problem is clearly defined, we are more than happy to apply the newest fancy algorithm. But let’s start from the beginning!

The Initial Problem

Management View

Let’s assume for a moment that you are a data scientist here at STATWORX. Monday morning, at 10 o’clock the telephone rings, and a manager of an international bank is on the phone. After a bit of back and forth, the bank manager explains that they have a problem with defaulting loans and they need a program that predicts loans which are going to default in the future. Unfortunately, he must end the call now, but he’ll catch up with you later. In the meanwhile, you start to make sense of the problem.

Data Scientist View

While it’s clear for the bank manager that he provided you with all necessary information, you grab another cup of coffee, lean back in your chair and recap the problem:
  • The bank lends money to customers today
  • The customer promises the bank to pay back the loan bit by bit over the next couple of months/years
  • Unfortunately, some of the customers are not able to do so and are going to default on the loan
So far everything is fine. The bank will give you data of the past and you are instructed to make a prediction. Fair enough, but what specifically was there to predict again? Do they need to know whether every single loan is going to default or not? Are they more concerned about the default trend throughout the whole bank?

Data Science Explanation

From a data science perspective, we differentiate between two sorts of problems: Classification and Regression tasks. The way we prepare the data and the models we apply are inherently different between the two tasks. Classification problems, as the name suggested, assign data points into a specific category. For bank loans, one approach could be to construct two categories:
  • The loan defaulted
  • The loan is still performing
On the other hand, the output of a Regression problem is a continuous variable. In this case, this could be:
  • The percentage of loans which are going to default in a given month
  • The total amount of money the bank will lose in a given month
From now on, it’s paramount to evaluate with the clients what problem they actually want to solve. While it’s a lot of fun to play around with the best tech stack, it is of the highest importance to never forget about the business needs of the client. I’ll present you two possible scenarios, one for the classification and one for the regression case.
Regression and Classification Task

Scenario Classification Problem

Management View

For the next day, you set up a phone conference with the manager and decision-makers of the bank to discuss the overall direction of the project. The management board of the bank decided that it is more important to focus on the default prediction of single loans, instead of the overall default trend. Now you know that you have to solve a classification problem. Further, you ask the board what exactly they expect from the model. Manager A: I want to have the best performing model possible! Manager B: As long as it predicts reality as accurate as possible, I’m happy 🙂 Manager C: As long as it catches every defaulted loan for sure… Manager A: … but of course, it should not predict too many loans wrong!

Data Scientist View

You try to match every requirement from the bank. Understandably, the bank wants to have the perfect model, which makes little to no mistakes. Unfortunately, there is always an error. You are still unsure which error is worse for the bank. To properly continue your work, it is important to define with the client which problem exactly to solve and, therefore, which error to minimize. Some options could be:
  • Catch every loan that will default
  • Make sure the model does not classify a performing loan as a defaulted loan
  • Some kind of weighted average between both of them
Have a look at the right chart above to see how it could look like.

Data Science Explanation

To generate predictions, you have to train a model on the given data. To tell the model how well it performed and to punish it for mistakes, it is necessary to define an error metric. The choice of the error metric always depends on the business case. From a technical point of view, it is possible to model nearly every business case, however, there are four metrics that are used in most classification problems.

    \[Accuracy = frac{# : correctly : classified : loans}{# : loans}\]

This metric measures, as the name suggests, how accurate the model can predict the loan status. While this is the most basic metric one can think of, it’s also a dangerous one. Let’s say the bank tells us that roughly 5% of the loans on the balance sheet default. If, for some reason, our model never predicts defaults. In other words, the model classifies every loan as a non-defaulting loan. The accuracy is immediately 95/100 = 95%. For datasets where the classes are highly imbalanced, it is usually a good idea to discard accuracy.

    \[Recall = frac{# : correctly : classified : defaulted : loans }{# : all : in : reality : defaulted : loans}\]

Optimizing the machine learning algorithm for recall would ensure that the algorithm catches as many defaulted loans as possible. On the flip side, an algorithm that predicts perfectly all defaulted loans as a default is often the result that the algorithm predicts too many loans as defaulted. Many loans that are not going to default are also flagged as default.

    \[Precision = frac{# : correctly : classified : defaulted : loans}{# : all : as : default : predicted : loans}\]

High precision ensures that all of the loans the algorithm flags as a default are classified correctly. This is done at the expense of the overall amount of loans which are flagged as default. Therefore, it might not be possible to flag every loan which is going to default as a default, but the loans which are flagged as defaults are most likely really going to default.

    \[F-beta : score = (1+beta^2) * frac{Precision * Recall}{(beta^2 * Precision) * Recall}\]

Empirically speaking, an increase in recall is almost always associated with a decrease in precision and vice versa. Often, it is desired to balance precision and recall somehow. This can be done with the F-beta score.

Scenario Regression Problem

Management View

During the phone conference (same one as in the classification scenario), the decision-makers from the bank announced that they want to predict the overall default trends. While that’s already important information, you evaluate with the client what exactly their business need is. At the end you’ll end up with a list of requirements: Manager A: It’s important to match the overall trend as close as possible. Manager B: During normal times, I won’t pay too much attention to the model. However, it is absolute necessary that the model performs well in extreme market situations. Manager C: To make it as easy and convenient to use as possible and to be able to explain it to the regulating agency, it has to be as explainable as possible.

Data Science View

Similar to the last scenario, there is again a tradeoff. It is a business problem to define which error is worse. Is every deviation from the ground truth equally bad? Is a certain stability of the prediction error important? Does the client care about the volatility of the forecast? Does a baseline exists? Have a look at the left chart above to see how it could look like.

Data Science Explanation

Once again, there are several metrices one can choose from. The best metric always depends on the business need. Here are the most common ones:

    \[Mean : Absolut : Error = frac{1}{n} sum(actual : output - predicted : output)\]

The Mean Absolute Error (MAE) calculates, as the name suggests, how far the predictions are off in absolute terms. While the number is easy to interpret, it treats every deviation in the same way. On a 100-day time interval, being every day off by 1 unit is the same as predicting everything, every day right but being one day off by 100 units.

    \[Mean : Squared : Error = frac{1}{n} sum(actual : output - predicted : output)^2\]

The Mean Squared Error (MSE) also calculates the difference between the actual and the predicted output. This time, the deviation is weighted. Extreme values are worse compared to many small errors.

    \[R^2 = 1 - frac{MSE(model)}{MSE(baseline)}\]

The R^2 compares the model to evaluate against a simple baseline model. The advantage is that the output is easy to interpret. A value of 1 describes the perfect model, while a value close to 0 (or even negative) describes a model with room for improvement. This metric is commonly used among economists and econometricians and, therefore, in some industries a metric to consider. However, it is also relatively easy to get a high R^2, which makes it hard to compare.

    \[Mean : Absolute : Percentage : Error = frac{1}{n} sum frac{actual : output - predicted : output}{actual : output} * 100\]

The Mean Absolute Percentage Error (MAPE) measures the absolute deviation from the predicted values. On the contrary to the MAE, the MAPE displays them in relative terms, which makes it very easy to interpret and to compare. The MAPE has its own set of drawbacks and caveats. Fortunately, my colleague Jan already wrote an article about it. Check it out if you want to learn more about it here

Conclusion

In either one of the cases, the classification or the regression case, the “right” answer to the problem depends on how the problem is actually defined. Before applying the latest machine learning algorithm, it is crucial that the business question is well defined. A strong collaboration with the client team is necessary and is the key to achieving the best result for the client. There is no one-size-fits-all data science solution. Even though the underlying problem is the same for every stakeholder in the bank, it might be worth it to train several models for every department. It all boils down to the business needs! We still haven’t covered several other problems, which might arise in subsequent steps. How is the default of a loan defined? What is the prediction horizon? Do we have enough data to cover all business cycles? Is the model just used internally or do we have to explain the model to a regulating agency? Should we optimize the model for some kind of internal resource constraints? To discuss this and more, feel free to reach out to me at dominique.lade@statworx.com or send me a message via LinkedIn.
As a data scientist, it is always tempting to focus on the newest technology, the latest release of your favorite deep learning network, or a fancy statistical test you recently heard of. While all of this is very important, and we here at STATWORX are proud to use the latest open-source machine learning tools, it is often more important to take a step back and have a closer look at the problem we want to solve. In this article, I want to show you the importance of framing your business question in a different way – the data science way. Once the problem is clearly defined, we are more than happy to apply the newest fancy algorithm. But let’s start from the beginning!

The Initial Problem

Management View

Let’s assume for a moment that you are a data scientist here at STATWORX. Monday morning, at 10 o’clock the telephone rings, and a manager of an international bank is on the phone. After a bit of back and forth, the bank manager explains that they have a problem with defaulting loans and they need a program that predicts loans which are going to default in the future. Unfortunately, he must end the call now, but he’ll catch up with you later. In the meanwhile, you start to make sense of the problem.

Data Scientist View

While it’s clear for the bank manager that he provided you with all necessary information, you grab another cup of coffee, lean back in your chair and recap the problem: So far everything is fine. The bank will give you data of the past and you are instructed to make a prediction. Fair enough, but what specifically was there to predict again? Do they need to know whether every single loan is going to default or not? Are they more concerned about the default trend throughout the whole bank?

Data Science Explanation

From a data science perspective, we differentiate between two sorts of problems: Classification and Regression tasks. The way we prepare the data and the models we apply are inherently different between the two tasks. Classification problems, as the name suggested, assign data points into a specific category. For bank loans, one approach could be to construct two categories: On the other hand, the output of a Regression problem is a continuous variable. In this case, this could be: From now on, it’s paramount to evaluate with the clients what problem they actually want to solve. While it’s a lot of fun to play around with the best tech stack, it is of the highest importance to never forget about the business needs of the client. I’ll present you two possible scenarios, one for the classification and one for the regression case.
Regression and Classification Task

Scenario Classification Problem

Management View

For the next day, you set up a phone conference with the manager and decision-makers of the bank to discuss the overall direction of the project. The management board of the bank decided that it is more important to focus on the default prediction of single loans, instead of the overall default trend. Now you know that you have to solve a classification problem. Further, you ask the board what exactly they expect from the model. Manager A: I want to have the best performing model possible! Manager B: As long as it predicts reality as accurate as possible, I’m happy 🙂 Manager C: As long as it catches every defaulted loan for sure… Manager A: … but of course, it should not predict too many loans wrong!

Data Scientist View

You try to match every requirement from the bank. Understandably, the bank wants to have the perfect model, which makes little to no mistakes. Unfortunately, there is always an error. You are still unsure which error is worse for the bank. To properly continue your work, it is important to define with the client which problem exactly to solve and, therefore, which error to minimize. Some options could be: Have a look at the right chart above to see how it could look like.

Data Science Explanation

To generate predictions, you have to train a model on the given data. To tell the model how well it performed and to punish it for mistakes, it is necessary to define an error metric. The choice of the error metric always depends on the business case. From a technical point of view, it is possible to model nearly every business case, however, there are four metrics that are used in most classification problems.

    \[Accuracy = frac{# : correctly : classified : loans}{# : loans}\]

This metric measures, as the name suggests, how accurate the model can predict the loan status. While this is the most basic metric one can think of, it’s also a dangerous one. Let’s say the bank tells us that roughly 5% of the loans on the balance sheet default. If, for some reason, our model never predicts defaults. In other words, the model classifies every loan as a non-defaulting loan. The accuracy is immediately 95/100 = 95%. For datasets where the classes are highly imbalanced, it is usually a good idea to discard accuracy.

    \[Recall = frac{# : correctly : classified : defaulted : loans }{# : all : in : reality : defaulted : loans}\]

Optimizing the machine learning algorithm for recall would ensure that the algorithm catches as many defaulted loans as possible. On the flip side, an algorithm that predicts perfectly all defaulted loans as a default is often the result that the algorithm predicts too many loans as defaulted. Many loans that are not going to default are also flagged as default.

    \[Precision = frac{# : correctly : classified : defaulted : loans}{# : all : as : default : predicted : loans}\]

High precision ensures that all of the loans the algorithm flags as a default are classified correctly. This is done at the expense of the overall amount of loans which are flagged as default. Therefore, it might not be possible to flag every loan which is going to default as a default, but the loans which are flagged as defaults are most likely really going to default.

    \[F-beta : score = (1+beta^2) * frac{Precision * Recall}{(beta^2 * Precision) * Recall}\]

Empirically speaking, an increase in recall is almost always associated with a decrease in precision and vice versa. Often, it is desired to balance precision and recall somehow. This can be done with the F-beta score.

Scenario Regression Problem

Management View

During the phone conference (same one as in the classification scenario), the decision-makers from the bank announced that they want to predict the overall default trends. While that’s already important information, you evaluate with the client what exactly their business need is. At the end you’ll end up with a list of requirements: Manager A: It’s important to match the overall trend as close as possible. Manager B: During normal times, I won’t pay too much attention to the model. However, it is absolute necessary that the model performs well in extreme market situations. Manager C: To make it as easy and convenient to use as possible and to be able to explain it to the regulating agency, it has to be as explainable as possible.

Data Science View

Similar to the last scenario, there is again a tradeoff. It is a business problem to define which error is worse. Is every deviation from the ground truth equally bad? Is a certain stability of the prediction error important? Does the client care about the volatility of the forecast? Does a baseline exists? Have a look at the left chart above to see how it could look like.

Data Science Explanation

Once again, there are several metrices one can choose from. The best metric always depends on the business need. Here are the most common ones:

    \[Mean : Absolut : Error = frac{1}{n} sum(actual : output - predicted : output)\]

The Mean Absolute Error (MAE) calculates, as the name suggests, how far the predictions are off in absolute terms. While the number is easy to interpret, it treats every deviation in the same way. On a 100-day time interval, being every day off by 1 unit is the same as predicting everything, every day right but being one day off by 100 units.

    \[Mean : Squared : Error = frac{1}{n} sum(actual : output - predicted : output)^2\]

The Mean Squared Error (MSE) also calculates the difference between the actual and the predicted output. This time, the deviation is weighted. Extreme values are worse compared to many small errors.

    \[R^2 = 1 - frac{MSE(model)}{MSE(baseline)}\]

The R^2 compares the model to evaluate against a simple baseline model. The advantage is that the output is easy to interpret. A value of 1 describes the perfect model, while a value close to 0 (or even negative) describes a model with room for improvement. This metric is commonly used among economists and econometricians and, therefore, in some industries a metric to consider. However, it is also relatively easy to get a high R^2, which makes it hard to compare.

    \[Mean : Absolute : Percentage : Error = frac{1}{n} sum frac{actual : output - predicted : output}{actual : output} * 100\]

The Mean Absolute Percentage Error (MAPE) measures the absolute deviation from the predicted values. On the contrary to the MAE, the MAPE displays them in relative terms, which makes it very easy to interpret and to compare. The MAPE has its own set of drawbacks and caveats. Fortunately, my colleague Jan already wrote an article about it. Check it out if you want to learn more about it here

Conclusion

In either one of the cases, the classification or the regression case, the “right” answer to the problem depends on how the problem is actually defined. Before applying the latest machine learning algorithm, it is crucial that the business question is well defined. A strong collaboration with the client team is necessary and is the key to achieving the best result for the client. There is no one-size-fits-all data science solution. Even though the underlying problem is the same for every stakeholder in the bank, it might be worth it to train several models for every department. It all boils down to the business needs! We still haven’t covered several other problems, which might arise in subsequent steps. How is the default of a loan defined? What is the prediction horizon? Do we have enough data to cover all business cycles? Is the model just used internally or do we have to explain the model to a regulating agency? Should we optimize the model for some kind of internal resource constraints? To discuss this and more, feel free to reach out to me at dominique.lade@statworx.com or send me a message via LinkedIn.