In [2]:
import pandas as pd

from sklearn.linear_model import Lasso, Ridge
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression

%pylab inline
Populating the interactive namespace from numpy and matplotlib

Predicting age from DNA methylation pattern

https://en.wikipedia.org/wiki/Epigenetic_clock

By Mariuswalter - Own work, CC BY-SA 4.0, https://commons.wikimedia.org/w/index.php?curid=54318073

Most important publication: Horvath, S. DNA methylation age of human tissues and cell types. Genome Biol 14, 3156 (2013). https://doi.org/10.1186/gb-2013-14-10-r115

Here we will use only a small example dataset, therefore our results will not be as good.

~3 years of accuracy is possible

Read and quickly look at the data!

In [3]:
data = pd.read_csv('methylation_data.tsv', sep='\t')
print(data.shape)
data.head()
(656, 502)
Out[3]:
id cg13869341 cg14008030 cg12045430 cg20826792 cg00381604 cg20253340 cg21870274 cg03130891 cg24335620 ... cg19734301 cg08470135 cg05035248 cg05041517 cg15207999 cg26338735 cg24437834 cg15565004 cg19999567 age
0 X1001 0.849261 0.505916 0.072590 0.186961 0.036803 0.661391 0.777891 0.119538 0.782193 ... 0.747433 0.873424 0.985385 0.619894 0.849958 0.859900 0.854188 0.712632 0.888114 67
1 X1002 0.897434 0.476842 0.079020 0.228201 0.053161 0.545065 0.776407 0.063938 0.787890 ... 0.733304 0.857438 0.993911 0.639916 0.864043 0.854608 0.828381 0.715557 0.916687 89
2 X1003 0.751596 0.487245 0.089230 0.237660 0.045588 0.560305 0.774234 0.113279 0.788896 ... 0.760751 0.858997 0.992959 0.617711 0.834260 0.875883 0.825482 0.722971 0.890517 66
3 X1004 0.871313 0.466692 0.076666 0.253624 0.032824 0.509904 0.772503 0.067163 0.800081 ... 0.760267 0.865483 0.990998 0.613343 0.849927 0.883713 0.833676 0.725210 0.901819 64
4 X1005 0.775703 0.490255 0.079986 0.220404 0.035804 0.549847 0.730345 0.131637 0.759171 ... 0.763401 0.847928 0.988508 0.640416 0.858024 0.882822 0.832511 0.729157 0.908057 62

5 rows × 502 columns

In [4]:
data.pop('id');
y = data.pop('age')
In [5]:
data.max().max()
Out[5]:
0.9992396000000001
In [6]:
data.min().min()
Out[6]:
0.0
In [7]:
data.isna().sum().sum()
Out[7]:
0
In [8]:
plt.hist(y)
plt.xlabel('age', fontsize=20)
plt.show()

Try different linear models!

Linear regression

In [9]:
train_x = data.values[::2]
train_y = y[::2]

test_x = data.values[1::2]
test_y = y[1::2]
In [10]:
lr = LinearRegression()
lr.fit(train_x, train_y)
lr_pred = lr.predict(test_x)
lr_pred_train = lr.predict(train_x)
In [11]:
(lr_pred_train - train_y).abs().mean()
Out[11]:
1.9011350759239755e-13
In [12]:
def plot_preds(train_pred, train_y, test_pred, test_y):
    plt.figure(figsize=(16, 8))
    plt.subplot(121)
    plt.title('Test data', fontsize=20)

    plt.scatter(test_pred, test_y)
    plt.xlim(15, 105)
    plt.ylim(15, 105)
    plt.xlabel('linear regression prediction', fontsize=20)
    plt.ylabel('actual age', fontsize=20)
    plt.plot([0, 150], [0, 150], '--', c='k', label='perfect prediction')
    plt.legend(fontsize=15)

    plt.subplot(122)
    plt.title('Train data', fontsize=20)
    plt.scatter(train_pred, train_y)
    plt.xlim(15, 105)
    plt.ylim(15, 105)
    plt.xlabel('linear regression prediction', fontsize=20)
    plt.ylabel('actual age', fontsize=20)
    plt.plot([0, 150], [0, 150], '--', c='k', label='perfect prediction')
    plt.legend(fontsize=15)
    plt.show()
In [13]:
(lr_pred - test_y).abs().mean()
Out[13]:
12.613022885843598
In [14]:
plot_preds(lr_pred_train, train_y, lr_pred, test_y)
In [15]:
np.corrcoef(test_y, lr_pred)
Out[15]:
array([[1.        , 0.60339756],
       [0.60339756, 1.        ]])

Ridge regression (L2 regularization)

In [16]:
r = Ridge(alpha=0.2)
r.fit(train_x, train_y)
r_pred = r.predict(test_x)
r_pred_train = r.predict(train_x)
In [17]:
(r_pred - test_y).abs().mean()
Out[17]:
8.089739775443295
In [18]:
np.corrcoef(test_y, r_pred)
Out[18]:
array([[1.        , 0.73634028],
       [0.73634028, 1.        ]])
In [19]:
plot_preds(r_pred_train, train_y, r_pred, test_y)

Try different regularization strengths!

In [20]:
def get_mae_for_alpha(alpha):
    r = Ridge(alpha=alpha)
    r.fit(train_x, train_y)
    r_pred = r.predict(test_x)
    return (r_pred - test_y).abs().mean()
In [21]:
alphas = []
maes   = []
for a in np.linspace(0, 10, 20):
    alphas.append(a)
    maes.append(get_mae_for_alpha(a))
/home/pataki/.conda/envs/fastai/lib/python3.6/site-packages/sklearn/linear_model/_ridge.py:188: LinAlgWarning: Ill-conditioned matrix (rcond=2.63028e-18): result may not be accurate.
  overwrite_a=False)
In [22]:
plt.plot(alphas, maes)
Out[22]:
[<matplotlib.lines.Line2D at 0x7f50481f8b00>]
In [23]:
alphas = []
maes   = []
for a in np.linspace(0.05, 2, 100):
    alphas.append(a)
    maes.append(get_mae_for_alpha(a))
    
plt.plot(alphas, maes)
plt.grid()

Parameter values vs regularization strength

In [24]:
r.coef_
Out[24]:
array([-3.38105342e+00, -2.86229825e+00,  6.56102431e+00, -9.46628329e+00,
       -6.27168449e+00,  4.31180032e+00,  4.60446229e+00, -3.01469631e+00,
        1.38633363e+01,  1.82447673e+01, -1.97458906e-01,  9.84569191e+00,
        1.25943354e+01,  1.94474712e+01, -1.51144187e+01,  1.84121094e+00,
        9.25285724e+00,  3.37221109e-01, -3.07067483e+00, -1.36309293e+01,
       -2.66894244e+00,  1.65493554e+00,  6.94042382e+00,  7.55774046e+00,
        1.03792931e+01,  7.77525578e-01, -2.31631743e+00, -5.38359898e+00,
       -5.29341473e+00,  1.38235305e+01, -6.62979643e+00, -5.12117847e+00,
        7.05978811e+00,  6.21736630e+00,  5.45767588e+00,  9.04315102e-01,
        7.41605317e-01, -7.26673103e-01, -4.68102245e+00,  1.09552695e+01,
       -2.59597173e+00,  3.76056185e+00,  4.96927921e+00, -2.46337165e+00,
        2.61778372e+00, -9.64721918e-01, -4.00889478e-01, -3.13121713e+00,
       -5.85614842e-01,  3.29349599e-01, -1.96688824e+00,  1.70874339e+00,
        4.47422784e+00, -7.11888011e-01,  2.75008963e-01, -1.25655908e-01,
        1.26792535e+01,  1.69825353e+01,  1.41536760e+01,  2.48263502e+00,
       -5.36814361e-01,  1.28001553e+01, -5.79269181e+00, -9.18075963e-01,
        1.27269193e+01, -1.87217973e+01, -1.87014537e+00, -6.76746486e+00,
       -4.92279647e+00, -4.61340170e+00, -2.79159316e+00,  1.84691883e+01,
       -6.09160849e+00, -4.85321402e+00,  7.40634401e+00,  1.72534895e+01,
        3.37652198e+00,  4.84060486e+00,  3.89460544e+00,  5.93088954e-01,
        9.81111064e+00, -2.74776426e+00,  7.90442789e+00, -4.33916730e+00,
        1.24871732e+01,  2.35192804e-01, -3.64940910e+00,  4.54288466e+00,
        1.59626311e+00,  3.86012640e+00, -2.30208067e+00, -1.46695686e+00,
       -1.08262832e+01,  2.00604805e-01,  6.14916359e+00, -6.45169910e+00,
       -3.83587265e+00, -9.82132842e+00,  1.44443771e+01,  1.46066935e+01,
        4.93108052e+00, -1.85743550e+00, -5.62516010e+00,  8.08232751e+00,
        1.10502823e+01,  9.71050400e+00,  1.07565995e+01,  3.10459313e+00,
       -6.16164844e+00,  1.08559207e+01,  5.50288954e+00, -7.72275235e+00,
        1.69842901e+01,  2.05669589e+00,  1.49975645e+01, -7.09310837e+00,
        1.23426268e+00, -4.94952558e+00,  6.67791562e+00,  2.11076527e+00,
        7.31604830e+00, -4.93661518e+00, -1.12418825e+01, -2.53281016e+01,
        1.55693407e+00,  9.69961228e+00, -1.56191545e+01,  2.01023987e-01,
       -6.07882751e-01, -3.76459475e+00,  8.88576221e+00, -7.07215195e+00,
        1.69232432e+01,  1.49072592e+00, -1.31813998e+00, -7.28409177e-01,
        7.90487214e+00,  3.68298433e+00,  4.54076589e+00,  4.22221335e+00,
        1.10724625e+00,  1.66664200e+00, -5.43656056e+00, -2.20660885e+01,
       -9.60883600e+00, -4.80056588e+00,  1.27802600e+01, -1.32282151e+01,
       -7.64650894e+00, -5.11414985e+00, -4.90177486e-01, -1.03420596e+00,
       -2.70586976e+00,  1.09455472e+01,  4.95560203e+00, -1.15129138e-01,
       -5.28848578e+00,  3.12407706e+00,  8.74165461e+00, -1.35886484e+00,
        4.09095139e+00,  3.79691234e+00,  8.50367763e+00,  1.56726590e+01,
        1.29071312e+01, -2.52422726e+01, -2.33651572e+01, -1.19493405e+00,
        1.15858284e+01, -2.31074203e+00, -5.63821386e+00, -1.42332247e+00,
       -5.57654512e-01, -1.43850578e+00,  7.20853168e-01, -2.07633913e+00,
        1.31071020e+00, -3.87487934e+00, -2.07526195e+00,  2.54276537e+00,
       -2.31768606e+00, -4.71994750e+00, -1.75499619e+01,  1.10852135e+01,
        5.70424770e+00, -2.68290463e+00,  1.69668378e+00, -5.90568488e+00,
        1.12899702e+01, -7.32826524e-01, -1.05589231e+01,  3.06985268e+00,
        2.50264073e+00, -2.00951095e+00,  5.54857897e+00, -2.82627568e+00,
        5.53551253e-01,  4.89711302e-01,  1.20746933e+00,  2.99882830e+00,
        2.92219713e+00, -2.03439184e+00,  7.36547418e+00,  4.54988687e+00,
        1.26985028e-01,  1.13377695e+01,  1.50307724e+01,  9.07749958e-01,
        5.75195653e-01, -1.44585823e+00,  3.40868774e+00,  1.99890062e+00,
       -2.19297269e+00, -3.14889891e+00,  1.45553931e+01, -5.35836313e+00,
        6.19598756e+00, -2.38046308e+01, -9.45403985e-01,  6.11695036e+00,
        5.58783408e+00,  9.14418077e+00, -2.49349416e+00, -6.75497311e+00,
        6.35983281e+00, -1.49143376e+00, -8.53229128e+00,  3.62102719e+00,
       -3.41060989e+01,  8.98455055e+00,  6.82882411e+00, -7.87391182e+00,
        3.84197392e+00, -9.51393946e+00, -5.74618374e+00, -2.96546471e+00,
        6.25091436e+00, -3.54869916e+00, -1.25077878e+01,  1.92368647e+01,
       -9.13291141e+00,  6.46067966e+00,  1.16066519e+01, -7.78800318e-02,
        1.63714474e-01, -1.54032836e+01, -8.09161021e+00, -9.68648326e+00,
        4.02719417e+00,  7.65962537e+00,  1.18875730e+00, -2.14531178e+01,
        9.08671020e+00,  1.28177732e+01,  1.19174841e+01, -7.56245816e+00,
       -1.21724200e+01, -1.07520062e+01, -1.85225781e+01, -2.05809854e+01,
       -6.85024293e+00, -1.14434182e+00, -7.87829090e-01,  1.98111438e+00,
       -2.01123468e-04, -1.54164383e+00, -3.20481197e+00,  1.77519709e+00,
        1.05295174e+00,  1.70922400e+01,  3.88377523e+00,  9.77629967e+00,
       -4.09409721e+00,  1.58172639e+01,  3.25593898e+01,  1.25067522e+01,
        1.36723203e+01,  4.07791409e-01, -4.97713044e+00, -1.66525781e+00,
        4.44504668e+00,  5.27262843e+00,  1.87747668e+01,  1.68984680e+00,
        4.59715167e+00,  1.28610602e+00,  3.82145884e+00,  4.64242804e-01,
        8.73766512e+00, -1.65839019e+00,  1.09156546e+01,  1.18165257e+00,
       -1.62721547e+00,  1.24442502e+01,  7.87047674e+00, -1.81194715e+00,
       -5.24765987e+00, -1.83959006e+00, -4.83193651e+00,  5.56055932e+00,
       -2.44969814e+00,  2.34246200e+00,  2.45594699e+00,  2.50158612e+00,
       -2.98160017e-01, -2.52143917e+01,  1.13984463e+01, -6.54805419e+00,
       -5.52120431e+00,  1.11109221e+01,  5.34762568e+00,  2.54223830e+00,
       -8.61591889e-01,  2.10923179e+00,  2.10749252e+00,  4.31778725e+00,
       -1.29023359e+01,  3.18603390e+00, -2.61879359e-01, -5.70027013e+00,
        3.00287913e+00,  1.81064678e+01, -4.14359339e+00, -2.39076666e+00,
       -1.11257477e+01,  4.97280244e+00,  2.22191984e+00,  5.36357895e+00,
       -6.33144212e+00,  1.89575388e+01,  1.22147162e-02, -4.19925746e+00,
        3.71308611e+00, -4.60706387e+00, -7.67642055e-02, -6.77559199e-01,
       -4.27871347e+00,  9.86334535e-01,  8.38451386e+00,  9.38763369e+00,
       -3.08680932e-01, -1.15603158e+00,  3.04045692e-02, -3.21516671e+00,
       -7.10446090e+00, -1.49608873e-01, -9.22437563e+00,  3.35088597e-01,
       -8.44339663e-01,  4.62592882e+00, -1.73486409e+00,  6.17099459e+00,
       -8.76543410e-01, -8.75530451e-01, -2.54305112e+00, -5.29491542e+00,
        2.89297841e+00,  3.93785878e+00,  1.09775666e+01, -4.62345701e+00,
       -1.66928614e+01,  6.58280995e+00, -6.47878879e-01,  6.77701584e+00,
       -7.38553295e+00, -2.86335022e+00, -1.26332465e+00,  3.25693007e+00,
        9.93473575e+00, -1.09336951e+00,  8.67281924e+00,  6.44205053e+00,
       -2.14065792e+00,  4.13340319e+00, -4.41400751e+00, -4.58466032e+00,
       -3.11921752e+00, -5.60162236e+00,  5.13346251e+00,  7.93384485e+00,
       -5.75383510e+00,  1.92288636e+00, -2.10213660e+01, -9.13660100e-01,
        4.85424807e+00,  4.43097491e+00,  1.28341490e+00, -1.15216530e+00,
        1.20740770e-01,  4.55128891e+00,  1.35510443e+00, -2.38014837e-01,
        4.31120599e-01,  1.44749976e-01, -3.11754209e+00, -5.00998604e+00,
       -1.13747752e+00,  9.72191267e+00,  4.59538966e-01, -6.78516382e+00,
       -5.76800709e-01, -1.57690499e+00, -3.04971391e+00,  6.65741895e+00,
       -2.23550353e+00, -3.06132188e-01,  1.13093847e+01,  4.52432133e+00,
       -9.89622346e+00, -2.91933298e+01, -1.12503490e+01, -1.54127957e+00,
        1.85831632e+00,  4.93860752e+00, -2.54738743e+00,  6.49160756e+00,
        6.56782508e+00, -6.24684576e+00,  1.16489953e+01,  1.27212237e+01,
       -4.54813006e+00,  3.66817539e+00,  2.75184498e+00,  1.69680974e+00,
        1.55364146e+00,  9.22548972e+00, -8.20756459e+00,  6.44329448e-01,
        6.10114209e+00,  4.95976709e+00,  6.93752908e+00,  2.07491964e-01,
        5.11996423e+00,  1.19171078e+01, -4.39878023e+00,  2.35351788e+00,
       -1.83368057e+00,  5.32413509e+00, -3.07527228e+00, -1.43748870e+01,
        1.21562624e+01, -1.28672827e+01,  3.88090650e+00,  2.87892544e+00,
        2.06930559e+00,  9.34802046e+00,  9.29072846e+00,  1.63674207e+00,
       -4.42429367e+00, -1.03384604e+01,  1.00752023e+01,  1.39635239e+01,
       -1.05367575e+01, -5.50068083e+00, -8.53234178e+00,  2.77684370e+00,
       -4.58627648e+00,  4.54951702e+00, -8.20328926e+00,  1.40708531e+01,
        3.45348029e+00, -2.58677016e+00, -2.43531321e+00, -2.10795620e+00,
       -7.02110224e+00, -3.09709899e+00, -1.40442046e+01, -2.05754088e+01,
       -9.68001787e+00, -1.46027703e+01, -3.63613727e+00, -2.02651652e+01,
        1.31506290e+01,  1.39313737e+01, -3.09739512e+00,  9.52337888e+00,
        1.55217879e+00,  6.65799076e+00,  5.64967635e+00, -1.22226391e+00,
        5.22438888e+00, -2.82244343e+00, -8.38688860e-01, -1.82644092e+00,
        1.88227867e+00,  1.23407635e+00,  1.07334738e+01, -2.55916134e+00,
        3.12671717e-01,  5.56379646e+00, -2.01842601e+00,  1.38518554e+00,
       -1.90286789e+00,  1.56656295e+00,  6.89580954e+00, -1.08211850e+01,
       -7.31732768e+00, -3.20752083e+00, -3.34452192e+00,  1.18521468e+01])
In [25]:
all_coeffs = []
alphas = []

for a in np.linspace(0.1, 10, 100):
    alphas.append(a)
    r = Ridge(alpha=a)
    r.fit(train_x, train_y)
    all_coeffs.append(r.coef_)
    
all_coeffs = np.array(all_coeffs)
In [26]:
plt.plot(alphas, all_coeffs)
plt.show()

Lasso, L1 regularization

In [27]:
all_coeffs = []
alphas = []

for a in np.linspace(0.1, 0.8, 100):
    alphas.append(a)
    l = Lasso(alpha=a)
    l.fit(train_x, train_y)
    all_coeffs.append(l.coef_)
    
all_coeffs = np.array(all_coeffs)

plt.plot(alphas, all_coeffs)
plt.show()