diff --git a/fw/executor/profiler.go b/fw/executor/profiler.go new file mode 100644 index 00000000..2e876ca7 --- /dev/null +++ b/fw/executor/profiler.go @@ -0,0 +1,75 @@ +package executor + +import ( + "os" + "runtime" + "runtime/pprof" + + "github.com/named-data/YaNFD/core" +) + +type Profiler struct { + config *YaNFDConfig + cpuFile *os.File + memFile *os.File + block *pprof.Profile +} + +func NewProfiler(config *YaNFDConfig) *Profiler { + return &Profiler{config: config} +} + +func (p *Profiler) Start() (err error) { + if p.config.CpuProfile != "" { + p.cpuFile, err = os.Create(p.config.CpuProfile) + if err != nil { + core.LogFatal("Main", "Unable to open output file for CPU profile: ", err) + } + + core.LogInfo("Main", "Profiling CPU - outputting to ", p.config.CpuProfile) + pprof.StartCPUProfile(p.cpuFile) + } + + if p.config.MemProfile != "" { + memProfileFile, err := os.Create(p.config.MemProfile) + if err != nil { + core.LogFatal("Main", "Unable to open output file for memory profile: ", err) + } + + core.LogInfo("Main", "Profiling memory - outputting to ", p.config.MemProfile) + runtime.GC() + if err := pprof.WriteHeapProfile(memProfileFile); err != nil { + core.LogFatal("Main", "Unable to write memory profile: ", err) + } + } + + if p.config.BlockProfile != "" { + core.LogInfo("Main", "Profiling blocking operations - outputting to ", p.config.BlockProfile) + runtime.SetBlockProfileRate(1) + p.block = pprof.Lookup("block") + } + + return +} + +func (p *Profiler) Stop() { + if p.block != nil { + blockProfileFile, err := os.Create(p.config.BlockProfile) + if err != nil { + core.LogFatal("Main", "Unable to open output file for block profile: ", err) + } + if err := p.block.WriteTo(blockProfileFile, 0); err != nil { + core.LogFatal("Main", "Unable to write block profile: ", err) + } + blockProfileFile.Close() + } + + if p.memFile != nil { + p.memFile.Close() + } + + if p.cpuFile != nil { + pprof.StopCPUProfile() + p.cpuFile.Close() + } +} diff --git a/fw/executor/yanfd.go b/fw/executor/yanfd.go index ed23dc1b..9a072df0 100644 --- a/fw/executor/yanfd.go +++ b/fw/executor/yanfd.go @@ -10,8 +10,6 @@ package executor import ( "net" "os" - "runtime" - "runtime/pprof" "time" "github.com/named-data/YaNFD/core" @@ -38,15 +36,13 @@ type YaNFDConfig struct { // YaNFD is the wrapper class for the NDN Forwarding Daemon. // Note: only one instance of this class should be created. type YaNFD struct { - config *YaNFDConfig - - cpuProfileFile *os.File - memProfileFile *os.File - blockProfiler *pprof.Profile + config *YaNFDConfig + profiler *Profiler unixListener *face.UnixStreamListener wsListener *face.WebSocketListener tcpListeners []*face.TCPListener + udpListener *face.UDPListener } // NewYaNFD creates a YaNFD. Don't call this function twice. @@ -68,33 +64,9 @@ func NewYaNFD(config *YaNFDConfig) *YaNFD { table.Configure() mgmt.Configure() - // Initialize profiling - var cpuProfileFile *os.File - var memProfileFile *os.File - var blockProfiler *pprof.Profile - var err error - if config.CpuProfile != "" { - cpuProfileFile, err = os.Create(config.CpuProfile) - if err != nil { - core.LogFatal("Main", "Unable to open output file for CPU profile: ", err) - } - - core.LogInfo("Main", "Profiling CPU - outputting to ", config.CpuProfile) - pprof.StartCPUProfile(cpuProfileFile) - } - - if config.BlockProfile != "" { - core.LogInfo("Main", "Profiling blocking operations - outputting to ", config.BlockProfile) - runtime.SetBlockProfileRate(1) - blockProfiler = pprof.Lookup("block") - // Output at end of runtime - } - return &YaNFD{ - config: config, - cpuProfileFile: cpuProfileFile, - memProfileFile: memProfileFile, - blockProfiler: blockProfiler, + config: config, + profiler: NewProfiler(config), } } @@ -103,18 +75,18 @@ func NewYaNFD(config *YaNFDConfig) *YaNFD { func (y *YaNFD) Start() { core.LogInfo("Main", "Starting YaNFD") + // Start profiler + y.profiler.Start() + // Initialize FIB table fibTableAlgorithm := core.GetConfigStringDefault("tables.fib.algorithm", "nametree") table.CreateFIBTable(fibTableAlgorithm) // Create null face - nullFace := face.MakeNullLinkService(face.MakeNullTransport()) - face.FaceTable.Add(nullFace) - go nullFace.Run(nil) + face.MakeNullLinkService(face.MakeNullTransport()).Run(nil) // Start management thread - management := mgmt.MakeMgmtThread() - go management.Run() + go mgmt.MakeMgmtThread().Run() // Create forwarding threads if fw.NumFwThreads < 1 || fw.NumFwThreads > fw.MaxFwThreads { @@ -170,10 +142,13 @@ func (y *YaNFD) Start() { core.LogError("Main", "Unable to create MulticastUDPTransport for ", path, " on ", iface.Name, ": ", err) continue } - multicastUDPFace := face.MakeNDNLPLinkService(multicastUDPTransport, face.MakeNDNLPLinkServiceOptions()) - face.FaceTable.Add(multicastUDPFace) + + face.MakeNDNLPLinkService( + multicastUDPTransport, + face.MakeNDNLPLinkServiceOptions(), + ).Run(nil) + faceCnt += 1 - go multicastUDPFace.Run(nil) core.LogInfo("Main", "Created multicast UDP face for ", path, " on ", iface.Name) } @@ -184,6 +159,7 @@ func (y *YaNFD) Start() { } faceCnt += 1 go udpListener.Run() + y.udpListener = udpListener core.LogInfo("Main", "Created UDP listener for ", path, " on ", iface.Name) if tcpEnabled { @@ -194,8 +170,8 @@ func (y *YaNFD) Start() { } faceCnt += 1 go tcpListener.Run() - core.LogInfo("Main", "Created TCP listener for ", path, " on ", iface.Name) y.tcpListeners = append(y.tcpListeners, tcpListener) + core.LogInfo("Main", "Created TCP listener for ", path, " on ", iface.Name) } } } @@ -240,28 +216,22 @@ func (y *YaNFD) Stop() { core.LogInfo("Main", "Forwarder shutting down ...") core.ShouldQuit = true - if y.config.MemProfile != "" { - memProfileFile, err := os.Create(y.config.MemProfile) - if err != nil { - core.LogFatal("Main", "Unable to open output file for memory profile: ", err) - } - - core.LogInfo("Main", "Profiling memory - outputting to ", y.config.MemProfile) - runtime.GC() - if err := pprof.WriteHeapProfile(memProfileFile); err != nil { - core.LogFatal("Main", "Unable to write memory profile: ", err) - } - } + // Stop profiler + y.profiler.Stop() // Wait for unix socket listener to quit if y.unixListener != nil { y.unixListener.Close() - <-y.unixListener.HasQuit } if y.wsListener != nil { y.wsListener.Close() } + // Wait for UDP listener to quit + if y.udpListener != nil { + y.udpListener.Close() + } + // Wait for TCP listeners to quit for _, tcpListener := range y.tcpListeners { tcpListener.Close() @@ -272,12 +242,6 @@ func (y *YaNFD) Stop() { face.Close() } - // Wait for all faces to quit - for _, face := range face.FaceTable.GetAll() { - core.LogTrace("Main", "Waiting for face ", face, " to quit") - <-face.GetHasQuit() - } - // Tell all forwarding threads to quit for _, fw := range fw.Threads { fw.TellToQuit() @@ -287,23 +251,4 @@ func (y *YaNFD) Stop() { for _, fw := range fw.Threads { <-fw.HasQuit } - - // Shutdown Profilers - if y.config.BlockProfile != "" { - blockProfileFile, err := os.Create(y.config.BlockProfile) - if err != nil { - core.LogFatal("Main", "Unable to open output file for block profile: ", err) - } - if err := y.blockProfiler.WriteTo(blockProfileFile, 0); err != nil { - core.LogFatal("Main", "Unable to write block profile: ", err) - } - blockProfileFile.Close() - } - if y.config.MemProfile != "" { - y.memProfileFile.Close() - } - if y.config.CpuProfile != "" { - pprof.StopCPUProfile() - y.cpuProfileFile.Close() - } } diff --git a/fw/face/internal-transport.go b/fw/face/internal-transport.go index b791bd26..bc95e73c 100644 --- a/fw/face/internal-transport.go +++ b/fw/face/internal-transport.go @@ -8,7 +8,6 @@ package face import ( - "runtime" "strconv" "github.com/named-data/YaNFD/core" @@ -37,20 +36,21 @@ func MakeInternalTransport() *InternalTransport { defn.MaxNDNPacketSize) t.recvQueue = make(chan []byte, faceQueueSize) t.sendQueue = make(chan []byte, faceQueueSize) - t.changeState(defn.Up) + t.running.Store(true) return t } // RegisterInternalTransport creates, registers, and starts an InternalTransport. func RegisterInternalTransport() (LinkService, *InternalTransport) { - t := MakeInternalTransport() - l := MakeNDNLPLinkService(t, NDNLPLinkServiceOptions{ - IsIncomingFaceIndicationEnabled: true, - IsConsumerControlledForwardingEnabled: true, - }) - FaceTable.Add(l) - go l.Run(nil) - return l, t + transport := MakeInternalTransport() + + options := MakeNDNLPLinkServiceOptions() + options.IsIncomingFaceIndicationEnabled = true + options.IsConsumerControlledForwardingEnabled = true + link := MakeNDNLPLinkService(transport, options) + link.Run(nil) + + return link, transport } func (t *InternalTransport) String() string { @@ -83,7 +83,7 @@ func (t *InternalTransport) Send(netWire enc.Wire, pitToken []byte, nextHopFaceI Fragment: netWire, } if len(pitToken) > 0 { - lpPkt.PitToken = pitToken + lpPkt.PitToken = append([]byte{}, pitToken...) } if nextHopFaceID != nil { lpPkt.NextHopFaceId = utils.IdPtr(*nextHopFaceID) @@ -103,27 +103,22 @@ func (t *InternalTransport) Send(netWire enc.Wire, pitToken []byte, nextHopFaceI // Receive receives a packet from the perspective of the internal component. func (t *InternalTransport) Receive() (enc.Wire, []byte, uint64) { - shouldContinue := true - // We need to use a for loop to silently ignore invalid packets - for shouldContinue { - select { - case frame := <-t.recvQueue: - pkt, _, err := spec.ReadPacket(enc.NewBufferReader(frame)) - if err != nil { - core.LogWarn(t, "Unable to decode received block - DROP: ", err) - continue - } - lpPkt := pkt.LpPacket - if lpPkt.Fragment.Length() == 0 { - core.LogWarn(t, "Received empty fragment - DROP") - continue - } - - return lpPkt.Fragment, lpPkt.PitToken, *lpPkt.IncomingFaceId - case <-t.hasQuit: - shouldContinue = false + for frame := range t.recvQueue { + packet, _, err := spec.ReadPacket(enc.NewBufferReader(frame)) + if err != nil { + core.LogWarn(t, "Unable to decode received block - DROP: ", err) + continue } + + lpPkt := packet.LpPacket + if lpPkt.Fragment.Length() == 0 { + core.LogWarn(t, "Received empty fragment - DROP") + continue + } + + return lpPkt.Fragment, lpPkt.PitToken, *lpPkt.IncomingFaceId } + return nil, []byte{}, 0 } @@ -135,54 +130,26 @@ func (t *InternalTransport) sendFrame(frame []byte) { t.nOutBytes += uint64(len(frame)) - core.LogDebug(t, "Sending frame of size ", len(frame)) - frameCopy := make([]byte, len(frame)) copy(frameCopy, frame) t.recvQueue <- frameCopy } func (t *InternalTransport) runReceive() { - core.LogTrace(t, "Starting receive thread") - - if lockThreadsToCores { - runtime.LockOSThread() - } - - for { - core.LogTrace(t, "Waiting for frame from component") - select { - case <-t.hasQuit: - return - case frame := <-t.sendQueue: - core.LogTrace(t, "Component send of size ", len(frame)) - - if len(frame) > defn.MaxNDNPacketSize { - core.LogWarn(t, "Component trying to send too much data - DROP") - continue - } - - t.nInBytes += uint64(len(frame)) - - t.linkService.handleIncomingFrame(frame) + for frame := range t.sendQueue { + if len(frame) > defn.MaxNDNPacketSize { + core.LogWarn(t, "Component trying to send too much data - DROP") + continue } - } -} -func (t *InternalTransport) changeState(new defn.State) { - if t.state == new { - return + t.nInBytes += uint64(len(frame)) + t.linkService.handleIncomingFrame(frame) } +} - core.LogInfo(t, "state: ", t.state, " -> ", new) - t.state = new - - if t.state != defn.Up { - // Stop link service - t.hasQuit <- true - t.hasQuit <- true // Send again to stop any pending receives - t.linkService.tellTransportQuit() - - FaceTable.Remove(t.faceID) +func (t *InternalTransport) Close() { + if t.running.Swap(false) { + close(t.recvQueue) + close(t.sendQueue) } } diff --git a/fw/face/link-service.go b/fw/face/link-service.go index 70367b13..53b96b63 100644 --- a/fw/face/link-service.go +++ b/fw/face/link-service.go @@ -38,16 +38,16 @@ type LinkService interface { State() defn.State // Run is the main entry point for running face thread - // optNewFrame is optional new incoming frame - Run(optNewFrame []byte) + // initial is optional new incoming frame + Run(initial []byte) - // SendPacket Add a packet to the send queue for this link service + // Add a packet to the send queue for this link service SendPacket(packet *defn.Pkt) + // Synchronously handle an incoming frame and dispatch to fw handleIncomingFrame(frame []byte) + // Close the face Close() - tellTransportQuit() - GetHasQuit() chan bool // Counters NInInterests() uint64 @@ -60,12 +60,10 @@ type LinkService interface { // linkServiceBase is the type upon which all link service implementations should be built type linkServiceBase struct { - faceID uint64 - transport transport - HasQuit chan bool - hasImplQuit chan bool - hasTransportQuit chan bool - sendQueue chan *defn.Pkt + faceID uint64 + transport transport + stopped chan bool + sendQueue chan *defn.Pkt // Counters nInInterests uint64 @@ -89,23 +87,12 @@ func (l *linkServiceBase) SetFaceID(faceID uint64) { } } -func (l *linkServiceBase) tellTransportQuit() { - l.hasTransportQuit <- true -} - -// GetHasQuit returns the channel that indicates when the face has quit. -func (l *linkServiceBase) GetHasQuit() chan bool { - return l.HasQuit -} - // // "Constructors" and threading // func (l *linkServiceBase) makeLinkServiceBase() { - l.HasQuit = make(chan bool) - l.hasImplQuit = make(chan bool) - l.hasTransportQuit = make(chan bool) + l.stopped = make(chan bool) l.sendQueue = make(chan *defn.Pkt, faceQueueSize) } @@ -170,7 +157,10 @@ func (l *linkServiceBase) ExpirationPeriod() time.Duration { // State returns the state of the underlying transport. func (l *linkServiceBase) State() defn.State { - return l.transport.State() + if l.transport.IsRunning() { + return defn.Up + } + return defn.Down } // @@ -207,6 +197,11 @@ func (l *linkServiceBase) NOutBytes() uint64 { return l.transport.NOutBytes() } +// Close the underlying transport +func (l *linkServiceBase) Close() { + l.transport.Close() +} + // // Forwarding pipeline // @@ -277,7 +272,3 @@ func (l *linkServiceBase) dispatchData(pkt *defn.Pkt) { core.LogTrace(l, "Dispatched Data to thread ", thread) dispatch.GetFWThread(thread).QueueData(pkt) } - -func (l *linkServiceBase) Close() { - l.transport.changeState(defn.Down) -} diff --git a/fw/face/multicast-udp-transport.go b/fw/face/multicast-udp-transport.go index fdab6969..90c4c4f7 100644 --- a/fw/face/multicast-udp-transport.go +++ b/fw/face/multicast-udp-transport.go @@ -9,9 +9,9 @@ package face import ( "errors" + "fmt" "net" - "runtime" - "strconv" + "strings" "github.com/named-data/YaNFD/core" defn "github.com/named-data/YaNFD/defn" @@ -36,24 +36,21 @@ func MakeMulticastUDPTransport(localURI *defn.URI) (*MulticastUDPTransport, erro return nil, core.ErrNotCanonical } - t := new(MulticastUDPTransport) - // Get local interface - localIf, err := InterfaceByIP(net.ParseIP(localURI.PathHost())) - if err != nil || localIf == nil { - core.LogError(t, "Unable to get interface for local URI ", localURI, ": ", err) - } - + // Get remote Uri + var remote string if localURI.Scheme() == "udp4" { - t.makeTransportBase( - defn.DecodeURIString("udp4://"+udp4MulticastAddress+":"+strconv.FormatUint(uint64(UDPMulticastPort), 10)), - localURI, PersistencyPermanent, defn.NonLocal, defn.MultiAccess, defn.MaxNDNPacketSize) + remote = fmt.Sprintf("udp4://%s:%d", udp4MulticastAddress, UDPMulticastPort) } else if localURI.Scheme() == "udp6" { - t.makeTransportBase( - defn.DecodeURIString("udp6://["+udp6MulticastAddress+"%"+localIf.Name+"]:"+ - strconv.FormatUint(uint64(UDPMulticastPort), 10)), - localURI, PersistencyPermanent, defn.NonLocal, defn.MultiAccess, defn.MaxNDNPacketSize) + remote = fmt.Sprintf("udp6://[%s]:%d", udp6MulticastAddress, UDPMulticastPort) } - t.scope = defn.NonLocal + + // Create transport + t := &MulticastUDPTransport{} + t.makeTransportBase( + defn.DecodeURIString(remote), + localURI, PersistencyPermanent, + defn.NonLocal, defn.MultiAccess, + defn.MaxNDNPacketSize) // Format group and local addresses t.groupAddr.IP = net.ParseIP(t.remoteURI.PathHost()) @@ -65,32 +62,51 @@ func MakeMulticastUDPTransport(localURI *defn.URI) (*MulticastUDPTransport, erro // Configure dialer so we can allow address reuse t.dialer = &net.Dialer{LocalAddr: &t.localAddr, Control: impl.SyscallReuseAddr} + t.running.Store(true) // Create send connection - sendConn, err := t.dialer.Dial(t.remoteURI.Scheme(), t.groupAddr.String()) + err := t.connectSend() if err != nil { - return nil, errors.New("Unable to create send connection to group address: " + err.Error()) + t.Close() + return nil, err } - t.sendConn = sendConn.(*net.UDPConn) // Create receive connection - t.recvConn, err = net.ListenMulticastUDP(t.remoteURI.Scheme(), localIf, &t.groupAddr) + err = t.connectRecv() if err != nil { - return nil, errors.New("Unable to create receive connection for group address on " + - localIf.Name + ": " + err.Error()) + t.Close() + return nil, err } - t.changeState(defn.Up) - return t, nil } +func (t *MulticastUDPTransport) connectSend() error { + sendConn, err := t.dialer.Dial(t.remoteURI.Scheme(), t.groupAddr.String()) + if err != nil { + return errors.New("unable to create send connection to group address: " + err.Error()) + } + t.sendConn = sendConn.(*net.UDPConn) + return nil +} + +func (t *MulticastUDPTransport) connectRecv() error { + localIf, err := InterfaceByIP(net.ParseIP(t.localURI.PathHost())) + if err != nil || localIf == nil { + return fmt.Errorf("unable to get interface for local URI %s: %s", t.localURI, err.Error()) + } + + t.recvConn, err = net.ListenMulticastUDP(t.remoteURI.Scheme(), localIf, &t.groupAddr) + if err != nil { + return fmt.Errorf("unable to create receive conn for group %s: %s", localIf.Name, err.Error()) + } + return nil +} + func (t *MulticastUDPTransport) String() string { - return "MulticastUDPTransport, FaceID=" + strconv.FormatUint(t.faceID, 10) + - ", RemoteURI=" + t.remoteURI.String() + ", LocalURI=" + t.localURI.String() + return fmt.Sprintf("MulticastUDPTransport, FaceID=%d, RemoteURI=%s, LocalURI=%s", t.faceID, t.remoteURI, t.localURI) } -// SetPersistency changes the persistency of the face. func (t *MulticastUDPTransport) SetPersistency(persistency Persistency) bool { if persistency == t.persistency { return true @@ -104,7 +120,6 @@ func (t *MulticastUDPTransport) SetPersistency(persistency Persistency) bool { return false } -// GetSendQueueSize returns the current size of the send queue. func (t *MulticastUDPTransport) GetSendQueueSize() uint64 { rawConn, err := t.recvConn.SyscallConn() if err != nil { @@ -114,82 +129,65 @@ func (t *MulticastUDPTransport) GetSendQueueSize() uint64 { } func (t *MulticastUDPTransport) sendFrame(frame []byte) { + if !t.running.Load() { + return + } + if len(frame) > t.MTU() { core.LogWarn(t, "Attempted to send frame larger than MTU - DROP") return } - core.LogDebug(t, "Sending frame of size ", len(frame)) _, err := t.sendConn.Write(frame) if err != nil { core.LogWarn(t, "Unable to send on socket - DROP") - t.sendConn.Close() - sendConn, err := t.dialer.Dial(t.remoteURI.Scheme(), t.groupAddr.String()) - if err != nil { - core.LogError(t, "Unable to create send connection to group address: ", err) + + // Re-create the socket if connection is still running + if t.running.Load() { + err = t.connectSend() + if err != nil { + core.LogError(t, "Unable to re-create send connection: ", err) + return + } } - t.sendConn = sendConn.(*net.UDPConn) } + t.nOutBytes += uint64(len(frame)) } func (t *MulticastUDPTransport) runReceive() { - if lockThreadsToCores { - runtime.LockOSThread() - } + defer t.Close() - recvBuf := make([]byte, defn.MaxNDNPacketSize) for { - readSize, remoteAddr, err := t.recvConn.ReadFromUDP(recvBuf) + err := readTlvStream(t.recvConn, func(b []byte) { + t.nInBytes += uint64(len(b)) + t.linkService.handleIncomingFrame(b) + }, func(err error) bool { + // Same as unicast UDP transport + return strings.Contains(err.Error(), "connection refused") + }) if err != nil { - if err.Error() == "EOF" { - core.LogDebug(t, "EOF - Face DOWN") - t.changeState(defn.Down) - break - } else { - core.LogWarn(t, "Unable to read from socket (", err, ") - DROP") - t.recvConn.Close() - localIf, err := InterfaceByIP(net.ParseIP(t.localURI.PathHost())) - if err != nil || localIf == nil { - core.LogError(t, "Unable to get interface for local URI ", t.localURI, ": ", err) + core.LogWarn(t, "Unable to read from socket (", err, ") - Face DOWN") + + // Re-create the socket if connection is still running + if t.running.Load() { + err = t.connectRecv() + if err != nil { + core.LogError(t, "Unable to re-create receive connection: ", err) + return } - t.recvConn, _ = net.ListenMulticastUDP(t.remoteURI.Scheme(), localIf, &t.groupAddr) } } - - core.LogTrace(t, "Receive of size ", readSize, " from ", remoteAddr) - t.nInBytes += uint64(readSize) - - if readSize > defn.MaxNDNPacketSize { - core.LogWarn(t, "Received too much data without valid TLV block - DROP") - } - if readSize <= 0 { - core.LogInfo(t, "Socket close.") - continue - } - - // Packet was successfully received, send up to link service - t.linkService.handleIncomingFrame(recvBuf[:readSize]) } } -func (t *MulticastUDPTransport) changeState(new defn.State) { - if t.state == new { - return - } - - core.LogInfo(t, "state: ", t.state, " -> ", new) - t.state = new - - if t.state != defn.Up { - core.LogInfo(t, "Closing UDP socket") - t.hasQuit <- true - t.sendConn.Close() - t.recvConn.Close() - - // Stop link service - t.linkService.tellTransportQuit() - - FaceTable.Remove(t.faceID) +func (t *MulticastUDPTransport) Close() { + if t.running.Swap(false) { + if t.sendConn != nil { + t.sendConn.Close() + } + if t.recvConn != nil { + t.recvConn.Close() + } } } diff --git a/fw/face/ndnlp-link-service.go b/fw/face/ndnlp-link-service.go index ed33fb2a..f53aa242 100644 --- a/fw/face/ndnlp-link-service.go +++ b/fw/face/ndnlp-link-service.go @@ -48,10 +48,12 @@ type NDNLPLinkServiceOptions struct { } func MakeNDNLPLinkServiceOptions() NDNLPLinkServiceOptions { - var o NDNLPLinkServiceOptions - o.BaseCongestionMarkingInterval = time.Duration(100) * time.Millisecond - o.DefaultCongestionThresholdBytes = uint64(math.Pow(2, 16)) - return o + return NDNLPLinkServiceOptions{ + BaseCongestionMarkingInterval: time.Duration(100) * time.Millisecond, + DefaultCongestionThresholdBytes: uint64(math.Pow(2, 16)), + IsReassemblyEnabled: true, + IsFragmentationEnabled: true, + } } // NDNLPLinkService is a link service implementing the NDNLPv2 link protocol @@ -125,28 +127,35 @@ func (l *NDNLPLinkService) computeHeaderOverhead() { } // Run starts the face and associated goroutines -func (l *NDNLPLinkService) Run(optNewFrame []byte) { +func (l *NDNLPLinkService) Run(initial []byte) { if l.transport == nil { core.LogError(l, "Unable to start face due to unset transport") return } - if optNewFrame != nil { - l.handleIncomingFrame(optNewFrame) + // Add self to face table. Removed in runSend. + FaceTable.Add(l) + + // Process initial incoming frame + if initial != nil { + l.handleIncomingFrame(initial) } // Start transport goroutines - go l.transport.runReceive() + go l.runReceive() go l.runSend() +} + +func (l *NDNLPLinkService) runReceive() { + if lockThreadsToCores { + runtime.LockOSThread() + } - // Wait for link service send goroutine to quit - <-l.hasImplQuit - l.HasQuit <- true + l.transport.runReceive() + l.stopped <- true } func (l *NDNLPLinkService) runSend() { - core.LogTrace(l, "Starting send thread") - if lockThreadsToCores { runtime.LockOSThread() } @@ -155,8 +164,8 @@ func (l *NDNLPLinkService) runSend() { select { case pkt := <-l.sendQueue: sendPacket(l, pkt) - case <-l.hasTransportQuit: - l.hasImplQuit <- true + case <-l.stopped: + FaceTable.Remove(l.transport.FaceID()) return } } @@ -165,11 +174,6 @@ func (l *NDNLPLinkService) runSend() { func sendPacket(l *NDNLPLinkService, pkt *defn.Pkt) { wire := pkt.Raw - if l.transport.State() != defn.Up { - core.LogWarn(l, "Attempted to send frame on down face - DROP and stop LinkService") - l.hasImplQuit <- true - return - } // Counters if pkt.L3.Interest != nil { l.nOutInterests++ @@ -322,12 +326,7 @@ func (l *NDNLPLinkService) handleIncomingFrame(frame []byte) { } // Reassembly - if l.options.IsReassemblyEnabled { - if LP.Sequence == nil { - core.LogInfo(l, "Received NDNLPv2 frame without Sequence but reassembly requires it - DROP") - return - } - + if l.options.IsReassemblyEnabled && LP.Sequence != nil { fragIndex := uint64(0) if LP.FragIndex != nil { fragIndex = *LP.FragIndex diff --git a/fw/face/null-link-service.go b/fw/face/null-link-service.go index 9b33b347..8f227ac6 100644 --- a/fw/face/null-link-service.go +++ b/fw/face/null-link-service.go @@ -36,25 +36,12 @@ func (l *NullLinkService) String() string { } // Run runs the NullLinkService. -func (l *NullLinkService) Run(optNewFrame []byte) { - if l.transport == nil { - core.LogError(l, "Unable to start face due to unset transport") - return - } - - // Start transport goroutines - go l.transport.runReceive() - - if optNewFrame != nil { - l.handleIncomingFrame(optNewFrame) - } - - // Wait for transport receive goroutine to quit - <-l.hasTransportQuit - - core.LogTrace(l, "Transport has quit") - - l.HasQuit <- true +func (l *NullLinkService) Run(initial []byte) { + FaceTable.Add(l) + go func() { + l.transport.runReceive() + FaceTable.Remove(l.transport.FaceID()) + }() } func (l *NullLinkService) handleIncomingFrame(frame []byte) { diff --git a/fw/face/null-transport.go b/fw/face/null-transport.go index 803a6b7d..0bac6f6e 100644 --- a/fw/face/null-transport.go +++ b/fw/face/null-transport.go @@ -10,20 +10,27 @@ package face import ( "strconv" - "github.com/named-data/YaNFD/core" defn "github.com/named-data/YaNFD/defn" ) // NullTransport is a transport that drops all packets. type NullTransport struct { transportBase + close chan bool } // MakeNullTransport makes a NullTransport. func MakeNullTransport() *NullTransport { - t := new(NullTransport) - t.makeTransportBase(defn.MakeNullFaceURI(), defn.MakeNullFaceURI(), PersistencyPermanent, defn.NonLocal, defn.PointToPoint, defn.MaxNDNPacketSize) - t.changeState(defn.Up) + t := &NullTransport{ + close: make(chan bool), + } + t.makeTransportBase( + defn.MakeNullFaceURI(), + defn.MakeNullFaceURI(), + PersistencyPermanent, + defn.NonLocal, + defn.PointToPoint, + defn.MaxNDNPacketSize) return t } @@ -50,18 +57,17 @@ func (t *NullTransport) GetSendQueueSize() uint64 { return 0 } -func (t *NullTransport) changeState(new defn.State) { - if t.state == new { - return - } - - core.LogInfo(t, "state: ", t.state, " -> ", new) - t.state = new +func (t *NullTransport) sendFrame([]byte) { + // Do nothing +} - if t.state != defn.Up { - // Stop link service - t.linkService.tellTransportQuit() +func (t *NullTransport) runReceive() { + t.running.Store(true) + <-t.close +} - FaceTable.Remove(t.faceID) +func (t *NullTransport) Close() { + if t.running.Swap(false) { + t.close <- true } } diff --git a/fw/face/stream-transport.go b/fw/face/stream-transport.go new file mode 100644 index 00000000..7d4f9408 --- /dev/null +++ b/fw/face/stream-transport.go @@ -0,0 +1,71 @@ +package face + +import ( + "errors" + "io" + + defn "github.com/named-data/YaNFD/defn" + enc "github.com/zjkmxy/go-ndn/pkg/encoding" +) + +func readTlvStream( + reader io.Reader, + onFrame func([]byte), + ignoreError func(error) bool, +) error { + recvBuf := make([]byte, defn.MaxNDNPacketSize*32) + recvOff := 0 + tlvOff := 0 + + for { + readSize, err := reader.Read(recvBuf[recvOff:]) + recvOff += readSize + if err != nil { + if ignoreError != nil && ignoreError(err) { + continue + } + if errors.Is(err, io.EOF) { + return nil + } + return err + } + + // Determine whether valid packet received + for { + rdr := enc.NewBufferReader(recvBuf[tlvOff:recvOff]) + + typ, err := enc.ReadTLNum(rdr) + if err != nil { + // Probably incomplete packet + break + } + + len, err := enc.ReadTLNum(rdr) + if err != nil { + // Probably incomplete packet + break + } + + tlvSize := typ.EncodingLength() + len.EncodingLength() + int(len) + + if recvOff-tlvOff >= tlvSize { + // Packet was successfully received, send up to link service + onFrame(recvBuf[tlvOff : tlvOff+tlvSize]) + tlvOff += tlvSize + } else if recvOff-tlvOff > defn.MaxNDNPacketSize { + // Invalid packet, something went wrong + return errors.New("received too much data without valid TLV block") + } else { + // Incomplete packet (for sure) + break + } + } + + // If less than one packet space remains in buffer, shift to beginning + if recvOff-tlvOff < defn.MaxNDNPacketSize { + copy(recvBuf, recvBuf[tlvOff:recvOff]) + recvOff -= tlvOff + tlvOff = 0 + } + } +} diff --git a/fw/face/table.go b/fw/face/table.go index b818c1aa..2832a170 100644 --- a/fw/face/table.go +++ b/fw/face/table.go @@ -10,6 +10,7 @@ package face import ( "sync" "sync/atomic" + "time" "github.com/named-data/YaNFD/core" defn "github.com/named-data/YaNFD/defn" @@ -29,6 +30,7 @@ type Table struct { func init() { FaceTable.faces = sync.Map{} FaceTable.nextFaceID.Store(1) + go FaceTable.ExpirationHandler() } // Add adds a face to the face table. @@ -80,3 +82,21 @@ func (t *Table) Remove(id uint64) { table.Rib.CleanUpFace(id) core.LogDebug("FaceTable", "Unregistered FaceID=", id) } + +// ExpirationHandler stops the faces that have expired +func (t *Table) ExpirationHandler() { + for { + // Check for expired faces every 10 seconds + time.Sleep(10 * time.Second) + + // Iterate the face table + t.faces.Range(func(_, face interface{}) bool { + transport := face.(LinkService).Transport() + if transport != nil && transport.ExpirationPeriod() < 0 { + core.LogInfo(transport, "Face expired") + transport.Close() + } + return true + }) + } +} diff --git a/fw/face/tcp-listener.go b/fw/face/tcp-listener.go index dcf1c82e..49dd31c3 100644 --- a/fw/face/tcp-listener.go +++ b/fw/face/tcp-listener.go @@ -9,8 +9,9 @@ package face import ( "context" + "errors" + "fmt" "net" - "strconv" "github.com/named-data/YaNFD/core" defn "github.com/named-data/YaNFD/defn" @@ -21,7 +22,7 @@ import ( type TCPListener struct { conn net.Listener localURI *defn.URI - HasQuit chan bool + stopped chan bool } // MakeTCPListener constructs a TCPListener. @@ -33,31 +34,33 @@ func MakeTCPListener(localURI *defn.URI) (*TCPListener, error) { l := new(TCPListener) l.localURI = localURI - l.HasQuit = make(chan bool, 1) + l.stopped = make(chan bool, 1) return l, nil } func (l *TCPListener) String() string { - return "TCPListener, " + l.localURI.String() + return fmt.Sprintf("TCPListener, %s", l.localURI) } -// Run starts the TCP listener. func (l *TCPListener) Run() { + defer func() { l.stopped <- true }() + // Create dialer and set reuse address option listenConfig := &net.ListenConfig{Control: impl.SyscallReuseAddr} // Create listener - var err error var remote string if l.localURI.Scheme() == "tcp4" { - remote = l.localURI.PathHost() + ":" + strconv.Itoa(int(l.localURI.Port())) + remote = fmt.Sprintf("%s:%d", l.localURI.PathHost(), l.localURI.Port()) } else { - remote = "[" + l.localURI.Path() + "]:" + strconv.Itoa(int(l.localURI.Port())) + remote = fmt.Sprintf("[%s]:%d", l.localURI.Path(), l.localURI.Port()) } + + // Start listening for incoming connections + var err error l.conn, err = listenConfig.Listen(context.Background(), l.localURI.Scheme(), remote) if err != nil { core.LogError(l, "Unable to start TCP listener: ", err) - l.HasQuit <- true return } @@ -65,6 +68,9 @@ func (l *TCPListener) Run() { for !core.ShouldQuit { remoteConn, err := l.conn.Accept() if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } core.LogWarn(l, "Unable to accept connection: ", err) continue } @@ -74,21 +80,17 @@ func (l *TCPListener) Run() { core.LogError(l, "Failed to create new unicast TCP transport: ", err) continue } - newLinkService := MakeNDNLPLinkService(newTransport, MakeNDNLPLinkServiceOptions()) - // Add face to table (which assigns FaceID) before passing current frame to link service - FaceTable.Add(newLinkService) - go newLinkService.Run(nil) + core.LogInfo(l, "Accepting new TCP face ", newTransport.RemoteURI()) + options := MakeNDNLPLinkServiceOptions() + options.IsFragmentationEnabled = false // reliable stream + MakeNDNLPLinkService(newTransport, options).Run(nil) } - - l.HasQuit <- true } -// Close closes the TCPListener. func (l *TCPListener) Close() { - core.LogInfo(l, "Stopping listener") if l.conn != nil { l.conn.Close() - l.conn = nil + <-l.stopped } } diff --git a/fw/face/transport.go b/fw/face/transport.go index 8cba52af..3b8a9dcd 100644 --- a/fw/face/transport.go +++ b/fw/face/transport.go @@ -8,6 +8,7 @@ package face import ( + "sync/atomic" "time" defn "github.com/named-data/YaNFD/defn" @@ -27,16 +28,19 @@ type transport interface { LinkType() defn.LinkType MTU() int SetMTU(mtu int) - State() defn.State ExpirationPeriod() time.Duration + FaceID() uint64 + // Get the number of queued outgoing packets GetSendQueueSize() uint64 - - runReceive() - + // Send a frame (make if copy if necessary) sendFrame([]byte) - - changeState(newState defn.State) + // Receive frames in an infinite loop + runReceive() + // Transport is currently running (up) + IsRunning() bool + // Close the transport (runReceive should exit) + Close() // Counters NInBytes() uint64 @@ -46,6 +50,7 @@ type transport interface { // transportBase provides logic common types between transport types type transportBase struct { linkService LinkService + running atomic.Bool faceID uint64 remoteURI *defn.URI @@ -56,26 +61,32 @@ type transportBase struct { mtu int expirationTime *time.Time - state defn.State - - hasQuit chan bool - // Counters nInBytes uint64 nOutBytes uint64 } -func (t *transportBase) makeTransportBase(remoteURI *defn.URI, localURI *defn.URI, persistency Persistency, scope defn.Scope, linkType defn.LinkType, mtu int) { +func (t *transportBase) makeTransportBase( + remoteURI *defn.URI, + localURI *defn.URI, + persistency Persistency, + scope defn.Scope, + linkType defn.LinkType, + mtu int, +) { + t.running = atomic.Bool{} t.remoteURI = remoteURI t.localURI = localURI t.persistency = persistency t.scope = scope t.linkType = linkType - t.state = defn.Down t.mtu = mtu - t.hasQuit = make(chan bool, 2) } +// +// Setters +// + func (t *transportBase) setFaceID(faceID uint64) { t.faceID = faceID } @@ -88,42 +99,36 @@ func (t *transportBase) setLinkService(linkService LinkService) { // Getters // -// LocalURI returns the local URI of the transport. func (t *transportBase) LocalURI() *defn.URI { return t.localURI } -// RemoteURI returns the remote URI of the transport. func (t *transportBase) RemoteURI() *defn.URI { return t.remoteURI } -// Persistency returns the persistency of the transport. func (t *transportBase) Persistency() Persistency { return t.persistency } -// Scope returns the scope of the transport. func (t *transportBase) Scope() defn.Scope { return t.scope } -// LinkType returns the type of the transport. func (t *transportBase) LinkType() defn.LinkType { return t.linkType } -// MTU returns the maximum transmission unit (MTU) of the Transport. func (t *transportBase) MTU() int { return t.mtu } -// SetMTU sets the MTU of the transport. func (t *transportBase) SetMTU(mtu int) { t.mtu = mtu } -// ExpirationPeriod returns the time until this face expires. If transport not on-demand, returns 0. +// ExpirationPeriod returns the time until this face expires. +// If transport not on-demand, returns 0. func (t *transportBase) ExpirationPeriod() time.Duration { if t.expirationTime == nil || t.persistency != PersistencyOnDemand { return 0 @@ -131,35 +136,22 @@ func (t *transportBase) ExpirationPeriod() time.Duration { return time.Until(*t.expirationTime) } -// State returns the state of the transport. -func (t *transportBase) State() defn.State { - return t.state +func (t *transportBase) FaceID() uint64 { + return t.faceID +} + +func (t *transportBase) IsRunning() bool { + return t.running.Load() } // // Counters // -// NInBytes returns the number of link-layer bytes received on this transport. func (t *transportBase) NInBytes() uint64 { return t.nInBytes } -// NOutBytes returns the number of link-layer bytes sent on this transport. func (t *transportBase) NOutBytes() uint64 { return t.nOutBytes } - -// -// Stubs -// - -func (t *transportBase) runReceive() { - // Overridden in specific transport implementation -} - -func (t *transportBase) sendFrame(frame []byte) { - // Overridden in specific transport implementation - - t.nOutBytes += uint64(len(frame)) -} diff --git a/fw/face/udp-listener.go b/fw/face/udp-listener.go index f9cf05ed..9991246e 100644 --- a/fw/face/udp-listener.go +++ b/fw/face/udp-listener.go @@ -9,6 +9,8 @@ package face import ( "context" + "errors" + "fmt" "net" "strconv" @@ -21,7 +23,7 @@ import ( type UDPListener struct { conn net.PacketConn localURI *defn.URI - HasQuit chan bool + stopped chan bool } // MakeUDPListener constructs a UDPListener. @@ -33,31 +35,34 @@ func MakeUDPListener(localURI *defn.URI) (*UDPListener, error) { l := new(UDPListener) l.localURI = localURI - l.HasQuit = make(chan bool, 1) + l.stopped = make(chan bool, 1) return l, nil } func (l *UDPListener) String() string { - return "UDPListener, " + l.localURI.String() + return fmt.Sprintf("UDPListener, %s", l.localURI) } // Run starts the UDP listener. func (l *UDPListener) Run() { + defer func() { l.stopped <- true }() + // Create dialer and set reuse address option listenConfig := &net.ListenConfig{Control: impl.SyscallReuseAddr} // Create listener - var err error var remote string if l.localURI.Scheme() == "udp4" { - remote = l.localURI.PathHost() + ":" + strconv.Itoa(int(l.localURI.Port())) + remote = fmt.Sprintf("%s:%d", l.localURI.PathHost(), l.localURI.Port()) } else { - remote = "[" + l.localURI.Path() + "]:" + strconv.Itoa(int(l.localURI.Port())) + remote = fmt.Sprintf("[%s]:%d", l.localURI.Path(), l.localURI.Port()) } + + // Start listening for incoming connections + var err error l.conn, err = listenConfig.ListenPacket(context.Background(), l.localURI.Scheme(), remote) if err != nil { core.LogError(l, "Unable to start UDP listener: ", err) - l.HasQuit <- true return } @@ -66,8 +71,11 @@ func (l *UDPListener) Run() { for !core.ShouldQuit { readSize, remoteAddr, err := l.conn.ReadFrom(recvBuf) if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } core.LogWarn(l, "Unable to read from socket (", err, ") - DROP ") - break + return } // Construct remote URI @@ -89,21 +97,21 @@ func (l *UDPListener) Run() { continue } - core.LogTrace(l, "Receive of size ", readSize, " from ", remoteURI) - // If frame received here, must be for new remote endpoint newTransport, err := MakeUnicastUDPTransport(remoteURI, l.localURI, PersistencyOnDemand) if err != nil { core.LogError(l, "Failed to create new unicast UDP transport: ", err) continue } - newLinkService := MakeNDNLPLinkService(newTransport, MakeNDNLPLinkServiceOptions()) - // Add face to table (which assigns FaceID) before passing current frame to link service - FaceTable.Add(newLinkService) - go newLinkService.Run(recvBuf[:readSize]) + core.LogInfo(l, "Accepting new UDP face ", newTransport.RemoteURI()) + MakeNDNLPLinkService(newTransport, MakeNDNLPLinkServiceOptions()).Run(recvBuf[:readSize]) } +} - l.conn.Close() - l.HasQuit <- true +func (l *UDPListener) Close() { + if l.conn != nil { + l.conn.Close() + <-l.stopped + } } diff --git a/fw/face/unicast-tcp-transport.go b/fw/face/unicast-tcp-transport.go index 2572a1cd..19fd369a 100644 --- a/fw/face/unicast-tcp-transport.go +++ b/fw/face/unicast-tcp-transport.go @@ -9,27 +9,37 @@ package face import ( "errors" + "fmt" "net" - "runtime" "strconv" "time" "github.com/named-data/YaNFD/core" defn "github.com/named-data/YaNFD/defn" "github.com/named-data/YaNFD/face/impl" + "github.com/zjkmxy/go-ndn/pkg/utils" ) // UnicastTCPTransport is a unicast TCP transport. type UnicastTCPTransport struct { + transportBase + dialer *net.Dialer conn *net.TCPConn localAddr net.TCPAddr remoteAddr net.TCPAddr - transportBase + + // Permanent face reconnection + rechan chan bool + closed bool // (permanently) } -// MakeUnicastTCPTransport creates a new unicast TCP transport. -func MakeUnicastTCPTransport(remoteURI *defn.URI, localURI *defn.URI, persistency Persistency) (*UnicastTCPTransport, error) { +// Makes an outgoing unicast TCP transport. +func MakeUnicastTCPTransport( + remoteURI *defn.URI, + localURI *defn.URI, + persistency Persistency, +) (*UnicastTCPTransport, error) { // Validate URIs. if !remoteURI.IsCanonical() || (remoteURI.Scheme() != "tcp4" && remoteURI.Scheme() != "tcp6") { @@ -39,11 +49,11 @@ func MakeUnicastTCPTransport(remoteURI *defn.URI, localURI *defn.URI, persistenc return nil, errors.New("do not specify localURI for TCP") } + // Construct transport t := new(UnicastTCPTransport) - // All persistencies are accepted. t.makeTransportBase(remoteURI, localURI, persistency, defn.NonLocal, defn.PointToPoint, defn.MaxNDNPacketSize) - t.expirationTime = new(time.Time) - *t.expirationTime = time.Now().Add(tcpLifetime) + t.expirationTime = utils.IdPtr(time.Now().Add(tcpLifetime)) + t.rechan = make(chan bool, 1) // Set scope ip := net.ParseIP(remoteURI.Path()) @@ -54,40 +64,31 @@ func MakeUnicastTCPTransport(remoteURI *defn.URI, localURI *defn.URI, persistenc } // Set local and remote addresses - if localURI != nil { - t.localAddr.IP = net.ParseIP(localURI.Path()) - t.localAddr.Port = int(localURI.Port()) - } else { - t.localAddr.Port = int(TCPUnicastPort) - } + t.localAddr.Port = int(TCPUnicastPort) t.remoteAddr.IP = net.ParseIP(remoteURI.Path()) t.remoteAddr.Port = int(remoteURI.Port()) - // Attempt to "dial" remote URI - var err error // Configure dialer so we can allow address reuse // Fix: for TCP we shouldn't specify the local address. Instead, we should obtain it from system. // Though it succeeds in Windows and MacOS, Linux does not allow this. t.dialer = &net.Dialer{Control: impl.SyscallReuseAddr} - conn, err := t.dialer.Dial(t.remoteURI.Scheme(), net.JoinHostPort(t.remoteURI.Path(), strconv.Itoa(int(t.remoteURI.Port())))) - if err != nil { - return nil, errors.New("Unable to connect to remote endpoint: " + err.Error()) - } - t.conn = conn.(*net.TCPConn) - if localURI == nil { - t.localAddr = *t.conn.LocalAddr().(*net.TCPAddr) - t.localURI = defn.DecodeURIString("tcp://" + t.localAddr.String()) - } + // Do not attempt to connect here at all, since it blocks the main thread + // The cost is that we can't compute the localUri here + // We will attempt to connect in the receive loop instead - t.changeState(defn.Up) - - go t.expirationHandler() + // Fake for filling up the response + t.localURI = defn.DecodeURIString("tcp://127.0.0.1:0") return t, nil } -func AcceptUnicastTCPTransport(remoteConn net.Conn, localURI *defn.URI, persistency Persistency) (*UnicastTCPTransport, error) { +// Accept an incoming unicast TCP transport. +func AcceptUnicastTCPTransport( + remoteConn net.Conn, + localURI *defn.URI, + persistency Persistency, +) (*UnicastTCPTransport, error) { // Construct remote URI var remoteURI *defn.URI remoteAddr := remoteConn.RemoteAddr() @@ -108,11 +109,19 @@ func AcceptUnicastTCPTransport(remoteConn net.Conn, localURI *defn.URI, persiste return nil, err } + // Construct transport t := new(UnicastTCPTransport) - // All persistencies are accepted. t.makeTransportBase(remoteURI, localURI, persistency, defn.NonLocal, defn.PointToPoint, defn.MaxNDNPacketSize) - t.expirationTime = new(time.Time) - *t.expirationTime = time.Now().Add(tcpLifetime) + t.expirationTime = utils.IdPtr(time.Now().Add(tcpLifetime)) + t.rechan = make(chan bool, 1) + + var success bool + t.conn, success = remoteConn.(*net.TCPConn) + if !success { + core.LogError("UnicastTCPTransport", "Specified connection ", remoteConn, " is not a net.TCPConn") + return nil, errors.New("specified connection is not a net.TCPConn") + } + t.running.Store(true) // Set scope ip := net.ParseIP(remoteURI.Path()) @@ -127,45 +136,25 @@ func AcceptUnicastTCPTransport(remoteConn net.Conn, localURI *defn.URI, persiste t.localAddr.IP = net.ParseIP(localURI.Path()) t.localAddr.Port = int(localURI.Port()) } else { - t.localAddr.Port = int(TCPUnicastPort) - } - t.remoteAddr.IP = net.ParseIP(remoteURI.Path()) - t.remoteAddr.Port = int(remoteURI.Port()) - - var success bool - t.conn, success = remoteConn.(*net.TCPConn) - if !success { - core.LogError("UnicastTCPTransport", "Specified connection ", remoteConn, " is not a net.TCPConn") - return nil, errors.New("specified connection is not a net.TCPConn") - } - - if localURI == nil { t.localAddr = *t.conn.LocalAddr().(*net.TCPAddr) - t.localURI = defn.DecodeURIString("tcp://" + t.localAddr.String()) + t.localURI = defn.DecodeURIString(fmt.Sprintf("tcp://%s", &t.localAddr)) } - t.changeState(defn.Up) - - go t.expirationHandler() + t.remoteAddr.IP = net.ParseIP(remoteURI.Path()) + t.remoteAddr.Port = int(remoteURI.Port()) return t, nil } func (t *UnicastTCPTransport) String() string { - return "UnicastTCPTransport, FaceID=" + strconv.FormatUint(t.faceID, 10) + ", RemoteURI=" + t.remoteURI.String() + ", LocalURI=" + t.localURI.String() + return fmt.Sprintf("UnicastTCPTransport, FaceID=%d, RemoteURI=%s, LocalURI=%s", t.faceID, t.remoteURI, t.localURI) } -// SetPersistency changes the persistency of the face. func (t *UnicastTCPTransport) SetPersistency(persistency Persistency) bool { - if persistency == t.persistency { - return true - } - t.persistency = persistency return true } -// GetSendQueueSize returns the current size of the send queue. func (t *UnicastTCPTransport) GetSendQueueSize() uint64 { rawConn, err := t.conn.SyscallConn() if err != nil { @@ -174,113 +163,124 @@ func (t *UnicastTCPTransport) GetSendQueueSize() uint64 { return impl.SyscallGetSocketSendQueueSize(rawConn) } -// onTransportFailure modifies the state of the UnicastTCPTransport to indicate -// a failure in transmission. -func (t *UnicastTCPTransport) onTransportFailure(fromReceive bool) { - switch t.persistency { - case PersistencyPermanent: - // Restart socket - t.conn.Close() - var err error - conn, err := t.dialer.Dial(t.remoteURI.Scheme(), net.JoinHostPort(t.remoteURI.Path(), - strconv.Itoa(int(t.remoteURI.Port())))) - if err != nil { - core.LogError(t, "Unable to connect to remote endpoint: ", err) - } - t.conn = conn.(*net.TCPConn) +// Set the connection and params. +func (t *UnicastTCPTransport) setConn(conn *net.TCPConn) { + t.conn = conn + t.localAddr = *t.conn.LocalAddr().(*net.TCPAddr) + t.localURI = defn.DecodeURIString("tcp://" + t.localAddr.String()) +} - if fromReceive { - t.runReceive() - } else { - // Old receive thread will error out, so we need to replace it - go t.runReceive() - } - default: - t.changeState(defn.Down) +// Attempt to reconnect to the remote transport. +func (t *UnicastTCPTransport) reconnect() { + // Shut down the existing socket + if t.conn != nil { + t.conn.Close() } -} -// expirationHandler checks if the face should expire (if on demand) -func (t *UnicastTCPTransport) expirationHandler() { + // Number of attempts we have made so far + attempt := 0 + + // Keep trying to reconnect until successful + // If the transport is not permanent, do not attempt to restart + // Do this inside the loop to account for changes to persistency for { - time.Sleep(time.Duration(10) * time.Second) - if t.state == defn.Down { - break + attempt++ + + // If there is no connection, this is the initial attempt to + // connect for any face, so we will continue regardless + // However, make only one attempt to connect for non-permanent faces + if !(t.conn == nil && attempt == 1) { + // Do not continue if the transport is not permanent or closed + if t.Persistency() != PersistencyPermanent || t.closed { + t.rechan <- false // do not continue + return + } } - if t.persistency == PersistencyOnDemand && (t.expirationTime.Before(time.Now()) || t.expirationTime.Equal(time.Now())) { - core.LogInfo(t, "Face expired") - t.changeState(defn.Down) - break + + // Restart socket for permanent transport + remote := net.JoinHostPort(t.remoteURI.Path(), strconv.Itoa(int(t.remoteURI.Port()))) + conn, err := t.dialer.Dial(t.remoteURI.Scheme(), remote) + if err != nil { + core.LogWarn(t, "Unable to connect to remote endpoint [", attempt, "]: ", err) + time.Sleep(5 * time.Second) // TODO: configurable + continue + } + + // If the transport was closed while we were trying to reconnect, + // close the new connection and return without notifying + if t.closed { + conn.Close() + return } + + // Connected to remote again + t.setConn(conn.(*net.TCPConn)) + t.rechan <- true // continue + return } } func (t *UnicastTCPTransport) sendFrame(frame []byte) { + if !t.running.Load() { + return + } + if len(frame) > t.MTU() { core.LogWarn(t, "Attempted to send frame larger than MTU - DROP") return } - core.LogDebug(t, "Sending frame of size ", len(frame)) _, err := t.conn.Write(frame) if err != nil { core.LogWarn(t, "Unable to send on socket - DROP") - t.onTransportFailure(false) + t.CloseConn() // receive might restart if needed return } + t.nOutBytes += uint64(len(frame)) *t.expirationTime = time.Now().Add(tcpLifetime) } func (t *UnicastTCPTransport) runReceive() { - if lockThreadsToCores { - runtime.LockOSThread() - } + defer t.Close() - recvBuf := make([]byte, defn.MaxNDNPacketSize) for { - readSize, err := t.conn.Read(recvBuf) - if err != nil { - if err.Error() == "EOF" { - core.LogDebug(t, "EOF - Face DOWN") - } else { - core.LogWarn(t, "Unable to read from socket (", err, ") - DROP") - t.onTransportFailure(true) + // The connection can be nil if the initial connection attempt + // failed for a persistent face. In that case we will reconnect. + if t.conn != nil { + err := readTlvStream(t.conn, func(b []byte) { + t.nInBytes += uint64(len(b)) + *t.expirationTime = time.Now().Add(tcpLifetime) + t.linkService.handleIncomingFrame(b) + }, nil) + if err == nil { + break // EOF } - t.changeState(defn.Down) - break - } - core.LogTrace(t, "Receive of size ", readSize) - t.nInBytes += uint64(readSize) - *t.expirationTime = time.Now().Add(tcpLifetime) + core.LogWarn(t, "Unable to read from socket (", err, ") - Face DOWN") + } - if readSize > defn.MaxNDNPacketSize { - core.LogWarn(t, "Received too much data without valid TLV block - DROP") - continue + // Persistent faces will reconnect, otherwise close + go t.reconnect() + if !<-t.rechan { + return // do not continue } - // Send up to link service - t.linkService.handleIncomingFrame(recvBuf[:readSize]) + core.LogInfo(t, "Connected socket - Face UP") + t.running.Store(true) } } -func (t *UnicastTCPTransport) changeState(new defn.State) { - if t.state == new { - return - } - - core.LogInfo(t, "state: ", t.state, " -> ", new) - t.state = new - - if t.state != defn.Up { - core.LogInfo(t, "Closing TCP socket") - t.hasQuit <- true +// Close the inner connection if running without closing the transport. +func (t *UnicastTCPTransport) CloseConn() { + if t.running.Swap(false) { t.conn.Close() - - // Stop link service - t.linkService.tellTransportQuit() - - FaceTable.Remove(t.faceID) } } + +// Close the connection permanently - this will not attempt to reconnect. +func (t *UnicastTCPTransport) Close() { + t.closed = true + t.rechan <- false + t.CloseConn() +} diff --git a/fw/face/unicast-udp-transport.go b/fw/face/unicast-udp-transport.go index e29a9b64..a43dc60b 100644 --- a/fw/face/unicast-udp-transport.go +++ b/fw/face/unicast-udp-transport.go @@ -9,9 +9,10 @@ package face import ( "errors" + "fmt" "net" - "runtime" "strconv" + "strings" "time" "github.com/named-data/YaNFD/core" @@ -30,7 +31,9 @@ type UnicastUDPTransport struct { // MakeUnicastUDPTransport creates a new unicast UDP transport. func MakeUnicastUDPTransport( - remoteURI *defn.URI, localURI *defn.URI, persistency Persistency, + remoteURI *defn.URI, + localURI *defn.URI, + persistency Persistency, ) (*UnicastUDPTransport, error) { // Validate URIs if !remoteURI.IsCanonical() || (remoteURI.Scheme() != "udp4" && remoteURI.Scheme() != "udp6") || @@ -38,8 +41,8 @@ func MakeUnicastUDPTransport( return nil, core.ErrNotCanonical } + // Construct transport t := new(UnicastUDPTransport) - // All persistencies are accepted t.makeTransportBase(remoteURI, localURI, persistency, defn.NonLocal, defn.PointToPoint, defn.MaxNDNPacketSize) t.expirationTime = new(time.Time) *t.expirationTime = time.Now().Add(udpLifetime) @@ -62,45 +65,36 @@ func MakeUnicastUDPTransport( t.remoteAddr.IP = net.ParseIP(remoteURI.Path()) t.remoteAddr.Port = int(remoteURI.Port()) - // Attempt to "dial" remote URI - var err error // Configure dialer so we can allow address reuse + // Unlike TCP, we don't need to do this in a separate goroutine because + // we don't need to wait for the connection to be established t.dialer = &net.Dialer{LocalAddr: &t.localAddr, Control: impl.SyscallReuseAddr} - conn, err := t.dialer.Dial(t.remoteURI.Scheme(), - net.JoinHostPort(t.remoteURI.Path(), strconv.Itoa(int(t.remoteURI.Port())))) + remote := net.JoinHostPort(t.remoteURI.Path(), strconv.Itoa(int(t.remoteURI.Port()))) + conn, err := t.dialer.Dial(t.remoteURI.Scheme(), remote) if err != nil { return nil, errors.New("Unable to connect to remote endpoint: " + err.Error()) } + t.conn = conn.(*net.UDPConn) + t.running.Store(true) if localURI == nil { t.localAddr = *t.conn.LocalAddr().(*net.UDPAddr) t.localURI = defn.DecodeURIString("udp://" + t.localAddr.String()) } - t.changeState(defn.Up) - - go t.expirationHandler() - return t, nil } func (t *UnicastUDPTransport) String() string { - return "UnicastUDPTransport, FaceID=" + strconv.FormatUint(t.faceID, 10) + - ", RemoteURI=" + t.remoteURI.String() + ", LocalURI=" + t.localURI.String() + return fmt.Sprintf("UnicastUDPTransport, FaceID=%d, RemoteURI=%s, LocalURI=%s", t.faceID, t.remoteURI, t.localURI) } -// SetPersistency changes the persistency of the face. func (t *UnicastUDPTransport) SetPersistency(persistency Persistency) bool { - if persistency == t.persistency { - return true - } - t.persistency = persistency return true } -// GetSendQueueSize returns the current size of the send queue. func (t *UnicastUDPTransport) GetSendQueueSize() uint64 { rawConn, err := t.conn.SyscallConn() if err != nil { @@ -109,111 +103,46 @@ func (t *UnicastUDPTransport) GetSendQueueSize() uint64 { return impl.SyscallGetSocketSendQueueSize(rawConn) } -func (t *UnicastUDPTransport) onTransportFailure(fromReceive bool) { - switch t.persistency { - case PersistencyPermanent: - // Restart socket - t.conn.Close() - var err error - conn, err := t.dialer.Dial(t.remoteURI.Scheme(), - net.JoinHostPort(t.remoteURI.Path(), strconv.Itoa(int(t.remoteURI.Port())))) - if err != nil { - core.LogError(t, "Unable to connect to remote endpoint: ", err) - } - t.conn = conn.(*net.UDPConn) - - if fromReceive { - t.runReceive() - } else { - // Old receive thread will error out, so we need to replace it - go t.runReceive() - } - default: - t.changeState(defn.Down) - } -} - -// expirationHandler checks if the face should expire (if on demand) -func (t *UnicastUDPTransport) expirationHandler() { - for { - time.Sleep(time.Duration(10) * time.Second) - if t.state == defn.Down { - break - } - if t.persistency == PersistencyOnDemand && (t.expirationTime.Before(time.Now()) || - t.expirationTime.Equal(time.Now())) { - core.LogInfo(t, "Face expired") - t.changeState(defn.Down) - break - } +func (t *UnicastUDPTransport) sendFrame(frame []byte) { + if !t.running.Load() { + return } -} -func (t *UnicastUDPTransport) sendFrame(frame []byte) { if len(frame) > t.MTU() { core.LogWarn(t, "Attempted to send frame larger than MTU - DROP") return } - core.LogDebug(t, "Sending frame of size ", len(frame)) _, err := t.conn.Write(frame) if err != nil { - core.LogWarn(t, "Unable to send on socket - DROP") - t.onTransportFailure(false) + core.LogWarn(t, "Unable to send on socket - DROP and Face DOWN") + t.Close() return } + t.nOutBytes += uint64(len(frame)) *t.expirationTime = time.Now().Add(udpLifetime) } func (t *UnicastUDPTransport) runReceive() { - if lockThreadsToCores { - runtime.LockOSThread() - } + defer t.Close() - recvBuf := make([]byte, defn.MaxNDNPacketSize) - for { - readSize, err := t.conn.Read(recvBuf) - if err != nil { - if err.Error() == "EOF" { - core.LogDebug(t, "EOF") - } else { - core.LogWarn(t, "Unable to read from socket (", err, ") - DROP") - t.onTransportFailure(true) - } - break - } - - core.LogTrace(t, "Receive of size ", readSize) - t.nInBytes += uint64(readSize) + err := readTlvStream(t.conn, func(b []byte) { + t.nInBytes += uint64(len(b)) *t.expirationTime = time.Now().Add(udpLifetime) - - if readSize > defn.MaxNDNPacketSize { - core.LogWarn(t, "Received too much data without valid TLV block - DROP") - continue - } - - // Send up to link service - t.linkService.handleIncomingFrame(recvBuf[:readSize]) + t.linkService.handleIncomingFrame(b) + }, func(err error) bool { + // Ignore since UDP is a connectionless protocol + // This happens if the other side is not listening (ICMP) + return strings.Contains(err.Error(), "connection refused") + }) + if err != nil { + core.LogWarn(t, "Unable to read from socket (", err, ") - Face DOWN") } } -func (t *UnicastUDPTransport) changeState(new defn.State) { - if t.state == new { - return - } - - core.LogInfo(t, "state: ", t.state, " -> ", new) - t.state = new - - if t.state != defn.Up { - core.LogInfo(t, "Closing UDP socket") - t.hasQuit <- true +func (t *UnicastUDPTransport) Close() { + if t.running.Swap(false) { t.conn.Close() - - // Stop link service - t.linkService.tellTransportQuit() - - FaceTable.Remove(t.faceID) } } diff --git a/fw/face/unix-stream-listener.go b/fw/face/unix-stream-listener.go index ea8a96e8..bbc5ae67 100644 --- a/fw/face/unix-stream-listener.go +++ b/fw/face/unix-stream-listener.go @@ -8,6 +8,7 @@ package face import ( + "errors" "net" "os" "path" @@ -21,7 +22,7 @@ type UnixStreamListener struct { conn net.Listener localURI *defn.URI nextFD int // We can't (at least easily) access the actual FD through net.Conn, so we'll make our own - HasQuit chan bool + stopped chan bool } // MakeUnixStreamListener constructs a UnixStreamListener. @@ -31,20 +32,20 @@ func MakeUnixStreamListener(localURI *defn.URI) (*UnixStreamListener, error) { return nil, core.ErrNotCanonical } - l := &UnixStreamListener{ + return &UnixStreamListener{ localURI: localURI, nextFD: 1, - HasQuit: make(chan bool, 1), - } - return l, nil + stopped: make(chan bool, 1), + }, nil } func (l *UnixStreamListener) String() string { return "UnixStreamListener, " + l.localURI.String() } -// Run starts the Unix stream listener. func (l *UnixStreamListener) Run() { + defer func() { l.stopped <- true }() + // Delete any existing socket os.Remove(l.localURI.Path()) @@ -67,18 +68,16 @@ func (l *UnixStreamListener) Run() { core.LogInfo(l, "Listening") // Run accept loop - for { + for !core.ShouldQuit { newConn, err := l.conn.Accept() if err != nil { - if err.Error() == "EOF" { - // Must have failed due to being closed, so quit quietly - } else { - core.LogWarn(l, "Unable to accept connection: ", err) + if errors.Is(err, net.ErrClosed) { + return } - break + core.LogWarn(l, "Unable to accept connection: ", err) + return } - // Construct remote URI remoteURI := defn.MakeFDFaceURI(l.nextFD) l.nextFD++ if !remoteURI.IsCanonical() { @@ -91,24 +90,17 @@ func (l *UnixStreamListener) Run() { core.LogError(l, "Failed to create new Unix stream transport: ", err) continue } - newLinkService := MakeNDNLPLinkService(newTransport, MakeNDNLPLinkServiceOptions()) - if newLinkService == nil { - core.LogError(l, "Failed to create new NDNLPv2 transport: ", err) - continue - } core.LogInfo(l, "Accepting new Unix stream face ", remoteURI) - - // Add face to table and start its thread - FaceTable.Add(newLinkService) - go newLinkService.Run(nil) + options := MakeNDNLPLinkServiceOptions() + options.IsFragmentationEnabled = false // reliable stream + MakeNDNLPLinkService(newTransport, options).Run(nil) } - - l.HasQuit <- true } -// Close closes the UnixStreamListener. func (l *UnixStreamListener) Close() { - core.LogInfo(l, "Stopping listener") - l.conn.Close() + if l.conn != nil { + l.conn.Close() + <-l.stopped + } } diff --git a/fw/face/unix-stream-transport.go b/fw/face/unix-stream-transport.go index 3f1b29c2..c4a82e77 100644 --- a/fw/face/unix-stream-transport.go +++ b/fw/face/unix-stream-transport.go @@ -8,23 +8,17 @@ package face import ( - "bufio" - "io" + "fmt" "net" - "runtime" - "strconv" "github.com/named-data/YaNFD/core" defn "github.com/named-data/YaNFD/defn" "github.com/named-data/YaNFD/face/impl" - - enc "github.com/zjkmxy/go-ndn/pkg/encoding" ) // UnixStreamTransport is a Unix stream transport for communicating with local applications. type UnixStreamTransport struct { - conn *net.UnixConn - reader *bufio.Reader + conn *net.UnixConn transportBase } @@ -40,16 +34,13 @@ func MakeUnixStreamTransport(remoteURI *defn.URI, localURI *defn.URI, conn net.C // Set connection t.conn = conn.(*net.UnixConn) - t.reader = bufio.NewReaderSize(t.conn, 32*defn.MaxNDNPacketSize) - - t.changeState(defn.Up) + t.running.Store(true) return t, nil } func (t *UnixStreamTransport) String() string { - return "UnixStreamTransport, FaceID=" + strconv.FormatUint(t.faceID, 10) + - ", RemoteURI=" + t.remoteURI.String() + ", LocalURI=" + t.localURI.String() + return fmt.Sprintf("UnixStreamTransport, FaceID=%d, RemoteURI=%s, LocalURI=%s", t.faceID, t.remoteURI, t.localURI) } // SetPersistency changes the persistency of the face. @@ -76,83 +67,39 @@ func (t *UnixStreamTransport) GetSendQueueSize() uint64 { } func (t *UnixStreamTransport) sendFrame(frame []byte) { + if !t.running.Load() { + return + } + if len(frame) > t.MTU() { core.LogWarn(t, "Attempted to send frame larger than MTU - DROP") return } - core.LogDebug(t, "Sending frame of size ", len(frame)) _, err := t.conn.Write(frame) if err != nil { core.LogWarn(t, "Unable to send on socket - DROP and Face DOWN") - t.changeState(defn.Down) + t.Close() + return } t.nOutBytes += uint64(len(frame)) } func (t *UnixStreamTransport) runReceive() { - core.LogTrace(t, "Starting receive thread") - - if lockThreadsToCores { - runtime.LockOSThread() - } + defer t.Close() - handleError := func(err error) { - if err.Error() == "EOF" { - core.LogDebug(t, "EOF - Face DOWN") - } else { - core.LogWarn(t, "Unable to read from socket (", err, ") - DROP and Face DOWN") - } - t.changeState(defn.Down) - } - - recvBuf := make([]byte, defn.MaxNDNPacketSize) - for { - typ, err := enc.ReadTLNum(t.reader) - if err != nil { - handleError(err) - break - } - - len, err := enc.ReadTLNum(t.reader) - if err != nil { - handleError(err) - break - } - - cursor := 0 - cursor += typ.EncodeInto(recvBuf[cursor:]) - cursor += len.EncodeInto(recvBuf[cursor:]) - - lenRead, err := io.ReadFull(t.reader, recvBuf[cursor:cursor+int(len)]) - if err != nil { - handleError(err) - break - } - cursor += lenRead - - t.linkService.handleIncomingFrame(recvBuf[:cursor]) - t.nInBytes += uint64(cursor) + err := readTlvStream(t.conn, func(b []byte) { + t.nInBytes += uint64(len(b)) + t.linkService.handleIncomingFrame(b) + }, nil) + if err != nil { + core.LogWarn(t, "Unable to read from socket (", err, ") - Face DOWN") } } -func (t *UnixStreamTransport) changeState(new defn.State) { - if t.state == new { - return - } - - core.LogInfo(t, "state: ", t.state, " -> ", new) - t.state = new - - if t.state != defn.Up { - core.LogInfo(t, "Closing Unix stream socket") - t.hasQuit <- true +func (t *UnixStreamTransport) Close() { + if t.running.Swap(false) { t.conn.Close() - - // Stop link service - t.linkService.tellTransportQuit() - - FaceTable.Remove(t.faceID) } } diff --git a/fw/face/web-socket-listener.go b/fw/face/web-socket-listener.go index e2b361b0..d2727ceb 100644 --- a/fw/face/web-socket-listener.go +++ b/fw/face/web-socket-listener.go @@ -33,7 +33,13 @@ type WebSocketListenerConfig struct { TLSKey string } -// URL returns server URL. +// WebSocketListener listens for incoming WebSockets connections. +type WebSocketListener struct { + server http.Server + upgrader websocket.Upgrader + localURI *defn.URI +} + func (cfg WebSocketListenerConfig) URL() *url.URL { addr := net.JoinHostPort(cfg.Bind, strconv.FormatUint(uint64(cfg.Port), 10)) u := &url.URL{ @@ -55,7 +61,6 @@ func (cfg WebSocketListenerConfig) String() string { return b.String() } -// NewWebSocketListener constructs a WebSocketListener. func NewWebSocketListener(cfg WebSocketListenerConfig) (*WebSocketListener, error) { localURI := cfg.URL() ret := &WebSocketListener{ @@ -80,31 +85,21 @@ func NewWebSocketListener(cfg WebSocketListenerConfig) (*WebSocketListener, erro return ret, nil } -// WebSocketListener listens for incoming WebSockets connections. -type WebSocketListener struct { - server http.Server - upgrader websocket.Upgrader - localURI *defn.URI -} - -var _ Listener = &WebSocketListener{} - func (l *WebSocketListener) String() string { return "WebSocketListener, " + l.localURI.String() } -// Run starts the WebSocket listener. func (l *WebSocketListener) Run() { l.server.Handler = http.HandlerFunc(l.handler) - var e error + var err error if l.server.TLSConfig == nil { - e = l.server.ListenAndServe() + err = l.server.ListenAndServe() } else { - e = l.server.ListenAndServeTLS("", "") + err = l.server.ListenAndServeTLS("", "") } - if !errors.Is(e, http.ErrServerClosed) { - core.LogFatal(l, "Unable to start listener: ", e) + if !errors.Is(err, http.ErrServerClosed) { + core.LogFatal(l, "Unable to start listener: ", err) } } @@ -114,15 +109,14 @@ func (l *WebSocketListener) handler(w http.ResponseWriter, r *http.Request) { return } - t := NewWebSocketTransport(l.localURI, c) - linkService := MakeNDNLPLinkService(t, MakeNDNLPLinkServiceOptions()) + newTransport := NewWebSocketTransport(l.localURI, c) + core.LogInfo(l, "Accepting new WebSocket face ", newTransport.RemoteURI()) - core.LogInfo(l, "Accepting new WebSocket face ", t.RemoteURI()) - FaceTable.Add(linkService) - go linkService.Run(nil) + options := MakeNDNLPLinkServiceOptions() + options.IsFragmentationEnabled = false // reliable stream + MakeNDNLPLinkService(newTransport, options).Run(nil) } -// Close closes the WebSocketListener. func (l *WebSocketListener) Close() { core.LogInfo(l, "Stopping listener") l.server.Shutdown(context.TODO()) diff --git a/fw/face/web-socket-transport.go b/fw/face/web-socket-transport.go index 4f4a6553..24e118f9 100644 --- a/fw/face/web-socket-transport.go +++ b/fw/face/web-socket-transport.go @@ -8,9 +8,8 @@ package face import ( + "fmt" "net" - "runtime" - "strconv" "github.com/gorilla/websocket" "github.com/named-data/YaNFD/core" @@ -23,12 +22,8 @@ type WebSocketTransport struct { c *websocket.Conn } -var _ transport = &WebSocketTransport{} - -// NewWebSocketTransport creates a Unix stream transport. func NewWebSocketTransport(localURI *defn.URI, c *websocket.Conn) (t *WebSocketTransport) { remoteURI := defn.MakeWebSocketClientFaceURI(c.RemoteAddr()) - t = &WebSocketTransport{c: c} scope := defn.NonLocal ip := net.ParseIP(remoteURI.PathHost()) @@ -36,55 +31,59 @@ func NewWebSocketTransport(localURI *defn.URI, c *websocket.Conn) (t *WebSocketT scope = defn.Local } + t = &WebSocketTransport{c: c} t.makeTransportBase(remoteURI, localURI, PersistencyOnDemand, scope, defn.PointToPoint, defn.MaxNDNPacketSize) - t.changeState(defn.Up) + t.running.Store(true) + return t } func (t *WebSocketTransport) String() string { - return "WebSocketTransport, FaceID=" + strconv.FormatUint(t.faceID, 10) + - ", RemoteURI=" + t.remoteURI.String() + ", LocalURI=" + t.localURI.String() + return fmt.Sprintf("WebSocketTransport, FaceID=%d, RemoteURI=%s, LocalURI=%s", t.faceID, t.remoteURI, t.localURI) } -// SetPersistency changes the persistency of the face. func (t *WebSocketTransport) SetPersistency(persistency Persistency) bool { return persistency == PersistencyOnDemand } -// GetSendQueueSize returns the current size of the send queue. func (t *WebSocketTransport) GetSendQueueSize() uint64 { return 0 } func (t *WebSocketTransport) sendFrame(frame []byte) { + if !t.running.Load() { + return + } + if len(frame) > t.MTU() { core.LogWarn(t, "Attempted to send frame larger than MTU - DROP") return } - core.LogDebug(t, "Sending frame of size ", len(frame)) e := t.c.WriteMessage(websocket.BinaryMessage, frame) if e != nil { core.LogWarn(t, "Unable to send on socket - DROP and Face DOWN") - t.changeState(defn.Down) + t.Close() + return } t.nOutBytes += uint64(len(frame)) } func (t *WebSocketTransport) runReceive() { - core.LogTrace(t, "Starting receive thread") - - if lockThreadsToCores { - runtime.LockOSThread() - } + defer t.Close() for { mt, message, e := t.c.ReadMessage() if e != nil { - core.LogWarn(t, "Unable to read from socket (", e, ") - DROP and Face DOWN") - t.changeState(defn.Down) - break + if websocket.IsCloseError(e) { + // gracefully closed + } else if websocket.IsUnexpectedCloseError(e) { + core.LogInfo(t, "WebSocket closed unexpectedly (", e, ") - DROP and Face DOWN") + } else { + core.LogWarn(t, "Unable to read from WebSocket (", e, ") - DROP and Face DOWN") + } + return } if mt != websocket.BinaryMessage { @@ -92,35 +91,17 @@ func (t *WebSocketTransport) runReceive() { continue } - core.LogTrace(t, "Receive of size ", len(message)) - t.nInBytes += uint64(len(message)) - if len(message) > defn.MaxNDNPacketSize { core.LogWarn(t, "Received too much data without valid TLV block - DROP") continue } - // Send up to link service + t.nInBytes += uint64(len(message)) t.linkService.handleIncomingFrame(message) } } -func (t *WebSocketTransport) changeState(new defn.State) { - if t.state == new { - return - } - - core.LogInfo(t, "state: ", t.state, " -> ", new) - t.state = new - - if t.state != defn.Up { - core.LogInfo(t, "Closing Unix stream socket") - t.hasQuit <- true - t.c.Close() - - // Stop link service - t.linkService.tellTransportQuit() - - FaceTable.Remove(t.faceID) - } +func (t *WebSocketTransport) Close() { + t.running.Store(false) + t.c.Close() } diff --git a/fw/fw/thread.go b/fw/fw/thread.go index f82cbc58..96a936db 100644 --- a/fw/fw/thread.go +++ b/fw/fw/thread.go @@ -391,10 +391,9 @@ func (t *Thread) processIncomingData(packet *defn.Pkt) { // Get PIT if present var pitToken *uint32 - if len(packet.PitToken) == 6 { - pitToken = new(uint32) - // We have already guaranteed that, if a PIT token is present, it is 6 bytes long - *pitToken = binary.BigEndian.Uint32(packet.PitToken[2:6]) + //lint:ignore S1009 removing the nil check causes a segfault ¯\_(ツ)_/¯ + if packet.PitToken != nil && len(packet.PitToken) == 6 { + pitToken = utils.IdPtr(binary.BigEndian.Uint32(packet.PitToken[2:6])) } // Get incoming face diff --git a/fw/mgmt/face.go b/fw/mgmt/face.go index e5c5312a..47d99257 100644 --- a/fw/mgmt/face.go +++ b/fw/mgmt/face.go @@ -201,10 +201,7 @@ func (f *FaceModule) create(interest *spec.Interest, pitToken []byte, inFace uin } linkService = face.MakeNDNLPLinkService(transport, options) - face.FaceTable.Add(linkService) - - // Start new face - go linkService.Run(nil) + linkService.Run(nil) } else if URI.Scheme() == "tcp4" || URI.Scheme() == "tcp6" { // Check that remote endpoint is not a unicast address if remoteAddr := net.ParseIP(URI.Path()); remoteAddr != nil && !remoteAddr.IsGlobalUnicast() && @@ -256,6 +253,7 @@ func (f *FaceModule) create(interest *spec.Interest, pitToken []byte, inFace uin // NDNLP link service parameters options := face.MakeNDNLPLinkServiceOptions() + options.IsFragmentationEnabled = false // reliable stream if params.Flags != nil { // Mask already guaranteed to be present if Flags is above flags := *params.Flags @@ -284,10 +282,7 @@ func (f *FaceModule) create(interest *spec.Interest, pitToken []byte, inFace uin } linkService = face.MakeNDNLPLinkService(transport, options) - face.FaceTable.Add(linkService) - - // Start new face - go linkService.Run(nil) + linkService.Run(nil) } else { // Unsupported scheme core.LogWarn(f, "Cannot create face with URI ", URI, ": Unsupported scheme ", URI) diff --git a/fw/table/pit-cs.go b/fw/table/pit-cs.go index 8899e9cd..e1aab94f 100644 --- a/fw/table/pit-cs.go +++ b/fw/table/pit-cs.go @@ -133,7 +133,7 @@ func (bpe *basePitEntry) InsertInRecord( record.LatestTimestamp = time.Now() record.LatestInterest = interest.NameV.Clone() record.ExpirationTime = time.Now().Add(time.Millisecond * 4000) - record.PitToken = incomingPitToken + record.PitToken = append([]byte{}, incomingPitToken...) bpe.inRecords[face] = record return record, false }