We'll see how linear regression works but first we need a dataset. Experiment
with the values below to generate a dataset.
Problem Statement
Given a set of m training examples (x(i),y(i)) for all
i∈[1,…,m], a linear regression model is a supervised learning
regression model which expresses f(x(i))=y^(i) as a linear function
of x(i) given parameters w and b as follows:
y^(i)=w⋅x(i)+b
The model is a supervised learning model because the training set contains the
"right" or expected output value for the target variable for every x(i),
and a regression model because the model outputs a continuous value. Linear
regression with one variable is known as a univariate linear regression, i.e,
x(i)∈R while multivariate linear regression is one wherein
x(i)∈Rn for n>1 features.
The goal of the linear regression model then is to find optimal parameters w
and b such that for every training example x(i),y(i) for i∈[1,…,m]
f(x(i))≈y(i)
Cost Function
It helps to provide a precise measure of how far off the predicted outputs are
from the expected output. This can be formalized as a cost function J. A
useful cost function for linear regression is the mean squared error,
J(w,b)=2m1i=1∑m(y^(i)−y(i))2
One of the benefits of the m1 is that without it, as m gets larger,
the value of the cost function gets larger as well. Therefore it helps to build
a cost function that does not change significantly just because the size of the
dataset increases. The 2 in 2m simplifies the calculation of the gradient
later on given that the minimum of f(x) is the same as the minimum of
21f(x). Given the cost function, the goal is to find optimal
parameters w∗ and b∗ which minimize J(w,b).
Take a look at the contour plot of the loss function below. Observe that the
plot looks somewhat convex, with the deepest point being wherever the true
coefficient and intercept was set above. This is precisely the optimal point to
minimize the mean squared error cost function.
TODO: PLOT
The convex nature of the function itself is more apparent and easily
visualizable given another 3-dimensional plot as you can see below. As before,
the lowest point can be identified based on the location of the minimum point
and is based on whatever values you may have set above. Any other guess by the
model will incur additional loss.
TODO: PLOT
Why Use Mean Squared Error?
The mean squared error cost function comes from a statistical method known as
maximum likelihood estimation. Recall that the underlying assumption for
linear regression is that the data is accurately modelled with a linear
function. The error terms are assumed to be i.i.d from a normal
distribution with mean 0 and constant variance σ2, i.e for all
i∈[1,…,m]
Thus the probability density function of the error terms for each sample
ϵ(i) for i∈[1,…,m] are represented with a Gaussian or
normal distribution with parameters μ=0 and θ2=σ2 as
follows:
The likelihood, represents the joint probability density fuction of observing
the data that was gathered. Assuming that all samples are independent this is
simply the product of all the probability density fuctions for ϵ
The goal is to find the value θ2 which maximizes the likelihood of
seeing the data that was gathered given these parameters. In this case, given
the probability density function of the normally distributed error terms
Given that the log(x) is a strictly increasingly monotic function, i.e, for
every x1<x2, f(x1)<f(x2), then the maximum of L(θ2) is also
a maximum of the log(L(θ2)) . This allows the derivation to
be more convenient based on the properties of log
Gradient descent is an iterative algorithm that allows for a linear regression
model — and many other more complex deep learning models — to arrive
at a minimum of the loss function.
The idea is that given w and b which are initialized in any manner, the
algorithm will continuously update these parameters in the steepest direction
(negative of the gradient) towards a minimum of the cost function. This
iteration is repeated until convergence.
Note that since the mean squared error cost function is convex, any local
minimum is also the global minimum. This is not necessarily true for other cost
functions, leading to multiple local minima depending on the initial starting
point.
The parameters w and b on the k-th iteration are continuously updated
(simultaneously) as follows,
Experiment with the values below to modify the parameters of the linear
regression model. Click on the 'Train Model' button to start training the linear
model.
Analytical Solution
Given that the mean squared error cost function is a convex function, the
minimum can be solved directly by writing the derivatives as a system of
equations with two unknowns and finding the solution for