Imaginings of a Livestock Geneticist

Machine Learning Algorithms to
Estimate Model Parameters

In the context of prediction, a linear model, although very simple, is a robust prediction machine that can result in a similar or even better predictive ability than more complex models such as random forest, support vector machines or neural networks. A linear regression model prediction is summarized as: :

$$ \hat y = \theta_0 + \theta_1 x_1 + \theta_2 x_2 + \theta_3 x_3 + \theta_n x_n $$

The following equation can be written in an alternative vectorized that is similar to many equations, where 'T' represents the transpose of the matrix.

$$ \hat y = \theta^T X $$

When training a prediction model, the model parameters (e.g., Θ) is chosen so that the model "best" fits the training data. The training dataset contains observations not contained within the sample where one wishes to obtain predictions. In the context of animal breeding, one can think of the training set as older phenotyped animals, and the test set as young animals lacking phenotypes. One such metric to define the best is the sum of the squared differences between the prediction (e.g., Θx) and the actual value, also known as the Mean Square Error (MSE). The MSE measures the average amount that the model's predictions vary from the actual values. The MSE is outlined below and n refers to the number of observations:

$$ MSE = \frac{1}{n} \sum^n_{i=1} (\theta^T X - y)^2 $$

In the first year of graduate school in animal breeding, a statistics class will likely describe the ordinary least squares (OLS) method to estimate the Θ values that minimize the MSE in a training set. In most cases, the OLS method of estimating the Θ values is the preferred method, and as outlined below it is rather simple to obtain.

$$ \theta = (X^T X)^{-1} Xy $$

The cases where OLS becomes computationally intensive is when the number of features (n) becomes large and generating the inverse of an n by n matrix (e.g., X transpose X) begins to be too computationally intensive. An alternative iterative method that doesn't involve an inverse calculation is gradient descent. Gradient descent is an optimization algorithm that minimizes MSE, or in the machine learning world, referred to as the cost function. Models with a high MSE are performing poorly on the training dataset compared to models with lower MSE. The primary objective of the optimization algorithm is to find the Θ values that result in the lowest possible MSE. The algorithm figures out the direction and rate at which the Θ values are changing across iterations by taking the derivative of the MSE. The derivative tells us the direction the Θ values should move along with how big of a step to take. Furthermore, the learning rate, that the user controls, also impacts how big of a step to take. The derivative of the MSE is outlined below and n refers to the number of observations:

$$ \frac{2}{n} \sum^n_{i=1} ((X^T (X \theta - y))) $$

All three variations of the gradient descent algorithm use the same formula outlined above to determine gradient, which is then multiplied by the learning rate and the result of this subtracted from the current Θ. Within each of the gradient descent variants, this step is shown in red. The major difference between the different variations is the number of observations it uses when calculating the gradient and they are described below:

For each method, it is essential to center and scale the variables before running the algorithm in order to ensure convergence. For stochastic and mini-batch gradient descent it is important to shuffle the dataset to remove any inherent data trends. Furthermore, one of the advantages of the stochastic and mini-batch versions is the path to the 'best' Θ is noisier compared to the batch version. A noisier path may is important to allow the algorithm to potentially move out of a local minimum in order to reach the global minima. Outlined below is R code for each version.

The following data files that contain X and y can be utilized to test the algorithms. The X matrix has already been centered and scaled. Outlined below is what the solutions are for different solving methods. Look at the MSE across the different algorithms and the impact of learning rate and number of iterations on the MSE.

Normal Equations
Θ Estimate
X1 -55650.412
X1 -55650.412
X2 -56716.452
X3 13732.838
X4 -1933.128
X5 7330.041
X6 -45708.263
X7 45455.475
X8 74714.391
X9 6605.128
X10 1042.957
X11 9249.759
X12 218898.473
X13 181695.845
X14 347272.781
X15 214435.157
X16 222272.730
Batch Gradient Descent
Θ Estimate
X1 -55650.412
X1 -56056.885
X2 -57176.502
X3 13764.920
X4 -2051.110
X5 8302.964
X6 -45600.936
X7 44486.073
X8 74715.421
X9 6543.062
X10 1030.306
X11 9201.119
X12 218825.557
X13 181877.675
X14 8333.394
X15 214366.650
X16 222133.798
Stochastic Gradient Descent
Θ Estimate
X1 -55650.412
X1 -55517.102
X2 -56104.809
X3 13910.009
X4 -2188.414
X5 15368.485
X6 -45037.204
X7 36938.226
X8 74998.867
X9 5873.385
X10 1172.407
X11 8746.217
X12 219200.981
X13 181216.612
X14 2515.239
X15 212413.533
X16 222520.157
Mini-Batch Gradient Descent
Θ Estimate
X1 -55650.412
X1 -56022.5244
X2 -57131.3560
X3 13743.7264
X4 -2165.1129
X5 8422.1033
X6 -45552.7020
X7 44230.6017
X8 74928.3918
X9 6614.2208
X10 873.8993
X11 8771.1297
X12 218774.2882
X13 181869.2788
X14 8334.0142
X15 214391.5062
X16 222173.9754

Improved Gradient Descent Algorithm

In the previous implementations of gradient descent the same learning rate is applied to all parameters and depending on the local loss landscape the convergence may be slow. A method to tackle the second aspect is to incorporate a gradient accumulation or average mechanism to utilize information from more than just the previous gradient, commonly referred to as 'momentum'. In the code outlined below, the new parameters introduced 'm' and 'v', are exponentially decaying averages of past gradients. A number of different methods (add momentum, AdaGrad, RMSprop) were generated to update the initial gradient descent, but the most common one is referred to as ADAM. The code is outlined below and as outlined in the plots below the improvement in convergence is also shown.