In [1]:
import pandas as pd
from sklearn.svm import SVC
from sklearn.datasets import *
from collections import Counter
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import cross_val_predict

%pylab inline
Populating the interactive namespace from numpy and matplotlib
In [2]:
def plot_roc(y_score, y_true):
    fpr, tpr, _ = roc_curve(y_score=y_score, y_true=y_true)
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, label='AUC: ' + str(np.round(roc_auc_score(y_score=y_score, y_true=y_true), 3)))
    plt.xlim(-0.05, 1.05)
    plt.ylim(-0.05, 1.05)
    plt.plot([-1, 2], [-1, 2], '--')
    plt.legend(fontsize=15)
    plt.show()
In [3]:
data = make_circles(random_state=0, noise=0.1, n_samples=200, factor=0.8)
data
Out[3]:
(array([[ 0.37688311,  0.86683481],
        [-0.22124013, -0.92898008],
        [ 0.63218216,  0.32870102],
        [ 0.94748697, -0.33966878],
        [ 0.10127036, -0.84951402],
        [ 0.388714  , -0.78127584],
        [ 0.92580976,  0.41300167],
        [-0.77721448,  0.20789414],
        [ 0.66953589,  0.67397743],
        [-0.71941219, -0.17576474],
        [-0.78434151, -0.84991602],
        [-0.10300188,  0.74825481],
        [ 0.23303359, -0.88849323],
        [-0.6267152 , -0.03302425],
        [ 0.30903096, -0.80047268],
        [ 0.91527914,  0.40217498],
        [-0.39981017,  0.91659232],
        [-0.39103778,  0.7879701 ],
        [-0.75360742,  0.66081533],
        [-0.21401358, -0.77833763],
        [ 0.48952044, -0.70436895],
        [-0.75428108,  0.27379217],
        [-0.8358723 ,  0.32282128],
        [-0.69024289, -0.26027869],
        [-0.71183378, -0.42701785],
        [ 0.07455381,  0.97467353],
        [ 0.28818676, -0.78452414],
        [ 0.4325125 , -0.69774337],
        [ 0.01363794,  0.8991425 ],
        [-0.99984546,  0.19128976],
        [ 0.56986852,  0.91725209],
        [-1.0598919 , -0.59679487],
        [-0.86614777,  0.07059284],
        [ 0.58055081,  0.58313744],
        [ 0.18374466,  1.04117867],
        [ 0.53844711, -0.46230552],
        [-0.1155639 ,  0.9062458 ],
        [ 0.9084123 ,  0.19397351],
        [ 0.54271875, -0.80041346],
        [ 0.62522927,  0.38743959],
        [-0.37396623,  0.58850311],
        [-0.413484  , -0.85757096],
        [-0.11328014,  0.95355095],
        [ 0.5853513 ,  0.51946425],
        [-0.35109875, -0.73896584],
        [-0.6000845 , -0.63810439],
        [ 0.97454884,  0.36365859],
        [ 0.00899438, -1.07062449],
        [ 0.34584593,  0.58970373],
        [-0.58556741,  0.37931472],
        [-0.38193521, -0.91820141],
        [ 0.15352398,  0.78619902],
        [ 0.61192522, -0.68203566],
        [ 1.16866492, -0.1997504 ],
        [-0.67638151,  0.04957368],
        [ 0.12216503,  0.9316288 ],
        [-0.36975927,  0.79461949],
        [-0.82058779, -0.53178463],
        [-0.90448038,  0.56290476],
        [-0.93528524, -0.33991756],
        [-0.48393145,  0.82765227],
        [ 0.92695195, -0.11299709],
        [ 0.1747337 , -0.91152575],
        [ 0.46924789,  0.85696369],
        [-0.80755332, -0.09690338],
        [ 0.68671137, -0.44928513],
        [-1.00135065, -0.19838313],
        [-0.44708276, -0.52570632],
        [-0.84235872, -0.10468235],
        [ 0.5964523 , -0.65887579],
        [-0.64150852,  0.46841638],
        [ 0.74510838, -0.33800849],
        [-0.8481505 ,  0.49577518],
        [ 0.53340246,  0.16263296],
        [-0.76518489,  0.4076828 ],
        [ 0.77423655, -0.50714896],
        [ 0.70045239,  0.36407467],
        [ 0.68454359,  0.73070291],
        [-0.22125475,  1.04620201],
        [-0.68328689,  0.61204071],
        [ 0.65010701, -0.59134165],
        [-1.2313376 ,  0.21304411],
        [-0.59303519,  0.41759814],
        [ 0.78971917, -0.10207584],
        [ 0.68325353,  0.36841225],
        [-0.67395591, -0.47491903],
        [ 0.90443441,  0.08203405],
        [-0.92891781, -0.51945352],
        [-0.48009932,  0.64017755],
        [ 0.77024932, -0.23585624],
        [-0.65192581,  0.3413308 ],
        [ 0.8980832 ,  0.66498563],
        [ 0.74585255, -0.45282466],
        [-0.11778877, -0.93290335],
        [ 0.79894482, -0.27860662],
        [ 0.14376549, -0.77931984],
        [ 0.02486498, -0.87547306],
        [ 0.11971072,  0.75914672],
        [ 0.81857753, -0.3175337 ],
        [ 0.56819668,  0.44615845],
        [-0.63215008, -0.61951377],
        [ 0.68807637, -0.30856753],
        [-0.90370031,  0.15900664],
        [ 1.06172312, -0.04778164],
        [ 1.10340443, -0.07105457],
        [ 0.70872705,  0.67692255],
        [ 0.95881084, -0.3542489 ],
        [-0.69717947, -0.63963654],
        [-0.34681523, -0.70176442],
        [-0.9106265 ,  0.40112217],
        [-0.37764448, -0.95254926],
        [-0.95478306,  0.10948449],
        [-1.12614284,  0.2676192 ],
        [ 0.68716498, -0.87066808],
        [ 0.70504088,  0.68101407],
        [-0.46517197, -0.6446231 ],
        [-0.15274054,  0.98316732],
        [ 0.56575335, -0.54083807],
        [-0.42240669,  0.64182301],
        [ 0.18482686,  0.78520699],
        [ 0.55244808,  0.97466776],
        [ 0.3385726 , -1.07296768],
        [ 0.6971448 ,  0.1250291 ],
        [-0.89331579, -0.12798051],
        [ 0.67373307, -0.15611651],
        [-0.03114487, -0.96448144],
        [ 0.83331209,  0.08069024],
        [ 0.53224773,  0.73859367],
        [-0.69937586, -0.61275345],
        [ 0.76244154, -0.01884182],
        [ 0.96801692,  0.53279114],
        [-0.40302395, -1.04481451],
        [ 0.3196955 , -0.8162529 ],
        [ 0.23922831,  0.70355133],
        [-0.6794012 , -0.09483852],
        [-0.25802235, -0.58138664],
        [-0.83655049, -0.14506656],
        [ 0.76177486, -0.60585098],
        [ 0.83574051, -0.04396389],
        [ 0.66792533,  0.73372557],
        [ 0.14527072,  0.8538113 ],
        [ 0.7648258 ,  0.16841252],
        [-0.58458839,  0.90894941],
        [-1.14717799, -0.40301831],
        [-0.52014384, -0.80444037],
        [ 1.02971721,  0.13308403],
        [ 0.3198755 ,  0.96973097],
        [-0.6538897 , -0.59333924],
        [-0.83935614,  0.57923554],
        [ 0.72399386,  0.06160893],
        [-0.3356021 ,  0.68274386],
        [-0.09585005, -1.02001426],
        [ 0.4570725 ,  0.76154485],
        [-0.75836593,  0.62853034],
        [-0.21057559,  0.73152132],
        [-1.1169767 , -0.24222743],
        [-0.67806895, -0.34500209],
        [-0.21794343,  0.97765587],
        [-0.42775347,  0.97894256],
        [-0.16367645,  0.98433643],
        [ 0.63119811,  0.57018877],
        [-0.67000796, -0.06555333],
        [-0.26854662,  0.93983501],
        [-0.46371909, -0.59785075],
        [ 0.0053233 ,  0.56765914],
        [-0.0954778 , -0.65144424],
        [-0.52920433,  0.91632128],
        [-0.81742514,  0.48418834],
        [-0.3262346 , -0.81188181],
        [-0.74707281, -0.00656294],
        [-0.23190324,  1.05373594],
        [ 0.85523158, -0.00940421],
        [ 0.3155662 , -0.87965313],
        [ 0.36274103, -1.11792417],
        [ 0.54166392,  0.62310685],
        [-0.81128764,  0.20141276],
        [ 0.44980084, -0.15008246],
        [-0.21420679, -1.11207276],
        [ 0.18493852, -1.07865224],
        [-0.20002014,  1.02564578],
        [-0.46601629, -0.63512731],
        [ 0.27543151, -0.99376778],
        [ 0.56163295, -0.49956157],
        [ 0.12028626, -0.92244726],
        [ 0.6308322 , -0.37527211],
        [-0.75521585,  0.70856413],
        [-0.68105402, -0.62530978],
        [-0.51209297,  0.34502593],
        [ 0.8343438 , -0.63466008],
        [-0.35552746, -0.96083597],
        [ 0.68114731, -0.69719584],
        [-0.61929686,  0.81662473],
        [ 0.29381154,  0.80426366],
        [ 0.80536053,  0.5145903 ],
        [ 0.74459003,  0.08740938],
        [-0.49997044, -0.85522079],
        [ 0.51582075, -0.45607009],
        [ 0.45287299,  0.71393812],
        [-0.89231508,  0.20947499],
        [-0.21657517, -0.72090116]]),
 array([0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1,
        0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1,
        1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0,
        0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0,
        1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0,
        0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0,
        1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1,
        1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1,
        0, 1]))
In [4]:
data = pd.DataFrame({'x1':data[0][:,0], 'x2':data[0][:,1], 'y':data[1]})
data.head()
Out[4]:
x1 x2 y
0 0.376883 0.866835 0
1 -0.221240 -0.928980 1
2 0.632182 0.328701 1
3 0.947487 -0.339669 0
4 0.101270 -0.849514 1
In [5]:
plt.figure(figsize=(8, 8))
plt.scatter(data[data.y == 0].x1, data[data.y == 0].x2)
plt.scatter(data[data.y == 1].x1, data[data.y == 1].x2)
plt.show()
In [7]:
preds[:5]
Out[7]:
array([[0.6488572 , 0.3511428 ],
       [0.39357147, 0.60642853],
       [0.61290164, 0.38709836],
       [0.56522105, 0.43477895],
       [0.43123194, 0.56876806]])
In [6]:
linSVM = SVC(kernel='linear', random_state=0, probability=True)
#linSVM.fit(data[['x1', 'x2']], data.y)
preds = cross_val_predict(estimator=linSVM, X=data[['x1', 'x2']], y=data.y, cv=5, method='predict_proba')

plot_roc(y_score=preds[:, 1], y_true=data.y)
In [8]:
data['x1^2'] = data.x1**2
data['x2^2'] = data.x2**2

linSVM = SVC(kernel='linear', random_state=0, probability=True)
#linSVM.fit(data[['x1', 'x2', 'x1^2', 'x2^2']], data.y)
preds = cross_val_predict(estimator=linSVM, 
                          X=data[['x1', 'x2', 'x1^2', 'x2^2']], 
                          y=data.y, cv=5, method='predict_proba')

plot_roc(y_score=preds[:, 1], y_true=data.y)
In [9]:
rbfSVM = SVC(kernel='rbf', random_state=0, probability=True, C=1)
#rbfSVM.fit(data[['x1', 'x2']], data.y)
preds = cross_val_predict(estimator=rbfSVM, 
                          X=data[['x1', 'x2']], 
                          y=data.y, cv=5, method='predict_proba')

plot_roc(y_score=preds[:, 1], y_true=data.y)
In [10]:
Cs = []
AUCs = []
for c in [1e-6, 1e-5, 1e-4, 1e-2, 0.1, 0.15, 0.2, 0.5, 0.7, 1.5, 10, 1e2, 1e3, 1e4, 1e5]:
    rbfSVM = SVC(kernel='rbf', random_state=0, probability=True, C=c)
    rbfSVM.fit(data[['x1', 'x2']], data.y)
    preds = cross_val_predict(estimator=rbfSVM, 
                              X=data[['x1', 'x2']], 
                              y=data.y, cv=5, method='predict_proba')
    
    Cs.append(c)
    AUCs.append(roc_auc_score(y_score=preds[:,1], y_true=data.y))
In [11]:
plt.plot(Cs, AUCs, 'o-')
plt.xscale('log')
plt.xlabel('C', fontsize=15)
plt.ylabel('AUC', fontsize=15)
plt.grid()