-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implemented lcMetaMethod and lcMetaConverged (#61) * Added more tests and support for setting consecutive seeds (#61) * Renamed lcMetaConverged to lcFitConverged (#61) * Implemented lcFitRep and variants(#61) * Added convergence status in messaging (#61) Small changes: * Added check for formula argument to LMKM method * Added converged slot to lcModelPartition * Made getLcMethod() generic * Generic validate() checks for correct output length * Fixed doc check issues * Added workaround for erroneous R CMD check rmarkdown import note
- Loading branch information
Showing
23 changed files
with
743 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#' @include meta-fit.R | ||
|
||
#' @export | ||
#' @rdname lcFitMethods | ||
#' @examples | ||
#' | ||
#' data(latrendData) | ||
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time", nClusters = 2) | ||
#' metaMethod <- lcFitConverged(method, maxRep = 10) | ||
#' metaMethod | ||
#' model <- latrend(metaMethod, latrendData) | ||
setClass('lcFitConverged', contains = 'lcMetaMethod') | ||
|
||
#' @export | ||
#' @rdname lcFitMethods | ||
#' @param method The `lcMethod` to use for fitting. | ||
#' @param maxRep The maximum number of fit attempts | ||
lcFitConverged = function(method, maxRep = Inf) { | ||
mc = match.call.all() | ||
mc$method = getCall(method) | ||
mc$Class = 'lcFitConverged' | ||
do.call(new, as.list(mc)) | ||
} | ||
|
||
|
||
#' @rdname interface-metaMethods | ||
setMethod('fit', 'lcFitConverged', function(method, data, envir, verbose) { | ||
attempt = 1L | ||
|
||
repeat { | ||
enter(verbose, level = verboseLevels$fine, suffix = '') | ||
model = fit(getLcMethod(method), data = data, envir = envir, verbose = verbose) | ||
exit(verbose, level = verboseLevels$fine, suffix = '') | ||
|
||
if (converged(model)) { | ||
return (model) | ||
} else if (attempt >= method$maxRep) { | ||
warning( | ||
sprintf( | ||
'Failed to obtain converged result (got %s) for %s within %d attempts.\n\tReturning last model.', | ||
converged(model), | ||
class(getLcMethod(method))[1], | ||
method$maxRep | ||
), | ||
immediate. = TRUE | ||
) | ||
return (model) | ||
} else { | ||
attempt = attempt + 1L | ||
|
||
if (has_lcMethod_args(getLcMethod(method), 'seed')) { | ||
seed = sample.int(.Machine$integer.max, 1L) | ||
set.seed(seed) | ||
# update fit method with new seed | ||
method@arguments$method = update(getLcMethod(method), seed = seed, .eval = TRUE) | ||
} | ||
|
||
if (is.infinite(method$maxRep)) { | ||
cat( | ||
verbose, | ||
sprintf( | ||
'Method failed to converge (got %gs). Retrying... attempt %d', | ||
converged(model), | ||
attempt | ||
) | ||
) | ||
} else { | ||
cat( | ||
verbose, | ||
sprintf( | ||
'Method failed to converge (got %s). Retrying... attempt %d / %d', | ||
converged(model), | ||
attempt, | ||
method$maxRep | ||
) | ||
) | ||
} | ||
} | ||
} | ||
}) | ||
|
||
#' @rdname interface-metaMethods | ||
setMethod('validate', 'lcFitConverged', function(method, data, envir = NULL, ...) { | ||
callNextMethod() | ||
|
||
validate_that( | ||
has_lcMethod_args(method, 'maxRep'), | ||
is.count(method$maxRep) | ||
) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#' @include meta-fit.R | ||
|
||
#' @export | ||
#' @rdname lcFitMethods | ||
#' @examples | ||
#' | ||
#' data(latrendData) | ||
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time", nClusters = 2) | ||
#' repMethod <- lcFitRep(method, rep = 10, metric = "RSS", maximize = FALSE) | ||
#' repMethod | ||
#' model <- latrend(repMethod, latrendData) | ||
#' | ||
#' minMethod <- lcFitRepMin(method, rep = 10, metric = "RSS") | ||
#' | ||
#' maxMethod <- lcFitRepMax(method, rep = 10, metric = "ASW") | ||
setClass('lcFitRep', contains = 'lcMetaMethod') | ||
|
||
#' @export | ||
#' @rdname lcFitMethods | ||
#' @param rep The number of fits | ||
#' @param metric The internal metric to assess the fit. | ||
#' @param maximize Whether to maximize the metric. Otherwise, it is minimized. | ||
lcFitRep = function(method, rep = 10, metric, maximize) { | ||
mc = match.call.all() | ||
mc$method = getCall(method) | ||
mc$Class = 'lcFitRep' | ||
do.call(new, as.list(mc)) | ||
} | ||
|
||
#' @export | ||
#' @rdname lcFitMethods | ||
lcFitRepMin = function(method, rep = 10, metric) { | ||
mc = match.call.all() | ||
mc$method = getCall(method) | ||
mc$maximize = FALSE | ||
mc$Class = 'lcFitRep' | ||
do.call(new, as.list(mc)) | ||
} | ||
|
||
#' @export | ||
#' @rdname lcFitMethods | ||
lcFitRepMax = function(method, rep = 10, metric) { | ||
mc = match.call.all() | ||
mc$method = getCall(method) | ||
mc$maximize = TRUE | ||
mc$Class = 'lcFitRep' | ||
do.call(new, as.list(mc)) | ||
} | ||
|
||
|
||
#' @rdname interface-metaMethods | ||
setMethod('fit', 'lcFitRep', function(method, data, envir, verbose) { | ||
bestModel = NULL | ||
mult = ifelse(method$maximize, 1, -1) | ||
bestScore = -Inf | ||
|
||
for (i in seq_len(method$rep)) { | ||
cat(verbose, sprintf('Repeated fitting %d / %d', i, method$rep)) | ||
enter(verbose, level = verboseLevels$fine, suffix = '') | ||
newModel = fit(getLcMethod(method), data = data, envir = envir, verbose = verbose) | ||
newScore = metric(newModel, method$metric) | ||
exit(verbose, level = verboseLevels$fine, suffix = '') | ||
|
||
if (is.finite(newScore) && newScore * mult > bestScore) { | ||
cat( | ||
verbose, | ||
sprintf('Found improved fit for %s = %g (previous is %g)', method$metric, newScore, mult * bestScore), | ||
level = verboseLevels$fine | ||
) | ||
bestModel = newModel | ||
bestScore = newScore | ||
} | ||
|
||
if (has_lcMethod_args(getLcMethod(method), 'seed')) { | ||
# update seed for the next run | ||
seed = sample.int(.Machine$integer.max, 1L) | ||
set.seed(seed) | ||
# update fit method with new seed | ||
method@arguments$method = update(getLcMethod(method), seed = seed, .eval = TRUE) | ||
} | ||
} | ||
|
||
bestModel | ||
}) | ||
|
||
#' @rdname interface-metaMethods | ||
setMethod('validate', 'lcFitRep', function(method, data, envir = NULL, ...) { | ||
callNextMethod() | ||
|
||
validate_that( | ||
has_lcMethod_args(method, c('rep', 'metric', 'maximize')), | ||
is.count(method$rep), | ||
is.string(method$metric), | ||
method$metric %in% getInternalMetricNames(), | ||
is.flag(method$maximize) | ||
) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#' @include meta-method.R | ||
|
||
#' @name lcFitMethods | ||
#' @rdname lcFitMethods | ||
#' @title Method fit modifiers | ||
#' @description A collection of special methods that adapt the fitting procedure of the underlying longitudinal cluster method. | ||
#' Supported fit methods: | ||
#' * `lcFitConverged`: Fit a method until a converged result is obtained. | ||
#' * `lcFitRep`: Repeatedly fit a method and return the best result based on a given internal metric. | ||
#' * `lcFitRepMin`: Repeatedly fit a method and return the best result that minimizes the given internal metric. | ||
#' * `lcFitRepMax`: Repeatedly fit a method and return the best result that maximizes the given internal metric. | ||
NULL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#' @include method.R | ||
|
||
#' @export | ||
#' @name interface-metaMethods | ||
#' @rdname interface-metaMethods | ||
#' @inheritParams lcMethod-class | ||
#' @inheritParams getLcMethod | ||
#' @inheritParams compose | ||
#' @inheritParams preFit | ||
#' @inheritParams postFit | ||
#' @inheritParams prepareData | ||
#' @inheritParams validate | ||
#' @aliases lcMetaMethod-class | ||
#' @title lcMetaMethod abstract class | ||
#' @description Virtual class for internal use. Do not use. | ||
setClass( | ||
'lcMetaMethod', | ||
contains = c('lcMethod', 'VIRTUAL') | ||
) | ||
|
||
as.character.lcMetaMethod = function(x, ...) { | ||
c( | ||
sprintf('%s encapsulating:', class(x)[1]), | ||
paste0(' ', as.character(getLcMethod(x), ...)), | ||
' with meta-method arguments:', | ||
paste0(' ', tail(as.character.lcMethod(x), -2L)) | ||
) | ||
} | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('compose', 'lcMetaMethod', function(method, envir = NULL) { | ||
newMethod = method | ||
newMethod@arguments$method = evaluate.lcMethod(getLcMethod(method), try = FALSE, envir = envir) | ||
newMethod | ||
}) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('getLcMethod', 'lcMetaMethod', function(object, ...) object$method) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('getName', 'lcMetaMethod', function(object, ...) getName(getLcMethod(object), ...)) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('getShortName', 'lcMetaMethod', function(object, ...) getShortName(getLcMethod(object), ...)) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('idVariable', 'lcMetaMethod', function(object, ...) idVariable(getLcMethod(object), ...)) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('preFit', 'lcMetaMethod', function(method, data, envir, verbose) { | ||
preFit(getLcMethod(method), data = data, envir = envir, verbose = verbose) | ||
}) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('prepareData', 'lcMetaMethod', function(method, data, verbose) { | ||
prepareData(getLcMethod(method), data = data, verbose = verbose) | ||
}) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('postFit', 'lcMetaMethod', function(method, data, model, envir, verbose) { | ||
postFit(getLcMethod(method), data = data, model = model, envir = envir, verbose = verbose) | ||
}) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('responseVariable', 'lcMetaMethod', function(object, ...) responseVariable(getLcMethod(object), ...)) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('timeVariable', 'lcMetaMethod', function(object, ...) timeVariable(getLcMethod(object), ...)) | ||
|
||
#' @export | ||
#' @rdname interface-metaMethods | ||
setMethod('validate', 'lcMetaMethod', function(method, data, envir = NULL, ...) { | ||
validate(getLcMethod(method), data = data, envir = envir, ...) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.