finish upstream subcommand

This commit is contained in:
MDK
2023-06-26 15:58:04 +08:00
parent b59fbce07c
commit 38be845348
7 changed files with 59 additions and 16 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
/test

View File

@@ -1,13 +1,17 @@
package cmd package cmd
import ( import (
"dtool/prober"
_ "dtool/prober" _ "dtool/prober"
"dtool/utils" "dtool/utils"
_ "dtool/utils"
"errors"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var filename string var filename string
var output_file string
var thread_num int var thread_num int
var upstreamCmd = &cobra.Command{ var upstreamCmd = &cobra.Command{
Use: "upstream", Use: "upstream",
@@ -17,16 +21,30 @@ var upstreamCmd = &cobra.Command{
input target can be added as an argument or as a file input target can be added as an argument or as a file
-f input file with limited ip addresses (limit=50) -f input file with limited ip addresses (limit=50)
-o output file default json type
-t number of goroutine`, -t number of goroutine`,
Args: cobra.ExactArgs(1), //Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) { Run: upstream,
//prober.Get_upstream_ip(args[0]) }
utils.SendTencentHttpdnsQuery()
}, func upstream(cmd *cobra.Command, args []string) {
if len(args) > 1 {
panic(errors.New("too many arguments!"))
} else if len(args) == 1 {
if utils.IsValidIP(args[0]) {
prober.Get_upstream_ip(args[0])
} else {
panic(errors.New("invalid ip address"))
}
} else if len(args) == 0 {
prober.Get_upstream_file(filename, output_file, thread_num)
}
} }
func init() { func init() {
upstreamCmd.Flags().StringVarP(&filename, "file", "f", "", "input filename") upstreamCmd.Flags().StringVarP(&filename, "file", "f", "", "input file(optional)")
upstreamCmd.Flags().StringVarP(&output_file, "output", "o", "", "output file(optional)")
upstreamCmd.MarkFlagsRequiredTogether("file", "output")
upstreamCmd.Flags().IntVarP(&thread_num, "threads", "t", 10, "number of concurrent threads") upstreamCmd.Flags().IntVarP(&thread_num, "threads", "t", 10, "number of concurrent threads")
rootCmd.AddCommand(upstreamCmd) rootCmd.AddCommand(upstreamCmd)
} }

6
go.mod
View File

@@ -2,11 +2,13 @@ module dtool
go 1.20 go 1.20
require github.com/spf13/cobra v1.7.0 require (
github.com/miekg/dns v1.1.54
github.com/spf13/cobra v1.7.0
)
require ( require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/miekg/dns v1.1.54 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/mod v0.7.0 // indirect golang.org/x/mod v0.7.0 // indirect
golang.org/x/net v0.2.0 // indirect golang.org/x/net v0.2.0 // indirect

1
go.sum
View File

@@ -12,6 +12,7 @@ golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA=
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/tools v0.3.0 h1:SrNbZl6ECOS1qFzgTdQfWXZM9XBkiA6tkFrH9YSTPHM= golang.org/x/tools v0.3.0 h1:SrNbZl6ECOS1qFzgTdQfWXZM9XBkiA6tkFrH9YSTPHM=

View File

@@ -24,8 +24,7 @@ func retrieve_ip(pool chan string, filename string) {
cnt := 0 cnt := 0
f, err := os.Open(filename) f, err := os.Open(filename)
if err != nil { if err != nil {
fmt.Printf("cannot open file %s\n", filename) panic(err)
return
} }
fmt.Println("sending msg ...") fmt.Println("sending msg ...")
reader := bufio.NewReader(f) reader := bufio.NewReader(f)
@@ -104,7 +103,7 @@ func store_data(pool chan Data, wg *sync.WaitGroup) {
wg.Done() wg.Done()
} }
func Get_upstream_file(filename string, prober_num int) { func Get_upstream_file(filename string, output string, prober_num int) {
dataset = map[string][]string{} dataset = map[string][]string{}
ip_pool := make(chan string, 500) ip_pool := make(chan string, 500)
data_pool := make(chan Data, 200) data_pool := make(chan Data, 200)
@@ -117,7 +116,7 @@ func Get_upstream_file(filename string, prober_num int) {
probe_tasks.Wait() probe_tasks.Wait()
close(data_pool) close(data_pool)
store_task.Wait() store_task.Wait()
utils.OutputJSON(dataset) utils.OutputJSON(dataset, output)
} }
func Get_upstream_ip(ip string) { func Get_upstream_ip(ip string) {
@@ -129,5 +128,5 @@ func Get_upstream_ip(ip string) {
temp = append(temp, rdns) temp = append(temp, rdns)
} }
dataset[data.target] = temp dataset[data.target] = temp
utils.OutputJSON(dataset) utils.OutputJSON(dataset, "-")
} }

13
utils/other_utils.go Normal file
View File

@@ -0,0 +1,13 @@
package utils
import (
"net"
)
func IsValidIP(ip string) bool {
res := net.ParseIP(ip)
if res == nil {
return false
}
return true
}

View File

@@ -3,13 +3,22 @@ package utils
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
) )
func OutputJSON(data map[string][]string) { func OutputJSON(data interface{}, filename string) error {
jsonstr, err := json.MarshalIndent(data, "", " ") jsonstr, err := json.MarshalIndent(data, "", " ")
if err != nil { if err != nil {
fmt.Println("JSON encoding error:", err) fmt.Println("JSON encoding error:", err)
return return err
} }
fmt.Println(string(jsonstr)) if filename == "-" {
fmt.Println(string(jsonstr))
} else {
err := ioutil.WriteFile(filename, jsonstr, 0666)
if err != nil {
return err
}
}
return nil
} }