mirror of
https://github.com/RoganDawes/P4wnP1_aloa.git
synced 2025-03-26 01:21:44 +01:00
179 lines
3.6 KiB
Go
179 lines
3.6 KiB
Go
package mnetlink
|
|
|
|
import (
|
|
"fmt"
|
|
"errors"
|
|
"golang.org/x/sys/unix"
|
|
"math/rand"
|
|
"os"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
type Client struct {
|
|
Family int
|
|
sock_fd int
|
|
seq uint32
|
|
pid uint32
|
|
}
|
|
|
|
func NewNl(family int) (res *Client, err error) {
|
|
res = &Client{
|
|
Family: family,
|
|
}
|
|
|
|
// random start seq
|
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
res.seq = r.Uint32()
|
|
// assign current PID as PortID
|
|
res.pid = uint32(os.Getpid())
|
|
|
|
|
|
return res,nil
|
|
}
|
|
|
|
func (c *Client) incSeq() {
|
|
atomic.AddUint32(&c.seq,1)
|
|
}
|
|
|
|
func (c *Client) Open() (err error) {
|
|
// if family is 0, choose NETLINK_USERSOCK by default
|
|
if c.Family == 0 { c.Family = unix.NETLINK_USERSOCK }
|
|
|
|
c.sock_fd,err = unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, c.Family)
|
|
if err != nil { return }
|
|
|
|
|
|
// bind socket
|
|
err = unix.Bind(c.sock_fd, &unix.SockaddrNetlink{
|
|
Family: unix.AF_NETLINK,
|
|
Groups: 0,
|
|
Pid: uint32(os.Getpid()),
|
|
})
|
|
|
|
return
|
|
}
|
|
|
|
func (c *Client) Close() (err error) {
|
|
unix.Close(c.sock_fd)
|
|
return
|
|
}
|
|
|
|
func (c *Client) AddGroupMembership(groupid int) (err error) {
|
|
err = syscall.SetsockoptInt(c.sock_fd, unix.SOL_NETLINK, unix.NETLINK_ADD_MEMBERSHIP, groupid)
|
|
return
|
|
}
|
|
|
|
func (c *Client) DropGroupMembership(groupid int) (err error) {
|
|
err = syscall.SetsockoptInt(c.sock_fd, unix.SOL_NETLINK, unix.NETLINK_DROP_MEMBERSHIP, groupid)
|
|
return
|
|
}
|
|
|
|
func (c *Client) Sendmsg(p, oob []byte, to unix.Sockaddr, flags int) (err error) {
|
|
err = unix.Sendmsg(c.sock_fd, p, oob, to, flags)
|
|
|
|
return
|
|
}
|
|
|
|
func (c *Client) Send(msg Message) (err error) {
|
|
// adjust seq
|
|
msg.Seq = c.seq
|
|
msg.Pid = c.pid
|
|
|
|
|
|
raw,err := msg.MarshalBinary()
|
|
if err != nil { return }
|
|
|
|
addr := &unix.SockaddrNetlink{
|
|
Family: unix.AF_NETLINK,
|
|
}
|
|
|
|
//fmt.Printf("Sending raw2:\n%+v\n", hex.Dump(raw))
|
|
err = c.Sendmsg(raw, nil, addr, 0)
|
|
if err == nil { c.incSeq() }
|
|
return err
|
|
}
|
|
|
|
func (c *Client) Receive() (msgs []Message, err error) {
|
|
for {
|
|
//fmt.Println("Reading")
|
|
raw_in := c.Read()
|
|
for len(raw_in) > unix.NLMSG_HDRLEN {
|
|
msg := Message{}
|
|
msg.UnmarshalBinary(raw_in)
|
|
|
|
if msg.IsTypeError() {
|
|
return nil,errors.New(fmt.Sprintf("Error response: %+v %+v", msg.GetData(), msg.GetErrNo()))
|
|
}
|
|
|
|
msgs = append(msgs, msg)
|
|
//fmt.Printf("Received raw: \n%v\n", hex.Dump(raw_in))
|
|
//fmt.Printf("Received %+v\nData:\n%+v\n", msg, hex.Dump(msg.data))
|
|
|
|
// check if last message
|
|
if msg.IsTypeDone() || !msg.HasFlagMulti() {
|
|
return
|
|
}
|
|
|
|
raw_in = raw_in[AlignMsg(int(msg.Len)):]
|
|
}
|
|
|
|
}
|
|
return
|
|
}
|
|
/*
|
|
func (c *Client) Receive2() (msgs []Message, err error) {
|
|
for {
|
|
//fmt.Println("Reading")
|
|
raw_in := c.Read()
|
|
msg := Message{}
|
|
msg.UnmarshalBinary(raw_in)
|
|
|
|
if msg.IsTypeError() {
|
|
return nil,errors.New(fmt.Sprintf("Error response: %+v %+v", msg.GetData(), msg.GetErrNo()))
|
|
|
|
}
|
|
|
|
msgs = append(msgs, msg)
|
|
fmt.Printf("Received raw: \n%v\n", hex.Dump(raw_in))
|
|
fmt.Printf("Received %+v\nData:\n%+v\n", msg, hex.Dump(msg.data))
|
|
|
|
// check if last message
|
|
if msg.IsTypeDone() || !msg.HasFlagMulti() {
|
|
break
|
|
}
|
|
}
|
|
return
|
|
}
|
|
*/
|
|
func (c *Client) Read() (res []byte) {
|
|
rcvBuf := make([]byte, os.Getpagesize())
|
|
for {
|
|
//fmt.Println("calling receive")
|
|
|
|
// peek into rcv to fetch bytes available
|
|
n,_,_ := unix.Recvfrom(c.sock_fd, rcvBuf, unix.MSG_PEEK)
|
|
//fmt.Println("Bytes received: ", n)
|
|
|
|
if len(rcvBuf) < n {
|
|
fmt.Println("Receive buffer too small, increasing...")
|
|
rcvBuf = make([]byte, len(rcvBuf)*2)
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
n,_,_ := unix.Recvfrom(c.sock_fd, rcvBuf, 0)
|
|
|
|
nlMsgRaw := make([]byte, n)
|
|
copy(nlMsgRaw, rcvBuf) // Copy over as many bytes as readen
|
|
|
|
return nlMsgRaw
|
|
/*
|
|
msg := NlMsg{}
|
|
msg.fromWire(nlMsgRaw)
|
|
return msg.Payload
|
|
*/
|
|
}
|
|
|