function [entX entXY] = ent(dinsts, labels, bins)
% Written by Jianxin Wu (Jianxin Wu, wujx2001@gmail.com)
% originally version -- 2013.8.21
% slightly modified (mainly adding some comments to help reading -- 2015.1.17
%
% Originally I thought this is a trivial function, and did not put it online
% Recently, (after giving talks & receiving emails), I realized that some details are still important.
% Although the algorithm is simple, a wrong implementation could still happen.
% Thus, comments are added and this file is put online.
%
% Given the data in 'dinsts', where each row is one example
% and the 'labels' (integers 1, 2, ..., C, where C is the number of categories)
% One parameter 'bins' is needed:
% case '-2': quantize dinst into 2-bits, then compute MI
% case '-1': quantize dinst into 1-bit, then compute MI
% else: be a positive integer, uniform quantization
% Output:
% entX: the entropy of each feature dimension
% entXY: the mutual information between each dimension and the labels
%
% You may want to change the thresholds for bins = -1 (1-bit) or -2 (2-bits)
% e.g., if you data is non-negative, current threshold values in X will not work
% they are designed for FV / VLAD features (which are [-1 +1])
%
% Even if your data are [-1 +1], you may want to change the threshold +/- 0.0125 to
% a value that fits your dataset, e.g., 0.1- and 0.9-quantile
[~, dim] = size(dinsts);
nr_class = max(labels); % we assume labes are 1, 2, 3, ...
if bins == -2 % this is the case for quantized 2-bit version
X = [-1 -0.0125 0 0.0125 1];
elseif bins == -1 % 1-bit version
X = [-1 0 1];
else
vmin = min(min(dinsts));
vmax = max(max(dinsts));
X = vmin:((vmax-vmin)/bins):vmax+1e-6;
end
bins = length(X)-1;
probX = zeros(bins,dim);
probXY = zeros(bins*nr_class,dim);
for i = 1:dim
if mod(i,1000)==0
fprintf('.');
end
if mod(i,100000)==0
fprintf('\n');
end
temp = dinsts(:,i);
frequency = histc(temp, X);
frequency = frequency / sum(frequency);
probX(:,i) = frequency(1:end-1);
this = zeros(1,nr_class*bins);
for j = 1:nr_class
frequency = histc(temp(labels==j), X);
this((j-1)*bins+1:j*bins) = frequency(1:end-1);
end
this = this / sum(this);
probXY(:,i) = this;
end
fprintf('\n');
entX = -sum(probX.*log2(probX+1.0e-6));
entXY = -sum(probXY.*log2(probXY+1.0e-6));