cosmo disp

function cosmo_disp(x, varargin)
    % display the input as a string representation
    %
    % cosmo_disp(x,opt)
    %
    % Inputs:
    %   x              any type of data element (can be a dataset struct)
    %   opt            Optional struct with fields
    %     .threshold   If the number of values in an array along a dimension
    %                  exceeds threshold, then an array is showed in summary
    %                  style along that dimension. Default: 5
    %     .edgeitems   When an array is shown in summary style, edgeitems sets
    %                  the number of items at the beginning and end of the
    %                  array to be shown (separated by '...' in rows and by ':'
    %                  in columns).
    %                  Default: 3
    %     .precision   Numeric precision, indicating number of decimals after
    %                  the floating point
    %                  Default: 3
    %     .strlen      Maximal string length, if a string is longer the
    %                  beginning and end are shown separated by ' ... '.
    %                  Default: 20
    %     .depth       Maximum recursion depth
    %                  Default: 6
    %
    % Side effect:     Calling this function caused the representation of x
    %                  to be displayed.
    %
    %
    % Examples:
    %     % display a complicated data structure
    %     x=struct();
    %     x.a_cell={[],{'cell within cell',[1 2; 3 4]}};
    %     x.small_matrix=[10 11 12; 13 14 15];
    %     x.big_matrix=reshape(1:200,10,20);
    %     x.huge=2^40;
    %     x.tiny=2^-40;
    %     x.a_string='hello world';
    %     x.a_struct.another_struct.name='me';
    %     x.a_struct.another_struct.func=@abs;
    %     cosmo_disp(x);
    %     %|| .a_cell
    %     %||   { [  ]  { 'cell within cell'  [ 1         2
    %     %||                                   3         4 ] } }
    %     %|| .small_matrix
    %     %||   [ 10        11        12
    %     %||     13        14        15 ]
    %     %|| .big_matrix
    %     %||   [  1        11        21  ...  171       181       191
    %     %||      2        12        22  ...  172       182       192
    %     %||      3        13        23  ...  173       183       193
    %     %||      :         :         :        :         :         :
    %     %||      8        18        28  ...  178       188       198
    %     %||      9        19        29  ...  179       189       199
    %     %||     10        20        30  ...  180       190       200 ]@10x20
    %     %|| .huge
    %     %||   [ 1.1e+12 ]
    %     %|| .tiny
    %     %||   [ 9.09e-13 ]
    %     %|| .a_string
    %     %||   'hello world'
    %     %|| .a_struct
    %     %||   .another_struct
    %     %||     .name
    %     %||       'me'
    %     %||     .func
    %     %||       @abs
    %     %
    %     cosmo_disp(x.a_cell)
    %     %|| { [  ]  { 'cell within cell'  [ 1         2
    %     %||                                 3         4 ] } }
    %     cosmo_disp(x.a_cell{2}{2})
    %     %|| [ 1         2
    %     %||   3         4 ]
    %
    %     % make a cell in a cell in a cell in a cell ...
    %     m={{{{{{{{{{{'hello'}}}}}}}}}}};
    %     cosmo_disp(m)
    %     %|| { { { { { { <cell> } } } } } }
    %     cosmo_disp(m,'depth',8)
    %     %|| { { { { { { { { <cell> } } } } } } } }
    %     cosmo_disp(m,'depth',Inf)
    %     %|| { { { { { { { { { { { 'hello' } } } } } } } } } } }
    %
    %     % illustrate 'threshold' and 'edgeitems' arguments
    %     cosmo_disp(num2cell('a':'k'))
    %     %|| { 'a'  'b'  'c' ... 'i'  'j'  'k'   }@1x11
    %     cosmo_disp(num2cell('a':'k'),'threshold',Inf)
    %     %|| { 'a'  'b'  'c'  'd'  'e'  'f'  'g'  'h'  'i'  'j'  'k' }
    %     cosmo_disp(num2cell('a':'k'),'edgeitems',2)
    %     %|| { 'a'  'b' ... 'j'  'k'   }@1x11
    %
    %     % illustrate 'precision' argument
    %     cosmo_disp(pi*[1 2],'precision',1);
    %     %|| [ 3       6 ]
    %     cosmo_disp(pi*[1 2],'precision',3);
    %     %|| [ 3.14      6.28 ]
    %     cosmo_disp(pi*[1 2],'precision',5);
    %     %|| [ 3.1416      6.2832 ]
    %     cosmo_disp(pi*[1 2],'precision',7);
    %     %|| [ 3.141593      6.283185 ]
    %
    %     % illustrate n-dimensional arrays
    %     x=zeros([2 2 1 2 3]);
    %     x(:)=2*(1:numel(x));
    %     cosmo_disp(x)
    %     %|| <double>@2x2x1x2x3
    %     %||    (:,:,1,1,1) =  [ 2         6
    %     %||                     4         8 ]
    %     %||    (:,:,1,2,1) =  [ 10        14
    %     %||                     12        16 ]
    %     %||    (:,:,1,1,2) =  [ 18        22
    %     %||                     20        24 ]
    %     %||    (:,:,1,2,2) =  [ 26        30
    %     %||                     28        32 ]
    %     %||    (:,:,1,1,3) =  [ 34        38
    %     %||                     36        40 ]
    %     %||    (:,:,1,2,3) =  [ 42        46
    %     %||                     44        48 ]
    %     cosmo_disp(reshape(char(65:72),[2 2 2]))
    %     %|| <char>@2x2x2
    %     %||    (:,:,1) = 'AC
    %     %||               BD'
    %     %||    (:,:,2) = 'EG
    %     %||               FH'
    %     cosmo_disp(zeros([2 3 5 7 0 2]))
    %     %|| <double>@2x3x5x7x0x2 (empty)
    %
    %     % illustrate non-singleton structs
    %     x=struct('x',{1 2; 3 4});
    %     cosmo_disp(x);
    %     %|| <struct>@2x2
    %     %||    (1,1).x
    %     %||           [ 1 ]
    %     %||    (2,1).x
    %     %||           [ 3 ]
    %     %||    (1,2).x
    %     %||           [ 2 ]
    %     %||    (2,2).x
    %     %||           [ 4 ]
    %     x3=cat(3,x,x,x);
    %     cosmo_disp(x3);
    %     %|| <struct>@2x2x3
    %     %||    (1,1,1).x
    %     %||             [ 1 ]
    %     %||    (2,1,1).x
    %     %||             [ 3 ]
    %     %||    (1,2,1).x
    %     %||             [ 2 ]
    %     %||      :        :
    %     %||    (2,1,3).x
    %     %||             [ 3 ]
    %     %||    (1,2,3).x
    %     %||             [ 2 ]
    %     %||    (2,2,3).x
    %     %||             [ 4 ]
    %
    % Notes:
    %   - Unlike the builtin 'disp' function, this function shows the contents
    %     of the input using recursion. For example if a cell contains a
    %     struct, then the contents of that struct is shown as well
    %   - A use case is displaying dataset structs
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    defaults.threshold = 5;    % max #items before triggering summary style
    defaults.edgeitems = 3;    %#items at edges in summary style
    defaults.precision = 3;    % show floats with 3 decimals
    defaults.strlen = 20;      % insert '...' with strings more than 20 chars
    defaults.depth = 6;        % maximal depth
    defaults.show_size = false; % whether to show size of matrices

    opt = cosmo_structjoin(defaults, varargin);

    % get string representation of x
    s = disp_helper(x, opt);

    % print string representation of x
    disp(s);

function s = disp_helper(x, opt)
    % general helper function to get a string representation. Unlike the
    % main function this function returns a string, which makes it suitable
    % for recursion
    depth = opt.depth;
    if depth <= 0
        s = any2summary_str(x, opt);
        return
    end

    opt.depth = depth - 1;

    if ~has_size(x)
        s = any2summary_str(x, opt);
    else
        s = nd_any2str(x, opt);
    end

function s = any2summary_str(x, unused)
    if has_size(x)
        sz = size(x);
    else
        sz = [1 1];
    end
    s = surround_with(true, '<', class(x), '>', sz);

function tf = has_size(x)
    % helper, because some classes have no 'size'
    try
        size(x);
        tf = true;
    catch
        tf = false;
    end

function y = nd_any2str(x, opt)
    if isstruct(x)
        if numel(x) <= 1
            y = struct2str(x, opt);
        else
            y = multi_any2string(x, 0, opt);
        end
    elseif numel(size(x)) == 2
        if iscell(x)
            y = cell2str(x, opt);
        elseif isnumeric(x) || islogical(x)
            y = matrix2str(x, opt);
        elseif ischar(x)
            y = string2str(x, opt);
        elseif isa(x, 'function_handle')
            y = function_handle2str(x, opt);
        else
            y = any2summary_str(x, opt);
        end
    else
        y = multi_any2string(x, 2, opt);
    end

function y = multi_any2string(x, ndim_post, opt)
    sz = size(x);
    sz_rest = sz((ndim_post + 1):end);
    ndim_rest = numel(sz_rest);
    n_rest = prod(sz_rest);

    parts = cell(ndim_rest, 1);
    for k = 1:ndim_rest
        parts{k} = num2cell(1:sz_rest(k));
    end

    p = cosmo_cartprod(parts);
    if ndim_post > 0
        xflat = reshape(x, [sz(1:ndim_post) n_rest]);
    else
        xflat = x(:);
    end

    [pre, post] = get_mx_idxs(xflat, opt.edgeitems, opt.threshold, ndim_post + 1);

    header = any2summary_str(x, opt);
    s_pre = nd_any2str_helper(xflat, p, pre, opt);

    if isempty(post)
        s_post = {'', ''};
        s_dots = {'', ''};
    else
        s_post = nd_any2str_helper(xflat, p, post, opt);
        s_dots = cell(1, 2);
        for k = 1:2
            szs = cellfun(@(x)size(x, 2), s_post(:, k));
            pos = round(max(szs) / 2);
            s_dots{k} = [spaces(1, pos) ':'];
        end
    end
    s_all = cat(1, s_pre, s_dots, s_post);

    y = strcat_({header; strcat_(s_all)});

function s = nd_any2str_helper(xflat, p, idxs, opt)
    n_rest = numel(idxs);
    s = cell(n_rest, 2);
    sz = size(xflat);
    npre = numel(sz) - 1;
    for k = 1:n_rest
        idx = idxs(k);
        switch npre
            case 1
                % for struct
                v = xflat(idx);
                idx_prefix = '';
                idx_postfix = '';
            case 2
                % anything else
                v = xflat(:, :, idx);
                idx_prefix = ':,:,';
                idx_postfix = ' = ';
            otherwise
                assert(false);
        end

        v_str = disp_helper(v, opt);
        idx_str = sprintf(',%d', p(idx, :));
        s{k, 1} = sprintf('   (%s%s)%s', idx_prefix, idx_str(2:end), idx_postfix);
        s{k, 2} = v_str;
    end

function y = strcat_(xs)
    if isempty(xs)
        y = '';
        return
    end

    % all elements in xs are char
    [nr, nc] = size(xs);
    ys = cell(1, nc);

    % sizes for each element
    width_per_col = max_element_size(xs, 2);
    height_per_row = max_element_size(xs, 1);
    for k = 1:nc
        xcol = cell(nr, 1);
        width = width_per_col(k);
        row_pos = 0;
        for j = 1:nr
            height = height_per_row(j);
            if height == 0
                continue
            end

            x = xs{j, k};
            if ~ischar(x) && isempty(x)
                x = '';
            end

            sx = size(x);
            to_add = [height width] - sx;

            % pad with spaces
            row_pos = row_pos + 1;
            xcol{row_pos} = [[x spaces(sx(1), to_add(2))]; ...
                             spaces(to_add(1), width)];
        end
        ys{k} = char(xcol{1:row_pos});
    end
    y = [ys{:}];

function m = max_element_size(x, dim)
    % faster than cellfun
    n = numel(x);
    sizes = zeros(size(x));
    for k = 1:n
        sizes(k) = size(x{k}, dim);
    end
    m = max(sizes, [], 3 - dim);

function y = spaces(nx, ny)
    % faster than repmat(' ',nx,ny)
    if nx > 0 && ny > 0
        y(nx, ny) = ' ';
        y(:) = ' ';
    else
        if nx < 0
            nx = 0;
        end
        if ny < 0
            ny = 0;
        end
        y = reshape('', nx, ny);
    end

function y = struct2str(x, opt)
    if numel(x) == 0
        show_size = opt.show_size;
        y = [surround_with(show_size, '', 'struct', '', size(x)) ' (empty)'];
        return
    end

    assert(numel(x) == 1);
    fns = fieldnames(x);
    n = numel(fns);

    if n == 0
        show_size = opt.show_size;
        y = [surround_with(show_size, '', 'struct', '', size(x)) ' (empty)'];
    else
        r = cell(n * 2, 1);
        for k = 1:n
            fn = fns{k};
            r{k * 2 - 1} = ['.' fn];
            d = disp_helper(x.(fn), opt);
            r{k * 2} = [spaces(size(d, 1), 2) d];
        end
        y = strcat_(r);
    end

function s = function_handle2str(x, opt)
    s_with_quotes = string2str(func2str(x), opt);
    s = ['@' s_with_quotes(2:(end - 1))];

function s = string2str(x, opt)
    if ~ischar(x)
        error('expected a char');
    end
    [nrows, ncols] = size(x);

    if ncols > opt.strlen
        infix = ' ... ';
        h = floor((opt.strlen - numel(infix)) / 2);
        x = strcat_({x(:, 1:h), infix, x(:, ncols + ((1 - h):0))});
    end
    quote = '''';
    pre = quote;
    post = [spaces(nrows - 1, 1); quote];
    s = strcat_({pre, x, post});

function s = cell2str(x, opt)
    % display a cell

    edgeitems = opt.edgeitems;
    threshold = opt.threshold;

    % get indices of rows and columns to show
    [r_pre, r_post] = get_mx_idxs(x, edgeitems, threshold, 1);
    [c_pre, c_post] = get_mx_idxs(x, edgeitems, threshold, 2);

    part_idxs = {{r_pre, r_post}, {c_pre, c_post}};

    nrows = numel([r_pre r_post]) + ~isempty(r_post);
    ncols = numel([c_pre c_post]) + ~isempty(c_post);

    sinfix = cell(nrows, ncols * 2 + 1);
    for k = 1:(ncols - 1)
        sinfix{1, k * 2 + 2} = '  ';
    end

    cpos = 1;
    for cpart = 1:2
        col_idxs = part_idxs{2}{cpart};
        nc = numel(col_idxs);

        rpos = 0;
        for rpart = 1:2
            row_idxs = part_idxs{1}{rpart};

            nr = numel(row_idxs);
            if nr == 0
                continue
            end
            for ci = 1:nc
                col_idx = col_idxs(ci);
                trgc = cpos + ci * 2;
                for ri = 1:nr
                    row_idx = row_idxs(ri);
                    sinfix{rpos + ri, trgc} = disp_helper(x{row_idx, ...
                                                            col_idx}, opt);
                    if cpart == 2 && ci == 1 && nc > 0
                        sinfix{rpos + ri, cpos + ci * 2 - 1} = ' ... ';
                    end
                end

                if rpart == 2
                    max_length = max(cellfun(@numel, sinfix(:, trgc)));
                    pre_spaces = spaces(1, floor(max_length / 2 - 1));
                    sinfix{rpos, cpos + ci * 2} = [pre_spaces ':'];
                end
            end
            rpos = rpos + nr + 1;
        end
        cpos = cpos + nc * 2;
    end

    show_size = opt.show_size || ~isempty(r_post) || ~isempty(c_post);
    s = surround_with(show_size, '{ ', strcat_(sinfix), ' }', size(x));

function pre_infix_post = surround_with(show_size, pre, infix, post, matrix_sz)
    % surround infix by pre and post, doing
    n = prod(matrix_sz);
    if show_size && n ~= 1
        size_str = sprintf('x%d', matrix_sz);
        size_str(1) = '@';
        if n == 0
            size_str = [size_str ' (empty)'];
        end
    else
        size_str = '';
    end
    post = strcat_({spaces(size(infix, 1) - 1, 1); [post size_str]});
    pre_infix_post = strcat_({pre, infix, post});

function s = matrix2str(x, opt)
    if isempty(x)
        show_size = opt.show_size;
        s = surround_with(show_size, '[', '  ', ']', size(x));
        return
    end

    % display a matrix
    edgeitems = opt.edgeitems;
    threshold = opt.threshold;
    precision = opt.precision;

    % get indices of rows and columns to show
    [r_pre, r_post] = get_mx_idxs(x, edgeitems, threshold, 1);
    [c_pre, c_post] = get_mx_idxs(x, edgeitems, threshold, 2);

    % data to be shown
    y = x([r_pre r_post], [c_pre c_post]);

    % convert to string
    s = num2str(y, precision);

    % number of characters in first and second dimension
    [nc_row, nc_col] = size(s);

    % see where each column is a space; that's a potential split point
    sp_col = sum(s == ' ', 1) == nc_row;

    % col_index has value k for characters in the k-th column, else zero
    col_index = zeros(1, nc_col);
    col_count = 1;
    in_num = true;
    for k = 1:nc_col
        if in_num
            if sp_col(k)
                col_count = col_count + 1;
                in_num = false;

            else
                col_index(k) = col_count;
            end
        elseif ~sp_col(k)
            in_num = true;
            col_index(k) = col_count;
        end
    end

    % deal with rows
    row_blocks = cell(3, 1);
    if isempty(r_post)
        row_blocks{1, 1} = s;
    else
        % insert ':' for each column
        line = spaces(1, nc_col);
        for k = 1:max(col_index)
            idxs = find(col_index == k);
            median_pos = round(mean(idxs));
            line(median_pos) = ':';
        end
        row_blocks{1} = s(1:edgeitems, :);
        row_blocks{2} = line;
        row_blocks{3} = s(edgeitems + (1:edgeitems), :);
    end

    % deal with columns
    row_and_col_blocks = cell(3, 3);
    for row = 1:3
        if isempty(c_post)
            row_and_col_blocks{row} = row_blocks{row};
        else
            % insert ' ... ' halfway each row
            pre_end = find(col_index == edgeitems, 1, 'last') + 1;
            post_start = find(col_index == (edgeitems + 1), 1, 'first') - 1;

            r = row_blocks{row, 1};
            if isempty(r)
                continue
            end
            row_and_col_blocks{row, 1} = r(:, 1:pre_end);
            if row ~= 2
                row_and_col_blocks{row, 2} = repmat(' ... ', size(r, 1), 1);
            end
            row_and_col_blocks{row, 3} = r(:, post_start:end);
        end
    end

    show_size = opt.show_size || ~isempty(r_post) || ~isempty(c_post);
    s = surround_with(show_size, '[ ', strcat_(row_and_col_blocks), ' ]', ...
                      size(x));

function [pre, post] = get_mx_idxs(x, edgeitems, threshold, dim)
    % returns the first and last indices for showing an array along
    % dimension dim. If size(x,dim)<2*edgeitems, then pre has all the
    % indices, otherwise pre and post have the first and last edgeitems
    % indices, respectively
    n = size(x, dim);

    if n > max(threshold, 2 * edgeitems) % properly deal with Inf values
        pre = 1:edgeitems;
        post = n - edgeitems + (1:edgeitems);
    else
        pre = 1:n;
        post = [];
    end