Jusene's Blog

Python 线程同步

字数统计: 2k阅读时长: 11 min
2018/07/01 Share

多线程

并发: 假同时,一段时间内同时处理多个任务,单核也可以并发
并行; 真同时,同时处理多个任务,必须要多核

python中实现并发的手段:

操作系统:

  • 线程
  • 进程

主流的语言提供用户空间的调度:协程

1
import logging
2
import threading
3
4
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s [%(threadName)s]: %(message)s')
5
6
for i in range(10):
7
    threading.Thread(target=lambda x: logging.info('worker-{}'.format(x)), args=(i,)).start()
8
9
2018-06-23 17:34:44,178 Thread-1726 worker-0
10
2018-06-23 17:34:44,178 Thread-1727 worker-1
11
2018-06-23 17:34:44,178 Thread-1728 worker-2
12
2018-06-23 17:34:44,179 Thread-1729 worker-3
13
2018-06-23 17:34:44,179 Thread-1730 worker-4
14
2018-06-23 17:34:44,179 Thread-1731 worker-5
15
2018-06-23 17:34:44,180 Thread-1732 worker-6
16
2018-06-23 17:34:44,180 Thread-1733 worker-7
17
2018-06-23 17:34:44,180 Thread-1734 worker-8
18
2018-06-23 17:34:44,180 Thread-1735 worker-9

deamon 与 non-deamon

1
import logging
2
import threading
3
4
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s [%(threadName)s]: %(message)s')
5
6
def worker():
7
    logging.info('starting...')
8
    time.sleep(2)
9
    logging.info('complated...')
10
11
12
if __name__ == "__main__":
13
    logging.info('starting...')
14
    t1 = threading.Thread(target=worker, name='non-deamon')
15
    t2 = threading.Thread(target=worker, name='deamon', daemon=True)
16
    t1.start()
17
    t2.start()
18
    loggibf.info('complated...')
19
20
2018-06-24 13:26:41,247 INFO [MainThread]: starting...
21
2018-06-24 13:26:41,247 INFO [non-deamon]: starting...
22
2018-06-24 13:26:41,247 INFO [deamon]: starting...
23
2018-06-24 13:26:41,247 INFO [MainThread]: complated...
24
2018-06-24 13:26:43,252 INFO [non-deamon]: complated...

线程退出时,其deamon子线程也会退出,non-deamon子线程不会退出

thread local

1
thl = threading.local()
2
thl.data = 5
3
4
thl.data
5
5
6
7
threading.Thread(target=lambda: print(thl.data)).start()
8
AttributeError: '_thread._local' object has no attribute 'data'
9
10
threading.local对象的属性,只有当前线程可见,其他线程不可见

定时器

1
timer = threading.Timer(interval=200, function=lambda: print('worker'))
2
timer.daemon=True
3
timer.name='worker'
4
timer.start()
5
timer.cancel()
6
threading.enumerate()

同步

1
import random
2
import datetime
3
import logging
4
import time
5
6
logging.basicConfig(level=logging.INFO, format=('%(asctime)s %(levelname)s [%(threadName)s]: %(message)s'))
7
8
def worker(event: threading.Event):
9
    s = random.randint(1,5)
10
    time.sleep(s)
11
    logging.info('sleep {}'.format(s))
12
13
def boss(event: threading.Event):
14
    start = datetime.datetime.now()
15
    event.wait()
16
    logging.info('worker sleep {}'.format(datetime.datetime.now() - start))
17
18
if __name__ == "__main__":
19
    event = threading.Event()
20
    threading.Thread(target=boss, args=(event,), name='boss').start()
21
    for i in range(5):
22
        thread = threading.Thread(target=worker, args=(event,), name='worker')
23
        thread.start()
24
        thread.join()
25
    event.set()
26
27
2018-06-27 21:40:26,127 INFO [worker]: sleep 4
28
2018-06-27 21:40:27,132 INFO [worker]: sleep 1
29
2018-06-27 21:40:32,136 INFO [worker]: sleep 5
30
2018-06-27 21:40:33,139 INFO [worker]: sleep 1
31
2018-06-27 21:40:36,144 INFO [worker]: sleep 3
32
2018-06-27 21:40:36,148 INFO [boss]: worker sleep 0:00:14.024013

event.wait会阻塞线程直到set方法调用或者超时

1
import random
2
import datetime
3
import logging
4
import time
5
6
logging.basicConfig(level=logging.INFO, format=('%(asctime)s %(levelname)s [%(threadName)s]: %(message)s'))
7
8
def worker(event: threading.Event):
9
    s = random.randint(1,5)
10
    event.wait(s)
11
    event.set()
12
    logging.info('sleep {}'.format(s))
13
14
def boss(event: threading.Event):
15
    start = datetime.datetime.now()
16
    event.wait()
17
    logging.info('worker sleep {}'.format(datetime.datetime.now() - start))
18
19
if __name__ == "__main__":
20
    event = threading.Event()
21
    threading.Thread(target=boss, args=(event,), name='boss').start()
22
    for i in range(5):
23
        thread = threading.Thread(target=worker, args=(event,), name='worker')
24
        thread.start()
25
        thread.join()
26
27
2018-06-27 21:51:14,352 INFO [worker]: sleep 1
28
2018-06-27 21:51:14,353 INFO [boss]: worker sleep 0:00:01.003538
29
2018-06-27 21:51:14,353 INFO [worker]: sleep 3
30
2018-06-27 21:51:14,354 INFO [worker]: sleep 4
31
2018-06-27 21:51:14,354 INFO [worker]: sleep 4
32
2018-06-27 21:51:14,354 INFO [worker]: sleep 4

event可以在线程之间发送信号

通常用于某个线程要等待其他线程处理完成某些动作之后才启动

1
import threading
2
import logging
3
4
logging.basicConfig(level=logging.INFO, format=('%(asctime)s %(levelname)s [%(threadName)s]: %(message)s'))
5
6
def worker(event: threading.Event):
7
    while not event.wait(3):
8
        logging.info('run run run')
9
10
if __name__ == "__main__":
11
    event = threading.Event()
12
    threading.Thread(target=worker, args=(event,), name='worker').start()
13
14
event.wait(1)
15
False
16
event.set()
17
event.wait(1)
18
True
19
event.is_set()
20
event.clear()
21
22
def worker(event: threading.Event):
23
    while not event.is_set():
24
        pass

lock

1
class Counter:
2
     def __init__(self):
3
        with open('test.txt', 'w') as f:
4
            f.write(str(1))
5
     def write(self):
6
        with open('test.txt') as f:
7
             num = f.read()
8
        with open('test.txt', 'w') as f:
9
             f.write(str(int(num) + 2)) 
10
counter = Counter()
11
 
12
for _ in range(10):     
13
    threading.Thread(target=counter.write).start()

会发现有的线程会获取不到数据而抛出异常

lock

1
lock = threading.Lock()
2
3
class Counter:
4
     def __init__(self):
5
        with open('test.txt', 'w') as f:
6
            f.write(str(1))
7
     def write(self):
8
        lock.acquire()
9
        with open('test.txt') as f:
10
             num = f.read()
11
        with open('test.txt', 'w') as f:
12
             f.write(str(int(num) + 2))
13
        lock.release() 
14
counter = Counter()
15
 
16
for _ in range(10):     
17
    threading.Thread(target=counter.write).start()

预先启动10个线程,处理一些任务,当其中一个线程其中一个任务时,其他线程处理其他任务。

1
def worker(tasks):
2
    for task in tasks:
3
        if task.lock.acquire(blocking=False):
4
            logging.info(task.name)
5
6
class Task:
7
    def __init__(self, name):
8
        self.name = name
9
        self.lock = threading.Lock()
10
11
tasks = [Task(x) for x in range(10)]
12
13
for x in range(5):
14
    threading.Thread(target=worker, args=(tasks,), name='worker={}'.format(x)).start()
15
16
原理:
17
lock=threading.Lock()
18
lock.acquire()
19
True
20
lock.acquire(blocking=False)
21
False
22
lock.release()
23
lock.acquire(blocking=False)
24
True
25
26
lock.acquire(timeout=3)
27
False

rlock

可重入锁在同一个线程内,可以多次acquire成功,但是只能有一个线程acquire成功,acquire几次,就需要release几次

Condition

1
class Dispatcher:
2
    def __init__(self):
3
        self.data = None
4
        self.event = threading.Event()
5
    def consumer(self):
6
        while True:
7
            self.event.wait()
8
            logging.info(self.data)
9
            self.event.clear()
10
    def producer(self, value):
11
        self.data = value
12
        self.event.set()
13
14
dispatcher = Dispatcher()
15
consumer = threading.Thread(target=dispatcher.consumer,name='consumer')
16
consumer.start()
17
consumer.is_alive()
18
for i in range(10):
19
    threading.Thread(target=dispatcher.producer, args=(i,), name='producer').start()
1
class Dispatcher:
2
    def __init__(self):
3
        self.data = None
4
        self.event = threading.Event()
5
        self.cond = threading.Condition()
6
    def consumer(self):
7
        while not self.event.is_set():
8
            with self.cond:
9
                self.cond.wait()
10
                logging.info(self.data)
11
    def producer(self):
12
        while True:
13
            self.data = random.randint(1,100)
14
            logging.info(self.data)
15
            with self.cond:
16
                self.cond.notifyAll()
17
                self.event.wait(1)

Condition通常用于生产者消费者模式,生产者生产消息之后,使用notify或者notify_all通知消费者

消费者使用wait方式阻塞线程,等待通知

notify可以制定通知几个线程,默认一个,notify_all通知所以消费者

Barrier 栅栏

1
import threading
2
import logging
3
4
logging.basicConfig(level=logging.INFO, format('%(asctime)s %(levelname)s [%(threadName)s]: %(message)s'))
5
6
def worker(barrier: threading.Barrier):
7
    logging.info('waiting for {} threads'.format(barrier.n_waiting))
8
    try:
9
        worker_id = barrier.wait()
10
    except threading.BrokenBarrierError:
11
        logging.warning('aboring')
12
    else:
13
        logging.info('after barrier {}'.format(worker_id))
14
15
barrier=threading.Barrier(3)
16
17
for x in range(3):
18
    threading.Thread(target=worker, name='worker-{}'.format(x), args=(barrier,)).start()
19
20
2018-06-30 21:37:41,175 INFO [worker-0]: waitting for 0 threads
21
2018-06-30 21:37:41,176 INFO [worker-1]: waitting for 1 threads
22
2018-06-30 21:37:41,176 INFO [worker-2]: waitting for 2 threads
23
2018-06-30 21:37:41,176 INFO [worker-2]: after barrier 2
24
2018-06-30 21:37:41,176 INFO [worker-0]: after barrier 0
25
2018-06-30 21:37:41,176 INFO [worker-1]: after barrier 1

凑齐一波线程,才可以往下执行

barrier.n_waiting 多少个线程在等

异常触发:

1
for x in range(3):
2
    threading.Thread(target=worker, name='worker-{}'.format(x), args=(barrier,)).start()
3
2018-06-30 22:01:22,492 INFO [worker-0]: waitting for 0 threads
4
2018-06-30 22:01:22,492 INFO [worker-1]: waitting for 1 threads
5
6
barrier.abort()
7
2018-06-30 22:01:34,064 WARNING [worker-0]: aborting
8
2018-06-30 22:01:34,065 WARNING [worker-1]: aborting

bariier.reset() 重置barrier
barrier.wait(timeout=3) 当超时会抛出BrokenBarrierError异常

Semaphore 信号量

1
s = threading.Semaphore(3)
2
s.acquire()
3
True
4
s.acquire(False)
5
True
6
s.acquire(False)
7
True
8
s.acquire(False)
9
False

锁是信号量的特例: 为1的信号量

1
class Pool:
2
    def __init__(self, num):
3
        self.num = num
4
        self.conns = [self._make_connect(x) for x in range(num)]
5
        self.s = threading.Semaphore(num)
6
7
    def _make_connect(self, name):
8
        return name
9
10
    def get(self):
11
        self.s.acquire()
12
        return self.conns.pop()
13
14
    def return_resource(self, conn):
15
        self.conns.insert(0,conn)
16
        self.s.release()
17
18
def worker(pool):
19
    logging.info('started')
20
    name = pool.get()
21
    logging.info('get connect {}'.format(name))
22
    time.sleep(2)
23
    pool.return_resource(name)
24
    logging.info('return resource {}'.format(name))
25
26
pool=Pool(3)
27
for x in range(5):
28
    threading.Thread(target=worker, args=(pool,),name='worker-{}'.format(x)).start()

信号量是对资源的保护,但是和锁不一样的地方在于,锁限制只有线程可以访问共享资源,而信号量限制指定个数线程可以访问共享资源。

线程之间的通讯

1
import queue
2
3
def producer(queue: queue.Queue, event: threading.Event):
4
    while not event.wait(3):
5
        data = random.randint(0, 100)
6
        logging.info(data)
7
        queue.put(data)
8
9
def consumer(queue: queue.Queue, event: threading.Event):
10
    while not event.is_set():
11
        logging.info(queue.get())
12
13
q = queue.Queue()
14
e = threading.Event()
15
16
threading.Thread(target=consumer, name='consumer', args=(q,e)).start()
17
threading.Thread(target=producer, name='producer', args=(q,e)).start()
CATALOG
  1. 1. 多线程
  2. 2. deamon 与 non-deamon
  3. 3. thread local
  4. 4. 定时器
  5. 5. 同步
  6. 6. lock
  7. 7. lock
  8. 8. rlock
  9. 9. Condition
  10. 10. Barrier 栅栏
  11. 11. Semaphore 信号量
  12. 12. 线程之间的通讯