function [a_hat,P,stats,covest] = penscm_rie0(X,Xv,c,etas,verbose,ret_cov,preconly) 
% PENSCM_RIE0 computes regularized (shrinkage) sample covariance matrix
% (SCM) estimator that solves the penalized  negative gaussian likelihood 
% function with Riemannian penalty function with a shrinkage towards a scaled 
% identity  matrix. 
%
% [a_hat,P,stats,covest] = PENSCM_RIE0(X,Xv,...) 
%
% PENSCM_RIE0 is the utility function of PENSCM
% 
% Given an input data matrix X (rows are observations), the algorithm 
% and optional validation data set Xv, the algorithm scans through a set 
% of penalty parameter values in a input grid etas or using the default 
% grid computed by the function. The outputs are the log(eigenvalue) 
% estimates a = log(lambda_j) and an orthonormal eigenvector matrix P of 
% the solution to the penalized optimization problem. If the validation 
% data set is also given, the algorithm outputs the best penalty value and 
% corresponding covariance matrix and precision matrix estimate. 
% 
% INPUT 
%
%    X          (numeric matrix) Data set of size n x p 
%    Xv         (numeric matrix) Validation data set of size nn x p 
%               This is optional. 
%    etas       (numeric vector) Vector of eta values (positive reals)
%    c          The covariance matrix estimator is  shrinked  towards c*I, 
%               where c>0 or a string 'mean'. In case 'mean' the value of c 
%               will be the mean of the eigenvalues of the SCM of the 
%    verbose    (logical) print details of the algorithm (default true)
%    ret_cov    (logical) if true, then return the covariance and inverse 
%               covariance matrix correponding to the penalty parameter   
%               value that had the smallest log-likelihood fit for the
%               validation data set. (Default false)
%    preconly   (logical) if true, then return the inverse 
%               covariance matrix only
%
% OUTPUT 
%  a_hat        (cell) All log(eigenvalue) estimates for each eta value on
%               the grid
%  lik_eta     likelihood values for given eta values and validation
%               set
%  stats        (cell) with following elements: 
%      it       # of NR iterations for each eta value
%      ind      index of the smallest value in lik_eta vector
%  covest       (cell) with following elements:
%    Theta_hat  Best precision matrix estimate based on validation set.
%               returned only if ret_cov is equal to true
%    Sigma_hat  Best covariamce matrix estimate based on validation set
%               returned only if ret_cov is equal to true
% 
% NOTE: Theta_hat, Sigma_hat are non-empty only if validation (test) data 
% set Xv is given. 
% 
% See also: PENSCM
%
% toolbox: SCATTER
% authors: Copyright by Esa Ollila and Michael Muma, 2019
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

[n,p] = size(X);

if nargin < 7
    preconly = false;
end

if nargin < 6
    ret_cov = false;
end

if nargin < 5
    verbose = true;
end

if nargin < 4
    etas = [];
end

if nargin < 3
    c = 1;
end

if nargin < 2
    Xv = [];
end

if  (nargin>=4 || ~isempty(etas)) && sum(etas > 0) ~= length(etas)
    error('The grid of penalty parameter values need to be positive reals');
end


if nargin < 4 || isempty(etas)
    %% Set-up the grid of eta values
    betas = fliplr([0.001 0.01 (1:9)/10]); % small beta --> large eta
    etas = (1-betas)./betas;
    stats.etas = etas;
    % NOTE: If eta is very small and n < p, then the NR algorithm may have
    % convergence problems
else
    stats.etas = etas;
end

%--  'c' parameter 
if isempty(c)
    c = 1;
end

if (isa(c,'char') && ~strcmpi(c,'mean'))
   error(message('stats:RSCM_Rie0:Invalid_c_type'));
elseif  isreal(c) 
    if c < 0
        error(message('stats:RSCM_Rie0:negative_c'));
    end
end
   
%-- set up 
MAX_ITER = 1000; % max number of NR-iterations
EPS = 1e-7;     % tolerance level  
X  = X  - sum(X,1) ./n;  % Centering the traning data

if ~isempty(Xv)
    nv = size(Xv,1);
    Xv = Xv - sum(Xv,1)./nv;  % Centering the validation data
    Sv = (Xv'*Xv)./(nv-1);
end

%-- Compute the SVD: 
if p/n < 20 
    S =(X'*X)./(n-1); % := cov(X); 
    [P,D,~] = svd(S);  % eigenvectors and eigenvalue matrix
    d = diag(D);
else
    % We use SVD trick when the dimensiolity is atleast one order of
    % magnitude larger than the number of variables. 
    [U,D,~] = svd((X*X'));  % eigenvectors and eigenvalue matrix
    d = diag(D);
    indx = find(d>1e-10); 
    % Note:  if p >= n, then there are p - n +1  zeros eigenvalues
    % or equivalently: n-1 non-zero eigenvalues
%     d(indx)=0;
    inv_sqrtd = 1./sqrt(d(indx));     
    P1 = X'*U(:,indx)*diag(inv_sqrtd);
    P2 = null(P1');
    P = [P1 P2];
    % norm((P1*diag(d(1:end-1)/(n-1))*P1')-S,'fro')
    d = [d/(n-1); zeros(p-n,1)];
    %--Testing: all these should be close to zero 
    % norm((P*diag(d)*P')-S,'fro')
    % trace(S) - sum(d) % should be close to zero 
end
    

if isa(c,'char') 
    c = mean(d);
    %if verbose
    %   fprintf('shrinkage towards the mean %.7f of the eigenvalue\n',c);
    %end
end

% set-up the initial guess 
a0 = log(d);
a0(d==0) = log(1e-12); 

% initialize 
it = zeros(1,length(etas));
a_hat = cell(length(etas),1); 

if ~isempty(Xv)
    tmp = P'*Sv*P;
    stats.lik_eta = zeros(length(etas),1);  % this is for cross validations
else
    stats.lik_eta = [];
end


% then do the iterations through a grid of etas
for ii = 1:length(etas)

    eta = etas(ii);
    
    if verbose && length(eta)>1 
        fprintf('Computing the estimator over grid of eta values\n');
    end
        
    for iter = 1:MAX_ITER

        % update the terms in Eq. (6) according to the current step
        g = d .* exp(-a0) - 1  - 2*eta*(a0-log(c));
        del =  d .* exp(-a0) + 2*eta;                      
        % compute the update vector:   
        update = g./del;
          
        % update step
        a1 = a0 + update;  
    
        % Termination rule 
        err = norm(update,Inf);  
        
        %if mod(iter,1)==0
        %    fprintf('%.7f\n',err);
        %end

        if err < EPS
            break
        end
    
        a0 = a1; 
    end
    
    if verbose
        fprintf(' . ');
    end
    
    if (iter==MAX_ITER)
        fprintf(1,'WARNING! Slow convergence: the error of the solution is %f\n',err);  
        fprintf(1,'after %d iterations\n',iter);
        a0 = log(d);
        a0(d==0) = log(1e-12); 
    end    
    
    it(ii) = iter;
    a_hat{ii} = a1;
        
    if ~isempty(Xv)
        if iter == MAX_ITER
              stats.lik_eta(ii) = Inf;
        else
            stats.lik_eta(ii)= trace(bsxfun(@times, tmp, exp(-a1))) + sum(a1);
        end
    end
       
end
if verbose
        fprintf('\n');
end
if ~isempty(Xv)
    [~,stats.ind]= min(stats.lik_eta); 
    if ret_cov
        covest.Theta_hat = P*diag(exp(-a_hat{stats.ind}))*P'; 
        if preconly 
            covest.Sigma_hat = [];
        else
            covest.Sigma_hat = P*diag(exp(a_hat{stats.ind}))*P';   
        end
    end
else
    stats.ind = [];
end


if length(etas)==1 && ret_cov
    covest.Theta_hat = P*diag(exp(-a1))*P'; 
    if preconly 
       covest.Sigma_hat = [];
    else
       covest.Sigma_hat = P*diag(exp(a1))*P';   
    end
end   

if ~ret_cov
    covest.Theta_hat = [];
    covest.Sigma_hat = [];
end
