Skip to content

Commit

Permalink
Fixes error in create pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
vwmaus committed Aug 27, 2023
1 parent c5d1a85 commit ccd6b22
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 36 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import(stars)
import(twdtw)
importFrom(mgcv,gam)
importFrom(mgcv,predict.gam)
importFrom(mgcv,s)
importFrom(reshape2,melt)
importFrom(rlang,.data)
importFrom(scales,date_format)
importFrom(scales,percent)
importFrom(scales,pretty_breaks)
importFrom(stats,as.formula)
importFrom(stats,predict)
importFrom(stats,setNames)
52 changes: 28 additions & 24 deletions R/create_patterns.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@
#'
#' This function creates a pattern based on Generalized Additive Models (GAM).
#' It uses the specified formula to fit the model and predict values.
#' It operates on stars objects (for spatial-temporal data) and sf objects (for spatial data).
#'
#' @param x A stars object representing spatial-temporal data.
#' @param y An sf object representing spatial data with associated attributes.
#' @param formula A formula for the GAM.
#' @param x A three dimensions stars object (x, y, time) with the satellite image time series.
#' @param y An sf object with the coordinates of the training points.
#' @param formula A formula for the GAM. Default is \code{band ~ \link[mgcv]{s}(time)}.
#' @param start_column Name of the column in y that indicates the start date. Default is 'start_date'.
#' @param end_column Name of the column in y that indicates the end date. Default is 'end_date'.
#' @param label_colum Name of the column in y that contains labels. Default is 'label'.
#' @param sampling_freq The time frequency for sampling. If NULL, the function will infer it.
#' @param label_colum Name of the column in y that contains land use labels. Default is 'label'.
#' @param sampling_freq The time frequency for sampling including unit, e.g '16 day'. If NULL, the function will infer it.
#' @param ... Additional arguments passed to the GAM function.
#'
#' @return A list containing the predicted values for each label.
#'
#'
#'
#' @export
create_patterns = function(x, y, formula, start_column = 'start_date',
create_patterns = function(x, y, formula = band ~ s(time), start_column = 'start_date',
end_column = 'end_date', label_colum = 'label',
sampling_freq = NULL, ...){

Expand All @@ -39,6 +38,11 @@ create_patterns = function(x, y, formula, start_column = 'start_date',
stop(paste("Missing required columns in y:", paste(missing_columns, collapse = ", ")))
}

# Check if formula has two
if(length(all.vars(formula)) != 2) {
stop("The formula should have only one predictor!")
}

# Convert columns to date-time
y[ , start_column] <- to_date_time(y[[start_column]])
y[ , end_column] <- to_date_time(y[[end_column]])
Expand Down Expand Up @@ -66,29 +70,27 @@ create_patterns = function(x, y, formula, start_column = 'start_date',
sampling_freq <- get_stars_time_freq(x)
}

# Extract variables from formula
vars <- all.vars(formula)

# Define GAM function
gam_fun = function(y, t, formula_vars, ...){
df = data.frame(y, t = as.numeric(t))
names(df) <- formula_vars
fit = mgcv::gam(data = df, formula = as.formula(paste(formula_vars[1], "~", formula_vars[2])), ...)
pred_t = data.frame(t = as.numeric(seq(min(t), max(t), by = sampling_freq)))
gam_fun <- function(y, t, formula, ...){
df <- data.frame(y, t = as.numeric(t))
df <- setNames(list(y, as.numeric(t)), all.vars(formula))
fit <- mgcv::gam(data = df, formula = formula, ...)
pred_t <- setNames(list(as.numeric(seq(min(t), max(t), by = sampling_freq))), all.vars(formula)[2])
predict(fit, newdata = pred_t)
}

# Apply GAM function
res <- lapply(y_ts, function(ts){
patterns <- lapply(y_ts, function(ts){
y_time <- ts$time
ts$time <- NULL
ts$id <- NULL
sapply(as.list(ts), function(y) {
gam_fun(y, y_time, vars)
gam_fun(y, y_time, formula, ...)
})
})

return(res)
return(patterns)

}


Expand Down Expand Up @@ -119,24 +121,26 @@ extract_time_series <- function(x, y) {
#' This function calculates the most common difference between consecutive time points in a stars object.
#' This can be useful for determining the sampling frequency of the time series data.
#'
#' @param stars_object A stars object containing time series data.
#' @param x A stars object containing time series data.
#'
#' @return A difftime object representing the most common time difference between consecutive samples.
#'
#'
get_stars_time_freq <- function(stars_object) {
get_stars_time_freq <- function(x) {

# Extract the time dimension
time_values <- st_get_dimension_values(stars_object, "time")
time_values <- st_get_dimension_values(x, "time")

# Compute the differences between consecutive time points
time_diffs <- diff(time_values)

# Convert differences to days
# Convert differences to days (while retaining the difftime class)
time_diffs <- as.difftime(time_diffs, units = "days")

# Compute the most common difference (mode)
freq <- unique(time_diffs)[which.max(tabulate(match(time_diffs, unique(time_diffs))))]
# Identify the mode
mode_val_index <- which.max(tabulate(match(time_diffs, unique(time_diffs))))
freq <- diff(time_values[mode_val_index:(mode_val_index+1)])

return(freq)
}

4 changes: 2 additions & 2 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#' @import sf
#' @import stars
#' @import ggplot2
#' @importFrom stats as.formula predict
#' @importFrom mgcv gam predict.gam
#' @importFrom stats as.formula predict setNames
#' @importFrom mgcv gam s predict.gam
#' @importFrom scales pretty_breaks date_format percent
#' @importFrom reshape2 melt
#' @importFrom rlang .data
Expand Down
13 changes: 6 additions & 7 deletions man/create_patterns.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/get_stars_time_freq.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-twdtw_classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dc <- read_stars(tif_files, proxy = FALSE, along = list(time = acquisition_date)
dc <- dc[c("EVI", "NDVI", "RED", "BLUE", "NIR", "MIR")]

# Get temporal patters
ts_patterns <- create_patterns(x = dc, y = samples, formula = band ~ s(t), sampling_freq = 23)
ts_patterns <- create_patterns(x = dc, y = samples)

# Visualize patterns
plot_patterns(ts_patterns)

0 comments on commit ccd6b22

Please sign in to comment.