-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Includes cross validation class and methods.
- Loading branch information
Showing
15 changed files
with
460 additions
and
175 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,110 +1,168 @@ | ||
#' @title Cross-validation | ||
#' @name Cross-validation | ||
#' | ||
############################################################### | ||
# # | ||
# (c) Victor Maus <[email protected]> # | ||
# Institute for Geoinformatics (IFGI) # | ||
# University of Muenster (WWU), Germany # | ||
# # | ||
# Earth System Science Center (CCST) # | ||
# National Institute for Space Research (INPE), Brazil # | ||
# # | ||
# # | ||
# R Package dtwSat - 2016-11-27 # | ||
# # | ||
############################################################### | ||
|
||
|
||
#' @title class "twdtwCrossValidation" | ||
#' @name twdtwCrossValidation-class | ||
#' @aliases twdtwCrossValidation | ||
#' @author Victor Maus, \email{vwmaus1@@gmail.com} | ||
#' | ||
#' @description This functions create data partitions and compute Cross-validation metrics. | ||
#' @description This class stores the cross-validation. | ||
#' | ||
#' @param object an object of class \code{\link[dtwSat]{twdtwTimeSeries}} or | ||
#' \code{\link[dtwSat]{twdtwMatches}}. | ||
#' @param object an object of class \code{\link[dtwSat]{twdtwTimeSeries}}. | ||
#' | ||
#' @param times Number of partitions to create. | ||
#' | ||
#' @param p the percentage of data that goes to training. | ||
#' See \code{\link[caret]{createDataPartition}} for details. | ||
#' | ||
#' @param ... Other arguments to be passed to \code{\link[dtwSat]{createPatterns}}. | ||
#' @param conf.int specifies the confidence level (0-1) for interval estimation of the | ||
#' population mean. For more details see \code{\link[ggplot2]{mean_cl_boot}}. | ||
#' | ||
#' @param matrix logical. If TRUE retrieves the confusion matrix. | ||
#' FALSE retrieves User's Accuracy (UA) and Producer's Accuracy (PA). | ||
#' Dafault is FALSE. | ||
#' @param ... Other arguments to be passed to \code{\link[dtwSat]{createPatterns}} and | ||
#' to \code{\link[dtwSat]{twdtwApply}}. | ||
#' | ||
#' @details | ||
#' \describe{ | ||
#' \item{\code{splitDataset}:}{This function splits the a set of time | ||
#' series into training and validation. The function uses stratified | ||
#' sampling and a simple random sampling for each stratum. Each data partition | ||
#' returned by this function has the temporal patterns and a set of time series for | ||
#' validation.} | ||
#' \item{\code{twdtwCrossValidation}:}{The function \code{splitDataset} performs the Cross-validation of | ||
#' the classification based on the labels of the classified time series | ||
#' (Reference) and the labels of the classification (Predicted). This function | ||
#' returns a data.frame with User's and Produce's Accuracy or a list for confusion | ||
#' matrices.} | ||
#' } | ||
#' | ||
#' @seealso | ||
#' \code{\link[dtwSat]{twdtwMatches-class}}, | ||
#' \code{\link[dtwSat]{twdtwApply}}, and | ||
#' \code{\link[dtwSat]{twdtwClassify}}. | ||
#' \code{\link[dtwSat]{createPatterns}}, and | ||
#' \code{\link[dtwSat]{twdtwApply}}. | ||
#' | ||
#' @section Slots : | ||
#' \describe{ | ||
#' \item{\code{partitions}:}{A list with the indices of time series used for training.} | ||
#' \item{\code{accuracy}:}{A list with the accuracy and other TWDTW information for each | ||
#' data partitions.} | ||
#' } | ||
#' | ||
#' @examples | ||
#' \dontrun{ | ||
#' load(system.file("lucc_MT/field_samples_ts.RData", package="dtwSat")) | ||
#' set.seed(1) | ||
#' partitions = splitDataset(field_samples_ts, p=0.1, times=5, | ||
#' freq = 8, formula = y ~ s(x, bs="cc")) | ||
#' log_fun = logisticWeight(alpha=-0.1, beta=50) | ||
#' twdtw_res = lapply(partitions, function(x){ | ||
#' res = twdtwApply(x = x$ts, y = x$patterns, weight.fun = log_fun, n=1) | ||
#' twdtwClassify(x = res) | ||
#' }) | ||
#' cross_validation = twdtwCrossValidation(twdtw_res) | ||
# head(cross_validation, 5) | ||
#' | ||
#' } | ||
NULL | ||
setClass( | ||
Class = "twdtwCrossValidation", | ||
slots = c(partitions = "list", accuracy = "list"), | ||
validity = function(object){ | ||
if(!is(object@partitions, "list")){ | ||
stop("[twdtwTimeSeries: validation] Invalid partitions, class different from list.") | ||
}else{} | ||
if(!is(object@accuracy, "list")){ | ||
stop("[twdtwTimeSeries: validation] Invalid accuracy, class different from list.") | ||
}else{} | ||
return(TRUE) | ||
} | ||
) | ||
|
||
setGeneric("splitDataset", function(object, times, p, ...) standardGeneric("splitDataset")) | ||
setMethod("initialize", | ||
signature = "twdtwCrossValidation", | ||
definition = | ||
function(.Object, partitions, accuracy){ | ||
.Object@partitions = list(Resample1=NULL) | ||
.Object@accuracy = list(OverallAccuracy=NULL, UsersAccuracy=NULL, ProducersAccuracy=NULL, | ||
error.matrix=table(NULL), data=data.frame(NULL)) | ||
if(!missing(partitions)) | ||
.Object@partitions = partitions | ||
if(!missing(accuracy)) | ||
.Object@accuracy = accuracy | ||
validObject(.Object) | ||
return(.Object) | ||
} | ||
) | ||
|
||
#' @rdname Cross-validation | ||
#' @aliases splitDataset | ||
setGeneric("twdtwCrossValidation", | ||
def = function(object, ...) standardGeneric("twdtwCrossValidation") | ||
) | ||
|
||
#' @inheritParams twdtwCrossValidation-class | ||
#' @aliases twdtwCrossValidation-create | ||
#' | ||
#' @describeIn twdtwCrossValidation Splits the set of time | ||
#' series into training and validation. The function uses stratified | ||
#' sampling and a simple random sampling for each stratum. For each data partition | ||
#' this function performs a TWDTW analysis and returns the Overall Accuracy, | ||
#' User's Accuracy, Produce's Accuracy, error matrix (confusion matrix), and a | ||
#' \code{\link[base]{data.frame}} with the classification (Predicted), the | ||
#' reference classes (Reference), and some TWDTW information. | ||
#' | ||
#' @examples | ||
#' \dontrun{ | ||
#' # Data folder | ||
#' data_folder = system.file("lucc_MT/data", package = "dtwSat") | ||
#' | ||
#' # Read dates | ||
#' dates = scan(paste(data_folder,"timeline", sep = "/"), what = "dates") | ||
#' | ||
#' # Read raster time series | ||
#' evi = brick(paste(data_folder,"evi.tif", sep = "/")) | ||
#' raster_timeseries = twdtwRaster(evi, timeline = dates) | ||
#' | ||
#' # Read field samples | ||
#' field_samples = read.csv(paste(data_folder,"samples.csv", sep = "/")) | ||
#' table(field_samples[["label"]]) | ||
#' | ||
#' # Read field samples projection | ||
#' proj_str = scan(paste(data_folder,"samples_projection", sep = "/"), | ||
#' what = "character") | ||
#' | ||
#' # Get sample time series from raster time series | ||
#' field_samples_ts = getTimeSeries(raster_timeseries, | ||
#' y = field_samples, proj4string = proj_str) | ||
#' field_samples_ts | ||
#' | ||
#' # Run cross validation | ||
#' set.seed(1) | ||
#' # Define TWDTW weight function | ||
#' log_fun = logisticWeight(alpha=-0.1, beta=50) | ||
#' cross_validation = twdtwCrossValidation(field_samples_ts, times=3, p=0.1, | ||
#' freq = 8, formula = y ~ s(x, bs="cc"), weight.fun = log_fun) | ||
#' cross_validation | ||
#' } | ||
#' @export | ||
setMethod("splitDataset", "twdtwTimeSeries", | ||
function(object, times=1, p=0.1, ...) splitDataset.twdtwTimeSeries(object, times=times, p=p, ...)) | ||
setMethod(f = "twdtwCrossValidation", | ||
definition = function(object, times, p, ...) twdtwCrossValidation.twdtwTimeSeries(object, times, p, ...)) | ||
|
||
splitDataset.twdtwTimeSeries = function(object, times, p, ...){ | ||
twdtwCrossValidation.twdtwTimeSeries = function(object, times, p, ...){ | ||
|
||
partitions = createDataPartition(y = labels(object), times, p, list = TRUE) | ||
|
||
res = lapply(partitions, function(I){ | ||
training_ts = subset(object, I) | ||
validation_ts = subset(object, -I) | ||
# patt = createPatterns(training_ts, freq = 8, formula = y ~ s(x, bs="cc")) | ||
patt = createPatterns(training_ts, ...) | ||
list(patterns=patt, ts=validation_ts) | ||
}) | ||
|
||
res | ||
} | ||
|
||
setGeneric("twdtwCrossValidation", function(object, matrix=FALSE) standardGeneric("twdtwCrossValidation")) | ||
|
||
#' @rdname Cross-validation | ||
#' @aliases twdtwCrossValidation | ||
#' @export | ||
setMethod("twdtwCrossValidation", "list", | ||
function(object, matrix) twdtwCrossValidation.twdtwTimeSeries(object, matrix=matrix)) | ||
|
||
twdtwCrossValidation.twdtwTimeSeries = function(object, matrix){ | ||
|
||
res = lapply(object, function(x){ | ||
ref = labels(x)$timeseries | ||
levels = sort(as.character(unique(ref))) | ||
# twdtw_res = twdtwApply(x = validation_ts, y = patt, n=1, weight.fun = log_fun) | ||
twdtw_res = twdtwApply(x = validation_ts, y = patt, n=1, ...) | ||
ref = as.character(labels(twdtw_res)$timeseries) | ||
levels = sort(unique(ref)) | ||
labels = levels | ||
# pred = factor(do.call("rbind", x[])$label, levels, labels) | ||
pred = do.call("rbind", lapply(x[], function(xx) as.character(xx$label[which.min(xx$distance)])) ) | ||
df = do.call("rbind", lapply(twdtw_res[], function(xx) xx[which.min(xx$distance),]) ) | ||
ref = factor(ref, levels, labels) | ||
table(Reference=ref, Predicted=pred) | ||
pred = factor(as.character(df$label), levels, labels) | ||
data = data.frame(Reference=ref, Predicted=pred, df[,!names(df)%in%"labels"]) | ||
error.matrix = table(Reference=ref, Predicted=pred) | ||
UA = diag(error.matrix) / colSums(error.matrix) | ||
PA = diag(error.matrix) / rowSums(error.matrix) | ||
O = sum(diag(error.matrix)) / sum(rowSums(error.matrix)) | ||
list(OverallAccuracy=O, UsersAccuracy=UA, ProducersAccuracy=PA, error.matrix=error.matrix, data=data) | ||
}) | ||
|
||
if(!matrix){ | ||
res = do.call("rbind", lapply(seq_along(res), function(i){ | ||
x = res[[i]] | ||
Users = diag(x) / rowSums(x) | ||
Producers = diag(x) / colSums(x) | ||
data.frame(resample=i,label=names(Users), UA = Users, PA = Producers, row.names=NULL) | ||
})) | ||
} | ||
|
||
res | ||
new("twdtwCrossValidation", partitions=partitions, accuracy=res) | ||
|
||
} | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.