function [C1,invC1,nu,b,iter] = MVT(X,varargin)
% MVT computes the M-estimator of scatter using t-distribution weight 
% function and fixed-point (FP) algorithm. 
% 
% If the data follows a multivariate t-distribution (MVT)  with the 
% correctly specified degrees of freedom (d.o.f.), then  this function 
% gives the maximum  likelihood estimate (MLE) of scatter parameter.
%
% Data is assumed to be centered (or the symmetry center parameter = 0)
%   
% USAGE
% -----
% C = MVT(X);
% [C,invC] = MVT(X,'nu',4,'scaling','covariance','printitn',1);
% [C1,invC1,nu,b,iter] = MVT(X,___);
% 
% INPUT 
% -----
%   X       : the data matrix with n rows (observations) and p columns.
%
% Name-Value Pair Arguments
% -------------------------
% MVT can be called with numerous optional arguments. Optional
% arguments are given in parameter pairs, so that first argument is
% the name of the parameter and the next argument is the value for
% that parameter. Optional parameter pairs can be given in any order.
%
% Name      Value and description
%==========================================================================
%  Optional inputs: 
% 
%   'nu' : [] (default) or non-negative real scalar
%       correponds to d.o.f. parameter  of the t-distribution
%       - []  Implies that nu is estimated from the data using Algorithm 1 [1]
%       - nu = Inf implies using Gaussian weight u(t) = t in which case
%       the function returns the sample covariance matrix
%       - nu > 0 implies using a t-weight function u(t;nu) 
%
%   'invC' : [] (default) or a pos. def. p x p matrix
%       correspond to the inverse scatter matrix to start the FP iterations.  
%       If not specified ([]), algorithm uses the  inverse of the SCM as 
%       initial start for the FP algorithm.
%
%  'scaling' : 'none' (default), 'gaussian', 'covariance'
%       -'none' implies no scaling, so using the standard MVT weight 
%           function u(t;v) = (p+v)/(v+t)
%       -'covariance' implies using weight function corresponding to MVT 
%           distribution where the scatter matrix coincides with the 
%           covariance matrix, so u(t;v)= (p+v)/(v-2+t). Requires nu > 2. 
%       -'gaussian' implies using MVT weight function u(t;nu) that is scaled 
%           so that the estimator is consistent to covariance matrix for 
%           Gaussian  data 
% 
% 'iter_nu' : positive integer (Default is 2) 
%       maximum number of iterations used to estimate the nu parameter.  
%
% 'printitn' : non-negative integer (Default is 0) 
%       Print iteration convergence. If 1, then print each iteration, if 2, 
%       then every 2nd iteration, etc. 
%
% OUTPUT
% ------
%   C1      : M-estimator of scatter for t-weight u(t;nu)
%   invC1   : the inverse matrix of C1
%   nu      : the d.o.f. parameter used in the weight function u(t:nu)
%   b       : scaling constant used (b=1, if scaling = false) 
%   iter    : nr of iterations
% 
% REFERENCE 
% ---------
% [1] E. Ollila, D.P. Palomar, and F. Pascal, "Shrinking the eigenvalues of 
%  M-estimators of covariance matrix", Arxiv, 2020. 
%
% Author: Esa Ollila, May 2020, Aalto University. 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

argin = inputParser;
[n, p] = size(X);

valX = @(X) assert(numel(size(X))==2 && size(X,1) >= size(X,2) && size(X,2) >= 2, ... 
        '''X'' must be a matrix having at least two columns (variables)');
addRequired(argin,'X',valX);

if any(any(isnan(X)))
    error ('Input data contains NaN''s.');
end

if ~isa(X,'double')
    X = double(X);
end

realdata = true;
if ~isreal(X)
   realdata=false; 
end

%-- optional input parsing rules
valnu = @(x) assert(isempty(x) || (isscalar(x) && x > 0),'''nu'' must be positive scalar');
addParameter(argin,'nu',[], valnu);

valinvC = @(x) assert( isempty(x) || ( ismatrix(x) && size(x,2)==size(x,1) && size(x,1)==p), ...
    'the input invC must be a empty set or matrix of size p x p');
addParameter(argin,'invC',[],valinvC);

valscaling = @(x) assert(any(strcmpi(x,{'gaussian','covariance','none'})), ...
    'must be string equal to ''gaussian'' or ''covariance'' or ''none''');
addParameter(argin,'scaling','none',valscaling);


valprintitn = @(x) assert( isscalar(x) && x >= 0 && (x==round(x)),  ... 
    '''printitn'' must be a non-negative integer');
addParameter(argin,'printitn',0, valprintitn);

valiter = @(x) assert(isempty(x) || ( (isscalar(x) && x > 0) && round(x)==x) , '''iter_nu'' must be positive integer');
addParameter(argin,'iter_nu',2, valiter);

%---  parse inputs
parse(argin,X,varargin{:});
scaling     = argin.Results.scaling;
invC1       = argin.Results.invC;
nu          = argin.Results.nu;
printitn    = argin.Results.printitn; 
MAX_ITER_nu = argin.Results.iter_nu;

% if nu is not given, the estimate nu from the data using Algorithm 1 of
% Ollila et al (2020). 
if isempty(nu)
    [nu,~,~,invC1] = estimate_nu(X,MAX_ITER_nu);
end
    
% if the initial start of FP algorithm not given, then use the inverse of
% the SCM:
if isempty(invC1) || isinf(nu)
    C1 = X'*X/n; % SCM initial start 
    invC1 = C1\eye(p);
end

switch lower(scaling)
    case 'gaussian'       
        % compute the scaling constant so that the t-MLE is consistent for 
        % covariance matrix for Gaussian data 
        b = tloss_consistency_factor(p,nu,realdata);      
    otherwise
        b = 1;
end
    
if realdata
    % if scatter matrix == covariance matrix
    if strcmpi(scaling,'covariance') && nu >2       
        ufun = @(t,v) (p+v)./(v-2+t); % weight function
    else
    % otherwise use the standard weight function:     
        ufun = @(t,v) (p+v)./(v+t); % weight function
    end       
else 
    if strcmpi(scaling,'covariance')  && nu >2       
        ufun = @(t,v) (v+2*p)./(v - 2 + 2*t); % weight function
    else
        ufun = @(t,v) (v+2*p)./(v + 2*t); % weight function
    end
end
const = 1/(b*n);

% if nu = infty, then return the MLE corresponding to Gaussian = SCM 
% otherwise compute the MVTMLE using the FP algorithm
if ~isinf(nu)

    MAX_ITER = 2000; % Max number of iteration
    EPS = 5.0e-4;    % convergence tolerance 
    iter = 1;
    while (iter<=MAX_ITER)   
    
        t = real(sum((X*invC1).*conj(X),2)); % norms        
        C1 = const*X'*(X.*repmat(ufun(t,nu),1,p)); 
        err = eye(p)-invC1*C1;
        d = norm(err(:),Inf);
      
        if mod(iter,printitn)==0
            fprintf('At iter = %4d, dis=%.6f\n',iter,d);
        end
      
        invC1 = C1\eye(p);
    
        if (d<=EPS) 
            break;             
        end         
           
        iter = iter+1;  
    end
    % FP iterations are done
    
    C1(1:(p+1):p^2) = real(C1(1:(p+1):p^2));

    if (iter==MAX_ITER)
        error(['WARNING! Slow convergence: error of the solution is %f\n ' ...
            'after %d iterations\n'],d,iter);
    end
else
    % if nu 0 infty, we use the SCM 
    iter = 0; 
end

end

function b=tloss_consistency_factor(p,v,realdata)
%  Function that computes the scaling factor for multivariate t-weight
%  function so that the returned scatter matrix is concistent estimator of
%  the covariance matrix under the assumption that the data is from 
%  Gaussian distribution
%--------------------------------------------------------------------------   

% First try by numerical integration 
if realdata 
     b = tloss_consistency_factor_int(p,v);
else
     b = tloss_consistency_factor_int(2*p,v);        
end

% If integration did not converge, then compute b by MC simulation 
% --
if isnan(b)
   % need to use MC simul to find b
   MCsimul = 100000;
   
   if realdata
        t = chi2rnd(p,1,MCsimul);
        psifun = @(t,v) (p+v)*t./(v+t); % weight function
   else 
        t = (1/2)*chi2rnd(2*p,1,MCsimul);
        psifun = @(t,v) (v+2*p)*t./(v+2*t); % weight function
   end
   b = (1/p)*mean(psifun(t,v));

end

end

function b=tloss_consistency_factor_int(p,v)
% computes the concistency factor b = (1/p) E[|| x ||^2 u_v( ||x||^2)] when
% x ~ N_p(0,I). 

sfun = @(x,p,v)  (x.^(p/2)./(v+ x) .* exp(-x/2));
c = 2^(p/2)*gamma(p/2);
w = warning('off','all');
q = (1/c)*integral(@(x)sfun(x,p,v),0,Inf);
b = ((v+p)/p)*q; % consistency factor  
warning(w)
end
