cosmo overlap skl

function [y_in_x, x_in_y] = cosmo_overlap(xs, ys)
    % compute overlap between vectors or cellstrings in two cells
    %
    % [y_in_x,x_in_y]=cosmo_overlap(xs,ys)
    %
    % Inputs:
    %    xs         Nx1 cell with all elements either numeric arrays or
    %               cells with strings
    %    ys         Mx1 cell with all elements either numeric arrays or
    %               cells with strings
    %
    % Output:
    %    y_in_x     MxN cell with the (i,j)-th element indicating the ratio of
    %               elements in ys{j} that are present in xs{i}
    %    x_in_i     MxN cell with the (i,j)-th element indicating the ratio of
    %               elements in ys{i} that are present in ys{i}
    %
    % Examples:
    %     % Compute overlap between two cells with cellstrings
    %     xs={{'a'},{'a','b'},{'a','b','c'},{}};
    %     ys={{'b'},{'c','b','a'}};
    %     [x_in_y,y_in_x]=cosmo_overlap(xs,ys);
    %     cosmo_disp(x_in_y);
    %     %|| [ 0     0.333
    %     %||   1     0.667
    %     %||   1         1
    %     %||   0         0 ]
    %     %||
    %     cosmo_disp(y_in_x);
    %     %|| [     0         1
    %     %||     0.5         1
    %     %||   0.333         1
    %     %||     NaN       NaN ]
    %
    %     % Compute overlap between two cells with numeric arrays
    %     xs={1,[1 2],1:3,[]};
    %     ys={2,[3,2,1]};
    %     [x_in_y,y_in_x]=cosmo_overlap(xs,ys);
    %     cosmo_disp(x_in_y);
    %     %|| [ 0     0.333
    %     %||   1     0.667
    %     %||   1         1
    %     %||   0         0 ]
    %     %||
    %     cosmo_disp(y_in_x);
    %     %|| [     0         1
    %     %||     0.5         1
    %     %||   0.333         1
    %     %||     NaN       NaN ]
    %
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    [xs_vec, nx, nxs, xi] = get_counts(xs);
    [ys_vec, ny, nys, yi] = get_counts(ys);

    xy_cell = [xs_vec(:); ys_vec(:)];
    xyc = cat(1, xy_cell{:});

    % index xs from 1 to nx, and ys from (nx+1) to (nx+ny)
    xyi = [xi; (nx + yi)];

    % space for histogram
    h = zeros(nx, ny);

    % position of index of last value in xs
    xs_pos_last = sum(nxs);

    % because string comparisons are slow, use their indices instead
    [unq, unused, idxs] = unique(xyc);
    nunq = numel(unq);
    [idxs_sorted, i_sorted] = sort(idxs);

    msk = [true; any(diff(idxs_sorted, 1), 2)];
    unq_start_end_pos = [find(msk); 1 + numel(msk)];

    for j = 1:nunq
        % more readable, but slower:
        %    start_pos=unq_start_end_pos(j);
        %    end_pos=unq_start_end_pos(j+1)-1;
        %    if i_sorted(start_pos)>xs_pos_last [...]
        if i_sorted(unq_start_end_pos(j)) <= xs_pos_last && ...
                i_sorted(unq_start_end_pos(j + 1) - 1) > xs_pos_last

            start_pos = unq_start_end_pos(j);
            end_pos = unq_start_end_pos(j + 1) - 1;

            i = i_sorted(start_pos:(end_pos));

            first_y = find_first_greater_than(i, xs_pos_last);
            px = xyi(i(1:(first_y - 1)));
            py = xyi(i(first_y:end)) - nx;

            h(px, py) = h(px, py) + 1;
        end
    end

    x_in_y = bsxfun(@rdivide, h, nxs);
    y_in_x = bsxfun(@rdivide, h, nys');

function i = find_first_greater_than(sorted_vs, thr)
    % using binary search, find the first position i in sorted_vs
    % so that sorted(vs)>thr
    % it is assumed that sorted_vs is sorted
    if sorted_vs(end) <= thr
        i = numel(sorted_vs) + 1;
        return
    end

    first = 1;
    last = numel(sorted_vs);

    while first < last
        mid = floor((first + last) / 2);
        if sorted_vs(mid) <= thr
            first = mid + 1;
        else
            last = mid;
        end
    end
    i = first;
    % assert(sorted_vs(i)>thr);
    % assert(i==1 || sorted_vs(i-1)<=thr);

function [xs_vec, n, c, i] = get_counts(xs)
    n = numel(xs);
    c = zeros(n, 1);
    xs_vec = cell(n, 1);
    for k = 1:n
        xsk = xs{k};
        if ~isnumeric(xsk) && ~iscellstr(xsk)
            error('only cells with numeric or cellstr input is supported');
        end
        c(k) = numel(xsk);
        xs_vec{k} = xsk(:);
    end

    nc = sum(c);
    i = zeros(nc, 1);
    pos = 0;
    for k = 1:n
        ck = c(k);
        idxs = pos + (1:ck);
        i(idxs) = k;
        pos = pos + ck;
    end