source: mainline/uspace/lib/cpp/include/impl/random.hpp@ 08be4a4

lfn serial ticket/834-toolchain-update topic/msim-upgrade topic/simplify-dev-export
Last change on this file since 08be4a4 was 08be4a4, checked in by Dzejrou <dzejrou@…>, 7 years ago

cpp: used constexpr builtin wrappers to avoid reallocation of the helper array in congruential engine, fixed compile time bugs and added mersenne_twister_engine

  • Property mode set to 100644
File size: 20.5 KB
Line 
1/*
2 * Copyright (c) 2018 Jaroslav Jindrak
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 *
9 * - Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 * - Redistributions in binary form must reproduce the above copyright
12 * notice, this list of conditions and the following disclaimer in the
13 * documentation and/or other materials provided with the distribution.
14 * - The name of the author may not be used to endorse or promote products
15 * derived from this software without specific prior written permission.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
18 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
19 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
20 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
21 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
22 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
26 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 */
28
29#ifndef LIBCPP_RANDOM
30#define LIBCPP_RANDOM
31
32#include <cstdlib>
33#include <ctime>
34#include <initializer_list>
35#include <internal/builtins.hpp>
36#include <limits>
37#include <type_traits>
38#include <vector>
39
40/**
41 * Note: Variables with one or two lettered
42 * names here are named after their counterparts in
43 * the standard. If one needs to understand their meaning,
44 * they should seek the mentioned standard section near
45 * the declaration of these variables.
46 * Note: There will be a lot of mathematical expressions in this header.
47 * All of these are taken directly from the standard's requirements
48 * and as such won't be commented here, check the appropriate
49 * sections if you need explanation of these forumulae.
50 */
51
52namespace std
53{
54 namespace aux
55 {
56 /**
57 * This is the minimum requirement imposed by the
58 * standard for a type to qualify as a seed sequence
59 * in overloading resolutions.
60 * (This is because the engines have constructors
61 * that accept sequence and seed and without this
62 * minimal requirements overload resolution would fail.)
63 */
64 template<class Sequence, class ResultType>
65 struct is_seed_sequence
66 : aux::value_is<
67 bool, !is_convertible_v<Sequence, ResultType>
68 >
69 { /* DUMMY BODY */ };
70
71 template<class T, class Engine>
72 inline constexpr bool is_seed_sequence_v = is_seed_sequence<T, Engine>::value;
73 }
74
75 /**
76 * 26.5.3.1, class template linear_congruential_engine:
77 */
78
79 template<class UIntType, UIntType a, UIntType c, UIntType m>
80 class linear_congruential_engine
81 {
82 static_assert(m == 0 || (a < m && c < m));
83
84 public:
85 using result_type = UIntType;
86
87 static constexpr result_type multiplier = a;
88 static constexpr result_type increment = c;
89 static constexpr result_type modulus = m;
90
91 static constexpr result_type min()
92 {
93 return c == 0U ? 1U : 0U;
94 }
95
96 static constexpr result_type max()
97 {
98 return m - 1U;
99 }
100
101 static constexpr result_type default_seed = 1U;
102
103 explicit linear_congruential_engine(result_type s = default_seed)
104 : state_{}
105 {
106 seed(s);
107 }
108
109 linear_congruential_engine(const linear_congruential_engine& other)
110 : state_{other.state_}
111 { /* DUMMY BODY */ }
112
113 template<class Seq>
114 explicit linear_congruential_engine(
115 enable_if_t<aux::is_seed_sequence_v<Seq, result_type>, Seq&> q
116 )
117 : state_{}
118 {
119 seed(q);
120 }
121
122 void seed(result_type s = default_seed)
123 {
124 if (c % modulus_ == 0 && s == 0)
125 state_ = 0;
126 else
127 state_ = s;
128 }
129
130 template<class Seq>
131 void seed(
132 enable_if_t<aux::is_seed_sequence_v<Seq, result_type>, Seq&> q
133 )
134 {
135 q.generate(arr_, arr_ + k_ + 3);
136
137 result_type s{};
138 for (size_t j = 0; j < k_; ++j)
139 s += arr_[j + 3] * aux::pow2(32U * j);
140 s = s % modulus_;
141
142 seed(s);
143 }
144
145 result_type operator()()
146 {
147 return generate_();
148 }
149
150 void discard(unsigned long long z)
151 {
152 for (unsigned long long i = 0ULL; i < z; ++i)
153 transition_();
154 }
155
156 bool operator==(const linear_congruential_engine& rhs) const
157 {
158 return state_ = rhs.state_;
159 }
160
161 bool operator!=(const linear_congruential_engine& rhs) const
162 {
163 return !(*this == rhs);
164 }
165
166 template<class Char, class Traits>
167 basic_ostream<Char, Traits>& operator<<(basic_ostream<Char, Traits>& os) const
168 {
169 auto flags = os.flags();
170 os.flags(ios_base::dec | ios_base::left);
171
172 os << state_;
173
174 os.flags(flags);
175 return os;
176 }
177
178 template<class Char, class Traits>
179 basic_istream<Char, Traits>& operator>>(basic_istream<Char, Traits>& is) const
180 {
181 auto flags = is.flags();
182 is.flags(ios_base::dec);
183
184 result_type tmp{};
185 if (is >> tmp)
186 state_ = tmp;
187 else
188 is.setstate(ios::failbit);
189
190 is.flags(flags);
191 return is;
192 }
193
194 private:
195 result_type state_;
196
197 static constexpr result_type modulus_ =
198 (m == 0) ? (numeric_limits<result_type>::max() + 1) : m;
199
200 /**
201 * We use constexpr builtins to keep this array
202 * between calls to seed(Seq&), which means we don't
203 * have to keep allocating and deleting it.
204 */
205 static constexpr size_t k_ = static_cast<size_t>(aux::ceil(aux::log2(modulus_) / 32));
206 result_type arr_[k_ + 3];
207
208 void transition_()
209 {
210 state_ = (a * state_ + c) % modulus_;
211 }
212
213 result_type generate_()
214 {
215 transition_();
216
217 return state_;
218 }
219 };
220
221 /**
222 * 26.5.3.2, class template mersenne_twister_engine:
223 */
224
225 template<
226 class UIntType, size_t w, size_t n, size_t m, size_t r,
227 UIntType a, size_t u, UIntType d, size_t s,
228 UIntType b, size_t t, UIntType c, size_t l, UIntType f
229 >
230 class mersenne_twister_engine
231 {
232 // TODO: fix these
233 /* static_assert(0 < m && m <= n); */
234 /* static_assert(2 * u < w); */
235 /* static_assert(r <= w && u <= w && s <= w && t <= w && l <= w); */
236 /* /1* static_assert(w <= numeric_limits<UIntType>::digits); *1/ */
237 /* static_assert(a <= (1U << w) - 1U); */
238 /* static_assert(b <= (1U << w) - 1U); */
239 /* static_assert(c <= (1U << w) - 1U); */
240 /* static_assert(d <= (1U << w) - 1U); */
241 /* static_assert(f <= (1U << w) - 1U); */
242
243 public:
244 using result_type = UIntType;
245
246 static constexpr size_t word_size = w;
247 static constexpr size_t state_size = n;
248 static constexpr size_t shift_size = m;
249 static constexpr size_t mask_bits = r;
250 static constexpr UIntType xor_mask = a;
251
252 static constexpr size_t tempering_u = u;
253 static constexpr UIntType tempering_d = d;
254 static constexpr size_t tempering_s = s;
255 static constexpr UIntType tempering_b = b;
256 static constexpr size_t tempering_t = t;
257 static constexpr UIntType tempering_c = c;
258 static constexpr size_t tempering_l = l;
259
260 static constexpr UIntType initialization_multiplier = f;
261
262 static constexpr result_type min()
263 {
264 return result_type{};
265 }
266
267 static constexpr result_type max()
268 {
269 return static_cast<result_type>(aux::pow2(w)) - 1U;
270 }
271
272 static constexpr result_type default_seed = 5489U;
273
274 explicit mersenne_twister_engine(result_type value = default_seed)
275 : state_{}, i_{}
276 {
277 seed(value);
278 }
279
280 template<class Seq>
281 explicit mersenne_twister_engine(
282 enable_if_t<aux::is_seed_sequence_v<Seq, result_type>, Seq&> q
283 )
284 : state_{}, i_{}
285 {
286 seed(q);
287 }
288
289 void seed(result_type value = default_seed)
290 {
291 state_[idx_(-n)] = value % aux::pow2u(w);;
292
293 for (long long i = 1 - n; i <= -1; ++i)
294 {
295 state_[idx_(i)] = (f * (state_[idx_(i - 1)] ^
296 (state_[idx_(i - 1)] >> (w - 2))) + 1 % n) % aux::pow2u(w);
297 }
298 }
299
300 template<class Seq>
301 void seed(
302 enable_if_t<aux::is_seed_sequence_v<Seq, result_type>, Seq&> q
303 )
304 {
305 q.generate(arr_, arr_ + n * k_);
306
307 for (long long i = -n; i <= -1; ++i)
308 {
309 state_[idx_(i)] = result_type{};
310 for (long long j = 0; j < k_; ++j)
311 state_[idx_(i)] += arr_[k_ * (i + n) + j] * aux::pow2(32 * j);
312 state_[idx_(i)] %= aux::pow2(w);
313 }
314 }
315
316 result_type operator()()
317 {
318 return generate_();
319 }
320
321 void discard(unsigned long long z)
322 {
323 for (unsigned long long i = 0ULL; i < z; ++i)
324 transition_();
325 }
326
327 bool operator==(const mersenne_twister_engine& rhs) const
328 {
329 for (size_t i = 0; i < n; ++i)
330 {
331 if (state_[i] != rhs.state_[i])
332 return false;
333 }
334
335 return true;
336 }
337
338 bool operator!=(const mersenne_twister_engine& rhs) const
339 {
340 return !(*this == rhs);
341 }
342
343 template<class Char, class Traits>
344 basic_ostream<Char, Traits>& operator<<(basic_ostream<Char, Traits>& os) const
345 {
346 auto flags = os.flags();
347 os.flags(ios_base::dec | ios_base::left);
348
349 for (size_t j = n + 1; j > 1; --j)
350 {
351 os << state_[idx_(i_ - j - 1)];
352
353 if (j > 2)
354 os << os.widen(' ');
355 }
356
357 os.flags(flags);
358 return os;
359 }
360
361 template<class Char, class Traits>
362 basic_istream<Char, Traits>& operator>>(basic_istream<Char, Traits>& is) const
363 {
364 auto flags = is.flags();
365 is.flags(ios_base::dec);
366
367 for (size_t j = n + 1; j > 1; --j)
368 {
369 if (!(is >> state_[idx_(i_ - j - 1)]))
370 {
371 is.setstate(ios::failbit);
372 break;
373 }
374 }
375
376 is.flags(flags);
377 return is;
378 }
379
380 private:
381 result_type state_[n];
382 size_t i_;
383
384 static constexpr size_t k_ = static_cast<size_t>(w / 32);
385 result_type arr_[n * k_];
386
387 void transition_()
388 {
389 auto mask = (result_type{1} << r) - 1;
390 auto y = (state_[idx_(i_ - n)] & ~mask) | (state_[idx_(i_ + 1 - n)] & mask);
391 auto alpha = a * (y & 1);
392 state_[i_] = state_[idx_(i_ + m - n)] ^ (y >> 1) ^ alpha;
393
394 i_ = (i_ + 1) % n;
395 }
396
397 result_type generate_()
398 {
399 auto z1 = state_[i_] ^ ((state_[i_] >> u) & d);
400 auto z2 = z1 ^ (lshift_(z1, s) & b);
401 auto z3 = z2 ^ (lshift_(z2, t) & c);
402 auto z4 = z3 ^ (z3 >> l);
403
404 transition_();
405
406 return z4;
407 }
408
409 size_t idx_(size_t idx) const
410 {
411 return idx % n;
412 }
413
414 result_type lshift_(result_type val, size_t count)
415 {
416 return (val << count) % aux::pow2u(w);
417 }
418 };
419
420 /**
421 * 26.5.3.3, class template subtract_with_carry_engine:
422 */
423
424 template<class UIntType, size_t w, size_t s, size_t r>
425 class subtract_with_carry_engine;
426
427 /**
428 * 26.5.4.2, class template discard_block_engine:
429 */
430
431 template<class Engine, size_t p, size_t r>
432 class discard_block_engine;
433
434 /**
435 * 26.5.4.3, class template independent_bits_engine:
436 */
437
438 template<class Engine, size_t w, class UIntType>
439 class independent_bits_engine;
440
441 /**
442 * 26.5.4.4, class template shiffle_order_engine:
443 */
444
445 template<class Engine, size_t k>
446 class shuffle_order_engine;
447
448 /**
449 * 26.5.5, engines and engine adaptors with predefined
450 * parameters:
451 * TODO: check their requirements for testing
452 */
453
454 using minstd_rand0 = linear_congruential_engine<uint_fast32_t, 16807, 0, 2147483647>;
455 using minstd_rand = linear_congruential_engine<uint_fast32_t, 48271, 0, 2147483647>;
456 using mt19937 = mersenne_twister_engine<
457 uint_fast32_t, 32, 624, 397, 31, 0x9908b0df, 11, 0xffffffff, 7,
458 0x9d2c5680, 15, 0xefc60000, 18, 1812433253
459 >;
460 using mt19937_64 = mersenne_twister_engine<
461 uint_fast64_t, 64, 312, 156, 31, 0xb5026f5aa96619e9, 29,
462 0x5555555555555555, 17, 0x71d67fffeda60000, 37, 0xfff7eee000000000,
463 43, 6364136223846793005
464 >;
465 using ranlux24_base = subtract_with_carry_engine<uint_fast32_t, 24, 10, 24>;
466 using ranlux48_base = subtract_with_carry_engine<uint_fast64_t, 48, 5, 12>;
467 using ranlux24 = discard_block_engine<ranlux24_base, 223, 23>;
468 using ranlux48 = discard_block_engine<ranlux48_base, 389, 11>;
469 using knuth_b = shuffle_order_engine<minstd_rand0, 256>;
470
471 using default_random_engine = minstd_rand0;
472
473 /**
474 * 26.5.6, class random_device:
475 */
476
477 class random_device
478 {
479 using result_type = unsigned int;
480
481 static constexpr result_type min()
482 {
483 return numeric_limits<result_type>::min();
484 }
485
486 static constexpr result_type max()
487 {
488 return numeric_limits<result_type>::max();
489 }
490
491 explicit random_device(const string& token = "")
492 {
493 /**
494 * Note: token can be used to choose between
495 * random generators, but HelenOS only
496 * has one :/
497 * Also note that it is implementation
498 * defined how this class generates
499 * random numbers and I decided to use
500 * time seeding with C stdlib random,
501 * - feel free to change it if you know
502 * something better.
503 */
504 hel::srandom(hel::time(nullptr));
505 }
506
507 result_type operator()()
508 {
509 return hel::random();
510 }
511
512 double entropy() const noexcept
513 {
514 return 0.0;
515 }
516
517 random_device(const random_device&) = delete;
518 random_device& operator=(const random_device&) = delete;
519 };
520
521 /**
522 * 26.5.7.1, class seed_seq:
523 */
524
525 class seed_seq
526 {
527 public:
528 using result_type = uint_least32_t;
529
530 seed_seq()
531 : vec_{}
532 { /* DUMMY BODY */ }
533
534 template<class T>
535 seed_seq(initializer_list<T> init)
536 : seed_seq(init.begin(), init.end())
537 { /* DUMMY BODY */ }
538
539 template<class InputIterator>
540 seed_seq(InputIterator first, InputIterator last)
541 : vec_{}
542 {
543 while (first != last)
544 vec_.push_back(*first++ % aux::pow2(32));
545 }
546
547 template<class RandomAccessGenerator>
548 void generate(RandomAccessGenerator first,
549 RandomAccessGenerator last)
550 {
551 if (first == last)
552 return;
553
554 // TODO: research this
555 }
556
557 size_t size() const
558 {
559 return vec_.size();
560 }
561
562 template<class OutputIterator>
563 void param(OutputIterator dest) const
564 {
565 for (const auto& x: vec_)
566 *dest++ = x;
567 }
568
569 seed_seq(const seed_seq&) = delete;
570 seed_seq& operator=(const seed_seq&) = delete;
571
572 private:
573 vector<result_type> vec_;
574 };
575
576 /**
577 * 26.5.7.2, function template generate_canonical:
578 */
579
580 template<class RealType, size_t bits, class URNG>
581 RealType generate_canonical(URNG& g);
582
583 /**
584 * 26.5.8.2.1, class template uniform_int_distribution:
585 */
586
587 template<class IntType = int>
588 class uniform_int_distribution;
589
590 /**
591 * 26.5.8.2.2, class template uniform_real_distribution:
592 */
593
594 template<class RealType = double>
595 class uniform_real_distribution;
596
597 /**
598 * 26.5.8.3.1, class bernoulli_distribution:
599 */
600
601 class bernoulli_distribution;
602
603 /**
604 * 26.5.8.3.2, class template binomial_distribution:
605 */
606
607 template<class IntType = int>
608 class binomial_distribution;
609
610 /**
611 * 26.5.8.3.3, class template geometric_distribution:
612 */
613
614 template<class IntType = int>
615 class geometric_distribution;
616
617 /**
618 * 26.5.8.3.4, class template negative_binomial_distribution:
619 */
620
621 template<class IntType = int>
622 class negative_binomial_distribution;
623
624 /**
625 * 26.5.8.4.1, class template poisson_distribution:
626 */
627
628 template<class IntType = int>
629 class poisson_distribution;
630
631 /**
632 * 26.5.8.4.2, class template exponential_distribution:
633 */
634
635 template<class RealType = double>
636 class exponential_distribution;
637
638 /**
639 * 26.5.8.4.3, class template gamma_distribution:
640 */
641
642 template<class RealType = double>
643 class gamma_distribution;
644
645 /**
646 * 26.5.8.4.4, class template weibull_distribution:
647 */
648
649 template<class RealType = double>
650 class weibull_distribution;
651
652 /**
653 * 26.5.8.4.5, class template extreme_value_distribution:
654 */
655
656 template<class RealType = double>
657 class extreme_value_distribution;
658
659 /**
660 * 26.5.8.5.1, class template normal_distribution:
661 */
662
663 template<class RealType = double>
664 class normal_distribution;
665
666 /**
667 * 26.5.8.5.2, class template lognormal_distribution:
668 */
669
670 template<class RealType = double>
671 class lognormal_distribution;
672
673 /**
674 * 26.5.8.5.3, class template chi_squared_distribution:
675 */
676
677 template<class RealType = double>
678 class chi_squared_distribution;
679
680 /**
681 * 26.5.8.5.4, class template cauchy_distribution:
682 */
683
684 template<class RealType = double>
685 class cauchy_distribution;
686
687 /**
688 * 26.5.8.5.5, class template fisher_f_distribution:
689 */
690
691 template<class RealType = double>
692 class fisher_f_distribution;
693
694 /**
695 * 26.5.8.5.6, class template student_t_distribution:
696 */
697
698 template<class RealType = double>
699 class student_t_distribution;
700
701 /**
702 * 26.5.8.6.1, class template discrete_distribution:
703 */
704
705 template<class IntType = int>
706 class discrete_distribution;
707
708 /**
709 * 26.5.8.6.2, class template piecewise_constant_distribution:
710 */
711
712 template<class RealType = double>
713 class piecewise_constant_distribution;
714
715 /**
716 * 26.5.8.6.3, class template piecewise_linear_distribution:
717 */
718
719 template<class RealType = double>
720 class piecewise_linear_distribution;
721}
722
723#endif
Note: See TracBrowser for help on using the repository browser.