Skip to content

Commit

Permalink
Includes cross validation class and methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
vwmaus committed Nov 29, 2016
1 parent edb45e2 commit 463e0e8
Show file tree
Hide file tree
Showing 15 changed files with 460 additions and 175 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Collate:
'plotChanges.R'
'plotClassification.R'
'plotCostMatrix.R'
'plotCrossValidation.R'
'plotDistance.R'
'plotMaps.R'
'plotMatches.R'
Expand Down
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export(plotArea)
export(plotChanges)
export(plotClassification)
export(plotCostMatrix)
export(plotCrossValidation)
export(plotDistance)
export(plotMaps)
export(plotMatches)
Expand Down Expand Up @@ -56,8 +57,8 @@ exportMethods(res)
exportMethods(resampleTimeSeries)
exportMethods(shiftDates)
exportMethods(show)
exportMethods(splitDataset)
exportMethods(subset)
exportMethods(summary)
exportMethods(twdtwApply)
exportMethods(twdtwClassify)
exportMethods(twdtwCrossValidation)
Expand Down Expand Up @@ -103,6 +104,7 @@ importFrom(sp,over)
importFrom(sp,spTransform)
importFrom(stats,ave)
importFrom(stats,na.omit)
importFrom(stats,sd)
importFrom(stats,window)
importFrom(stats,xtabs)
useDynLib(dtwSat,bestmatches)
Expand Down
1 change: 1 addition & 0 deletions R/class-twdtwTimeSeries.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ setGeneric(name = "twdtwTimeSeries",

#' @inheritParams twdtwTimeSeries-class
#' @aliases twdtwTimeSeries-create
#'
#' @describeIn twdtwTimeSeries Create object of class twdtwTimeSeries.
#'
#' @examples
Expand Down
4 changes: 2 additions & 2 deletions R/createPatterns.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
#' \code{\link[dtwSat]{twdtwApply}}
#'
#' @export
setGeneric("createPatterns", function(x, ...) standardGeneric("createPatterns"))
setGeneric("createPatterns", function(x, from=NULL, to=NULL, freq=1, attr=NULL, split=TRUE, formula, ...) standardGeneric("createPatterns"))

#' @rdname createPatterns
#' @aliases createPatterns-twdtwMatches
Expand Down Expand Up @@ -86,7 +86,7 @@ setGeneric("createPatterns", function(x, ...) standardGeneric("createPatterns"))
#' }
#' @export
setMethod("createPatterns", "twdtwTimeSeries",
function(x, from=NULL, to=NULL, freq=1, attr=NULL, split=TRUE, formula, ...) {
function(x, from, to, freq, attr, split, formula, ...) {

# Get formula variables
if(!is(formula, "formula"))
Expand Down
208 changes: 133 additions & 75 deletions R/crossValidation.R
Original file line number Diff line number Diff line change
@@ -1,110 +1,168 @@
#' @title Cross-validation
#' @name Cross-validation
#'
###############################################################
# #
# (c) Victor Maus <[email protected]> #
# Institute for Geoinformatics (IFGI) #
# University of Muenster (WWU), Germany #
# #
# Earth System Science Center (CCST) #
# National Institute for Space Research (INPE), Brazil #
# #
# #
# R Package dtwSat - 2016-11-27 #
# #
###############################################################


#' @title class "twdtwCrossValidation"
#' @name twdtwCrossValidation-class
#' @aliases twdtwCrossValidation
#' @author Victor Maus, \email{vwmaus1@@gmail.com}
#'
#' @description This functions create data partitions and compute Cross-validation metrics.
#' @description This class stores the cross-validation.
#'
#' @param object an object of class \code{\link[dtwSat]{twdtwTimeSeries}} or
#' \code{\link[dtwSat]{twdtwMatches}}.
#' @param object an object of class \code{\link[dtwSat]{twdtwTimeSeries}}.
#'
#' @param times Number of partitions to create.
#'
#' @param p the percentage of data that goes to training.
#' See \code{\link[caret]{createDataPartition}} for details.
#'
#' @param ... Other arguments to be passed to \code{\link[dtwSat]{createPatterns}}.
#' @param conf.int specifies the confidence level (0-1) for interval estimation of the
#' population mean. For more details see \code{\link[ggplot2]{mean_cl_boot}}.
#'
#' @param matrix logical. If TRUE retrieves the confusion matrix.
#' FALSE retrieves User's Accuracy (UA) and Producer's Accuracy (PA).
#' Dafault is FALSE.
#' @param ... Other arguments to be passed to \code{\link[dtwSat]{createPatterns}} and
#' to \code{\link[dtwSat]{twdtwApply}}.
#'
#' @details
#' \describe{
#' \item{\code{splitDataset}:}{This function splits the a set of time
#' series into training and validation. The function uses stratified
#' sampling and a simple random sampling for each stratum. Each data partition
#' returned by this function has the temporal patterns and a set of time series for
#' validation.}
#' \item{\code{twdtwCrossValidation}:}{The function \code{splitDataset} performs the Cross-validation of
#' the classification based on the labels of the classified time series
#' (Reference) and the labels of the classification (Predicted). This function
#' returns a data.frame with User's and Produce's Accuracy or a list for confusion
#' matrices.}
#' }
#'
#' @seealso
#' \code{\link[dtwSat]{twdtwMatches-class}},
#' \code{\link[dtwSat]{twdtwApply}}, and
#' \code{\link[dtwSat]{twdtwClassify}}.
#' \code{\link[dtwSat]{createPatterns}}, and
#' \code{\link[dtwSat]{twdtwApply}}.
#'
#' @section Slots :
#' \describe{
#' \item{\code{partitions}:}{A list with the indices of time series used for training.}
#' \item{\code{accuracy}:}{A list with the accuracy and other TWDTW information for each
#' data partitions.}
#' }
#'
#' @examples
#' \dontrun{
#' load(system.file("lucc_MT/field_samples_ts.RData", package="dtwSat"))
#' set.seed(1)
#' partitions = splitDataset(field_samples_ts, p=0.1, times=5,
#' freq = 8, formula = y ~ s(x, bs="cc"))
#' log_fun = logisticWeight(alpha=-0.1, beta=50)
#' twdtw_res = lapply(partitions, function(x){
#' res = twdtwApply(x = x$ts, y = x$patterns, weight.fun = log_fun, n=1)
#' twdtwClassify(x = res)
#' })
#' cross_validation = twdtwCrossValidation(twdtw_res)
# head(cross_validation, 5)
#'
#' }
NULL
setClass(
Class = "twdtwCrossValidation",
slots = c(partitions = "list", accuracy = "list"),
validity = function(object){
if(!is(object@partitions, "list")){
stop("[twdtwTimeSeries: validation] Invalid partitions, class different from list.")
}else{}
if(!is(object@accuracy, "list")){
stop("[twdtwTimeSeries: validation] Invalid accuracy, class different from list.")
}else{}
return(TRUE)
}
)

setGeneric("splitDataset", function(object, times, p, ...) standardGeneric("splitDataset"))
setMethod("initialize",
signature = "twdtwCrossValidation",
definition =
function(.Object, partitions, accuracy){
.Object@partitions = list(Resample1=NULL)
.Object@accuracy = list(OverallAccuracy=NULL, UsersAccuracy=NULL, ProducersAccuracy=NULL,
error.matrix=table(NULL), data=data.frame(NULL))
if(!missing(partitions))
.Object@partitions = partitions
if(!missing(accuracy))
.Object@accuracy = accuracy
validObject(.Object)
return(.Object)
}
)

#' @rdname Cross-validation
#' @aliases splitDataset
setGeneric("twdtwCrossValidation",
def = function(object, ...) standardGeneric("twdtwCrossValidation")
)

#' @inheritParams twdtwCrossValidation-class
#' @aliases twdtwCrossValidation-create
#'
#' @describeIn twdtwCrossValidation Splits the set of time
#' series into training and validation. The function uses stratified
#' sampling and a simple random sampling for each stratum. For each data partition
#' this function performs a TWDTW analysis and returns the Overall Accuracy,
#' User's Accuracy, Produce's Accuracy, error matrix (confusion matrix), and a
#' \code{\link[base]{data.frame}} with the classification (Predicted), the
#' reference classes (Reference), and some TWDTW information.
#'
#' @examples
#' \dontrun{
#' # Data folder
#' data_folder = system.file("lucc_MT/data", package = "dtwSat")
#'
#' # Read dates
#' dates = scan(paste(data_folder,"timeline", sep = "/"), what = "dates")
#'
#' # Read raster time series
#' evi = brick(paste(data_folder,"evi.tif", sep = "/"))
#' raster_timeseries = twdtwRaster(evi, timeline = dates)
#'
#' # Read field samples
#' field_samples = read.csv(paste(data_folder,"samples.csv", sep = "/"))
#' table(field_samples[["label"]])
#'
#' # Read field samples projection
#' proj_str = scan(paste(data_folder,"samples_projection", sep = "/"),
#' what = "character")
#'
#' # Get sample time series from raster time series
#' field_samples_ts = getTimeSeries(raster_timeseries,
#' y = field_samples, proj4string = proj_str)
#' field_samples_ts
#'
#' # Run cross validation
#' set.seed(1)
#' # Define TWDTW weight function
#' log_fun = logisticWeight(alpha=-0.1, beta=50)
#' cross_validation = twdtwCrossValidation(field_samples_ts, times=3, p=0.1,
#' freq = 8, formula = y ~ s(x, bs="cc"), weight.fun = log_fun)
#' cross_validation
#' }
#' @export
setMethod("splitDataset", "twdtwTimeSeries",
function(object, times=1, p=0.1, ...) splitDataset.twdtwTimeSeries(object, times=times, p=p, ...))
setMethod(f = "twdtwCrossValidation",
definition = function(object, times, p, ...) twdtwCrossValidation.twdtwTimeSeries(object, times, p, ...))

splitDataset.twdtwTimeSeries = function(object, times, p, ...){
twdtwCrossValidation.twdtwTimeSeries = function(object, times, p, ...){

partitions = createDataPartition(y = labels(object), times, p, list = TRUE)

res = lapply(partitions, function(I){
training_ts = subset(object, I)
validation_ts = subset(object, -I)
# patt = createPatterns(training_ts, freq = 8, formula = y ~ s(x, bs="cc"))
patt = createPatterns(training_ts, ...)
list(patterns=patt, ts=validation_ts)
})

res
}

setGeneric("twdtwCrossValidation", function(object, matrix=FALSE) standardGeneric("twdtwCrossValidation"))

#' @rdname Cross-validation
#' @aliases twdtwCrossValidation
#' @export
setMethod("twdtwCrossValidation", "list",
function(object, matrix) twdtwCrossValidation.twdtwTimeSeries(object, matrix=matrix))

twdtwCrossValidation.twdtwTimeSeries = function(object, matrix){

res = lapply(object, function(x){
ref = labels(x)$timeseries
levels = sort(as.character(unique(ref)))
# twdtw_res = twdtwApply(x = validation_ts, y = patt, n=1, weight.fun = log_fun)
twdtw_res = twdtwApply(x = validation_ts, y = patt, n=1, ...)
ref = as.character(labels(twdtw_res)$timeseries)
levels = sort(unique(ref))
labels = levels
# pred = factor(do.call("rbind", x[])$label, levels, labels)
pred = do.call("rbind", lapply(x[], function(xx) as.character(xx$label[which.min(xx$distance)])) )
df = do.call("rbind", lapply(twdtw_res[], function(xx) xx[which.min(xx$distance),]) )
ref = factor(ref, levels, labels)
table(Reference=ref, Predicted=pred)
pred = factor(as.character(df$label), levels, labels)
data = data.frame(Reference=ref, Predicted=pred, df[,!names(df)%in%"labels"])
error.matrix = table(Reference=ref, Predicted=pred)
UA = diag(error.matrix) / colSums(error.matrix)
PA = diag(error.matrix) / rowSums(error.matrix)
O = sum(diag(error.matrix)) / sum(rowSums(error.matrix))
list(OverallAccuracy=O, UsersAccuracy=UA, ProducersAccuracy=PA, error.matrix=error.matrix, data=data)
})

if(!matrix){
res = do.call("rbind", lapply(seq_along(res), function(i){
x = res[[i]]
Users = diag(x) / rowSums(x)
Producers = diag(x) / colSums(x)
data.frame(resample=i,label=names(Users), UA = Users, PA = Producers, row.names=NULL)
}))
}

res
new("twdtwCrossValidation", partitions=partitions, accuracy=res)

}






72 changes: 51 additions & 21 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,55 @@ show.twdtwRaster = function(object){
invisible(NULL)
}

# Show objects of class twdtwCrossValidation
show.twdtwCrossValidation = function(object){
res = summary(object, conf.int=.95)
res = lapply(res, FUN=round, digits = 2)
cat("An object of class \"twdtwCrossValidation\"\n")
cat("Number of data partitions:",length(object@partitions),"\n")
cat("Bootstrap simulation (CI .95)\n")
print(res)
invisible(NULL)
}

#' @inheritParams twdtwCrossValidation-class
#' @rdname twdtwCrossValidation-class
#' @export
setMethod(f = "show", "twdtwCrossValidation",
definition = show.twdtwCrossValidation)

summary.twdtwCrossValidation = function(object, conf.int=.95, ...){

ov = do.call("rbind", lapply(object@accuracy, function(x){
data.frame(OV=x$OverallAccuracy, row.names = NULL)
}))

uapa = do.call("rbind", lapply(object@accuracy, function(x){
data.frame(label=names(x$UsersAccuracy), UA=x$UsersAccuracy, PA=x$ProducersAccuracy, row.names = NULL)
}))

sd_ov = sd(ov[, c("OV")])
sd_uapa = aggregate(uapa[, c("UA","PA")], list(uapa$label), sd)
l_names = levels(uapa$label)
names(l_names) = l_names
ic_ov = mean_cl_boot(x = ov[, c("OV")], conf.int = conf.int, ...)
names(ic_ov) = NULL
assess_ov = data.frame(OverallAccuracy=ic_ov[1], sd=sd_ov, CImin=ic_ov[2], CImax=ic_ov[3])
ic_ua = t(sapply(l_names, function(i) mean_cl_boot(x = uapa$UA[uapa$label==i], conf.int = conf.int, ...)))
names(ic_ua) = NULL
assess_ua = data.frame(UsersAccuracy=unlist(ic_ua[,1]), sd=sd_uapa[,"UA"], CImin=unlist(ic_ua[,2]), CImax=unlist(ic_ua[,3]))
ic_pa = t(sapply(l_names, function(i) mean_cl_boot(x = uapa$PA[uapa$label==i], conf.int = conf.int, ...)))
names(ic_pa) = NULL
assess_pa = data.frame(ProducersAccuracy=unlist(ic_pa[,1]), sd=sd_uapa[,"PA"], CImin=unlist(ic_pa[,2]), CImax=unlist(ic_pa[,3]))
list(OverallAccuracy=assess_ov, UsersAccuracy=assess_ua, ProducersAccuracy=assess_pa)
}

#' @inheritParams twdtwCrossValidation-class
#' @rdname twdtwCrossValidation-class
#' @export
setMethod(f = "summary", "twdtwCrossValidation",
definition = summary.twdtwCrossValidation)

#' @inheritParams twdtwTimeSeries-class
#' @rdname twdtwTimeSeries-class
#' @export
Expand Down Expand Up @@ -479,25 +528,6 @@ setMethod("is.twdtwMatches", "ANY",
#' @export
setMethod("is.twdtwRaster", "ANY",
function(x) is(x, "twdtwRaster"))

# #' @aliases summary
# #' @inheritParams twdtwMatches-class
# #' @describeIn twdtwMatches Summary of objects of class twdtwMatches.
# #' @export
# setMethod("summary",
# signature(object = "twdtwMatches"),
# function(object, labels=NULL, ...){
# summary.twdtw(object, labels=labels, ...)
# }
# )
#
# summary.twdtw = function(object, ...){
# lapply(as.list(object), function(obj){
# res = lapply(labels(obj)$patterns, function(l){
# m = subset(obj, patterns.labels=l)[[1]]
# c(labels=l, N.Matches=nrow(m), summary(m$distance))
# })
# data.frame(res)
# })
# }



Loading

0 comments on commit 463e0e8

Please sign in to comment.