# FUNCTONS FOR PRINTING/SUMMARIZING ---------------------------------------

#' @rdname load_fits_ids
#' @export
print.fits_ids_dm <- function(x, ...) {
  fits_ids <- x

  cat("Fit procedure name:", fits_ids$drift_dm_fit_info$fit_procedure_name)
  cat("\n")
  cat(
    "Fitted model type:",
    paste(
      class(fits_ids$drift_dm_fit_info$drift_dm_obj),
      collapse = ", "
    )
  )
  cat("\n")
  cat("Time of (last) call:", fits_ids$drift_dm_fit_info$time_call)
  cat("\n")
  cat("N Individuals:", length(fits_ids$all_fits), "\n")

  invisible(x)
}


#' @rdname summary.fits_ids_dm
#' @export
print.summary.fits_ids_dm <- function(x, ...,
                                      round_digits = drift_dm_default_rounding()) {
  summary_obj <- x
  cat("Fit Procedure Name:", summary_obj$fit_procedure_name)
  cat("\n")
  cat("N Individuals:", summary_obj$N, "\n\n")


  for (one_cond in names(summary_obj$stats)) {
    cat("Parameter Summary:", one_cond, "\n")
    temp <- round(summary_obj$stats[[one_cond]], round_digits)
    print(temp)
    cat("\n")
  }
  cat("\n")

  cat("Parameter Space:\n")
  temp <- rbind(summary_obj$lower, summary_obj$upper)
  rownames(temp) <- c("lower", "upper")
  colnames(temp) <- names(summary_obj$upper)
  print(temp)

  cat("\n-------\n")
  cat("Fitted Model Type:", summary_obj$model_type)
  cat("\n")
  cat("Time of (Last) Call:", summary_obj$time_call)
  cat("\n")
  invisible(x)
}


#' Summary and Printing for fits_ids_dm Objects
#'
#' Methods for summarizing and printing objects of the class `fits_ids_dm`,
#' which contain multiple fits across individuals.
#'
#' @param object an object of class `fits_ids_dm`, generated by a call
#'   to [dRiftDM::load_fits_ids].
#' @param x an object of class `summary.fits_ids_dm`.
#' @param round_digits integer, specifying the number of decimal places for
#'   rounding in the printed summary. Default is set to 3.
#' @param ... additional arguments
#'
#' @details
#' The `summary.fits_ids_dm` function creates a summary object containing:
#' - **fit_procedure_name**: The name of the fit procedure used.
#' - **time_call**: Timestamp of the last fit procedure call.
#' - **lower** and **upper**: Lower and upper bounds of the search space.
#' - **model_type**: Description of the model type, based on class information.
#' - **prms**: All parameter values across all conditions (essentially a call
#'   to coef() with the argument select_unique = FALSE).
#' - **stats**: A named list of matrices for each condition, including mean and
#'   standard error for each parameter.
#' - **N**: The number of individuals.
#'
#' The `print.summary.fits_ids_dm` function displays the summary object in a
#' formatted manner.
#'
#' @return
#' `summary.fits_ids_dm()` returns a list of class `summary.fits_ids_dm` (see
#' the Details section summarizing each entry of this list).
#'
#' `print.summary.fits_ids_dm()` returns invisibly the `summary.fits_ids_dm`
#'  object.
#'
#' @examples
#' # get an auxiliary object of type fits_ids_dm for demonstration purpose
#' all_fits <- get_example_fits_ids()
#' sum_obj <- summary(all_fits)
#' print(sum_obj, round_digits = 2)
#'
#' @export
summary.fits_ids_dm <- function(object, ...) {
  fits_ids <- object
  ans <- list()
  ans$fit_procedure_name <- fits_ids$drift_dm_fit_info$fit_procedure_name
  ans$time_call <- fits_ids$drift_dm_fit_info$time_call

  l_u <- get_lower_upper_smart(
    drift_dm_obj = fits_ids$drift_dm_fit_info$drift_dm_obj,
    lower = fits_ids$drift_dm_fit_info$lower,
    upper = fits_ids$drift_dm_fit_info$upper
  )
  ans$lower <- l_u$lower
  ans$upper <- l_u$upper
  ans$model_type <- paste(
    class(fits_ids$drift_dm_fit_info$drift_dm_obj),
    collapse = ", "
  )
  all_prms <- coef(fits_ids, select_unique = FALSE)
  ans$prms <- all_prms
  prm_names <- colnames(all_prms)[!(colnames(all_prms) %in% c("ID", "Cond"))]
  means <- stats::aggregate(all_prms[prm_names], by = all_prms["Cond"], mean)
  std_errs <- stats::aggregate(all_prms[prm_names],
    by = all_prms["Cond"],
    \(x) stats::sd(x) / sqrt(length(x))
  )
  ans$stats <- sapply(conds(fits_ids), function(one_cond) {
    mean <- means[means$Cond == one_cond, -1]
    std_err <- std_errs[means$Cond == one_cond, -1]
    matrix <- rbind(mean, std_err)
    rownames(matrix) <- c("mean", "std_err")
    return(matrix)
  }, simplify = FALSE, USE.NAMES = TRUE)
  ans$N <- length(fits_ids$all_fits)

  class(ans) <- "summary.fits_ids_dm"
  return(ans)
}
