tcprelay.py


#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from __future__ import absolute_import, division, print_function, \
    with_statement

import time
import socket
import errno
import struct
import logging
import traceback
import random

from shadowsocks import encrypt, eventloop, shell, common
from shadowsocks.common import parse_header, onetimeauth_verify, \
    onetimeauth_gen, ONETIMEAUTH_BYTES, ONETIMEAUTH_CHUNK_BYTES, \
    ONETIMEAUTH_CHUNK_DATA_LEN, ADDRTYPE_AUTH

# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time
TIMEOUTS_CLEAN_SIZE = 512

MSG_FASTOPEN = 0x20000000

# SOCKS METHOD definition
METHOD_NOAUTH = 0

# SOCKS command definition
CMD_CONNECT = 1
CMD_BIND = 2
CMD_UDP_ASSOCIATE = 3

# for each opening port, we have a TCP Relay

# for each connection, we have a TCP Relay Handler to handle the connection

# for each handler, we have 2 sockets:
#    local:   connected to the client
#    remote:  connected to remote server

# for each handler, it could be at one of several stages:

# as sslocal:
# stage 0 auth METHOD received from local, reply with selection message
# stage 1 addr received from local, query DNS for remote
# stage 2 UDP assoc
# stage 3 DNS resolved, connect to remote 逐个连接 中继连接
# stage 4 still connecting, more data from local received
# stage 5 remote connected, piping local and remote

# as ssserver:
# stage 0 just jump to stage 1
# stage 1 addr received from local, query DNS for remote
# stage 3 DNS resolved, connect to remote
# stage 4 still connecting, more data from local received
# stage 5 remote connected, piping local and remote

STAGE_INIT = 0
STAGE_ADDR = 1
STAGE_UDP_ASSOC = 2
STAGE_DNS = 3
STAGE_CONNECTING = 4
STAGE_STREAM = 5
STAGE_DESTROYED = -1

# for each handler, we have 2 stream directions:
#    upstream:    from client to server direction
#                 read local and write to remote
#    downstream:  from server to client direction
#                 read remote and write to local

STREAM_UP = 0
STREAM_DOWN = 1

# for each stream, it's waiting for reading, or writing, or both
WAIT_STATUS_INIT = 0
WAIT_STATUS_READING = 1
WAIT_STATUS_WRITING = 2
WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING

BUF_SIZE = 32 * 1024

# helper exceptions for TCPRelayHandler

class BadSocksHeader(Exception):
    pass

class NoAcceptableMethods(Exception):
    pass

class TCPRelayHandler(object):
    def __init__(self, server, fd_to_handlers, loop, local_sock, config,
                 dns_resolver, is_local):
        self._server = server
        self._fd_to_handlers = fd_to_handlers #传到连接套接字的类里
        self._loop = loop
        self._local_sock = local_sock#传入的套接字
        self._remote_sock = None#需要解析的
        self._config = config
        self._dns_resolver = dns_resolver

        # TCP Relay works as either sslocal or ssserver
        # if is_local, this is sslocal
        self._is_local = is_local
        self._stage = STAGE_INIT
        self._encryptor = encrypt.Encryptor(config['password'],
                                            config['method'])#加密实例化 会将初始向量传过去
        self._ota_enable = config.get('one_time_auth', False)
        self._ota_enable_session = self._ota_enable
        #字节流是逐比特读取 https://prinsss.github.io/why-do-shadowsocks-deprecate-ota/ 弃用的参考文档
        '''
        +------+---------------------+------------------+-----------+
        | ATYP | Destination Address | Destination Port | HMAC - SHA1 |
        +------+---------------------+------------------+-----------+
        | 1    | Variable            | 2                |      10     |
        +------+---------------------+------------------+-----------+
        '''
        self._ota_buff_head = b''
        self._ota_buff_data = b''
        self._ota_len = 0
        self._ota_chunk_idx = 0
        self._fastopen_connected = False
        self._data_to_write_to_local = []
        self._data_to_write_to_remote = []
        self._upstream_status = WAIT_STATUS_READING
        self._downstream_status = WAIT_STATUS_INIT
        #init状态 下行流啥也不监听 上行流监听读 实际上是local socket的pollin事件
        self._client_address = local_sock.getpeername()[:2]#获取连接的远程端点地址信息的方法 回了一个类似 (address, port) [:2] 表示取这个元组的前两个元素,即
        self._remote_address = None #这里是要在dns解析后 才去处理
        #config是字典 'forbidden_ip'通过构造时 将字符串替换为了  IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128')) 的对象 也就是common.IPNetwork
        self._forbidden_iplist = config.get('forbidden_ip')#这里在common里涉及到了cidr段的处理 ipv6与ipv4复合地址的处理 ipv6地址缩写的处理 可以理解到为什么ipv6地址缩写0的::只能有一次 并且如何处理
        if is_local:
            self._chosen_server = self._get_a_server()#local才会选
        fd_to_handlers[local_sock.fileno()] = self#向tcprelay的fd_to_handlers 登记这个连接套接字的文件描述符 以及对应的处理器
        local_sock.setblocking(False)
        local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)#它代表了禁用 Nagle 算法 tcp协议族 降低延迟 收集小包发送 tcp/ip 卷子P495
        loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR,#POLL_IN开始的
                 self._server)#把tcprelay作为处理器传过去 再分发
        self.last_activity = 0#通过._server.update_activity赋予当前的时间
        self._update_activity()#更新last_activity 并在tcprelay中 更新更新timeouts队列 以及 handler to timeouts

    def __hash__(self): #对象的内存地址 id() 不除以16了 以消除碰撞
        # default __hash__ is id / 16
        # we want to eliminate collisions
        return id(self)

    @property#def remote_address(self)。通过 @property 装饰器 通过方法获取属性值 不能修改 修改需要@remote_address.setter 修饰器
    def remote_address(self):
        return self._remote_address

    def _get_a_server(self):
        server = self._config['server']
        server_port = self._config['server_port']
        if type(server_port) == list:
            server_port = random.choice(server_port)
        if type(server) == list:
            server = random.choice(server)
        logging.debug('chosen server: %s:%d', server, server_port)
        return server, server_port

    def _update_activity(self, data_len=0):
        # tell the TCP Relay we have activities recently
        # else it will think we are inactive and timed out
        self._server.update_activity(self, data_len)

    def _update_stream(self, stream, status):
        # update a stream to a new waiting status

        # check if status is changed
        # only update if dirty
        dirty = False
        if stream == STREAM_DOWN:
            if self._downstream_status != status:
                self._downstream_status = status
                dirty = True
        elif stream == STREAM_UP:
            if self._upstream_status != status:
                self._upstream_status = status
                dirty = True
        if not dirty:
            return

        if self._local_sock:#本地套接字 只关心下行流的写入 和 上行流的读取
            event = eventloop.POLL_ERR
            if self._downstream_status & WAIT_STATUS_WRITING:
                event |= eventloop.POLL_OUT
            if self._upstream_status & WAIT_STATUS_READING:
                event |= eventloop.POLL_IN
            self._loop.modify(self._local_sock, event)#修改监听事件
        if self._remote_sock:
            event = eventloop.POLL_ERR
            if self._downstream_status & WAIT_STATUS_READING:
                event |= eventloop.POLL_IN
            if self._upstream_status & WAIT_STATUS_WRITING:
                event |= eventloop.POLL_OUT
            self._loop.modify(self._remote_sock, event)
            '''
                def modify(self, f, mode):
                fd = f.fileno()
                self._impl.modify(fd, mode)#修改关注的事件
            '''

    def _write_to_sock(self, data, sock):
        #写不完 才监听可写 这是运用了非阻塞套接字的特效
        #不在没写入时 就监听可写 这样会loop一直自循环
        # write data to sock
        # if only some of the data are written, put remaining in the buffer
        # and update the stream to wait for writing
        if not data or not sock:
            return False
        uncomplete = False
        try:
            l = len(data)
            s = sock.send(data)#它返回一个整数值,表示实际发送的字节数
            if s H', port)#>网络字节序
                self._write_to_sock(header + addr_to_send + port_to_send,
                                    self._local_sock)#返回给socks5 client 关联数据包
                self._stage = STAGE_UDP_ASSOC
                # just wait for the client to disconnect
                #sock5 udp有写相关内容 这个tcp连接关了 udp就关了
                return
            elif cmd == CMD_CONNECT:
                # just trim VER CMD RSV
                data = data[3:]
            else:
                logging.error('unknown command %d', cmd)
                self.destroy()
                return
        header_result = parse_header(data) #common.parse_header     return addrtype, to_bytes(dest_addr), dest_port, header_length
        if header_result is None:#密码错了 这里解密的就错了 这里返回了none 提出异常
            raise Exception('can not parse header')
        addrtype, remote_addr, remote_port, header_length = header_result#拿到了远端地址端口 可能是域名
        logging.info('connecting %s:%d from %s:%d' %
                     (common.to_str(remote_addr), remote_port,
                      self._client_address[0], self._client_address[1]))
        if self._is_local is False:
            # spec https://shadowsocks.org/en/spec/one-time-auth.html
            self._ota_enable_session = addrtype & ADDRTYPE_AUTH
            if self._ota_enable and not self._ota_enable_session:
                logging.warn('client one time auth is required')
                return
            if self._ota_enable_session:
                if len(data)  header_length:
                self._data_to_write_to_remote.append(data[header_length:])
            # notice here may go into _handle_dns_resolved directly
            self._dns_resolver.resolve(remote_addr,
                                       self._handle_dns_resolved)#解析的是远端的地址 也就是目标服务器的地址

    def _create_remote_socket(self, ip, port):#这里传入的ip只是为了获得建立连接的五元组
        addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM,
                                   socket.SOL_TCP)
        if len(addrs) == 0:
            raise Exception("getaddrinfo failed for %s:%d" % (ip, port))
        af, socktype, proto, canonname, sa = addrs[0]
        if self._forbidden_iplist:
            if common.to_str(sa[0]) in self._forbidden_iplist:#是comman.IPNetwork的实例 验证是通过__contain__方法 简单说是先判断ip是v4 v6 然后移位掩码去看和列表里各个地址段是否一致 一致则抛出异常 NO则通过
                raise Exception('IP %s is in forbidden list, reject' %
                                common.to_str(sa[0]))
        remote_sock = socket.socket(af, socktype, proto)
        self._remote_sock = remote_sock
        self._fd_to_handlers[remote_sock.fileno()] = self#注册处理器 调用handle event
        remote_sock.setblocking(False)
        remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
        return remote_sock

    @shell.exception_handle(self_=True)
    def _handle_dns_resolved(self, result, error): #  callback((hostname, ip), None)
        if error:
            addr, port = self._client_address[0], self._client_address[1]
            logging.error('%s when handling connection from %s:%d' %
                          (error, addr, port))#二次查询没有结果                 callback((hostname, None),Exception('unknown hostname %s' % hostname))
            self.destroy()#没查到 毁了
            return
        if not (result and result[1]):
            self.destroy()
            return

        ip = result[1]#local server是ip直接拿到就是ip地址
        self._stage = STAGE_CONNECTING
        remote_addr = ip#local对应于server的ip 是在handle——addr里设置的
        if self._is_local:
            remote_port = self._chosen_server[1]
        else:
            remote_port = self._remote_address[1]

        if self._is_local and self._config['fast_open']:
            # for fastopen:
            # wait for more data arrive and send them in one SYN
            self._stage = STAGE_CONNECTING
            # we don't have to wait for remote since it's not
            # created
            self._update_stream(STREAM_UP, WAIT_STATUS_READING)#监听local socket的读事件
            #在handle stage addr中 设置是        self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)#未连接远端 实际上localsocket啥也不监听 remote还没连接 状态转入dns
            #此时开始处理local socket的可读事件

            # TODO when there is already data in this packet
        else:
            # else do connect
            remote_sock = self._create_remote_socket(remote_addr,
                                                     remote_port)#建立连接 返回的是设置好的套接字 未连接
            try:
                remote_sock.connect((remote_addr, remote_port))#建立连接
            except (OSError, IOError) as e:
                if eventloop.errno_from_exception(e) == \
                        errno.EINPROGRESS:#仍在进行
                    pass
            self._loop.add(remote_sock,
                           eventloop.POLL_ERR | eventloop.POLL_OUT,
                           self._server)#进入监听remote socket的可写状态
            self._stage = STAGE_CONNECTING
            self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)#remote:io local:in
            self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
            # 在handle stage addr中 设置是        self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)#未连接远端 实际上localsocket啥也不监听 remote还没连接 状态转入dns
            # 此时开始处理local socket的可读事件 remote socket的可读可写事件

    def _write_to_sock_remote(self, data):
        self._write_to_sock(data, self._remote_sock)

    def _ota_chunk_data(self, data, data_cb):
        # spec https://shadowsocks.org/en/spec/one-time-auth.html
        unchunk_data = b''
        while len(data) > 0:
            if self._ota_len == 0:
                # get DATA.LEN + HMAC-SHA1
                length = ONETIMEAUTH_CHUNK_BYTES - len(self._ota_buff_head)
                self._ota_buff_head += data[:length]
                data = data[length:]
                if len(self._ota_buff_head) H', data_len)[0]
            length = min(self._ota_len - len(self._ota_buff_data), len(data))
            self._ota_buff_data += data[:length]
            data = data[length:]
            if len(self._ota_buff_data) == self._ota_len:
                # get a chunk data
                _hash = self._ota_buff_head[ONETIMEAUTH_CHUNK_DATA_LEN:]
                _data = self._ota_buff_data
                index = struct.pack('>I', self._ota_chunk_idx)
                key = self._encryptor.decipher_iv + index
                if onetimeauth_verify(_hash, _data, key) is False:
                    logging.warn('one time auth fail, drop chunk !')
                else:
                    unchunk_data += _data
                    self._ota_chunk_idx += 1
                self._ota_buff_head = b''
                self._ota_buff_data = b''
                self._ota_len = 0
        data_cb(unchunk_data)
        return

    def _ota_chunk_data_gen(self, data):
        data_len = struct.pack(">H", len(data))
        index = struct.pack('>I', self._ota_chunk_idx)
        key = self._encryptor.cipher_iv + index
        sha110 = onetimeauth_gen(data, key)
        self._ota_chunk_idx += 1
        return data_len + sha110 + data

    def _handle_stage_stream(self, data):#对于local是未加密数据 对于server是解密了的数据
        #将数据发往remote _on_remote_read在获取数据后 会调用该函数 将明文数据作为参数传入
        if self._is_local:
            if self._ota_enable_session:
                data = self._ota_chunk_data_gen(data)
            data = self._encryptor.encrypt(data)#加密数据
            self._write_to_sock(data, self._remote_sock)
            #可能写入会阻塞 在_write_to_sock方法会在_data_to_write_to_xxx添加未写完的内容 并修改监听事件
        else:
            if self._ota_enable_session:
                self._ota_chunk_data(data, self._write_to_sock_remote)
            else:
                self._write_to_sock(data, self._remote_sock)
        return

    '''

       The client connects to the server, and sends a version
       identifier/method selection message:

                       +----+----------+----------+
                       |VER | NMETHODS | METHODS  |
                       +----+----------+----------+
                       | 1  |    1     | 1 to 255 |
                       +----+----------+----------+

    '''
    def _check_auth_method(self, data):

        # VER, NMETHODS, and at least 1 METHODS
        if len(data) connection.setsockopt(socket.SOL_TCP, 23, 5) here 23 is the protocol number of TCP_FASTOPEN it is not defined in socket module if python2 so writing manually here and 5 is the queue length for number of TFO request which are yet to complete 3 way handshake.
            except socket.error:
                logging.error('warning: fast open is not available')
                self._config['fast_open'] = False
        server_socket.listen(1024)

        #传入连接请求的队列的最大长度 当服务器套接字处于监听状态时,它会等待传入的连接请求。如果在一段时间内有多个客户端尝试连接到服务器,但服务器尚未来得及处理,这些连接请求将会被放置在一个队列中。listen() 方法的参数指定了这个队列的最大长度。
        #在使用 accept() 方法接受传入连接请求时,它会从队列中移除一个等待的连接请求,并创建一个新的套接字用于与客户端进行通信。
        self._server_socket = server_socket
        self._stat_callback = stat_callback

    def add_to_loop(self, loop):
        if self._eventloop:
            raise Exception('already add to loop')
        if self._closed:
            raise Exception('already closed')
        self._eventloop = loop
        self._eventloop.add(self._server_socket,#tcp监听的地址和端口 关注可读事件
                            eventloop.POLL_IN | eventloop.POLL_ERR, self)
        self._eventloop.add_periodic(self.handle_periodic)#asyncdns是清楚lrucache的缓存

    def remove_handler(self, handler):
        index = self._handler_to_timeouts.get(hash(handler), -1)#处理器在超时队列里的下标
        if index >= 0:
            # delete is O(n), so we just set it to None
            self._timeouts[index] = None
            del self._handler_to_timeouts[hash(handler)]#处理器在超时队列下标也要删除

    def update_activity(self, handler, data_len):#tcprelayhandler传进 默认datalen是0
        if data_len and self._stat_callback:#默认是none
            self._stat_callback(self._listen_port, data_len)#其他类调用这个类用的回调函数结构

        # set handler to active
        now = int(time.time())
        if now - handler.last_activity = 0:
            # delete is O(n), so we just set it to None
            self._timeouts[index] = None
        length = len(self._timeouts)
        self._timeouts.append(handler)
        self._handler_to_timeouts[hash(handler)] = length#handler_to_timeouts更新到最后 更新活动只有当大于超时时才更新

    def _sweep_timeout(self):#主要是tcprelayhandler的destroy方法 毁掉超时的连接套接字 dns是lruchache的清理
        # tornado's timeout memory management is more flexible than we need
        # we just need a sorted last_activity queue and it's faster than heapq
        # in fact we can do O(1) insertion/remove so we invent our own
        if self._timeouts:
            logging.log(shell.VERBOSE_LEVEL, 'sweeping timeouts')
            now = time.time()
            length = len(self._timeouts)
            pos = self._timeout_offset
            while pos  TIMEOUTS_CLEAN_SIZE and pos > length >> 1:#大于长度的一半
                # clean up the timeout queue when it gets larger than half
                # of the queue
                self._timeouts = self._timeouts[pos:]
                for key in self._handler_to_timeouts:#先将超时的移除了 在timeouts里只出现一次
                    #会调用本类的remove_handler方法  del self._handler_to_timeouts[hash(handler)] 将_handler_to_timeouts中的handler移除 在之后的遍历就没有了
                    #超时了 已经清除 如果未过半 _handler_to_timeouts里也没有这个handler了 _timeouts[index]里也释放了
                    self._handler_to_timeouts[key] -= pos
                pos = 0
            self._timeout_offset = pos

    def handle_event(self, sock, fd, event):#eventloop的add方法会将套接字映射成文件描述符在loop的_fdmap方法中_fdmap[fd]就是套接字sock
        # handle events and dispatch to handlers
        if sock:
            logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd,
                        eventloop.EVENT_NAMES.get(event, event))#不在返回event本身
        if sock == self._server_socket:#新连接来了
            if event & eventloop.POLL_ERR:
                # TODO
                raise Exception('server_socket error')
            try:
                logging.debug('accept')
                conn = self._server_socket.accept()#接受 也就是tcp三次握手中的第二个回程发送以及第三个收到ack 之后返回进行下一步 conn是连接套接字
                #TCPRelayHandler的构造函数原型def __init__(self, server, fd_to_handlers, loop, local_sock, config,dns_resolver, is_local):
                TCPRelayHandler(self, self._fd_to_handlers,
                                self._eventloop, conn[0], self._config,
                                self._dns_resolver, self._is_local)
                #144行的       self._loop.add(remote_sock, eventloop.POLL_ERR, self._server)
                #tcprelayhandler的注册时 初始化第一个参数会作为_server注册事件处理 这里传进去的是tcprelay类
                #也就是说 tcprelayhandler的事件 最先调用的处理器 是tcprelay的处理器
            except (OSError, IOError) as e:
                error_no = eventloop.errno_from_exception(e)
                if error_no in (errno.EAGAIN, errno.EINPROGRESS,
                                errno.EWOULDBLOCK):
                    return
                else:
                    shell.print_exception(e)
                    if self._config['verbose']:
                        traceback.print_exc()
        else:
            if sock:#连接套接字 tcprelayhandler
                handler = self._fd_to_handlers.get(fd, None)#文件描述符注册了对应的处理器
                if handler:
                    handler.handle_event(sock, event)
            else:
                logging.warn('poll removed fd')

    def handle_periodic(self):
        if self._closed:#不关闭不进行tcprelay的移除
            if self._server_socket:
                self._eventloop.remove(self._server_socket)#在loop中 查询文件描述符后 删除fdmap的映射字典项 并且在epoll对象中移除关注事件 这里是只注册了的pollin可读事件
                self._server_socket.close()
                self._server_socket = None
                logging.info('closed TCP port %d', self._listen_port)
            if not self._fd_to_handlers:#所有连接套接字都关闭了 等待自然关闭 停止loop
                logging.info('stopping')
                self._eventloop.stop()
        self._sweep_timeout()#维护的最后更新时间和文件描述的字典 进行相关操作

    def close(self, next_tick=False):
        logging.debug('TCP close')
        self._closed = True
        if not next_tick:#next_tick设为1 不执行下面的 只是closed设了标 将关闭的流程放到了tcprelay的handle_periodic中,这样可以等到各个连接套接字自然关闭 比较优雅 这里是强关 直接移除周期性
            if self._eventloop:
                self._eventloop.remove_periodic(self.handle_periodic)
                self._eventloop.remove(self._server_socket)
            self._server_socket.close()
            for handler in list(self._fd_to_handlers.values()):
                handler.destroy()


发表评论