Setup:
knitr::opts_knit$set(root.dir = "/Users/turx/Projects/machine-teaching-23sp/hw02-estimate-truth")
knitr::opts_chunk$set(warning = FALSE, message = FALSE)
library(readxl)
library(tidyverse)
library(lubridate)
library(ggplot2)
The generated datasets are stored in dfs
, a list of
dataframes.
gen_dataset <- function(n, a, b, sigma_sq) {
x <- runif(n, -1, 1)
eps <- rnorm(n, mean = 0, sd = sigma_sq)
y <- a * x + b + eps
return(tibble(x = x, y = y))
}
gen_datasets <- function(n_datasets, len_dataset, a, b, sigma_sq) {
dfs <- list(1:n_datasets)
for (i in 1:n_datasets) {
dfs[[i]] <- gen_dataset(len_dataset, a, b, sigma_sq)
}
return(dfs)
}
ds <- gen_datasets(n_datasets = 100, len_dataset = 10, a = 2, b = -3, sigma_sq = 1)
write.csv(ds, "datasets-n10-s1.csv")
ds[[1]]
## # A tibble: 10 × 2
## x y
## <dbl> <dbl>
## 1 0.140 -2.29
## 2 -0.0237 -3.19
## 3 -0.992 -4.99
## 4 0.881 -0.290
## 5 0.711 0.141
## 6 -0.618 -4.17
## 7 0.855 -0.380
## 8 0.829 0.301
## 9 -0.569 -2.76
## 10 -0.107 -3.47
Definition of OLS Regression Function on 1D from Homework 01:
ols_regression <- function(x, y) {
x <- as.matrix(x)
y <- as.matrix(y)
m <- (mean(x * y) - mean(x) * mean(y)) / (mean(x^2) - mean(x)^2)
b <- mean(y) - m * mean(x)
return(list(m = m, b = b))
}
Run OLS on the 100 datasets:
gen_ols_results <- function(ds) {
ols_results <- list(1:100)
for (i in 1:100) {
ols_results[[i]] <- ols_regression(ds[[i]]$x, ds[[i]]$y)
}
ols_results_df <- tibble(
m = map_dbl(ols_results, "m"),
b = map_dbl(ols_results, "b")
)
return(ols_results_df)
}
ols_results_df <- gen_ols_results(ds)
ols_results_df
## # A tibble: 100 × 2
## m b
## <dbl> <dbl>
## 1 2.63 -2.40
## 2 2.13 -2.90
## 3 2.80 -2.52
## 4 1.92 -3.13
## 5 2.21 -3.14
## 6 2.23 -2.95
## 7 1.97 -2.92
## 8 1.90 -2.95
## 9 0.630 -3.30
## 10 1.49 -2.93
## # … with 90 more rows
gen_plot_estimates <- function(ols_results_df, n_datasets, len_dataset, sigma_sq) {
line_plot <- ggplot(ols_results_df) +
geom_abline(aes(slope = m, intercept = b, color = "estimate")) +
geom_abline(aes(slope = 2, intercept = -3, color = "truth")) +
xlim(-10, 10) +
ylim(-10, 10) +
ggtitle(bquote("OLS Estimates as Lines on" ~ .(n_datasets) ~ "Datasets with" ~ n == .(len_dataset) ~ "and" ~ sigma^2 == .(sigma_sq)))
pt_plot <- ggplot(ols_results_df) +
geom_point(aes(x = m, y = b, color = "estimate")) +
geom_point(aes(x = 2, y = -3, color = "truth")) +
xlim(-50, 50) +
ylim(-50, 50) +
ggtitle(bquote("OLS Estimates as Points on" ~ .(n_datasets) ~ "Datasets with" ~ n == .(len_dataset) ~ "and" ~ sigma^2 == .(sigma_sq)))
return(list(line_plot = line_plot, pt_plot = pt_plot))
}
plots <- gen_plot_estimates(ols_results_df, n_datasets = 100, len_dataset = 10, sigma_sq = 1)
plots$line_plot