1use std::io::{stdout, Write};
2use std::sync::Arc;
3
4use futures::future::try_join_all;
5use rand::rngs::SmallRng;
6use rand_distr::uniform::SampleUniform;
7use tokio::sync::{mpsc, Mutex};
8
9use rand::{Rng, SeedableRng};
10use rand_distr::{Distribution, Exp1, Open01, Standard, StandardNormal};
11
12use crate::executor::AgentExtTrait;
13use crate::stat::FileWriters;
14use crate::{
15 executor::{Executor, InstanceExt, Memory},
16 opinion::MyFloat,
17 stat::Stat,
18};
19
20#[derive(Debug, serde::Deserialize)]
21pub struct RuntimeParams {
22 pub seed_state: u64,
23 pub iteration_count: u32,
24}
25
26pub async fn run<V, E, Ax, Ix>(
27 mut writers: FileWriters,
28 runtime: &RuntimeParams,
29 exec: E,
30 max_permits: Option<usize>,
31) -> anyhow::Result<()>
32where
33 V: MyFloat + SampleUniform + 'static,
34 V::Sampler: Sync + Send,
35 Open01: Distribution<V>,
36 Standard: Distribution<V>,
37 StandardNormal: Distribution<V>,
38 Exp1: Distribution<V>,
39 E: Executor<V, Ax, Ix> + Send + Sync + 'static,
40 Ax: AgentExtTrait<V, Exec = E, Ix = Ix> + Default + Send + 'static,
41 Ix: InstanceExt<V, SmallRng, E> + Send + 'static,
42{
43 println!("initialising...");
44
45 let permits = max_permits.unwrap_or(num_cpus::get());
46 let (tx, mut rx) = mpsc::channel::<Stat>(permits);
47 let handle = tokio::spawn(async move {
48 while let Some(stat) = rx.recv().await {
49 writers.write(stat).unwrap();
50 }
51 writers.finish().unwrap();
52 });
53
54 let mut rng = SmallRng::seed_from_u64(runtime.seed_state);
55 let rngs: Vec<SmallRng> = (0..(runtime.iteration_count))
56 .map(|_| SmallRng::from_rng(&mut rng))
57 .collect::<Result<Vec<_>, _>>()?;
58
59 let exec = Arc::new(exec);
60 let mut manager = Manager::new(permits, |id| Memory::new(exec.as_ref(), id));
61
62 let mut jhs = Vec::new();
63 print!("started.");
64 for (num_iter, rng) in rngs.into_iter().enumerate() {
65 let permit = manager.rent().await;
66 let tx = tx.clone();
67 jhs.push(tokio::spawn(permit.run(exec.clone(), num_iter, rng, tx)));
68 }
69
70 try_join_all(jhs).await?;
71 drop(tx);
72 handle.await.unwrap();
73 println!("\ndone.");
74 Ok(())
75}
76
77pub struct Manager<E> {
78 pub rx: mpsc::Receiver<usize>,
79 pub tx: mpsc::Sender<usize>,
80 pub resources: Vec<Arc<Mutex<E>>>,
81}
82
83impl<E> Manager<E> {
84 pub fn new<F: Fn(usize) -> E>(permits: usize, f: F) -> Self {
85 let mut resources = Vec::new();
86 let (tx, rx) = mpsc::channel(permits);
87 for i in 0..permits {
88 let r = Arc::new(Mutex::new(f(i)));
89 resources.push(r);
90 tx.try_send(i).unwrap();
91 }
92 Self { rx, tx, resources }
93 }
94
95 async fn rent(&mut self) -> EnvPermit<E> {
96 let idx = self.rx.recv().await.unwrap();
97 EnvPermit {
98 idx,
99 env: self.resources[idx].clone(),
100 tx: self.tx.clone(),
101 }
102 }
103}
104
105pub struct EnvPermit<E> {
106 idx: usize,
107 tx: mpsc::Sender<usize>,
108 env: Arc<Mutex<E>>,
109}
110
111impl<V, Ax> EnvPermit<Memory<V, Ax>>
112where
113 V: MyFloat + 'static,
114{
115 async fn run<Ex, Ix, R: Rng + Send + 'static>(
116 self,
117 exec: Arc<Ex>,
118 num_iter: usize,
119 rng: R,
120 tx: mpsc::Sender<Stat>,
121 ) -> anyhow::Result<()>
122 where
123 Open01: Distribution<V>,
124 Standard: Distribution<V>,
125 StandardNormal: Distribution<V>,
126 Exp1: Distribution<V>,
127 Ex: Executor<V, Ax, Ix> + Send + Sync + 'static,
128 Ax: AgentExtTrait<V, Exec = Ex, Ix = Ix> + Default + Send + 'static,
129 Ix: InstanceExt<V, R, Ex> + Send,
130 {
131 let env = self.env.clone();
132 let handle = tokio::spawn(async move {
133 if num_iter % 100 == 0 {
134 println!("\n{num_iter}");
135 }
136 if num_iter % 10 == 0 {
137 print!("|");
138 stdout().flush().unwrap();
139 }
140 print!(".");
141 let mut memory = env.lock().await;
142 exec.execute::<R>(&mut memory, num_iter as u32, rng)
143 });
144 let ss = handle.await?;
145 for s in ss {
146 tx.send(s).await.unwrap();
147 }
148 Ok(())
149 }
150}
151
152impl<E> Drop for EnvPermit<E> {
153 fn drop(&mut self) {
154 self.tx.try_send(self.idx).unwrap();
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use std::fs::read_to_string;
161
162 use super::RuntimeParams;
163 use serde_json::json;
164
165 #[test]
166 fn test_json_config() -> anyhow::Result<()> {
167 let runtime = json!({
168 "seed_state": 0,
169 "num_parallel": 1,
170 "iteration_count": 1,
171 });
172 let runtime = serde_json::from_value::<RuntimeParams>(runtime)?;
173 println!("{:?}", runtime);
174 Ok(())
175 }
176
177 #[test]
178 fn test_toml_config() -> anyhow::Result<()> {
179 let runtime = toml::from_str::<RuntimeParams>(&read_to_string("./test_runtime.toml")?)?;
180 assert_eq!(runtime.seed_state, 0);
181 assert_eq!(runtime.iteration_count, 1);
182
183 Ok(())
184 }
185}