Notebook zur Vorhersage von Werbung¶
In diesem Notizbuch wird der Advertising-Datensatz von Kaggle analysiert (https://www.kaggle.com/fayomi/advertising). Er besteht aus 10 Spalten mit insgesamt 1000 Zeilen. Der Anwendungsfall besteht in der Vorhersage, ob ein Website-Besucher auf eine Werbung klicken wird oder nicht, basierend auf seinen demografischen Daten und Daten zur Internetnutzung.
Der Ansatz des Notebooks basiert auf dem CRISP-DM-Modell, das die Phasen in einem Data-Science-Projekt klar unterteilt.
1. Business Understanding¶
Aus wirtschaftlicher Sicht ist es für Facebook notwendig, die Kunden so lange wie möglich auf seiner Streaming-Plattform zu halten. Nun stellt sich aber die Frage, wie dies erreicht werden kann. Das Problem dabei ist, dass Netflix zwar über eine gute Datenbasis verfügt, diese aber erst einmal aufbereitet und zu einem digitalen Angebot entwickelt werden muss. Außerdem muss festgelegt werden, welche Daten überhaupt vorhanden sind, welche Dienste auf Basis dieser Daten realisiert werden können und welchen Mehrwert der Kunde und Netflix selbst daraus ziehen. Als Service geht es in diesem Beitrag um ein Empfehlungsmodell, das dem Nutzer zu jedem angebotenen Film oder jeder Serie eine Auswahl an ähnlichen Angeboten vorschlägt.
2. Daten und Datenverständnis¶
In diesem Notebook wird der Anzeigendatensatz von Kaggle analysiert. Er besteht aus 10 Spalten mit insgesamt 1000 Zeilen. Der Anwendungsfall besteht in der Vorhersage, ob ein Website-Besucher auf der Grundlage seiner demografischen Daten und seiner Internetnutzungsdaten auf eine Anzeige klicken wird oder nicht. Der Zielwert für angeklickte Werbung ist perfekt zwischen den beiden Kategorien ausgeglichen (0,1), da der Mittelwert genau 0,5 beträgt. Dies bedeutet, dass es für beide Kategorien die gleiche Anzahl von Werten gibt (jeweils 500). Darüber hinaus können wir feststellen, dass die Merkmale "Ad Topic Line" und "City" sehr viele eindeutige Werte aufweisen (1000 bzw. 969 "eindeutige" Werte), was bedeutet, dass. Es ist zu erkennen, dass es erhebliche Unterschiede zwischen den Nutzerprofilen gibt. Nutzer, die auf eine Anzeige klicken (Clicked on Ad=1), verbringen im Durchschnitt weniger Zeit auf der Website, sind älter (ca. 40), haben ein geringeres Einkommen und nutzen das Internet deutlich weniger. Aus diesen Informationen lässt sich bereits ein grobes Nutzerprofil ableiten, das auch für Marketing und Vertrieb eines Unternehmens relevant sein könnte, um ihre Maßnahmen auf Basis der Nutzerprofile zu optimieren.
2.1. Import von relevanten Modulen¶
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
# Set style for the visualization libra
%matplotlib inline
sns.set_style('whitegrid')
plt.style.use("fivethirtyeight")
2. 2. Daten einlesen¶
# Load the CSV-file in a DataFrame
data = pd.read_csv('https://storage.googleapis.com/ml-service-repository-datastorage/Predicting_clicks_on_online_advertising_by_Facebook_data.csv')
data.head()
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 68.95 | 35 | 61833.90 | 256.09 | Cloned 5thgeneration orchestration | Wrightburgh | 0 | Tunisia | 2016-03-27 00:53:11 | 0 |
1 | 80.23 | 31 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 2016-04-04 01:39:02 | 0 |
2 | 69.47 | 26 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 2016-03-13 20:35:42 | 0 |
3 | 74.15 | 29 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 2016-01-10 02:31:19 | 0 |
4 | 68.37 | 35 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 2016-06-03 03:36:18 | 0 |
3. Datenanalyse¶
Zweck dieses Kapitels ist die Überprüfung, Analyse und Aufbereitung der Daten.
# Info of the DataFrame
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1000 entries, 0 to 999 Data columns (total 10 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Daily Time Spent on Site 1000 non-null float64 1 Age 1000 non-null int64 2 Area Income 1000 non-null float64 3 Daily Internet Usage 1000 non-null float64 4 Ad Topic Line 1000 non-null object 5 City 1000 non-null object 6 Male 1000 non-null int64 7 Country 1000 non-null object 8 Timestamp 1000 non-null object 9 Clicked on Ad 1000 non-null int64 dtypes: float64(3), int64(3), object(4) memory usage: 78.2+ KB
# Compute different metrics for each column
data.describe(include="all")
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
count | 1000.000000 | 1000.000000 | 1000.000000 | 1000.000000 | 1000 | 1000 | 1000.000000 | 1000 | 1000 | 1000.00000 |
unique | NaN | NaN | NaN | NaN | 1000 | 969 | NaN | 237 | 1000 | NaN |
top | NaN | NaN | NaN | NaN | Cloned 5thgeneration orchestration | Lisamouth | NaN | France | 2016-03-27 00:53:11 | NaN |
freq | NaN | NaN | NaN | NaN | 1 | 3 | NaN | 9 | 1 | NaN |
mean | 65.000200 | 36.009000 | 55000.000080 | 180.000100 | NaN | NaN | 0.481000 | NaN | NaN | 0.50000 |
std | 15.853615 | 8.785562 | 13414.634022 | 43.902339 | NaN | NaN | 0.499889 | NaN | NaN | 0.50025 |
min | 32.600000 | 19.000000 | 13996.500000 | 104.780000 | NaN | NaN | 0.000000 | NaN | NaN | 0.00000 |
25% | 51.360000 | 29.000000 | 47031.802500 | 138.830000 | NaN | NaN | 0.000000 | NaN | NaN | 0.00000 |
50% | 68.215000 | 35.000000 | 57012.300000 | 183.130000 | NaN | NaN | 0.000000 | NaN | NaN | 0.50000 |
75% | 78.547500 | 42.000000 | 65470.635000 | 218.792500 | NaN | NaN | 1.000000 | NaN | NaN | 1.00000 |
max | 91.430000 | 61.000000 | 79484.800000 | 269.960000 | NaN | NaN | 1.000000 | NaN | NaN | 1.00000 |
Fehlende Werte¶
# Number of missing values in each column
data.isnull().sum()
Daily Time Spent on Site 0 Age 0 Area Income 0 Daily Internet Usage 0 Ad Topic Line 0 City 0 Male 0 Country 0 Timestamp 0 Clicked on Ad 0 dtype: int64
Duplikate¶
# Displays duplicate records
data.duplicated().sum()
0
3.1 Explorative Datenanalyse¶
In diesem Kapitel werden erste Analysen und Visualisierungen vorgenommen.
# Create Pairplots
sns.pairplot(data, hue='Clicked on Ad')
<seaborn.axisgrid.PairGrid at 0x7f88e00349a0>
# For each label, count the occurence
data['Clicked on Ad'].value_counts()
0 500 1 500 Name: Clicked on Ad, dtype: int64
Es ist zu erkennen, dass der Datensatz perfekt ausgeglichen ist, d.h. es gibt genau 500 Datensätze für beide Klassen.
# User profile analysis
data.groupby('Clicked on Ad')['Daily Time Spent on Site', 'Age', 'Area Income',
'Daily Internet Usage'].mean()
# Scatterplot: Daily Time Spent on Site vs. Age in context of Clicked on Ad
sns.scatterplot(x="Daily Time Spent on Site", y="Age", data=data, hue="Clicked on Ad")
# Scatterplot: Daily Time Spent on Site vs. Area Income in context of Clicked on Ad
sns.scatterplot(x="Daily Time Spent on Site", y="Area Income", data=data, hue="Clicked on Ad")
# Scatterplot: Daily Time Spent on Site vs. Daily Internet Usage in context of Clicked on Ad
sns.scatterplot(x="Daily Time Spent on Site", y="Daily Internet Usage", data=data, hue="Clicked on Ad")
# Scatterplot: Age vs. Daily Internet Usage in context of Clicked on Ad
sns.scatterplot(x="Age", y="Daily Internet Usage", data=data, hue="Clicked on Ad")
3.2 Verteilungsdiagramme für alle Merkmale mit numerischen Werten¶
Verteilungsdiagramme werden erstellt, um Ausreißer in den Daten zu identifizieren und die Daten besser zu verstehen.
# Distribution plot of Age
sns.distplot(data["Age"])
plt.title("Age Distribution")
# Cut the left 1% and right 99% quantile to avoid outliers
q_small = data["Age"].quantile(0.01)
q_big = data["Age"].quantile(0.99)
data = data[(data["Age"]>q_small) & (data["Age"]<q_big)]
# Distribution plot of Daily Time Spent on Site
sns.distplot(data["Daily Time Spent on Site"])
plt.title("Daily Time Spent on Site Distribution")
# Cut the left 1% and right 99% quantile to avoid outliers
q_small = data["Daily Time Spent on Site"].quantile(0.01)
q_big = data["Daily Time Spent on Site"].quantile(0.99)
data = data[(data["Daily Time Spent on Site"]>q_small) & (data["Daily Time Spent on Site"]<q_big)]
# Distribution plot of Area Income
sns.distplot(data["Area Income"])
plt.title("Area Income Distribution")
# Cut the left 1% and right 99% quantile to avoid outliers
q_small = data["Area Income"].quantile(0.01)
q_big = data["Area Income"].quantile(0.99)
data = data[(data["Area Income"]>q_small) & (data["Area Income"]<q_big)]
# Distribution plot of Area Income with method Boxcox and lambda = 1.5
# The other functions have also been tried out, but the boxcox method fits the best
from scipy.stats import boxcox
#function = lambda x: 1/x or np.log(x) or np.sqrt(x)
#function = lambda x: np.log(x)
#log_data = data["Area Income"].apply(function)
data['Area Income'] = boxcox(data['Area Income'], lmbda=1.5)
sns.distplot(data['Area Income'])
plt.title("Area Income: Boxcox")
# Distribution plot of Daily Internet Usage
sns.distplot(data["Daily Internet Usage"])
plt.title("Daily Internet Usage Distribution")
# Cut the left 1% and right 99% quantile to avoid outliers
q_small = data["Daily Internet Usage"].quantile(0.01)
q_big = data["Daily Internet Usage"].quantile(0.99)
data = data[(data["Daily Internet Usage"]>q_small) & (data["Daily Internet Usage"]<q_big)]
# Distribution plot of Clicked on Ad
sns.distplot(data["Clicked on Ad"])
plt.title("Clicked on Ad Distribution")
4. Korrelationen¶
Nun werden die Korrelationen aller numerischen Merkmale berechnet und in einer Korrelationsmatrix dargestellt.
# Create heatmap
sns.heatmap(data.corr(), annot=True)
Es ist zu erkennen, dass die tägliche Verweildauer auf der Website und die tägliche Internetnutzung korrelieren. Es besteht auch eine starke negative Korrelation zwischen der täglichen Internetnutzung / der täglichen Verweildauer auf der Website und den angeklickten Anzeigen. Signifikante Korrelationen, die zur Streichung eines Merkmals führen, gibt es jedoch nicht (Annahme: wenn die Korrelation größer als 0,9 ist).
5. Datenaufbereitung¶
In diesem Abschnitt wird der Datensatz für das maschinelle Lernen vorbereitet.
5.1 Feature Engineering¶
In diesem Abschnitt wird das Feature Engineering durchgeführt. Hier werden wichtige Informationen aus den Rohdaten extrahiert.
5.1.1 Zeitstempel¶
Der Datensatz enthält ein Zeitstempel-Merkmal. Dies könnte für die Vorhersage wichtig sein, da es eine Korrelation zwischen dem Klick des Nutzers und der Uhrzeit geben kann.
# Extract datetime variables using timestamp column
data['Timestamp'] = pd.to_datetime(data['Timestamp'])
# Converting timestamp column into datatime object in order to extract new features
data['Month'] = data['Timestamp'].dt.month
# Creates a new column called Month
data['Day'] = data['Timestamp'].dt.day
# Creates a new column called Day
data['Hour'] = data['Timestamp'].dt.hour
# Creates a new column called Hour
data["Weekday"] = data['Timestamp'].dt.dayofweek
# Creates a new column called Weekday with sunday as 6 and monday as 0
data = data.drop(['Timestamp'], axis=1) # deleting timestamp
In diesem Abschnitt wird das Feature Engineering durchgeführt. Hier werden wichtige Informationen aus den Rohdaten extrahiert.
# Look at first 5 rows of the newly created DataFrame
data.head()
# Create heatmap
sns.set(rc={'figure.figsize':(14,14)})
sns.heatmap(data.corr(), annot=True)
# Barplots for the Weekday feature in context of the Clicked on Ad
ax = sns.barplot(x="Weekday", y="Clicked on Ad", data=data, estimator=sum)
# Creating pairplot to check effect of datetime variables on target variable (variables which were created)
pp = sns.pairplot(data, hue= 'Clicked on Ad', vars = ['Month', 'Day', 'Hour', 'Weekday'], palette= 'husl')
Wahrscheinlich gibt es im Laufe der Zeit keine nennenswerten Auswirkungen.
# Info of the dataframe
data.info()
# Reset the index
data.reset_index(drop=True, inplace=True)
# Creating Bins on Age column
data['Age_bins'] = pd.cut(data['Age'], bins=[0, 18, 30, 45, 70], labels=['Young', 'Adult','Mid', 'Elder'])
# Count for each category of Age_bins
data['Age_bins'].value_counts()
# Dummy encoding on Age_bins column
data = pd.concat([data, pd.get_dummies(data['Age_bins'], prefix='Age', drop_first=True)], axis=1)
5.2 Erstellung des endgültigen Datensatzes¶
# Remove redundant and no predictive power features
data.drop(['Country', 'Ad Topic Line', 'City', 'Day', 'Month', 'Weekday',
'Hour', 'Age', 'Age_bins'], axis = 1, inplace = True)
5.3 Datensatzaufteilung und Standardisierung¶
Teilen Sie den Datensatz in Merkmale (X) und Zielvariable (y) auf.
# First 5 rows of the dataset
data.head()
# Prepare and split data for prediction
from sklearn.model_selection import train_test_split
X = data.drop(['Clicked on Ad'],1)
y = data['Clicked on Ad']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
# Standardization of the Features
from sklearn.preprocessing import StandardScaler
stdsc = StandardScaler()
X_train_std = stdsc.fit_transform(X_train)
X_test_std = stdsc.transform(X_test)
# Dimensions of the different splits (rows -> number of samples, columns -> number of features)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
# Import required libraries for the model creation
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix, classification_report
6.1 Logistische Regression¶
# Sample sigmoid curve
x = np.linspace(-6, 6, num=1000)
plt.figure(figsize=(10, 6))
plt.plot(x, (1 / (1 + np.exp(-x))))
plt.title("Sigmoid Function")
from sklearn.linear_model import LogisticRegression
# Create a Logistic Regression Classifier
lr = LogisticRegression(penalty="l2", C= 0.1, random_state=42)
lr.fit(X_train_std, y_train)
# Predict and evaluate using model
lr_training_pred = lr.predict(X_train_std)
lr_test_pred = lr.predict(X_test_std)
lr_training_prediction = accuracy_score(y_train, lr_training_pred)
lr_test_prediction = accuracy_score(y_test, lr_test_pred)
print( "Accuracy of Logistic regression training set:", round(lr_training_prediction,3))
print( "Accuracy of Logistic regression test set:", round(lr_test_prediction,3))
print(classification_report(y_test, lr.predict(X_test_std)))
tn, fp, fn, tp = confusion_matrix(y_test, lr_test_pred).ravel()
precision = tp/(tp+fp)
recall = tp/(tp+fn)
f1_score = 2*((precision*recall)/(precision+recall))
print("True Positive: %i" %tp)
print("False Positive: %i" %fp)
print("True Negative: %i" %tn)
print("False Negative: %i" %fn)
print(f"Precision: {precision:.2%}")
print(f"Recall: {recall:.2%}")
print(f"F1-Score: {f1_score:.2%}")
print('Intercept:', lr.intercept_)
weights = pd.Series(lr.coef_[0],
index=X.columns.values)
weights.sort_values(ascending = False)
Insbesondere die tägliche Zeit, die vor Ort verbracht wird, die tägliche Internetnutzung und das Gebietseinkommen haben einen größeren Einfluss.
6.2 Entscheidungsbaum¶
from sklearn.tree import DecisionTreeClassifier
# Create a Decision Tree Classifier
estimator = DecisionTreeClassifier(max_leaf_nodes=4, random_state=0)
# Predict and evaluate using model
estimator.fit(X_train_std,y_train)
# Predict and evaluate using model
rf_training_pred = estimator.predict(X_train_std)
rf_test_pred = estimator.predict(X_test_std)
rf_training_prediction = accuracy_score(y_train, rf_training_pred)
rf_test_prediction = accuracy_score(y_test, rf_test_pred)
print("Accuracy of Decision Tree training set:", round(rf_training_prediction,3))
print("Accuracy of Decision Tree test set:", round(rf_test_prediction,3))
print(classification_report(y_test, lr.predict(X_test_std)))
tn, fp, fn, tp = confusion_matrix(y_test, rf_test_pred).ravel()
precision = tp/(tp+fp)
recall = tp/(tp+fn)
f1_score = 2*((precision*recall)/(precision+recall))
print("True Positive: %i" %tp)
print("False Positive: %i" %fp)
print("True Negative: %i" %tn)
print("False Negative: %i" %fn)
print(f"Precision: {precision:.2%}")
print(f"Recall: {recall:.2%}")
print(f"F1-Score: {f1_score:.2%}")