1use std::{fs::File, marker::PhantomData, path::PathBuf, sync::Arc};
2
3use polars_arrow::{
4 array::{ArrayRef, BooleanArray, PrimitiveArray},
5 datatypes::{ArrowDataType, ArrowSchema, Field, Metadata},
6 io::ipc::write::{stream_async::WriteOptions, FileWriter},
7 legacy::error::PolarsResult,
8 record_batch::RecordBatch,
9};
10
11use crate::info::InfoLabel;
12
13#[derive(Default)]
14pub struct InfoData {
15 num_posted: u32,
16 num_received: u32,
17 num_shared: u32,
18 num_viewed: u32,
19 num_fst_viewed: u32,
20}
21
22impl InfoData {
23 pub fn posted(&mut self) {
24 self.num_posted += 1;
25 }
26
27 pub fn received(&mut self) {
28 self.num_received += 1;
29 }
30
31 pub fn shared(&mut self) {
32 self.num_shared += 1;
33 }
34
35 pub fn first_viewed(&mut self) {
36 self.num_fst_viewed += 1;
37 }
38
39 pub fn viewed(&mut self) {
40 self.num_viewed += 1;
41 }
42}
43
44#[derive(Debug)]
45pub enum Stat {
46 Info(InfoStat),
47 Agent(AgentStat),
48 Pop(PopStat),
49}
50
51#[derive(Default, Debug)]
52pub struct InfoStat {
53 num_iter: Vec<u32>,
54 t: Vec<u32>,
55 info_label: Vec<u8>,
56 num_posted: Vec<u32>,
57 num_received: Vec<u32>,
58 num_shared: Vec<u32>,
59 num_viewed: Vec<u32>,
60 num_fst_viewed: Vec<u32>,
61}
62
63impl StatTrait for InfoStat {
64 fn fields() -> Vec<Field> {
65 vec![
66 Field::new("num_iter", ArrowDataType::UInt32, false),
67 Field::new("t", ArrowDataType::UInt32, false),
68 Field::new("info_label", ArrowDataType::UInt8, false),
69 Field::new("num_posted", ArrowDataType::UInt32, false),
70 Field::new("num_received", ArrowDataType::UInt32, false),
71 Field::new("num_shared", ArrowDataType::UInt32, false),
72 Field::new("num_viewed", ArrowDataType::UInt32, false),
73 Field::new("num_fst_viewed", ArrowDataType::UInt32, false),
74 ]
75 }
76
77 fn to_columns(self) -> Vec<ArrayRef> {
78 vec![
79 Box::new(PrimitiveArray::from_vec(self.num_iter)),
80 Box::new(PrimitiveArray::from_vec(self.t)),
81 Box::new(PrimitiveArray::from_vec(self.info_label)),
82 Box::new(PrimitiveArray::from_vec(self.num_posted)),
83 Box::new(PrimitiveArray::from_vec(self.num_received)),
84 Box::new(PrimitiveArray::from_vec(self.num_shared)),
85 Box::new(PrimitiveArray::from_vec(self.num_viewed)),
86 Box::new(PrimitiveArray::from_vec(self.num_fst_viewed)),
87 ]
88 }
89
90 fn label() -> &'static str {
91 "info"
92 }
93}
94
95impl From<InfoStat> for Stat {
96 fn from(value: InfoStat) -> Self {
97 Self::Info(value)
98 }
99}
100
101impl InfoStat {
102 pub fn push(&mut self, num_iter: u32, t: u32, d: &InfoData, label: &InfoLabel) {
103 self.num_iter.push(num_iter);
104 self.t.push(t);
105 self.info_label.push(label.into());
106 self.num_posted.push(d.num_posted);
107 self.num_received.push(d.num_received);
108 self.num_shared.push(d.num_shared);
109 self.num_viewed.push(d.num_viewed);
110 self.num_fst_viewed.push(d.num_fst_viewed);
111 }
112}
113
114#[derive(Default, Debug)]
115pub struct AgentStat {
116 num_iter: Vec<u32>,
117 t: Vec<u32>,
118 agent_idx: Vec<u32>,
119 selfish: Vec<bool>,
120}
121
122impl StatTrait for AgentStat {
123 fn fields() -> Vec<Field> {
124 vec![
125 Field::new("num_iter", ArrowDataType::UInt32, false),
126 Field::new("t", ArrowDataType::UInt32, false),
127 Field::new("agent_idx", ArrowDataType::UInt32, false),
128 Field::new("selfish", ArrowDataType::Boolean, false),
129 ]
130 }
131
132 fn to_columns(self) -> Vec<ArrayRef> {
133 vec![
134 Box::new(PrimitiveArray::from_vec(self.num_iter)),
135 Box::new(PrimitiveArray::from_vec(self.t)),
136 Box::new(PrimitiveArray::from_vec(self.agent_idx)),
137 Box::new(BooleanArray::from_slice(&self.selfish)),
138 ]
139 }
140
141 fn label() -> &'static str {
142 "agent"
143 }
144}
145
146impl AgentStat {
147 pub fn push_selfish(&mut self, num_iter: u32, t: u32, agent_idx: usize) {
148 self.num_iter.push(num_iter);
149 self.t.push(t);
150 self.agent_idx.push(agent_idx as u32);
151 self.selfish.push(true);
152 }
153}
154
155impl From<AgentStat> for Stat {
156 fn from(value: AgentStat) -> Self {
157 Self::Agent(value)
158 }
159}
160
161#[derive(Default)]
162pub struct PopData {
163 pub num_selfish: u32,
164}
165
166impl PopData {
167 pub fn selfish(&mut self) {
168 self.num_selfish += 1;
169 }
170}
171
172#[derive(Default, Debug)]
173pub struct PopStat {
174 num_iter: Vec<u32>,
175 t: Vec<u32>,
176 num_selfish: Vec<u32>,
177}
178
179impl StatTrait for PopStat {
180 fn fields() -> Vec<Field> {
181 vec![
182 Field::new("num_iter", ArrowDataType::UInt32, false),
183 Field::new("t", ArrowDataType::UInt32, false),
184 Field::new("num_selfish", ArrowDataType::UInt32, false),
185 ]
186 }
187
188 fn to_columns(self) -> Vec<ArrayRef> {
189 vec![
190 Box::new(PrimitiveArray::from_vec(self.num_iter)),
191 Box::new(PrimitiveArray::from_vec(self.t)),
192 Box::new(PrimitiveArray::from_vec(self.num_selfish)),
193 ]
194 }
195
196 fn label() -> &'static str {
197 "pop"
198 }
199}
200
201impl PopStat {
202 pub fn push(&mut self, num_iter: u32, t: u32, d: PopData) {
203 self.num_iter.push(num_iter);
204 self.t.push(t);
205 self.num_selfish.push(d.num_selfish);
206 }
207}
208
209impl From<PopStat> for Stat {
210 fn from(value: PopStat) -> Self {
211 Self::Pop(value)
212 }
213}
214
215pub trait StatTrait {
216 fn label() -> &'static str;
217 fn fields() -> Vec<Field>;
218 fn to_columns(self) -> Vec<ArrayRef>;
219}
220
221struct MyWriter<T> {
222 writer: FileWriter<File>,
223 _marker: PhantomData<T>,
224}
225
226impl<T: StatTrait> MyWriter<T> {
227 fn try_new(
228 output_dir: &PathBuf,
229 identifier: &str,
230 metadata: Metadata,
231 overwriting: bool,
232 compress: bool,
233 ) -> anyhow::Result<Self> {
234 let output_path = output_dir.join(format!("{identifier}_{}.arrow", T::label()));
235 if !overwriting && output_path.exists() {
236 panic!(
237 "{} already exists. If you want to overwrite it, run with the overwriting option.",
238 output_path.display()
239 );
240 }
241
242 let schema = Arc::new(ArrowSchema::from(T::fields()).with_metadata(metadata));
243 let writer = FileWriter::try_new(
244 File::create(output_path)?,
245 schema,
246 None,
247 WriteOptions {
248 compression: if compress {
249 Some(polars_arrow::io::ipc::write::Compression::ZSTD)
250 } else {
251 None
252 },
253 },
254 )?;
255 Ok(Self {
256 writer,
257 _marker: PhantomData,
258 })
259 }
260
261 fn write(&mut self, data: T) -> PolarsResult<()> {
262 let batch = RecordBatch::try_new(data.to_columns())?;
263 self.writer.write(&batch, None)
264 }
265
266 fn finish(&mut self) -> PolarsResult<()> {
267 self.writer.finish()
268 }
269}
270
271pub struct FileWriters {
272 info: MyWriter<InfoStat>,
273 agent: MyWriter<AgentStat>,
274 pop: MyWriter<PopStat>,
275}
276
277impl FileWriters {
278 pub fn try_new(
279 identifier: &str,
280 output_dir: &PathBuf,
281 overwriting: bool,
282 compressing: bool,
283 metadata: Metadata,
284 ) -> anyhow::Result<Self> {
285 Ok(Self {
286 info: MyWriter::try_new(
287 output_dir,
288 identifier,
289 metadata.clone(),
290 overwriting,
291 compressing,
292 )?,
293 agent: MyWriter::try_new(
294 output_dir,
295 identifier,
296 metadata.clone(),
297 overwriting,
298 compressing,
299 )?,
300 pop: MyWriter::try_new(output_dir, identifier, metadata, overwriting, compressing)?,
301 })
302 }
303
304 pub fn write(&mut self, stat: Stat) -> PolarsResult<()> {
305 match stat {
306 Stat::Info(stat) => self.info.write(stat)?,
307 Stat::Agent(stat) => self.agent.write(stat)?,
308 Stat::Pop(stat) => self.pop.write(stat)?,
309 }
310 Ok(())
311 }
312
313 pub fn finish(&mut self) -> PolarsResult<()> {
314 self.info.finish()?;
315 self.agent.finish()?;
316 self.pop.finish()?;
317 Ok(())
318 }
319}