asyncdns.py


#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2014-2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from __future__ import absolute_import, division, print_function, \
    with_statement

import os
import socket
import struct
import re
import logging

from shadowsocks import common, lru_cache, eventloop, shell

CACHE_SWEEP_INTERVAL = 30

VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d\-_]{1,63}(? 63:
            return None
        results.append(common.chr(l))#一个字节
        results.append(label)
    results.append(b'\0')
    return b''.join(results)

def build_request(address, qtype):#构造一个请求数据包
    request_id = os.urandom(2)#2byte租房因此高压包
    header = struct.pack('!BBHHHH', 1, 0, 1, 0, 0, 0)#112222字节 0 QR查询 0000 opcode标准查询 0 响应报文有效果 0 不截断 1可以迭代查询 0相应 000保留取0 0000返回码 无错误 1个问题 0回答 权威服务器条目数 额外信息条目数
    #提问只打包了一个题目
    addr = build_address(address)#打标签
    qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN) #各两个字节
    return request_id + header + addr + qtype_qclass

#构造的都是迭代查询

# 分析ip数据包,返回一个ip地址,点分格式
def parse_ip(addrtype, data, length, offset):
    if addrtype == QTYPE_A:
        return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
    elif addrtype == QTYPE_AAAA:
        return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
    elif addrtype in [QTYPE_CNAME, QTYPE_NS]:#包里回复报文解析出来的 权威
        return parse_name(data, offset)[1]#cnmae返回cnmae记录值
    else:
        return data[offset:offset + length]#SOA直接返回数据 不解析了

def parse_name(data, offset):#返回一个元组(整个记录的长度,别名b'www.google.com')的二进制流 answer的长度和question的长度一致 故可以一直用
    p = offset
    labels = []
    l = common.ord(data[p])
    while l > 0:
        if (l & (128 + 64)) == (128 + 64):#递归了
            # pointer
            pointer = struct.unpack('!H', data[p:p + 2])[0]
            pointer &= 0x3FFF#取出递归的偏移位置
            r = parse_name(data, pointer)
            labels.append(r[1])
            p += 2#指针就是最后了 直接返回
            # pointer is the end
            return p - offset, b'.'.join(labels)#返回的是www.google.com这样
        else:
            labels.append(data[p + 1:p + 1 + l])#从偏移处取出标签
            p += 1 + l#更改偏移
        l = common.ord(data[p])#取下一个的长度 结束为0
    return p - offset + 1, b'.'.join(labels)#一个是直接返回 一个是递归返回

# rfc1035
# record
#                                    1  1  1  1  1  1
#      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                                               |
#    /                                               /
#    /                      NAME                     /
#    |                                               |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                      TYPE                     |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                     CLASS                     |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                      TTL                      |
#    |                                               |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                   RDLENGTH                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
#    /                     RDATA                     /
#    /                                               /
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
def parse_record(data, offset, question=False):#nlen 是名字的长度
    nlen, name = parse_name(data, offset)
    if not question:
        record_type, record_class, record_ttl, record_rdlength = struct.unpack(
            '!HHiH', data[offset + nlen:offset + nlen + 10]
        )
        ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)#从offset + nlen + 10开始 加上record_rdlength ns记录返回了元组(长度,域名)
        return nlen + 10 + record_rdlength, \
            (name, ip, record_type, record_class, record_ttl) #一条记录的长度,以及一个元组
    else:
        record_type, record_class = struct.unpack(#这是提问的 没有
            '!HH', data[offset + nlen:offset + nlen + 4]#问题的名字还在 没有RDLENGTH和RDATA
        )
        return nlen + 4, (name, None, record_type, record_class, None, None)

def parse_header(data):
    if len(data) >= 12:
        header = struct.unpack('!HBBHHHH', data[:12])
        res_id = header[0]
        res_qr = header[1] & 128#10000000 最高位
        res_tc = header[1] & 2#10 RD位 1期望递归
        res_ra = header[2] & 128#RA位 1 递归可用
        res_rcode = header[2] & 15#Rcode位
        # assert res_tc == 0
        # assert res_rcode in [0, 3]
        res_qdcount = header[3]
        res_ancount = header[4]
        res_nscount = header[5]
        res_arcount = header[6]
        return (res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount,
                res_ancount, res_nscount, res_arcount)
    return None

def parse_response(data):
    try:
        if len(data) >= 12:
            header = parse_header(data)
            if not header:
                return None
            res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \
                res_ancount, res_nscount, res_arcount = header

            qds = []
            ans = []
            offset = 12#去头
            for i in range(0, res_qdcount):
                l, r = parse_record(data, offset, True)
                offset += l#parse_record返回的是字段长度 所以offset要累加 非常漂亮

                if r:
                    qds.append(r)
            for i in range(0, res_ancount):
                l, r = parse_record(data, offset)#回答 跨过一条
                offset += l
                if r:
                    ans.append(r)#回答加入元组
            for i in range(0, res_nscount):
                l, r = parse_record(data, offset)#名字长度+头长度10+记录长度
                offset += l
            for i in range(0, res_arcount):
                l, r = parse_record(data, offset)
                offset += l
            response = DNSResponse()#一个类
            if qds:
                response.hostname = qds[0][0]#www.baidu.com
            for an in qds:
                response.questions.append((an[1], an[2], an[3]))#None, record_type, record_class 只取这些record_ttl不要了
            for an in ans:
                response.answers.append((an[1], an[2], an[3]))#ip, record_type, record_class
            return response
    except Exception as e:
        shell.print_exception(e)
        return None

def is_valid_hostname(hostname):
    if len(hostname) > 255:
        return False
    if hostname[-1] == b'.':
        hostname = hostname[:-1]
    return all(VALID_HOSTNAME.match(x) for x in hostname.split(b'.'))

class DNSResponse(object): #调用str方法
    def __init__(self):
        self.hostname = None
        self.questions = []  # each: (addr, type, class)
        self.answers = []  # each: (addr, type, class)

    def __str__(self):
        return '%s: %s' % (self.hostname, str(self.answers)) #字符串表示

STATUS_FIRST = 0
STATUS_SECOND = 1

class DNSResolver(object):

    def __init__(self, server_list=None, prefer_ipv6=False):
        self._loop = None
        self._hosts = {}
        self._hostname_status = {}
        self._hostname_to_cb = {}
        self._cb_to_hostname = {}
        self._cache = lru_cache.LRUCache(timeout=300)#没有指定删除时的 对各个值的回调函数
        self._sock = None
        if server_list is None:
            self._servers = None #一个包含DNS的列表
            self._parse_resolv()
        else:
            self._servers = server_list
        if prefer_ipv6:
            self._QTYPES = [QTYPE_AAAA, QTYPE_A]
        else:
            self._QTYPES = [QTYPE_A, QTYPE_AAAA]
        self._parse_hosts()
        # TODO monitor hosts change and reload hosts
        # TODO parse /etc/gai.conf and follow its rules

    def _parse_resolv(self):
        self._servers = []
        try:
            with open('/etc/resolv.conf', 'rb') as f:#从/etc/resolv.conf找dns服务器
                content = f.readlines()
                for line in content:
                    line = line.strip()
                    if not (line and line.startswith(b'nameserver')):
                        continue

                    parts = line.split()
                    if len(parts) 

发表评论