在本文中,我們將從sklearn庫中了解StratifiedShuffleSplit交叉驗證器,該驗證器提供train-test索引以將數據分為train-test集。
什麽是StratifiedShuffleSplit?
分層洗牌拆分是ShuffleSplit和StratifiedKFold的組合。使用StratifiedShuffleSplit,訓練和測試數據集之間的類標簽分布比例幾乎相等。 StratifiedShuffleSplit和StratifiedKFold(shuffle = True)之間的主要區別在於,在StratifiedKFold中,數據集僅在開始時被隨機洗一次,然後分成指定的折疊數。這將丟棄train-test集重疊的任何機會。但是,在StratifiedShuffleSplit中,每次在拆分完成之前都會對數據進行混洗,這就是為什麽在train-test集之間有更大可能重疊的原因。
用法:sklearn.model_selection.StratifiedShuffleSplit(n_splits=10, *, test_size=None, train_size=None, random_state=None)
參數:
n_splits:int,默認= 10
re-shuffling和拆分迭代的次數。
test_size:float或int,默認為None
如果為float,則應在0.0到1.0之間,並且代表要包含在測試拆分中的數據集的比例。
train_size:float或int,默認為None
如果為float,則應在0.0到1.0之間,並且代表要包含在火車分割中的數據集的比例。
random_state:整型
控製所產生的訓練和測試指標的隨機性。
下麵是實現。
步驟1)導入所需的模塊。
Python3
# import the libraries
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedShuffleSplit
步驟2)加載數據集並標識因變量和自變量。
數據集可從此處下載。
Python3
# convert data set into dataframe
churn_df = pd.read_csv(r"ChurnData.csv")
# assign dependent and indepenedent variables
X = churn_df[['tenure', 'age', 'address', 'income',
'ed', 'employ', 'equip', 'callcard', 'wireless']]
y = churn_df['churn'].astype('int')
步驟3)預處理數據。
Python3
# data pre-processing
X = preprocessing.StandardScaler().fit(X).transform(X)
步驟4)創建StratifiedShuffleSplit類的對象。
Python3
# use StratifiedShuffleSplit()
sss = StratifiedShuffleSplit(n_splits=4, test_size=0.5,
random_state=0)
sss.get_n_splits(X, y)
輸出:
步驟5)調用實例並將數據幀分為訓練樣本和測試樣本。 split()函數返回train-test個樣本的索引。使用回歸算法並比較每個預測值的準確性。
Python3
scores = []
# using regression to get predicted data
rf = RandomForestClassifier(n_estimators=40, max_depth=7)
for train_index, test_index in sss.split(X, y):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
rf.fit(X_train, y_train)
pred = rf.predict(X_test)
scores.append(accuracy_score(y_test, pred))
# get accurracy of each prediction
print(scores)
輸出:
注:本文由純淨天空篩選整理自 Sklearn.StratifiedShuffleSplit() function in Python。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。