Understanding Gradient Descent using R
This program implements a simple gradient descent algorithm to minimize the following function for the value of x:
f(x) = 1.4 * (x-2)^2 + 3.2
The gradient (the derivative of the function with respect to x) of the function is also defined in the grad function. The script starts with initializing the parameters for the gradient descent algorithm, such as the number of iterations, the stopping criterion (threshold), the initial weight, and the learning rate.
The function f(x) is plotted over a range of x values, and the algorithm then performs gradient descent to find the minimum of the function. At each iteration, the weight is updated based on the gradient and the learning rate, and the updated weight and function value are stored in xtrace and ftrace. The algorithm stops when the change in the value of f(x) is less than the specified threshold. A red line is drawn on the plot to show the progression of the weight during each iteration.
To see the effect of learning rate in optimisation you can play with different values of stepSize (recommended anything between 0 to 1).
The stopping criteria can be either number of steps or change in function value. The program stops if it doesn't see the change in the function value more than 0.00001 (threshold = 1e-5)
# This R script defines two functions: `f` and `grad`. # `f` calculates the value of the function 1.4 * (x-2)^2 + 3.2 # `grad` calculates the gradient of `f` f <- function(x) { 1.4 * (x-2)^2 + 3.2 } grad <- function(x){ 1.4*2*(x-2) } iterations <- 100 threshold <- 1e-5 #learning rate stepSize <- 0.05 # initialize x x <- -5 # initialize vectors to store x and f(x) xtrace <- x ftrace <- f(x) # generate series of x values within some range xs <- seq(-6,10,len=1000) plot(xs , f(xs), type="l",xlab="X",ylab=expression(1.4(x-2)^2 +3.2)) for (iter in 1:iterations) { x <- x - stepSize*grad(x) xtrace <- c(xtrace,x) ftrace <- c(ftrace,f(x)) points(xtrace , ftrace , type="b",col="red", pch=1) if(iter>1 && (abs(ftrace[iter]-ftrace[iter-1])) < threshold) break } df = data.frame(x=xtrace,f=ftrace)
Comments
Post a Comment