Skip to content

Commit f600344

Browse files
committed
Truncating plots moved to Plot() f
Truncating plots moved to Plot() for added flexibility.
1 parent fc88b71 commit f600344

File tree

10 files changed

+283
-31
lines changed

10 files changed

+283
-31
lines changed

R/functions_generic.R

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ plot.History <- function(x, ...) {
3131
legend <- eval(args$legend)
3232
else
3333
legend <- TRUE
34+
if ("trunc_per_agent" %in% names(args))
35+
trunc_per_agent <- eval(args$trunc_per_agent)
36+
else
37+
trunc_per_agent <- TRUE
38+
if ("trunc_over_agents" %in% names(args))
39+
trunc_over_agents <- eval(args$trunc_over_agents)
40+
else
41+
trunc_over_agents <- TRUE
3442
if ("regret" %in% names(args))
3543
regret <- eval(args$regret)
3644
else
@@ -163,7 +171,9 @@ plot.History <- function(x, ...) {
163171
xlab = xlab,
164172
ylab = ylab,
165173
limit_agents = limit_agents,
166-
limit_context = limit_context
174+
limit_context = limit_context,
175+
trunc_over_agents = trunc_over_agents,
176+
trunc_per_agent = trunc_per_agent
167177
)
168178
} else if (type == "average") {
169179
Plot$new()$average(
@@ -193,7 +203,9 @@ plot.History <- function(x, ...) {
193203
ylab = ylab,
194204
cum_average = cum_average,
195205
limit_agents = limit_agents,
196-
limit_context = limit_context
206+
limit_context = limit_context,
207+
trunc_over_agents = trunc_over_agents,
208+
trunc_per_agent = trunc_per_agent
197209
)
198210
} else if (type == "optimal") {
199211
Plot$new()$optimal(
@@ -220,7 +232,9 @@ plot.History <- function(x, ...) {
220232
xlab = xlab,
221233
ylab = ylab,
222234
limit_agents = limit_agents,
223-
limit_context = limit_context
235+
limit_context = limit_context,
236+
trunc_over_agents = trunc_over_agents,
237+
trunc_per_agent = trunc_per_agent
224238
)
225239
} else if (type == "arms") {
226240
Plot$new()$arms(
@@ -240,7 +254,9 @@ plot.History <- function(x, ...) {
240254
xlab = xlab,
241255
ylab = ylab,
242256
limit_agents = limit_agents,
243-
limit_context = limit_context
257+
limit_context = limit_context,
258+
trunc_over_agents = trunc_over_agents,
259+
trunc_per_agent = trunc_per_agent
244260

245261
)
246262
}

R/history.R

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,10 @@ History <- R6::R6Class(
343343

344344
private$.cum_stats <- private$.data[, list(
345345

346+
347+
sims = length(reward),
348+
sqrt_sims = sqrt(length(reward)),
349+
346350
regret_var = var(regret),
347351
regret_sd = sd(regret),
348352
regret = mean(regret),
@@ -373,14 +377,15 @@ History <- R6::R6Class(
373377
private$.cum_stats[, cum_regret_rate := cum_regret / t]
374378

375379
qn <- qnorm(0.975)
376-
sqrt_sim <- sqrt(self$get_simulation_count())
377380

378-
private$.cum_stats[, cum_regret_ci := cum_regret_sd / sqrt_sim * qn]
379-
private$.cum_stats[, cum_reward_ci := cum_reward_sd / sqrt_sim * qn]
380-
private$.cum_stats[, cum_regret_rate_ci := cum_regret_rate_sd / sqrt_sim * qn]
381-
private$.cum_stats[, cum_reward_rate_ci := cum_reward_rate_sd / sqrt_sim * qn]
382-
private$.cum_stats[, regret_ci := regret_sd / sqrt_sim * qn]
383-
private$.cum_stats[, reward_ci := reward_sd / sqrt_sim * qn]
381+
private$.cum_stats[, cum_regret_ci := cum_regret_sd / sqrt_sims * qn]
382+
private$.cum_stats[, cum_reward_ci := cum_reward_sd / sqrt_sims * qn]
383+
private$.cum_stats[, cum_regret_rate_ci := cum_regret_rate_sd / sqrt_sims * qn]
384+
private$.cum_stats[, cum_reward_rate_ci := cum_reward_rate_sd / sqrt_sims * qn]
385+
private$.cum_stats[, regret_ci := regret_sd / sqrt_sims * qn]
386+
private$.cum_stats[, reward_ci := reward_sd / sqrt_sims * qn]
387+
388+
private$.cum_stats[,sqrt_sims:=NULL]
384389

385390
private$.data[, cum_reward_rate := cum_reward / t]
386391
private$.data[, cum_regret_rate := cum_regret / t]

R/plot.R

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ Plot <- R6::R6Class(
3131
legend_position = "topleft",
3232
legend_title = NULL,
3333
limit_agents = NULL,
34-
limit_context = NULL) {
34+
limit_context = NULL,
35+
trunc_over_agents = TRUE,
36+
trunc_per_agent = TRUE) {
3537

3638
self$history <- history
3739

@@ -85,7 +87,9 @@ Plot <- R6::R6Class(
8587
traces_max = traces_max,
8688
traces_alpha = traces_alpha,
8789
smooth = smooth,
88-
rate = rate
90+
rate = rate,
91+
trunc_over_agents = trunc_over_agents,
92+
trunc_per_agent = trunc_per_agent
8993
)
9094

9195
invisible(recordPlot())
@@ -116,7 +120,9 @@ Plot <- R6::R6Class(
116120
legend_position = "topleft",
117121
legend_title = NULL,
118122
limit_agents = NULL,
119-
limit_context = NULL) {
123+
limit_context = NULL,
124+
trunc_over_agents = TRUE,
125+
trunc_per_agent = TRUE) {
120126

121127
self$history <- history
122128

@@ -149,7 +155,9 @@ Plot <- R6::R6Class(
149155
traces = traces,
150156
traces_max = traces_max,
151157
traces_alpha = traces_alpha,
152-
smooth = smooth
158+
smooth = smooth,
159+
trunc_over_agents = trunc_over_agents,
160+
trunc_per_agent = trunc_per_agent
153161
)
154162

155163
invisible(recordPlot())
@@ -183,7 +191,9 @@ Plot <- R6::R6Class(
183191
legend_position = "topleft",
184192
legend_title = NULL,
185193
limit_agents = NULL,
186-
limit_context = NULL) {
194+
limit_context = NULL,
195+
trunc_over_agents = TRUE,
196+
trunc_per_agent = TRUE) {
187197
self$history <- history
188198

189199
if (regret) {
@@ -225,7 +235,9 @@ Plot <- R6::R6Class(
225235
traces_max = traces_max,
226236
traces_alpha = traces_alpha,
227237
smooth = smooth,
228-
rate = rate
238+
rate = rate,
239+
trunc_over_agents = trunc_over_agents,
240+
trunc_per_agent = trunc_per_agent
229241
)
230242

231243
invisible(recordPlot())
@@ -248,7 +260,9 @@ Plot <- R6::R6Class(
248260
legend_title = NULL,
249261
limit_context = NULL,
250262
smooth = FALSE,
251-
limit_agents = NULL) {
263+
limit_agents = NULL,
264+
trunc_over_agents = TRUE,
265+
trunc_per_agent = TRUE) {
252266

253267
self$history <- history
254268

@@ -314,7 +328,10 @@ Plot <- R6::R6Class(
314328

315329
eg <- expand.grid(t = dt[sim == 1]$t, choice = seq(1.0, max_arm, 1))
316330
data <- merge(data, eg, all = TRUE)
317-
data[is.na(data)] <- 0.0
331+
# turn NA into 0
332+
for (j in seq_len(ncol(data)))
333+
set(data,which(is.na(data[[j]])),j,0)
334+
318335
data$dataum <- ave(data$arm_count, data$t, FUN = cumsum)
319336
data$zero <- 0.0
320337
min_ylim <- 0
@@ -445,7 +462,9 @@ Plot <- R6::R6Class(
445462
traces_alpha = 0.3,
446463
cum_average = FALSE,
447464
smooth = FALSE,
448-
rate = FALSE) {
465+
rate = FALSE,
466+
trunc_over_agents = TRUE,
467+
trunc_per_agent = TRUE) {
449468

450469
cum_flip <- FALSE
451470
if((line_data_name=="reward" || line_data_name=="regret") && isTRUE(cum_average)) {
@@ -472,7 +491,7 @@ Plot <- R6::R6Class(
472491
disp_data_name <- gsub("none", disp, disp_data_name)
473492
data <-
474493
self$history$get_cumulative_data(
475-
limit_cols = c("agent", "t", line_data_name, disp_data_name),
494+
limit_cols = c("agent", "t", "sims", line_data_name, disp_data_name),
476495
limit_agents = limit_agents,
477496
interval = interval
478497
)
@@ -481,12 +500,28 @@ Plot <- R6::R6Class(
481500
disp <- NULL
482501
data <-
483502
self$history$get_cumulative_data(
484-
limit_cols = c("agent", "t", line_data_name),
503+
limit_cols = c("agent", "t", "sims", line_data_name),
485504
limit_agents = limit_agents,
486505
interval = interval
487506
)
488507
}
489508

509+
agent_levels <- levels(droplevels(data$agent))
510+
n_agents <- length(agent_levels)
511+
512+
# turn NA into 0
513+
for (j in seq_len(ncol(data)))
514+
data.table::set(data,which(is.na(data[[j]])),j,0)
515+
516+
if(isTRUE(trunc_per_agent)) {
517+
data <- data[data$sims == max(data$sims)]
518+
}
519+
520+
if(isTRUE(trunc_over_agents)) {
521+
min_t_sim <- min(data[,max(t), by = c("agent")]$V1)
522+
data <- data[t<=min_t_sim]
523+
}
524+
490525
if (!is.null(xlim)) {
491526
min_xlim <- xlim[1]
492527
max_xlim <- xlim[2]
@@ -495,9 +530,6 @@ Plot <- R6::R6Class(
495530
max_xlim <- data[, max(t)]
496531
}
497532

498-
agent_levels <- levels(droplevels(data$agent))
499-
n_agents <- length(agent_levels)
500-
501533
data.table::setorder(data, agent, t)
502534

503535
if(cum_flip==TRUE) {
@@ -709,10 +741,10 @@ Plot <- R6::R6Class(
709741

710742
#' Plot
711743
#'
712-
#' Generates plots from \code{History} data.
744+
#' Generates plots from \code{\link{History}} data.
713745
#'
714746
#' Usually not instantiated directly but invoked by calling the generic \code{plot(h)}, where \code{h}
715-
#' is an \code{History} class instance.
747+
#' is an \code{\link{History}} class instance.
716748
#'
717749
#' @name Plot
718750
#' @aliases average optimal arms do_plot gg_color_hue check_history_data
@@ -834,6 +866,14 @@ Plot <- R6::R6Class(
834866
#' \item{\code{ylab}}{
835867
#' \code{(character, NULL)} a title for the y axis
836868
#' }
869+
#' \item{\code{trunc_over_agents}}{
870+
#' \code{(logical , TRUE)} Truncate the chart to the agent with the fewest time steps t.
871+
#' }
872+
#' \item{\code{trunc_per_agent}}{
873+
#' \code{(logical , TRUE)} Truncate every agent's plot to the number of time steps that have been fully
874+
#' simulated. That is, time steps for which the number of simulations equals the number defined in
875+
#' \code{\link{Simulator}}'s \code{simulations} parameter.
876+
#' }
837877
#' }
838878
#'
839879
#'

R/simulator.R

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,6 @@ Simulator <- R6::R6Class(
251251
self$internal_history$set_meta_data("sim_total_duration", formatted_duration)
252252
message(paste0("Completed simulation in ",formatted_duration))
253253

254-
# TODO: this should be optional, and maybe done at plotside?
255-
self$internal_history$truncate()
256-
257254
start_time_stats <- Sys.time()
258255
message("Computing statistics.")
259256
# update statistics TODO: not always necessary, add option arg to class?

contextual.Rproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ BuildType: Package
1919
PackageUseDevtools: Yes
2020
PackageInstallArgs: --no-multiarch --with-keep.source
2121
PackageCheckArgs: --as-cran
22-
PackageRoxygenize: rd,collate,namespace,vignette
22+
PackageRoxygenize: vignette
2323

2424
QuitChildProcessesOnExit: Yes
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#' @export
2+
OnlineOfflineContinuumBandit <- R6::R6Class(
3+
inherit = Bandit,
4+
class = FALSE,
5+
private = list(
6+
S = NULL
7+
),
8+
public = list(
9+
class_name = "OnlineOfflineContinuumBandit",
10+
delta = NULL,
11+
c1 = NULL,
12+
c2 = NULL,
13+
arm_function = NULL,
14+
choice = NULL,
15+
initialize = function(delta, horizon) {
16+
self$c1 <- runif(1,0.25,0.75)
17+
self$c2 <- runif(1,0.25,0.75)
18+
self$arm_function <- function(x, c1 = 0.25, c2 = 0.75) {
19+
-(x - c1) ^ 2 + c2 + rnorm(length(x), 0, 0.01)
20+
}
21+
self$delta <- delta
22+
self$choice <- runif(horizon, min=0, max=1)
23+
private$S <- data.frame(self$choice, self$arm_function(self$choice, self$c1, self$c2))
24+
self$k <- 1
25+
},
26+
post_initialization = function() {
27+
private$S <- private$S[sample(nrow(private$S)),]
28+
colnames(private$S) <- c('choice', 'reward')
29+
#print(private$S)
30+
},
31+
get_context = function(index) {
32+
context <- list()
33+
context$k <- self$k
34+
context
35+
},
36+
get_reward = function(index, context, action) {
37+
reward_at_index <- as.double(private$S$reward[[index]])
38+
if (abs(private$S$choice[[index]] - action$choice) < self$delta) {
39+
reward <- list(
40+
reward = reward_at_index,
41+
optimal_reward = self$c2
42+
)
43+
} else {
44+
NULL
45+
}
46+
}
47+
)
48+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
library(contextual)
2+
library(here)
3+
setwd(here("demo","replication_kruijswijk_2019"))
4+
5+
source("./bandit_continuum_offon.R")
6+
source("./policy_tbl.R")
7+
source("./policy_unifcont.R")
8+
source("./policy_efirst_regression.R")
9+
10+
set.seed(100)
11+
12+
13+
14+
horizon <- 10000
15+
simulations <- 10
16+
17+
continuous_arms <- function(x, c1 = 0.25, c2 = 0.75) {
18+
-(x - c1) ^ 2 + c2 + rnorm(length(x), 0, 0.01)
19+
}
20+
21+
22+
23+
choice <- runif(horizon, min=0, max=1)
24+
reward <- continuous_arms(choice)
25+
offline_data <- data.frame(choice, reward)
26+
27+
int_time <- 50
28+
amplitude <- 0.05
29+
learn_rate <- 1
30+
omega <- 1#2*pi/int_time
31+
x0_start <- runif(1)#2.0
32+
33+
34+
35+
bandit <- OnlineOfflineContinuumBandit$new(delta = 0.1, horizon = horizon)
36+
37+
38+
agents <- list(Agent$new(UniformRandomContinuousPolicy$new(), bandit),
39+
Agent$new(ThompsonBayesianLinearPolicy$new(), bandit))
40+
#Agent$new(LifPolicy$new(int_time, amplitude, learn_rate, omega, x0_start), bandit),
41+
#Agent$new(EFirstRegressionPolicy$new(epsilon = 100), bandit))
42+
43+
44+
history <- Simulator$new(agents = agents,
45+
horizon = horizon,
46+
simulations = simulations,
47+
do_parallel = TRUE)$run()
48+
49+
plot(history, type = "cumulative", regret = TRUE, rate = FALSE, disp = 'ci', trunc_over_agents = FALSE, trunc_per_agent = FALSE)

0 commit comments

Comments
 (0)