#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::plugins(cpp11)]]

using namespace Rcpp;
using namespace arma;

//--------------------------------------------------------------
// Soft-thresholding operator for element-wise Lasso penalty
inline void soft_threshold_inplace(arma::mat& X, const arma::mat& lambda) {
  X = sign(X) % max(abs(X) - lambda, zeros<arma::mat>(size(X)));
}

//--------------------------------------------------------------
// Cheap spectral-norm estimate via a few power iterations
// Works well for SPD (xtx, Theta). Falls back to inf-norm bound.
// ITER <= 8 empirically enough; early exit if converged.
//--------------------------------------------------------------
// Helper function
inline double power_norm2_sym(const arma::mat& A, const int maxit = 8, const double tol = 1e-6) {
  const uword m = A.n_rows;
  if (m == 0) return 0.0;
  
  arma::vec x = randu<vec>(m);
  x /= norm(x, 2);
  double prev = 0.0;
  
  for (int it = 0; it < maxit; ++it) {
    arma::vec y = A * x;
    const double ny = norm(y, 2);
    if (ny <= 0.0) break;
    x = y / ny;
    const double rayleigh = as_scalar(x.t() * (A * x));
    if (std::abs(rayleigh - prev) <= tol * std::max(1.0, std::abs(rayleigh))) {
      return std::abs(rayleigh);
    }
    prev = rayleigh;
  }
  // Fallback: row-sum norm as a safe bound
  return std::max(prev, norm(A, "inf"));
}

//--------------------------------------------------------------
// Compute smooth objective f(B) using precomputed temp = xtx*B - xty
//--------------------------------------------------------------
// Helper function
inline double compute_objective_from_temp(const arma::mat& B,
                                          const arma::mat& temp,  // xtx*B - xty
                                          const arma::mat& xty,
                                          const arma::mat& Theta,
                                          const double inv_n,
                                          const bool theta_is_diag) {
  if (theta_is_diag) {
    const arma::vec th = Theta.diag();
    // For each column k: yk^T xtx yk - 2 yk^T xty_k
    // = yk^T (temp_k) - yk^T xty_k
    const arma::vec v1 = sum(B % temp, 0).t();     // q×1, each is yk^T temp_k
    const arma::vec v2 = sum(B % xty, 0).t();      // q×1, each is yk^T xty_k
    return inv_n * accu((v1 - v2) % th);
  } else {
    // f(B) = (1/n) trace( (B^T temp - B^T xty) * Theta )
    const arma::mat M1 = B.t() * temp;            // q×q
    const arma::mat M2 = B.t() * xty;             // q×q
    return inv_n * trace((M1 - M2) * Theta);
  }
}

//--------------------------------------------------------------
// Original objective for general calls (kept).
//--------------------------------------------------------------
inline double compute_objective(const arma::mat& B, const arma::mat& xtx, const arma::mat& xty,
                                const arma::mat& Theta, const double inv_n,
                                const bool theta_is_diag) {
  if (theta_is_diag) {
    const vec theta_diag = Theta.diag();
    return inv_n * accu(diagvec(B.t() * xtx * B) % theta_diag) -
      2.0 * inv_n * accu(diagvec(B.t() * xty) % theta_diag);
  } else {
    return inv_n * trace((B.t() * xtx * B - 2.0 * B.t() * xty) * Theta);
  }
}

//--------------------------------------------------------------
// Compute gradient: grad = (2/n) (xtx*B - xty) * Theta
//--------------------------------------------------------------
inline void compute_gradient(arma::mat& grad, const arma::mat& B, const arma::mat& xtx,
                             const arma::mat& xty, const arma::mat& Theta,
                             const double two_inv_n, const bool theta_is_diag,
                             arma::mat* temp_out = nullptr) {
  arma::mat temp = xtx * B - xty;  // p×q
  if (theta_is_diag) {
    const vec th = Theta.diag();
    grad = two_inv_n * (temp.each_row() % th.t());
  } else {
    grad = two_inv_n * (temp * Theta);
  }
  if (temp_out) *temp_out = std::move(temp);
}

// [[Rcpp::export]]
List updateBeta(const arma::mat& Theta, const arma::mat& B0,
                const int n, const arma::mat& xtx, const arma::mat& xty,
                const arma::mat& lamB, const double eta = 0.8,
                const double tolin = 1e-4, const int maxitrin = 1000,
                const bool adaptive_restart = true, const double restart_tol = 1e-6) {
  
  const uword p = B0.n_rows;
  const uword q = B0.n_cols;
  
  // Constants
  const double inv_n    = 1.0 / n;
  const double two_inv_n= 2.0 * inv_n;
  
  // Clamp eta to a sensible range
  const double eta_used = std::min(0.95, std::max(0.1, eta));
  
  const bool theta_is_diag = (norm(Theta - diagmat(Theta.diag()), "fro") < 1e-12);
  
  // Initialize variables
  arma::mat B     = B0;            // current iterate
  arma::mat B_prev= B0;            // previous iterate
  arma::mat Y     = B0;            // extrapolated point
  arma::mat Y_prev= B0;            // store previous Y for BB
  
  arma::mat B_new(p, q, fill::zeros);
  arma::mat grad_f(p, q, fill::zeros);
  arma::mat Gt(p, q, fill::zeros);
  arma::mat lamBt(p, q, fill::zeros);
  
  // FISTA momentum
  double t_k = 1.0, t_k_prev = 1.0;
  
  // Step size (1/L); initialize via cheap spectral-norm estimates
  // avoid full eigendecompositions
  double norm_xtx  = power_norm2_sym(xtx, 6); // 4–8 iters are plenty
  double norm_theta= theta_is_diag ? Theta.diag().max() : power_norm2_sym(Theta, 6);
  double L = std::max(1e-12, two_inv_n * norm_xtx * std::max(1.0, norm_theta));
  double alpha = 1.0 / L;
  
  // Barzilai–Borwein memory
  arma::mat s_BB(p, q, fill::zeros);   // Y_k - Y_{k-1}
  arma::mat y_BB(p, q, fill::zeros);   // grad(Y_k) - grad(Y_{k-1})
  arma::mat grad_prev(p, q, fill::zeros);
  bool use_BB = false;
  
  // Objective at starting point
  double f_B = compute_objective(B, xtx, xty, Theta, inv_n, theta_is_diag);
  
  int   iter = 1, restarts = 0;
  bool  converged = false;
  
  while (iter <= maxitrin && !converged) {
    if (iter > 1) {
      grad_prev = grad_f;
      Y_prev = Y;  // keep previous Y for consistent BB
    }
    
    // Gradient at Y (also get temp = xtx*Y - xty for objective reuse)
    arma::mat tempY;
    compute_gradient(grad_f, Y, xtx, xty, Theta, two_inv_n, theta_is_diag, &tempY);
    
    // Smooth objective at Y using temp
    double f_Y = compute_objective_from_temp(Y, tempY, xty, Theta, inv_n, theta_is_diag);
    
    // Barzilai–Borwein step size using Y/grad(Y)
    if (iter > 2 && use_BB) {
      y_BB = grad_f - grad_prev;
      s_BB = Y - Y_prev;
      const double s_dot_y = accu(s_BB % y_BB);
      const double s_dot_s = accu(s_BB % s_BB);
      if (std::abs(s_dot_y) > 1e-16 && s_dot_s > 1e-32) {
        const double alpha_BB = s_dot_s / s_dot_y; // BB2
        if (alpha_BB > 1e-16 && alpha_BB < 1e16) {
          alpha = alpha_BB;
          L = 1.0 / alpha;
        }
      }
    }
    
    // Backtracking line search (proximal, smooth part only)
    alpha = std::min(alpha * 1.1, 1.0 / L);  // cap at 1/L
    bool armijo_ok = false;
    int  bt_iter = 0;
    const int max_bt_iter = 50;
    // Smooth objective at B_new; re-use gradient mapping for bound
    double f_B_new_smooth;
    
    while (!armijo_ok && bt_iter < max_bt_iter) {
      lamBt = lamB * alpha;
      
      B_new = Y - alpha * grad_f;
      soft_threshold_inplace(B_new, lamBt);
      
      {
        // Compute temp at Y already; for bound we only need Gt
        Gt = (Y - B_new) / alpha;
        
        // Armijo-style quadratic upper bound for f
        const double grad_dot_Gt = accu(grad_f % Gt);
        const double Q_bound = f_Y - alpha * grad_dot_Gt + 0.5 * alpha * accu(square(Gt));
        
        // Evaluate f(B_new): we need temp at B_new; compute once
        arma::mat tempB = xtx * B_new - xty;
        f_B_new_smooth = compute_objective_from_temp(B_new, tempB, xty, Theta, inv_n, theta_is_diag);
        
        armijo_ok = (f_B_new_smooth <= Q_bound + 1e-12);
      }
      
      if (!armijo_ok) {
        alpha *= eta_used;  // shrink step
        if (alpha < 1e-16) {
          warning("Step size became too small, stopping line search");
          break;
        }
      }
      ++bt_iter;
    }
    
    L = 1.0 / alpha; // update Lipschitz estimate
    
    if (bt_iter >= max_bt_iter) {
      warning("Backtracking line search reached maximum iterations");
    }
    
    // Adaptive restart checks
    bool restart_triggered = false;
    if (adaptive_restart && iter > 1) {
      // (a) Monotonic smooth f test
      if (f_B_new_smooth > f_B + restart_tol * std::abs(f_B)) {
        restart_triggered = true; ++restarts;
      } else if (iter > 2) {
        // (b) Momentum orthogonality test at Y/B (O’Donoghue–Candès flavor)
        const arma::mat diff1 = Y - B_new;
        const arma::mat diff2 = B_new - B;
        const double dotm = accu(diff1 % diff2);
        if (dotm > 0) { restart_triggered = true; ++restarts; }
      }
    }
    
    // Convergence criteria
    const double change_norm = norm(B_new - B, "fro");
    const double B_norm = norm(B, "fro");
    const double grad_map_norm = norm(Gt, "fro");
    
    bool conv_change = (B_norm > 1e-8) ? (change_norm / B_norm < tolin) : (change_norm < tolin);
    bool conv_grad   = grad_map_norm < std::max(tolin, 1e-8);
    
    // Smooth objective update (for monotonic check next iter)
    const double f_B_new = f_B_new_smooth;
    bool conv_obj   = (iter > 1) && (std::abs(f_B_new - f_B) < tolin * std::max(1.0, std::abs(f_B)));
    
    bool done = conv_change || conv_grad || conv_obj;
    
    if (!done) {
      // BB bookkeeping now that step accepted
      if (iter > 1) use_BB = true;
      
      B_prev = std::move(B);
      B      = std::move(B_new);
      f_B    = f_B_new;
      
      // FISTA momentum
      if (restart_triggered) {
        t_k = 1.0;
        Y   = B;     // no extrapolation
      } else {
        t_k_prev = t_k;
        t_k = 0.5 * (1.0 + std::sqrt(1.0 + 4.0 * t_k_prev * t_k_prev));
        const double beta_k = (t_k_prev - 1.0) / t_k;
        Y = B + beta_k * (B - B_prev);
      }
      
      ++iter;
    } else {
      B   = std::move(B_new);
      f_B = f_B_new;
      converged = true;
    }
  }
  
  return List::create(
    Named("Bhat")      = B,
    Named("it.final")  = iter,
    Named("converged") = converged
  );
}
