Skip to content

Commit

Permalink
Rename class. Fixes stars conversin bug. Improves plot fucntion.
Browse files Browse the repository at this point in the history
  • Loading branch information
vwmaus committed Sep 18, 2023
1 parent 8ec86c7 commit ad5ed5b
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 99 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
Collate:
'plot_patterns.R'
'plot.R'
'predict.R'
'prepare_time_series.R'
'shift_dates.R'
Expand Down
6 changes: 3 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,knn1_twdtw)
S3method(predict,knn1_twdtw)
export(knn1_twdtw)
S3method(plot,twdtw_knn1)
S3method(predict,twdtw_knn1)
export(shift_dates)
export(twdtw_knn1)
import(ggplot2)
import(sf)
import(stars)
Expand Down
48 changes: 48 additions & 0 deletions R/plot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#' Plot Patterns from twdtw-knn1 model
#'
#' This function visualizes time series patterns from the \code{"twdtw_knn1"} model.
#' It produces a multi-faceted plot, where each facet represents a different time series
#' label from the model's data. Within each facet, different bands or indices (attributes)
#' are plotted as distinct lines, differentiated by color.
#'
#' @param x A model of class \code{"twdtw_knn1"}.
#'
#' @param bands A character vector specifying the bands or indices to plot.
#' If NULL (default), all available bands or indices in the data will be plotted.
#'
#' @param ... Additional arguments passed to \code{\link[ggplot2]{ggplot}}. Currently not used.
#'
#' @return A \code{\link[ggplot2]{ggplot}} object displaying the time series patterns.
#'
#' @seealso twdtw_knn1
#'
#' @inherit twdtw_knn1 examples
#'
#' @export
plot.twdtw_knn1 <- function(x, bands = NULL, ...) {

# Convert the list of time series data into a long-format data.frame
df <- x$data
df$id <- 1:nrow(df)
df <- unnest(df, cols = 'observations')

# Select bands
if(!is.null(bands)){
df <- df[c('id', 'time', 'label', bands)]
}

# Pivote data into long format for ggplot2
df <- pivot_longer(df, !c('id', 'label', 'time'), names_to = "band", values_to = "value")

# Construct the ggplot
gp <- ggplot(df, aes(x = .data$time, y = .data$value, colour = .data$band, group = interaction(.data$id, .data$band))) +
geom_line() +
facet_wrap(~label) +
theme(legend.position = "bottom") +
guides(colour = guide_legend(title = "Bands")) +
ylab("Value") +
xlab("Time")

return(gp)

}
47 changes: 0 additions & 47 deletions R/plot_patterns.R

This file was deleted.

14 changes: 7 additions & 7 deletions R/predict.R
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
#' Predict using the knn1_twdtw model
#' Predict using the twdtw_knn1 model
#'
#' This function predicts the classes of new data using the Time Warped Dynamic Time Warping (TWDTW)
#' method with a 1-nearest neighbor approach. The prediction is based on the minimum TWDTW distance
#' to the known patterns stored in the `knn1_twdtw` model.
#' to the known patterns stored in the `twdtw_knn1` model.
#'
#' @param object A `knn1_twdtw` model object generated by the `knn1_twdtw` function.
#' @param object A `twdtw_knn1` model object generated by the `twdtw_knn1` function.
#' @param newdata A data frame or similar object containing the new observations
#' (time series data) to be predicted.
#' @param ... Additional arguments passed to the \link[twdtw]{twdtw} function.
#'
#' @return A vector of predicted classes for the `newdata`.
#'
#' @seealso knn1_twdtw
#' @seealso twdtw_knn1
#'
#' @inherit knn1_twdtw examples
#' @inherit twdtw_knn1 examples
#'
#' @export
predict.knn1_twdtw <- function(object, newdata, ...){
predict.twdtw_knn1 <- function(object, newdata, ...){


# Convert newdata to time series
Expand All @@ -25,7 +25,7 @@ predict.knn1_twdtw <- function(object, newdata, ...){
# Compute TWDTW distances
distances <- sapply(object$data$observations, function(pattern){
sapply(newdata_ts$observations, function(ts) {
proxy::dist(x = as.data.frame(ts), y = as.data.frame(pattern), method = "twdtw", ...)
proxy::dist(x = as.data.frame(ts), y = as.data.frame(pattern), method = 'twdtw', ...)
})
})

Expand Down
2 changes: 2 additions & 0 deletions R/prepare_time_series.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ prepare_time_series <- function(x) {

# Remove the 'geom' column if it exists
x$geom <- NULL
x$x <- NULL
x$y <- NULL
var_names <- names(x)
var_names <- var_names[!var_names %in% 'label']

Expand Down
10 changes: 6 additions & 4 deletions R/train.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
#' then samples of the same label (land cover class) will be resampled using GAM.
#' Resampling can significantly reduce prediction processing time.
#'
#' @return A 'knn1_twdtw' model containing the trained model information and the data used.
#' @return A 'twdtw_knn1' model containing the trained model information and the data used.
#'
#' @examples
#' \dontrun{
#'
#' # Read training samples
#' samples <-
#' system.file("mato_grosso_brazil/samples.gpkg", package = "dtwSat") |>
Expand All @@ -46,7 +47,7 @@
#' split(c("time"))
#'
#' # Create a knn1-twdtw model
#' m <- knn1_twdtw(x = dc,
#' m <- twdtw_knn1(x = dc,
#' y = samples,
#' formula = band ~ s(time))
#'
Expand All @@ -62,9 +63,10 @@
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50))
#' )
#'
#' }
#' @export
knn1_twdtw <- function(x, y, formula = NULL, start_column = 'start_date',
twdtw_knn1 <- function(x, y, formula = NULL, start_column = 'start_date',
end_column = 'end_date', label_colum = 'label',
sampling_freq = NULL, ...){

Expand Down Expand Up @@ -138,7 +140,7 @@ knn1_twdtw <- function(x, y, formula = NULL, start_column = 'start_date',
model$call <- match.call()
model$formula <- formula
model$data <- ts_data
class(model) <- "knn1_twdtw"
class(model) <- "twdtw_knn1"

return(model)

Expand Down
38 changes: 17 additions & 21 deletions man/plot.knn1_twdtw.Rd → man/plot.twdtw_knn1.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 10 additions & 8 deletions man/predict.knn1_twdtw.Rd → man/predict.twdtw_knn1.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions man/knn1_twdtw.Rd → man/twdtw_knn1.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ad5ed5b

Please sign in to comment.