function [yhat,B,I,K,cvstats,covstats] =  CRDA(Xt,X,y,varargin)
% CRDA performs compressive regularized (linear) discriminant analysis,
% referred to as CRDA, proposed in Tabassum and Ollila (2019) (see also
% Tabassum and Ollila (2018) for preliminary results).
%
% CRDA classifies each column of the test data set Xt (p x N) into one of
% the G classes. Test data set Xt and training data set X must have the 
% same number of rows (features or variables). Vector y is a
% class variable of training data. Its unique values define classes; each
% element defines the class to which the corresponding column of X belongs.
% The input y is a numeric vector with integer elements ranging from
% 1,2,..,G, where G is the number of classes. Note that y must have the
% same number of rows as there are columns in X. The output yhat indicates
% the class to which each column of Xt has been assigned. Also yhat is Nx1
% vector of integers ranging from 1,2,...,G. 
%
% By default, the CRDA function uses the CRDA2 method which uses the Ell2-RSCM
% estimator as the estimator of the covariance matrix, and cross validation (CV) 
% to select the optimal joint sparsity level K and the hard-thresholding 
% selector function. 
%
% The grid of sparsity levels K used in the range [1,p] is determined 
% automatically if not given as optional parameter. 
%
% For more details,  we refer to Ollila and Tabassum (2019).
%
% USAGE:
% ------
% yhat = CRDA(Xt,X,y)
% rng(iter); % for reproducibility
% yhat = CRDA(Xt,X,y,Name,Value)
% [yhat,B,I,k,cvstats,covstats] =  CRDA(___)
%
% EXAMPLES
% --------
% To call CRDA2 method you can use all of the following examples:
%
% mc = 0;
% rng(mc); % for reproducibility (due to randomness of cross validation)
% yhat1 = CRDA(Xt,X,y,'verbose','on');
% rng(mc); % for reproducibility (due to randomness of cross validation)
% yhat2 = CRDA(Xt,X,y,'method','crda2','verbose','on');
% rng(mc); % for reproducibility (due to randomness of cross validation)
% yhat3 = CRDA(Xt,X,y,'q','cv','cov','ell2','K','cv','verbose','on');
% isequal(yhat1,yhat2,yhat3)
%
% Name-Value Pair Arguments
% -------------------------
% CRDA can be called with numerous optional arguments. Optional
% arguments are given in parameter pairs, so that first argument is
% the name of the parameter and the next argument is the value for
% that parameter. Optional parameter pairs can be given in any order.
%
% Name      Value and description
%==========================================================================
% 'cov' is specifies which covariance matrix estimator is to be
% used in the LDA discriminant rule. Two state-of-the-art methods are
% implemented.
%
% 'cov'     (string) which estimate to use
%           'ell2' (default) use Ell2-RSCM estimator as detailed in
%                            Ollila and Raninen (2019).
%           'ell1'           use Ell1-RSCM as detailed in Ollila and
%                            Raninen (2019)
%           'riemann'        use Rie-PSCM estimator with Riemannian penalty
%                            with shrinkage towards the mean of the
%                            eigenvalues of the sample covariance matrix.
%                            The estimator is computed using the
%                            Newton-Raphson algorithm detailed in Tyler
%                            and Xi (2019).
%==========================================================================
%
% 'q'        scalar (>=1) or string 'var' or 'inf' or 'cv'
%            If q is a real scalar, then it must be >= 1 and it
%            denotes the L_q-norm to be used in the hard
%            thresholding operator H_K(B,phi) in the CRDA method.
%            If q is equal to a  string 'inf' (or 'var') then L_infty
%            norm will be used (or sample variance) will be  used
%            as the hard-thresholding  selector function. If q is equal to
%            a string 'cv', then CV is used to select the best
%            hard-thresholding selector function  among the L_1-, L_2-,
%            L_inf-norm and the sample variance. 
%
% 'prior'   numeric vector of length G
%           specifies prior values used in the discriminant rule. The
%           elements should sum to to 1 (i.e., sum(prior)==1). Default is
%           uniform priors.
%
% 'kgrid'   vector of integer values in the range [1,2,...p]
%           each element specifies a joint sparsity level K that is used in
%           the hard-thresholding operator H_K(B,phi). The function uses
%           cross-validation (CV) to pick the optimal sparsity level from
%           this grid. If value is not given then a uniform grid of 10
%           values in log-space ranging from [0.05 x p, K_ub] is used. 
%
% 'nfolds'  positive interger
%           specifies the number of folds to be used in the CV scheme of
%           the joint sparsity values K in the kgrid.
%
% 'coefmat' real matrix of size p x G
%           specifies the coefficient matrix B computed from the training
%           data set X. This value should be equal to B = Sigma^-1 * mu,
%           where Sigma is the covariance matrix estimator based on the
%           training dataset X and mu is the p x G matrix of sample mean
%           vectors.  Note: coefmat is specified only if you have computed
%           the desired covariance matrix Sigma and the related coefficient
%           matrix B based on it.  Default value is [] which implies that 
%           B is computed using the Ell1-RSCM, Ell2-RSCM estimator or the 
%           Rie-PSCM estimator depending on the value of the optional
%           parameter 'cov'.
%
% 'mu'      real matrix of size p x G
%           matrix with class sample mean vectors as columns.
%
% 'verbose' string equal to 'on' or 'off' (default).
%           When 'on' then one prints progress of the function in
%           text format.
%
% Output Arguments
% ----------------
% yhat     vector of size N x 1 of integers from 1, ..., G.
%          specifies  the group to which each column of Xt has been
%          assigned
%
% B        matrix of size p X G
%          coefficient matrix calculated by the function (if not given by
%          the optional argument 'coefmat'
%
% Ind      indices of the length of the row vectors of B organized in
%          descending order, where the length is determined by optional
%          argument 'q' (if q is scalar, q >=1, then the length is
%          determined by L_q-norm)
%
% K        the best value in the grid (= K_CV) found using CV
%
% cvtats   struct that contains data of the CV with fields:
%   .nfolds     # of folds used in the CV scheme
%   .cverr      accumulated errors for each K in the grid cvstats.kgrid
%   .kgrid      the used grid of sparsity levels
%   .indx       indices of cverr when arranged from smallest to largest
%   .q0         the q-value picked by CV (1,2,Inf, or 0=sample variance9
%   .indxq0     index of picked q from the list qvals = [inf,0,2,1]
%
% -------------------------------------------------------------------------
% See also: CRDA0, CRDA_COEFMAT, HARD_THRESHOLD
%
% DEPENDENCIES:
% -------------
% Install the toolbox regularizedSCM needed for computation of Ell2-RSCM
% and Ell1-RSCM covariance matrix estimator:
%
% http://users.spa.aalto.fi/esollila/regscm/
%
% R-version available at:
%       https://github.com/mntabassm/compressiveRDA
%
% REFERENCES:
% ------------
% If you use this code in your research, then please cite:
%
% [1] M.N. Tabassum and E. Ollila,"A Compressive Classification Framework 
%       for High-Dimensional Data," preprint, submitted for publication, 
%       Oct. 2019.
%
% AUTHORS
% -------
% Esa Ollila and Muhammad Naveed Tabassum, Aalto University, October 2019.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%-- Check requirements on the input data matrix X
argin = inputParser;
valX = @(X) assert(isreal(X) && (numel(size (X))==2),['''X'' must be a ' ...
    'real-valued matrix having at least two columns (variables)']);
addRequired(argin,'X',valX);

if nargin < 3
    error (['You must supply the test data Xt, training data matrix X and' ...
        ' the corresponding classes y  as input argument.']);
end

if any (any (isnan (X)))
    error ('Input data contains NaN''s.');
end

if ~isa (X, 'double')
    X = double(X);
end

p = size(X,1);
n  = histcounts(y);	% class labels counts
G = length(n);

if size(X,2) ~= size(y,1)
    error('''X'' must be an p x n matrix and ''y'' a n x 1 vector');
end

%-- optional input parsing rules
valmethod = @(x) assert(any(strcmpi(x,{'crda1','crda2','crda3'})), ...
    ('''method'' must be string equal to ''crda1'',''crda2'',''crda3'''));
addParameter(argin,'method','crda2',valmethod);

valK = @(x) assert(isempty(x) || (isscalar(x) && x >= 1 && x <=p ) || strcmpi(x,{'cv'}), ...
    ['''K'' must be a string equal to ''CV''  or an integer between ' ...
    '1 to %d or an empty string'],p);
addParameter(argin,'K',[], valK);

valcov = @(x) assert(any(strcmpi(x,{'ell1','ell2','riemann'})), ['''verbose'' ' ...
    'must be string equal to ''ell2'' or ''ell1'' or ''riemann''']);
addParameter(argin,'cov','ell2',valcov);

valq = @(x) assert( (isscalar(x) && x >= 1) || any(strcmpi(x,{'inf','var','cv'})), ...
    ['''q'' must be a string equal to ''var'' or ''inf'' or ''cv'' real-valued scalar ' ...
    'greater or equal to 1']);
addParameter(argin,'q','cv', valq);

valprior = @(x) assert( isreal(x) && (numel(x) == G) && isvector(x) && (abs(sum(x)-1)<=1e-10), ...
    '''prior'' must be a vector with G elements that sum to 1');
addParameter(argin,'prior',1/G * ones(1,G), valprior);

valkgrid = @(x) assert( isreal(x) && isvector(x) && isequal(round(x),x) && all(x>0), ...
    '''kgrid'' must be a vector with positive integer elements');
addParameter(argin,'kgrid',[], valkgrid);

valnfolds = @(x) assert( isscalar(x) && isreal(x) && x > 0 && isequal(round(x),x), ....
    '''nfolds'' must be a positive integer scalar');
addParameter(argin,'nfolds',5,valnfolds);

valcoefmu = @(x) assert( ismatrix(x) && size(x,2)==G && size(x,1)==p, ...
    'the input must be a real matrix of size p x G');
addParameter(argin,'coefmat',[],valcoefmu);
addParameter(argin,'mu',[],valcoefmu);

valVerbose = @(x) assert(any(strcmpi(x,{'on','off'})), ...
    '''verbose'' must be a string equal to ''on'' or ''off''');
addParameter(argin,'verbose','off',valVerbose);

%---  parse inputs
parse(argin,X,varargin{:});
q       = argin.Results.q;
covtype = argin.Results.cov;
prior   = argin.Results.prior;
nfolds  = argin.Results.nfolds;
B       = argin.Results.coefmat;
mu      = argin.Results.mu;
kgrid   = argin.Results.kgrid;
print_info = strcmpi(argin.Results.verbose,'on');
K = argin.Results.K;
method  = argin.Results.method;

if isempty(which('ellkurt'))
    error(['CRDAcv:: you need to install RegularizedSCM toolbox\n' ....
        'you may download it from http://users.spa.aalto.fi/esollila/regscm/']);
end

if ischar(q) 
        q = lower(q);
end
        
switch q
    % encode q to numeric if 'var' or 'inf'
  
    case 'var'
       q=0;           % q = 0 -> Variance 
    case 'inf'
       q = inf ;     % L_inf-norm  
    case 'cv' 
       q = -1;              
end

if ~isempty(method)
    switch lower(method)
        case 'crda1'
            covtype = 'ell1';
            Kmethod = 'CV';
            q = -1;
        case 'crda2'
            covtype = 'ell2';
            Kmethod = 'CV';
            q = -1;
        case 'crda3'
            covtype = 'riemann';
            q = inf;
            Kmethod = [];               
    end
end

if print_info
    
    %-- print information about data
    fprintf('Number of variables           : %d\n', p);
    fprintf('Number of samples             : %d\n', sum(n));
    fprintf('Number of folds for CV        : %d\n', nfolds);
    
    if q>=1
        fprintf('using L_{%.1f}-norm in the hard-thresholding operator\n',q);
    else
        fprintf('using CV to pick the hard-thresholding operator\n');
    end

end

if isempty(B)
    if print_info
        fprintf('...computing the coefficient matrix using %s-estimator.\n',covtype);
    end
    [B, mu,covstats] = CRDA_coefmat(X,y,covtype);
else
    covstats = [];
end

if isempty(mu)
    % need to calculate the class means
    I = eye(max(y));
    Y = I(y, :);	% nxG indicator matrix => Yij = 1, i belong to group j
    mu = X*Y ./ n;
end

qvals = [inf,0,2,1];
% Encoding : inf = L_infty, 0 = var, 2 = L_2-norm, 1=L_1 norm

if strcmpi(Kmethod,'CV')
    
    if isempty(kgrid)
        % Kgrid not given, then compute it using K_up computed from B
        nK = 10; % default is 10 values in the grid
        fromK = 0.05*p;
        if q >= 0 
            len = zeros(p,size(q))
        else
            len = zeros(p,4);
        end

        switch q
            case 0
                len = var(B,[],2);         % Variance (rows as vectors)
            case inf 
                len = max(abs(B), [],2);   % Linf-norm (rows as vectors)
            case -1 
                % use CV 
                zeros(p,2);
                len(:,1) = max(abs(B), [],2); 
                len(:,2) = var(B,[],2); 
                len(:,3) = sum(abs(B).^2, 2).^(1/2);
                len(:,4) = sum(abs(B),2);
            otherwise
                len = sum(abs(B).^q, 2).^(1/q);	% Lq-norm, q>=1                 
        end            
        Kup = sum(len>mean(len));
        Kup = min(Kup);
        kgrid = round(exp(linspace(log(fromK),log(Kup),nK)));
    end
    
    if print_info
        fprintf(['...using K-grid of %d points on the interval '...
            '[%d,%d]\n'],length(kgrid),kgrid(1),kgrid(end));
    end
    
    nK = length(kgrid);
    if q==-1
        cverr = zeros(nK,4);
    else
       cverr = zeros(nK,1);
    end
    indices = crossvalind('Kfold',y,nfolds);
        
    %--- compute the crossvalidation errors on the grid
    for ii = 1:nfolds
        
        test = (indices == ii);
        train = ~test;
        Xii = X(:,train); % used for training
        yii = y(train);
        Xho = X(:,test);  % hold-out (validation) data set
        yho = y(test);
        
        I = eye(G);
        Yii = I(yii, :);	% nxG indicator matrix => Yij = 1, i belong to group j
        nii  = histcounts(yii);	% class labels counts
        muii = Xii*Yii ./ nii;
        
        % Center the training data
        for g = 1:G
            Xii(:,yii==g) = Xii(:,yii==g) - repmat(muii(:,g),1,nii(g));
        end
        
        if q == -1
            yhat = zeros(size(Xho,2),4);
            Ind = zeros(p,4); 
        end
        
        % Sweep through the grid
        for jj = 1:nK
            
            K = kgrid(jj);
            if jj == 1
                            
                if q==-1                  
                   [yhat(:,1),Bii,Ind(:,1)] = CRDA0(Xho,Xii,yii,covtype,K,...
                       Inf,prior,[],muii,true); % L_inf norm
                   [yhat(:,2),~,Ind(:,2)] = CRDA0(Xho,Xii,yii,covtype,K, ....
                       0,prior,Bii,muii,true);  % var                                               
                   [yhat(:,3),~,Ind(:,3)] = CRDA0(Xho,Xii,yii,covtype,K,...
                       2,prior,Bii,muii,true);  % L2-norm
                   [yhat(:,4),~,Ind(:,4)] = CRDA0(Xho,Xii,yii,covtype,K, ....
                       1,prior,Bii,muii,true);  % L1-norm
                else
                   [yhat,Bii,Ind] = CRDA0(Xho,Xii,yii,covtype,K,...
                        q,prior,[],muii,true);
                end

            else
                
                if q==-1 
                    yhat(:,1) = CRDA0(Xho,Xii,yii,covtype,K,Inf,prior,Bii, ...
                            muii,true,Ind(:,1));   
                    yhat(:,2) = CRDA0(Xho,Xii,yii,covtype,K,0,prior,Bii, ...
                            muii,true,Ind(:,2));  
                    yhat(:,3) = CRDA0(Xho,Xii,yii,covtype,K,2,prior,Bii, ...
                            muii,true,Ind(:,3));  
                    yhat(:,4) = CRDA0(Xho,Xii,yii,covtype,K,1,prior,Bii, ...
                            muii,true,Ind(:,4));                           
                else
                    yhat = CRDA0(Xho,Xii,yii,covtype,K,q,prior,Bii, ...
                         muii,true,Ind);
                end
            end
            
            cverr(jj,:) = cverr(jj,:) + sum(yhat ~= yho)/length(yho);
        end
    end
    
    %-- Choose the best among the grid values
    cverr = cverr/nfolds;
    [idxK,idxq] = ind2sub(size(cverr),find(cverr<=min(cverr(:))));
    idxK0 = min(idxK);
    % if there are many q values that had small CV err, then 
    % determine the best as the one having small mean CV err
    if sum(idxK==idxK0) ~= 1 
        bst_indx = idxq(idxK==idxK0);
        [~,pick_indx] = min(mean(cverr(:,bst_indx)));
        idxq0 = bst_indx(pick_indx);
        q0 = qvals(idxq0);
    else
        idxq0 = idxq(idxK==idxK0);
        q0 = qvals(idxq0);
    end
    K = kgrid(idxK0);
    
    if print_info
        fprintf('...chosen sparsity level : %d\n',K);
    end
    
    %-- now we can use CRDA for the best value
    [yhat,~,I] = CRDA0(Xt,X,y,covtype,K,q0,prior,B,mu);
    
    cvstats.nfolds = nfolds;
    cvstats.cverr = cverr;
    cvstats.kgrid = kgrid;
    cvstats.idxK0  = idxK0;
    cvstats.q = q0;
    cvstats.idxq0  = idxq0;
else
    % Use value for K (either mean value or given scalar in [1,p]
    [yhat,B,I,K] = CRDA0(Xt,X,y,covtype,Kmethod,q,prior,B,mu);  
    cvstats = [];
end

