P4wnP1_aloa/mnetlink/client.go
2018-11-13 14:10:39 +01:00

179 lines
3.6 KiB
Go

package mnetlink
import (
"fmt"
"errors"
"golang.org/x/sys/unix"
"math/rand"
"os"
"sync/atomic"
"syscall"
"time"
)
type Client struct {
Family int
sock_fd int
seq uint32
pid uint32
}
func NewNl(family int) (res *Client, err error) {
res = &Client{
Family: family,
}
// random start seq
r := rand.New(rand.NewSource(time.Now().UnixNano()))
res.seq = r.Uint32()
// assign current PID as PortID
res.pid = uint32(os.Getpid())
return res,nil
}
func (c *Client) incSeq() {
atomic.AddUint32(&c.seq,1)
}
func (c *Client) Open() (err error) {
// if family is 0, choose NETLINK_USERSOCK by default
if c.Family == 0 { c.Family = unix.NETLINK_USERSOCK }
c.sock_fd,err = unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, c.Family)
if err != nil { return }
// bind socket
err = unix.Bind(c.sock_fd, &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: 0,
Pid: uint32(os.Getpid()),
})
return
}
func (c *Client) Close() (err error) {
unix.Close(c.sock_fd)
return
}
func (c *Client) AddGroupMembership(groupid int) (err error) {
err = syscall.SetsockoptInt(c.sock_fd, unix.SOL_NETLINK, unix.NETLINK_ADD_MEMBERSHIP, groupid)
return
}
func (c *Client) DropGroupMembership(groupid int) (err error) {
err = syscall.SetsockoptInt(c.sock_fd, unix.SOL_NETLINK, unix.NETLINK_DROP_MEMBERSHIP, groupid)
return
}
func (c *Client) Sendmsg(p, oob []byte, to unix.Sockaddr, flags int) (err error) {
err = unix.Sendmsg(c.sock_fd, p, oob, to, flags)
return
}
func (c *Client) Send(msg Message) (err error) {
// adjust seq
msg.Seq = c.seq
msg.Pid = c.pid
raw,err := msg.MarshalBinary()
if err != nil { return }
addr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
}
//fmt.Printf("Sending raw2:\n%+v\n", hex.Dump(raw))
err = c.Sendmsg(raw, nil, addr, 0)
if err == nil { c.incSeq() }
return err
}
func (c *Client) Receive() (msgs []Message, err error) {
for {
//fmt.Println("Reading")
raw_in := c.Read()
for len(raw_in) > unix.NLMSG_HDRLEN {
msg := Message{}
msg.UnmarshalBinary(raw_in)
if msg.IsTypeError() {
return nil,errors.New(fmt.Sprintf("Error response: %+v %+v", msg.GetData(), msg.GetErrNo()))
}
msgs = append(msgs, msg)
//fmt.Printf("Received raw: \n%v\n", hex.Dump(raw_in))
//fmt.Printf("Received %+v\nData:\n%+v\n", msg, hex.Dump(msg.data))
// check if last message
if msg.IsTypeDone() || !msg.HasFlagMulti() {
return
}
raw_in = raw_in[AlignMsg(int(msg.Len)):]
}
}
return
}
/*
func (c *Client) Receive2() (msgs []Message, err error) {
for {
//fmt.Println("Reading")
raw_in := c.Read()
msg := Message{}
msg.UnmarshalBinary(raw_in)
if msg.IsTypeError() {
return nil,errors.New(fmt.Sprintf("Error response: %+v %+v", msg.GetData(), msg.GetErrNo()))
}
msgs = append(msgs, msg)
fmt.Printf("Received raw: \n%v\n", hex.Dump(raw_in))
fmt.Printf("Received %+v\nData:\n%+v\n", msg, hex.Dump(msg.data))
// check if last message
if msg.IsTypeDone() || !msg.HasFlagMulti() {
break
}
}
return
}
*/
func (c *Client) Read() (res []byte) {
rcvBuf := make([]byte, os.Getpagesize())
for {
//fmt.Println("calling receive")
// peek into rcv to fetch bytes available
n,_,_ := unix.Recvfrom(c.sock_fd, rcvBuf, unix.MSG_PEEK)
//fmt.Println("Bytes received: ", n)
if len(rcvBuf) < n {
fmt.Println("Receive buffer too small, increasing...")
rcvBuf = make([]byte, len(rcvBuf)*2)
} else {
break
}
}
n,_,_ := unix.Recvfrom(c.sock_fd, rcvBuf, 0)
nlMsgRaw := make([]byte, n)
copy(nlMsgRaw, rcvBuf) // Copy over as many bytes as readen
return nlMsgRaw
/*
msg := NlMsg{}
msg.fromWire(nlMsgRaw)
return msg.Payload
*/
}