Skip to content

Commit

Permalink
Adds knn-1-twdtw train and predict methods for stars objects
Browse files Browse the repository at this point in the history
  • Loading branch information
vwmaus committed Sep 7, 2023
1 parent d621cf7 commit 6329150
Show file tree
Hide file tree
Showing 20 changed files with 464 additions and 310 deletions.
7 changes: 3 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ Depends:
Imports:
mgcv,
stats,
scales,
reshape2,
rlang
tidyr,
rlang,
proxy
Suggests:
rbenchmark,
stringr,
testthat (>= 3.0.0)
Config/testthat/edition: 3
14 changes: 8 additions & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Generated by roxygen2: do not edit by hand

export(create_patterns)
export(plot_patterns)
S3method(plot,knn1_twdtw)
S3method(predict,knn1_twdtw)
export(knn1_twdtw)
export(shift_dates)
import(ggplot2)
import(sf)
Expand All @@ -10,11 +11,12 @@ import(twdtw)
importFrom(mgcv,gam)
importFrom(mgcv,predict.gam)
importFrom(mgcv,s)
importFrom(reshape2,melt)
importFrom(proxy,dist)
importFrom(rlang,.data)
importFrom(scales,date_format)
importFrom(scales,percent)
importFrom(scales,pretty_breaks)
importFrom(stats,as.formula)
importFrom(stats,predict)
importFrom(stats,setNames)
importFrom(tidyr,nest)
importFrom(tidyr,pivot_longer)
importFrom(tidyr,pivot_wider)
importFrom(tidyr,unnest)
146 changes: 0 additions & 146 deletions R/create_patterns.R

This file was deleted.

62 changes: 32 additions & 30 deletions R/plot_patterns.R
Original file line number Diff line number Diff line change
@@ -1,41 +1,43 @@
#' Plot Patterns from Time Series Data
#'
#' This function takes a list of time series data and creates a multi-faceted plot
#' where each facet corresponds to a different time series from the list.
#' Within each facet, different attributes (columns of the time series) are
#' plotted as lines with different colors.
#' This function visualizes time series data patterns from the \code{"knn1_twdtw"} model.
#' It produces a multi-faceted plot,
#' where each facet represents a different time series pattern from the model's data.
#' Within each facet, different attributes (e.g., bands or indices) are plotted as
#' distinct lines, differentiated by color.
#'
#' @param x A list where each element is a data.frame representing a time series.
#' Each data.frame should have the same number of rows and columns,
#' with columns representing different attributes (e.g., bands or indices)
#' and rows representing time points.
#' The name of each element in the list will be used as the facet title.
#' @param x A model of class \code{"knn1_twdtw"}.
#'
#' @param ... Not used.
#' @param n An integer specifying the number of patterns to plot. Default is 12.
#'
#' @return A ggplot object displaying the time series patterns.
#' @param ... Additional arguments. Currently not used.
#'
#' @details
#' The function is designed for visual inspection of time series patterns from the model.
#' Each data frame in the model's data represents a time series. Columns within the data frame
#' correspond to different attributes (e.g., bands or indices), while rows represent time points.
#' The facet title is derived from the name of each time series in the model's data.
#'
#' @return A \code{\link[ggplot2]{ggplot}} object displaying the time series patterns.
#'
#' @export
plot_patterns = function(x, ...) {
plot.knn1_twdtw <- function(x, n = 12, ...) {

# Convert the list of time series data into a long-format data.frame
df.p = do.call("rbind", lapply(names(x), function(p) {
ts = x[[p]]
# Create a new data.frame with a 'Time' column and a 'Pattern' column
# representing the name of the current time series (facet name).
data.frame(Time = 1:nrow(ts), ts, Pattern = p) # Assuming the time series are evenly spaced
}))

df <- unnest(x$data[1:n, ], cols = .data$observations)

# Melt the data into long format suitable for ggplot2
df.p = melt(df.p, id.vars = c("Time", "Pattern"))
df <- pivot_longer(df, !c(.data$label, .data$time), names_to = "band", values_to = "value")

# Construct the ggplot
gp = ggplot(df.p, aes(x = .data$Time, y = .data$value, colour = .data$variable)) +
geom_line() +
facet_wrap(~Pattern) +
theme(legend.position = "bottom") +
guides(colour = guide_legend(title = "Bands")) +
ylab("Value")

gp <- ggplot(df, aes(x = .data$time, y = .data$value, colour = .data$band)) +
geom_line() +
facet_wrap(~label) +
theme(legend.position = "bottom") +
guides(colour = guide_legend(title = "Bands")) +
ylab("Value") +
xlab("Time")

return(gp)
}

}
34 changes: 34 additions & 0 deletions R/predict.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#' Predict using the knn1_twdtw model
#'
#' This function predicts the classes of new data using the Time Warped Dynamic Time Warping (TWDTW)
#' method with a 1-nearest neighbor approach. The prediction is based on the minimum TWDTW distance
#' to the known patterns stored in the `knn1_twdtw` model.
#'
#' @param object A `knn1_twdtw` model object generated by the `knn1_twdtw` function.
#' @param newdata A data frame or similar object containing the new observations
#' (time series data) to be predicted.
#' @param ... Additional arguments passed to the \link[twdtw]{twdtw} function.
#'
#' @return A vector of predicted classes for the `newdata`.
#'
#'
#' @export
predict.knn1_twdtw <- function(object, newdata, ...){


# Convert newdata to time series
newdata_ts <- prepare_time_series(newdata)

# Compute TWDTW distances
distances <- sapply(object$data$observations, function(pattern){
sapply(newdata_ts$observations, function(ts) {
proxy::dist(x = as.data.frame(ts), y = as.data.frame(pattern), method = "twdtw", ...)
})
})

# Find the nearest neighbor for each observation in newdata
nearest_neighbor <- apply(distances, 1, which.min)

# Return the predicted label for each observation based on the nearest neighbor
return(factor(object$data$label[nearest_neighbor]))
}
81 changes: 81 additions & 0 deletions R/prepare_time_series.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#' Prepare a Time Series Tibble from a 2D stars Object with Bands and Time Attributes
#'
#' This function reshapes a data frame, which has been converted from a stars object, into a nested wide tibble format.
#' The stars object conversion often results in columns named in formats like "band.YYYY.MM.DD", "XYYYY.MM.DD.band", or "YYYY.MM.DD.band".
#'
#' @param x A data frame derived from a stars object containing time series data in wide format.
#' The column names should adhere to one of the following formats: "band.YYYY.MM.DD", "XYYYY.MM.DD.band", or "YYYY.MM.DD.band".
#'
#' @return A nested tibble in wide format. Each row of the tibble corresponds to a unique 'ts_id' that maintains the order from the original stars object.
#' The nested structure contains observations (time series) for each 'ts_id', including the 'time' of each observation, and individual bands are presented as separate columns.
#'
#'
prepare_time_series <- function(x) {

# Remove the 'geom' column if it exists
x$geom <- NULL
var_names <- names(x)
var_names <- var_names[!var_names %in% 'label']

# Extract date and band information from the column names
date_band <- do.call(rbind, lapply(var_names, function(name) {

# Replace any hyphen with periods for consistent processing
name <- gsub("-", "\\.", name)

# Extract date and band info based on different name patterns
if (grepl("^.+\\.\\d{4}\\.\\d{2}\\.\\d{2}$", name)) {
date_str <- gsub("^.*?(\\d{4}\\.\\d{2}\\.\\d{2})$", "\\1", name)
band_str <- gsub("^(.+)\\.\\d{4}\\.\\d{2}\\.\\d{2}$", "\\1", name)
}
else if (grepl("^X\\d{4}\\.\\d{2}\\.\\d{2}\\..+$", name)) {
date_str <- gsub("^X(\\d{4}\\.\\d{2}\\.\\d{2})\\..+$", "\\1", name)
band_str <- gsub("^X\\d{4}\\.\\d{2}\\.\\d{2}\\.(.+)$", "\\1", name)
}
else if (grepl("^\\d{4}\\.\\d{2}\\.\\d{2}\\..+$", name)) {
date_str <- gsub("^(\\d{4}\\.\\d{2}\\.\\d{2})\\..+$", "\\1", name)
band_str <- gsub("^\\d{4}\\.\\d{2}\\.\\d{2}\\.(.+)$", "\\1", name)
}
else {
stop(paste("Unrecognized format in:", name))
}

# Convert the date string to Date format
date <- to_date_time(gsub("\\.", "-", date_str))

return(data.frame(time = date, band = band_str))
}))

# Construct tiem sereis
ns <- nrow(x)
x$ts_id <- 1:ns
if (!'label' %in% names(x)) {
x$label <- NA
}
x <- pivot_longer(x, !c(.data$ts_id, .data$label), names_to = "band_date", values_to = "value")
x$band <- rep(date_band$band, ns)
x$time <- rep(date_band$time, ns)
x$band_date <- NULL
result_df <- pivot_wider(x, id_cols = c(.data$ts_id, .data$label, .data$time), names_from = 'band', values_from = 'value')
result_df <- nest(result_df, .by = c(.data$ts_id, .data$label), .key = "observations")

return(result_df)

}


#### TO BE REMOVED: twdtw package will export this fucntion
to_date_time <- function(x){
if (!inherits(x, c("Date", "POSIXt"))) {
# check if all strings in the vector include hours, minutes, and seconds
if (all(grepl("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", x))) {
x <- try(as.POSIXct(x), silent = TRUE)
} else {
x <- try(as.Date(x), silent = TRUE)
}
if (inherits(x, "try-error")) {
stop("Some elements of x could not be converted to a date or datetime format")
}
}
return(x)
}
Loading

0 comments on commit 6329150

Please sign in to comment.