Skip to content

Commit

Permalink
Meta methods (#125)
Browse files Browse the repository at this point in the history
* Implemented lcMetaMethod and lcMetaConverged (#61)
* Added more tests and support for setting consecutive seeds (#61)
* Renamed lcMetaConverged to lcFitConverged (#61)
* Implemented lcFitRep and variants(#61)
* Added convergence status in messaging (#61)

Small changes:
* Added check for formula argument to LMKM method
* Added converged slot to lcModelPartition
* Made getLcMethod() generic
* Generic validate() checks for correct output length
* Fixed doc check issues
* Added workaround for erroneous R CMD check rmarkdown import note
  • Loading branch information
niekdt authored Nov 7, 2022
1 parent 44ed93a commit de45672
Show file tree
Hide file tree
Showing 23 changed files with 743 additions and 12 deletions.
4 changes: 4 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ Collate:
'make.R'
'matrix.R'
'method.R'
'meta-method.R'
'meta-fit.R'
'meta-fit-converged.R'
'meta-fit-rep.R'
'methodMatrix.R'
'methodAKMedoids.R'
'methodCrimCV.R'
Expand Down
12 changes: 12 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ export(latrendBatch)
export(latrendBoot)
export(latrendCV)
export(latrendRep)
export(lcFitConverged)
export(lcFitRep)
export(lcFitRepMax)
export(lcFitRepMin)
export(lcMethodAkmedoids)
export(lcMethodCrimCV)
export(lcMethodDtwclust)
Expand Down Expand Up @@ -173,6 +177,9 @@ export(validate)
export(weighted.meanNA)
export(which.weight)
exportClasses(lcApproxModel)
exportClasses(lcFitConverged)
exportClasses(lcFitRep)
exportClasses(lcMetaMethod)
exportClasses(lcMethod)
exportClasses(lcModel)
exportMethods("$")
Expand All @@ -188,6 +195,7 @@ exportMethods(fittedTrajectories)
exportMethods(getArgumentDefaults)
exportMethods(getArgumentExclusions)
exportMethods(getLabel)
exportMethods(getLcMethod)
exportMethods(getName)
exportMethods(getShortName)
exportMethods(idVariable)
Expand All @@ -198,10 +206,13 @@ exportMethods(plot)
exportMethods(plotClusterTrajectories)
exportMethods(plotFittedTrajectories)
exportMethods(plotTrajectories)
exportMethods(postFit)
exportMethods(postprob)
exportMethods(preFit)
exportMethods(predictAssignments)
exportMethods(predictForCluster)
exportMethods(predictPostprob)
exportMethods(prepareData)
exportMethods(qqPlot)
exportMethods(responseVariable)
exportMethods(strip)
Expand Down Expand Up @@ -233,6 +244,7 @@ importFrom(foreach,"%dopar%")
importFrom(foreach,foreach)
importFrom(matrixStats,rowLogSumExps)
importFrom(matrixStats,rowMaxs)
importFrom(rmarkdown,html_vignette)
importFrom(stats,AIC)
importFrom(stats,BIC)
importFrom(stats,approx)
Expand Down
23 changes: 23 additions & 0 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,19 @@ setGeneric('getLabel', function(object, ...) {
label
})

# getLcMethod ####
#' @export
#' @name latrend-generics
setGeneric('getLcMethod', function(object, ...) {
method <- standardGeneric('getLcMethod')

assert_that(
is.lcMethod(method)
)

method
})


# getName ####
#' @export
Expand Down Expand Up @@ -720,7 +733,17 @@ setGeneric('trajectoryAssignments', function(object, ...) {
setGeneric('validate', function(method, data, envir, ...) {
validationResult <- standardGeneric('validate')

assert_that(
length(validationResult) == 1L,
msg = sprintf(
'implementation error in validate(%s): output should be length 1',
class(method)[1]
)
)

if (!isTRUE(validationResult)) {
stop(sprintf('%s validation failed: %s', class(method)[1], validationResult))
}

validationResult
})
90 changes: 90 additions & 0 deletions R/meta-fit-converged.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#' @include meta-fit.R

#' @export
#' @rdname lcFitMethods
#' @examples
#'
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time", nClusters = 2)
#' metaMethod <- lcFitConverged(method, maxRep = 10)
#' metaMethod
#' model <- latrend(metaMethod, latrendData)
setClass('lcFitConverged', contains = 'lcMetaMethod')

#' @export
#' @rdname lcFitMethods
#' @param method The `lcMethod` to use for fitting.
#' @param maxRep The maximum number of fit attempts
lcFitConverged = function(method, maxRep = Inf) {
mc = match.call.all()
mc$method = getCall(method)
mc$Class = 'lcFitConverged'
do.call(new, as.list(mc))
}


#' @rdname interface-metaMethods
setMethod('fit', 'lcFitConverged', function(method, data, envir, verbose) {
attempt = 1L

repeat {
enter(verbose, level = verboseLevels$fine, suffix = '')
model = fit(getLcMethod(method), data = data, envir = envir, verbose = verbose)
exit(verbose, level = verboseLevels$fine, suffix = '')

if (converged(model)) {
return (model)
} else if (attempt >= method$maxRep) {
warning(
sprintf(
'Failed to obtain converged result (got %s) for %s within %d attempts.\n\tReturning last model.',
converged(model),
class(getLcMethod(method))[1],
method$maxRep
),
immediate. = TRUE
)
return (model)
} else {
attempt = attempt + 1L

if (has_lcMethod_args(getLcMethod(method), 'seed')) {
seed = sample.int(.Machine$integer.max, 1L)
set.seed(seed)
# update fit method with new seed
method@arguments$method = update(getLcMethod(method), seed = seed, .eval = TRUE)
}

if (is.infinite(method$maxRep)) {
cat(
verbose,
sprintf(
'Method failed to converge (got %gs). Retrying... attempt %d',
converged(model),
attempt
)
)
} else {
cat(
verbose,
sprintf(
'Method failed to converge (got %s). Retrying... attempt %d / %d',
converged(model),
attempt,
method$maxRep
)
)
}
}
}
})

#' @rdname interface-metaMethods
setMethod('validate', 'lcFitConverged', function(method, data, envir = NULL, ...) {
callNextMethod()

validate_that(
has_lcMethod_args(method, 'maxRep'),
is.count(method$maxRep)
)
})
97 changes: 97 additions & 0 deletions R/meta-fit-rep.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#' @include meta-fit.R

#' @export
#' @rdname lcFitMethods
#' @examples
#'
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time", nClusters = 2)
#' repMethod <- lcFitRep(method, rep = 10, metric = "RSS", maximize = FALSE)
#' repMethod
#' model <- latrend(repMethod, latrendData)
#'
#' minMethod <- lcFitRepMin(method, rep = 10, metric = "RSS")
#'
#' maxMethod <- lcFitRepMax(method, rep = 10, metric = "ASW")
setClass('lcFitRep', contains = 'lcMetaMethod')

#' @export
#' @rdname lcFitMethods
#' @param rep The number of fits
#' @param metric The internal metric to assess the fit.
#' @param maximize Whether to maximize the metric. Otherwise, it is minimized.
lcFitRep = function(method, rep = 10, metric, maximize) {
mc = match.call.all()
mc$method = getCall(method)
mc$Class = 'lcFitRep'
do.call(new, as.list(mc))
}

#' @export
#' @rdname lcFitMethods
lcFitRepMin = function(method, rep = 10, metric) {
mc = match.call.all()
mc$method = getCall(method)
mc$maximize = FALSE
mc$Class = 'lcFitRep'
do.call(new, as.list(mc))
}

#' @export
#' @rdname lcFitMethods
lcFitRepMax = function(method, rep = 10, metric) {
mc = match.call.all()
mc$method = getCall(method)
mc$maximize = TRUE
mc$Class = 'lcFitRep'
do.call(new, as.list(mc))
}


#' @rdname interface-metaMethods
setMethod('fit', 'lcFitRep', function(method, data, envir, verbose) {
bestModel = NULL
mult = ifelse(method$maximize, 1, -1)
bestScore = -Inf

for (i in seq_len(method$rep)) {
cat(verbose, sprintf('Repeated fitting %d / %d', i, method$rep))
enter(verbose, level = verboseLevels$fine, suffix = '')
newModel = fit(getLcMethod(method), data = data, envir = envir, verbose = verbose)
newScore = metric(newModel, method$metric)
exit(verbose, level = verboseLevels$fine, suffix = '')

if (is.finite(newScore) && newScore * mult > bestScore) {
cat(
verbose,
sprintf('Found improved fit for %s = %g (previous is %g)', method$metric, newScore, mult * bestScore),
level = verboseLevels$fine
)
bestModel = newModel
bestScore = newScore
}

if (has_lcMethod_args(getLcMethod(method), 'seed')) {
# update seed for the next run
seed = sample.int(.Machine$integer.max, 1L)
set.seed(seed)
# update fit method with new seed
method@arguments$method = update(getLcMethod(method), seed = seed, .eval = TRUE)
}
}

bestModel
})

#' @rdname interface-metaMethods
setMethod('validate', 'lcFitRep', function(method, data, envir = NULL, ...) {
callNextMethod()

validate_that(
has_lcMethod_args(method, c('rep', 'metric', 'maximize')),
is.count(method$rep),
is.string(method$metric),
method$metric %in% getInternalMetricNames(),
is.flag(method$maximize)
)
})
12 changes: 12 additions & 0 deletions R/meta-fit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#' @include meta-method.R

#' @name lcFitMethods
#' @rdname lcFitMethods
#' @title Method fit modifiers
#' @description A collection of special methods that adapt the fitting procedure of the underlying longitudinal cluster method.
#' Supported fit methods:
#' * `lcFitConverged`: Fit a method until a converged result is obtained.
#' * `lcFitRep`: Repeatedly fit a method and return the best result based on a given internal metric.
#' * `lcFitRepMin`: Repeatedly fit a method and return the best result that minimizes the given internal metric.
#' * `lcFitRepMax`: Repeatedly fit a method and return the best result that maximizes the given internal metric.
NULL
84 changes: 84 additions & 0 deletions R/meta-method.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#' @include method.R

#' @export
#' @name interface-metaMethods
#' @rdname interface-metaMethods
#' @inheritParams lcMethod-class
#' @inheritParams getLcMethod
#' @inheritParams compose
#' @inheritParams preFit
#' @inheritParams postFit
#' @inheritParams prepareData
#' @inheritParams validate
#' @aliases lcMetaMethod-class
#' @title lcMetaMethod abstract class
#' @description Virtual class for internal use. Do not use.
setClass(
'lcMetaMethod',
contains = c('lcMethod', 'VIRTUAL')
)

as.character.lcMetaMethod = function(x, ...) {
c(
sprintf('%s encapsulating:', class(x)[1]),
paste0(' ', as.character(getLcMethod(x), ...)),
' with meta-method arguments:',
paste0(' ', tail(as.character.lcMethod(x), -2L))
)
}

#' @export
#' @rdname interface-metaMethods
setMethod('compose', 'lcMetaMethod', function(method, envir = NULL) {
newMethod = method
newMethod@arguments$method = evaluate.lcMethod(getLcMethod(method), try = FALSE, envir = envir)
newMethod
})

#' @export
#' @rdname interface-metaMethods
setMethod('getLcMethod', 'lcMetaMethod', function(object, ...) object$method)

#' @export
#' @rdname interface-metaMethods
setMethod('getName', 'lcMetaMethod', function(object, ...) getName(getLcMethod(object), ...))

#' @export
#' @rdname interface-metaMethods
setMethod('getShortName', 'lcMetaMethod', function(object, ...) getShortName(getLcMethod(object), ...))

#' @export
#' @rdname interface-metaMethods
setMethod('idVariable', 'lcMetaMethod', function(object, ...) idVariable(getLcMethod(object), ...))

#' @export
#' @rdname interface-metaMethods
setMethod('preFit', 'lcMetaMethod', function(method, data, envir, verbose) {
preFit(getLcMethod(method), data = data, envir = envir, verbose = verbose)
})

#' @export
#' @rdname interface-metaMethods
setMethod('prepareData', 'lcMetaMethod', function(method, data, verbose) {
prepareData(getLcMethod(method), data = data, verbose = verbose)
})

#' @export
#' @rdname interface-metaMethods
setMethod('postFit', 'lcMetaMethod', function(method, data, model, envir, verbose) {
postFit(getLcMethod(method), data = data, model = model, envir = envir, verbose = verbose)
})

#' @export
#' @rdname interface-metaMethods
setMethod('responseVariable', 'lcMetaMethod', function(object, ...) responseVariable(getLcMethod(object), ...))

#' @export
#' @rdname interface-metaMethods
setMethod('timeVariable', 'lcMetaMethod', function(object, ...) timeVariable(getLcMethod(object), ...))

#' @export
#' @rdname interface-metaMethods
setMethod('validate', 'lcMetaMethod', function(method, data, envir = NULL, ...) {
validate(getLcMethod(method), data = data, envir = envir, ...)
})
9 changes: 9 additions & 0 deletions R/methodLMKM.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,12 @@ setMethod('fit', signature('lcMethodLMKM'), function(method, data, envir, verbos
clusterNames = make.clusterNames(method$nClusters)
)
})

#' @rdname interface-featureBased
setMethod('validate', 'lcMethodLMKM', function(method, data, envir = NULL, ...) {
callNextMethod()

validate_that(
has_lcMethod_args(method, 'formula')
)
})
Loading

0 comments on commit de45672

Please sign in to comment.