1use std::{fmt, iter::Sum, ops::AddAssign};
2
3use approx::{ulps_eq, UlpsEq};
4use num_traits::Float;
5use subjective_logic::{
6 iter::Container,
7 marr_d1, marr_d2,
8 multi_array::labeled::{MArrD1, MArrD2},
9};
10use tracing::debug;
11
12use crate::opinion::{Theta, Thetad, A};
13
14#[derive(Clone, Default, Debug)]
15pub struct CPT<V> {
16 alpha: V,
17 beta: V,
18 lambda: V,
19 gamma: V,
20 delta: V,
21}
22
23impl<V> CPT<V> {
24 pub fn reset(&mut self, alpha: V, beta: V, lambda: V, gamma: V, delta: V) {
25 self.alpha = alpha;
26 self.beta = beta;
27 self.lambda = lambda;
28 self.gamma = gamma;
29 self.delta = delta;
30 }
31
32 fn w(p: V, e: V) -> V
48 where
49 V: Float + UlpsEq + AddAssign + Sum,
50 {
51 if ulps_eq!(p, V::one()) {
52 V::one()
53 } else {
54 let temp = p.powf(e);
55 temp / (temp + (V::one() - p).powf(e)).powf(V::one() / e)
56 }
57 }
58
59 fn positive_weight(&self, p: V) -> V
61 where
62 V: Float + UlpsEq + AddAssign + Sum,
63 {
64 Self::w(p, self.gamma)
65 }
66
67 fn negative_weight(&self, p: V) -> V
69 where
70 V: Float + UlpsEq + AddAssign + Sum,
71 {
72 Self::w(p, self.delta)
73 }
74
75 fn positive_value(&self, x: V) -> V
77 where
78 V: Float + UlpsEq + AddAssign + Sum,
79 {
80 x.powf(self.alpha)
81 }
82
83 fn negative_value(&self, x: V) -> V
85 where
86 V: Float + UlpsEq + AddAssign + Sum,
87 {
88 -self.lambda * x.abs().powf(self.beta)
89 }
90
91 fn positive_valuate<P: Container<Idx, Output = V>, Idx: Copy>(
93 &self,
94 positive_level_sets: &[(V, Vec<Idx>)],
95 prob: &P,
96 ) -> V
97 where
98 V: Float + UlpsEq + AddAssign + Sum,
99 {
100 positive_level_sets
101 .iter()
102 .scan((V::zero(), V::zero()), |(w, acc), (o, ids)| {
103 let w0 = *w;
104 *acc += ids.iter().map(|i| prob[*i]).sum::<V>();
105 *w = self.positive_weight(*acc);
106 Some(self.positive_value(*o) * (*w - w0))
107 })
108 .sum::<V>()
109 }
110
111 fn negative_valuate<P: Container<Idx, Output = V>, Idx: Copy>(
113 &self,
114 negative_level_sets: &[(V, Vec<Idx>)],
115 prob: &P,
116 ) -> V
117 where
118 V: Float + UlpsEq + AddAssign + Sum,
119 {
120 negative_level_sets
121 .iter()
122 .scan((V::zero(), V::zero()), |(w, acc), (o, ids)| {
123 let w0 = *w;
124 *acc += ids.iter().map(|i| prob[*i]).sum::<V>();
125 *w = self.negative_weight(*acc);
126 Some(self.negative_value(*o) * (*w - w0))
127 })
128 .sum::<V>()
129 }
130
131 pub fn valuate<P: Container<Idx, Output = V>, Idx: Copy>(
133 &self,
134 level_sets: &LevelSet<Idx, V>,
135 prob: &P,
136 ) -> V
137 where
138 V: Float + UlpsEq + AddAssign + Sum,
139 {
140 self.positive_valuate(&level_sets.positive, prob)
141 + self.negative_valuate(&level_sets.negative, prob)
142 }
143}
144
145#[derive(Clone, Debug)]
146pub struct LevelSet<Idx, V> {
147 positive: Vec<(V, Vec<Idx>)>,
148 negative: Vec<(V, Vec<Idx>)>,
149}
150
151impl<Idx, V> Default for LevelSet<Idx, V> {
152 fn default() -> Self {
153 Self {
154 positive: Vec::new(),
155 negative: Vec::new(),
156 }
157 }
158}
159
160impl<Idx, V: Float> LevelSet<Idx, V> {
161 pub fn new<T>(outcome: &T) -> Self
162 where
163 T: Container<Idx, Output = V>,
164 for<'a> &'a T: IntoIterator<Item = &'a V>,
165 Idx: Copy,
166 {
167 let mut pos = T::indexes()
168 .zip(outcome)
169 .filter(|(_, &o)| o > V::zero())
170 .collect::<Vec<_>>();
171 pos.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
172 let mut positive = Vec::<(V, Vec<Idx>)>::new();
173 for (i, &o) in pos {
174 match positive.last_mut() {
175 Some((o2, v)) if o == *o2 => {
176 v.push(i);
177 }
178 _ => {
179 positive.push((o, vec![i]));
180 }
181 }
182 }
183 let mut neg = T::indexes()
184 .zip(outcome)
185 .filter(|(_, &o)| o < V::zero())
186 .collect::<Vec<_>>();
187 neg.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
188 let mut negative = Vec::<(V, Vec<Idx>)>::new();
189 for (i, &o) in neg {
190 match negative.last_mut() {
191 Some((o2, v)) if o == *o2 => {
192 v.push(i);
193 }
194 _ => {
195 negative.push((o, vec![i]));
196 }
197 }
198 }
199 Self { positive, negative }
200 }
201}
202
203#[derive(Debug, Default)]
204pub struct Prospect<V> {
205 pub selfish: [LevelSet<Theta, V>; 2],
206 pub sharing: [LevelSet<(A, Thetad), V>; 2],
207}
208
209impl<V> Prospect<V> {
210 pub fn reset(&mut self, x0: V, x1: V, y: V)
211 where
212 V: Float + fmt::Debug,
213 {
214 let selfish_outcome_maps: [MArrD1<Theta, V>; 2] =
215 [marr_d1![V::zero(), x1], marr_d1![x0, x0]];
216 let sharing_outcome_maps: [MArrD2<A, Thetad, V>; 2] = [
217 marr_d2![[V::zero(), x1], [x0, x0]],
218 marr_d2![[y, x1 + y], [x0 + y, x0 + y]],
219 ];
220 let selfish = [
221 LevelSet::new(&selfish_outcome_maps[0]),
222 LevelSet::new(&selfish_outcome_maps[1]),
223 ];
224 let sharing = [
225 LevelSet::new(&sharing_outcome_maps[0]),
226 LevelSet::new(&sharing_outcome_maps[1]),
227 ];
228
229 debug!(target: "X outcomes", x = ?selfish_outcome_maps);
230 debug!(target: "Y outcomes", x = ?sharing_outcome_maps);
231 self.selfish = selfish;
232 self.sharing = sharing;
233 }
234
235 }
247
248#[cfg(test)]
249mod tests {
250 use crate::decision::{LevelSet, CPT};
251 use approx::ulps_eq;
252 use subjective_logic::{domain::Domain, impl_domain, marr_d2};
253
254 fn v(o: f32) -> f32 {
255 if o.is_sign_negative() {
256 -2.25 * (-o).powf(0.88)
257 } else {
258 o.powf(0.88)
259 }
260 }
261 fn wp(p: f32) -> f32 {
262 let q = p.powf(0.61);
263 q / (q + (1.0 - p).powf(0.61)).powf(1.0 / 0.61)
264 }
265 fn wm(p: f32) -> f32 {
266 let q = p.powf(0.69);
267 q / (q + (1.0 - p).powf(0.69)).powf(1.0 / 0.69)
268 }
269
270 #[test]
271 fn test_cpt() {
272 let outcome = [6.0, 2.0, 4.0, -3.0, -1.0, -5.0];
273 let prob = [1.0 / 6.0; 6];
274 let cpt = CPT {
275 alpha: 0.88,
276 beta: 0.88,
277 lambda: 2.25,
278 gamma: 0.61,
279 delta: 0.69,
280 };
281 let ls = LevelSet::<_, f32>::new(&outcome);
282 let a = cpt.valuate(&ls, &prob);
283 let b = v(2.0) * (wp(1.0 / 2.0) - wp(1.0 / 3.0))
284 + v(4.0) * (wp(1.0 / 3.0) - wp(1.0 / 6.0))
285 + v(6.0) * (wp(1.0 / 6.0) - wp(0.0))
286 + v(-5.0) * (wm(1.0 / 6.0) - wm(0.0))
287 + v(-3.0) * (wm(1.0 / 3.0) - wm(1.0 / 6.0))
288 + v(-1.0) * (wm(1.0 / 2.0) - wm(1.0 / 3.0));
289
290 let c = cpt.positive_value(2.0)
291 * (cpt.positive_weight(1.0 / 2.0) - cpt.positive_weight(1.0 / 3.0))
292 + cpt.positive_value(4.0)
293 * (cpt.positive_weight(1.0 / 3.0) - cpt.positive_weight(1.0 / 6.0))
294 + cpt.positive_value(6.0) * (cpt.positive_weight(1.0 / 6.0) - cpt.positive_weight(0.0))
295 + cpt.negative_value(-5.0)
296 * (cpt.negative_weight(1.0 / 6.0) - cpt.negative_weight(0.0))
297 + cpt.negative_value(-3.0)
298 * (cpt.negative_weight(1.0 / 3.0) - cpt.negative_weight(1.0 / 6.0))
299 + cpt.negative_value(-1.0)
300 * (cpt.negative_weight(1.0 / 2.0) - cpt.negative_weight(1.0 / 3.0));
301
302 assert!(ulps_eq!(a, b));
303 assert!(ulps_eq!(a, c));
304 }
305
306 struct X;
307 impl_domain!(X = 2);
308
309 struct Y;
310 impl_domain!(Y = 3);
311
312 #[test]
313 fn test_cpt_prod() {
314 let outcome = marr_d2!(X, Y; [[6.0, 2.0, 4.0], [-3.0, -1.0, -5.0]]);
315 let prob = marr_d2!(X, Y; [
316 [1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0],
317 [1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]
318 ]);
319 let cpt = CPT {
320 alpha: 0.88,
321 beta: 0.88,
322 lambda: 2.25,
323 gamma: 0.61,
324 delta: 0.69,
325 };
326 let ls = LevelSet::<_, f32>::new(&outcome);
327 let a = cpt.valuate(&ls, &prob);
328 let b = v(2.0) * (wp(1.0 / 2.0) - wp(1.0 / 3.0))
329 + v(4.0) * (wp(1.0 / 3.0) - wp(1.0 / 6.0))
330 + v(6.0) * (wp(1.0 / 6.0) - wp(0.0))
331 + v(-5.0) * (wm(1.0 / 6.0) - wm(0.0))
332 + v(-3.0) * (wm(1.0 / 3.0) - wm(1.0 / 6.0))
333 + v(-1.0) * (wm(1.0 / 2.0) - wm(1.0 / 3.0));
334
335 let c = cpt.positive_value(2.0)
336 * (cpt.positive_weight(1.0 / 2.0) - cpt.positive_weight(1.0 / 3.0))
337 + cpt.positive_value(4.0)
338 * (cpt.positive_weight(1.0 / 3.0) - cpt.positive_weight(1.0 / 6.0))
339 + cpt.positive_value(6.0) * (cpt.positive_weight(1.0 / 6.0) - cpt.positive_weight(0.0))
340 + cpt.negative_value(-5.0)
341 * (cpt.negative_weight(1.0 / 6.0) - cpt.negative_weight(0.0))
342 + cpt.negative_value(-3.0)
343 * (cpt.negative_weight(1.0 / 3.0) - cpt.negative_weight(1.0 / 6.0))
344 + cpt.negative_value(-1.0)
345 * (cpt.negative_weight(1.0 / 2.0) - cpt.negative_weight(1.0 / 3.0));
346
347 assert!(ulps_eq!(a, b));
348 assert!(ulps_eq!(a, c));
349 }
350}