function [nuhat,nu_init,stats,invC] = estimate_nu(X,MAX_iter,varargin)
% ESTIMATE_NU computes and estimate the degrees of freedom (d.o.f.) 
% parameter nu of  the multivariate t (MVT) distribution using the 
% Algorithm 1 described in Ollila et al. 2020 [1].  
% 
% INPUT 
% -----
% X       : the data matrix with n rows (observations) and p columns.
% MAX_iter: maximum number of iterations. Default is 2.  
% 
% Name-Value Pair Arguments
% -------------------------
% ESTIMATE_NU can be called with numerous optional arguments. Optional
% arguments are given in parameter pairs, so that first argument is
% the name of the parameter and the next argument is the value for
% that parameter. Optional parameter pairs can be given in any order.
%
% Name      Value and description
%==========================================================================
% Optional inputs: 
%
% 'nu_init' : initial value for nu. A real number > 2. 
% 
% 'etaS' : positive real number equal to the scale of the sample cova-
%       riance matrix (trace(S)/p). 
% 
% 'invC' : [] (default) or a pos. def. p x p matrix
%       correspond to the inverse scatter matrix to start the FP iterations.  
%       If not specified ([]), algorithm uses the  inverse of the SCM as 
%       initial start for the FP algorithm.
% 
% OUTPUT:
% -------
% nuhat   : estimate of the d.o.f. parameter
% nu_init : this is the initial estimate based on elliptical kurtosis. 
% stats   : a structure array with fields
%       .nu0  : initial estimate of nu
%       .etaS : trace of the SCM
%       .iter : number of iterations needed before convergence 
% 
% invC    : the inverse of MVT estimator at the last iteration 
%
% REFERENCE: 
% ----------
% [1] E. Ollila, D.P. Palomar, and F. Pascal, "Shrinking the eigenvalues of 
%       M-estimators of covariance matrix", Arxiv, 2020. 
%
% Author: Esa Ollila, May 2020, Aalto University. 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

argin = inputParser;
[n, p] = size(X);

%% Required arguments: 
valX = @(X) assert(numel(size(X))==2 && size(X,1) >= size(X,2) && size(X,2) >= 2, ... 
        '''X'' must be a matrix having at least two columns (variables)');
addRequired(argin,'X',valX);

valITER= @(x) assert(x > 0 && isreal(x) && (x-round(x)==0), ...
    '''MAX_ITER'' must be a positive integer');
addRequired(argin,'MAX_ITER',valITER);

if any(any(isnan(X)))
    error ('Input data contains NaN''s.');
end

if ~isa(X,'double')
    X = double(X);
end


%%  optional arguments:

valnu = @(x) assert(isempty(x) || (isscalar(x) && x > 2),'''nu'' must be positive scalar > 2');
addParameter(argin,'nu_init',[], valnu);
%
valEtaS = @(x) assert(isempty(x) || (isscalar(x) && x > 0), ... 
    '''scale'' must be positive scalar or empty set');
addParameter(argin,'etaS',[], valEtaS);
%
valinvC = @(x) assert( isempty(x) || ( ismatrix(x) && size(x,2)==size(x,1) && size(x,1)==p), ...
    'the input invC must be a empty set or matrix of size p x p');
addParameter(argin,'invC',[],valinvC);

%%  parse inputs
parse(argin,X,MAX_iter,varargin{:});
invC  = argin.Results.invC;
etaS   = argin.Results.etaS;
nu_init = argin.Results.nu_init;

%% code starts ...
S = [];
if  isempty(invC)

    S  = X'*X/n ; % the sample covariance matrix (SCM)
    invC = S \ eye(p);
    
end

if isempty(etaS)
    
    if isempty(S)
        S  = X'*X/n ; % the sample covariance matrix (SCM)
    end
    etaS = trace(S)/p; 
end 


%-- Compute the initial value if not given 
if  isempty(nu_init)
    
    is_centered = true; % we assume that the data is centered
    kappahat = ellkurt(X,[],is_centered);
    nu_init = 2/kappahat+4;

    % if kappahat < 0, then nu_init can be negative. Replace 
    % nu by a large number (nu=100);
    if nu_init<0
        nu_init = 100;
    end 
end

nu0 = nu_init;
stats.nu0 = nu_init;
printitn = 0;

for iter=1:MAX_iter
    
    %-- Compute the MVT-estimator using the estimated dof parameter:
    [C,invC] = MVT(X,'nu',nu0,'invC',invC); 
    etaC = trace(C)/p;  

    eta_ratio = max(etaS/etaC,1);
    nu1 = 2*eta_ratio/(eta_ratio-1);
    
    d = abs(nu1-nu0)/abs(nu0);
    
    if mod(iter,printitn)==0
        fprintf('At iter = %2d, dis=%.6f  v=%.6f\n',iter,d,nu1);
    end
    
    if  d< 1.0e-2
        break
    end
    
    nu0 = nu1;
    
end

nuhat = nu1;
stats.etaS = etaS;
stats.iter = iter; 
 