#' fit_epiestim_model - Function to estimate the reproduction number of an epidemic
#'
#' @description A wrapper function for {\code{\link[EpiEstim]{estimate_R}}} from the \code{EpiEstim} library to estimate the reproduction number of epidemics to support short-term forecasts
#'
#'
#' @details \code{fit_epiestim_model} currently supports the following epidemics: Influenza, RSV and COVID-19. The default serial intervals for the estimation of R were retrieved from
#' Cowling et al., 2011, Vink et al., 2014 and Madewell et al., 2023 for Influenza A, Influenza B, RSV and COVID (BA.5 Omicron variant) respectively
#'
#'
#' @param data *data frame* containing two columns: date and confirm (number of cases)
#' @param window_size *Integer* Length of the sliding windows used for R estimates.
#' @param type *character* Specifies type of epidemic. Must be one of "flu_a", "flu_b", "rsv", "sars_cov2" or "custom"
#' @param mean_si *Numeric* User specification of mean of parametric serial interval
#' @param std_si *Numeric* User specification of standard deviation of parametric serial interval
#' @param recon_opt Not implemented. One of "naive" or "match" to pass on to {\code{\link[EpiEstim]{estimate_R}}} (see help page)
#' @param method One of "non_parametric_si", "parametric_si", "uncertain_si", "si_from_data" or "si_from_sample" to pass on to {\code{\link[EpiEstim]{estimate_R}}} (see help page)
#' @param mean_prior *Numeric* positive number giving the mean of the common prior distribution for all reproduction numbers
#' @param std_prior *Numeric* positive number giving the standard deviation of the common prior distribution for all reproduction numbers
#'
#'
#' @return Object of class {\code{\link[EpiEstim]{estimate_R}}} (see \code{EpiEstim} help page)
#' @importFrom rlang .data

fit_epiestim_model <- function(data, window_size = 7L,type = NULL, mean_si = NULL, std_si = NULL, recon_opt = "match",
                               method = "parametric_si", mean_prior = NULL, std_prior = NULL) {
  confirm <- NULL
  if (!is.data.frame(data) || !all(colnames(data) %in% c("date", "confirm"))) {
    stop("Must pass a data frame with two columns: date and confirm")
  }
  if (missing(type) || !(type %in% c("flu_a", "flu_b", "sars_cov2", "rsv", "custom"))) {
    stop("Must specify the type of epidemic (flu_a, flu_b, sars_cov2, rsv or custom)")
  }
  if (type == "custom" && any(is.null(mean_si), is.null(std_si), is.null(mean_prior), is.null(std_prior))) {
    stop("Must specify mean_si, std_si, mean_prior and std_prior for type custom")
  }

  if (type != "custom" && any(!is.null(mean_si), !is.null(std_si), !is.null(mean_prior), !is.null(std_prior))) {
    warning("Custom mean_si, std_s, mean_prior and std_prior can only be specified with type set to custom. Default config values were used")
  }

  # 6. Providing default values
  if (is.null(mean_si)) {
    mean_si <- switch(type,
                      "flu_a" = 4,
                      "flu_b" = 3.7,
                      "rsv" = 7.5,
                      "sars_cov2" = 4,
                      "custom" = NULL
    )
  }
  if (is.null(std_si)) {
    std_si <- switch(type,
                     "flu_a" = 2,
                     "flu_b" = 2.1,
                     "rsv" = 2.1,
                     "sars_cov2" = 4.75,
                     "custom" = NULL
    )
  }
  if (is.null(mean_prior)) {
    mean_prior <- switch(type,
                         "flu_a" = 1,
                         "flu_b" = 1,
                         "rsv" = 1,
                         "sars_cov2" = 2,
                         "custom" = NULL
    )
  }
  if (is.null(std_prior)) {
    std_prior <- switch(type,
                        "flu_a" = 1,
                        "flu_b" = 1,
                        "rsv" = 1,
                        "sars_cov2" = 1,
                        "custom" = NULL
    )
  }

  # configuring based on type
  incid <- data.frame(I = data$confirm, dates = data$date) %>%
    dplyr::arrange(.data$dates)
  # change window size to custom window_size
  # starting at 2 as conditional on the past observations
  n_t <- nrow(incid)
  t_start <- seq(2, max(n_t-(window_size - 1),2))
  t_end <- pmin(t_start + window_size - 1,n_t)

  config <- EpiEstim::make_config(list(
    mean_si = mean_si,
    std_si = std_si,
    mean_prior = mean_prior,
    std_prior = std_prior,
    t_start = t_start,
    t_end = t_end
  ))


  epiestim_estimates <- NULL
  epiestim_estimates <- suppressWarnings(EpiEstim::estimate_R(
    incid = incid,
    method = method,
    config = config
  ))


  return(epiestim_estimates)
}


#' Extract daily forecast samples
#'
#'
#' @description Function to produce short-term daily projections from objects of class {\code{\link[EpiEstim]{estimate_R}}}
#'
#' @param data *data frame* containing two columns: date and confirm (number of cases per day)
#' @param model_fit Object of class {\code{\link[EpiEstim]{estimate_R}}} generated by running \code{fit_epiestim_model}
#' @param n_days 	The number of days to run simulations for. Defaults to 14
#' @param n_sim The number of epicurves to simulate. Defaults to 1000
#'
#'
#'
#'
#' @return Data-frame of daily forecast samples from all simulations
#' \describe{
#'   \item{date}{date}
#'   \item{incidence}{projected number of daily confirmed cases}
#'   \item{sim}{simulation run number}
#' }
#' 
project_epiestim_model <- function(data, model_fit, n_days = 7, n_sim = 1000) {
  confirm <- NULL

  # check valid days
  check_epiestim_format(data)
  check_min_days(data)

  # incidence expects data in linelist format
  date_list <- data |>
    tidyr::uncount(confirm) |>
    dplyr::pull(date)

  incidence_obj <- incidence::incidence(date_list,
                                        first_date = min(data$date, na.rm = T),
                                        last_date = max(data$date, na.rm = T),
                                        standard = FALSE)
  r_vals <- utils::tail(model_fit$R, n = 1)

  # sample from a truncated normal using inverse transform uniform sampling
  r_dist <- stats::qnorm(stats::runif(1000,
                                      stats::pnorm(0,
                                                   mean = r_vals$`Mean(R)`,
                                                   sd = r_vals$`Std(R)`),
                                      1),
                         mean = r_vals$`Mean(R)`,
                         sd = r_vals$`Std(R)`)

  # Use the project function
  proj <- projections::project(incidence_obj,
                               R = r_dist,
                               si = model_fit$si_distr[-1],
                               n_sim = n_sim,
                               n_days = n_days,
                               R_fix_within = FALSE
  )

  data_proj <- as.data.frame(proj, long = TRUE)

  return(data_proj)
}


#' Forecast daily epidemic cases using EpiEstim
#'
#' @description
#' This function prepares epidemic data, estimates the reproduction number
#' (\eqn{R_t}) using \code{\link{fit_epiestim_model}}, and produces short-term
#' forecasts of daily confirmed cases with \code{\link{project_epiestim_model}}.
#'
#' It removes early periods with no cases, checks data validity, optionally
#' smooths the epidemic curve, and then generates forward projections of cases
#' for a specified number of days.
#'
#' @details
#' - Data prior to the first non-zero `confirm` value is excluded.
#' - Input is checked for validity (sufficient days, proper format).
#' - If smoothing is enabled, case counts are adjusted before fitting.
#' - Forecasts are generated from the fitted EpiEstim model and returned with
#'   quantiles (2.5%, 25%, 50%, 75%, 97.5%), minimum, and maximum.
#'
#' @param data *data frame*
#'   Must contain two columns:
#'   - `date`: observation dates
#'   - `confirm`: daily confirmed cases
#'
#' @param start_date *Date*
#'   Date after which the epidemic is considered to have started. Data before
#'   this date is removed.
#'
#' @param window_size *Integer*
#'   Length of the sliding window (in days) used for reproduction number
#'   estimation. Default is 7.
#'
#' @param n_days *Integer*
#'   Number of future days to forecast. Default is 7.
#'
#' @param type *character*
#'   Type of epidemic. Must be one of `"flu_a"`, `"flu_b"`, `"rsv"`,
#'   `"sars_cov2"`, or `"custom"`. Passed to
#'   \code{\link{fit_epiestim_model}}.
#'
#' @param smooth_data *logical*
#'   Whether to smooth the input daily case counts before estimation. Default
#'   is `FALSE`.
#'
#' @param smoothing_cutoff *Integer*
#'   Cutoff parameter for smoothing. Only used if \code{smooth_data = TRUE}.
#'   Default is 10.
#'
#' @param ...
#'   Additional arguments passed to \code{\link{fit_epiestim_model}}.
#'
#' @return A data frame of forecasted daily incidence with columns:
#'   - `date`: date of forecast
#'   - `p50`, `p25`, `p75`, `p025`, `p975`: forecast quantiles
#'   - `min_sim`, `max_sim`: forecast range
#'
#' @seealso
#' \code{\link{fit_epiestim_model}} for reproduction number estimation,
#' \code{\link{project_epiestim_model}} for forward simulations.
#' 
#' @importFrom incidence incidence
#' @importFrom rlang .data
#'
#' @export
#'
#' @examples
#'
#' # Create sample test rsv data
#' disease_type <- "rsv"
#' test_data <- simulate_data()
#' formatted_data <- get_aggregated_data(
#'   test_data,
#'   number_column = disease_type,
#'   date_column = "date",
#'   start_date = "2024-04-01",
#'   end_date = "2024-05-01"
#' )
#'
#' # Run a 7 day forecast with smoothing
#' res_smooth <- generate_forecast(
#'   data = formatted_data,
#'   start_date = "2024-04-01",
#'   n_days = 7,
#'   type = "rsv",
#'   smooth_data = FALSE
#' )
#' 

generate_forecast <- function(
    data,
    start_date,
    window_size = 7,
    n_days = 7,
    type = NULL,
    smooth_data = FALSE,
    smoothing_cutoff = 10,
    ...
){

  data <- clean_sample_data(data,
                            start_date)

  # use smooth data
  if(smooth_data){
    smoothed_output <- smooth_model_data(data, smoothing_cutoff = smoothing_cutoff)
    data <- smoothed_output$data
    smoothed_error <- smoothed_output$error
    original_data <- smoothed_output$original_data
    smoothed_data <- smoothed_output$data
  }else{
    smoothed_error <- NULL
    original_data <- data
    smoothed_data <- NULL
  }
  # modelling function
  epiestim_estimates <- fit_epiestim_model(data = data,
                                           window_size = window_size,
                                           type = type,
                                           ...)
  # generating forecast data
  forecast_res <- project_epiestim_model(data = data,
                                     model_fit = epiestim_estimates,
                                     n_days = n_days)

  forecast_res_quantiles <- forecast_res %>%
    dplyr::rename(daily_incidence = incidence) %>%
    create_quantiles(.data$date, variable = "daily_incidence")



  return(list(original_data = original_data,
              smoothed_data = smoothed_data,
              smoothed_error = smoothed_error,
              forecast_res_quantiles = forecast_res_quantiles,
              estimate_R = epiestim_estimates))
}



#' Smooth Model Data Using P-Spline GAM
#'
#' Applies P-spline smoothing using a Generalized Additive Model (GAM) to the input model data.
#' If the number of rows in `model_data` is greater than or equal to `smoothing_cutoff`,
#' the function fits a GAM, estimates confidence intervals, and computes uncertainty.
#'
#' @param model_data A data frame containing the model data, including a column `confirm` representing observed values.
#' @param smoothing_cutoff minimum number of rows required to smooth
#' @param n_reps Number of replicates to calculate error
#'
#' @return A list with two elements:
#'  - `original_data`: A copy of `model_data`
#' 	- `data`: A data frame with smoothed `confirm` values.
#' 	- `error`: Estimated uncertainty of the smoothing process.
#'
#' @importFrom mgcv gam
#' @importFrom stats coef qnorm rnorm quantile
#' @importFrom dplyr mutate
#' @noRd
smooth_model_data <- function(model_data, smoothing_cutoff = 10, n_reps = 10000) {
  ##### Add in P-spline smoothing with GAM at each time-step ###################
  smoothed_model_data <- model_data
  if (nrow(model_data) >= smoothing_cutoff) {
    index <- seq_len(nrow(model_data))
    model_data$index <- index
    model_smooth <- mgcv::gam(confirm ~ s(index, bs = "ps", k = round(length(index) / 2, 0)), data = model_data)
    beta <- coef(model_smooth)
    Vb <- stats::vcov(model_smooth)
    Cv <- chol(Vb)

    nb <- length(beta)
    br <- t(Cv) %*% matrix(rnorm(n_reps * nb), nb, n_reps) + beta
    Xp <- suppressWarnings(stats::predict(model_smooth, newdata = data.frame(index = index), type = "lpmatrix"))
    fv <- Xp %*% br
    yr <- matrix(rnorm(nrow(fv) * ncol(fv), mean = fv, sd = model_smooth$sig2),
      nrow = nrow(fv), ncol = ncol(fv)
    )
    conf_int <- apply(yr, 1, quantile, prob = c(0.025, 0.975))
    diff <- conf_int[2, ] - conf_int[1, ]
    uncertainity_se <- diff / (qnorm(1 - 0.05 / 2) * 2)
    smoothed_estimates <- stats::predict(model_smooth, type = "response", se.fit = TRUE)
    smoothed_model_data$confirm <- round(smoothed_estimates$fit, 0)
    smoothed_model_data <- smoothed_model_data %>%
      mutate(confirm = ifelse(.data$confirm < 0, 0, .data$confirm))
    smoothed_error <- data.frame(smoothed_error = smoothed_estimates$se.fit + uncertainity_se)
  } else {
    smoothed_error <- 0
  }

  return(list(original_data = model_data, data = smoothed_model_data, error = smoothed_error))
}



