restic/internal/backend/smb/conpool.go

239 lines
6.0 KiB
Go

package smb
import (
"context"
"fmt"
"net"
"strconv"
"sync/atomic"
"github.com/hirochachacha/go-smb2"
"github.com/restic/restic/internal/debug"
)
// Parts of this code have been adapted from Rclone (https://github.com/rclone)
// Copyright (C) 2012 by Nick Craig-Wood http://www.craig-wood.com/nick/
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
// conn encapsulates a SMB client and corresponding SMB client
type conn struct {
conn *net.Conn
smbSession *smb2.Session
smbShare *smb2.Share
shareName string
}
// Closes the connection
func (c *conn) close() (err error) {
if c.smbShare != nil {
err = c.smbShare.Umount()
}
sessionLogoffErr := c.smbSession.Logoff()
if err != nil {
return err
}
return sessionLogoffErr
}
// True if it's closed
func (c *conn) closed() bool {
var nopErr error
if c.smbShare != nil {
// stat the current directory
_, nopErr = c.smbShare.Stat(".")
} else {
// list the shares
_, nopErr = c.smbSession.ListSharenames()
}
return nopErr == nil
}
// Show that we are using a SMB session
//
// Call removeSession() when done
func (b *Backend) addSession() {
atomic.AddInt32(&b.sessions, 1)
}
// Show the SMB session is no longer in use
func (b *Backend) removeSession() {
atomic.AddInt32(&b.sessions, -1)
}
// getSessions shows whether there are any sessions in use
func (b *Backend) getSessions() int32 {
return atomic.LoadInt32(&b.sessions)
}
// dial starts a client connection to the given SMB server. It is a
// convenience function that connects to the given network address,
// initiates the SMB handshake, and then returns a session for SMB communication.
func (b *Backend) dial(ctx context.Context, network, addr string) (*conn, error) {
dialer := net.Dialer{}
tconn, err := dialer.Dial(network, addr)
if err != nil {
return nil, err
}
var clientID [16]byte
if b.ClientGUID != "" {
copy(clientID[:], []byte(b.ClientGUID))
}
d := &smb2.Dialer{
Negotiator: smb2.Negotiator{
RequireMessageSigning: b.RequireMessageSigning,
SpecifiedDialect: b.Dialect,
ClientGuid: clientID,
},
Initiator: &smb2.NTLMInitiator{
User: b.User,
Password: b.Password.Unwrap(),
Domain: b.Domain,
},
}
session, err := d.DialContext(ctx, tconn)
if err != nil {
return nil, err
}
return &conn{
smbSession: session,
conn: &tconn,
}, nil
}
// Open a new connection to the SMB server.
func (b *Backend) newConnection(share string) (c *conn, err error) {
// As we are pooling these connections we need to decouple
// them from the current context
ctx := context.Background()
c, err = b.dial(ctx, "tcp", b.Host+":"+strconv.Itoa(b.Port))
if err != nil {
return nil, fmt.Errorf("couldn't connect SMB: %w", err)
}
if share != "" {
// mount the specified share as well if user requested
c.smbShare, err = c.smbSession.Mount(share)
if err != nil {
_ = c.smbSession.Logoff()
return nil, fmt.Errorf("couldn't initialize SMB: %w", err)
}
c.smbShare = c.smbShare.WithContext(ctx)
}
return c, nil
}
// Ensure the specified share is mounted or the session is unmounted
func (c *conn) mountShare(share string) (err error) {
if c.shareName == share {
return nil
}
if c.smbShare != nil {
err = c.smbShare.Umount()
c.smbShare = nil
}
if err != nil {
return
}
if share != "" {
c.smbShare, err = c.smbSession.Mount(share)
if err != nil {
return
}
}
c.shareName = share
return nil
}
// Get a SMB connection from the pool, or open a new one
func (b *Backend) getConnection(_ context.Context, share string) (c *conn, err error) {
b.poolMu.Lock()
for len(b.pool) > 0 {
c = b.pool[0]
b.pool = b.pool[1:]
err = c.mountShare(share)
if err == nil {
break
}
debug.Log("Discarding unusable SMB connection: %v", err)
c = nil
}
b.poolMu.Unlock()
if c != nil {
return c, nil
}
c, err = b.newConnection(share)
return c, err
}
// Return a SMB connection to the pool
func (b *Backend) putConnection(c *conn) {
var nopErr error
if c.smbShare != nil {
// stat the current directory
_, nopErr = c.smbShare.Stat(".")
} else {
// list the shares
_, nopErr = c.smbSession.ListSharenames()
}
if nopErr != nil {
debug.Log("Connection failed, closing: %v", nopErr)
_ = c.close()
return
}
b.poolMu.Lock()
b.pool = append(b.pool, c)
b.drain.Reset(b.Config.IdleTimeout) // nudge on the pool emptying timer
b.poolMu.Unlock()
}
// Drain the pool of any connections
func (b *Backend) drainPool() (err error) {
b.poolMu.Lock()
defer b.poolMu.Unlock()
if sessions := b.getSessions(); sessions != 0 {
debug.Log("Not closing %d unused connections as %d sessions active", len(b.pool), sessions)
b.drain.Reset(b.Config.IdleTimeout) // nudge on the pool emptying timer
return nil
}
if b.Config.IdleTimeout > 0 {
b.drain.Stop()
}
if len(b.pool) != 0 {
debug.Log("Closing %d unused connections", len(b.pool))
}
for i, c := range b.pool {
if !c.closed() {
cErr := c.close()
if cErr != nil {
err = cErr
}
}
b.pool[i] = nil
}
b.pool = nil
return err
}