-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds knn-1-twdtw train and predict methods for stars objects
- Loading branch information
Showing
20 changed files
with
464 additions
and
310 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,43 @@ | ||
#' Plot Patterns from Time Series Data | ||
#' | ||
#' This function takes a list of time series data and creates a multi-faceted plot | ||
#' where each facet corresponds to a different time series from the list. | ||
#' Within each facet, different attributes (columns of the time series) are | ||
#' plotted as lines with different colors. | ||
#' This function visualizes time series data patterns from the \code{"knn1_twdtw"} model. | ||
#' It produces a multi-faceted plot, | ||
#' where each facet represents a different time series pattern from the model's data. | ||
#' Within each facet, different attributes (e.g., bands or indices) are plotted as | ||
#' distinct lines, differentiated by color. | ||
#' | ||
#' @param x A list where each element is a data.frame representing a time series. | ||
#' Each data.frame should have the same number of rows and columns, | ||
#' with columns representing different attributes (e.g., bands or indices) | ||
#' and rows representing time points. | ||
#' The name of each element in the list will be used as the facet title. | ||
#' @param x A model of class \code{"knn1_twdtw"}. | ||
#' | ||
#' @param ... Not used. | ||
#' @param n An integer specifying the number of patterns to plot. Default is 12. | ||
#' | ||
#' @return A ggplot object displaying the time series patterns. | ||
#' @param ... Additional arguments. Currently not used. | ||
#' | ||
#' @details | ||
#' The function is designed for visual inspection of time series patterns from the model. | ||
#' Each data frame in the model's data represents a time series. Columns within the data frame | ||
#' correspond to different attributes (e.g., bands or indices), while rows represent time points. | ||
#' The facet title is derived from the name of each time series in the model's data. | ||
#' | ||
#' @return A \code{\link[ggplot2]{ggplot}} object displaying the time series patterns. | ||
#' | ||
#' @export | ||
plot_patterns = function(x, ...) { | ||
plot.knn1_twdtw <- function(x, n = 12, ...) { | ||
|
||
# Convert the list of time series data into a long-format data.frame | ||
df.p = do.call("rbind", lapply(names(x), function(p) { | ||
ts = x[[p]] | ||
# Create a new data.frame with a 'Time' column and a 'Pattern' column | ||
# representing the name of the current time series (facet name). | ||
data.frame(Time = 1:nrow(ts), ts, Pattern = p) # Assuming the time series are evenly spaced | ||
})) | ||
|
||
df <- unnest(x$data[1:n, ], cols = .data$observations) | ||
|
||
# Melt the data into long format suitable for ggplot2 | ||
df.p = melt(df.p, id.vars = c("Time", "Pattern")) | ||
df <- pivot_longer(df, !c(.data$label, .data$time), names_to = "band", values_to = "value") | ||
|
||
# Construct the ggplot | ||
gp = ggplot(df.p, aes(x = .data$Time, y = .data$value, colour = .data$variable)) + | ||
geom_line() + | ||
facet_wrap(~Pattern) + | ||
theme(legend.position = "bottom") + | ||
guides(colour = guide_legend(title = "Bands")) + | ||
ylab("Value") | ||
|
||
gp <- ggplot(df, aes(x = .data$time, y = .data$value, colour = .data$band)) + | ||
geom_line() + | ||
facet_wrap(~label) + | ||
theme(legend.position = "bottom") + | ||
guides(colour = guide_legend(title = "Bands")) + | ||
ylab("Value") + | ||
xlab("Time") | ||
|
||
return(gp) | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#' Predict using the knn1_twdtw model | ||
#' | ||
#' This function predicts the classes of new data using the Time Warped Dynamic Time Warping (TWDTW) | ||
#' method with a 1-nearest neighbor approach. The prediction is based on the minimum TWDTW distance | ||
#' to the known patterns stored in the `knn1_twdtw` model. | ||
#' | ||
#' @param object A `knn1_twdtw` model object generated by the `knn1_twdtw` function. | ||
#' @param newdata A data frame or similar object containing the new observations | ||
#' (time series data) to be predicted. | ||
#' @param ... Additional arguments passed to the \link[twdtw]{twdtw} function. | ||
#' | ||
#' @return A vector of predicted classes for the `newdata`. | ||
#' | ||
#' | ||
#' @export | ||
predict.knn1_twdtw <- function(object, newdata, ...){ | ||
|
||
|
||
# Convert newdata to time series | ||
newdata_ts <- prepare_time_series(newdata) | ||
|
||
# Compute TWDTW distances | ||
distances <- sapply(object$data$observations, function(pattern){ | ||
sapply(newdata_ts$observations, function(ts) { | ||
proxy::dist(x = as.data.frame(ts), y = as.data.frame(pattern), method = "twdtw", ...) | ||
}) | ||
}) | ||
|
||
# Find the nearest neighbor for each observation in newdata | ||
nearest_neighbor <- apply(distances, 1, which.min) | ||
|
||
# Return the predicted label for each observation based on the nearest neighbor | ||
return(factor(object$data$label[nearest_neighbor])) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#' Prepare a Time Series Tibble from a 2D stars Object with Bands and Time Attributes | ||
#' | ||
#' This function reshapes a data frame, which has been converted from a stars object, into a nested wide tibble format. | ||
#' The stars object conversion often results in columns named in formats like "band.YYYY.MM.DD", "XYYYY.MM.DD.band", or "YYYY.MM.DD.band". | ||
#' | ||
#' @param x A data frame derived from a stars object containing time series data in wide format. | ||
#' The column names should adhere to one of the following formats: "band.YYYY.MM.DD", "XYYYY.MM.DD.band", or "YYYY.MM.DD.band". | ||
#' | ||
#' @return A nested tibble in wide format. Each row of the tibble corresponds to a unique 'ts_id' that maintains the order from the original stars object. | ||
#' The nested structure contains observations (time series) for each 'ts_id', including the 'time' of each observation, and individual bands are presented as separate columns. | ||
#' | ||
#' | ||
prepare_time_series <- function(x) { | ||
|
||
# Remove the 'geom' column if it exists | ||
x$geom <- NULL | ||
var_names <- names(x) | ||
var_names <- var_names[!var_names %in% 'label'] | ||
|
||
# Extract date and band information from the column names | ||
date_band <- do.call(rbind, lapply(var_names, function(name) { | ||
|
||
# Replace any hyphen with periods for consistent processing | ||
name <- gsub("-", "\\.", name) | ||
|
||
# Extract date and band info based on different name patterns | ||
if (grepl("^.+\\.\\d{4}\\.\\d{2}\\.\\d{2}$", name)) { | ||
date_str <- gsub("^.*?(\\d{4}\\.\\d{2}\\.\\d{2})$", "\\1", name) | ||
band_str <- gsub("^(.+)\\.\\d{4}\\.\\d{2}\\.\\d{2}$", "\\1", name) | ||
} | ||
else if (grepl("^X\\d{4}\\.\\d{2}\\.\\d{2}\\..+$", name)) { | ||
date_str <- gsub("^X(\\d{4}\\.\\d{2}\\.\\d{2})\\..+$", "\\1", name) | ||
band_str <- gsub("^X\\d{4}\\.\\d{2}\\.\\d{2}\\.(.+)$", "\\1", name) | ||
} | ||
else if (grepl("^\\d{4}\\.\\d{2}\\.\\d{2}\\..+$", name)) { | ||
date_str <- gsub("^(\\d{4}\\.\\d{2}\\.\\d{2})\\..+$", "\\1", name) | ||
band_str <- gsub("^\\d{4}\\.\\d{2}\\.\\d{2}\\.(.+)$", "\\1", name) | ||
} | ||
else { | ||
stop(paste("Unrecognized format in:", name)) | ||
} | ||
|
||
# Convert the date string to Date format | ||
date <- to_date_time(gsub("\\.", "-", date_str)) | ||
|
||
return(data.frame(time = date, band = band_str)) | ||
})) | ||
|
||
# Construct tiem sereis | ||
ns <- nrow(x) | ||
x$ts_id <- 1:ns | ||
if (!'label' %in% names(x)) { | ||
x$label <- NA | ||
} | ||
x <- pivot_longer(x, !c(.data$ts_id, .data$label), names_to = "band_date", values_to = "value") | ||
x$band <- rep(date_band$band, ns) | ||
x$time <- rep(date_band$time, ns) | ||
x$band_date <- NULL | ||
result_df <- pivot_wider(x, id_cols = c(.data$ts_id, .data$label, .data$time), names_from = 'band', values_from = 'value') | ||
result_df <- nest(result_df, .by = c(.data$ts_id, .data$label), .key = "observations") | ||
|
||
return(result_df) | ||
|
||
} | ||
|
||
|
||
#### TO BE REMOVED: twdtw package will export this fucntion | ||
to_date_time <- function(x){ | ||
if (!inherits(x, c("Date", "POSIXt"))) { | ||
# check if all strings in the vector include hours, minutes, and seconds | ||
if (all(grepl("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", x))) { | ||
x <- try(as.POSIXct(x), silent = TRUE) | ||
} else { | ||
x <- try(as.Date(x), silent = TRUE) | ||
} | ||
if (inherits(x, "try-error")) { | ||
stop("Some elements of x could not be converted to a date or datetime format") | ||
} | ||
} | ||
return(x) | ||
} |
Oops, something went wrong.