线程池的简单实现(Rust)

栏目: 编程语言 · Rust · 发布时间: 5年前

内容简介:线程池,就是一组工作线程,工作线程的数量一般与CPU核数相关(如果是CPU密集型任务,可初始设为 ,如果是IO密集型任务,可初始设为 ,运行过程中可能会依据任务的繁忙程度而动态增减),由线程池负责管理工作线程的创建,异常处理(如果工作线程异常退出,会创建新的工作线程弥补线程池中的工作线程数量),任务分配等工作。线程池中一般会有一个任务队列,所有工作线程从任务队列中取任务,执行,如此反复。其核心是避免大量线程的创建及频繁的线程切换,尽最大可能提高CPU利用率。线程池有不同的实现形式,主要的区别就是如何指定

线程池,就是一组工作线程,工作线程的数量一般与CPU核数相关(如果是CPU密集型任务,可初始设为 ,如果是IO密集型任务,可初始设为 ,运行过程中可能会依据任务的繁忙程度而动态增减),由线程池负责管理工作线程的创建,异常处理(如果工作线程异常退出,会创建新的工作线程弥补线程池中的工作线程数量),任务分配等工作。线程池中一般会有一个任务队列,所有工作线程从任务队列中取任务,执行,如此反复。其核心是避免大量线程的创建及频繁的线程切换,尽最大可能提高CPU利用率。

线程池有不同的实现形式,主要的区别就是如何指定因任务队列中任务的繁忙程度与调度管理工作线程的数量的调度策略。比如如果任务队列中有大量的任务等待处理,是否需要根据待处理任务队列的任务数量而开启新的工作线程去处理,等任务队列中的任务完成,再关闭部分工作线程。

线程池一般适用于大量短任务的处理,这样可以避免开启大量线程及频繁的线程切换,提高效率。如果是长时任务,则线程池的优势不明显,并且可能造成其他短任务(要求快速得到响应)得不到运行,造成饥饿。同时线程池不适用于有特定优先级的任务。

二、使用示例

该示例实现了接收客户端的连接,并echo回应连接。使用mio+threadpool的方式。threadpool是rust的一个线程池库。

//! mio+threadpool
#[macro_use]
extern crate log;
extern crate simple_logger;
extern crate mio;
extern crate threadpool;
extern crate num_cpus;

use std::thread;
use std::str::FromStr;
use std::time::Duration;
use std::io::{Read,Write};
use threadpool::{ThreadPool,Builder};
use mio::*;
use mio::tcp::{TcpListener, TcpStream};

fn main() {
    simple_logger::init().unwrap();
    let server_handle=run_server(None);
    server_handle.join();
}

fn run_server(timeout: Option<Duration>)->thread::JoinHandle<()>{
    let handle=thread::spawn(move||{
        let num_cpus=num_cpus::get_physical();
        let pool=Builder::new().num_threads(num_cpus).thread_name(String::from("threadpool")).build();

        const SERVER: Token = Token(0);
        let addr = "127.0.0.1:12345".parse().unwrap();
        let server = TcpListener::bind(&addr).unwrap();

        let poll = Poll::new().unwrap();
        poll.register(&server, SERVER, Ready::readable(), PollOpt::edge()).unwrap();
        let mut events = Events::with_capacity(1024);
        loop {
            match poll.poll(&mut events, timeout){
                Ok(size)=>{
                    trace!("event size={}",size);
                    if size<=0{
                        break;
                    }
                },
                Err(e)=>{
                    error!("{}",e);
                    break;
                }
            }
            for event in events.iter() {
                match event.token() {
                    SERVER => {
                        let (stream,_) = server.accept().unwrap();
                        pool.execute(move ||{
                            simple_echo(stream);
                        });
                    },
                    _ => unreachable!(),
                }
            }
        }

        pool.join();
    });

    handle
}

fn simple_echo(mut stream:TcpStream) {
    info!("New accept {:?}", stream.peer_addr());
    let mut buf = String::new();
    if let Err(e) = stream.read_to_string(&mut buf) {
        error!("{}", e);
    }

    thread::sleep_ms(1000); //加上延时是为了验证线程池工作
    info!("server receive data: {}", buf);
    stream.write_all(buf.as_bytes());
}

#[cfg(test)]
mod tests{
    use super::*;

    #[test]
    fn test_server(){
        simple_logger::init().unwrap();
        let server_handle=run_server(Some(Duration::new(10,0)));
        thread::sleep_ms(1000);
        let client_handle=run_client(4);
        client_handle.join();
        server_handle.join();
    }

    fn run_client(num: usize)->thread::JoinHandle<()>{
        let handle=thread::spawn(move||{
            let mut ths=Vec::new();
            for id in 0..num{
                let h=thread::spawn(move||{
                    client(id);
                });
                ths.push(h);
            }

            for h in ths{
                h.join().unwrap();
            }
        });

        handle
    }

    fn client(id: usize){
        let mut stream = std::net::TcpStream::connect("127.0.0.1:12345").unwrap();
        let mut data=format!("client data {}",id);
        stream.write_all(data.as_bytes());
        let mut buffer=String::new();
        stream.read_to_string(&mut buffer);
        info!("client {} received data:{}",id,buffer);

        info!("connect {} end!",id);
    }
}


复制代码

三、Rust threadpool源码 实现

该线程池实现了对工作线程的创建,线程异常panic处理,工作线程数量可运行时改变,但数量数需要具体指定,并没有实现随任务队列中任务繁忙程度而动态改变等功能。下面代码实现了线程池最基本的功能(对工作线程的管理),列出部分源码如下:

trait FnBox {
    fn call_box(self: Box<Self>);
}

impl<F: FnOnce()> FnBox for F {
    fn call_box(self: Box<F>) {
        (*self)()
    }
}

type Thunk<'a> = Box<FnBox + Send + 'a>;

复制代码

Sentinel主要作用是检测出线程panic后,新建工作线程补充到线程池中。

struct Sentinel<'a> {
    shared_data: &'a Arc<ThreadPoolSharedData>,
    active: bool,
}

impl<'a> Sentinel<'a> {
    fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> {
        Sentinel {
            shared_data: shared_data,
            active: true,
        }
    }

    /// Cancel and destroy this sentinel.
    fn cancel(mut self) {
        self.active = false;
    }
}

impl<'a> Drop for Sentinel<'a> {
    fn drop(&mut self) {
        if self.active {
            self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
            if thread::panicking() {
                self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
            }
            self.shared_data.no_work_notify_all();
            spawn_in_pool(self.shared_data.clone())
        }
    }
}

复制代码

线程池建造者,负责构造线程池

/// [`ThreadPool`] factory, which can be used in order to configure the properties of the [`ThreadPool`].
#[derive(Clone, Default)]
pub struct Builder {
    num_threads: Option<usize>,
    thread_name: Option<String>,
    thread_stack_size: Option<usize>,
}

impl Builder {
    /// Initiate a new [`Builder`].
    pub fn new() -> Builder {
        Builder {
            num_threads: None,
            thread_name: None,
            thread_stack_size: None,
        }
    }

    ...

    // Finalize the Builder and build the ThreadPool.
    pub fn build(self) -> ThreadPool {
        let (tx, rx) = channel::<Thunk<'static>>();
        let num_threads = self.num_threads.unwrap_or_else(num_cpus::get);
        let shared_data = Arc::new(ThreadPoolSharedData {
            name: self.thread_name,
            job_receiver: Mutex::new(rx),
            empty_condvar: Condvar::new(),
            empty_trigger: Mutex::new(()),
            join_generation: AtomicUsize::new(0),
            queued_count: AtomicUsize::new(0),
            active_count: AtomicUsize::new(0),
            max_thread_count: AtomicUsize::new(num_threads),
            panic_count: AtomicUsize::new(0),
            stack_size: self.thread_stack_size,
        });

        // Threadpool threads
        for _ in 0..num_threads {
            spawn_in_pool(shared_data.clone());
        }

        ThreadPool {
            jobs: tx,
            shared_data: shared_data,
        }
    }
}
复制代码
struct ThreadPoolSharedData {
    name: Option<String>,
    job_receiver: Mutex<Receiver<Thunk<'static>>>,
    empty_trigger: Mutex<()>,
    empty_condvar: Condvar,
    join_generation: AtomicUsize,
    queued_count: AtomicUsize,
    active_count: AtomicUsize,
    max_thread_count: AtomicUsize,
    panic_count: AtomicUsize,
    stack_size: Option<usize>,
}

impl ThreadPoolSharedData {
    fn has_work(&self) -> bool {
        self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
    }

    /// Notify all observers joining this pool if there is no more work to do.
    fn no_work_notify_all(&self) {
        if !self.has_work() {
            *self.empty_trigger.lock().expect(
                "Unable to notify all joining threads",
            );
            self.empty_condvar.notify_all();
        }
    }
}

复制代码

线程池结构体

// Abstraction of a thread pool for basic parallelism.
pub struct ThreadPool {
    // How the threadpool communicates with subthreads.
    // This is the only such Sender, so when it is dropped all subthreads will quit.
    jobs: Sender<Thunk<'static>>,
    shared_data: Arc<ThreadPoolSharedData>,
}

impl ThreadPool {
    // Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
    pub fn new(num_threads: usize) -> ThreadPool {
        Builder::new().num_threads(num_threads).build()
    }

    // Executes the function `job` on a thread in the pool.
    pub fn execute<F>(&self, job: F)
    where
        F: FnOnce() + Send + 'static,
    {
        self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
        self.jobs.send(Box::new(job)).expect( "ThreadPool::execute unable to send job into queue.");
    }

    /// Returns the number of jobs waiting to executed in the pool.
    pub fn queued_count(&self) -> usize {
        self.shared_data.queued_count.load(Ordering::Relaxed)
    }

    /// **Deprecated: Use [`ThreadPool::set_num_threads`](#method.set_num_threads)**
    #[deprecated(since = "1.3.0", note = "use ThreadPool::set_num_threads")]
    pub fn set_threads(&mut self, num_threads: usize) {
        self.set_num_threads(num_threads)
    }

    /// Sets the number of worker-threads to use as `num_threads`.  Can be used to change the threadpool size during runtime. Will not abort already running or waiting threads.
    pub fn set_num_threads(&mut self, num_threads: usize) {
        assert!(num_threads >= 1);
        let prev_num_threads = self.shared_data.max_thread_count.swap(
            num_threads,
            Ordering::Release,
        );
        if let Some(num_spawn) = num_threads.checked_sub(prev_num_threads) {
            // Spawn new threads
            for _ in 0..num_spawn {
                spawn_in_pool(self.shared_data.clone());
            }
        }
    }

    /// Block the current thread until all jobs in the pool have been executed.
    pub fn join(&self) {
        // fast path requires no mutex
        if self.shared_data.has_work() == false {
            return ();
        }

        let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
        let mut lock = self.shared_data.empty_trigger.lock().unwrap();

        while generation == self.shared_data.join_generation.load(Ordering::Relaxed) &&
                self.shared_data.has_work() {
            lock = self.shared_data.empty_condvar.wait(lock).unwrap();
        }

        // increase generation if we are the first thread to come out of the loop
        self.shared_data.join_generation.compare_and_swap(generation, generation.wrapping_add(1), Ordering::SeqCst);
    }
}
复制代码

创建线程,实现线程增减,如果运行中工作线程panic,则新建工作线程补充到线程池中。如果是减少当前的工作线程数量,则要等到工作线程运行到自动结束。不会强制终结目前正在运行的工作线程。

fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>) {
    let mut builder = thread::Builder::new();
    if let Some(ref name) = shared_data.name {
        builder = builder.name(name.clone());
    }
    if let Some(ref stack_size) = shared_data.stack_size {
        builder = builder.stack_size(stack_size.to_owned());
    }
    builder
        .spawn(move || {
            // Will spawn a new thread on panic unless it is cancelled.
            let sentinel = Sentinel::new(&shared_data);

            loop {
                // Shutdown this thread if the pool has become smaller
                let thread_counter_val = shared_data.active_count.load(Ordering::Acquire);
                let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
                if thread_counter_val >= max_thread_count_val {
                    break;
                }
                let message = {
                    // Only lock jobs for the time it takes
                    // to get a job, not run it.
                    let lock = shared_data.job_receiver.lock().expect(
                        "Worker thread unable to lock job_receiver",
                    );
                    lock.recv()
                };

                let job = match message {
                    Ok(job) => job,
                    // The ThreadPool was dropped.
                    Err(..) => break,
                };
                // Do not allow IR around the job execution
                shared_data.active_count.fetch_add(1, Ordering::SeqCst);
                shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);

                job.call_box();

                shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
                shared_data.no_work_notify_all();
            }

            sentinel.cancel();
        })
        .unwrap();
}

复制代码

四、其他线程池的实现

线程池有很多实现形式,用Rust语言实现的一个较为复杂的线程池可以参考tokio_threadpool,是一个基于work-stealing算法的线程池。


以上所述就是小编给大家介绍的《线程池的简单实现(Rust)》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

算法引论

算法引论

[美]乌迪·曼博(Udi Manber) / 黄林鹏、谢瑾奎、陆首博、等 / 电子工业出版社 / 2010-1 / 36.00元

本书是国际算法大师乌迪·曼博(Udi Manber)博士撰写的一本享有盛誉的著作。全书共分12章:第1章到第4章为介绍性内容,涉及数学归纳法、算法分析、数据结构等内容;第5章提出了与归纳证明进行类比的算法设计思想;第6章到第9章分别给出了4个领域的算法,如序列和集合的算法、图算法、几何算法、代数和数值算法;第10章涉及归约,也是第11章的序幕,而后者涉及NP完全问题;第12章则介绍了并行算法;最后......一起来看看 《算法引论》 这本书的介绍吧!

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码

MD5 加密
MD5 加密

MD5 加密工具

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具