在python中实现SVMOne-vs-all时出了点问题

我试图通过将函数OneVsRestClassifier与我自己的实现进行比较来验证我是否正确理解了 SVM - OVA(一对一)的工作原理。

在下面的代码中,我num_classes在训练阶段实现了分类器,然后在测试集上测试了所有分类器,并选择了返回最高概率值的分类器。

import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score,classification_report
from sklearn.preprocessing import scale

# Read dataset 
df = pd.read_csv('In/winequality-white.csv',  delimiter=';')
X = df.loc[:, df.columns != 'quality']
Y = df.loc[:, df.columns == 'quality']
my_classes = np.unique(Y)
num_classes = len(my_classes)

# Train-test split
np.random.seed(42)
msk = np.random.rand(len(df)) <= 0.8
train = df[msk]
test = df[~msk]

# From dataset to features and labels
X_train = train.loc[:, train.columns != 'quality']
Y_train = train.loc[:, train.columns == 'quality']
X_test = test.loc[:, test.columns != 'quality']
Y_test = test.loc[:, test.columns == 'quality']

# Models
clf =  [None] * num_classes
for k in np.arange(0,num_classes):
    my_model = SVC(gamma='auto', C=1000, kernel='rbf', class_weight='balanced', probability=True)
    clf[k] = my_model.fit(X_train, Y_train==my_classes[k])

# Prediction
prob_table = np.zeros((len(Y_test), num_classes))
for k in np.arange(0,num_classes):
    p = clf[k].predict_proba(X_test)
    prob_table[:,k] = p[:,list(clf[k].classes_).index(True)]
Y_pred = prob_table.argmax(axis=1)

print("Test accuracy = ", accuracy_score( Y_test, Y_pred) * 100,"nn") 

测试精度等于0.21,而使用函数时OneVsRestClassifier,返回0.59。为了完整起见,我还报告了其他代码(预处理步骤与之前相同):

....
clf = OneVsRestClassifier(SVC(gamma='auto', C=1000, kernel='rbf', class_weight='balanced'))
clf.fit(X_train, Y_train)
Y_pred = clf.predict(X_test)
print("Test accuracy = ", accuracy_score( Y_test, Y_pred) * 100,"nn")

我自己的 SVM - OVA 实现有什么问题吗?

回答

我自己的 SVM - OVA 实现有什么问题吗?

您有唯一的 classes array([3, 4, 5, 6, 7, 8, 9]),但是该行Y_pred = prob_table.argmax(axis=1)假定它们是 0-indexed。

尝试重构您的代码,以减少出现以下假设的错误:

from sklearn.svm import SVC
from sklearn.metrics import accuracy_score,classification_report
from sklearn.preprocessing import scale
from sklearn.model_selection import train_test_split
df = pd.read_csv('winequality-white.csv',  delimiter=';')
y = df["quality"]
my_classes = np.unique(y)
X = df.drop("quality", axis=1)

X_train, X_test, Y_train, Y_test = train_test_split(X,y, random_state=42)

# Models
clfs =  []

for k in my_classes:
    my_model = SVC(gamma='auto', C=1000, kernel='rbf', class_weight='balanced'
                   , probability=True, random_state=42)
    clfs.append(my_model.fit(X_train, Y_train==k))

# Prediction
prob_table = np.zeros((len(X_test),len(my_classes)))

for i,clf in enumerate(clfs):
    probs = clf.predict_proba(X_test)[:,1]
    prob_table[:,i] = probs
    
Y_pred = my_classes[prob_table.argmax(1)]
print("Test accuracy = ", accuracy_score(Y_test, Y_pred) * 100,)

from sklearn.multiclass import OneVsRestClassifier
clf = OneVsRestClassifier(SVC(gamma='auto', C=1000, kernel='rbf'
                              ,class_weight='balanced', random_state=42))
clf.fit(X_train, Y_train)
Y_pred = clf.predict(X_test)
print("Test accuracy = ", accuracy_score(Y_test, Y_pred) * 100,)

Test accuracy =  61.795918367346935
Test accuracy =  58.93877551020408

请注意基于概率的 OVR 差异,与基于标签的 OVR 相比,它的粒度更细并产生更好的结果。

对于进一步的实验,您可能希望将分类器包装到一个可重用的类中:

class OVRBinomial(BaseEstimator, ClassifierMixin):

    def __init__(self, cls):
        self.cls = cls

    def fit(self, X, y, **kwargs):
        self.classes_ = np.unique(y)
        self.clfs_ = []
        for c in self.classes_:
            clf = self.cls(**kwargs)
            clf.fit(X, y == c)
            self.clfs_.append(clf)
        return self

    def predict(self, X, **kwargs):
        probs = np.zeros((len(X), len(self.classes_)))
        for i, c in enumerate(self.classes_):
            prob = self.clfs_[i].predict_proba(X, **kwargs)[:, 1]
            probs[:, i] = prob
        idx_max = np.argmax(probs, 1)
        return self.classes_[idx_max]


回答

您的代码的预测部分有错误。使用该命令Y_pred = prob_table.argmax(axis=1),您可以获得概率最大的列的索引。但是您希望拥有概率最大的类而不是列索引:

Y_pred = my_classes[prob_table.argmax(axis=1)]


以上是在python中实现SVMOne-vs-all时出了点问题的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>