function [e,ins,del,tru,hh,tt] = beat_score(beats,truth,collar,VERBOSE)
% [e,ins,del,tru,hh,tt] = beat_score(beats,truth,collar)
% Compare beat times to ground truth.
% <beats> is a list of system-generated beat times
% <truth> is a set of truth tap times, potentially a cell-array
% of different subjects' responses.
% Return <e> as the average error rate of <beats> against full
% ground truth, a simple average of error rate against each
% individual truth set, where error rate = (inserts + deletions)/true
% and a hit means the beat time was within (collar*mean true
% period) of a true hit. <ins>, <del>, <tru> return counts of
% individual beats inserted, deleted, and true;
% <hh> returns a histogram of system beat times relative to true
% beat times, quantized to 10ms bins, with timebase given in <tt>.
% 2012-03-27 Dan Ellis dpwe@ee.columbia.edu
if nargin < 3; collar = 0.2; end
if nargin < 4; VERBSOSE = 0; end
if nargout == 0; VERBOSE = 1; end
if isnumeric(truth)
% make sure truth is always a cell array. Convert from rows, if any.
trutharray = truth;
truth = cell(size(truth,1));
for i = 1:size(truth,1)
% keep only nonnegative values in each row
truth{i} = trutharray(i, find(trutharray(i,:)>0));
end
end
% parameters for histogram
maxt = 0.5; % cover for +/- 0.5 sec around true beats
tres = 0.01; % in 10ms bins
tt = (-maxt):tres:maxt; % actual bin values
h = zeros(1,length(tt));
ntruth = length(truth);
for i = 1:ntruth
truebeats = truth{i};
medianperiod = median(diff(truebeats));
collartime = collar*medianperiod;
% find nearest truth to each system beat
ntrue(i) = length(truebeats);
nsys(i) = length(beats);
% We're working with reported beats - true, so late tracking is positive
tdiffs = repmat(beats,ntrue(i),1) - repmat(truebeats',1,nsys(i));
% insertions are any system-generated beats more than collartime
% away from the nearest true beat
inserts(i) = sum(min(abs(tdiffs),[],1) > collartime);
% deletions are true beats more than collartime away from nearest
% system-generated beat
deletes(i) = sum(min(abs(tdiffs),[],2) > collartime);
% So error = (insertions + deletes)/ntrue
error(i) = ( inserts(i) + deletes(i) )/ntrue(i);
% update the histogram
% % with just best times?
% [mm, xx] = min(abs(tdiffs),[],1);
% h = h + hist(tdiffs(sub2ind(size(tdiffs),xx,1:size(tdiffs,2))),tt);
% or with *all* time differences within the window
h = h + hist(tdiffs(:),tt);
end
% Average the error
e = mean(error);
% Combine error counts
ins = sum(inserts);
del = sum(deletes);
tru = sum(ntrue);
% trim the extreme bins from the histogram
hh = h(2:end-1);
tt = tt(2:end-1);
if VERBOSE
fprintf(1,'Overall error= %5.1f%% (%4d ins, %4d del, %4d true)\n', ...
100*e, ins, del, tru);
end