base/
stat.rs

1use std::{fs::File, marker::PhantomData, path::PathBuf, sync::Arc};
2
3use polars_arrow::{
4    array::{ArrayRef, BooleanArray, PrimitiveArray},
5    datatypes::{ArrowDataType, ArrowSchema, Field, Metadata},
6    io::ipc::write::{stream_async::WriteOptions, FileWriter},
7    legacy::error::PolarsResult,
8    record_batch::RecordBatch,
9};
10
11use crate::info::InfoLabel;
12
13#[derive(Default)]
14pub struct InfoData {
15    num_posted: u32,
16    num_received: u32,
17    num_shared: u32,
18    num_viewed: u32,
19    num_fst_viewed: u32,
20}
21
22impl InfoData {
23    pub fn posted(&mut self) {
24        self.num_posted += 1;
25    }
26
27    pub fn received(&mut self) {
28        self.num_received += 1;
29    }
30
31    pub fn shared(&mut self) {
32        self.num_shared += 1;
33    }
34
35    pub fn first_viewed(&mut self) {
36        self.num_fst_viewed += 1;
37    }
38
39    pub fn viewed(&mut self) {
40        self.num_viewed += 1;
41    }
42}
43
44#[derive(Debug)]
45pub enum Stat {
46    Info(InfoStat),
47    Agent(AgentStat),
48    Pop(PopStat),
49}
50
51#[derive(Default, Debug)]
52pub struct InfoStat {
53    num_iter: Vec<u32>,
54    t: Vec<u32>,
55    info_label: Vec<u8>,
56    num_posted: Vec<u32>,
57    num_received: Vec<u32>,
58    num_shared: Vec<u32>,
59    num_viewed: Vec<u32>,
60    num_fst_viewed: Vec<u32>,
61}
62
63impl StatTrait for InfoStat {
64    fn fields() -> Vec<Field> {
65        vec![
66            Field::new("num_iter", ArrowDataType::UInt32, false),
67            Field::new("t", ArrowDataType::UInt32, false),
68            Field::new("info_label", ArrowDataType::UInt8, false),
69            Field::new("num_posted", ArrowDataType::UInt32, false),
70            Field::new("num_received", ArrowDataType::UInt32, false),
71            Field::new("num_shared", ArrowDataType::UInt32, false),
72            Field::new("num_viewed", ArrowDataType::UInt32, false),
73            Field::new("num_fst_viewed", ArrowDataType::UInt32, false),
74        ]
75    }
76
77    fn to_columns(self) -> Vec<ArrayRef> {
78        vec![
79            Box::new(PrimitiveArray::from_vec(self.num_iter)),
80            Box::new(PrimitiveArray::from_vec(self.t)),
81            Box::new(PrimitiveArray::from_vec(self.info_label)),
82            Box::new(PrimitiveArray::from_vec(self.num_posted)),
83            Box::new(PrimitiveArray::from_vec(self.num_received)),
84            Box::new(PrimitiveArray::from_vec(self.num_shared)),
85            Box::new(PrimitiveArray::from_vec(self.num_viewed)),
86            Box::new(PrimitiveArray::from_vec(self.num_fst_viewed)),
87        ]
88    }
89
90    fn label() -> &'static str {
91        "info"
92    }
93}
94
95impl From<InfoStat> for Stat {
96    fn from(value: InfoStat) -> Self {
97        Self::Info(value)
98    }
99}
100
101impl InfoStat {
102    pub fn push(&mut self, num_iter: u32, t: u32, d: &InfoData, label: &InfoLabel) {
103        self.num_iter.push(num_iter);
104        self.t.push(t);
105        self.info_label.push(label.into());
106        self.num_posted.push(d.num_posted);
107        self.num_received.push(d.num_received);
108        self.num_shared.push(d.num_shared);
109        self.num_viewed.push(d.num_viewed);
110        self.num_fst_viewed.push(d.num_fst_viewed);
111    }
112}
113
114#[derive(Default, Debug)]
115pub struct AgentStat {
116    num_iter: Vec<u32>,
117    t: Vec<u32>,
118    agent_idx: Vec<u32>,
119    selfish: Vec<bool>,
120}
121
122impl StatTrait for AgentStat {
123    fn fields() -> Vec<Field> {
124        vec![
125            Field::new("num_iter", ArrowDataType::UInt32, false),
126            Field::new("t", ArrowDataType::UInt32, false),
127            Field::new("agent_idx", ArrowDataType::UInt32, false),
128            Field::new("selfish", ArrowDataType::Boolean, false),
129        ]
130    }
131
132    fn to_columns(self) -> Vec<ArrayRef> {
133        vec![
134            Box::new(PrimitiveArray::from_vec(self.num_iter)),
135            Box::new(PrimitiveArray::from_vec(self.t)),
136            Box::new(PrimitiveArray::from_vec(self.agent_idx)),
137            Box::new(BooleanArray::from_slice(&self.selfish)),
138        ]
139    }
140
141    fn label() -> &'static str {
142        "agent"
143    }
144}
145
146impl AgentStat {
147    pub fn push_selfish(&mut self, num_iter: u32, t: u32, agent_idx: usize) {
148        self.num_iter.push(num_iter);
149        self.t.push(t);
150        self.agent_idx.push(agent_idx as u32);
151        self.selfish.push(true);
152    }
153}
154
155impl From<AgentStat> for Stat {
156    fn from(value: AgentStat) -> Self {
157        Self::Agent(value)
158    }
159}
160
161#[derive(Default)]
162pub struct PopData {
163    pub num_selfish: u32,
164}
165
166impl PopData {
167    pub fn selfish(&mut self) {
168        self.num_selfish += 1;
169    }
170}
171
172#[derive(Default, Debug)]
173pub struct PopStat {
174    num_iter: Vec<u32>,
175    t: Vec<u32>,
176    num_selfish: Vec<u32>,
177}
178
179impl StatTrait for PopStat {
180    fn fields() -> Vec<Field> {
181        vec![
182            Field::new("num_iter", ArrowDataType::UInt32, false),
183            Field::new("t", ArrowDataType::UInt32, false),
184            Field::new("num_selfish", ArrowDataType::UInt32, false),
185        ]
186    }
187
188    fn to_columns(self) -> Vec<ArrayRef> {
189        vec![
190            Box::new(PrimitiveArray::from_vec(self.num_iter)),
191            Box::new(PrimitiveArray::from_vec(self.t)),
192            Box::new(PrimitiveArray::from_vec(self.num_selfish)),
193        ]
194    }
195
196    fn label() -> &'static str {
197        "pop"
198    }
199}
200
201impl PopStat {
202    pub fn push(&mut self, num_iter: u32, t: u32, d: PopData) {
203        self.num_iter.push(num_iter);
204        self.t.push(t);
205        self.num_selfish.push(d.num_selfish);
206    }
207}
208
209impl From<PopStat> for Stat {
210    fn from(value: PopStat) -> Self {
211        Self::Pop(value)
212    }
213}
214
215pub trait StatTrait {
216    fn label() -> &'static str;
217    fn fields() -> Vec<Field>;
218    fn to_columns(self) -> Vec<ArrayRef>;
219}
220
221struct MyWriter<T> {
222    writer: FileWriter<File>,
223    _marker: PhantomData<T>,
224}
225
226impl<T: StatTrait> MyWriter<T> {
227    fn try_new(
228        output_dir: &PathBuf,
229        identifier: &str,
230        metadata: Metadata,
231        overwriting: bool,
232        compress: bool,
233    ) -> anyhow::Result<Self> {
234        let output_path = output_dir.join(format!("{identifier}_{}.arrow", T::label()));
235        if !overwriting && output_path.exists() {
236            panic!(
237                "{} already exists. If you want to overwrite it, run with the overwriting option.",
238                output_path.display()
239            );
240        }
241
242        let schema = Arc::new(ArrowSchema::from(T::fields()).with_metadata(metadata));
243        let writer = FileWriter::try_new(
244            File::create(output_path)?,
245            schema,
246            None,
247            WriteOptions {
248                compression: if compress {
249                    Some(polars_arrow::io::ipc::write::Compression::ZSTD)
250                } else {
251                    None
252                },
253            },
254        )?;
255        Ok(Self {
256            writer,
257            _marker: PhantomData,
258        })
259    }
260
261    fn write(&mut self, data: T) -> PolarsResult<()> {
262        let batch = RecordBatch::try_new(data.to_columns())?;
263        self.writer.write(&batch, None)
264    }
265
266    fn finish(&mut self) -> PolarsResult<()> {
267        self.writer.finish()
268    }
269}
270
271pub struct FileWriters {
272    info: MyWriter<InfoStat>,
273    agent: MyWriter<AgentStat>,
274    pop: MyWriter<PopStat>,
275}
276
277impl FileWriters {
278    pub fn try_new(
279        identifier: &str,
280        output_dir: &PathBuf,
281        overwriting: bool,
282        compressing: bool,
283        metadata: Metadata,
284    ) -> anyhow::Result<Self> {
285        Ok(Self {
286            info: MyWriter::try_new(
287                output_dir,
288                identifier,
289                metadata.clone(),
290                overwriting,
291                compressing,
292            )?,
293            agent: MyWriter::try_new(
294                output_dir,
295                identifier,
296                metadata.clone(),
297                overwriting,
298                compressing,
299            )?,
300            pop: MyWriter::try_new(output_dir, identifier, metadata, overwriting, compressing)?,
301        })
302    }
303
304    pub fn write(&mut self, stat: Stat) -> PolarsResult<()> {
305        match stat {
306            Stat::Info(stat) => self.info.write(stat)?,
307            Stat::Agent(stat) => self.agent.write(stat)?,
308            Stat::Pop(stat) => self.pop.write(stat)?,
309        }
310        Ok(())
311    }
312
313    pub fn finish(&mut self) -> PolarsResult<()> {
314        self.info.finish()?;
315        self.agent.finish()?;
316        self.pop.finish()?;
317        Ok(())
318    }
319}