Vorhersage von psychischen Erkrankungen für die Krankenversicherung¶
1. Business Understanding¶
Welche Personengruppe, bestimmt durch Alter, Geschlecht, Vorerkrankungen und berufliche Merkmale, muss im Verhältnis eine höhere Versicherungsprämie zahlen. Damit bei einem allgemeinen Schadensfall und Arbeitsausfall eine ausreichende Deckung gewährleistet ist, damit die Versicherung die zusätzlichen Behandlungskosten für psychische Erkrankungen finanzieren kann?
import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import plotly.express as px
%matplotlib inline
import seaborn as sns
sns.set()
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
# from sklearn.linear_model import LinearRegression
from sklearn.cluster import KMeans
from sklearn import metrics
from sklearn.metrics import confusion_matrix, classification_report
from statsmodels.stats.outliers_influence import variance_inflation_factor
from sklearn import svm, datasets
from sklearn.preprocessing import LabelEncoder
2.2. Daten einlesen¶
data = pd.read_csv('https://storage.googleapis.com/ml-service-repository-datastorage/Predicting_clicks_on_online_advertising_by_Facebook_data.csv') # read data
data.head(5)
Timestamp | Age | Gender | Country | state | self_employed | family_history | treatment | work_interfere | no_employees | ... | leave | mental_health_consequence | phys_health_consequence | coworkers | supervisor | mental_health_interview | phys_health_interview | mental_vs_physical | obs_consequence | comments;;;; | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2014-08-27 11:29:31 | 37 | Female | United States | IL | NaN | No | Yes | Often | 6-25 | ... | Somewhat easy | No | No | Some of them | Yes | No | Maybe | Yes | No | NA;;;; |
1 | 2014-08-27 11:29:37 | 44 | M | United States | IN | NaN | No | No | Rarely | More than 1000 | ... | Don't know | Maybe | No | No | No | No | No | Don't know | No | NA;;;; |
2 | 2014-08-27 11:29:44 | 32 | Male | Canada | NaN | NaN | No | No | Rarely | 6-25 | ... | Somewhat difficult | No | No | Yes | Yes | Yes | Yes | No | No | NA;;;; |
3 | 2014-08-27 11:29:46 | 31 | Male | United Kingdom | NaN | NaN | Yes | Yes | Often | 26-100 | ... | Somewhat difficult | Yes | Yes | Some of them | No | Maybe | Maybe | No | Yes | NA;;;; |
4 | 2014-08-27 11:30:22 | 31 | Male | United States | TX | NaN | No | No | Never | 100-500 | ... | Don't know | No | No | Some of them | Yes | Yes | Yes | Don't know | No | NA;;;; |
5 rows × 27 columns
2.3. Deskriptive Analyse¶
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1259 entries, 0 to 1258 Data columns (total 27 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Timestamp 1259 non-null object 1 Age 1259 non-null int64 2 Gender 1259 non-null object 3 Country 1259 non-null object 4 state 745 non-null object 5 self_employed 1240 non-null object 6 family_history 1259 non-null object 7 treatment 1259 non-null object 8 work_interfere 995 non-null object 9 no_employees 1259 non-null object 10 remote_work 1259 non-null object 11 tech_company 1259 non-null object 12 benefits 1259 non-null object 13 care_options 1259 non-null object 14 wellness_program 1259 non-null object 15 seek_help 1259 non-null object 16 anonymity 1259 non-null object 17 leave 1259 non-null object 18 mental_health_consequence 1259 non-null object 19 phys_health_consequence 1259 non-null object 20 coworkers 1259 non-null object 21 supervisor 1259 non-null object 22 mental_health_interview 1259 non-null object 23 phys_health_interview 1259 non-null object 24 mental_vs_physical 1259 non-null object 25 obs_consequence 1259 non-null object 26 comments;;;; 1259 non-null object dtypes: int64(1), object(26) memory usage: 265.7+ KB
data.describe(include='all') #Alle Daten statistisch anschauen
Timestamp | Age | Gender | Country | state | self_employed | family_history | treatment | work_interfere | no_employees | ... | leave | mental_health_consequence | phys_health_consequence | coworkers | supervisor | mental_health_interview | phys_health_interview | mental_vs_physical | obs_consequence | comments;;;; | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 1259 | 1.259000e+03 | 1259 | 1259 | 745 | 1240 | 1259 | 1259 | 995 | 1259 | ... | 1259 | 1259 | 1259 | 1259 | 1259 | 1259 | 1259 | 1259 | 1259 | 1259 |
unique | 1246 | NaN | 49 | 49 | 47 | 3 | 2 | 2 | 5 | 8 | ... | 6 | 5 | 3 | 4 | 3 | 3 | 3 | 4 | 3 | 161 |
top | 2014-08-27 12:44:51 | NaN | Male | United States | CA | No | No | Yes | Sometimes | 6-25 | ... | Don't know | No | No | Some of them | Yes | No | Maybe | Don't know | No | NA;;;; |
freq | 2 | NaN | 615 | 751 | 138 | 1094 | 767 | 636 | 464 | 289 | ... | 563 | 490 | 925 | 774 | 516 | 1008 | 556 | 575 | 1074 | 1095 |
mean | NaN | 7.942815e+07 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
std | NaN | 2.818299e+09 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
min | NaN | -1.726000e+03 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
25% | NaN | 2.700000e+01 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
50% | NaN | 3.100000e+01 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
75% | NaN | 3.600000e+01 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
max | NaN | 1.000000e+11 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
11 rows × 27 columns
def attribute_description(data):
longestColumnName = len(max(np.array(data.columns), key=len))
print("| Feature | Data Type|")
print("|-----|------|")
for col in data.columns:
description = ''
col_dropna = data[col].dropna()
example = col_dropna.sample(1).values[0]
if type(example) == str:
description = 'str '
if len(col_dropna.unique()) < 10:
description += '{'
description += '; '.join([ f'"{name}"' for name in col_dropna.unique()])
description += '}'
else:
description += '[ example: "'+ example + '" ]'
elif (type(example) == np.int32) and (len(col_dropna.unique()) < 10) :
description += 'dummy int32 {'
description += '; '.join([ f'{name}' for name in sorted(col_dropna.unique())])
description += '}'
else:
try:
description = example.dtype
except:
description = type(example)
print("|" + col.ljust(longestColumnName)+ f'| {description} |')
attribute_description(data)
| Feature | Data Type| |-----|------| |Timestamp | str [ example: "2014-08-28 10:00:48" ] | |Age | int64 | |Gender | str [ example: "male" ] | |Country | str [ example: "Mexico" ] | |state | str [ example: "TN" ] | |self_employed | str {"Yes"; "No"; "IL"} | |family_history | str {"No"; "Yes"} | |treatment | str {"Yes"; "No"} | |work_interfere | str {"Often"; "Rarely"; "Never"; "Sometimes"; "Yes"} | |no_employees | str {"6-25"; "More than 1000"; "26-100"; "100-500"; "1-5"; "500-1000"; "Often"; "Sometimes"} | |remote_work | str {"No"; "Yes"; "1-5"; "6-25"} | |tech_company | str {"Yes"; "No"} | |benefits | str {"Yes"; "Don't know"; "No"} | |care_options | str {"Not sure"; "No"; "Yes"; "Don't know"} | |wellness_program | str {"No"; "Don't know"; "Yes"} | |seek_help | str {"Yes"; "Don't know"; "No"} | |anonymity | str {"Yes"; "Don't know"; "No"} | |leave | str {"Somewhat easy"; "Don't know"; "Somewhat difficult"; "Very difficult"; "Very easy"; "Yes"} | |mental_health_consequence| str {"No"; "Maybe"; "Yes"; "Very easy"; "Don't know"} | |phys_health_consequence | str {"No"; "Yes"; "Maybe"} | |coworkers | str {"Some of them"; "No"; "Yes"; "Maybe"} | |supervisor | str {"Yes"; "No"; "Some of them"} | |mental_health_interview | str {"No"; "Yes"; "Maybe"} | |phys_health_interview | str {"Maybe"; "No"; "Yes"} | |mental_vs_physical | str {"Yes"; "Don't know"; "No"; "Maybe"} | |obs_consequence | str {"No"; "Yes"; "Don't know"} | |comments;;;; | str [ example: "NA;;;;" ] |
Feature | Data Type |
---|---|
Timestamp | str [ example: "2014-08-27 20:52:31" ] |
Age | int64 |
Gender | str [ example: "Male" ] |
Country | str [ example: "China" ] |
state | str [ example: "OR" ] |
self_employed | str {"Yes"; "No"; "IL"} |
family_history | str {"No"; "Yes"} |
treatment | str {"Yes"; "No"} |
work_interfere | str {"Often"; "Rarely"; "Never"; "Sometimes"; "Yes"} |
no_employees | str {"6-25"; "More than 1000"; "26-100"; "100-500"; "1-5"; "500-1000"; "Often"; "Sometimes"} |
remote_work | str {"No"; "Yes"; "1-5"; "6-25"} |
tech_company | str {"Yes"; "No"} |
benefits | str {"Yes"; "Don't know"; "No"} |
care_options | str {"Not sure"; "No"; "Yes"; "Don't know"} |
wellness_program | str {"No"; "Don't know"; "Yes"} |
seek_help | str {"Yes"; "Don't know"; "No"} |
anonymity | str {"Yes"; "Don't know"; "No"} |
leave | str {"Somewhat easy"; "Don't know"; "Somewhat difficult"; "Very difficult"; "Very easy"; "Yes"} |
mental_health_consequence | str {"No"; "Maybe"; "Yes"; "Very easy"; "Don't know"} |
phys_health_consequence | str {"No"; "Yes"; "Maybe"} |
coworkers | str {"Some of them"; "No"; "Yes"; "Maybe"} |
supervisor | str {"Yes"; "No"; "Some of them"} |
mental_health_interview | str {"No"; "Yes"; "Maybe"} |
phys_health_interview | str {"Maybe"; "No"; "Yes"} |
mental_vs_physical | str {"Yes"; "Don't know"; "No"; "Maybe"} |
obs_consequence | str {"No"; "Yes"; "Don't know"} |
comments;;;; | str [ example: "NA;;;;" ] |
data[data.duplicated(keep=False)] # show duplicates
Timestamp | Age | Gender | Country | state | self_employed | family_history | treatment | work_interfere | no_employees | ... | leave | mental_health_consequence | phys_health_consequence | coworkers | supervisor | mental_health_interview | phys_health_interview | mental_vs_physical | obs_consequence | comments;;;; |
---|
0 rows × 27 columns
3.2 Fehlende Daten entfernen¶
data.isnull().sum() #count missing data
Timestamp 0 Age 0 Gender 0 Country 0 state 514 self_employed 19 family_history 0 treatment 0 work_interfere 264 no_employees 0 remote_work 0 tech_company 0 benefits 0 care_options 0 wellness_program 0 seek_help 0 anonymity 0 leave 0 mental_health_consequence 0 phys_health_consequence 0 coworkers 0 supervisor 0 mental_health_interview 0 phys_health_interview 0 mental_vs_physical 0 obs_consequence 0 comments;;;; 0 dtype: int64
data1 = data.drop(['Timestamp','state','comments;;;;'], axis =1)
# delete features, that are not needed
data1['self_employed'] = data1['self_employed'].fillna(data1['self_employed'].mode().iloc[0])
# replace missing data in 'self_employed' wwith 'No'
data1.isnull().sum()
Age 0 Gender 0 Country 0 self_employed 0 family_history 0 treatment 0 work_interfere 264 no_employees 0 remote_work 0 tech_company 0 benefits 0 care_options 0 wellness_program 0 seek_help 0 anonymity 0 leave 0 mental_health_consequence 0 phys_health_consequence 0 coworkers 0 supervisor 0 mental_health_interview 0 phys_health_interview 0 mental_vs_physical 0 obs_consequence 0 dtype: int64
data1 = data1.dropna(axis=0) # remove rows with missing data (in 'work_interfere)
data1.isnull().sum() # make sure there is no missing data now
Age 0 Gender 0 Country 0 self_employed 0 family_history 0 treatment 0 work_interfere 0 no_employees 0 remote_work 0 tech_company 0 benefits 0 care_options 0 wellness_program 0 seek_help 0 anonymity 0 leave 0 mental_health_consequence 0 phys_health_consequence 0 coworkers 0 supervisor 0 mental_health_interview 0 phys_health_interview 0 mental_vs_physical 0 obs_consequence 0 dtype: int64
2.3 Unerwünschte Merkmale entfernen¶
data1.columns.values
array(['Age', 'Gender', 'Country', 'self_employed', 'family_history', 'treatment', 'work_interfere', 'no_employees', 'remote_work', 'tech_company', 'benefits', 'care_options', 'wellness_program', 'seek_help', 'anonymity', 'leave', 'mental_health_consequence', 'phys_health_consequence', 'coworkers', 'supervisor', 'mental_health_interview', 'phys_health_interview', 'mental_vs_physical', 'obs_consequence'], dtype=object)
data1 = data1.drop(['Country','wellness_program', 'seek_help', 'anonymity', 'leave',
'mental_health_consequence', 'phys_health_consequence',
'coworkers', 'supervisor', 'mental_health_interview',
'phys_health_interview', 'mental_vs_physical', 'obs_consequence'], axis = 1)
# remove features that are not relevant
# and remove features, that are relevant, but can not be used in the final model,
# as the data can not be collected in the production environment
data1.columns.values # these features are left
array(['Age', 'Gender', 'self_employed', 'family_history', 'treatment', 'work_interfere', 'no_employees', 'remote_work', 'tech_company', 'benefits', 'care_options'], dtype=object)
data1.describe(include='all')
Age | Gender | self_employed | family_history | treatment | work_interfere | no_employees | remote_work | tech_company | benefits | care_options | |
---|---|---|---|---|---|---|---|---|---|---|---|
count | 9.950000e+02 | 995 | 995 | 995 | 995 | 995 | 995 | 995 | 995 | 995 | 995 |
unique | NaN | 44 | 3 | 2 | 2 | 5 | 8 | 4 | 2 | 3 | 4 |
top | NaN | Male | No | No | Yes | Sometimes | 26-100 | No | Yes | Yes | Yes |
freq | NaN | 481 | 870 | 546 | 632 | 464 | 229 | 690 | 815 | 406 | 393 |
mean | 1.005025e+08 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
std | 3.170213e+09 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
min | -1.726000e+03 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
25% | 2.700000e+01 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
50% | 3.100000e+01 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
75% | 3.600000e+01 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
max | 1.000000e+11 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
data1['Age'].values
array([ 37, 44, 32, 31, 31, 33, 35, 39, 42, 23, 31, 29, 42, 36, 27, 29, 23, 32, 46, 29, 31, 46, 41, 33, 35, 35, 34, 37, 32, 31, 30, 42, 40, 27, 29, 35, 24, 27, 18, 30, 38, 26, 30, 22, 32, 27, 24, 33, 44, 26, 27, 35, 40, 23, 36, 34, 28, 34, 23, 33, 31, 32, 28, 38, 23, 30, 27, 33, 39, 34, 29, 31, 40, 25, 29, 24, 31, 33, 30, 26, 44, 33, 29, 35, 35, 28, 34, 32, 22, 28, 45, 32, 26, 21, 27, 18, 29, 33, 36, 27, 27, 32, 31, 19, 33, 32, 27, 24, 39, 28, 39, 38, 37, 35, 37, 24, 23, 30, 32, 28, 36, 37, 25, 27, 26, 27, 25, 36, 25, 31, 26, 33, 34, 23, 24, 26, 31, 22, 34, 31, 32, 45, 29, 26, 28, 45, 43, 24, 35, 38, 28, 28, 35, 32, 31, 35, 26, 28, 27, 34, 41, 37, 32, 21, 30, 24, 37, 26, 32, 32, 27, 30, 29, 28, 28, 23, 32, 34, 24, 26, 36, 41, 38, 38, 25, 37, 37, 28, 34, 33, 27, 40, 21, 32, 29, 23, 31, 24, 29, 23, 42, 25, 27, 27, 30, 29, 41, 32, 37, 32, 30, 23, 34, 38, 28, 28, 23, 22, 27, 18, 35, 27, 26, 18, 38, 26, 30, 35, 45, 32, 56, 30, 33, 37, 23, 31, 26, 28, 37, 26, 30, 27, 25, 35, 36, 26, 25, 22, 41, 29, 32, 24, 25, 30, 25, 30, 33, 25, 45, 46, 30, 29, 33, 27, 33, 25, 23, 54, 22, 25, 29, 27, 30, 26, 25, 31, 33, 34, 34, 34, 26, 32, 329, 28, 36, 21, 21, 41, 55, 32, 21, 45, 27, 25, 34, 26, 41, 27, 31, 25, 26, 27, 42, 29, 25, 33, 99999999999, 40, 31, 26, 29, 35, 32, 29, 26, 28, 35, 29, 33, 22, 33, 31, 21, 31, 26, 30, 30, 34, 55, 28, 28, 32, 28, 21, 24, 28, 24, 33, 34, 27, 28, 23, 29, 26, 36, 41, 23, 39, 26, 24, 37, 43, 40, 30, 34, 27, 36, 27, 35, 32, 33, 28, 26, 27, 38, 57, 28, 26, 42, 31, 58, 29, 39, 34, 57, 27, 23, 23, 43, 18, 29, 48, 43, 28, 30, 26, 33, 30, 24, 23, 36, 25, 54, 34, 25, 35, 46, 42, 32, 47, 33, 25, 39, 38, 46, 38, 33, 34, 62, 25, 36, 41, 24, 51, 29, 31, 31, 27, 23, 21, 27, 39, 26, 27, 22, 26, 31, 32, 28, 28, 30, 36, 30, 32, 29, 21, 27, 32, 34, 22, 27, 33, 36, 40, 28, 39, 32, 31, 38, 23, 42, 27, 26, 50, 37, 33, 29, 34, 41, 29, 35, 27, 40, 27, 29, 31, 43, 34, 29, 19, 41, 29, 23, 24, 31, 29, 33, 30, 32, 50, 24, 27, 32, 42, 37, 30, 29, 30, 35, 35, 38, 22, 24, 22, 31, 23, 31, 28, 37, 34, 32, 28, 24, 56, 31, 34, 35, 28, 36, 30, 49, 29, 57, 31, 37, 25, 30, 26, 22, 39, 29, 54, 34, 32, 29, 32, 30, 20, 27, 32, 26, 30, 30, 26, 26, 23, 26, 35, 28, 29, 45, 33, 38, 19, 29, 23, 33, 49, 27, 23, 29, 32, 33, 37, 23, 43, 32, 26, 32, 29, 30, 29, 32, -1726, 30, 25, 33, 31, 21, 30, 43, 37, 33, 33, 36, 37, 39, 31, 36, 30, 35, 19, 37, 40, 36, 29, 38, 26, 34, 21, 31, 37, 37, 38, 27, 33, 27, 36, 28, 39, 33, 37, 39, 43, 32, 43, 33, 34, 25, 25, 39, 29, 33, 37, 35, 22, 38, 32, 35, 29, 23, 28, 40, 41, 29, 29, 35, 28, 36, 39, 39, 44, 26, 35, 40, 35, 38, 48, 20, 40, 29, 35, 29, 40, 29, 29, 34, 44, 24, 36, 43, 36, 31, 35, 37, 34, 36, 40, 40, 42, 21, 26, 51, 32, 32, 26, 23, 33, 46, 35, 32, 56, 32, 30, 23, 31, 29, 30, 37, 36, 35, 41, 31, 39, 42, 32, 30, 40, 33, 34, 50, 24, 25, 43, 25, 51, 49, 25, 36, 48, 48, 53, 24, 33, 25, 30, 30, 34, 22, 28, 35, 28, 42, 29, 43, 31, 35, 34, 43, 38, 26, 38, 42, 32, 44, 28, 40, 31, 32, 28, 39, 43, 35, 40, 34, 24, 61, 36, 33, 30, 34, 26, 25, 35, 24, 55, 33, 26, 25, 45, 33, 43, 30, 40, 49, 38, 26, 28, 40, 37, 34, 28, 27, 29, 39, 28, 23, 8, 30, 20, 35, 39, 31, 32, 25, 42, 34, 26, 35, 34, 38, 34, 39, 33, 24, 38, 31, 46, 30, 25, 19, 30, 32, 37, 42, 25, 19, 40, 31, 40, 31, 36, 35, 26, 34, 28, 40, 26, 29, 26, 33, 28, 41, 39, 26, 23, 35, 36, 42, 39, 27, 33, 31, 28, 29, 27, 44, 25, 24, 25, 34, 26, 48, 39, 43, 41, 25, 31, 40, 43, 27, 37, 32, 25, 29, 30, 34, 32, 41, 38, 32, 28, 11, 43, 32, 25, 37, 36, 24, 40, 43, 26, 33, 35, 45, 25, 50, 26, 33, 30, 33, 29, 25, 24, 40, 46, 38, 34, 32, 44, 33, 45, 26, 20, -1, 37, 42, 36, 27, 27, 27, 25, 23, 21, 26, 29, 28, 23, 26, 38, 39, 35, 32, 32, 26, 38, 34, 39, 32, 37, 31, 30, 51, 29, 31, 26, 46, 32, 29, 34, 26, 32, 40, 23, 20, 26, 29, 40, 25, 32, 38, 72, 35, 28, 27, 56, 38, 40, 44, 34, 37, 38, 34, 35, 34, 32, 28, 28, 34, 32, 34, 23, 33, 29, 45, 34, 31, 33, 27, 42, 38, 46, 46, 41, 23, 24, 23, 32, 25, 23, 24, 25, 23, 24, 23, 60, 28, 28, 30, 31, 31, 28, 43, 22, 32, 36, 41, 30, 30, 36, 29, 36, 32, 34, 25], dtype=int64)
2.4 Ausreißer entfernen¶
einige Personen haben einen Trollwert für "Alter" eingegeben, diese Zeilen müssen entfernt werden
das Alter sollte zwischen 16 und 70 Jahren liegen
sns.histplot(data = data1, x = 'Age', bins="sqrt")
<AxesSubplot:xlabel='Age', ylabel='Count'>
data2 = data1[data1['Age']<70]
data2 = data2[data2['Age']>16]
sns.histplot(data2['Age'])
<AxesSubplot:xlabel='Age', ylabel='Count'>
data2['Age'].median(axis = 0)
31.5
data2.columns.values # these features are left
array(['Age', 'Gender', 'self_employed', 'family_history', 'treatment', 'work_interfere', 'no_employees', 'remote_work', 'tech_company', 'benefits', 'care_options'], dtype=object)
2.5 Daten bereinigen¶
data2.columns.values
array(['Age', 'Gender', 'self_employed', 'family_history', 'treatment', 'work_interfere', 'no_employees', 'remote_work', 'tech_company', 'benefits', 'care_options'], dtype=object)
# clean Gender
# reduce options to only male or female
data2['Gender'] = data2['Gender'].str.lower()
male = ["male", "m", "male-ish", "maile", "mal", "male (cis)", "make", "male ", "man","msle", "mail", "malr","cis man", "cis male"]
trans = ["trans-female", "something kinda male?", "queer/she/they", "non-binary","nah", "all", "enby", "fluid", "genderqueer", "androgyne", "agender", "male leaning androgynous", "guy (-ish) ^_^", "trans woman", "neuter", "female (trans)", "queer", "ostensibly male, unsure what that really means"]
female = ["cis female", "f", "female", "woman", "femake", "female ","cis-female/femme", "female (cis)", "femail"]
data2['Gender'] = data2['Gender'].apply(lambda x:"Male" if x in male else x)
data2['Gender'] = data2['Gender'].apply(lambda x:"Female" if x in female else x)
data2['Gender'] = data2['Gender'].apply(lambda x:"Trans" if x in trans else x)
data2.drop(data2[data2.Gender == 'p'].index, inplace=True)
data2.drop(data2[data2.Gender == 'a little about you'].index, inplace=True)
plot_gender = data2['Gender'].value_counts().reset_index()
plot_gender.columns = ['Gender','count']
px.pie(plot_gender,values='count',names='Gender',template='ggplot2',title='Gender')
dropvalue = data2[ data2['Gender'] == 'ostensibly male' ].index
data2.drop(dropvalue , inplace=True)
dropvalue1 = data2[ data2['Gender'] == 'Trans' ].index
data2.drop(dropvalue1, inplace=True)
plot_gender = data2['Gender'].value_counts().reset_index()
plot_gender.columns = ['Gender','count']
px.pie(plot_gender,values='count',names='Gender',template='ggplot2',title='Gender')
plot_self = data2['self_employed'].value_counts().reset_index()
plot_self.columns = ['self_employed','count']
px.pie(plot_self,values='count',names='self_employed',template='ggplot2',title='Self employed')
# Self employed looks good
plot_fam = data2['family_history'].value_counts().reset_index()
plot_fam.columns = ['family_history','count']
px.pie(plot_fam,values='count',names='family_history',template='ggplot2',title='Family history')
# family history looks good
plot_treat = data2['treatment'].value_counts().reset_index()
plot_treat.columns = ['treatment','count']
px.pie(plot_treat,values='count',names='treatment',template='ggplot2',title='In treatment')
# in treatment looks good.
# this will be our target variable
plot_inter = data2['work_interfere'].value_counts().reset_index()
plot_inter.columns = ['work_interfere','count']
px.pie(plot_inter,values='count',names='work_interfere',template='ggplot2',title='Work interfere')
plot_numb = data2['no_employees'].value_counts().reset_index()
plot_numb.columns = ['no_employees','count']
px.pie(plot_numb,values='count',names='no_employees',template='ggplot2',title='Number of employees')
# no_employees looks good
plot_remo = data2['remote_work'].value_counts().reset_index()
plot_remo .columns = ['remote_work','count']
px.pie(plot_remo ,values='count',names='remote_work',template='ggplot2',title='Remote work')
# remote work looks good
plot_tech = data2['tech_company'].value_counts().reset_index()
plot_tech.columns = ['tech_company','count']
px.pie(plot_tech,values='count',names='tech_company',template='ggplot2',title='Tech company')
plot_bene = data2['benefits'].value_counts().reset_index()
plot_bene.columns = ['benefits','count']
px.pie(plot_bene,values='count',names='benefits',template='ggplot2',title='Benefits')
plot_care = data2['care_options'].value_counts().reset_index()
plot_care.columns = ['care_options','count']
px.pie(plot_care,values='count',names='care_options',template='ggplot2',title='Care options')
data2.describe(include='all')
Age | Gender | self_employed | family_history | treatment | work_interfere | no_employees | remote_work | tech_company | benefits | care_options | |
---|---|---|---|---|---|---|---|---|---|---|---|
count | 971.000000 | 971 | 971 | 971 | 971 | 971 | 971 | 971 | 971 | 971 | 971 |
unique | NaN | 2 | 2 | 2 | 2 | 4 | 6 | 2 | 2 | 3 | 3 |
top | NaN | Male | No | No | Yes | Sometimes | 26-100 | No | Yes | Yes | Yes |
freq | NaN | 761 | 852 | 535 | 613 | 453 | 223 | 675 | 798 | 395 | 379 |
mean | 32.330587 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
std | 7.268977 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
min | 18.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
25% | 27.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
50% | 32.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
75% | 36.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
max | 62.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
plt.figure(figsize=(10,5))
sns.countplot(y="Gender", hue="treatment", data=data2)
plt.title("mental health vs Gender",fontsize=15,fontweight="normal")
plt.ylabel("")
plt.show()
plt.figure(figsize=(10,5))
sns.countplot(y="family_history", hue="treatment", data=data2)
plt.title("family history vs mental health ",
fontsize=15,fontweight="normal")
plt.ylabel("")
plt.show()
plt.figure(figsize=(10,5))
sns.countplot(y="work_interfere", hue="treatment", data=data2)
plt.title("Behandlung der psychischen Erkrankung stört die tägliche Arbeit?",fontsize=15,fontweight="normal")
plt.ylabel("")
plt.show()
Dieses Merkmal scheint ein guter Prädiktor für die Zielvariable Behandlung zu sein.
Aber die Information "Beeinträchtigt Ihre psychische Gesundheit Ihre Arbeit" kann nicht erhoben werden.
Niemand, der bei klarem Verstand ist, würde seiner Krankenkasse gegenüber die Wahrheit zu dieser Frage sagen.
Daher muss die Frage nach der Beeinträchtigung der Arbeit gestrichen werden.
data2 = data2.drop('work_interfere', axis=1)
plt.figure(figsize=(10,5))
sns.countplot(x="no_employees", hue="treatment", data=data2)
plt.title("number of employees vs mental health",fontsize=18,fontweight="normal")
plt.ylabel("")
plt.show()
data2
Age | Gender | self_employed | family_history | treatment | no_employees | remote_work | tech_company | benefits | care_options | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 37 | Female | No | No | Yes | 6-25 | No | Yes | Yes | Not sure |
1 | 44 | Male | No | No | No | More than 1000 | No | No | Don't know | No |
2 | 32 | Male | No | No | No | 6-25 | No | Yes | No | No |
3 | 31 | Male | No | Yes | Yes | 26-100 | No | Yes | No | Yes |
4 | 31 | Male | No | No | No | 100-500 | Yes | Yes | Yes | No |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1252 | 29 | Male | No | Yes | Yes | 100-500 | Yes | Yes | Yes | Yes |
1253 | 36 | Male | No | Yes | No | More than 1000 | No | No | Don't know | No |
1255 | 32 | Male | No | Yes | Yes | 26-100 | Yes | Yes | Yes | Yes |
1256 | 34 | Male | No | Yes | Yes | More than 1000 | No | Yes | Yes | Yes |
1258 | 25 | Male | No | Yes | Yes | 26-100 | No | No | Yes | Yes |
971 rows × 10 columns
data3 = data2.reset_index(drop = True)
data3.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 971 entries, 0 to 970 Data columns (total 10 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Age 971 non-null int64 1 Gender 971 non-null object 2 self_employed 971 non-null object 3 family_history 971 non-null object 4 treatment 971 non-null object 5 no_employees 971 non-null object 6 remote_work 971 non-null object 7 tech_company 971 non-null object 8 benefits 971 non-null object 9 care_options 971 non-null object dtypes: int64(1), object(9) memory usage: 76.0+ KB
data3.describe(include="all")
Age | Gender | self_employed | family_history | treatment | no_employees | remote_work | tech_company | benefits | care_options | |
---|---|---|---|---|---|---|---|---|---|---|
count | 971.000000 | 971 | 971 | 971 | 971 | 971 | 971 | 971 | 971 | 971 |
unique | NaN | 2 | 2 | 2 | 2 | 6 | 2 | 2 | 3 | 3 |
top | NaN | Male | No | No | Yes | 26-100 | No | Yes | Yes | Yes |
freq | NaN | 761 | 852 | 535 | 613 | 223 | 675 | 798 | 395 | 379 |
mean | 32.330587 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
std | 7.268977 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
min | 18.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
25% | 27.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
50% | 32.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
75% | 36.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
max | 62.000000 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2.6 Dummy-Merkmale für alle stringbasierten Variablen erstellen¶
data3 = pd.get_dummies(data3, drop_first=True) # 0-1 encoding for categorical values
data3.head()
Age | Gender_Male | self_employed_Yes | family_history_Yes | treatment_Yes | no_employees_100-500 | no_employees_26-100 | no_employees_500-1000 | no_employees_6-25 | no_employees_More than 1000 | remote_work_Yes | tech_company_Yes | benefits_No | benefits_Yes | care_options_Not sure | care_options_Yes | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 37 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 1 | 1 | 0 |
1 | 44 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 32 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 |
3 | 31 | 1 | 0 | 1 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 1 |
4 | 31 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 1 | 0 | 0 |
Y = data3['treatment_Yes']
X = data3.drop(['treatment_Yes'], axis=1)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=123) # 80-20 split into training and test data
scaler = StandardScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
tree = DecisionTreeClassifier()
tree.fit(X_train, y_train)
print('train performance')
print(classification_report(y_train, tree.predict(X_train)))
print('-----------------------------------------------------')
print('test performance')
print(classification_report(y_test, tree.predict(X_test)))
train performance precision recall f1-score support 0 0.91 1.00 0.95 280 1 1.00 0.95 0.97 496 accuracy 0.97 776 macro avg 0.96 0.97 0.96 776 weighted avg 0.97 0.97 0.97 776 ----------------------------------------------------- test performance precision recall f1-score support 0 0.47 0.44 0.45 78 1 0.64 0.68 0.66 117 accuracy 0.58 195 macro avg 0.56 0.56 0.56 195 weighted avg 0.57 0.58 0.58 195
tree_depth = [1, 2, 3, 4] # to prevent overfitting
for i in tree_depth:
tree = DecisionTreeClassifier(max_depth=i)
tree.fit(X_train, y_train)
print('Max tree depth:', i)
print('Confusion Matrix: ', confusion_matrix(y_test, tree.predict(X_test)).ravel())
print('Train results:', classification_report(y_train, tree.predict(X_train), zero_division=0 ))
print('Test results:', classification_report(y_test, tree.predict(X_test), zero_division=0 ))
print('----------------------------------------------------------------------------')
Max tree depth: 1 Confusion Matrix: [ 0 78 0 117] Train results: precision recall f1-score support 0 0.00 0.00 0.00 280 1 0.64 1.00 0.78 496 accuracy 0.64 776 macro avg 0.32 0.50 0.39 776 weighted avg 0.41 0.64 0.50 776 Test results: precision recall f1-score support 0 0.00 0.00 0.00 78 1 0.60 1.00 0.75 117 accuracy 0.60 195 macro avg 0.30 0.50 0.37 195 weighted avg 0.36 0.60 0.45 195 ---------------------------------------------------------------------------- Max tree depth: 2 Confusion Matrix: [50 28 22 95] Train results: precision recall f1-score support 0 0.58 0.57 0.58 280 1 0.76 0.77 0.76 496 accuracy 0.70 776 macro avg 0.67 0.67 0.67 776 weighted avg 0.70 0.70 0.70 776 Test results: precision recall f1-score support 0 0.69 0.64 0.67 78 1 0.77 0.81 0.79 117 accuracy 0.74 195 macro avg 0.73 0.73 0.73 195 weighted avg 0.74 0.74 0.74 195 ---------------------------------------------------------------------------- Max tree depth: 3 Confusion Matrix: [45 33 19 98] Train results: precision recall f1-score support 0 0.62 0.55 0.58 280 1 0.76 0.81 0.78 496 accuracy 0.72 776 macro avg 0.69 0.68 0.68 776 weighted avg 0.71 0.72 0.71 776 Test results: precision recall f1-score support 0 0.70 0.58 0.63 78 1 0.75 0.84 0.79 117 accuracy 0.73 195 macro avg 0.73 0.71 0.71 195 weighted avg 0.73 0.73 0.73 195 ---------------------------------------------------------------------------- Max tree depth: 4 Confusion Matrix: [44 34 19 98] Train results: precision recall f1-score support 0 0.68 0.52 0.59 280 1 0.76 0.86 0.81 496 accuracy 0.74 776 macro avg 0.72 0.69 0.70 776 weighted avg 0.73 0.74 0.73 776 Test results: precision recall f1-score support 0 0.70 0.56 0.62 78 1 0.74 0.84 0.79 117 accuracy 0.73 195 macro avg 0.72 0.70 0.71 195 weighted avg 0.72 0.73 0.72 195 ----------------------------------------------------------------------------
Random Forest¶
rf = RandomForestClassifier(max_depth=2)
rf.fit(X_train, y_train)
print('Confusion Matrix: ', confusion_matrix(y_test, rf.predict(X_test)).ravel())
print('Train results: ', classification_report(y_train, rf.predict(X_train), zero_division=0 ))
print('Test results: ',classification_report(y_test, rf.predict(X_test), zero_division=0 ))
Confusion Matrix: [ 0 78 1 116] Train results: precision recall f1-score support 0 1.00 0.00 0.01 280 1 0.64 1.00 0.78 496 accuracy 0.64 776 macro avg 0.82 0.50 0.39 776 weighted avg 0.77 0.64 0.50 776 Test results: precision recall f1-score support 0 0.00 0.00 0.00 78 1 0.60 0.99 0.75 117 accuracy 0.59 195 macro avg 0.30 0.50 0.37 195 weighted avg 0.36 0.59 0.45 195
Logistische Regression¶
logreg = LogisticRegression()
logreg.fit(X_train, y_train)
print('train performance')
print(classification_report(y_train, logreg.predict(X_train)))
print('test performance')
print(classification_report(y_test, logreg.predict(X_test)))
train performance precision recall f1-score support 0 0.62 0.50 0.56 280 1 0.75 0.83 0.78 496 accuracy 0.71 776 macro avg 0.68 0.67 0.67 776 weighted avg 0.70 0.71 0.70 776 test performance precision recall f1-score support 0 0.69 0.51 0.59 78 1 0.72 0.85 0.78 117 accuracy 0.71 195 macro avg 0.71 0.68 0.68 195 weighted avg 0.71 0.71 0.70 195
model_logReg = LogisticRegression(penalty='l2', C=0.1)
model_logReg.fit(X_train, y_train)
y_pred = model_logReg.predict(X_test)
print('train performance')
print(classification_report(y_train, model_logReg.predict(X_train)))
print('test performance')
print(classification_report(y_test, model_logReg.predict(X_test)))
train performance precision recall f1-score support 0 0.62 0.50 0.55 280 1 0.75 0.83 0.79 496 accuracy 0.71 776 macro avg 0.68 0.66 0.67 776 weighted avg 0.70 0.71 0.70 776 test performance precision recall f1-score support 0 0.71 0.51 0.60 78 1 0.73 0.86 0.79 117 accuracy 0.72 195 macro avg 0.72 0.69 0.69 195 weighted avg 0.72 0.72 0.71 195
print("Accuracy", metrics.accuracy_score(y_test, y_pred))
Accuracy 0.7230769230769231