1use base::{
2 decision::{Prospect, CPT},
3 opinion::{
4 DeducedOpinions, FPhi, FPsi, FixedOpinions, KPhi, KPsi, MyFloat, MyOpinions, Phi, Psi,
5 StateOpinions, Theta, Thetad, A, B, FH, FO, H, KH, KO, O,
6 },
7};
8
9use itertools::Itertools;
10use rand::{seq::SliceRandom, Rng};
11use rand_distr::{Beta, Distribution, Exp1, Open01, Standard, StandardNormal, Uniform};
12use serde_with::{serde_as, FromInto, TryFromInto};
13
14use subjective_logic::{
15 domain::Domain,
16 mul::{
17 labeled::{OpinionD1, SimplexD1},
18 Simplex,
19 },
20 multi_array::labeled::{MArrD1, MArrD2},
21};
22
23#[serde_as]
24#[derive(Debug, Clone, serde::Deserialize)]
25#[serde(bound(deserialize = "V: serde::Deserialize<'de>"))]
26pub struct InitialOpinions<V: MyFloat> {
27 #[serde_as(as = "TryFromInto<(Vec<V>, V, Vec<V>)>")]
28 psi: OpinionD1<Psi, V>,
29 #[serde_as(as = "TryFromInto<(Vec<V>, V, Vec<V>)>")]
30 fpsi: OpinionD1<FPsi, V>,
31 #[serde_as(as = "TryFromInto<(Vec<V>, V, Vec<V>)>")]
32 kpsi: OpinionD1<KPsi, V>,
33 #[serde_as(as = "TryFromInto<(Vec<V>, V, Vec<V>)>")]
34 phi: OpinionD1<Phi, V>,
35 #[serde_as(as = "TryFromInto<(Vec<V>, V, Vec<V>)>")]
36 fphi: OpinionD1<FPhi, V>,
37 #[serde_as(as = "TryFromInto<(Vec<V>, V, Vec<V>)>")]
38 kphi: OpinionD1<KPhi, V>,
39 #[serde_as(as = "TryFromInto<(Vec<V>, V, Vec<V>)>")]
40 o: OpinionD1<O, V>,
41 #[serde_as(as = "TryFromInto<(Vec<V>, V, Vec<V>)>")]
42 fo: OpinionD1<FO, V>,
43 #[serde_as(as = "TryFromInto<(Vec<V>, V, Vec<V>)>")]
44 ko: OpinionD1<KO, V>,
45 #[serde_as(as = "TryFromInto<Vec<(Vec<V>, V)>>")]
46 h_psi_if_phi1: MArrD1<Psi, SimplexD1<H, V>>,
47 #[serde_as(as = "TryFromInto<Vec<(Vec<V>, V)>>")]
48 fh_fpsi_if_fphi1: MArrD1<FPsi, SimplexD1<FH, V>>,
49 #[serde_as(as = "TryFromInto<Vec<(Vec<V>, V)>>")]
50 kh_kpsi_if_kphi1: MArrD1<KPsi, SimplexD1<KH, V>>,
51 #[serde_as(as = "TryFromInto<Vec<(Vec<V>, V)>>")]
52 h_b_if_phi1: MArrD1<B, SimplexD1<H, V>>,
53}
54
55impl<V: MyFloat> InitialOpinions<V> {
56 fn reset_to(self, state: &mut StateOpinions<V>)
57 where
58 Standard: Distribution<V>,
59 StandardNormal: Distribution<V>,
60 Exp1: Distribution<V>,
61 Open01: Distribution<V>,
62 {
63 let InitialOpinions {
64 psi,
65 phi,
66 o,
67 fo,
68 ko,
69 h_psi_if_phi1,
70 h_b_if_phi1,
71 fpsi,
72 fphi,
73 fh_fpsi_if_fphi1,
74 kpsi,
75 kphi,
76 kh_kpsi_if_kphi1,
77 } = self;
78 state.reset(
79 psi,
80 phi,
81 o,
82 fo,
83 ko,
84 h_psi_if_phi1,
85 h_b_if_phi1,
86 fpsi,
87 fphi,
88 fh_fpsi_if_fphi1,
89 kpsi,
90 kphi,
91 kh_kpsi_if_kphi1,
92 );
93 }
94}
95
96#[serde_as]
97#[derive(Debug, serde::Deserialize, Clone)]
98#[serde(bound(deserialize = "V: serde::Deserialize<'de>"))]
99pub struct InitialBaseRates<V> {
100 #[serde_as(as = "TryFromInto<Vec<V>>")]
101 a: MArrD1<A, V>,
102 #[serde_as(as = "TryFromInto<Vec<V>>")]
103 b: MArrD1<B, V>,
104 #[serde_as(as = "TryFromInto<Vec<V>>")]
105 h: MArrD1<H, V>,
106 #[serde_as(as = "TryFromInto<Vec<V>>")]
107 fh: MArrD1<FH, V>,
108 #[serde_as(as = "TryFromInto<Vec<V>>")]
109 kh: MArrD1<KH, V>,
110 #[serde_as(as = "TryFromInto<Vec<V>>")]
111 theta: MArrD1<Theta, V>,
112 #[serde_as(as = "TryFromInto<Vec<V>>")]
113 thetad: MArrD1<Thetad, V>,
114}
115
116impl<V: MyFloat> InitialBaseRates<V> {
117 fn reset_to(self, ded: &mut DeducedOpinions<V>) {
118 let InitialBaseRates {
119 a,
120 b,
121 h,
122 fh,
123 kh,
124 theta,
125 thetad,
126 } = self;
127 ded.reset(
128 OpinionD1::vacuous_with(h),
129 OpinionD1::vacuous_with(fh),
130 OpinionD1::vacuous_with(kh),
131 OpinionD1::vacuous_with(a),
132 OpinionD1::vacuous_with(b),
133 OpinionD1::vacuous_with(theta),
134 OpinionD1::vacuous_with(thetad),
135 );
136 }
137}
138
139pub struct ConditionSamples<V>
140where
141 V: MyFloat,
142 Open01: Distribution<V>,
143{
144 pub h_psi_if_phi0: ConditionSampler<Psi, H, V>,
145 pub h_b_if_phi0: ConditionSampler<B, H, V>,
146 pub o_b: ConditionSampler<B, O, V>,
147 pub a_fh: ConditionSampler<FH, A, V>,
148 pub b_kh: ConditionSampler<KH, B, V>,
149 pub theta_h: ConditionSampler<H, Theta, V>,
150 pub thetad_h: ConditionSampler<H, Thetad, V>,
151}
152
153pub enum ConditionSampler<D0, D1, V>
154where
155 D0: Domain,
156 D1: Domain,
157 V: MyFloat,
158 Open01: Distribution<V>,
159{
160 Array(Vec<MArrD1<D0, SimplexD1<D1, V>>>),
161 Random(MArrD1<D0, SimplexContainer<D1::Idx, V>>),
162}
163
164impl<D0: Domain, D1: Domain<Idx: Copy>, V: MyFloat> ConditionSampler<D0, D1, V>
165where
166 Open01: Distribution<V>,
167{
168 pub fn sample<R: Rng>(&self, rng: &mut R) -> MArrD1<D0, SimplexD1<D1, V>> {
169 match self {
170 ConditionSampler::Array(vec) => vec.choose(rng).unwrap().to_owned(),
171 ConditionSampler::Random(marr_d1) => MArrD1::from_iter(marr_d1.into_iter().map(|c| {
172 let mut acc = V::zero();
173 let mut b = MArrD1::default();
174 let mut u = V::default();
175 for x in &c.fixed {
176 match x {
177 SimplexIndexed::B(d1, v) => {
178 acc += *v;
179 b[*d1] = *v;
180 }
181 SimplexIndexed::U(v) => {
182 acc += *v;
183 u = *v;
184 }
185 }
186 }
187 if let Some(x) = &c.sampler {
188 match x {
189 SimplexIndexed::B(d1, s) => {
190 let v = s.choose(rng);
191 acc += v;
192 b[*d1] = v;
193 }
194 SimplexIndexed::U(s) => {
195 let v = s.choose(rng);
196 acc += v;
197 u = v;
198 }
199 }
200 }
201 match &c.auto {
202 SimplexIndexed::B(d1, _) => b[*d1] = V::one() - acc,
203 SimplexIndexed::U(_) => u = V::one() - acc,
204 }
205 Simplex::new_unchecked(b, u)
206 })),
207 }
208 }
209}
210
211pub struct SimplexContainer<Idx, V>
212where
213 V: MyFloat,
214 Open01: Distribution<V>,
215{
216 pub sampler: Option<SimplexIndexed<Idx, Sampler<V>>>,
217 pub fixed: Vec<SimplexIndexed<Idx, V>>,
218 pub auto: SimplexIndexed<Idx, ()>,
219}
220
221pub enum SimplexIndexed<Idx, T> {
222 B(Idx, T),
223 U(T),
224}
225
226#[serde_as]
227#[derive(Debug, serde::Deserialize)]
228#[serde(bound(deserialize = "V: serde::Deserialize<'de>"))]
229pub struct InitialStates<V: MyFloat> {
230 pub initial_opinions: InitialOpinions<V>,
231 pub initial_base_rates: InitialBaseRates<V>,
232}
233
234pub struct OpinionSamples<V: MyFloat>
235where
236 Open01: Distribution<V>,
237{
238 pub initial_opinions: InitialOpinions<V>,
239 pub initial_base_rates: InitialBaseRates<V>,
240 pub condition: ConditionSamples<V>,
241 pub uncertainty: UncertaintySamples<V>,
242}
243
244pub struct UncertaintySamples<V> {
245 pub fh_fpsi_if_fphi0: Vec<MArrD1<FPsi, V>>,
246 pub kh_kpsi_if_kphi0: Vec<MArrD1<KPsi, V>>,
247 pub fh_fphi_fo: Vec<MArrD2<FPhi, FO, V>>,
248 pub kh_kphi_ko: Vec<MArrD2<KPhi, KO, V>>,
249}
250
251impl<V: MyFloat> OpinionSamples<V>
252where
253 Open01: Distribution<V>,
254{
255 pub fn reset_to<R: Rng>(&self, ops: &mut MyOpinions<V>, rng: &mut R)
256 where
257 Standard: Distribution<V>,
258 StandardNormal: Distribution<V>,
259 Exp1: Distribution<V>,
260 Open01: Distribution<V>,
261 {
262 reset_fixed(&self.condition, &self.uncertainty, &mut ops.fixed, rng);
263 self.initial_opinions.clone().reset_to(&mut ops.state);
264 self.initial_base_rates.clone().reset_to(&mut ops.ded);
265 }
266}
267
268fn reset_fixed<V: MyFloat, R: Rng>(
269 condition: &ConditionSamples<V>,
270 uncertainty: &UncertaintySamples<V>,
271 fixed: &mut FixedOpinions<V>,
272 rng: &mut R,
273) where
274 Open01: Distribution<V>,
275{
276 let o_b = condition.o_b.sample(rng);
277 let b_kh = condition.b_kh.sample(rng);
278 let a_fh = condition.a_fh.sample(rng);
279 let theta_h = condition.theta_h.sample(rng);
280 let thetad_h = condition.thetad_h.sample(rng);
281 let h_psi_if_phi0 = condition.h_psi_if_phi0.sample(rng);
282 let h_b_if_phi0 = condition.h_b_if_phi0.sample(rng);
283 let uncertainty_fh_fpsi_if_fphi0 = uncertainty.fh_fpsi_if_fphi0.choose(rng).unwrap().to_owned();
284 let uncertainty_kh_kpsi_if_kphi0 = uncertainty.kh_kpsi_if_kphi0.choose(rng).unwrap().to_owned();
285 let uncertainty_fh_fo_fphi = uncertainty.fh_fphi_fo.choose(rng).unwrap().to_owned();
286 let uncertainty_kh_ko_kphi = uncertainty.kh_kphi_ko.choose(rng).unwrap().to_owned();
287 fixed.reset(
288 o_b,
289 b_kh,
290 a_fh,
291 theta_h,
292 thetad_h,
293 h_psi_if_phi0,
294 h_b_if_phi0,
295 uncertainty_fh_fpsi_if_fphi0,
296 uncertainty_kh_kpsi_if_kphi0,
297 uncertainty_fh_fo_fphi,
298 uncertainty_kh_ko_kphi,
299 );
300}
301
302#[derive(Debug, serde::Deserialize)]
303#[serde(rename_all = "lowercase")]
304pub enum SamplerOption<V> {
305 Single(V),
306 Array(Vec<V>),
307 Uniform(V, V),
308 Beta(V, V),
309}
310
311pub enum Sampler<V: MyFloat>
312where
313 Open01: Distribution<V>,
314{
315 Single(V),
316 Arr(Vec<V>),
317 Uni(Uniform<V>),
318 Beta(Beta<V>),
319}
320
321impl<V> From<SamplerOption<V>> for Sampler<V>
322where
323 V: MyFloat,
324 Open01: Distribution<V>,
325{
326 fn from(value: SamplerOption<V>) -> Self {
327 match value {
328 SamplerOption::Single(v) => Self::Single(v),
329 SamplerOption::Array(v) => Self::Arr(v),
330 SamplerOption::Uniform(low, high) => Self::Uni(Uniform::new(low, high)),
331 SamplerOption::Beta(alpha, beta) => Self::Beta(Beta::new(alpha, beta).unwrap()),
332 }
333 }
334}
335
336impl<V> Sampler<V>
337where
338 V: MyFloat,
339 Open01: Distribution<V>,
340{
341 pub fn choose<R: Rng>(&self, rng: &mut R) -> V {
342 match self {
343 Self::Single(v) => *v,
344 Self::Arr(v) => *v.choose(rng).unwrap(),
345 Self::Uni(u) => u.sample(rng),
346 Self::Beta(b) => b.sample(rng),
347 }
348 }
349}
350
351#[serde_as]
352#[derive(serde::Deserialize)]
353#[serde(bound(deserialize = "V: serde::Deserialize<'de>"))]
354pub struct ProbabilitySamples<V>
355where
356 V: MyFloat,
357 Open01: Distribution<V>,
358{
359 #[serde_as(as = "FromInto<SamplerOption<V>>")]
360 pub viewing: Sampler<V>,
361 #[serde_as(as = "FromInto<SamplerOption<V>>")]
362 pub viewing_friend: Sampler<V>,
363 #[serde_as(as = "FromInto<SamplerOption<V>>")]
364 pub viewing_social: Sampler<V>,
365 #[serde_as(as = "FromInto<SamplerOption<V>>")]
366 pub arrival_friend: Sampler<V>,
367 #[serde_as(as = "FromInto<SamplerOption<V>>")]
368 pub plural_ignore_friend: Sampler<V>,
369 #[serde_as(as = "FromInto<SamplerOption<V>>")]
370 pub plural_ignore_social: Sampler<V>,
371}
372
373#[serde_as]
374#[derive(serde::Deserialize)]
375#[serde(bound(deserialize = "V: serde::Deserialize<'de>"))]
376pub struct SharerTrustSamples<V>
377where
378 V: MyFloat,
379 Open01: Distribution<V>,
380{
381 #[serde_as(as = "FromInto<SamplerOption<V>>")]
382 pub misinfo: Sampler<V>,
383 #[serde_as(as = "FromInto<SamplerOption<V>>")]
384 pub correction: Sampler<V>,
385 #[serde_as(as = "FromInto<SamplerOption<V>>")]
386 pub obserbation: Sampler<V>,
387 #[serde_as(as = "FromInto<SamplerOption<V>>")]
388 pub inhibition: Sampler<V>,
389}
390
391#[derive(Debug, serde::Deserialize)]
392pub enum PopSampleType<V> {
393 Random(V),
394 Top(V),
395 Middle(V),
396 Bottom(V),
397}
398
399#[derive(Debug, serde::Deserialize)]
400pub struct Informing<V> {
401 pub step: u32,
402 pub pop_agents: V,
403}
404
405#[derive(Debug, serde::Deserialize)]
406pub struct InformingParams<V> {
407 pub max_pop_misinfo: V,
409 pub misinfo: Vec<Informing<V>>,
410
411 pub max_pop_correction: V,
413 pub correction: Vec<Informing<V>>,
414
415 pub max_pop_observation: V,
416 pub prob_post_observation: V,
417 pub max_step_pop_observation: V,
418
419 pub max_pop_inhibition: PopSampleType<V>,
421 pub inhibition: Vec<Informing<V>>,
422}
423
424pub struct InformationSamples<V> {
425 pub misinfo: Vec<OpinionD1<Psi, V>>,
427 pub correction: Vec<OpinionD1<Psi, V>>,
428 pub observation: Vec<OpinionD1<O, V>>,
429 pub inhibition: Vec<(
430 OpinionD1<Phi, V>,
431 MArrD1<Psi, SimplexD1<H, V>>,
432 MArrD1<B, SimplexD1<H, V>>,
433 )>,
434}
435
436pub struct SupportLevelTable<V> {
437 pub levels: Vec<V>,
439 pub indexes_by_level: Vec<usize>,
441}
442
443impl<V: MyFloat> SupportLevelTable<V> {
444 pub fn level(&self, idx: usize) -> V {
445 self.levels[idx]
446 }
447
448 pub fn from_vec(levels: Vec<V>) -> Self {
450 let indexes_by_level = levels
451 .iter()
452 .enumerate()
453 .sorted_by(|a, b| b.1.partial_cmp(&a.1).unwrap())
454 .map(|(i, _)| i)
455 .collect_vec();
456 Self {
457 levels,
458 indexes_by_level,
459 }
460 }
461
462 pub fn random<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<usize> {
480 self.indexes_by_level
481 .choose_multiple(rng, n)
482 .cloned()
483 .collect()
484 }
485
486 pub fn top<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<usize> {
487 let mut v = self.indexes_by_level.iter().take(n).cloned().collect_vec();
488 v.shuffle(rng);
489 v
490 }
491
492 pub fn middle<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<usize>
493 where
494 V: MyFloat,
495 {
496 macro_rules! level_of {
497 ($e:expr) => {
498 self.levels[self.indexes_by_level[$e]]
499 };
500 }
501 let l = self.indexes_by_level.len();
502 let c = self.indexes_by_level.len() / 2;
503 let median = if l % 2 == 1 {
504 level_of!(c)
505 } else {
506 (level_of!(c) + level_of!(c - 1)) / V::from_u32(2).unwrap()
507 };
508
509 let from = c.checked_sub(n).unwrap_or(0);
510 let to = (c + n).min(l);
511 let mut v = (from..to)
512 .sorted_by(|&i, &j| {
513 let a = (level_of!(i) - median).abs();
514 let b = (level_of!(j) - median).abs();
515 a.partial_cmp(&b).unwrap()
516 })
517 .take(n)
518 .map(|i| self.indexes_by_level[i])
519 .collect_vec();
520 v.shuffle(rng);
521 v
522 }
523
524 pub fn bottom<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<usize> {
525 let mut v = self
526 .indexes_by_level
527 .iter()
528 .rev()
529 .take(n)
530 .cloned()
531 .collect_vec();
532 v.shuffle(rng);
533 v
534 }
535}
536
537#[derive(Debug, serde::Deserialize)]
538pub struct CptRecord<V> {
539 alpha: V,
540 beta: V,
541 gamma: V,
542 delta: V,
543 lambda: V,
544}
545
546pub struct CptSamples<V>(pub Vec<CptRecord<V>>);
547
548impl<V: MyFloat> CptSamples<V> {
549 pub fn reset_to<R: Rng>(&self, cpt: &mut CPT<V>, rng: &mut R) {
550 let &CptRecord {
551 alpha,
552 beta,
553 gamma,
554 delta,
555 lambda,
556 } = self.0.choose(rng).unwrap();
557 cpt.reset(alpha, beta, lambda, gamma, delta);
558 }
559}
560
561#[derive(Debug, serde::Deserialize)]
562pub struct ProspectRecord<V> {
563 x0: V,
564 x1: V,
565 y: V,
566}
567
568pub struct ProspectSamples<V>(pub Vec<ProspectRecord<V>>);
569
570impl<V: MyFloat> ProspectSamples<V> {
571 pub fn reset_to<R: Rng>(&self, prospect: &mut Prospect<V>, rng: &mut R) {
572 let &ProspectRecord { x0, x1, y } = self.0.choose(rng).unwrap();
573 prospect.reset(x0, x1, y);
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::SupportLevelTable;
580
581 #[test]
582 fn test_support_levels() {
583 let levels = vec![0.5, 0.2, 0.3, 0.1, 0.6];
584 let sls = SupportLevelTable::<f32>::from_vec(levels);
585 assert_eq!(sls.indexes_by_level, vec![4, 0, 2, 1, 3]);
586 }
587}