base/
runner.rs

1use std::io::{stdout, Write};
2use std::sync::Arc;
3
4use futures::future::try_join_all;
5use rand::rngs::SmallRng;
6use rand_distr::uniform::SampleUniform;
7use tokio::sync::{mpsc, Mutex};
8
9use rand::{Rng, SeedableRng};
10use rand_distr::{Distribution, Exp1, Open01, Standard, StandardNormal};
11
12use crate::executor::AgentExtTrait;
13use crate::stat::FileWriters;
14use crate::{
15    executor::{Executor, InstanceExt, Memory},
16    opinion::MyFloat,
17    stat::Stat,
18};
19
20#[derive(Debug, serde::Deserialize)]
21pub struct RuntimeParams {
22    pub seed_state: u64,
23    pub iteration_count: u32,
24}
25
26pub async fn run<V, E, Ax, Ix>(
27    mut writers: FileWriters,
28    runtime: &RuntimeParams,
29    exec: E,
30    max_permits: Option<usize>,
31) -> anyhow::Result<()>
32where
33    V: MyFloat + SampleUniform + 'static,
34    V::Sampler: Sync + Send,
35    Open01: Distribution<V>,
36    Standard: Distribution<V>,
37    StandardNormal: Distribution<V>,
38    Exp1: Distribution<V>,
39    E: Executor<V, Ax, Ix> + Send + Sync + 'static,
40    Ax: AgentExtTrait<V, Exec = E, Ix = Ix> + Default + Send + 'static,
41    Ix: InstanceExt<V, SmallRng, E> + Send + 'static,
42{
43    println!("initialising...");
44
45    let permits = max_permits.unwrap_or(num_cpus::get());
46    let (tx, mut rx) = mpsc::channel::<Stat>(permits);
47    let handle = tokio::spawn(async move {
48        while let Some(stat) = rx.recv().await {
49            writers.write(stat).unwrap();
50        }
51        writers.finish().unwrap();
52    });
53
54    let mut rng = SmallRng::seed_from_u64(runtime.seed_state);
55    let rngs: Vec<SmallRng> = (0..(runtime.iteration_count))
56        .map(|_| SmallRng::from_rng(&mut rng))
57        .collect::<Result<Vec<_>, _>>()?;
58
59    let exec = Arc::new(exec);
60    let mut manager = Manager::new(permits, |id| Memory::new(exec.as_ref(), id));
61
62    let mut jhs = Vec::new();
63    print!("started.");
64    for (num_iter, rng) in rngs.into_iter().enumerate() {
65        let permit = manager.rent().await;
66        let tx = tx.clone();
67        jhs.push(tokio::spawn(permit.run(exec.clone(), num_iter, rng, tx)));
68    }
69
70    try_join_all(jhs).await?;
71    drop(tx);
72    handle.await.unwrap();
73    println!("\ndone.");
74    Ok(())
75}
76
77pub struct Manager<E> {
78    pub rx: mpsc::Receiver<usize>,
79    pub tx: mpsc::Sender<usize>,
80    pub resources: Vec<Arc<Mutex<E>>>,
81}
82
83impl<E> Manager<E> {
84    pub fn new<F: Fn(usize) -> E>(permits: usize, f: F) -> Self {
85        let mut resources = Vec::new();
86        let (tx, rx) = mpsc::channel(permits);
87        for i in 0..permits {
88            let r = Arc::new(Mutex::new(f(i)));
89            resources.push(r);
90            tx.try_send(i).unwrap();
91        }
92        Self { rx, tx, resources }
93    }
94
95    async fn rent(&mut self) -> EnvPermit<E> {
96        let idx = self.rx.recv().await.unwrap();
97        EnvPermit {
98            idx,
99            env: self.resources[idx].clone(),
100            tx: self.tx.clone(),
101        }
102    }
103}
104
105pub struct EnvPermit<E> {
106    idx: usize,
107    tx: mpsc::Sender<usize>,
108    env: Arc<Mutex<E>>,
109}
110
111impl<V, Ax> EnvPermit<Memory<V, Ax>>
112where
113    V: MyFloat + 'static,
114{
115    async fn run<Ex, Ix, R: Rng + Send + 'static>(
116        self,
117        exec: Arc<Ex>,
118        num_iter: usize,
119        rng: R,
120        tx: mpsc::Sender<Stat>,
121    ) -> anyhow::Result<()>
122    where
123        Open01: Distribution<V>,
124        Standard: Distribution<V>,
125        StandardNormal: Distribution<V>,
126        Exp1: Distribution<V>,
127        Ex: Executor<V, Ax, Ix> + Send + Sync + 'static,
128        Ax: AgentExtTrait<V, Exec = Ex, Ix = Ix> + Default + Send + 'static,
129        Ix: InstanceExt<V, R, Ex> + Send,
130    {
131        let env = self.env.clone();
132        let handle = tokio::spawn(async move {
133            if num_iter % 100 == 0 {
134                println!("\n{num_iter}");
135            }
136            if num_iter % 10 == 0 {
137                print!("|");
138                stdout().flush().unwrap();
139            }
140            print!(".");
141            let mut memory = env.lock().await;
142            exec.execute::<R>(&mut memory, num_iter as u32, rng)
143        });
144        let ss = handle.await?;
145        for s in ss {
146            tx.send(s).await.unwrap();
147        }
148        Ok(())
149    }
150}
151
152impl<E> Drop for EnvPermit<E> {
153    fn drop(&mut self) {
154        self.tx.try_send(self.idx).unwrap();
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use std::fs::read_to_string;
161
162    use super::RuntimeParams;
163    use serde_json::json;
164
165    #[test]
166    fn test_json_config() -> anyhow::Result<()> {
167        let runtime = json!({
168            "seed_state": 0,
169            "num_parallel": 1,
170            "iteration_count": 1,
171        });
172        let runtime = serde_json::from_value::<RuntimeParams>(runtime)?;
173        println!("{:?}", runtime);
174        Ok(())
175    }
176
177    #[test]
178    fn test_toml_config() -> anyhow::Result<()> {
179        let runtime = toml::from_str::<RuntimeParams>(&read_to_string("./test_runtime.toml")?)?;
180        assert_eq!(runtime.seed_state, 0);
181        assert_eq!(runtime.iteration_count, 1);
182
183        Ok(())
184    }
185}