1. Business Understanding¶
Kundenabwanderung ist die Entscheidung eines Kunden, eine bestimmte Unternehmensdienstleistung nicht mehr zu kaufen. Sie stellt somit das Gegenstück zur langfristigen Kundenbindung dar. Um die Kundenbindung zu fördern, müssen Unternehmen Analysen einsetzen, die frühzeitig erkennen, ob ein Kunde das Unternehmen verlassen will. So können Marketing- und Vertriebsmaßnahmen eingeleitet werden, bevor es zum eigentlichen Kundenverlust kommt. In diesem Zusammenhang beantwortet der Service konkret diese beiden Fragen: Wie hoch ist die Wahrscheinlichkeit, dass anhand historischer Daten vorhergesagt werden kann, ob ein Kunde zu einem anderen Anbieter abwandert? Welche Faktoren führen zur Kundenabwanderung?
2. Daten und Datenverständnis¶
Zur Visualisierung und Implementierung des Dienstes wird der Datensatz eines fiktiven Telekommunikationsunternehmens verwendet. Dieser besteht aus 7.043 Zeilen. Jede Zeile beschreibt einen Kunden mit 21 Spalten. Jede Spalte definiert verschiedene Merkmale (Attribute) der Kunden. Anhand der Daten soll klassifiziert werden, ob ein Kunde das Unternehmen verlässt oder nicht. Zu diesem Zweck enthalten die historischen Daten die Zielvariable "Churn", die Auskunft darüber gibt, ob ein Kunde abgewandert ist oder nicht.ot.
2.1. Import von relevanten Modulen¶
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
import warnings
import imblearn
from statsmodels.stats.outliers_influence import variance_inflation_factor
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from imblearn.under_sampling import InstanceHardnessThreshold
from sklearn import metrics
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
sns.set()
# remove warnings
warnings.filterwarnings('ignore')
2.2. Daten einlesen¶
data_raw = pd.read_csv("https://storage.googleapis.com/ml-service-repository-datastorage/Customer_Churn_Prediction_data.csv")
data_raw.head()
customerID | gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | ... | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 7590-VHVEG | Female | 0 | Yes | No | 1 | No | No phone service | DSL | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 29.85 | 29.85 | No |
1 | 5575-GNVDE | Male | 0 | No | No | 34 | Yes | No | DSL | Yes | ... | Yes | No | No | No | One year | No | Mailed check | 56.95 | 1889.5 | No |
2 | 3668-QPYBK | Male | 0 | No | No | 2 | Yes | No | DSL | Yes | ... | No | No | No | No | Month-to-month | Yes | Mailed check | 53.85 | 108.15 | Yes |
3 | 7795-CFOCW | Male | 0 | No | No | 45 | No | No phone service | DSL | Yes | ... | Yes | Yes | No | No | One year | No | Bank transfer (automatic) | 42.30 | 1840.75 | No |
4 | 9237-HQITU | Female | 0 | No | No | 2 | Yes | No | Fiber optic | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 70.70 | 151.65 | Yes |
5 rows × 21 columns
data_raw.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 7043 entries, 0 to 7042 Data columns (total 21 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 customerID 7043 non-null object 1 gender 7043 non-null object 2 SeniorCitizen 7043 non-null int64 3 Partner 7043 non-null object 4 Dependents 7043 non-null object 5 tenure 7043 non-null int64 6 PhoneService 7043 non-null object 7 MultipleLines 7043 non-null object 8 InternetService 7043 non-null object 9 OnlineSecurity 7043 non-null object 10 OnlineBackup 7043 non-null object 11 DeviceProtection 7043 non-null object 12 TechSupport 7043 non-null object 13 StreamingTV 7043 non-null object 14 StreamingMovies 7043 non-null object 15 Contract 7043 non-null object 16 PaperlessBilling 7043 non-null object 17 PaymentMethod 7043 non-null object 18 MonthlyCharges 7043 non-null float64 19 TotalCharges 7043 non-null object 20 Churn 7043 non-null object dtypes: float64(1), int64(2), object(18) memory usage: 1.1+ MB
Der Datensatz besteht aus 7.043 Zeilen und 21 Attributen:
Zu prognostizierendes Attribut: Abwanderung
Numerische Attribute: Vertragsdauer, Monatsgebühren und Gesamtgebühren.
Kategorische Attribute: CustomerID, Gender, SeniorCitizen, Partner, Angehörige, PhoneService, MultipleLines, InternetService, OnlineSecurity, OnlineBackup, DeviceProtection, TechSupport, StreamingTV, StreamingMovies, Contract, PaperlessBilling, PaymentMethod. Es wurden nicht alle Datentypen korrekt eingelesen:
TotalCharges muss ein numerischer Wert sein -> in Float umwandeln
# test for duplicates
data_raw[data_raw.duplicated(keep=False)]
customerID | gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | ... | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn |
---|
0 rows × 21 columns
Keine Duplikate im Datensatz
2.3. Datenbereinigung¶
Hier sollten die ersten Lesefehler korrigiert werden, bevor die eigentliche Datenaufbereitung erfolgt.
# convert total charges
data_raw['TotalCharges'] = pd.to_numeric(data_raw['TotalCharges'], errors='coerce')
data_raw.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 7043 entries, 0 to 7042 Data columns (total 21 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 customerID 7043 non-null object 1 gender 7043 non-null object 2 SeniorCitizen 7043 non-null int64 3 Partner 7043 non-null object 4 Dependents 7043 non-null object 5 tenure 7043 non-null int64 6 PhoneService 7043 non-null object 7 MultipleLines 7043 non-null object 8 InternetService 7043 non-null object 9 OnlineSecurity 7043 non-null object 10 OnlineBackup 7043 non-null object 11 DeviceProtection 7043 non-null object 12 TechSupport 7043 non-null object 13 StreamingTV 7043 non-null object 14 StreamingMovies 7043 non-null object 15 Contract 7043 non-null object 16 PaperlessBilling 7043 non-null object 17 PaymentMethod 7043 non-null object 18 MonthlyCharges 7043 non-null float64 19 TotalCharges 7032 non-null float64 20 Churn 7043 non-null object dtypes: float64(2), int64(2), object(17) memory usage: 1.1+ MB
Die Konvertierung der TotalCharges hat zu Nullwerten geführt. Diese Nullwerte müssen korrigiert werden.
# Remove zero values
# axis = 0 rows / axis = 1 columns
data_no_mv = data_raw.dropna(axis=0)
2.4. Deskriptive Analytik¶
In diesem Teil des Notebooks soll das Datenverständnis mit Hilfe der deskriptiven Analytik berücksichtigt werden. Nach dem Entfernen der Nullwerte besteht der Datensatz aus 7032 Zeilen, von denen jeweils eine einen Kunden beschreibt, und 21 Spalten, die die Attribute des Kunden definieren. Mit Hilfe dieser Daten soll versucht werden, zu klassifizieren, ob ein Kunde abwandert oder nicht. Zu diesem Zweck enthalten die historischen Daten die Zielvariable "Churn", die Auskunft darüber gibt, ob ein Kunde abgewandert ist.
2.4.1. Kontinuierliche Merkmale¶
Zunächst werden die Verteilungen der kontinuierlichen Merkmale einzeln untersucht und in einem zweiten Schritt die kategorialen Merkmale in Zusammenhang mit der Zielvariablen gesetzt.
# load continous features
numeric_data = data_no_mv.select_dtypes(include=[np.number])
Besitz¶
sns.displot(numeric_data["tenure"])
<seaborn.axisgrid.FacetGrid at 0x23e942418e0>
- Keine Normalverteilung erkennbar.
- Keine Ausreißer erkennbar.
- Kunden sind potentiell gleichmäßig über die einzelnen Monate verteilt, aber eine große Anzahl von Kunden ist noch nicht lange im Unternehmen.
sns.distplot(data_no_mv[data_no_mv.Churn == 'No']["tenure"],
bins=10,
color='orange',
label='Non-Churn',
kde=True)
sns.distplot(data_no_mv[data_no_mv.Churn == 'Yes']["tenure"],
bins=10,
color='blue',
label='Churn',
kde=True)
<AxesSubplot:xlabel='tenure', ylabel='Density'>
Kunden, die noch nicht lange bei dem Unternehmen sind, werden eher abwandern als langjährige Kunden.
Monatliche Kosten¶
sns.distplot(numeric_data["MonthlyCharges"])
<AxesSubplot:xlabel='MonthlyCharges', ylabel='Density'>
- Es ist keine Normalverteilung erkennbar.
- Die meisten Kunden befinden sich im vorderen Teil der Verteilung und zahlen relativ niedrige monatliche Gebühren.
- Dennoch verläuft die Kurve gleichmäßig mit einem erneuten Anstieg nach hinten und dementsprechend können keine Ausreißer identifiziert werden.
sns.distplot(data_no_mv[data_no_mv.Churn == 'No']["MonthlyCharges"],
bins=10,
color='orange',
label='Non-Churn',
kde=True)
sns.distplot(data_no_mv[data_no_mv.Churn == 'Yes']["MonthlyCharges"],
bins=10,
color='blue',
label='Churn',
kde=True)
<AxesSubplot:xlabel='MonthlyCharges', ylabel='Density'>
- Kunden mit niedrigen monatlichen Gebühren sind eher abwanderungsbereit.
- Der Abwanderungstrend zwischen Kunden, die abwandern, und Kunden, die nicht abwandern, gleicht sich an, wenn die monatlichen Gebühren steigen.
Gesamtkosten¶
sns.distplot(numeric_data["TotalCharges"])
<AxesSubplot:xlabel='TotalCharges', ylabel='Density'>
- Die Kurve flacht nach hinten hin extrem stark ab.
- Es sind Ähnlichkeiten mit der Exponentialverteilung zu erkennen. -> Test der logarithmischen Transformation zur Erreichung einer Normalverteilung.
- Es ist fraglich, ob es Ausreißer im hinteren Teil gibt. -> Boxplot
# Boxplot für TotalCharges erstellen, um sicherzustellen, dass keine Ausreißer vorhanden sind.
plt.boxplot(numeric_data["TotalCharges"])
plt.show()
- Boxplot zeigt keine Ausreißer.
- Dies bedeutet, dass auch bei den Gesamtkosten keine Ausreißer festgestellt werden können.
# logarithmic transformation
log_charges = np.log(data_no_mv["TotalCharges"])
sns.distplot(log_charges)
<AxesSubplot:xlabel='TotalCharges', ylabel='Density'>
- Auch die Transformation mit Hilfe des Logarithmus führt nicht zu einer Normalverteilung.
- Vor weiteren Transformationen sollte zunächst die Korrelation mit anderen Variablen untersucht werden.
sns.distplot(data_no_mv[data_no_mv.Churn == 'No']["TotalCharges"],
bins=10,
color='orange',
label='Non-Churn',
kde=True)
sns.distplot(data_no_mv[data_no_mv.Churn == 'Yes']["TotalCharges"],
bins=10,
color='blue',
label='Churn',
kde=True)
<AxesSubplot:xlabel='TotalCharges', ylabel='Density'>
Die Verteilung ist über die gesamte Bandbreite der Kosten sowohl bei den abwandernden als auch bei den nicht abwandernden Kunden nahezu identisch.
Korrelationsanalyse¶
# correlation between continous features
feature_corr = numeric_data.drop("SeniorCitizen", axis=1).corr()
sns.heatmap(feature_corr, annot=True, cmap='coolwarm')
<AxesSubplot:>
Die Korrelationsmatrix zeigt, dass die Attribute "Tenure" und "TotalCharges" eine kritische positive Korrelation von über 0,8 aufweisen. Diese Beziehung wird später im Zusammenhang mit der Multikollinearität erneut untersucht und muss entfernt werden.
Streudiagramme mit kontinuierlichen Merkmalen und Ziel¶
sns.scatterplot(data=data_no_mv, x="tenure", y="MonthlyCharges", hue="Churn")
<AxesSubplot:xlabel='tenure', ylabel='MonthlyCharges'>
Das Streudiagramm deutet darauf hin, dass Kunden im oberen linken Bereich, d. h. Kunden mit hohen monatlichen Kosten und kurzer Betriebszugehörigkeit, am ehesten abwandern.
sns.scatterplot(data=data_no_mv, x="tenure", y="TotalCharges", hue="Churn")
<AxesSubplot:xlabel='tenure', ylabel='TotalCharges'>
Es besteht eine rein logische, lineare Beziehung zwischen der Dauer der Betriebszugehörigkeit und den in Rechnung gestellten Gesamtkosten. Je länger eine Person Kunde ist, desto mehr monatliche Beträge musste sie bereits zahlen.
2.4.2. Kategorische Merkmale¶
Abwanderung (Ziel)¶
Zunächst wird die Verteilung der Zielvariablen Churn untersucht.
# produce pie chart for churn
# generate procentage relationship
churn_rate = data_no_mv.Churn.value_counts() / len(data_no_mv.Churn)
# Plot
labels = 'Keine Abwanderung', 'Abwanderung'
fig, ax = plt.subplots()
ax.pie(churn_rate, labels=labels, autopct='%.f%%')
ax.set_title('Abwanderung im Vergleich zur Nicht-Abwanderung')
Text(0.5, 1.0, 'Abwanderung im Vergleich zur Nicht-Abwanderung')
- Die Abwanderungen machen etwa 27 % des gesamten Datensatzes aus, während die Nicht-Abwanderungen etwa 73 % ausmachen.
- Dies ist ein unausgewogener Datensatz und eine andere Metrik muss in der Bewertungsphase verwendet werden.
Geschlecht¶
sns.countplot(x="gender", hue="Churn", data=data_no_mv)
plt.show()
Die Abwanderungsrate zwischen Männern und Frauen ist ungefähr gleich hoch.
Senioren¶
sns.countplot(x="SeniorCitizen", hue="Churn", data=data_no_mv)
plt.show()
Bei Kunden, die als Senioren eingestuft werden, ist die Wahrscheinlichkeit höher, dass sie abwandern.
Partner¶
sns.countplot(x="Partner", hue="Churn", data=data_no_mv)
plt.show()
Kunden, die keinen Partner haben, sind eher bereit, abzuwandern.
Angehörige¶
sns.countplot(x="Dependents", hue="Churn", data=data_no_mv)
plt.show()
Kunden, die Verwandte haben, sind eher bereit, abzuwandern.
Mehrere Anschlüsse¶
sns.countplot(x="MultipleLines", hue="Churn", data=data_no_mv)
plt.show()
Bei Kunden, die mehrere Anschlüsse haben, ist die Wahrscheinlichkeit einer Abwanderung geringer.
Internet Service¶
sns.countplot(x="InternetService", hue="Churn", data=data_no_mv)
plt.show()
Wenn ein Kunde einen Glasfaseranschluss hat, ist es wahrscheinlicher, dass er ausfällt als ein Kunde mit DSL.
Online-Sicherheit¶
sns.countplot(x="OnlineSecurity", hue="Churn", data=data_no_mv)
plt.show()
Kunden, die den Internet-Sicherheitsdienst nicht nutzen, werden eher abwandern.
Online Backup¶
sns.countplot(x="OnlineBackup", hue="Churn", data=data_no_mv)
plt.show()
Personen, die keine Online-Datensicherung nutzen, sind eher bereit, umzuziehen.
Geräteschutz¶
sns.countplot(x="DeviceProtection", hue="Churn", data=data_no_mv)
plt.show()
Kunden, die keinen zusätzlichen Geräteschutz erworben haben, werden mit größerer Wahrscheinlichkeit migrieren.
Technischer Support¶
sns.countplot(x="TechSupport", hue="Churn", data=data_no_mv)
plt.show()
Kunden, die keinen technischen Support in Anspruch nehmen, werden eher abwandern.
Streaming-TV/ Streaming-Filme¶
for col in ["StreamingTV", "StreamingMovies"]:
sns.countplot(x=col, hue='Churn', data=data_no_mv)
plt.show()
Die Hinzunahme von Film- und TV-Streaming-Angeboten hat kaum Auswirkungen auf die Abwanderungsrate.
Papierlose Abrechnung¶
sns.countplot(x="PaperlessBilling", hue="Churn", data=data_no_mv)
plt.show()
Kunden, die ohne Rechnung bezahlen, werden eher abwandern.
Zahlungsmethode¶
sns.countplot(x="PaymentMethod", hue="Churn", data=data_no_mv)
plt.show()
Kunden, die mit elektronischen Schecks bezahlen, wandern deutlich häufiger ab als Kunden, die eine andere Zahlungsmethode verwenden.
Vertrag¶
sns.countplot(x="Contract", hue="Churn", data=data_no_mv)
plt.show()
Bei Kunden mit kurzfristigen Verträgen ist die Wahrscheinlichkeit größer, dass sie das Unternehmen verlassen, als bei Kunden mit längerfristigen Verträgen.
3. Aufbereitung der Daten¶
3.1. Reduzieren der Kunden-ID¶
# Removing the Customer ID, it does not add value to the model
data_prep = data_no_mv.drop("customerID", axis = 1)
3.2. Umkodierung der kategorialen Variablen¶
# Convert binary variables to 1 and 0 with Yes and No
bin_var = ["Partner","Dependents","PhoneService","PaperlessBilling","Churn"]
def binaer_umwandeln(x):
return x.map({'Yes':1,'No':0})
data_prep[bin_var]=data_prep[bin_var].apply(binaer_umwandeln)
data_prep.head()
gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | OnlineBackup | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Female | 0 | 1 | 0 | 1 | 0 | No phone service | DSL | No | Yes | No | No | No | No | Month-to-month | 1 | Electronic check | 29.85 | 29.85 | 0 |
1 | Male | 0 | 0 | 0 | 34 | 1 | No | DSL | Yes | No | Yes | No | No | No | One year | 0 | Mailed check | 56.95 | 1889.50 | 0 |
2 | Male | 0 | 0 | 0 | 2 | 1 | No | DSL | Yes | Yes | No | No | No | No | Month-to-month | 1 | Mailed check | 53.85 | 108.15 | 1 |
3 | Male | 0 | 0 | 0 | 45 | 0 | No phone service | DSL | Yes | No | Yes | Yes | No | No | One year | 0 | Bank transfer (automatic) | 42.30 | 1840.75 | 0 |
4 | Female | 0 | 0 | 0 | 2 | 1 | No | Fiber optic | No | No | No | No | No | No | Month-to-month | 1 | Electronic check | 70.70 | 151.65 | 1 |
# create dummies
data_enc = pd.get_dummies(data_prep, drop_first=True)
data_enc.head()
SeniorCitizen | Partner | Dependents | tenure | PhoneService | PaperlessBilling | MonthlyCharges | TotalCharges | Churn | gender_Male | ... | TechSupport_Yes | StreamingTV_No internet service | StreamingTV_Yes | StreamingMovies_No internet service | StreamingMovies_Yes | Contract_One year | Contract_Two year | PaymentMethod_Credit card (automatic) | PaymentMethod_Electronic check | PaymentMethod_Mailed check | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 1 | 0 | 1 | 0 | 1 | 29.85 | 29.85 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
1 | 0 | 0 | 0 | 34 | 1 | 0 | 56.95 | 1889.50 | 0 | 1 | ... | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
2 | 0 | 0 | 0 | 2 | 1 | 1 | 53.85 | 108.15 | 1 | 1 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
3 | 0 | 0 | 0 | 45 | 0 | 0 | 42.30 | 1840.75 | 0 | 1 | ... | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
4 | 0 | 0 | 0 | 2 | 1 | 1 | 70.70 | 151.65 | 1 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
5 rows × 31 columns
# Dropping of dummies that also contain No phone service and No Internet service
dup_variables = ["OnlineSecurity_No internet service","OnlineBackup_No internet service", "TechSupport_No internet service","StreamingTV_No internet service","StreamingMovies_No internet service", "DeviceProtection_No internet service","MultipleLines_No phone service"]
data_enc.drop(dup_variables, axis=1, inplace=True)
3.3. Test auf Multikollinearität¶
Um ein korrektes Funktionieren der späteren Regression zu gewährleisten, darf keine Multikollinearität zwischen den Variablen bestehen. Das Vorhandensein einer solchen wird mit Hilfe der Bibliothek Statsmodel überprüft.
# independent variables
vif_test = data_enc.drop("Churn", axis=1)
# VIF dataframe
vif_data = pd.DataFrame()
vif_data["feature"] = vif_test.columns
# VIF for each Feature
vif_data["VIF"] = [variance_inflation_factor(vif_test.values, i)
for i in range(len(vif_test.columns))]
print(vif_data)
feature VIF 0 SeniorCitizen 1.376564 1 Partner 2.824725 2 Dependents 1.969391 3 tenure 20.482153 4 PhoneService 47.244378 5 PaperlessBilling 2.956951 6 MonthlyCharges 212.353073 7 TotalCharges 21.374002 8 gender_Male 2.021331 9 MultipleLines_Yes 2.861614 10 InternetService_Fiber optic 17.695260 11 InternetService_No 8.234451 12 OnlineSecurity_Yes 2.682712 13 OnlineBackup_Yes 2.909898 14 DeviceProtection_Yes 2.992570 15 TechSupport_Yes 2.758343 16 StreamingTV_Yes 4.928957 17 StreamingMovies_Yes 5.090603 18 Contract_One year 2.056188 19 Contract_Two year 3.487502 20 PaymentMethod_Credit card (automatic) 1.984196 21 PaymentMethod_Electronic check 2.955994 22 PaymentMethod_Mailed check 2.383290
"MonthlyCharges" hat den höchsten VIF und wird aus dem Datensatz entfernt.
data_enc.drop("MonthlyCharges", axis=1, inplace=True)
# the independent variables set
vif_test = data_enc.drop("Churn", axis=1)
# VIF dataframe
vif_data = pd.DataFrame()
vif_data["feature"] = vif_test.columns
# VIF for each Feature
vif_data["VIF"] = [variance_inflation_factor(vif_test.values, i)
for i in range(len(vif_test.columns))]
print(vif_data)
feature VIF 0 SeniorCitizen 1.366018 1 Partner 2.817414 2 Dependents 1.961947 3 tenure 17.073930 4 PhoneService 9.277446 5 PaperlessBilling 2.796488 6 TotalCharges 18.028499 7 gender_Male 1.942509 8 MultipleLines_Yes 2.514269 9 InternetService_Fiber optic 4.186492 10 InternetService_No 3.473225 11 OnlineSecurity_Yes 1.986701 12 OnlineBackup_Yes 2.182678 13 DeviceProtection_Yes 2.299462 14 TechSupport_Yes 2.099655 15 StreamingTV_Yes 2.749724 16 StreamingMovies_Yes 2.771330 17 Contract_One year 2.056169 18 Contract_Two year 3.468149 19 PaymentMethod_Credit card (automatic) 1.820729 20 PaymentMethod_Electronic check 2.535918 21 PaymentMethod_Mailed check 1.982063
"TotalCharges" hat den höchsten VIF und wird aus dem Datensatz entfernt.
data_enc.drop("TotalCharges", axis=1, inplace=True)
# the independent variables set
vif_test = data_enc.drop("Churn", axis=1)
# VIF dataframe
vif_data = pd.DataFrame()
vif_data["feature"] = vif_test.columns
# calculating VIF for each feature
vif_data["VIF"] = [variance_inflation_factor(vif_test.values, i)
for i in range(len(vif_test.columns))]
print(vif_data)
feature VIF 0 SeniorCitizen 1.363244 1 Partner 2.816895 2 Dependents 1.956413 3 tenure 7.530356 4 PhoneService 9.260839 5 PaperlessBilling 2.757816 6 gender_Male 1.931277 7 MultipleLines_Yes 2.426699 8 InternetService_Fiber optic 3.581328 9 InternetService_No 3.321342 10 OnlineSecurity_Yes 1.947904 11 OnlineBackup_Yes 2.093763 12 DeviceProtection_Yes 2.241375 13 TechSupport_Yes 2.060410 14 StreamingTV_Yes 2.636855 15 StreamingMovies_Yes 2.661529 16 Contract_One year 2.055971 17 Contract_Two year 3.456061 18 PaymentMethod_Credit card (automatic) 1.794059 19 PaymentMethod_Electronic check 2.401970 20 PaymentMethod_Mailed check 1.967082
Keine der Variablen hat jetzt einen VIF von mehr als 10.
3.4. Merkmalsskalierung¶
# Separate target variable and predictors
y = data_enc["Churn"]
X = data_enc.drop(labels = ["Churn"], axis = 1)
# Scaling the variables
num_features = ['tenure']
scaler = StandardScaler()
X[num_features] = scaler.fit_transform(X[num_features])
X.head()
SeniorCitizen | Partner | Dependents | tenure | PhoneService | PaperlessBilling | gender_Male | MultipleLines_Yes | InternetService_Fiber optic | InternetService_No | ... | OnlineBackup_Yes | DeviceProtection_Yes | TechSupport_Yes | StreamingTV_Yes | StreamingMovies_Yes | Contract_One year | Contract_Two year | PaymentMethod_Credit card (automatic) | PaymentMethod_Electronic check | PaymentMethod_Mailed check | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 1 | 0 | -1.280248 | 0 | 1 | 0 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
1 | 0 | 0 | 0 | 0.064303 | 1 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
2 | 0 | 0 | 0 | -1.239504 | 1 | 1 | 1 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
3 | 0 | 0 | 0 | 0.512486 | 0 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 1 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
4 | 0 | 0 | 0 | -1.239504 | 1 | 1 | 0 | 0 | 1 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
5 rows × 21 columns
3.5. Undersampling¶
iht = InstanceHardnessThreshold(random_state=0,estimator=LogisticRegression (solver='lbfgs', multi_class='auto'))
X_resampled, y_resampled = iht.fit_resample(X, y)
3.6. Erstellen von Test- & Trainingsdaten¶
# Split dataset in train and test datasets
# The default value of 80% to 20% is used.
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, random_state=110)
4. Modellierung und Auswertung¶
4.1. Logistische Regression¶
Zur Lösung des Problems wird die logistische Regression verwendet. Hierfür werden die beiden Bibliotheken Statsmodels und Scikit-Learn verwendet. Die komplette Auswertung des Modells findet erst im Unterkapitel zu Scikit-Learn statt.
Statistische Modelle¶
Training und Vorhersage¶
# add constant
X_const = sm.add_constant(X_train)
# create model
log_reg = sm.Logit(y_train, X_const).fit()
print(log_reg.summary())
Optimization terminated successfully. Current function value: 0.082006 Iterations 11 Logit Regression Results ============================================================================== Dep. Variable: Churn No. Observations: 2803 Model: Logit Df Residuals: 2781 Method: MLE Df Model: 21 Date: Thu, 21 Oct 2021 Pseudo R-squ.: 0.8817 Time: 15:00:28 Log-Likelihood: -229.86 converged: True LL-Null: -1942.4 Covariance Type: nonrobust LLR p-value: 0.000 ========================================================================================================= coef std err z P>|z| [0.025 0.975] --------------------------------------------------------------------------------------------------------- const 5.1912 0.828 6.266 0.000 3.567 6.815 SeniorCitizen 0.4609 0.457 1.008 0.313 -0.435 1.357 Partner -0.4112 0.302 -1.362 0.173 -1.003 0.181 Dependents -0.5746 0.294 -1.952 0.051 -1.151 0.002 tenure -2.9281 0.309 -9.468 0.000 -3.534 -2.322 PhoneService -1.2307 0.544 -2.261 0.024 -2.298 -0.164 PaperlessBilling 1.2621 0.288 4.385 0.000 0.698 1.826 gender_Male -0.1334 0.255 -0.524 0.600 -0.633 0.366 MultipleLines_Yes 1.0865 0.336 3.231 0.001 0.427 1.746 InternetService_Fiber optic 3.1681 0.400 7.916 0.000 2.384 3.952 InternetService_No -2.8314 0.567 -4.992 0.000 -3.943 -1.720 OnlineSecurity_Yes -1.7901 0.321 -5.581 0.000 -2.419 -1.161 OnlineBackup_Yes -0.3203 0.309 -1.036 0.300 -0.926 0.286 DeviceProtection_Yes 0.4336 0.331 1.312 0.190 -0.214 1.082 TechSupport_Yes -0.8710 0.329 -2.648 0.008 -1.516 -0.226 StreamingTV_Yes 1.1971 0.351 3.414 0.001 0.510 1.884 StreamingMovies_Yes 1.4263 0.374 3.815 0.000 0.693 2.159 Contract_One year -3.5720 0.488 -7.317 0.000 -4.529 -2.615 Contract_Two year -6.5206 0.584 -11.164 0.000 -7.665 -5.376 PaymentMethod_Credit card (automatic) -0.0720 0.313 -0.230 0.818 -0.686 0.542 PaymentMethod_Electronic check 1.2794 0.406 3.154 0.002 0.484 2.075 PaymentMethod_Mailed check -0.3240 0.398 -0.813 0.416 -1.105 0.457 ========================================================================================================= Possibly complete quasi-separation: A fraction 0.37 of observations can be perfectly predicted. This might indicate that there is complete quasi-separation. In this case some parameters will not be identified.
Das trainierte Modell zeigt statistisch nicht-signifikante Variablen an. Dies ist gegeben, wenn der Wert P>|z| größer als 0,05 ist und es sich nicht um die Konstante handelt.
# Removing the statistically non-significant features (P>|z|> 0.05)
insignificant_features = ["Partner", "gender_Male", "OnlineBackup_Yes", "DeviceProtection_Yes", "PaymentMethod_Credit card (automatic)","PaymentMethod_Mailed check"]
X_train.drop(insignificant_features, axis=1, inplace=True)
X_test.drop(insignificant_features, axis=1, inplace=True)
Nun kann ein zweites Modell erstellt werden:
# new model
X_const = sm.add_constant(X_train)
log_reg2 = sm.Logit(y_train, X_const).fit()
print(log_reg2.summary())
Optimization terminated successfully. Current function value: 0.083077 Iterations 11 Logit Regression Results ============================================================================== Dep. Variable: Churn No. Observations: 2803 Model: Logit Df Residuals: 2787 Method: MLE Df Model: 15 Date: Thu, 21 Oct 2021 Pseudo R-squ.: 0.8801 Time: 15:00:28 Log-Likelihood: -232.87 converged: True LL-Null: -1942.4 Covariance Type: nonrobust LLR p-value: 0.000 ================================================================================================== coef std err z P>|z| [0.025 0.975] -------------------------------------------------------------------------------------------------- const 4.7119 0.718 6.566 0.000 3.305 6.118 SeniorCitizen 0.3954 0.458 0.864 0.387 -0.501 1.292 Dependents -0.7328 0.262 -2.797 0.005 -1.246 -0.219 tenure -2.9242 0.297 -9.845 0.000 -3.506 -2.342 PhoneService -1.2073 0.540 -2.235 0.025 -2.266 -0.149 PaperlessBilling 1.2161 0.285 4.273 0.000 0.658 1.774 MultipleLines_Yes 1.0989 0.331 3.320 0.001 0.450 1.748 InternetService_Fiber optic 3.1159 0.391 7.966 0.000 2.349 3.883 InternetService_No -2.8462 0.529 -5.381 0.000 -3.883 -1.809 OnlineSecurity_Yes -1.7441 0.313 -5.576 0.000 -2.357 -1.131 TechSupport_Yes -0.8357 0.325 -2.569 0.010 -1.473 -0.198 StreamingTV_Yes 1.2193 0.348 3.508 0.000 0.538 1.901 StreamingMovies_Yes 1.4394 0.368 3.908 0.000 0.717 2.161 Contract_One year -3.4572 0.471 -7.337 0.000 -4.381 -2.534 Contract_Two year -6.3299 0.557 -11.372 0.000 -7.421 -5.239 PaymentMethod_Electronic check 1.3103 0.362 3.623 0.000 0.601 2.019 ================================================================================================== Possibly complete quasi-separation: A fraction 0.36 of observations can be perfectly predicted. This might indicate that there is complete quasi-separation. In this case some parameters will not be identified.
Keine statistisch nicht signifikanten Variablen mehr. Das endgültige Modell wurde modelliert:
# final model
X_const = sm.add_constant(X_train)
log_reg_final = sm.Logit(y_train, X_const).fit()
print(log_reg_final.summary())
Optimization terminated successfully. Current function value: 0.083077 Iterations 11 Logit Regression Results ============================================================================== Dep. Variable: Churn No. Observations: 2803 Model: Logit Df Residuals: 2787 Method: MLE Df Model: 15 Date: Thu, 21 Oct 2021 Pseudo R-squ.: 0.8801 Time: 15:00:28 Log-Likelihood: -232.87 converged: True LL-Null: -1942.4 Covariance Type: nonrobust LLR p-value: 0.000 ================================================================================================== coef std err z P>|z| [0.025 0.975] -------------------------------------------------------------------------------------------------- const 4.7119 0.718 6.566 0.000 3.305 6.118 SeniorCitizen 0.3954 0.458 0.864 0.387 -0.501 1.292 Dependents -0.7328 0.262 -2.797 0.005 -1.246 -0.219 tenure -2.9242 0.297 -9.845 0.000 -3.506 -2.342 PhoneService -1.2073 0.540 -2.235 0.025 -2.266 -0.149 PaperlessBilling 1.2161 0.285 4.273 0.000 0.658 1.774 MultipleLines_Yes 1.0989 0.331 3.320 0.001 0.450 1.748 InternetService_Fiber optic 3.1159 0.391 7.966 0.000 2.349 3.883 InternetService_No -2.8462 0.529 -5.381 0.000 -3.883 -1.809 OnlineSecurity_Yes -1.7441 0.313 -5.576 0.000 -2.357 -1.131 TechSupport_Yes -0.8357 0.325 -2.569 0.010 -1.473 -0.198 StreamingTV_Yes 1.2193 0.348 3.508 0.000 0.538 1.901 StreamingMovies_Yes 1.4394 0.368 3.908 0.000 0.717 2.161 Contract_One year -3.4572 0.471 -7.337 0.000 -4.381 -2.534 Contract_Two year -6.3299 0.557 -11.372 0.000 -7.421 -5.239 PaymentMethod_Electronic check 1.3103 0.362 3.623 0.000 0.601 2.019 ================================================================================================== Possibly complete quasi-separation: A fraction 0.36 of observations can be perfectly predicted. This might indicate that there is complete quasi-separation. In this case some parameters will not be identified.
# prediction
y_hat = log_reg_final.predict(sm.add_constant(X_test))
# Statsmodel only gives the probabilities, therefore rounding is required.
prediction = list(map(round, y_hat))
4.1. Auswertung¶
Zur Evaluation sollen mehrere Metriken verwendet werden, die komfortabler mittels Scikit-Learn erzeugt werden können. Deshalb wird das identische Modell wie mit Statsmodels nochmals in Scikit-Learn erzeugt.
Scikit-Learn¶
Training und Vorhersage¶
# C is needed to build the exact same model as with Statsmodels; source: https://www.kdnuggets.com/2016/06/regularization-logistic-regression.html
logistic_model = LogisticRegression(random_state=0, C=1e8)
# prediction with testdata
result = logistic_model.fit(X_train,y_train)
prediction_test = logistic_model.predict(X_test)
prediction_train = logistic_model.predict(X_train)
Evaluation¶
# Accuracy Score
acc = metrics.accuracy_score(y_test, prediction_test)
print('Accuracy with testdata: {}'.format(acc))
Accuracy with testdata: 0.9882352941176471
Die Genauigkeit deutet auf ein überdurchschnittliches Modell hin. Allerdings handelt es sich um einen unausgewogenen Datensatz. Daher müssen weitere Metriken analysiert werden.
# classification report
print("traindata:")
print(classification_report(y_train,prediction_train))
print("testdata:")
print(classification_report(y_test,prediction_test))
traindata: precision recall f1-score support 0 0.96 1.00 0.98 1374 1 1.00 0.96 0.98 1429 accuracy 0.98 2803 macro avg 0.98 0.98 0.98 2803 weighted avg 0.98 0.98 0.98 2803 testdata: precision recall f1-score support 0 0.98 1.00 0.99 495 1 1.00 0.98 0.99 440 accuracy 0.99 935 macro avg 0.99 0.99 0.99 935 weighted avg 0.99 0.99 0.99 935
Höhere Genauigkeit für das Training als für den Testdatensatz. Insgesamt sind die Werte für den Test- und den Trainingsdatensatz sehr ähnlich. Daher sollte nicht von einem Overfitting oder Underfitting ausgegangen werden.
# Confusion matrix testdata
cm = confusion_matrix(y_test,prediction_test)
df_cm = pd.DataFrame(cm, index=['No Churn','Churn'], columns=['No Churn', 'Churn'],)
fig = plt.figure(figsize=[10,7])
heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=14)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=14)
plt.ylabel('True label')
plt.xlabel('Predicted label')
Text(0.5, 39.5, 'Predicted label')
# metrics from confusion matrix
tn, fp, fn, tp = cm.ravel()
recall = tp/(fn+tp)
precision = tp/(tp+fp)
print("True Negatives: " + str(tn))
print("False Positives: " + str(fp))
print("False Negatives: " + str(fn))
print("True Positives: " + str(tp))
print("Recall: " + str(recall))
print("Precision: " + str(precision))
True Negatives: 493 False Positives: 2 False Negatives: 9 True Positives: 431 Recall: 0.9795454545454545 Precision: 0.9953810623556582
Präzision und Recall vermitteln ein viel realistischeres Bild des Modells. Es erreicht eine Präzision von rund 68 % und eine Wiederauffindbarkeit von 52 %. Der Recall ist für den Anwendungsfall eindeutig wichtiger und muss daher auf Kosten der Präzision verbessert werden.
# ROC-Kurve, AUC
fig, ax = plt.subplots(figsize=(8,6))
ax.set_title('ROC Kurve')
plot = metrics.plot_roc_curve(logistic_model, X_test, y_test, ax=ax);
ax.plot([0,1], [0,1], '--');
Der AUC der ROC-Kurve ergibt einen guten Wert von 0,84. Daraus lässt sich schließen, dass durch die Optimierung des Schwellenwertes Optimierungspotenzial besteht.
4.3. Interpretation¶
Zunächst sollen jedoch die Ergebnisse für das Unternehmen veranschaulicht werden und es soll geklärt werden, welche Kunden zur Abwanderung führen und welche gegen eine Abwanderung sprechen.
# Read out regression coefficients and thus find out importance of individual attributes
weights = pd.Series(logistic_model.coef_[0],
index=X_train.columns.values)
weights.sort_values(ascending = False)
InternetService_Fiber optic 3.115901 StreamingMovies_Yes 1.439381 PaymentMethod_Electronic check 1.310265 StreamingTV_Yes 1.219198 PaperlessBilling 1.216093 MultipleLines_Yes 1.098867 SeniorCitizen 0.395488 Dependents -0.732812 TechSupport_Yes -0.835712 PhoneService -1.207319 OnlineSecurity_Yes -1.744166 InternetService_No -2.846463 tenure -2.924275 Contract_One year -3.457173 Contract_Two year -6.329852 dtype: float64
# Graphical representation of key features that lead to churn.
weights = pd.Series(logistic_model.coef_[0],
index=X_train.columns.values)
print (weights.sort_values(ascending = False)[:7].plot(kind='bar'))
AxesSubplot(0.125,0.125;0.775x0.755)
Die drei Hauptmerkmale, die zur Abwanderung führen, sind:
- Der Glasfaserdienst (InternetService_Glasfaser),
- Die Online-Zahlungen (PaperlessBilling) und
- Das Abonnement des zusätzlichen Filmstreamingdienstes (StreamingMovies_Yes).
# Most important features that keep customers from churning
print(weights.sort_values(ascending = False)[-8:].plot(kind='bar'))
AxesSubplot(0.125,0.125;0.775x0.755)
Die drei wichtigsten Merkmale, die Kunden von der Abwanderung abhalten, sind:
- Die Verträge, die für zwei Jahre gekündigt werden können (Contract_Two year),
- Die Zeit, die man Kunde eines Unternehmens ist (Tenure) und
- kein Abonnement für den Internetdienst (InternetService_No).
4.4. Modell-Optimierung¶
Die Recall-Rate ist als Zielmetrik zu niedrig und muss daher erhöht werden. Daher werden die Metriken bei verschiedenen Schwellenwerten der logistischen Regression analysiert.
# Testing the metrics at different thresholds
threshold_list = [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,.7,.75,.8,.85,.9,.95,.99]
pred_proba_df = y_hat
for i in threshold_list:
print ('\n******** For a Threshold about {} ******'.format(i))
# Round up if value is above threshold
y_test_pred = pred_proba_df.apply(lambda x: 1 if x>i else 0)
# read metrics
test_accuracy = metrics.accuracy_score(y_test, y_test_pred)
print("Accuracy: {}".format(test_accuracy))
# Confusion matrix
c = confusion_matrix(y_test, y_test_pred)
tn, fp, fn, tp = c.ravel()
recall = tp/(fn+tp)
precision = tp/(tp+fp)
# print metrics
print("True Negatives: " + str(tn))
print("False Positives: " + str(fp))
print("False Negatives: " + str(fn))
print("True Positives: " + str(tp))
print("Recall: " + str(recall))
print("Precision: " + str(precision))
******** For a Threshold about 0.05 ****** Accuracy: 0.8588235294117647 True Negatives: 367 False Positives: 128 False Negatives: 4 True Positives: 436 Recall: 0.990909090909091 Precision: 0.7730496453900709 ******** For a Threshold about 0.1 ****** Accuracy: 0.9144385026737968 True Negatives: 420 False Positives: 75 False Negatives: 5 True Positives: 435 Recall: 0.9886363636363636 Precision: 0.8529411764705882 ******** For a Threshold about 0.15 ****** Accuracy: 0.9422459893048128 True Negatives: 446 False Positives: 49 False Negatives: 5 True Positives: 435 Recall: 0.9886363636363636 Precision: 0.8987603305785123 ******** For a Threshold about 0.2 ****** Accuracy: 0.9657754010695188 True Negatives: 468 False Positives: 27 False Negatives: 5 True Positives: 435 Recall: 0.9886363636363636 Precision: 0.9415584415584416 ******** For a Threshold about 0.25 ****** Accuracy: 0.9786096256684492 True Negatives: 481 False Positives: 14 False Negatives: 6 True Positives: 434 Recall: 0.9863636363636363 Precision: 0.96875 ******** For a Threshold about 0.3 ****** Accuracy: 0.9818181818181818 True Negatives: 486 False Positives: 9 False Negatives: 8 True Positives: 432 Recall: 0.9818181818181818 Precision: 0.9795918367346939 ******** For a Threshold about 0.35 ****** Accuracy: 0.986096256684492 True Negatives: 490 False Positives: 5 False Negatives: 8 True Positives: 432 Recall: 0.9818181818181818 Precision: 0.988558352402746 ******** For a Threshold about 0.4 ****** Accuracy: 0.9871657754010695 True Negatives: 491 False Positives: 4 False Negatives: 8 True Positives: 432 Recall: 0.9818181818181818 Precision: 0.9908256880733946 ******** For a Threshold about 0.45 ****** Accuracy: 0.9893048128342246 True Negatives: 493 False Positives: 2 False Negatives: 8 True Positives: 432 Recall: 0.9818181818181818 Precision: 0.9953917050691244 ******** For a Threshold about 0.5 ****** Accuracy: 0.9882352941176471 True Negatives: 493 False Positives: 2 False Negatives: 9 True Positives: 431 Recall: 0.9795454545454545 Precision: 0.9953810623556582 ******** For a Threshold about 0.55 ****** Accuracy: 0.9882352941176471 True Negatives: 493 False Positives: 2 False Negatives: 9 True Positives: 431 Recall: 0.9795454545454545 Precision: 0.9953810623556582 ******** For a Threshold about 0.6 ****** Accuracy: 0.9893048128342246 True Negatives: 494 False Positives: 1 False Negatives: 9 True Positives: 431 Recall: 0.9795454545454545 Precision: 0.9976851851851852 ******** For a Threshold about 0.65 ****** Accuracy: 0.9893048128342246 True Negatives: 494 False Positives: 1 False Negatives: 9 True Positives: 431 Recall: 0.9795454545454545 Precision: 0.9976851851851852 ******** For a Threshold about 0.7 ****** Accuracy: 0.9903743315508021 True Negatives: 495 False Positives: 0 False Negatives: 9 True Positives: 431 Recall: 0.9795454545454545 Precision: 1.0 ******** For a Threshold about 0.75 ****** Accuracy: 0.9903743315508021 True Negatives: 495 False Positives: 0 False Negatives: 9 True Positives: 431 Recall: 0.9795454545454545 Precision: 1.0 ******** For a Threshold about 0.8 ****** Accuracy: 0.9893048128342246 True Negatives: 495 False Positives: 0 False Negatives: 10 True Positives: 430 Recall: 0.9772727272727273 Precision: 1.0 ******** For a Threshold about 0.85 ****** Accuracy: 0.9882352941176471 True Negatives: 495 False Positives: 0 False Negatives: 11 True Positives: 429 Recall: 0.975 Precision: 1.0 ******** For a Threshold about 0.9 ****** Accuracy: 0.9871657754010695 True Negatives: 495 False Positives: 0 False Negatives: 12 True Positives: 428 Recall: 0.9727272727272728 Precision: 1.0 ******** For a Threshold about 0.95 ****** Accuracy: 0.9807486631016042 True Negatives: 495 False Positives: 0 False Negatives: 18 True Positives: 422 Recall: 0.9590909090909091 Precision: 1.0 ******** For a Threshold about 0.99 ****** Accuracy: 0.9497326203208556 True Negatives: 495 False Positives: 0 False Negatives: 47 True Positives: 393 Recall: 0.8931818181818182 Precision: 1.0
Ein Schwellenwert von 0,3 bietet ein besseres Ergebnis für die Anwendung. Er erhöht die Wiederauffindbarkeit auf ein zufriedenstellendes Niveau von 73,21 %, was zu Lasten der Präzision geht. Die Präzision ist jedoch vernachlässigbar.
Daraus ergeben sich die folgenden Werte:
# Threshold about 0,3
y_test_pred = pred_proba_df.apply(lambda x: 1 if x>0.30 else 0)
test_accuracy = metrics.accuracy_score(y_test, y_test_pred)
c = confusion_matrix(y_test, y_test_pred)
# read values from confusion matrix
tn, fp, fn, tp = c.ravel()
recall = tp/(fn+tp)
precision = tp/(tp+fp)
print(classification_report(y_test,y_test_pred))
# create confusion matrix
print("Confusion matrix for the new threshold:")
df_cm = pd.DataFrame(c, index=['No Churn','Churn'], columns=['No Churn', 'Churn'],)
fig = plt.figure(figsize=[10,7])
heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=14)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=14)
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
print(" ")
# print metrics
print("Metrics for the new threshold:")
print("Accuracy: {}".format(test_accuracy))
print("True Negatives: " + str(tn))
print("False Positives: " + str(fp))
print("False Negatives: " + str(fn))
print("True Positives: " + str(tp))
print("Recall: " + str(recall))
print("Precision: " + str(precision))
precision recall f1-score support 0 0.98 0.98 0.98 495 1 0.98 0.98 0.98 440 accuracy 0.98 935 macro avg 0.98 0.98 0.98 935 weighted avg 0.98 0.98 0.98 935 Confusion matrix for the new threshold:
Metrics for the new threshold: Accuracy: 0.9818181818181818 True Negatives: 486 False Positives: 9 False Negatives: 8 True Positives: 432 Recall: 0.9818181818181818 Precision: 0.9795918367346939
Erwartungsgemäß steigt die Rate der fälschlicherweise als abgewandert eingestuften Kunden. Im Gegenzug steigt aber auch die Anzahl der Kunden, die korrekt als Abwanderer vorhergesagt werden (True Positives). Wie in der Hausarbeit ausgeführt, ist dies essentiell, denn im Zweifelsfall würde ein Kunde fälschlicherweise vom Serviceteam angerufen werden und diesen Anruf sogar als guten Service wahrnehmen und längerfristig an das Unternehmen binden.
5. Deployment¶
# Separate individual (scaled) customer
customer_df = X_test.iloc[896]
# Overview about the customer
customer_df
SeniorCitizen 0.000000 Dependents 0.000000 tenure -0.302393 PhoneService 1.000000 PaperlessBilling 0.000000 MultipleLines_Yes 0.000000 InternetService_Fiber optic 0.000000 InternetService_No 1.000000 OnlineSecurity_Yes 0.000000 TechSupport_Yes 0.000000 StreamingTV_Yes 0.000000 StreamingMovies_Yes 0.000000 Contract_One year 0.000000 Contract_Two year 1.000000 PaymentMethod_Electronic check 0.000000 Name: 1544, dtype: float64
# execute prediction
cust_pred = logistic_model.predict([customer_df])
# evaluate results
def check_prediction(pred):
if pred[0] == 1:
print("The customer will probably churn! Inform Customer Relationship Management!")
else:
print("The customer probably will not churn.")
check_prediction(cust_pred)
The customer probably will not churn.
Zusammenfassung¶
Das Notebook hat gezeigt, wie eine logistische Regression verwendet werden kann, um die Abwanderung von Kunden im Telekommunikationssegment vorherzusagen.