5.3 Database Backends
In mlr3, Task
s store their data in an abstract data format, the DataBackend
.
The default backend uses data.table via the DataBackendDataTable
as an in-memory data base.
For larger data, or when working with many tasks in parallel, it can be advantageous to interface an out-of-memory data. We use the excellent R package dbplyr which extends dplyr to work on many popular data bases like MariaDB, PostgreSQL or SQLite.
5.3.1 Use Case: NYC Flights
To generate a halfway realistic scenario, we use the NYC flights data set from package nycflights13:
# load data
requireNamespace("DBI")
## Loading required namespace: DBI
requireNamespace("RSQLite")
## Loading required namespace: RSQLite
requireNamespace("nycflights13")
## Loading required namespace: nycflights13
data("flights", package = "nycflights13")
str(flights)
## tibble [336,776 × 19] (S3: tbl_df/tbl/data.frame)
## $ year : int [1:336776] 2013 2013 2013 2013 2013 2013 2013 2013 2013 2013 ...
## $ month : int [1:336776] 1 1 1 1 1 1 1 1 1 1 ...
## $ day : int [1:336776] 1 1 1 1 1 1 1 1 1 1 ...
## $ dep_time : int [1:336776] 517 533 542 544 554 554 555 557 557 558 ...
## $ sched_dep_time: int [1:336776] 515 529 540 545 600 558 600 600 600 600 ...
## $ dep_delay : num [1:336776] 2 4 2 -1 -6 -4 -5 -3 -3 -2 ...
## $ arr_time : int [1:336776] 830 850 923 1004 812 740 913 709 838 753 ...
## $ sched_arr_time: int [1:336776] 819 830 850 1022 837 728 854 723 846 745 ...
## $ arr_delay : num [1:336776] 11 20 33 -18 -25 12 19 -14 -8 8 ...
## $ carrier : chr [1:336776] "UA" "UA" "AA" "B6" ...
## $ flight : int [1:336776] 1545 1714 1141 725 461 1696 507 5708 79 301 ...
## $ tailnum : chr [1:336776] "N14228" "N24211" "N619AA" "N804JB" ...
## $ origin : chr [1:336776] "EWR" "LGA" "JFK" "JFK" ...
## $ dest : chr [1:336776] "IAH" "IAH" "MIA" "BQN" ...
## $ air_time : num [1:336776] 227 227 160 183 116 150 158 53 140 138 ...
## $ distance : num [1:336776] 1400 1416 1089 1576 762 ...
## $ hour : num [1:336776] 5 5 5 5 6 5 6 6 6 6 ...
## $ minute : num [1:336776] 15 29 40 45 0 58 0 0 0 0 ...
## $ time_hour : POSIXct[1:336776], format: "2013-01-01 05:00:00" "2013-01-01 05:00:00" ...
# add column of unique row ids
$row_id = 1:nrow(flights)
flights
# create sqlite database in temporary file
= tempfile("flights", fileext = ".sqlite")
path = DBI::dbConnect(RSQLite::SQLite(), path)
con = DBI::dbWriteTable(con, "flights", as.data.frame(flights))
tbl ::dbDisconnect(con)
DBI
# remove in-memory data
rm(flights)
5.3.2 Preprocessing with dplyr
With the SQLite database in path
, we now re-establish a connection and switch to dplyr/dbplyr for some essential preprocessing.
# establish connection
= DBI::dbConnect(RSQLite::SQLite(), path)
con
# select the "flights" table, enter dplyr
library("dplyr")
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library("dbplyr")
##
## Attaching package: 'dbplyr'
## The following objects are masked from 'package:dplyr':
##
## ident, sql
= tbl(con, "flights") tbl
First, we select a subset of columns to work on:
= c("row_id", "year", "month", "day", "hour", "minute", "dep_time",
keep "arr_time", "carrier", "flight", "air_time", "distance", "arr_delay")
= select(tbl, keep) tbl
Additionally, we remove those observations where the arrival delay (arr_delay
) has a missing value:
= filter(tbl, !is.na(arr_delay)) tbl
To keep runtime reasonable for this toy example, we filter the data to only use every second row:
= filter(tbl, row_id %% 2 == 0) tbl
The factor levels of the feature carrier
are merged so that infrequent carriers are replaced by level “other”:
= mutate(tbl, carrier = case_when(
tbl %in% c("OO", "HA", "YV", "F9", "AS", "FL", "VX", "WN") ~ "other",
carrier TRUE ~ carrier)
)
5.3.3 DataBackendDplyr
The processed table is now used to create a mlr3db::DataBackendDplyr
from mlr3db:
library("mlr3db")
= as_data_backend(tbl, primary_key = "row_id") b
We can now use the interface of DataBackend
to query some basic information of the data:
$nrow b
## [1] 163707
$ncol b
## [1] 13
$head() b
## row_id year month day hour minute dep_time arr_time carrier flight air_time
## 1: 2 2013 1 1 5 29 533 850 UA 1714 227
## 2: 4 2013 1 1 5 45 544 1004 B6 725 183
## 3: 6 2013 1 1 5 58 554 740 UA 1696 150
## 4: 8 2013 1 1 6 0 557 709 EV 5708 53
## 5: 10 2013 1 1 6 0 558 753 AA 301 138
## 6: 12 2013 1 1 6 0 558 853 B6 71 158
## distance arr_delay
## 1: 1416 20
## 2: 1576 -18
## 3: 719 12
## 4: 229 -14
## 5: 733 8
## 6: 1005 -3
Note that the DataBackendDplyr
does not know about any rows or columns we have filtered out with dplyr before, it just operates on the view we provided.
5.3.4 Model fitting
We create the following mlr3 objects:
- A
regression task
, based on the previously createdmlr3db::DataBackendDplyr
. - A regression learner (
regr.rpart
). - A resampling strategy: 3 times repeated subsampling using 2% of the observations for training (“
subsampling
”) - Measures “
mse
,” “time_train
” and “time_predict
”
= TaskRegr$new("flights_sqlite", b, target = "arr_delay")
task = lrn("regr.rpart")
learner = mlr_measures$mget(c("regr.mse", "time_train", "time_predict"))
measures = rsmp("subsampling")
resampling $param_set$values = list(repeats = 3, ratio = 0.02) resampling
We pass all these objects to resample()
to perform a simple resampling with three iterations.
In each iteration, only the required subset of the data is queried from the SQLite data base and passed to rpart::rpart()
:
= resample(task, learner, resampling)
rr print(rr)
## <ResampleResult> of 3 iterations
## * Task: flights_sqlite
## * Learner: regr.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
$aggregate(measures) rr
## regr.mse time_train time_predict
## 1256 0 0
5.3.5 Cleanup
Finally, we remove the tbl
object and close the connection.
rm(tbl)
::dbDisconnect(con) DBI