Regularized discriminant analysis using the RSCM estimators

By running this script you are able to reproduce the figure 7 in the ArXiv version of our paper Ollila and Raninen (2019). This example was left out from the published version.

Contents

Regularized linear and quadractic discriminant analysis

The problem is to classify an observation $\mathbf{x} \in R^p$ to one of the K populations or classes. We assume no knowledge of the class a priori probabilities. In quadratic discriminant analysis (QDA) classification, a new observation $\mathbf{x}$ is assigned to class $\hat k \in \{1, \ldots,K\}$ by the rule

$$  \hat k = \arg \min_k \, (\mathbf{x}-\mathbf{\mu}_k)^\top
\hat \Sigma_k^{-1} (\mathbf{x}- \hat \mu_k) + \log|\hat \Sigma_k|. $$

where $\hat \mu_k \in R^p$ and $\hat \Sigma_k \in R^{p \times p}$ are estimates of the mean vector and the covariance matrix of the populations computed from the training data set. Here we use the sample means for $\hat \mu_k$-s and different regularized SCM estimators for $\hat \Sigma_k$-s and evaluate the performance of such regularized QDA rule.

In linear discriminant analysis (LDA), one assumes that the class covariance matrices are equal, and in this case the rule may be written in the form

$$  \hat k = \arg \max_k \, \{ \mathbf{x}^\top\textbf{b}_k
- (1/2) \hat \mu_k^\top \mathbf{b}_k \} ,$$

where $\mathbf{b}_k=\hat \Sigma^{-1} \hat \mu_k$ and $\hat \Sigma$ is an estimate of the common covariance matrix of the classes computed from the pooled training data. Again we use different regularized SCM estimators in place of $\hat \Sigma$ and evaluate the performance of such regularized LDA rule.

We test the performance of regularized LDA and QDA rules on Phone data set

Load the Phoneme data set

each .mat file needs to have the data as an n x p matrix 'data' and an n x 1 vector classlabels telling which class each observation belongs to

clear; clc;
load phonemedata.mat

train_percentage = 1/13;
K       = max(classlabels) % #of groups(classes)
p       = size(data,2);    % dimension
NRSIM   = 50;               % NR os simulation runs
errLDA  = zeros(5,NRSIM);   % LDA (computed only if N > p)
errQDA  = zeros(4,NRSIM);   % QDA (computed only if min(n_i) > p)
rng(0); % set random seed for reproducibility
prior = ones(1,K);  % Choose the uniform priors
K =

     5

Simulation over 50 independent random splits to training and test sets

tic;
for iter=1:NRSIM

    % Training set X and test set Z selection by random sampling
    [yt,Xt,y,X,mu] = create_data(data,classlabels,train_percentage,true);
    N  = size(X,1);
    Nt = size(Xt,1);

    % Compute the spatial median of the classes and center it w.r.t. it
    mu2 = zeros(p,K);
    X2 = zeros(size(X));
    for ii=1:K
        mu2(:,ii) = spatmed(X(y==ii,:));           % class spat. medians
        X2(y==ii,:) =  bsxfun(@minus,X(y==ii,:),mu2(:,ii).');
    end

    if N > p
        poolSCM = cov(X,1)*(N/(N-K));
        yhat = LDA(Xt,mu,poolSCM \ mu);
    end
    errLDA(1,iter) = sum(yhat ~= yt)/Nt;

    % ELL2 - QDA
    [yhat,stat2] = RegQDA(Xt,X,y,mu,'ell2',prior);
    errQDA(2,iter) = sum(yhat ~= yt)/Nt;

    % ELL2 - LDA
    [RSCM2p,~,stat2p] = regscm(X,'approach','ell2', ...
        'centered',true,'verbose','off'); % pooled ELL2-RSCM
    yhat = LDA(Xt,mu,RSCM2p \ mu);
    errLDA(3,iter) = sum(yhat ~= yt)/Nt;

    % ELL1 - QDA
    kappavals = zeros(1,K);
    for ii=1:K,  kappavals(ii) = stat2{ii}.kappa;  end
    yhat = RegQDA(Xt,X,y,mu,'ell1',prior,kappavals);
    errQDA(1,iter) = sum(yhat ~= yt)/Nt;

    % ELL1 - LDA
    is_centered = true;
    gammahat =  gamell1(X2,is_centered); % do not center the data
    [RSCM1p,~,stat1p] = regscm(X,'approach','ell1','centered',true, ...
        'kappa',stat2p.kappa,'gamma',gammahat,'verbose','off');
    yhat = LDA(Xt,mu, RSCM1p \ mu);
    errLDA(2,iter) = sum(yhat ~= yt)/Nt;

    % GAU - QDA
    [yhat,stat4] = RegQDA(Xt,X,y,mu,'ell2',prior,zeros(1,K));
    errQDA(4,iter) = sum(yhat ~= yt)/Nt;

    % GAU - LDA
    [RSCM4p,~,stat4p] = regscm(X,'approach','ell2','centered',true, ...
        'kappa',0,'verbose','off');
    yhat = LDA(Xt,mu, RSCM4p \ mu);
    errLDA(5,iter) = sum(yhat ~= yt)/Nt;

    % LWE - QDA
    yhat = RegQDA(Xt,X,y,mu,'lw');
    errQDA(3,iter) = sum(yhat ~= yt)/Nt;

    % LWE - LDA
    [LWEp,~, stat3p] = regscm(X,'approach','lw', ...
        'verbose','off','centered',true);
    yhat = LDA(Xt,mu, LWEp \ mu);
    errLDA(4,iter) = sum(yhat ~= yt)/Nt;

    if mod(iter,10)==0
        fprintf('.');
    end
end
toc;

median([errLDA(1,:)' errQDA'])
median(errLDA')
std([errLDA(1,:)' errQDA'])
.....Elapsed time is 33.975200 seconds.

ans =

    0.1678    0.1285    0.1436    0.1521    0.1866


ans =

    0.1678    0.0996    0.1057    0.1062    0.1086


ans =

    0.0153    0.0129    0.0203    0.0214    0.0324

Plot the results

figure(1); clf
lab = {'LDA','Ell1-QDA','Ell2-QDA','LW-QDA','Gau-QDA','Ell1-LDA','Ell2-LDA','LW-LDA','Gau-LDA'};
koe=boxplot([errLDA(1,:)' errQDA' errLDA(2:end,:)'],'Labels',lab,'OutlierSize',10);
set(koe,'LineWidth',2);
set(gca,'FontSize',16)
set(gca,'LineWidth',2);

for more details, see paper Ollila and Raninen (2018)

References