线程-信号量

线程-信号量


 1# author: wangy
 2# date: 2024/8/10
 3# description: 信号量
 4
 5"""
 6信号量允许多个线程同时访问资源.
 7信号量的目的是为了'保护资源', 而不是共享资源.
 8比如控制数据库连接数量以控制并发量, 控制线程池大小等等
 9
10信号量持有"许可", 任何想访问对象的线程都必须通过信号量
11获得许可, 结束后"归还"许可. 许可用完后, 其他线程则无法
12再访问对象, 须等待其他线程"归还"许可.
13
14获得许可的线程, 不一定要归还许可
15
16许可其实是一个线程安全的"计数器"
17"""
18from threading import current_thread, BoundedSemaphore, Thread, Lock
19
20
21class ConnectionDB:
22    init_id = 0
23
24    def __init__(self):
25        ConnectionDB.init_id += 1
26        self.__connect_id = ConnectionDB.init_id
27        self.__connect_name = f"{self.__class__.__name__}-{self.__connect_id}"
28
29    def connect(self):
30        print(f"{current_thread().name}: {self.__connect_name} connected ✔️")
31
32    def disconnect(self):
33        print(f"{current_thread().name}: {self.__connect_name} disconnected ✖️")
34
35
36class Pool:
37
38    def __init__(self, cls, pool_size: int = 10):
39        self.__pool_size = pool_size
40        # 一般使用有界信号量
41        self.__semaphore = BoundedSemaphore(self.__pool_size)
42        self.__lock = Lock()
43        self.__objs = [cls() for x in range(pool_size)]
44
45    def __get_connection(self):
46        if self.__semaphore.acquire(timeout=5):
47            with self.__lock:
48                return self.__objs.pop()
49        else:
50            raise TimeoutError("too many connections")
51
52    def use_conn(self):
53        try:
54            conn = self.__get_connection()
55        except TimeoutError as e:
56            print(f"{current_thread().name}: {e}")
57            return
58        try:
59            conn.connect()
60            time.sleep(3)
61        finally:
62            conn.disconnect()
63            with self.__lock:
64                self.__objs.append(conn)
65            self.__semaphore.release()
66
67
68def run(p: Pool):
69    p.use_conn()
70
71
72if __name__ == '__main__':
73    pool = Pool(ConnectionDB, 4)
74    ts = [Thread(target=run, args=(pool,), daemon=True)
75          for x in range(10)]
76    for t in ts:
77        t.start()
78
79    for t in ts:
80        t.join()