diff --git a/cmd/domain.go b/cmd/domain.go deleted file mode 100644 index 85a7644..0000000 --- a/cmd/domain.go +++ /dev/null @@ -1,25 +0,0 @@ -package cmd - -import ( - "github.com/spf13/cobra" -) - -var input string -var output string -var domainCmd = &cobra.Command{ - Use: "domain", - Short: "query the ip and nameserver information", - Long: "query the ip and nameserver information", - Run: getDomainInfo, -} - -func getDomainInfo(cmd *cobra.Command, args []string) { - -} - -func init() { - domainCmd.Flags().StringVarP(&input, "input", "i", "", "") - domainCmd.Flags().StringVarP(&output, "output", "o", "", "") - domainCmd.MarkFlagsRequiredTogether("input", "output") - rootCmd.AddCommand(domainCmd) -} diff --git a/cmd/query.go b/cmd/query.go new file mode 100644 index 0000000..5838bb1 --- /dev/null +++ b/cmd/query.go @@ -0,0 +1,106 @@ +package cmd + +import ( + "dtool/utils" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "github.com/spf13/cobra" +) + +// input parameters +// var query_target string +var query_port string +var query_name string +var query_type uint16 +var query_class uint16 +var query_rate int +var query_packets int +var query_random_prefix bool + +/* +DNS请求测试功能,可设置参数包括域名(随机前缀),资源记录type/class,请求速率,请求数量 +*/ +var queryCmd = &cobra.Command{ + Use: "query", + Short: "", + Long: "", + Run: queryTest, +} + +func concurrent_request(ip, domain, port string, qtype, qclass uint16, result chan int, wg *sync.WaitGroup) { + wg.Add(1) + res, err := utils.SendQuery(ip, domain, port, qtype, qclass) + if err != nil { + result <- (-1) + //fmt.Println(err) + } else { + rcode := utils.ParseRcode(res) + result <- rcode + } + wg.Done() +} + +func response_stat(stat []int, result chan int, wg *sync.WaitGroup) { + wg.Add(1) + defer wg.Done() + for { + if v, ok := <-result; ok { + switch v { + case 0: + stat[0]++ + case 2: + stat[1]++ + case 3: + stat[2]++ + case -1: + stat[3]++ + default: + stat[4]++ + } + } else { + break + } + } +} + +func queryTest(cmd *cobra.Command, args []string) { + var wg1 sync.WaitGroup + var wg2 sync.WaitGroup + result := make(chan int) + stat := make([]int, 5) + if len(args) == 1 { + if utils.IsValidIP(args[0]) { + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + go response_stat(stat, result, &wg2) + for i := 0; i < query_packets; i++ { + var qname string + if query_random_prefix { + qname = strings.Join([]string{timestamp, strconv.Itoa(i)}, "-") + query_name + } else { + qname = query_name + } + go concurrent_request(args[0], qname, query_port, query_type, query_class, result, &wg1) + time.Sleep(time.Second / time.Duration(query_rate)) + } + wg1.Wait() + close(result) + wg2.Wait() + fmt.Printf("total: %v noerr: %v servfail: %v nx: %v error: %v other: %v\n", query_packets, stat[0], stat[1], stat[2], stat[3], stat[4]) + } + } +} + +func init() { + queryCmd.Flags().BoolVarP(&query_random_prefix, "random", "R", false, "random query name") + queryCmd.Flags().StringVarP(&query_port, "port", "p", "53", "query target port") + queryCmd.Flags().StringVarP(&query_name, "name", "d", "example.com", "query name in DNS requests") + queryCmd.Flags().Uint16VarP(&query_type, "type", "t", 1, "target record type") + queryCmd.Flags().Uint16VarP(&query_class, "class", "c", 1, "target record class") + queryCmd.Flags().IntVarP(&query_rate, "rate", "r", 10, "request rate (packets per second)") + queryCmd.Flags().IntVarP(&query_packets, "packets", "n", 100, "total request number") + rootCmd.AddCommand(queryCmd) +} diff --git a/dtool b/dtool index 3f38d77..f6479fc 100755 Binary files a/dtool and b/dtool differ diff --git a/prober/cache_prober.go b/prober/cache_prober.go index 94adeb0..7b7890d 100644 --- a/prober/cache_prober.go +++ b/prober/cache_prober.go @@ -27,7 +27,7 @@ func RecursiveCacheProbe(ip string, sld string) CacheStruct { } subdomain := strings.Join([]string{strings.Replace(ip, ".", "-", -1), "fwd", strconv.Itoa(i), time_now}, "-") domain := dns.Fqdn(subdomain + "." + sld) - res, err := utils.SendQuery(ip, domain) + res, err := utils.SendAQuery(ip, domain) if err != nil { //fmt.Printf("Error sending query: %s\n", err) stop += 1 diff --git a/prober/rdns_prober.go b/prober/rdns_prober.go index 4227b6a..cc36233 100644 --- a/prober/rdns_prober.go +++ b/prober/rdns_prober.go @@ -26,7 +26,7 @@ func active_probe(n int, addr string) RecursiveStruct { for i := 0; i < n; i++ { subdomain := strings.Join([]string{strings.Replace(target_ip, ".", "-", -1), "echo", strconv.Itoa(i), timestamp}, "-") domain := dns.Fqdn(subdomain + ".echodns.xyz") - res, err := utils.SendQuery(addr, domain) + res, err := utils.SendAQuery(addr, domain) if err != nil { stop += 1 continue diff --git a/utils/dns_utils.go b/utils/dns_utils.go index 641968a..d89926e 100644 --- a/utils/dns_utils.go +++ b/utils/dns_utils.go @@ -75,6 +75,10 @@ func QueryMaker(query QueryStruct) *dns.Msg { return msg } +func ParseRcode(msg *dns.Msg) int { + return msg.MsgHdr.Rcode +} + func ParseAResponse(msg *dns.Msg) (string, error) { var ip_str string if len(msg.Answer) == 1 { @@ -134,7 +138,22 @@ func ParseTXTResponse(msg *dns.Msg) (string, error) { return txt_string, nil } -func SendQuery(ip string, domain string) (*dns.Msg, error) { +func SendQuery(ip, domain, port string, qtype, qclass uint16) (*dns.Msg, error) { + addr := ip + ":" + port + query := new(QueryStruct) + query.Qname = domain + query.RD = true + query.Qtype = qtype + query.Qclass = uint16(qclass) + query.Id = dns.Id() + m := QueryMaker(*query) + + res, err := dns.Exchange(m, addr) + + return res, err +} + +func SendAQuery(ip string, domain string) (*dns.Msg, error) { addr := ip + ":53" query := new(QueryStruct) query.Qname = domain