function result = cosmo_rand(varargin)
% generate uniform pseudo-random numbers, optionally using a seed value
%
% result=cosmo_rand(s1,...,sN,['seed',seed])
%
% Input:
% s* scalar or vector indicating dimensions of the result
% 'seed', seed (optional) if provided, use this seed value for
% pseudo-random number generation
%
% Output:
% result array of size s1 x s2 x ... sN. If the seed option is
% used, repeated calls with the same seed and element
% dimensions gives the same result
% Example:
% % generate 2x2 pseudo-random number matrices twice, just like 'rand'
% % (repeated calls give different outputs)
% x1=cosmo_rand(2,2);
% x2=cosmo_rand(2,2);
% isequal(x1,x2)
% %|| false
% %
% % as above, but specify a seed; repeated calls give the same output
% x3=cosmo_rand(2,2,'seed',314);
% x4=cosmo_rand(2,2,'seed',314);
% isequal(x3,x4)
% %|| true
% %
% % using a different seed gives a different output
% x5=cosmo_rand(2,2,'seed',315);
% isequal(x3,x5)
% %|| false
%
%
% Notes:
% - this function behaves identically to the builtin 'rand' function,
% except that it supports a 'seed' option, which allows for
% deterministic pseudo-number generation
% - when using the 'seed' option, this function gives identical output
% under both matlab and octave. To achieve this, the PRNG is set to a
% different state for the two platforms
% - this function uses the Mersenne twister algorithm by default, even
% when 'seed' is used (unlike Matlab and Octave).
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
[sizes, seed, class_func] = process_input(varargin{:});
randomizer = @rand; % default
if seed ~= 0
is_matlab = cosmo_wtf('is_matlab');
if is_matlab
rng_state = get_mersenne_state_from_seed(seed, is_matlab);
stream = RandStream('mt19937ar', 'Seed', rng_state.Seed);
stream.State = rng_state.State;
randomizer = @stream.rand;
else
% preserve old PRNG state
orig_rng_state = rand('state');
cleaner = onCleanup(@()rand('state', orig_rng_state));
% set random number generation state
rng_state = get_mersenne_state_from_seed(seed, is_matlab);
rand('state', rng_state);
end
end
result = class_func(randomizer(sizes));
function rng_state = get_mersenne_state_from_seed(seed, is_matlab)
% set the PRNG of the mersenne twister based on seed
%
% based on pseudo-code from wikipedia:
% http://en.wikipedia.org/wiki/Mersenne_twister
persistent cached_seed
persistent cached_rng_state
if isequal(cached_seed, seed)
rng_state = cached_rng_state;
return
end
max_uint32 = 2^32 - 1;
state = uint64(zeros(625, 1));
state(1) = bitand(uint64(seed), max_uint32);
mersenne_mult = uint64(1812433253);
for j = 1:623
v = mersenne_mult .* bitxor(state(j), bitshift(state(j), -30)) + uint64(j);
state(j + 1) = bitand(v, max_uint32);
end
state(end) = 1;
if is_matlab
% reverse counter relative to Octave
% (this is undocumented in both Matlab and Octave)
state(end) = uint64(625) - state(end);
% matlab uses a struct to set the state
rng_state = struct();
rng_state.State = uint32(state);
rng_state.Type = 'twister';
rng_state.Seed = uint32(0);
else
% octave uses a vector to set the state
rng_state = state;
end
cached_rng_state = rng_state;
cached_seed = seed;
function x = identify_func(x)
% do nothing
function [sizes, seed, class_func] = process_input(varargin)
persistent cached_varargin
persistent cached_sizes
persistent cached_seed
persistent cached_class_func
if isequal(varargin, cached_varargin)
sizes = cached_sizes;
seed = cached_seed;
class_func = cached_class_func;
return
end
n = numel(varargin);
seed = 0;
sizes_cell = cell(1, n);
class_func = [];
has_processed_sizes = false;
% process each argument
k = 0;
while k < n
k = k + 1;
arg = varargin{k};
if isnumeric(arg)
if has_processed_sizes
error('size argument not allowed after string argument');
end
ensure_positive_vector(k, arg);
sizes_cell{k} = arg(:)';
elseif ischar(arg)
has_processed_sizes = true;
switch arg
case 'single'
if ~isempty(class_func)
error('type can only be set once');
end
class_func = @single;
case 'double'
if ~isempty(class_func)
error('type can only be set once');
end
class_func = @identify_func;
case 'seed'
k = k + 1;
if k > n
error('missing value after key ''%s''', arg);
end
value = varargin{k};
ensure_positive_scalar(k, value);
seed = value;
otherwise
error('unsupported key ''%s''', arg);
end
else
error('illegal input at position %d', k);
end
end
sizes = [sizes_cell{:}];
% no size provided, output is scalar
if isempty(sizes)
sizes = 1;
end
if isempty(class_func)
class_func = @identify_func;
end
cached_varargin = varargin;
cached_sizes = sizes;
cached_seed = seed;
cached_class_func = class_func;
function ensure_positive_scalar(k, arg)
ensure_positive_vector(k, arg);
if ~isscalar(arg)
error('argument at position %d is not a scalar', k);
end
function ensure_positive_vector(k, arg)
if ~isvector(arg) || ~isnumeric(arg) || ~(all(arg >= 0))
error('argument at position %d is not positive', k);
end