Implementing ordinary least squares regression from scratch for 1D case. This homework is assigned on 8 February 2023.
knitr::opts_knit$set(root.dir = "/Users/turx/Projects/machine-teaching-23sp/hw01-ols")
knitr::opts_chunk$set(warning = FALSE, message = FALSE)
library(readxl)
library(tidyverse)
library(lubridate)
library(ggplot2)
data <- read_excel("Data_Extract_From_World_Development_Indicators.xlsx")
data <- data %>%
na_if("..") %>%
rename(c_name = "Country Name", c_code = "Country Code", "2017" = "2017 [YR2017]", "2018" = "2018 [YR2018]") %>%
pivot_longer(cols = paste(c(1990, 2000, c(2010:2019))), names_to = "year", values_to = "gdp") %>%
filter(`Series Name` == "GDP (current US$)", `Series Code` == "NY.GDP.MKTP.CD") %>%
select(-`Series Name`, -`Series Code`) %>%
drop_na() %>%
mutate(gdp = as.numeric(gdp), year = as.numeric(year)) %>%
group_by(c_code)
ols_regression <- function(x, y) {
x <- as.matrix(x)
y <- as.matrix(y)
m <- (mean(x * y) - mean(x) * mean(y)) / (mean(x^2) - mean(x)^2)
b <- mean(y) - m * mean(x)
return(list(m = m, b = b))
}
usa <- data %>%
filter(c_code == "USA") %>%
ungroup() %>%
select(year, gdp)
usa %>%
ggplot(aes(x = year, y = gdp)) +
geom_point() +
geom_smooth(method = "lm", se = FALSE) +
labs(title = "GDP of the United States (with library)", x = "Year", y = "GDP", color = "blue") +
scale_y_continuous(labels = scales::dollar)
l_ols <- ols_regression(usa$year, usa$gdp)
usa %>%
ggplot(aes(x = year, y = gdp)) +
geom_point() +
geom_abline(intercept = l_ols$b, slope = l_ols$m, color = "red") +
labs(title = "GDP of the United States (without library)", x = "Year", y = "GDP") +
scale_y_continuous(labels = scales::dollar)
The learned coefficients for \(y=mx+b\)
l_ols
## $m
## [1] 507558820458
##
## $b
## [1] -1.004633e+15
Training set error: \(E_{\text{train}} = \sum_{i=1}^n (y_i - \hat{y}_i)^2\)
error <- usa %>%
mutate(y_hat = l_ols$m * year + l_ols$b) %>%
mutate(error = (gdp - y_hat)^2) %>%
summarise(error = sum(error)) %>%
pull()
error
## [1] 2.213066e+24
MSE: \(MSE = \frac{1}{n} E_{\text{train}}\)
MSE <- error / nrow(usa)
MSE
## [1] 2.011879e+23