diff --git a/device.go b/device.go new file mode 100644 index 0000000..5579eba --- /dev/null +++ b/device.go @@ -0,0 +1,171 @@ +package main + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "strconv" + "strings" + + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" +) + +const ( + // SysfsDevices = "/sys/bus/pci/devices" + SysfsDevices = "/root/demo" + MgmtPrefix = "/dev/xclmgmt" + UserPrefix = "/dev/dri" + QdmaPrefix = "/dev/xfpga" + QDMASTR = "dma.qdma.u" + UserPFKeyword = "drm" + DRMSTR = "renderD" + ROMSTR = "rom" + SNSTR = "xmc.u." + DSAverFile = "VBNV" + DSAtsFile = "timestamp" + InstanceFile = "instance" + MgmtFile = "mgmt_pf" + UserFile = "user_pf" + VendorFile = "vendor" + DeviceFile = "device" + SNFile = "serial_num" + VtShell = "xilinx_u30" + U30CommonShell = "ama_u30" + XilinxVendorID = "0x10ee" + ADVANTECH_ID = "0x13fe" + AWS_ID = "0x1d0f" + AristaVendorID = "0x3475" +) + +type Pairs struct { + Mgmt string + User string + Qdma string +} + +type Device struct { + index string + deviceID string //devid of the user pf + Healthy string + SN string +} + +func GetInstance(DBDF string) (string, error) { + strArray := strings.Split(DBDF, ":") + domain, err := strconv.ParseUint(strArray[0], 16, 16) + if err != nil { + return "", fmt.Errorf("strconv failed: %s\n", strArray[0]) + } + bus, err := strconv.ParseUint(strArray[1], 16, 8) + if err != nil { + return "", fmt.Errorf("strconv failed: %s\n", strArray[1]) + } + strArray = strings.Split(strArray[2], ".") + dev, err := strconv.ParseUint(strArray[0], 16, 8) + if err != nil { + return "", fmt.Errorf("strconv failed: %s\n", strArray[0]) + } + fc, err := strconv.ParseUint(strArray[1], 16, 8) + if err != nil { + return "", fmt.Errorf("strconv failed: %s\n", strArray[1]) + } + ret := domain*65536 + bus*256 + dev*8 + fc + return strconv.FormatUint(ret, 10), nil +} + +func GetFileNameFromPrefix(dir string, prefix string) (string, error) { + userFiles, err := ioutil.ReadDir(dir) + if err != nil { + return "", fmt.Errorf("Can't read folder %s", dir) + } + for _, userFile := range userFiles { + fname := userFile.Name() + + if !strings.HasPrefix(fname, prefix) { + continue + } + return fname, nil + } + return "", nil +} + +func GetFileContent(file string) (string, error) { + if buf, err := ioutil.ReadFile(file); err != nil { + return "", fmt.Errorf("Can't read file %s", file) + } else { + ret := strings.Trim(string(buf), "\n") + return ret, nil + } +} + +func FileExist(fname string) bool { + if _, err := os.Stat(fname); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} + +func IsMgmtPf(pciID string) bool { + fname := path.Join(SysfsDevices, pciID, MgmtFile) + return FileExist(fname) +} + +func IsUserPf(pciID string) bool { + fname := path.Join(SysfsDevices, pciID, UserFile) + return FileExist(fname) +} + +func GetDevices() ([]Device, error) { + var devices []Device + pciFiles, err := ioutil.ReadDir(SysfsDevices) + if err != nil { + return nil, fmt.Errorf("Can't read folder %s", SysfsDevices) + } + + for _, pciFile := range pciFiles { + pciID := pciFile.Name() + // get device id + fname := path.Join(SysfsDevices, pciID, DeviceFile) + content, err := GetFileContent(fname) + if err != nil { + return nil, err + } + devid := content + fname = path.Join(SysfsDevices, pciID, "SN") + content, err = GetFileContent(fname) + if err != nil { + return nil, err + } + sn := content + healthy := pluginapi.Healthy + devices = append(devices, Device{ + index: strconv.Itoa(len(devices) + 1), + deviceID: devid, + Healthy: healthy, + SN: sn, + }) + } + return devices, nil +} + +// func main() { +// devices, err := GetDevices() +// if err != nil { +// fmt.Printf("%s !!!\n", err) +// return +// } +// //SNFolder, err := GetFileNameFromPrefix(path.Join(SysfsDevices, "0000:e3:00.1"), SNSTR) +// //fname := path.Join(SysfsDevices, "0000:e3:00.1", SNFolder, SNFile) +// //content, err := GetFileContent(fname) +// //SN := content +// //fmt.Printf("SN: %v \n", SN) +// for _, device := range devices { +// fmt.Printf("Device: %v \n", device) +// fmt.Printf("ID: %s \n\n", device.deviceID) +// fmt.Printf("SN: %s \n\n", device.SN) +// fmt.Printf("Heathy: %s \n\n", device.Healthy) +// } +// } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a790a22 --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module deviceplugindemo + +go 1.20 + +require ( + github.com/fsnotify/fsnotify v1.6.0 + github.com/sirupsen/logrus v1.9.0 + golang.org/x/net v0.8.0 + google.golang.org/grpc v1.53.0 + k8s.io/kubelet v0.26.2 +) + +require ( + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/protobuf v1.5.2 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/text v0.8.0 // indirect + google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect + google.golang.org/protobuf v1.28.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8e3500e --- /dev/null +++ b/go.sum @@ -0,0 +1,69 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f h1:BWUVssLB0HVOSY78gIdvk1dTVYtT1y8SBWtPYuTJ/6w= +google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM= +google.golang.org/grpc v1.53.0 h1:LAv2ds7cmFV/XTS3XG1NneeENYrXGmorPxsBbptIjNc= +google.golang.org/grpc v1.53.0/go.mod h1:OnIrk0ipVdj4N5d9IUoFUx72/VlD7+jUsHwZgwSMQpw= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +k8s.io/kubelet v0.26.2 h1:egg7YfhCpH9wvLwQdL2Mzuy4/kC6hO91azY0jgdYPWA= +k8s.io/kubelet v0.26.2/go.mod h1:IXthU5hcJQE6+K33LuaYYO0wUcYO8glhl/ip1Hzux44= diff --git a/main.go b/main.go new file mode 100644 index 0000000..1756578 --- /dev/null +++ b/main.go @@ -0,0 +1,70 @@ +package main + +import ( + "flag" + "os" + "syscall" + + "github.com/fsnotify/fsnotify" + log "github.com/sirupsen/logrus" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" +) + +func main() { + // Parse command-line arguments + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flagLogLevel := flag.String("log-level", "info", "Define the logging level: error, info, debug.") + flag.Parse() + + switch *flagLogLevel { + case "debug": + log.SetLevel(log.DebugLevel) + case "info": + log.SetLevel(log.InfoLevel) + } + + log.Println("Starting FS watcher.") + watcher, err := newFSWatcher(pluginapi.DevicePluginPath) + if err != nil { + log.Printf("Failed to created FS watcher: %s.", err) + os.Exit(1) + } + defer watcher.Close() + + log.Println("Starting OS watcher.") + sigs := newOSWatcher(syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + + restart := true + var devicePlugin *FPGADevicePlugin +L: + for { + if restart { + devicePlugin = NewFPGADevicePlugin() + restart = false + } + + select { + case update := <-devicePlugin.updateChan: + devicePlugin.checkDeviceUpdate(update) + + case event := <-watcher.Events: + if event.Name == pluginapi.KubeletSocket && event.Op&fsnotify.Create == fsnotify.Create { + log.Printf("inotify: %s created, restarting.", pluginapi.KubeletSocket) + restart = true + } + + case err := <-watcher.Errors: + log.Printf("inotify: %s", err) + + case s := <-sigs: + switch s { + case syscall.SIGHUP: + log.Println("Received SIGHUP, restarting.") + restart = true + default: + log.Printf("Received signal \"%v\", shutting down.", s) + break L + } + } + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..e2d9c75 --- /dev/null +++ b/server.go @@ -0,0 +1,385 @@ +package main + +import ( + "fmt" + "net" + "os" + "path" + "reflect" + _ "runtime/debug" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/net/context" + "google.golang.org/grpc" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" +) + +const ( + resourceNamePrefix = "deviceplugindemo/" + serverSockPath = pluginapi.DevicePluginPath + // AWS_SN = "F1-Node" +) + +// FPGADevicePluginServer implements the Kubernetes device plugin API +type FPGADevicePluginServer struct { + devType string + devices map[string]Device + socket string + stop chan interface{} + update chan map[string]Device + + server *grpc.Server +} + +type FPGADevicePlugin struct { + devices map[string]map[string]Device + servers map[string]*FPGADevicePluginServer + updateChan chan map[string]map[string]Device +} + +// NewFPGADevicePlugin returns an initialized FPGADevicePlugin +func NewFPGADevicePlugin() *FPGADevicePlugin { + log.Debugf("create FPGA device plugin") + updateChan := make(chan map[string]map[string]Device) + plugin := FPGADevicePlugin{ + devices: make(map[string]map[string]Device), + servers: make(map[string]*FPGADevicePluginServer), + updateChan: updateChan, + } + + go func() { + for { + devices, err := GetDevices() + if err != nil { + time.Sleep(75 * time.Second) + devices, err = GetDevices() + if err != nil { + log.Errorf("Error to get FPGA devices: %v", err) + break + } + } + devMap := make(map[string]map[string]Device) + for _, device := range devices { + + DSAtype := device.index + id := device.deviceID + if subMap, ok := devMap[DSAtype]; ok { + subMap = devMap[DSAtype] + subMap[id] = device + } else { + subMap = make(map[string]Device) + devMap[DSAtype] = subMap + subMap[id] = device + } + } + updateChan <- devMap + time.Sleep(5 * time.Second) + } + close(updateChan) + }() + + return &plugin +} + +func (m *FPGADevicePlugin) checkDeviceUpdate(n map[string]map[string]Device) { + added := make(map[string]map[string]Device) + updated := make(map[string]map[string]Device) + removed := make(map[string]map[string]Device) + + for oDevType, oDevices := range m.devices { + if nDevices, ok := n[oDevType]; ok { + if !reflect.DeepEqual(oDevices, nDevices) { + updated[oDevType] = nDevices + } + delete(n, oDevType) + } else { + removed[oDevType] = oDevices + } + } + for nDevType, nDevices := range n { + added[nDevType] = nDevices + } + + //create new server for added devices + for aDevType, aDevices := range added { + devicePluginServer := m.NewFPGADevicePluginServer(aDevType, aDevices) + m.devices[aDevType] = aDevices + m.servers[aDevType] = devicePluginServer + go func(aDevType string, aDevices map[string]Device, name string) { + if err := m.servers[aDevType].Serve(name); err != nil { + log.Println("Could not contact Kubelet, Exit. Did you enable the device plugin feature gate?") + os.Exit(1) + } + m.servers[aDevType].update <- aDevices + }(aDevType, aDevices, resourceNamePrefix+aDevType) + } + + //stop server for removed devices + for rDevType, rDevices := range removed { + log.Debugf("Remove device %v", rDevices) + m.servers[rDevType].Stop() + delete(m.servers, rDevType) + delete(m.devices, rDevType) + } + + //send update for updated devices + for uDevType, uDevices := range updated { + m.devices[uDevType] = uDevices + m.servers[uDevType].update <- uDevices + } +} + +// NewFPGADevicePluginServer returns an initialized FPGADevicePluginServer +func (m *FPGADevicePlugin) NewFPGADevicePluginServer(devType string, devices map[string]Device) *FPGADevicePluginServer { + return &FPGADevicePluginServer{ + devType: devType, + devices: devices, + socket: path.Join(serverSockPath, devType+"-demodevice.sock"), + stop: make(chan interface{}), + update: make(chan map[string]Device, 1), + } +} + +// waitForServer checks if grpc server is alive +// by making grpc blocking connection to the server socket +func waitForServer(socket string, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + conn, err := grpc.DialContext(ctx, socket, grpc.WithInsecure(), grpc.WithBlock(), + grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", addr, timeout) + }), + ) + if conn != nil { + conn.Close() + } + + if err != nil { + fmt.Errorf("Failed dial context at %s", socket) + return err + } + return nil +} + +func (m *FPGADevicePluginServer) deviceExists(id string) bool { + for k, _ := range m.devices { + if k == id { + return true + } + } + return false +} + +func (m *FPGADevicePluginServer) PreStartContainer(ctx context.Context, rqt *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { + return nil, fmt.Errorf("PreStartContainer() should not be called") +} + +func (m *FPGADevicePluginServer) GetDevicePluginOptions(ctx context.Context, empty *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { + fmt.Println("GetDevicePluginOptions: return empty options") + return new(pluginapi.DevicePluginOptions), nil +} + +// Start starts the gRPC server of the device plugin +func (m *FPGADevicePluginServer) Start() error { + err := m.cleanup() + if err != nil { + return err + } + + sock, err := net.Listen("unix", m.socket) + if err != nil { + return err + } + + m.server = grpc.NewServer() + pluginapi.RegisterDevicePluginServer(m.server, m) + + go m.server.Serve(sock) + + // Wait for the server to start + if err = waitForServer(m.socket, 10*time.Second); err != nil { + return err + } + + return nil +} + +// Stop stops the gRPC server +func (m *FPGADevicePluginServer) Stop() error { + if m.server == nil { + return nil + } + + m.server.Stop() + m.server = nil + close(m.stop) + close(m.update) + + return m.cleanup() +} + +func (m *FPGADevicePluginServer) cleanup() error { + if err := os.Remove(m.socket); err != nil && !os.IsNotExist(err) { + return err + } + + return nil +} + +// Register registers the device plugin for the given resourceName with Kubelet. +func (m *FPGADevicePluginServer) Register(kubeletEndpoint, resourceName string) error { + conn, err := grpc.Dial(kubeletEndpoint, grpc.WithInsecure(), + grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", addr, timeout) + })) + + if err != nil { + log.Debugf("Cann't connect to kubelet service") + return err + } + defer conn.Close() + + client := pluginapi.NewRegistrationClient(conn) + reqt := &pluginapi.RegisterRequest{ + Version: pluginapi.Version, + Endpoint: path.Base(m.socket), + ResourceName: resourceName, + } + + _, err = client.Register(context.Background(), reqt) + if err != nil { + log.Debugf("Cann't register to kubelet service") + return err + } + return nil +} + +// func IsContain(items []string, item string) bool { +// AWS_SN := "F1-Node" +// for _, eachItem := range items { +// if eachItem == item && strings.EqualFold(item, AWS_SN) != true { +// return true +// } +// } +// return false +// } +func (m *FPGADevicePluginServer) sendDevices(s pluginapi.DevicePlugin_ListAndWatchServer) error { + resp := new(pluginapi.ListAndWatchResponse) + + check_range := m.devices + SerialNums := []string{} + for _, device := range check_range { + if device.SN == "" { + log.Printf("Error, Device %v has empty Serial number", device.deviceID) + } else { + SerialNums = append(SerialNums, device.SN) + tem := &pluginapi.Device{ + ID: device.deviceID, + Health: device.Healthy, + } + resp.Devices = append(resp.Devices, tem) + } + } + log.Printf("Check SeialNums arry: %v", SerialNums) + log.Printf("Sending %d device(s) %v to kubelet", len(resp.Devices), resp.Devices) + if err := s.Send(resp); err != nil { + m.Stop() + log.Debugf("Cannot update device list") + return err + } + return nil +} + +// ListAndWatch lists devices and update that list according to the health status +func (m *FPGADevicePluginServer) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error { + log.Debugf("In ListAndWatch(%s): stream: %v", m.devType, s) + //debug.PrintStack() + for m.devices = range m.update { + if err := m.sendDevices(s); err != nil { + return err + } + } + return nil +} + +// Allocate which return list of devices. +func (m *FPGADevicePluginServer) Allocate(ctx context.Context, req *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { + log.Debugf("In Allocate()") + response := new(pluginapi.AllocateResponse) + for _, creq := range req.ContainerRequests { + log.Debugf("Request IDs: %v", creq.DevicesIDs) + + cres := new(pluginapi.ContainerAllocateResponse) + + // Check same serial number devices, devices with same serail number "F1-node" will be marked as independent devices + deviceIDs_arry := creq.DevicesIDs + + for _, id := range deviceIDs_arry { + log.Printf("Receiving request %s", id) + dev, ok := m.devices[id] + if !ok { + return nil, fmt.Errorf("Invalid allocation request with non-existing device %s", id) + } + if !m.deviceExists(id) { + return nil, fmt.Errorf("invalid allocation request: unknown device: %s", id) + } + fname := path.Join(SysfsDevices, dev.deviceID, DeviceFile) + cres.Mounts = append(cres.Mounts, &pluginapi.Mount{ + HostPath: fname, + ContainerPath: fname, + ReadOnly: false, + }) + response.ContainerResponses = append(response.ContainerResponses, cres) + } + } + return response, nil +} + +// Serve starts the gRPC server and register the device plugin to Kubelet +func (m *FPGADevicePluginServer) Serve(resourceName string) error { + log.Debugf("In Serve(%s)", m.socket) + err := m.Start() + if err != nil { + log.Errorf("Could not start device plugin: %v", err) + return err + } + log.Infof("Starting to serve on %s", m.socket) + + err = m.Register(pluginapi.KubeletSocket, resourceName) + if err != nil { + log.Errorf("Could not register device plugin: %v", err) + m.Stop() + return err + } + log.Infof("Registered device plugin with Kubelet %s", resourceName) + + return nil +} + +func (m *FPGADevicePluginServer) GetPreferredAllocation(ctx context.Context, req *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) { + response := new(pluginapi.PreferredAllocationResponse) + for _, creq := range req.ContainerRequests { + log.Debugf("Request IDs: %v", creq.AvailableDeviceIDs) + + cres := new(pluginapi.ContainerPreferredAllocationResponse) + + // Check same serial number devices, devices with same serail number "F1-node" will be marked as independent devices + deviceIDs_arry := creq.AvailableDeviceIDs + + for _, id := range deviceIDs_arry { + log.Printf("Receiving request %s", id) + dev, ok := m.devices[id] + if !ok { + return nil, fmt.Errorf("Invalid allocation request with non-existing device %s", id) + } + if !m.deviceExists(id) { + return nil, fmt.Errorf("invalid allocation request: unknown device: %s", id) + } + fname := path.Join(SysfsDevices, dev.deviceID, DeviceFile) + cres.DeviceIDs = append(cres.DeviceIDs, fname) + response.ContainerResponses = append(response.ContainerResponses, cres) + } + } + return response, nil +} diff --git a/watch.go b/watch.go new file mode 100644 index 0000000..53a0125 --- /dev/null +++ b/watch.go @@ -0,0 +1,32 @@ +package main + +import ( + "os" + "os/signal" + + "github.com/fsnotify/fsnotify" +) + +func newFSWatcher(files ...string) (*fsnotify.Watcher, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + for _, f := range files { + err = watcher.Add(f) + if err != nil { + watcher.Close() + return nil, err + } + } + + return watcher, nil +} + +func newOSWatcher(sigs ...os.Signal) chan os.Signal { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, sigs...) + + return sigChan +}