Coder Social home page Coder Social logo

gradient-descent-step-sizes-lab-data-science's Introduction

Introduction

In this lab, we'll practice applying gradient descent. As we know gradient descent begins with an initial regression line, and moves to a "best fit" regression line by changing values of $m$ and $b$ and evaluating the RSS. So far, we have illustrated this technique by changing the values of $b$ and evaluating the RSS. In this lab, we will work through applying our technique by changing the value of $m$. We'll have access to our graph library, linear equations library, and error library in completing this lab.

Setting up our initial regression line

Once again, we'll take take a look at revenues of movies to predict revenue.

first_show = {'budget': 100, 'revenue': 275}
second_show = {'budget': 200, 'revenue': 300}
third_show = {'budget': 400, 'revenue': 700}

shows = [first_show, second_show, third_show]

Using our data, and our build_regression_line, we get some values for an initial regression line.

from linear_equations import build_regression_line

budgets = list(map(lambda show: show['budget'], shows))
revenues = list(map(lambda show: show['revenue'], shows))

build_regression_line(budgets, revenues)
{'b': 133.33333333333326, 'm': 1.4166666666666667}
def regression_line(x):
    return 1.417*x + 133.33

Now using the residual_sum_squares, function, we calculate the RSS. Let's take another look at it here:

def residual_sum_squares(x_values, y_values, m, b):
    return sum(squared_errors(x_values, y_values, m, b)) 

Building a cost curve

Now let's use the RSS to build a cost curve. Keeping the $b$ value fixed at $133.33$, write a function called rss_values that takes x_values and y_values to pass through our dataset, and various values of $m$, an initial $b$ value. It outputs a dictionary with keys of m_values and rss_values, with each key pointing to a list of the corresponding values.

from error import residual_sum_squares
def rss_values(x_values, y_values, m_values, b):
    pass
budgets = list(map(lambda show: show['budget'] ,shows))
revenues = list(map(lambda show: show['revenue'] ,shows))
initial_m_values = list(range(8, 19, 1))
scaled_m_values = list(map(lambda initial_m_value: initial_m_value/10,initial_m_values))
b_value = 133.33
rss_values(budgets, revenues, scaled_m_values, b_value)

# {'m_values': [0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8],
#  'rss_values': [64693.76669999998,
#   45559.96669999998,
#   30626.166699999987,
#   19892.36669999999,
#   13358.5667,
#   11024.766700000004,
#   12890.96670000001,
#   18957.166700000016,
#   29223.36670000002,
#   43689.566700000025,
#   62355.76670000004]}

Plotly provides for us a table chart, and we can pass the values generated from our rss_values function to create a table.

from plotly.offline import iplot, init_notebook_mode
from graph import plot
import plotly.graph_objs as go

init_notebook_mode(connected=True)

def plot_table(headers, columns):
    trace_cost_chart = go.Table(
        header=dict(values=headers,
                    line = dict(color='#7D7F80'),
                    fill = dict(color='#a1c3d1'),
                    align = ['left'] * 5),
        cells=dict(values=columns,
                   line = dict(color='#7D7F80'),
                   fill = dict(color='#EDFAFF'),
                   align = ['left'] * 5))
    plot([trace_cost_chart])
<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>
cost_chart = rss_values(budgets, revenues, scaled_m_values, b_value) or {}
column_values = list(cost_chart.values())
plot_table(headers = ['M values', 'RSS values'], columns=column_values)
<script type="text/javascript">require(["plotly"], function(Plotly) { window.PLOTLYENV=window.PLOTLYENV || {};window.PLOTLYENV.BASE_URL="https://plot.ly";Plotly.newPlot("9ac67c50-26a0-4881-be25-7fe41ffd5388", [{"type": "table", "header": {"values": ["M values", "RSS values"], "line": {"color": "#7D7F80"}, "fill": {"color": "#a1c3d1"}, "align": ["left", "left", "left", "left", "left"]}, "cells": {"values": [], "line": {"color": "#7D7F80"}, "fill": {"color": "#EDFAFF"}, "align": ["left", "left", "left", "left", "left"]}}], {}, {"showLink": true, "linkText": "Export to plot.ly"})});</script>

And let's plot this out using a a line chart.

from plotly.offline import iplot, init_notebook_mode
init_notebook_mode(connected=True)
from graph import plot, trace_values

initial_m_values = list(range(8, 19, 1))
scaled_m_values = list(map(lambda initial_m_value: initial_m_value/10,initial_m_values))
cost_values = rss_values(budgets, revenues, scaled_m_values, 133.33)
rss_trace = trace_values(cost_values['m_values'], cost_values['rss_values'], mode = 'line')
plot([rss_trace])

Changing our step size

In this section, we'll work up to building a gradient descent function that automatically changes our step size. To get you started, we'll provide a function called slope_at that calculates the slope of the cost curve at a given point. Here it is in action:

from helper import slope_at
slope_at(budgets, revenues, 1.7, 133.33333333333326)
{'m': 1.7, 'slope': 165687.66666649026}
slope_at(budgets, revenues, 1.3, 133.33333333333326)
{'m': 1.3, 'slope': -2312.3333333387563}

As you can see, it seems pretty accurate. When the curve is steeper at $m = 1.7$, the slope is over 165,000. When we near our flatline of our cost curve with our $m = 1.3$, our slope has a much smaller magnitude with a value of $-2312.3$.

Ok, now we're ready to write a function called updated_m. The function will allow us to move along our cost curve more efficiently, by taking a more efficient step size. The updated_m function takes as arguments an initial value of $m$, a learning rate, and the slope of the cost curve at that value of $m$. It returns an integer that equals the next value of m.

from error import residual_sum_squares

def updated_m(m, learning_rate, cost_curve_slope):
    pass
current_slope = slope_at(budgets, revenues, 1.7, 133.33333333333326)['slope']
updated_m(1.7, .000001, current_slope)
# 1.5343123333335096

current_slope = slope_at(budgets, revenues, 1.534, 133.33333333333326)['slope']
updated_m(1.534, .000001, current_slope)
# 1.43803233333338

current_slope = slope_at(budgets, revenues, 1.438, 133.33333333333326)['slope']
updated_m(1.438, .000001, current_slope)
# 1.3823523333332086

Take a careful look at how we use the updated_m function. By using our updated value of $m$ we are quickly converging towards an optimal value of $m$.

Now let's write another function called gradient_descent_values. Similar to our rss_values function it outputs keys of m_values and rss_values each returning a list of corresponding values. However, the inputs are now x_values, y_values, number_of_steps, and b. The number_of_steps arguments represents the number of steps the function will take before the function stops. It is the number of steps that are taken.

def gradient_descent(x_values, y_values, steps, b, learning_rate, current_m):
    pass
descent_steps = gradient_descent(budgets, revenues, 12, 133.33, learning_rate = .000001, current_m = 0) or []
m_values = list(map(lambda step: step['m'],descent_steps))
rss_result_values = list(map(lambda step: step['rss'], descent_steps))
text_values = list(map(lambda step: 'cost curve slope: ' + str(step['slope']), descent_steps))
gradient_trace = trace_values(m_values, rss_result_values, text=text_values)
plot([gradient_trace])
<script type="text/javascript">require(["plotly"], function(Plotly) { window.PLOTLYENV=window.PLOTLYENV || {};window.PLOTLYENV.BASE_URL="https://plot.ly";Plotly.newPlot("94e846e3-e299-415e-8adb-cf9ef7d2cb05", [{"x": [], "y": [], "mode": "markers", "name": "data", "text": []}], {}, {"showLink": true, "linkText": "Export to plot.ly"})});</script>

Taking a look at a plot of our trace, you can get a nice visualization of how our gradient descent function works. It starts far away with $m = 0$, and the step size is relatively large, as is the slope of the cost curve. As the $m$ value updates such that it approaches a minimumm of the RSS, and the slope of the cost curve decreases, the size of the step also decreases.

gradient-descent-step-sizes-lab-data-science's People

Contributors

jeffkatzy avatar

Watchers

 avatar Rishikesh Tirumala avatar James Cloos avatar  avatar Victoria Thevenot avatar  avatar Joe Cardarelli avatar Katie Burke avatar Sara Tibbetts avatar The Learn Team avatar Sophie DeBenedetto avatar  avatar Antoin avatar Alex Griffith avatar  avatar Amanda D'Avria avatar  avatar Nicole Kroese  avatar Lisa Jiang avatar  avatar

Forkers

fpolchow

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.