package sftp import ( "encoding" "sort" "sync" ) // The goal of the packetManager is to keep the outgoing packets in the same // order as the incoming. This is due to some sftp clients requiring this // behavior (eg. winscp). type packetSender interface { sendPacket(encoding.BinaryMarshaler) error } type packetManager struct { requests chan requestPacket responses chan responsePacket fini chan struct{} incoming requestPacketIDs outgoing responsePackets sender packetSender // connection object working *sync.WaitGroup } func newPktMgr(sender packetSender) *packetManager { s := &packetManager{ requests: make(chan requestPacket, SftpServerWorkerCount), responses: make(chan responsePacket, SftpServerWorkerCount), fini: make(chan struct{}), incoming: make([]uint32, 0, SftpServerWorkerCount), outgoing: make([]responsePacket, 0, SftpServerWorkerCount), sender: sender, working: &sync.WaitGroup{}, } go s.controller() return s } type responsePackets []responsePacket func (r responsePackets) Sort() { sort.Slice(r, func(i, j int) bool { return r[i].id() < r[j].id() }) } type requestPacketIDs []uint32 func (r requestPacketIDs) Sort() { sort.Slice(r, func(i, j int) bool { return r[i] < r[j] }) } // register incoming packets to be handled // send id of 0 for packets without id func (s *packetManager) incomingPacket(pkt requestPacket) { s.working.Add(1) s.requests <- pkt // buffer == SftpServerWorkerCount } // register outgoing packets as being ready func (s *packetManager) readyPacket(pkt responsePacket) { s.responses <- pkt s.working.Done() } // shut down packetManager controller func (s *packetManager) close() { // pause until current packets are processed s.working.Wait() close(s.fini) } // Passed a worker function, returns a channel for incoming packets. // The goal is to process packets in the order they are received as is // requires by section 7 of the RFC, while maximizing throughput of file // transfers. func (s *packetManager) workerChan(runWorker func(requestChan)) requestChan { rwChan := make(chan requestPacket, SftpServerWorkerCount) for i := 0; i < SftpServerWorkerCount; i++ { runWorker(rwChan) } cmdChan := make(chan requestPacket) runWorker(cmdChan) pktChan := make(chan requestPacket, SftpServerWorkerCount) go func() { // start with cmdChan curChan := cmdChan for pkt := range pktChan { // on file open packet, switch to rwChan switch pkt.(type) { case *sshFxpOpenPacket: curChan = rwChan // on file close packet, switch back to cmdChan // after waiting for any reads/writes to finish case *sshFxpClosePacket: // wait for rwChan to finish s.working.Wait() // stop using rwChan curChan = cmdChan } s.incomingPacket(pkt) curChan <- pkt } close(rwChan) close(cmdChan) s.close() }() return pktChan } // process packets func (s *packetManager) controller() { for { select { case pkt := <-s.requests: debug("incoming id: %v", pkt.id()) s.incoming = append(s.incoming, pkt.id()) if len(s.incoming) > 1 { s.incoming.Sort() } case pkt := <-s.responses: debug("outgoing pkt: %v", pkt.id()) s.outgoing = append(s.outgoing, pkt) if len(s.outgoing) > 1 { s.outgoing.Sort() } case <-s.fini: return } s.maybeSendPackets() } } // send as many packets as are ready func (s *packetManager) maybeSendPackets() { for { if len(s.outgoing) == 0 || len(s.incoming) == 0 { debug("break! -- outgoing: %v; incoming: %v", len(s.outgoing), len(s.incoming)) break } out := s.outgoing[0] in := s.incoming[0] // debug("incoming: %v", s.incoming) // debug("outgoing: %v", outfilter(s.outgoing)) if in == out.id() { s.sender.sendPacket(out) // pop off heads copy(s.incoming, s.incoming[1:]) // shift left s.incoming = s.incoming[:len(s.incoming)-1] // remove last copy(s.outgoing, s.outgoing[1:]) // shift left s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last } else { break } } } //func outfilter(o []responsePacket) []uint32 { // res := make([]uint32, 0, len(o)) // for _, v := range o { // res = append(res, v.id()) // } // return res //}