【scikit-learn】 2値分類の際のpredictで任意の閾値を設定する
scikit-learn
Lastmod: 2021-05-08

はじめに

scikit-learnにおける2値分類でpredictをするとデフォルトの閾値0.5で分類されますよね。今回はこの閾値を任意で設定する方法を紹介します。

結論

方法は以下の通り。

threshold = 0.6 #閾値を設定
predict = (clf.predict_proba(X)[:,1] >= threshold).astype(int)

挙動確認

ここでは上述した方法を使ってしっかりできているかを確認していきます。方法のみを知りたかった方は読まなくて大丈夫です。

データセットの用意

まず確認作業に使うデータセットを用意します。ここではmake_classificationを使います。またtrain_test_splitでデータを分割しておきます。

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


X, y = make_classification(n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, n_classes=2)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

学習

今回はLogisticRegressionを使って学習をさせます。

from sklearn.linear_model import LogisticRegression

clf = LogisticRegression()
clf.fit(X_train, y_train)

予測

早速予測させてみます。まずいつものpredictから。

predict = clf.predict(X_test)

print(predict)
print(X_test)

### 出力
# array([0, 0, 1, 0, 1, 0, 0, 1, 0, 0])
# array([0, 0, 1, 0, 1, 0, 0, 1, 0, 0])

しっかり分類できています。

それでは閾値を変えてみます。まずpredict_probaで確率を確認してみます。

predict_proba = clf.predict_proba(X_test)

print(predict_proba)

### 出力
# array([[0.94640156, 0.05359844],
#       [0.98731181, 0.01268819],
#       [0.11593201, 0.88406799],
#       [0.99392087, 0.00607913],
#       [0.12034518, 0.87965482],
#       [0.9844181 , 0.0155819 ],
#       [0.8665066 , 0.1334934 ],
#       [0.06603681, 0.93396319],
#       [0.96047255, 0.03952745],
#       [0.86216062, 0.13783938]])

結構しっかり分類できていそうなので本来なら閾値を調整する必要はなさそうです。今回はあえてthreshold=0.1として挙動を確認します。

threshold = 0.1
predict_custom = (clf.predict_proba(X_test)[:,1] >= threshold).astype(int)

print(predict_proba)
print(predict_custom)
print(predict)

### 出力
# array([[0.94640156, 0.05359844],
#       [0.98731181, 0.01268819],
#       [0.11593201, 0.88406799],
#       [0.99392087, 0.00607913],
#       [0.12034518, 0.87965482],
#       [0.9844181 , 0.0155819 ],
#       [0.8665066 , 0.1334934 ],
#       [0.06603681, 0.93396319],
#       [0.96047255, 0.03952745],
#       [0.86216062, 0.13783938]])
#
# array([0, 0, 1, 0, 1, 0, 1, 1, 0, 1])
# 
# array([0, 0, 1, 0, 1, 0, 0, 1, 0, 0])

挙動OK