function [Sest,beta0,S,stats] = RSCM(X,varargin)
% RSCM computes regularized (shrinkage) sample covariance matrix (SCM) 
% using  shrinkage estimator proposed in [1]. The function computers either 
% RSCM-Ell1 or RSCM-CV estimator.   
%
% The Data is assumed to be centered (or symmetry center parameter = 0). 
%               
% INPUT 
% -----
%   X       : the data matrix with n rows (observations) and p columns.
%             observations can be real or complex-valued. 
%
% Name-Value Pair Arguments
% -------------------------
% RSCM can be called with numerous optional arguments. Optional
% arguments are given in parameter pairs, so that first argument is
% the name of the parameter and the next argument is the value for
% that parameter. Optional parameter pairs can be given in any order.
%
% Name      Value and description
%==========================================================================
% --Basic parameter is the choise of the estimator to use
%
% 'approach' : (string) defines which method is used to estimate the 
%       shrinkage parameter beta. Valid choises are 
%           'ell1' (default)    use RSCM-ELL1 estimator
%           'ell2'              use RSCM-ELL2 estimator
%           'cv'                use RSCM-CV (uses cross validation)
%           'loocv'             use leave one out cross validation
%==========================================================================
%
%  Optional inputs if 'approach' is 'ell1': 
% 
% 'kappa' : real number > -2/(p+2).
%       corresponds to elliptical kurtosis parameter to use. Must
%       be larger than kappa_lowerb = -2/(p+2). If not given, the the
%       function estimates kappa using the marginal kurtosis values. 
%
% 'gamma' : real number between [1,p] 
%       Sphericity measure to use.  By default, the function 
%       computes the Ell1-estimator of sphericity. 
%
%==========================================================================
%  
%  Optional inputs if approach selected was 'cv': 
%
% 'folds'    : positive integer (default = 5)
%       when CV is positive integer K, then K-fold cross-validation is used 
%
% 'betas' : nonnegative vector with elements in [0,1]
%       specifies the grid of beta valuues. If not given, then default grid 
%       is beta=0:0.025:1
% 
%==========================================================================
%
% OUTPUT
% ------
%   Sest    : The shrinkage SCM equal to beta0*S + (1-beta0)*[Trace(S)/p]*I
%   beta0   : the found optimal MMSE/CV shrinkage parameter beta
%   S       : the SCM
%   stats   : a structure array with fields
%       .eta     : trace(S)/p 
%     if 'approach'='ell1' then also fields
%       .kapest  : elliptical kurtosis kappa 
%       .gamest  : Ell1-estimate of sphericity
%     if 'approach'='cv' then also fields
%       .folds   : value K used for K-fold cross-validation
%       .betas   : grid of beta values used in cross-validation
%
% REFERENCE 
% ---------
%  [1] E. Ollila, D.P. Palomar, and F. Pascal, "M-estimators of scatter with 
%                   eigenvalue shrinkage", Arxiv, 2020. 
%
% Author: Esa Ollila, May 2020, Aalto University. 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%% Parse inputs 
argin = inputParser;
[n, p] = size(X);

valX = @(X) assert(numel(size (X))==2 && size(X,2) >= 2, ... 
        '''X'' must be a matrix having at least two columns (variables)');
addRequired(argin,'X',valX);

if any (any (isnan (X)))
    error ('Input data contains NaN''s.');
end

if ~isa (X, 'double')
    X = double(X);
end

if ~isreal(X)
   realdata=false; 
else
   realdata = true;
end

%-- approach  must be ell1 or cv
addParameter(argin,'approach','ell1', ... 
    @(x) logical(sum(strcmpi(x,{'ell1','ell2','cv','loocv'}))));

%-- optional input parsing rules for 'Ell1'

if realdata
   kappa_lowerb    = -2/(2*p+2); % theoretical lower bound for kappa     
else
   kappa_lowerb    = -2/(p+2); % theoretical lower bound for kappa
end
        
valKappa = @(x) assert(isscalar(x) && x > kappa_lowerb && isreal(x), ... 
    ['''kappa'' must be a real-valued scalar and larger than the ' ...
    'lowerbound ', num2str(kappa_lowerb)]);
addParameter(argin,'kappa',[], valKappa);

valGamma = @(x) assert( isempty(x) || (isreal(x) && x >= 1 && x < p), ... 
    ['''gamma'' must be empty set or real-valued scalar in the range [1,' num2str(p) ')']);
addParameter(argin,'gamma',[], valGamma);

%-- optional input parsing rules for 'cv'

valBetas = @(x) assert( all(x>=0) && all(x<=1) && isreal(x) && isvector(x), ...
        ['''betas'' must be a vector with elements in [0,1]']); 
addParameter(argin,'betas',[], valBetas);

valfolds = @(x) assert(x > 0 && isreal(x) && (x-round(x)==0), ...
    ['''folds' must be a positive integer']);
addParameter(argin,'folds',5, valfolds);

% cvloss can be Gaussian negative likelihood or MSE 
addParameter(argin,'cvloss','MSE', ... 
    @(x) logical(sum(strcmpi(x,{'MSE','Gaussian'}))));


%---  parse inputs
parse(argin,X,varargin{:});
compute_gamma = isempty(argin.Results.gamma);
compute_kappa = isempty(argin.Results.kappa);
folds = argin.Results.folds;
betas = argin.Results.betas;
approach = argin.Results.approach;
cvloss = argin.Results.cvloss;

%% CODE starts 

%-- Compute the SCM 
% note: it is assumed that the data is centered 
S = X'*X/n; % SCM initial start 
eta = trace(S)/p;

%-- Now compute the optimal shrinkage parameter value
if any(strcmpi(approach,{'ell1','ell2'}))

    %-- Compute the sphericity  \gamma
    is_centered = true; % assumption is that data is centered or zero mean
    
    %--  compute elliptical kurtosis  \kappa
    if compute_kappa
        print_info = false;
        kapest = ellkurt(X,[],is_centered,print_info);
    else
        kapest = argin.Results.kappa;
    end
    
    %-- Compute the sphericity 
    if compute_gamma
        switch approach 
            case 'ell1'
                gamest = gamell1(X,is_centered); 
            case 'ell2'
                gamest0 =  (trace(S^2)/p)/eta^2;
                gamest = gamell2M(n,p,1+kapest,gamest0,realdata);
            otherwise
                % Hmmm, something wrong with the parameter string
                error(['Unrecognized parameter: ''' approach '''']);           
        end
    else 
        gamest = argin.Results.gamma;
    end
    T = gamest - 1;

    %-- compute the data-adaptive estiamte of shrinkage parameter \beta_0 
    if isreal(X) 
        a =  kapest*(2*gamest*(1-1/p) + p-1)  + gamest*(1-2/p);
    else
        a = kapest*(gamest*(1-1/p)+p-1)  - gamest/p;
    end
    beta0 = T/( T + (a+p)/n );

    beta0 = min(max(0,beta0),1); % assure  that beta is between 0 and 1 
    stats.gamest=gamest;
    stats.kapest=kapest;
    
elseif strcmpi(approach,'cv')

    %---- use crossvalidation 
    
    if isempty(betas)
        % use grid from [0,1] with 0.05 steps
        betas = 0:0.025:1;
    end
    
    CVO = cvpartition(n,'KFold',folds);
    err = zeros(length(betas),folds);
    print_info = false;

    for i = 1:CVO.NumTestSets

        trIdx = CVO.training(i);
        teIdx = CVO.test(i);

        % Compute the SCM / EVD for training:
        Xtr =  X(trIdx,:);
        ntr = size(Xtr,1);
        Str = Xtr'*Xtr/ntr; % SCM initial start 
        
        % Compute the SCM for validation data
        Xv = X(teIdx,:);
        nv = size(Xv,1);    
        Sv = (Xv'*Xv)./nv;
        
        if strcmpi(cvloss,'Gaussian')

            [P,D,~] = svd(Str); 
            d = diag(D);
            eta = mean(d);

            tmp = P'*Sv*P;

            %% iterate over beta values 
            lik_beta = zeros(length(betas),1);  % this is for cross validations

            for j = 1:length(betas)
                be = betas(j);
                dbe = be*d + (1-be)*eta;
                a = log(dbe);
                lik_beta(j) = trace(bsxfun(@times, tmp, exp(-a))) + sum(a);
            end
            err(:,i) = lik_beta;
            
        else
            %% iterate over beta values 
            etaStr = trace(Str)/p;
            
            errNMSE = zeros(length(betas),1);  % this is for cross validations
    
            for j = 1:length(betas)
                be = betas(j);
                Str_be = be*Str + (1-be)*etaStr*eye(p);
                errNMSE(j) = norm(Str_be - Sv,'fro');
            end
            err(:,i) = errNMSE;
        end
        %% 

        if print_info
             fprintf('  fold %d done\n',i); 
        end
    end

    [~,ind] = min(mean(err,2));
    beta0 = betas(ind);
    stats.cvloss = cvloss;
    stats.betas=betas;
    stats.folds=folds;
    stats.ind=ind;
else
    % use loocv
    ave_normX4 = mean((sum(X.*X,2)).^2);
    tmp2 = (n^2-2*n)/((n-1)^2);
    tmp1 = n/(n-1);
    num = tmp1*trace(S^2) - trace(S)^2/n - (1/(n-1))*ave_normX4;
    den = tmp2*trace(S^2) - trace(S)^2/n + (1/(n-1))*ave_normX4;
    beta0  = num/den;
end
    
%-- shrinkage RSCM estimator:
Sest = beta0*S + (1-beta0)*eta*eye(p);
stats.eta=eta; 
stats.approach = approach;

end

