base/
decision.rs

1use std::{fmt, iter::Sum, ops::AddAssign};
2
3use approx::{ulps_eq, UlpsEq};
4use num_traits::Float;
5use subjective_logic::{
6    iter::Container,
7    marr_d1, marr_d2,
8    multi_array::labeled::{MArrD1, MArrD2},
9};
10use tracing::debug;
11
12use crate::opinion::{Theta, Thetad, A};
13
14#[derive(Clone, Default, Debug)]
15pub struct CPT<V> {
16    alpha: V,
17    beta: V,
18    lambda: V,
19    gamma: V,
20    delta: V,
21}
22
23impl<V> CPT<V> {
24    pub fn reset(&mut self, alpha: V, beta: V, lambda: V, gamma: V, delta: V) {
25        self.alpha = alpha;
26        self.beta = beta;
27        self.lambda = lambda;
28        self.gamma = gamma;
29        self.delta = delta;
30    }
31
32    // pub fn reset_with(&mut self, params: &CptParams<V>, rng: &mut impl Rng)
33    // where
34    //     V: SampleUniform,
35    //     Open01: Distribution<V>,
36    //     Standard: Distribution<V>,
37    // {
38    //     self.reset(
39    //         params.alpha.sample(rng),
40    //         params.beta.sample(rng),
41    //         params.lambda.sample(rng),
42    //         params.gamma.sample(rng),
43    //         params.delta.sample(rng),
44    //     );
45    // }
46
47    fn w(p: V, e: V) -> V
48    where
49        V: Float + UlpsEq + AddAssign + Sum,
50    {
51        if ulps_eq!(p, V::one()) {
52            V::one()
53        } else {
54            let temp = p.powf(e);
55            temp / (temp + (V::one() - p).powf(e)).powf(V::one() / e)
56        }
57    }
58
59    /// Computes a probability weighting function for gains
60    fn positive_weight(&self, p: V) -> V
61    where
62        V: Float + UlpsEq + AddAssign + Sum,
63    {
64        Self::w(p, self.gamma)
65    }
66
67    /// Computes a probability weighting function for losses
68    fn negative_weight(&self, p: V) -> V
69    where
70        V: Float + UlpsEq + AddAssign + Sum,
71    {
72        Self::w(p, self.delta)
73    }
74
75    /// Computes a value funciton for positive value
76    fn positive_value(&self, x: V) -> V
77    where
78        V: Float + UlpsEq + AddAssign + Sum,
79    {
80        x.powf(self.alpha)
81    }
82
83    /// Computes a value funciton for negative value
84    fn negative_value(&self, x: V) -> V
85    where
86        V: Float + UlpsEq + AddAssign + Sum,
87    {
88        -self.lambda * x.abs().powf(self.beta)
89    }
90
91    /// Computes Choquet integral of a positive function
92    fn positive_valuate<P: Container<Idx, Output = V>, Idx: Copy>(
93        &self,
94        positive_level_sets: &[(V, Vec<Idx>)],
95        prob: &P,
96    ) -> V
97    where
98        V: Float + UlpsEq + AddAssign + Sum,
99    {
100        positive_level_sets
101            .iter()
102            .scan((V::zero(), V::zero()), |(w, acc), (o, ids)| {
103                let w0 = *w;
104                *acc += ids.iter().map(|i| prob[*i]).sum::<V>();
105                *w = self.positive_weight(*acc);
106                Some(self.positive_value(*o) * (*w - w0))
107            })
108            .sum::<V>()
109    }
110
111    /// Computes Choquet integral of a negative function
112    fn negative_valuate<P: Container<Idx, Output = V>, Idx: Copy>(
113        &self,
114        negative_level_sets: &[(V, Vec<Idx>)],
115        prob: &P,
116    ) -> V
117    where
118        V: Float + UlpsEq + AddAssign + Sum,
119    {
120        negative_level_sets
121            .iter()
122            .scan((V::zero(), V::zero()), |(w, acc), (o, ids)| {
123                let w0 = *w;
124                *acc += ids.iter().map(|i| prob[*i]).sum::<V>();
125                *w = self.negative_weight(*acc);
126                Some(self.negative_value(*o) * (*w - w0))
127            })
128            .sum::<V>()
129    }
130
131    /// Computes CPT
132    pub fn valuate<P: Container<Idx, Output = V>, Idx: Copy>(
133        &self,
134        level_sets: &LevelSet<Idx, V>,
135        prob: &P,
136    ) -> V
137    where
138        V: Float + UlpsEq + AddAssign + Sum,
139    {
140        self.positive_valuate(&level_sets.positive, prob)
141            + self.negative_valuate(&level_sets.negative, prob)
142    }
143}
144
145#[derive(Clone, Debug)]
146pub struct LevelSet<Idx, V> {
147    positive: Vec<(V, Vec<Idx>)>,
148    negative: Vec<(V, Vec<Idx>)>,
149}
150
151impl<Idx, V> Default for LevelSet<Idx, V> {
152    fn default() -> Self {
153        Self {
154            positive: Vec::new(),
155            negative: Vec::new(),
156        }
157    }
158}
159
160impl<Idx, V: Float> LevelSet<Idx, V> {
161    pub fn new<T>(outcome: &T) -> Self
162    where
163        T: Container<Idx, Output = V>,
164        for<'a> &'a T: IntoIterator<Item = &'a V>,
165        Idx: Copy,
166    {
167        let mut pos = T::indexes()
168            .zip(outcome)
169            .filter(|(_, &o)| o > V::zero())
170            .collect::<Vec<_>>();
171        pos.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
172        let mut positive = Vec::<(V, Vec<Idx>)>::new();
173        for (i, &o) in pos {
174            match positive.last_mut() {
175                Some((o2, v)) if o == *o2 => {
176                    v.push(i);
177                }
178                _ => {
179                    positive.push((o, vec![i]));
180                }
181            }
182        }
183        let mut neg = T::indexes()
184            .zip(outcome)
185            .filter(|(_, &o)| o < V::zero())
186            .collect::<Vec<_>>();
187        neg.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
188        let mut negative = Vec::<(V, Vec<Idx>)>::new();
189        for (i, &o) in neg {
190            match negative.last_mut() {
191                Some((o2, v)) if o == *o2 => {
192                    v.push(i);
193                }
194                _ => {
195                    negative.push((o, vec![i]));
196                }
197            }
198        }
199        Self { positive, negative }
200    }
201}
202
203#[derive(Debug, Default)]
204pub struct Prospect<V> {
205    pub selfish: [LevelSet<Theta, V>; 2],
206    pub sharing: [LevelSet<(A, Thetad), V>; 2],
207}
208
209impl<V> Prospect<V> {
210    pub fn reset(&mut self, x0: V, x1: V, y: V)
211    where
212        V: Float + fmt::Debug,
213    {
214        let selfish_outcome_maps: [MArrD1<Theta, V>; 2] =
215            [marr_d1![V::zero(), x1], marr_d1![x0, x0]];
216        let sharing_outcome_maps: [MArrD2<A, Thetad, V>; 2] = [
217            marr_d2![[V::zero(), x1], [x0, x0]],
218            marr_d2![[y, x1 + y], [x0 + y, x0 + y]],
219        ];
220        let selfish = [
221            LevelSet::new(&selfish_outcome_maps[0]),
222            LevelSet::new(&selfish_outcome_maps[1]),
223        ];
224        let sharing = [
225            LevelSet::new(&sharing_outcome_maps[0]),
226            LevelSet::new(&sharing_outcome_maps[1]),
227        ];
228
229        debug!(target: "X outcomes", x = ?selfish_outcome_maps);
230        debug!(target: "Y outcomes", x = ?sharing_outcome_maps);
231        self.selfish = selfish;
232        self.sharing = sharing;
233    }
234
235    // pub fn reset_with(&mut self, loss_params: &LossParams<V>, rng: &mut impl Rng)
236    // where
237    //     V: SampleUniform,
238    //     Open01: Distribution<V>,
239    //     Standard: Distribution<V>,
240    // {
241    //     let x0 = loss_params.x0.sample(rng);
242    //     let x1 = x0 * loss_params.x1_of_x0.sample(rng);
243    //     let y = x0 * loss_params.y_of_x0.sample(rng);
244    //     self.reset(x0, x1, y);
245    // }
246}
247
248#[cfg(test)]
249mod tests {
250    use crate::decision::{LevelSet, CPT};
251    use approx::ulps_eq;
252    use subjective_logic::{domain::Domain, impl_domain, marr_d2};
253
254    fn v(o: f32) -> f32 {
255        if o.is_sign_negative() {
256            -2.25 * (-o).powf(0.88)
257        } else {
258            o.powf(0.88)
259        }
260    }
261    fn wp(p: f32) -> f32 {
262        let q = p.powf(0.61);
263        q / (q + (1.0 - p).powf(0.61)).powf(1.0 / 0.61)
264    }
265    fn wm(p: f32) -> f32 {
266        let q = p.powf(0.69);
267        q / (q + (1.0 - p).powf(0.69)).powf(1.0 / 0.69)
268    }
269
270    #[test]
271    fn test_cpt() {
272        let outcome = [6.0, 2.0, 4.0, -3.0, -1.0, -5.0];
273        let prob = [1.0 / 6.0; 6];
274        let cpt = CPT {
275            alpha: 0.88,
276            beta: 0.88,
277            lambda: 2.25,
278            gamma: 0.61,
279            delta: 0.69,
280        };
281        let ls = LevelSet::<_, f32>::new(&outcome);
282        let a = cpt.valuate(&ls, &prob);
283        let b = v(2.0) * (wp(1.0 / 2.0) - wp(1.0 / 3.0))
284            + v(4.0) * (wp(1.0 / 3.0) - wp(1.0 / 6.0))
285            + v(6.0) * (wp(1.0 / 6.0) - wp(0.0))
286            + v(-5.0) * (wm(1.0 / 6.0) - wm(0.0))
287            + v(-3.0) * (wm(1.0 / 3.0) - wm(1.0 / 6.0))
288            + v(-1.0) * (wm(1.0 / 2.0) - wm(1.0 / 3.0));
289
290        let c = cpt.positive_value(2.0)
291            * (cpt.positive_weight(1.0 / 2.0) - cpt.positive_weight(1.0 / 3.0))
292            + cpt.positive_value(4.0)
293                * (cpt.positive_weight(1.0 / 3.0) - cpt.positive_weight(1.0 / 6.0))
294            + cpt.positive_value(6.0) * (cpt.positive_weight(1.0 / 6.0) - cpt.positive_weight(0.0))
295            + cpt.negative_value(-5.0)
296                * (cpt.negative_weight(1.0 / 6.0) - cpt.negative_weight(0.0))
297            + cpt.negative_value(-3.0)
298                * (cpt.negative_weight(1.0 / 3.0) - cpt.negative_weight(1.0 / 6.0))
299            + cpt.negative_value(-1.0)
300                * (cpt.negative_weight(1.0 / 2.0) - cpt.negative_weight(1.0 / 3.0));
301
302        assert!(ulps_eq!(a, b));
303        assert!(ulps_eq!(a, c));
304    }
305
306    struct X;
307    impl_domain!(X = 2);
308
309    struct Y;
310    impl_domain!(Y = 3);
311
312    #[test]
313    fn test_cpt_prod() {
314        let outcome = marr_d2!(X, Y; [[6.0, 2.0, 4.0], [-3.0, -1.0, -5.0]]);
315        let prob = marr_d2!(X, Y; [
316            [1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0],
317            [1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]
318        ]);
319        let cpt = CPT {
320            alpha: 0.88,
321            beta: 0.88,
322            lambda: 2.25,
323            gamma: 0.61,
324            delta: 0.69,
325        };
326        let ls = LevelSet::<_, f32>::new(&outcome);
327        let a = cpt.valuate(&ls, &prob);
328        let b = v(2.0) * (wp(1.0 / 2.0) - wp(1.0 / 3.0))
329            + v(4.0) * (wp(1.0 / 3.0) - wp(1.0 / 6.0))
330            + v(6.0) * (wp(1.0 / 6.0) - wp(0.0))
331            + v(-5.0) * (wm(1.0 / 6.0) - wm(0.0))
332            + v(-3.0) * (wm(1.0 / 3.0) - wm(1.0 / 6.0))
333            + v(-1.0) * (wm(1.0 / 2.0) - wm(1.0 / 3.0));
334
335        let c = cpt.positive_value(2.0)
336            * (cpt.positive_weight(1.0 / 2.0) - cpt.positive_weight(1.0 / 3.0))
337            + cpt.positive_value(4.0)
338                * (cpt.positive_weight(1.0 / 3.0) - cpt.positive_weight(1.0 / 6.0))
339            + cpt.positive_value(6.0) * (cpt.positive_weight(1.0 / 6.0) - cpt.positive_weight(0.0))
340            + cpt.negative_value(-5.0)
341                * (cpt.negative_weight(1.0 / 6.0) - cpt.negative_weight(0.0))
342            + cpt.negative_value(-3.0)
343                * (cpt.negative_weight(1.0 / 3.0) - cpt.negative_weight(1.0 / 6.0))
344            + cpt.negative_value(-1.0)
345                * (cpt.negative_weight(1.0 / 2.0) - cpt.negative_weight(1.0 / 3.0));
346
347        assert!(ulps_eq!(a, b));
348        assert!(ulps_eq!(a, c));
349    }
350}