In this article, I explain how gradient descent algorithm (GDA) works. I use a quadratic function for explanation:
f(x) = ax^2+bx+c
As you know, gradient of a function is first derivative of function (f'(x)). First derivative of f(x):
f'(x) = 2ax + b;
GDA is a iterative algorithm that calculates gradient of function f(x) for x in each iteration, and calculates next point by substracting gradient from x. GDA uses a learning parameter (alfa) that regularizes changings of x. So in each iteration, GDA calculates next x value:
x(t+1) = x(t) - alfa*f'(x(t))
So, why does GDA use gradient of function?
Think yourself on a hill and you want to get down. If slope is steep, you take a big step, otherwise a small step. Because while coming to end of hill, slope will be smaller until it is zero.
If you think of gradient as a vector, gradient's direction always guide you to end of hill or top of hill.
Note that, we can use GDA to find minimum point of a convex function, and gradient ascent algorithm to find maximum point of a concave function.
I prepared a working javascript example. You can calculate minimum point of quadratic function and visualize calculated point and function with this code.
In each iteration, calculated point (x,y) is displayed as a red dot.
Javascript Codes:
<script> var x = 200; var alfa = 0.01; var maxIteration = 5000; var a = 1; var b = -5; var c = 6; var fx = (x) => a * x * x + b * x + c; var dxfx = (x) => 2 * a * x + b; function start() { setValues(); drawCurve(); var canvas = document.getElementById("g"); let y = 0; for (var i = 0; i < maxIteration; i++) { x = x - alfa * dxfx(x); y = fx(x); let circle = document.createElementNS('http://www.w3.org/2000/svg', 'circle');; circle.setAttributeNS(null, "cx", x); circle.setAttributeNS(null, "cy", y); circle.setAttributeNS(null, "r", 1); circle.setAttributeNS(null, "fill", "red"); canvas.appendChild(circle); } let resultDiv = document.getElementById("resultDiv"); resultDiv.innerHTML = "x:" + x + ", y:" + y; } function setValues() { x = parseFloat(document.getElementById("x").value); alfa = parseFloat(document.getElementById("alfa").value); maxIteration = parseFloat(document.getElementById("maxIteration").value); a = parseFloat(document.getElementById("a").value); b = parseFloat(document.getElementById("b").value); c = parseFloat(document.getElementById("c").value); } function drawCurve() { var canvas = document.getElementById("g"); for (var i = 5000; i >= -5000; i -= 1) { let line = document.createElementNS('http://www.w3.org/2000/svg', 'line'); line.setAttributeNS(null, "x1", i); line.setAttributeNS(null, "y1", fx(i)); line.setAttributeNS(null, "x2", i - 1); line.setAttributeNS(null, "y2", fx(i - 1)); line.setAttributeNS(null, "style", "stroke: rgb(0, 255, 0); stroke-width:1"); canvas.appendChild(line); } } </script>HTML Codes:
<div> <input id="x" type="text" value="200" />->x </div> <div> <input id="alfa" type="text" value="0.01" />->alfa </div> <div> <input id="maxIteration" type="text" value="5000" />->maxIteration </div> <div> <input id="a" type="text" value="1" />->a </div> <div> <input id="b" type="text" value="-5" />->b </div> <div> <input id="c" type="text" value="6" />->c </div> <div style="margin-top:20px" id="resultDiv"></div> <div style="margin-top:20px"> <button type="button" onclick="start()">START</button> </div> <div style="margin-top:50px"> <svg id="canvas" width="500" height="500"> <g id="g" transform="translate(250 250) scale(1,-1)"> <line x1="-1000" y1="0" x2="1000" y2="0" style="stroke: rgb(0, 0, 255); stroke-width:1"></line> <line x1="0" y1="-1000" x2="0" y2="1000" style="stroke: rgb(0, 0, 255); stroke-width:1"></line> </g> </svg> </div>
No comments:
Post a Comment