From 4fcf28804c88ed090d96f737db7814cce44ac1fd Mon Sep 17 00:00:00 2001 From: MDK Date: Thu, 3 Aug 2023 15:25:17 +0800 Subject: [PATCH] standard commit --- cmd/cache.go | 37 +++++++++ cmd/domain.go | 25 ++++++ cmd/upstream.go | 2 - cmd/version.go | 20 +++++ prober/cache_prober.go | 66 ++++++++++++++++ prober/rdns_prober.go | 65 ++++++--------- prober/result_handler.go | 31 ++++++++ prober/version_prober.go | 1 + scheduler/scheduler.go | 68 ++++++++++++++++ utils/dns_utils.go | 166 +++++++++++++++++++++++++++++++-------- utils/input_utils.go | 32 ++++++++ utils/output_utils.go | 40 +++++++--- 12 files changed, 466 insertions(+), 87 deletions(-) create mode 100644 cmd/cache.go create mode 100644 cmd/domain.go create mode 100644 cmd/version.go create mode 100644 prober/cache_prober.go create mode 100644 prober/result_handler.go create mode 100644 prober/version_prober.go create mode 100644 scheduler/scheduler.go create mode 100644 utils/input_utils.go diff --git a/cmd/cache.go b/cmd/cache.go new file mode 100644 index 0000000..c72af16 --- /dev/null +++ b/cmd/cache.go @@ -0,0 +1,37 @@ +package cmd + +import ( + "dtool/prober" + "dtool/scheduler" + "dtool/utils" + + "github.com/spf13/cobra" +) + +var query_cnt int +var inputfile string +var outputfile string +var cacheCmd = &cobra.Command{ + Use: "cache", + Short: "cache related test", + Long: "cache related test", + Run: cache_test, +} + +func cache_test(cmd *cobra.Command, args []string) { + if len(args) == 1 { + if utils.IsValidIP(args[0]) { + prober.RecursiveCacheTest(args[0], query_cnt) + } + } else { + scheduler.CreateTask(prober.RecursiveCacheProbe, inputfile, outputfile, 10) + } +} + +func init() { + cacheCmd.Flags().StringVarP(&inputfile, "input", "i", "", "input file(optional)") + cacheCmd.Flags().StringVarP(&outputfile, "output", "o", "", "output file(optional)") + cacheCmd.MarkFlagsRequiredTogether("input", "output") + cacheCmd.Flags().IntVarP(&query_cnt, "num", "n", 20, "number of queries in one test") + rootCmd.AddCommand(cacheCmd) +} diff --git a/cmd/domain.go b/cmd/domain.go new file mode 100644 index 0000000..85a7644 --- /dev/null +++ b/cmd/domain.go @@ -0,0 +1,25 @@ +package cmd + +import ( + "github.com/spf13/cobra" +) + +var input string +var output string +var domainCmd = &cobra.Command{ + Use: "domain", + Short: "query the ip and nameserver information", + Long: "query the ip and nameserver information", + Run: getDomainInfo, +} + +func getDomainInfo(cmd *cobra.Command, args []string) { + +} + +func init() { + domainCmd.Flags().StringVarP(&input, "input", "i", "", "") + domainCmd.Flags().StringVarP(&output, "output", "o", "", "") + domainCmd.MarkFlagsRequiredTogether("input", "output") + rootCmd.AddCommand(domainCmd) +} diff --git a/cmd/upstream.go b/cmd/upstream.go index 5c315de..c40528b 100644 --- a/cmd/upstream.go +++ b/cmd/upstream.go @@ -2,9 +2,7 @@ package cmd import ( "dtool/prober" - _ "dtool/prober" "dtool/utils" - _ "dtool/utils" "errors" "github.com/spf13/cobra" diff --git a/cmd/version.go b/cmd/version.go new file mode 100644 index 0000000..3c6712b --- /dev/null +++ b/cmd/version.go @@ -0,0 +1,20 @@ +package cmd + +import ( + "github.com/spf13/cobra" +) + +var versionCmd = &cobra.Command{ + Use: "version", + Short: "get server version with version.bind", + Long: "get server version with version.bind chaos txt request", + Run: version, +} + +func version(cmd *cobra.Command, args []string) { + +} + +func init() { + rootCmd.AddCommand(versionCmd) +} diff --git a/prober/cache_prober.go b/prober/cache_prober.go new file mode 100644 index 0000000..b8dd9af --- /dev/null +++ b/prober/cache_prober.go @@ -0,0 +1,66 @@ +package prober + +import ( + "dtool/utils" + _ "fmt" + "strconv" + "strings" + "time" +) + +const query_num = 20 + +type CacheStruct struct { + target string + dict map[int]map[string]bool +} + +func RecursiveCacheProbe(ip string) CacheStruct { + data := CacheStruct{ip, make(map[int]map[string]bool)} + stop := 0 + time_now := strconv.FormatInt(time.Now().Unix(), 10) + for i := 0; i < query_num; i++ { + if stop >= 3 { + break + } + subdomain := strings.Join([]string{strings.Replace(ip, ".", "-", -1), "fwd", strconv.Itoa(i), time_now}, "-") + domain := subdomain + ".echodns.xyz." + res, err := utils.SendQuery(ip, domain) + if err != nil { + //fmt.Printf("Error sending query: %s\n", err) + stop += 1 + continue + } + cache_id, rdns, err := utils.ParseCNAMEChain(res) + if err != nil { + //fmt.Printf("Error parsing response: %s\n", err) + stop += 1 + continue + } + if data.dict[cache_id] == nil { + data.dict[cache_id] = make(map[string]bool) + } + data.dict[cache_id][rdns] = true + stop = 0 + } + return data +} + +func RecursiveCacheTest(ip string, num int) { + res := make(map[string]map[int][]string) + temp := make(map[int][]string) + data := RecursiveCacheProbe(ip) + if len(data.dict) > 0 { + for cache_id := range data.dict { + for rdns := range data.dict[cache_id] { + temp[cache_id] = append(temp[cache_id], rdns) + } + } + } + res[ip] = temp + utils.OutputJSON(res, "-", " ") +} + +func ClientCacheProbe(ip string) { + +} diff --git a/prober/rdns_prober.go b/prober/rdns_prober.go index 8fae0e3..4227b6a 100644 --- a/prober/rdns_prober.go +++ b/prober/rdns_prober.go @@ -1,56 +1,37 @@ package prober import ( - "bufio" - "fmt" - "io" - "os" "strconv" "strings" "sync" "time" "dtool/utils" + + "github.com/miekg/dns" ) -type Data struct { +type RecursiveStruct struct { target string dict map[string]bool } var dataset map[string][]string -func retrieve_ip(pool chan string, filename string) { - cnt := 0 - f, err := os.Open(filename) - if err != nil { - panic(err) - } - fmt.Println("sending msg ...") - reader := bufio.NewReader(f) - for { - s, err := reader.ReadString('\n') - if err == io.EOF { - break - } - s = strings.Trim(s, "\n") - pool <- s - cnt++ - if cnt%10 == 0 { - fmt.Println(cnt) - } - } - close(pool) -} - -func active_probe(n int, addr string) Data { - target_ip := addr[:len(addr)-3] - data := Data{target_ip, make(map[string]bool)} +func active_probe(n int, addr string) RecursiveStruct { + target_ip := addr + data := RecursiveStruct{target_ip, make(map[string]bool)} stop := 0 timestamp := strconv.FormatInt(time.Now().Unix(), 10) for i := 0; i < n; i++ { subdomain := strings.Join([]string{strings.Replace(target_ip, ".", "-", -1), "echo", strconv.Itoa(i), timestamp}, "-") - rdns_ip, err := utils.SendQuery(addr, subdomain) + domain := dns.Fqdn(subdomain + ".echodns.xyz") + res, err := utils.SendQuery(addr, domain) + if err != nil { + stop += 1 + continue + } + rdns_ip, err := utils.ParseAResponse(res) if err == nil { data.dict[rdns_ip] = true } else { @@ -63,11 +44,10 @@ func active_probe(n int, addr string) Data { return data } -func upstream_prober(ip_pool chan string, data_pool chan Data, wg *sync.WaitGroup) { +func upstream_prober(ip_pool chan string, data_pool chan RecursiveStruct, wg *sync.WaitGroup) { for { if s, ok := <-ip_pool; ok { - addr := s + ":53" - data := active_probe(20, addr) + data := active_probe(20, s) if data.dict != nil { data_pool <- data } @@ -78,14 +58,14 @@ func upstream_prober(ip_pool chan string, data_pool chan Data, wg *sync.WaitGrou wg.Done() } -func create_probers(num int, ip_pool chan string, data_pool chan Data, wg *sync.WaitGroup) { +func create_probers(num int, ip_pool chan string, data_pool chan RecursiveStruct, wg *sync.WaitGroup) { for i := 0; i < num; i++ { wg.Add(1) go upstream_prober(ip_pool, data_pool, wg) } } -func store_data(pool chan Data, wg *sync.WaitGroup) { +func store_data(pool chan RecursiveStruct, wg *sync.WaitGroup) { wg.Add(1) for { var temp []string @@ -106,27 +86,26 @@ func store_data(pool chan Data, wg *sync.WaitGroup) { func Get_upstream_file(filename string, output string, prober_num int) { dataset = map[string][]string{} ip_pool := make(chan string, 500) - data_pool := make(chan Data, 200) + data_pool := make(chan RecursiveStruct, 200) var probe_tasks sync.WaitGroup var store_task sync.WaitGroup - go retrieve_ip(ip_pool, filename) + go utils.RetrieveLines(ip_pool, filename) create_probers(prober_num, ip_pool, data_pool, &probe_tasks) go store_data(data_pool, &store_task) probe_tasks.Wait() close(data_pool) store_task.Wait() - utils.OutputJSON(dataset, output) + utils.OutputJSON(dataset, output, "") } func Get_upstream_ip(ip string) { dataset = make(map[string][]string) var temp []string - addr := ip + ":53" - data := active_probe(10, addr) + data := active_probe(10, ip) for rdns := range data.dict { temp = append(temp, rdns) } dataset[data.target] = temp - utils.OutputJSON(dataset, "-") + utils.OutputJSON(dataset, "-", " ") } diff --git a/prober/result_handler.go b/prober/result_handler.go new file mode 100644 index 0000000..35f85f7 --- /dev/null +++ b/prober/result_handler.go @@ -0,0 +1,31 @@ +package prober + +import "dtool/utils" + +func OutputHandler(data interface{}) (string, error) { + var output_str string + var err error + switch value := data.(type) { + case CacheStruct: + result := make(map[string]map[int][]string) + temp := make(map[int][]string) + if len(value.dict) > 0 { + for cache_id := range value.dict { + for rdns := range value.dict[cache_id] { + temp[cache_id] = append(temp[cache_id], rdns) + } + } + } + result[value.target] = temp + output_str, err = utils.ToJSON(result, "") + case RecursiveStruct: + result := make(map[string][]string) + temp := []string{} + for rdns := range value.dict { + temp = append(temp, rdns) + } + result[value.target] = temp + output_str, err = utils.ToJSON(result, "") + } + return output_str, err +} diff --git a/prober/version_prober.go b/prober/version_prober.go new file mode 100644 index 0000000..8219d21 --- /dev/null +++ b/prober/version_prober.go @@ -0,0 +1 @@ +package prober diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go new file mode 100644 index 0000000..496851c --- /dev/null +++ b/scheduler/scheduler.go @@ -0,0 +1,68 @@ +package scheduler + +import ( + "bufio" + "dtool/prober" + "dtool/utils" + "fmt" + "os" + "sync" +) + +//type ProbeTask func(string) interface{} + +func output_process(output chan interface{}, file string, wg *sync.WaitGroup) { + f, err := os.Create(file) + if err != nil { + panic(err) + } + defer f.Close() + writer := bufio.NewWriter(f) + for { + if data, ok := <-output; ok { + str, err := prober.OutputHandler(data) + if err != nil { + fmt.Printf("Error generating output: %s\n", err) + continue + } + _, err = writer.WriteString(str + "\n") + if err != nil { + fmt.Printf("Error writing file: %s\n", err) + } + } else { + break + } + } + writer.Flush() + wg.Done() +} + +func concurrent_execution[T any](fn func(string) T, input chan string, output chan interface{}, wg *sync.WaitGroup) { + for { + if ip, ok := <-input; ok { + data := fn(ip) + output <- data + } else { + break + } + } + wg.Done() +} + +func CreateTask[T any](fn func(string) T, input_file string, output_file string, concurrent_num int) { + input_pool := make(chan string, 500) + output_pool := make(chan interface{}, 500) + var probe_tasks sync.WaitGroup + var store_tasks sync.WaitGroup + + go utils.RetrieveLines(input_pool, input_file) + probe_tasks.Add(concurrent_num) + for i := 0; i < concurrent_num; i++ { + go concurrent_execution(fn, input_pool, output_pool, &probe_tasks) + } + store_tasks.Add(1) + go output_process(output_pool, output_file, &store_tasks) + probe_tasks.Wait() + close(output_pool) + store_tasks.Wait() +} diff --git a/utils/dns_utils.go b/utils/dns_utils.go index f6d4578..9b398ca 100644 --- a/utils/dns_utils.go +++ b/utils/dns_utils.go @@ -1,10 +1,11 @@ +// dns_utils.go contains the necessary functions for building +// and parsing a DNS packet package utils import ( + "encoding/binary" "fmt" - "strconv" "strings" - "time" "github.com/miekg/dns" ) @@ -13,43 +14,144 @@ type WrongAnswerError struct { Message string } +type QueryStruct struct { + Id uint16 + RD bool + Qname string + Qclass uint16 + Qtype uint16 +} + +type DomainInfo struct { + IPList []string + NSList map[string][]string +} + func (e *WrongAnswerError) Error() string { return fmt.Sprintf("Wrong Answer: %s", e.Message) } -func SendQuery(addr string, dn string) (string, error) { - var ( - domain string - rdns_ip string - ) - if dn == "timestamp" { - timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) - domain = strings.Join([]string{timestamp, "-scan.echodns.xyz."}, "") - } else { - domain = strings.Join([]string{dn, ".echodns.xyz."}, "") - } - //fmt.Println(domain) - m := new(dns.Msg) - m.SetQuestion(domain, dns.TypeA) - m.RecursionDesired = true +// build the question section of a dns packet +func QuestionMaker(domain string, qclass uint16, qtype uint16) *dns.Question { + return &dns.Question{Name: domain, Qtype: qtype, Qclass: qclass} +} - res, err := dns.Exchange(m, addr) - if err == nil { - if len(res.Answer) == 1 { - if a, ok := res.Answer[0].(*dns.A); ok { - rdns_ip = a.A.String() - } else { - rdns_ip = "" - err = &WrongAnswerError{ - Message: "Wrong Record Type", - } - } +// build a dns query message +func QueryMaker(query QueryStruct) *dns.Msg { + msg := new(dns.Msg) + if query.Id < 0 { + msg.Id = dns.Id() + } else { + msg.Id = query.Id + } + msg.RecursionDesired = query.RD + + var query_name string + var query_class, query_type uint16 + msg.Question = make([]dns.Question, 1) + query_name = dns.Fqdn(query.Qname) + + // default class INET + if query.Qclass == 0 { + query_class = 1 + } else { + query_class = query.Qclass + } + + // default type A + if query.Qtype == 0 { + query_type = 1 + } else { + query_type = query.Qtype + } + + question := QuestionMaker(query_name, query_class, query_type) + msg.Question[0] = *question + return msg +} + +func ParseAResponse(msg *dns.Msg) (string, error) { + var ip_str string + if len(msg.Answer) == 1 { + if a, ok := msg.Answer[0].(*dns.A); ok { + ip_str = a.A.String() } else { - rdns_ip = "" - err = &WrongAnswerError{ - Message: "Wrong Answer Section", + err := &WrongAnswerError{ + Message: "Wrong record type", + } + return ip_str, err + } + } else { + err := &WrongAnswerError{ + Message: "Wrong answer section", + } + return ip_str, err + } + return ip_str, nil +} + +func ParseCNAMEChain(msg *dns.Msg) (int, string, error) { + var cache_id int + var rdns_ip string + if len(msg.Answer) == 3 { + if cname, ok := msg.Answer[0].(*dns.CNAME); ok { + rdns_ip = strings.Join(strings.Split(strings.Split(cname.Target, ".")[0], "-")[1:], ".") + if a, ok := msg.Answer[2].(*dns.A); ok { + cache_id = int(binary.BigEndian.Uint32(a.A)) } } + } else { + err := &WrongAnswerError{ + Message: "Wrong record number", + } + return cache_id, rdns_ip, err } - return rdns_ip, err + return cache_id, rdns_ip, nil +} + +func ParseTXTResponse(msg *dns.Msg) (string, error) { + var txt_string string + if len(msg.Answer) == 1 { + if txt, ok := msg.Answer[0].(*dns.TXT); ok { + txt_string = txt.String() + } else { + err := &WrongAnswerError{ + Message: "Wrong record type", + } + return txt_string, err + } + } else { + err := &WrongAnswerError{ + Message: "wrong record number", + } + return txt_string, err + } + return txt_string, nil +} + +func SendQuery(ip string, domain string) (*dns.Msg, error) { + addr := ip + ":53" + query := new(QueryStruct) + query.Qname = domain + query.RD = true + query.Id = dns.Id() + m := QueryMaker(*query) + + res, err := dns.Exchange(m, addr) + + return res, err +} + +func SendVersionQuery(ip string) (*dns.Msg, error) { + addr := ip + "53" + query := new(QueryStruct) + query.Id = dns.Id() + query.Qname = "version.bind" + query.Qclass = dns.ClassCHAOS + query.Qtype = dns.TypeTXT + m := QueryMaker(*query) + + res, err := dns.Exchange(m, addr) + + return res, err } diff --git a/utils/input_utils.go b/utils/input_utils.go new file mode 100644 index 0000000..a968f07 --- /dev/null +++ b/utils/input_utils.go @@ -0,0 +1,32 @@ +package utils + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" +) + +func RetrieveLines(pool chan string, filename string) { + cnt := 0 + f, err := os.Open(filename) + if err != nil { + panic(err) + } + fmt.Println("reading file ...") + reader := bufio.NewReader(f) + for { + s, err := reader.ReadString('\n') + if err == io.EOF { + break + } + s = strings.Trim(s, "\n") + pool <- s + cnt++ + if cnt%10 == 0 { + fmt.Println(cnt) + } + } + close(pool) +} diff --git a/utils/output_utils.go b/utils/output_utils.go index d914494..396b626 100644 --- a/utils/output_utils.go +++ b/utils/output_utils.go @@ -3,22 +3,42 @@ package utils import ( "encoding/json" "fmt" - "io/ioutil" + "os" ) -func OutputJSON(data interface{}, filename string) error { - jsonstr, err := json.MarshalIndent(data, "", " ") - if err != nil { - fmt.Println("JSON encoding error:", err) - return err - } - if filename == "-" { - fmt.Println(string(jsonstr)) +func ToJSON(data interface{}, indent string) (string, error) { + var jsonbyte []byte + var err error + if indent == "" { + jsonbyte, err = json.Marshal(data) } else { - err := ioutil.WriteFile(filename, jsonstr, 0666) + jsonbyte, err = json.MarshalIndent(data, "", indent) + } + return string(jsonbyte), err +} + +func WriteFileLine(filename string, data string) error { + if filename == "-" { + fmt.Println(data) + } else { + err := os.WriteFile(filename, []byte(data), 0666) if err != nil { return err } } return nil } + +func OutputJSON(data interface{}, filename string, indent string) error { + str, err := ToJSON(data, indent) + if err != nil { + fmt.Printf("Error encoding JSON: %s", err) + return err + } + err = WriteFileLine(filename, str) + if err != nil { + fmt.Printf("Error writing to file: %s", err) + return err + } + return nil +}