From 074b201d7da4f47964c661776b4df6fbbf552fad Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Thu, 8 Jan 2015 20:33:06 -0800 Subject: [PATCH 01/41] Reference ScriptRock ssh library, not current google one --- client.go | 3 ++- example_test.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index f48b0be7..c26a8714 100644 --- a/client.go +++ b/client.go @@ -10,7 +10,8 @@ import ( "github.com/kr/fs" - "golang.org/x/crypto/ssh" + //"golang.org/x/crypto/ssh" + ssh "github.com/ScriptRock/ssh_block" ) // New creates a new SFTP client on conn. diff --git a/example_test.go b/example_test.go index 3f73726d..34dba1fb 100644 --- a/example_test.go +++ b/example_test.go @@ -6,7 +6,8 @@ import ( "os" "os/exec" - "golang.org/x/crypto/ssh" + //"golang.org/x/crypto/ssh" + ssh "github.com/ScriptRock/ssh_block" "github.com/pkg/sftp" ) From 1165da51c711fe7149023e9ae6a3171435330f5e Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Fri, 13 Feb 2015 09:59:47 -0800 Subject: [PATCH 02/41] Updated to use github.com/Scriptrock/crypto/ssh --- client.go | 2 +- example_test.go | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index c26a8714..73cf949f 100644 --- a/client.go +++ b/client.go @@ -11,7 +11,7 @@ import ( "github.com/kr/fs" //"golang.org/x/crypto/ssh" - ssh "github.com/ScriptRock/ssh_block" + "github.com/ScriptRock/crypto/ssh" ) // New creates a new SFTP client on conn. diff --git a/example_test.go b/example_test.go index 34dba1fb..e38ecaf4 100644 --- a/example_test.go +++ b/example_test.go @@ -6,10 +6,8 @@ import ( "os" "os/exec" - //"golang.org/x/crypto/ssh" - ssh "github.com/ScriptRock/ssh_block" - - "github.com/pkg/sftp" + "github.com/ScriptRock/crypto/ssh" + "github.com/ScriptRock/sftp" ) func Example(conn *ssh.Client) { From 5b6348f034e71f1377f88aaff7344ac7c2c86ea1 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sat, 25 Jul 2015 01:19:29 -0700 Subject: [PATCH 03/41] version, lstat --- .gitignore | 6 ++ attrs.go | 115 +++++++++++++++++++- client.go | 7 ++ packet.go | 115 +++++++++++++++++++- packet_test.go | 6 +- server.go | 215 ++++++++++++++++++++++++++++++++++++++ server_standalone/main.go | 15 +++ 7 files changed, 468 insertions(+), 11 deletions(-) create mode 100644 .gitignore create mode 100644 server.go create mode 100644 server_standalone/main.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..00ea0bf2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.*.swo +.*.swp + +testdata/ +server_standalone/server_standalone + diff --git a/attrs.go b/attrs.go index 0d37db08..f7c69cc9 100644 --- a/attrs.go +++ b/attrs.go @@ -10,11 +10,24 @@ import ( ) const ( - ssh_FILEXFER_ATTR_SIZE = 0x00000001 - ssh_FILEXFER_ATTR_UIDGID = 0x00000002 - ssh_FILEXFER_ATTR_PERMISSIONS = 0x00000004 - ssh_FILEXFER_ATTR_ACMODTIME = 0x00000008 - ssh_FILEXFER_ATTR_EXTENDED = 0x80000000 + ssh_FILEXFER_ATTR_SIZE = 0x00000001 + ssh_FILEXFER_ATTR_UIDGID = 0x00000002 + ssh_FILEXFER_ATTR_PERMISSIONS = 0x00000004 + ssh_FILEXFER_ATTR_ACMODTIME = 0x00000008 // protocol version 2 + ssh_FILEXFER_ATTR_ACCESSTIME = 0x00000008 // protocol version 3 onwards + ssh_FILEXFER_ATTR_CREATETIME = 0x00000010 + ssh_FILEXFER_ATTR_MODIFYTIME = 0x00000020 + ssh_FILEXFER_ATTR_ACL = 0x00000040 + ssh_FILEXFER_ATTR_OWNERGROUP = 0x00000080 + ssh_FILEXFER_ATTR_SUBSECOND_TIMES = 0x00000100 + ssh_FILEXFER_ATTR_BITS = 0x00000200 + ssh_FILEXFER_ATTR_ALLOCATION_SIZE = 0x00000400 + ssh_FILEXFER_ATTR_TEXT_HINT = 0x00000800 + ssh_FILEXFER_ATTR_MIME_TYPE = 0x00001000 + ssh_FILEXFER_ATTR_LINK_COUNT = 0x00002000 + ssh_FILEXFER_ATTR_UNTRANSLATED_NAME = 0x00004000 + ssh_FILEXFER_ATTR_CTIME = 0x00008000 + ssh_FILEXFER_ATTR_EXTENDED = 0x80000000 ) // fileInfo is an artificial type designed to satisfy os.FileInfo. @@ -106,6 +119,57 @@ func unmarshalAttrs(b []byte) (*FileStat, []byte) { return &fs, b } +func marshalFileInfo(b []byte, fi os.FileInfo) []byte { + // attributes variable struct, and also variable per protocol version + // spec version 3 attributes: + // uint32 flags + // uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE + // uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS + // uint32 atime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED + // string extended_type + // string extended_data + // ... more extended data (extended_type - extended_data pairs), + // so that number of pairs equals extended_count + + uid := uint32(0) + gid := uint32(0) + mtime := uint32(fi.ModTime().Unix()) + atime := mtime + + flags := ssh_FILEXFER_ATTR_SIZE | + ssh_FILEXFER_ATTR_PERMISSIONS | + ssh_FILEXFER_ATTR_ACMODTIME | + uint32(0) + + if statt, ok := fi.Sys().(*syscall.Stat_t); ok { + flags |= ssh_FILEXFER_ATTR_UIDGID + uid = statt.Uid + gid = statt.Gid + } + + b = marshalUint32(b, flags) // flags + if flags&ssh_FILEXFER_ATTR_SIZE != 0 { + b = marshalUint64(b, uint64(fi.Size())) // size + } + if flags&ssh_FILEXFER_ATTR_UIDGID != 0 { + b = marshalUint32(b, uid) + b = marshalUint32(b, gid) + } + if flags&ssh_FILEXFER_ATTR_PERMISSIONS != 0 { + b = marshalUint32(b, fromFileMode(fi.Mode())) // permissions + } + if flags&ssh_FILEXFER_ATTR_ACMODTIME != 0 { + b = marshalUint32(b, atime) + b = marshalUint32(b, mtime) + } + + return b +} + // toFileMode converts sftp filemode bits to the os.FileMode specification func toFileMode(mode uint32) os.FileMode { var fm = os.FileMode(mode & 0777) @@ -136,3 +200,44 @@ func toFileMode(mode uint32) os.FileMode { } return fm } + +// fromFileMode converts from the os.FileMode specification to sftp filemode bits +func fromFileMode(mode os.FileMode) uint32 { + ret := uint32(0) + + if mode&os.ModeDevice != 0 { + if mode&os.ModeCharDevice != 0 { + ret |= syscall.S_IFCHR + } else { + ret |= syscall.S_IFBLK + } + } + if mode&os.ModeDir != 0 { + ret |= syscall.S_IFDIR + } + if mode&os.ModeSymlink != 0 { + ret |= syscall.S_IFLNK + } + if mode&os.ModeNamedPipe != 0 { + ret |= syscall.S_IFIFO + } + if mode&os.ModeSetgid != 0 { + ret |= syscall.S_ISGID + } + if mode&os.ModeSetuid != 0 { + ret |= syscall.S_ISUID + } + if mode&os.ModeSticky != 0 { + ret |= syscall.S_ISVTX + } + if mode&os.ModeSocket != 0 { + ret |= syscall.S_IFSOCK + } + + if mode == 0 { + ret |= syscall.S_IFREG + } + ret |= uint32(mode & os.ModePerm) + + return ret +} diff --git a/client.go b/client.go index 73cf949f..f895548c 100644 --- a/client.go +++ b/client.go @@ -689,6 +689,13 @@ func unmarshalStatus(id uint32, data []byte) error { } } +func marshalStatus(b []byte, err StatusError) []byte { + b = marshalUint32(b, err.Code) + b = marshalString(b, err.msg) + b = marshalString(b, err.lang) + return b +} + // flags converts the flags passed to OpenFile into ssh flags. // Unsupported flags are ignored. func flags(f int) uint32 { diff --git a/packet.go b/packet.go index f2c2b671..c131f96f 100644 --- a/packet.go +++ b/packet.go @@ -2,11 +2,14 @@ package sftp import ( "encoding" + "encoding/hex" "fmt" "io" "reflect" ) +var shortPacketError = fmt.Errorf("packet too short") + func marshalUint32(b []byte, v uint32) []byte { return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) } @@ -52,17 +55,45 @@ func unmarshalUint32(b []byte) (uint32, []byte) { return v, b[4:] } +func unmarshalUint32Safe(b []byte) (uint32, []byte, error) { + if len(b) < 4 { + return 0, nil, shortPacketError + } + v := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 + return v, b[4:], nil +} + func unmarshalUint64(b []byte) (uint64, []byte) { h, b := unmarshalUint32(b) l, b := unmarshalUint32(b) return uint64(h)<<32 | uint64(l), b } +func unmarshalUint64Safe(b []byte) (uint64, []byte, error) { + if len(b) < 8 { + return 0, nil, shortPacketError + } + h, b := unmarshalUint32(b) + l, b := unmarshalUint32(b) + return uint64(h)<<32 | uint64(l), b, nil +} + func unmarshalString(b []byte) (string, []byte) { n, b := unmarshalUint32(b) return string(b[:n]), b[n:] } +func unmarshalStringSafe(b []byte) (string, []byte, error) { + n, b, err := unmarshalUint32Safe(b) + if err != nil { + return "", nil, err + } + if int64(n) > int64(len(b)) { + return "", nil, shortPacketError + } + return string(b[:n]), b[n:], nil +} + // sendPacket marshals p according to RFC 4234. func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { @@ -72,7 +103,7 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { } l := uint32(len(bb)) hdr := []byte{byte(l >> 24), byte(l >> 16), byte(l >> 8), byte(l)} - debug("send packet %T, len: %v", m, l) + debug("send packet %T, len: %v data: %v", m, l, hex.EncodeToString(bb)) _, err = w.Write(hdr) if err != nil { return err @@ -81,39 +112,101 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { return err } +func (svr *Server) sendPacket(m encoding.BinaryMarshaler) error { + return sendPacket(svr.out, m) +} + func recvPacket(r io.Reader) (uint8, []byte, error) { var b = []byte{0, 0, 0, 0} if _, err := io.ReadFull(r, b); err != nil { return 0, nil, err } l, _ := unmarshalUint32(b) + debug("recv packet %d bytes", l) b = make([]byte, l) if _, err := io.ReadFull(r, b); err != nil { + debug("recv packet %d bytes: err %v", l, err) return 0, nil, err } + debug("recv packet %d bytes: %v", l, hex.EncodeToString(b)) return b[0], b[1:], nil } +type ExtensionPair struct { + Name string + Data string +} + +func unmarshalExtensionPair(b []byte) (ExtensionPair, []byte, error) { + ep := ExtensionPair{} + var err error = nil + ep.Name, b, err = unmarshalStringSafe(b) + if err != nil { + return ep, b, err + } + ep.Data, b, err = unmarshalStringSafe(b) + if err != nil { + return ep, b, err + } + return ep, b, err +} + // Here starts the definition of packets along with their MarshalBinary // implementations. // Manually writing the marshalling logic wins us a lot of time and // allocation. type sshFxInitPacket struct { + Version uint32 + Extensions []ExtensionPair +} + +func (p sshFxInitPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 // byte + uint32 + for _, e := range p.Extensions { + l += 4 + len(e.Name) + 4 + len(e.Data) + } + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_INIT) + b = marshalUint32(b, p.Version) + for _, e := range p.Extensions { + b = marshalString(b, e.Name) + b = marshalString(b, e.Data) + } + return b, nil +} + +func (p *sshFxInitPacket) UnmarshalBinary(b []byte) (err error) { + if p.Version, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + for len(b) > 0 { + ep := ExtensionPair{} + ep, b, err = unmarshalExtensionPair(b) + if err != nil { + return err + } + p.Extensions = append(p.Extensions, ep) + } + return nil +} + +type sshFxVersionPacket struct { Version uint32 Extensions []struct { Name, Data string } } -func (p sshFxInitPacket) MarshalBinary() ([]byte, error) { +func (p sshFxVersionPacket) MarshalBinary() ([]byte, error) { l := 1 + 4 // byte + uint32 for _, e := range p.Extensions { l += 4 + len(e.Name) + 4 + len(e.Data) } b := make([]byte, 0, l) - b = append(b, ssh_FXP_INIT) + b = append(b, ssh_FXP_VERSION) b = marshalUint32(b, p.Version) for _, e := range p.Extensions { b = marshalString(b, e.Name) @@ -133,6 +226,18 @@ func marshalIdString(packetType byte, id uint32, str string) ([]byte, error) { return b, nil } +func unmarshalIdString(b []byte, id *uint32, str *string) (err error) { + *id, b, err = unmarshalUint32Safe(b) + if err != nil { + return + } + *str, b, err = unmarshalStringSafe(b) + if err != nil { + return + } + return +} + type sshFxpReaddirPacket struct { Id uint32 Handle string @@ -160,6 +265,10 @@ func (p sshFxpLstatPacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_LSTAT, p.Id, p.Path) } +func (p *sshFxpLstatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Path) +} + type sshFxpFstatPacket struct { Id uint32 Handle string diff --git a/packet_test.go b/packet_test.go index 80a1ebf0..88a1522c 100644 --- a/packet_test.go +++ b/packet_test.go @@ -144,7 +144,7 @@ var sendPacketTests = []struct { }{ {sshFxInitPacket{ Version: 3, - Extensions: []struct{ Name, Data string }{ + Extensions: []ExtensionPair{ {"posix-rename@openssh.com", "1"}, }, }, []byte{0x0, 0x0, 0x0, 0x26, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, @@ -197,7 +197,7 @@ var recvPacketTests = []struct { }{ {sp(sshFxInitPacket{ Version: 3, - Extensions: []struct{ Name, Data string }{ + Extensions: []ExtensionPair{ {"posix-rename@openssh.com", "1"}, }, }), ssh_FXP_INIT, []byte{0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, @@ -217,7 +217,7 @@ func BenchmarkMarshalInit(b *testing.B) { for i := 0; i < b.N; i++ { sp(sshFxInitPacket{ Version: 3, - Extensions: []struct{ Name, Data string }{ + Extensions: []ExtensionPair{ {"posix-rename@openssh.com", "1"}, }, }) diff --git a/server.go b/server.go new file mode 100644 index 00000000..e5000273 --- /dev/null +++ b/server.go @@ -0,0 +1,215 @@ +package sftp + +// sftp server counterpart + +import ( + "encoding" + "fmt" + "io" + "os" + "sync" + "syscall" +) + +type FileSystem interface { + Lstat(p string) (os.FileInfo, error) +} + +type nativeFs struct { +} + +func (nfs *nativeFs) Lstat(p string) (os.FileInfo, error) { return os.Lstat(p) } + +type Server struct { + in io.Reader + out io.Writer + rootDir string + lastId uint32 + fs FileSystem + pktChan chan serverRespondablePacket + openFiles map[string]*svrFile + openFilesLock *sync.Mutex +} + +type serverRespondablePacket interface { + encoding.BinaryUnmarshaler + respond(svr *Server) error +} + +type svrFile struct { +} + +// Creates a new server instance around the provided streams. +// A subsequent call to Run() is required. +func NewServer(in io.Reader, out io.Writer, rootDir string) (*Server, error) { + if rootDir == "" { + if wd, err := os.Getwd(); err != nil { + return nil, err + } else { + rootDir = wd + } + } + return &Server{ + in: in, + out: out, + rootDir: rootDir, + fs: &nativeFs{}, + pktChan: make(chan serverRespondablePacket, 4), + openFiles: map[string]*svrFile{}, + openFilesLock: &sync.Mutex{}, + }, nil +} + +// Unmarshal a single logical packet from the secure channel +func (svr *Server) rxPackets() error { + defer close(svr.pktChan) + + for { + pktType, pktBytes, err := recvPacket(svr.in) + if err == io.EOF { + return nil + } else if err != nil { + fmt.Fprintf(os.Stderr, "recvPacket error: %v\n", err) + return err + } + + if pkt, err := svr.decodePacket(fxp(pktType), pktBytes); err != nil { + fmt.Fprintf(os.Stderr, "decodePacket error: %v\n", err) + return err + } else { + svr.pktChan <- pkt + } + } +} + +func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondablePacket, error) { + //pktId, restBytes := unmarshalUint32(pktBytes[1:]) + var pkt serverRespondablePacket = nil + switch pktType { + case ssh_FXP_INIT: + pkt = &sshFxInitPacket{} + case ssh_FXP_LSTAT: + pkt = &sshFxpLstatPacket{} + case ssh_FXP_VERSION: + case ssh_FXP_OPEN: + case ssh_FXP_CLOSE: + case ssh_FXP_READ: + case ssh_FXP_WRITE: + case ssh_FXP_FSTAT: + case ssh_FXP_SETSTAT: + case ssh_FXP_FSETSTAT: + case ssh_FXP_OPENDIR: + case ssh_FXP_READDIR: + case ssh_FXP_REMOVE: + case ssh_FXP_MKDIR: + case ssh_FXP_RMDIR: + case ssh_FXP_REALPATH: + case ssh_FXP_STAT: + case ssh_FXP_RENAME: + case ssh_FXP_READLINK: + case ssh_FXP_SYMLINK: + case ssh_FXP_STATUS: + case ssh_FXP_HANDLE: + case ssh_FXP_DATA: + case ssh_FXP_NAME: + case ssh_FXP_ATTRS: + case ssh_FXP_EXTENDED: + case ssh_FXP_EXTENDED_REPLY: + default: + } + if pkt == nil { + return nil, fmt.Errorf("unhandled packet type: %s", pktType.String()) + } + if err := pkt.UnmarshalBinary(pktBytes); err != nil { + return nil, err + } + return pkt, nil +} + +// Run this server until the streams stop or until the subsystem is stopped +func (svr *Server) Run() error { + go svr.rxPackets() + for pkt := range svr.pktChan { + fmt.Fprintf(os.Stderr, "pkt: %T %v\n", pkt, pkt) + pkt.respond(svr) + } + return nil +} + +func (p sshFxInitPacket) respond(svr *Server) error { + return svr.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) +} + +type sshFxpLstatReponse struct { + Id uint32 + info os.FileInfo +} + +func (p sshFxpLstatReponse) MarshalBinary() ([]byte, error) { + b := []byte{ssh_FXP_ATTRS} + b = marshalUint32(b, p.Id) + b = marshalFileInfo(b, p.info) + return b, nil +} + +func (p sshFxpLstatPacket) respond(svr *Server) error { + // stat the requested file + if info, err := svr.fs.Lstat(p.Path); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + return svr.sendPacket(sshFxpLstatReponse{p.Id, info}) + } +} + +type sshFxpStatusPacket struct { + Id uint32 + StatusError +} + +func (p sshFxpStatusPacket) MarshalBinary() ([]byte, error) { + b := []byte{ssh_FXP_STATUS} + b = marshalUint32(b, p.Id) + b = marshalStatus(b, p.StatusError) + return b, nil +} + +func statusFromError(id uint32, err error) sshFxpStatusPacket { + ret := sshFxpStatusPacket{ + Id: id, + StatusError: StatusError{ + // ssh_FX_OK = 0 + // ssh_FX_EOF = 1 + // ssh_FX_NO_SUCH_FILE = 2 ENOENT + // ssh_FX_PERMISSION_DENIED = 3 + // ssh_FX_FAILURE = 4 + // ssh_FX_BAD_MESSAGE = 5 + // ssh_FX_NO_CONNECTION = 6 + // ssh_FX_CONNECTION_LOST = 7 + // ssh_FX_OP_UNSUPPORTED = 8 + Code: ssh_FX_FAILURE, + msg: err.Error(), + lang: "", + }, + } + debug("statusFromError: error is %T %#v", err, err) + if err == io.EOF { + ret.StatusError.Code = ssh_FX_EOF + } + if pathError, ok := err.(*os.PathError); ok { + debug("statusFromError: error is %T %#v", pathError.Err, pathError.Err) + if errno, ok := pathError.Err.(syscall.Errno); ok { + if errno == 0 { + ret.StatusError.Code = ssh_FX_OK + } else if errno == syscall.ENOENT { + ret.StatusError.Code = ssh_FX_NO_SUCH_FILE + } else if errno == syscall.EPERM { + ret.StatusError.Code = ssh_FX_PERMISSION_DENIED + } else { + ret.StatusError.Code = ssh_FX_FAILURE + } + + ret.StatusError.Code = uint32(errno) + } + } + return ret +} diff --git a/server_standalone/main.go b/server_standalone/main.go new file mode 100644 index 00000000..e053c243 --- /dev/null +++ b/server_standalone/main.go @@ -0,0 +1,15 @@ +package main + +// small wrapper around sftp server that allows it to be used as a separate process subsystem call by the ssh server. +// in practice this will statically link; however this allows unit testing from the sftp client. + +import ( + "os" + + "github.com/ScriptRock/sftp" +) + +func main() { + svr, _ := sftp.NewServer(os.Stdin, os.Stdout, "") + svr.Run() +} From 91d0c9e68a826eb141092ab124908843753d50a0 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sat, 25 Jul 2015 19:07:33 -0700 Subject: [PATCH 04/41] mkdir --- packet.go | 11 +++++++++++ server.go | 53 +++++++++++++++++++++++++++++++++-------------------- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/packet.go b/packet.go index c131f96f..69689718 100644 --- a/packet.go +++ b/packet.go @@ -418,6 +418,17 @@ func (p sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { return b, nil } +func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) (err error) { + if p.Id, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + return nil +} + type sshFxpSetstatPacket struct { Id uint32 Path string diff --git a/server.go b/server.go index e5000273..ecc788b9 100644 --- a/server.go +++ b/server.go @@ -13,12 +13,14 @@ import ( type FileSystem interface { Lstat(p string) (os.FileInfo, error) + Mkdir(name string, perm os.FileMode) error } type nativeFs struct { } -func (nfs *nativeFs) Lstat(p string) (os.FileInfo, error) { return os.Lstat(p) } +func (nfs *nativeFs) Lstat(p string) (os.FileInfo, error) { return os.Lstat(p) } +func (nfs *nativeFs) Mkdir(name string, perm os.FileMode) error { return os.Mkdir(name, perm) } type Server struct { in io.Reader @@ -102,6 +104,7 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable case ssh_FXP_READDIR: case ssh_FXP_REMOVE: case ssh_FXP_MKDIR: + pkt = &sshFxpMkdirPacket{} case ssh_FXP_RMDIR: case ssh_FXP_REALPATH: case ssh_FXP_STAT: @@ -161,6 +164,12 @@ func (p sshFxpLstatPacket) respond(svr *Server) error { } } +func (p sshFxpMkdirPacket) respond(svr *Server) error { + // ignore flags field + err := svr.fs.Mkdir(p.Path, 0755) + return svr.sendPacket(statusFromError(p.Id, err)) +} + type sshFxpStatusPacket struct { Id uint32 StatusError @@ -186,29 +195,33 @@ func statusFromError(id uint32, err error) sshFxpStatusPacket { // ssh_FX_NO_CONNECTION = 6 // ssh_FX_CONNECTION_LOST = 7 // ssh_FX_OP_UNSUPPORTED = 8 - Code: ssh_FX_FAILURE, - msg: err.Error(), + Code: ssh_FX_OK, + msg: "", lang: "", }, } - debug("statusFromError: error is %T %#v", err, err) - if err == io.EOF { - ret.StatusError.Code = ssh_FX_EOF - } - if pathError, ok := err.(*os.PathError); ok { - debug("statusFromError: error is %T %#v", pathError.Err, pathError.Err) - if errno, ok := pathError.Err.(syscall.Errno); ok { - if errno == 0 { - ret.StatusError.Code = ssh_FX_OK - } else if errno == syscall.ENOENT { - ret.StatusError.Code = ssh_FX_NO_SUCH_FILE - } else if errno == syscall.EPERM { - ret.StatusError.Code = ssh_FX_PERMISSION_DENIED - } else { - ret.StatusError.Code = ssh_FX_FAILURE + if err != nil { + debug("statusFromError: error is %T %#v", err, err) + ret.StatusError.Code = ssh_FX_FAILURE + ret.StatusError.msg = err.Error() + if err == io.EOF { + ret.StatusError.Code = ssh_FX_EOF + } + if pathError, ok := err.(*os.PathError); ok { + debug("statusFromError: error is %T %#v", pathError.Err, pathError.Err) + if errno, ok := pathError.Err.(syscall.Errno); ok { + if errno == 0 { + ret.StatusError.Code = ssh_FX_OK + } else if errno == syscall.ENOENT { + ret.StatusError.Code = ssh_FX_NO_SUCH_FILE + } else if errno == syscall.EPERM { + ret.StatusError.Code = ssh_FX_PERMISSION_DENIED + } else { + ret.StatusError.Code = ssh_FX_FAILURE + } + + ret.StatusError.Code = uint32(errno) } - - ret.StatusError.Code = uint32(errno) } } return ret From 058e1bee5829356e66c5632d03bc2ca6c92d70de Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 26 Jul 2015 01:32:19 -0700 Subject: [PATCH 05/41] open & close --- packet.go | 41 ++++++++++++++++++++ server.go | 110 ++++++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 136 insertions(+), 15 deletions(-) diff --git a/packet.go b/packet.go index 69689718..bd8e0110 100644 --- a/packet.go +++ b/packet.go @@ -287,6 +287,10 @@ func (p sshFxpClosePacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_CLOSE, p.Id, p.Handle) } +func (p *sshFxpClosePacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Handle) +} + type sshFxpRemovePacket struct { Id uint32 Filename string @@ -335,6 +339,19 @@ func (p sshFxpOpenPacket) MarshalBinary() ([]byte, error) { return b, nil } +func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) (err error) { + if p.Id, b, err = unmarshalUint32Safe(b); err != nil { + return + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return + } else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil { + return + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return + } + return +} + type sshFxpReadPacket struct { Id uint32 Handle string @@ -449,3 +466,27 @@ func (p sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { b = marshal(b, p.Attrs) return b, nil } + +type sshFxpHandlePacket struct { + Id uint32 + Handle string +} + +func (p sshFxpHandlePacket) MarshalBinary() ([]byte, error) { + b := []byte{ssh_FXP_HANDLE} + b = marshalUint32(b, p.Id) + b = marshalString(b, p.Handle) + return b, nil +} + +type sshFxpStatusPacket struct { + Id uint32 + StatusError +} + +func (p sshFxpStatusPacket) MarshalBinary() ([]byte, error) { + b := []byte{ssh_FXP_STATUS} + b = marshalUint32(b, p.Id) + b = marshalStatus(b, p.StatusError) + return b, nil +} diff --git a/server.go b/server.go index ecc788b9..d572e29b 100644 --- a/server.go +++ b/server.go @@ -16,11 +16,38 @@ type FileSystem interface { Mkdir(name string, perm os.FileMode) error } +type FileSystemOpen interface { + FileSystem + OpenFile(name string, flag int, perm os.FileMode) (file *os.File, err error) +} + +type FileSystemSFTPOpen interface { + FileSystem + OpenFile(path string, f int) (*File, error) // sftp package has a strange OpenFile method with no perm +} + +// common subset of os.File and sftp.File +type svrFile interface { + Chmod(mode os.FileMode) error + Chown(uid, gid int) error + Close() error + Read(b []byte) (int, error) + Seek(offset int64, whence int) (int64, error) + Stat() (os.FileInfo, error) + Truncate(size int64) error + Write(b []byte) (int, error) + // func (f *File) WriteTo(w io.Writer) (int64, error) // not in os + // func (f *File) ReadFrom(r io.Reader) (int64, error) // not in os +} + type nativeFs struct { } func (nfs *nativeFs) Lstat(p string) (os.FileInfo, error) { return os.Lstat(p) } func (nfs *nativeFs) Mkdir(name string, perm os.FileMode) error { return os.Mkdir(name, perm) } +func (nfs *nativeFs) OpenFile(name string, flag int, perm os.FileMode) (file *os.File, err error) { + return os.OpenFile(name, flag, perm) +} type Server struct { in io.Reader @@ -29,8 +56,29 @@ type Server struct { lastId uint32 fs FileSystem pktChan chan serverRespondablePacket - openFiles map[string]*svrFile - openFilesLock *sync.Mutex + openFiles map[string]svrFile + openFilesLock *sync.RWMutex + handleCount int +} + +func (svr *Server) nextHandle(f svrFile) string { + svr.openFilesLock.Lock() + defer svr.openFilesLock.Unlock() + svr.handleCount++ + handle := fmt.Sprintf("%d", svr.handleCount) + svr.openFiles[handle] = f + return handle +} + +func (svr *Server) closeHandle(handle string) error { + svr.openFilesLock.Lock() + defer svr.openFilesLock.Unlock() + if f, ok := svr.openFiles[handle]; ok { + delete(svr.openFiles, handle) + return f.Close() + } else { + return syscall.EBADF + } } type serverRespondablePacket interface { @@ -38,9 +86,6 @@ type serverRespondablePacket interface { respond(svr *Server) error } -type svrFile struct { -} - // Creates a new server instance around the provided streams. // A subsequent call to Run() is required. func NewServer(in io.Reader, out io.Writer, rootDir string) (*Server, error) { @@ -57,8 +102,8 @@ func NewServer(in io.Reader, out io.Writer, rootDir string) (*Server, error) { rootDir: rootDir, fs: &nativeFs{}, pktChan: make(chan serverRespondablePacket, 4), - openFiles: map[string]*svrFile{}, - openFilesLock: &sync.Mutex{}, + openFiles: map[string]svrFile{}, + openFilesLock: &sync.RWMutex{}, }, nil } @@ -94,7 +139,9 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable pkt = &sshFxpLstatPacket{} case ssh_FXP_VERSION: case ssh_FXP_OPEN: + pkt = &sshFxpOpenPacket{} case ssh_FXP_CLOSE: + pkt = &sshFxpClosePacket{} case ssh_FXP_READ: case ssh_FXP_WRITE: case ssh_FXP_FSTAT: @@ -170,16 +217,49 @@ func (p sshFxpMkdirPacket) respond(svr *Server) error { return svr.sendPacket(statusFromError(p.Id, err)) } -type sshFxpStatusPacket struct { - Id uint32 - StatusError +func (p sshFxpOpenPacket) respond(svr *Server) error { + osFlags := 0 + if p.Pflags&ssh_FXF_READ != 0 && p.Pflags&ssh_FXF_WRITE != 0 { + osFlags |= os.O_RDWR + } else if p.Pflags&ssh_FXF_READ != 0 { + osFlags |= os.O_RDONLY + } else if p.Pflags&ssh_FXF_WRITE != 0 { + osFlags |= os.O_WRONLY + } + if p.Pflags&ssh_FXF_APPEND != 0 { + osFlags |= os.O_APPEND + } + if p.Pflags&ssh_FXF_CREAT != 0 { + osFlags |= os.O_CREATE + } + if p.Pflags&ssh_FXF_TRUNC != 0 { + osFlags |= os.O_TRUNC + } + if p.Pflags&ssh_FXF_EXCL != 0 { + osFlags |= os.O_EXCL + } + + if fso, ok := svr.fs.(FileSystemOpen); ok { + if f, err := fso.OpenFile(p.Path, osFlags, 0644); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + handle := svr.nextHandle(f) + return svr.sendPacket(sshFxpHandlePacket{p.Id, handle}) + } + } else if sftpo, ok := svr.fs.(FileSystemSFTPOpen); ok { + if f, err := sftpo.OpenFile(p.Path, osFlags); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + handle := svr.nextHandle(f) + return svr.sendPacket(sshFxpHandlePacket{p.Id, handle}) + } + } else { + return svr.sendPacket(statusFromError(p.Id, fmt.Errorf("unknown filesystem backend"))) + } } -func (p sshFxpStatusPacket) MarshalBinary() ([]byte, error) { - b := []byte{ssh_FXP_STATUS} - b = marshalUint32(b, p.Id) - b = marshalStatus(b, p.StatusError) - return b, nil +func (p sshFxpClosePacket) respond(svr *Server) error { + return svr.sendPacket(statusFromError(p.Id, svr.closeHandle(p.Handle))) } func statusFromError(id uint32, err error) sshFxpStatusPacket { From 2888b4a6b140074cd859d29a774a42239ec248ac Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Wed, 29 Jul 2015 17:24:24 -0700 Subject: [PATCH 06/41] implement read --- client_integration_test.go | 37 ++++++++++++++++++++++++++++++++++ packet.go | 41 ++++++++++++++++++++++++++++++++++++++ server.go | 28 ++++++++++++++++++++++++++ 3 files changed, 106 insertions(+) diff --git a/client_integration_test.go b/client_integration_test.go index 828deea3..205aac74 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -427,6 +427,43 @@ func sameFile(want, got os.FileInfo) bool { want.Size() == got.Size() } +func TestClientReadSimple(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f, err := ioutil.TempFile(d, "read-test") + if err != nil { + t.Fatal(err) + } + fname := f.Name() + f.Write([]byte("hello")) + f.Close() + + f2, err := sftp.Open(fname) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + stuff := make([]byte, 32) + n, err := f2.Read(stuff) + if err != nil && err != io.EOF { + t.Fatalf("err: %v", err) + } + if n != 5 { + t.Fatalf("n should be 5, is %v", n) + } + if string(stuff[0:5]) != "hello" { + t.Fatalf("invalid contents") + } +} + var clientReadTests = []struct { n int64 }{ diff --git a/packet.go b/packet.go index bd8e0110..a4c27906 100644 --- a/packet.go +++ b/packet.go @@ -373,6 +373,19 @@ func (p sshFxpReadPacket) MarshalBinary() ([]byte, error) { return b, nil } +func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) (err error) { + if p.Id, b, err = unmarshalUint32Safe(b); err != nil { + return + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return + } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil { + return + } else if p.Len, b, err = unmarshalUint32Safe(b); err != nil { + return + } + return +} + type sshFxpRenamePacket struct { Id uint32 Oldpath string @@ -490,3 +503,31 @@ func (p sshFxpStatusPacket) MarshalBinary() ([]byte, error) { b = marshalStatus(b, p.StatusError) return b, nil } + +type sshFxpDataPacket struct { + Id uint32 + Length uint32 + Data []byte +} + +func (p sshFxpDataPacket) MarshalBinary() ([]byte, error) { + b := []byte{ssh_FXP_DATA} + b = marshalUint32(b, p.Id) + b = marshalUint32(b, p.Length) + b = append(b, p.Data[:p.Length]...) + return b, nil +} + +func (p *sshFxpDataPacket) UnmarshalBinary(b []byte) (err error) { + if p.Id, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if uint32(len(b)) < p.Length { + return fmt.Errorf("truncated packet") + } else { + p.Data = make([]byte, p.Length) + copy(p.Data, b) + return nil + } +} diff --git a/server.go b/server.go index d572e29b..ea170a4b 100644 --- a/server.go +++ b/server.go @@ -81,6 +81,13 @@ func (svr *Server) closeHandle(handle string) error { } } +func (svr *Server) getHandle(handle string) (svrFile, bool) { + svr.openFilesLock.RLock() + defer svr.openFilesLock.RUnlock() + f, ok := svr.openFiles[handle] + return f, ok +} + type serverRespondablePacket interface { encoding.BinaryUnmarshaler respond(svr *Server) error @@ -143,6 +150,7 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable case ssh_FXP_CLOSE: pkt = &sshFxpClosePacket{} case ssh_FXP_READ: + pkt = &sshFxpReadPacket{} case ssh_FXP_WRITE: case ssh_FXP_FSTAT: case ssh_FXP_SETSTAT: @@ -262,6 +270,26 @@ func (p sshFxpClosePacket) respond(svr *Server) error { return svr.sendPacket(statusFromError(p.Id, svr.closeHandle(p.Handle))) } +func (p sshFxpReadPacket) respond(svr *Server) error { + if f, ok := svr.getHandle(p.Handle); !ok { + return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + } else if p.Len > maxWritePacket { + return svr.sendPacket(statusFromError(p.Id, syscall.EINVAL)) + } else if osf, ok := f.(*os.File); ok { + debug("in readpacket server respond: len %d", p.Len) + ret := sshFxpDataPacket{Id: p.Id, Length: p.Len, Data: make([]byte, p.Len)} + if n, err := osf.ReadAt(ret.Data, int64(p.Offset)); err != nil && (err != io.EOF || n == 0) { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + ret.Length = uint32(n) + return svr.sendPacket(ret) + } + } else { + // server error... + return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + } +} + func statusFromError(id uint32, err error) sshFxpStatusPacket { ret := sshFxpStatusPacket{ Id: id, From c9cee8ac6ffd80cf1b52eafef74328af32d2e0c3 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Wed, 29 Jul 2015 17:37:58 -0700 Subject: [PATCH 07/41] implement write --- packet.go | 18 ++++++++++++++++++ server.go | 34 +++++++++++++++++++++++++--------- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/packet.go b/packet.go index a4c27906..ac5abeff 100644 --- a/packet.go +++ b/packet.go @@ -429,6 +429,24 @@ func (s sshFxpWritePacket) MarshalBinary() ([]byte, error) { return b, nil } +func (p *sshFxpWritePacket) UnmarshalBinary(b []byte) (err error) { + if p.Id, b, err = unmarshalUint32Safe(b); err != nil { + return + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return + } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil { + return + } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil { + return + } else if uint32(len(b)) < p.Length { + err = shortPacketError + return + } else { + p.Data = append([]byte{}, b[:p.Length]...) + } + return +} + type sshFxpMkdirPacket struct { Id uint32 Path string diff --git a/server.go b/server.go index ea170a4b..08adc765 100644 --- a/server.go +++ b/server.go @@ -152,6 +152,7 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable case ssh_FXP_READ: pkt = &sshFxpReadPacket{} case ssh_FXP_WRITE: + pkt = &sshFxpWritePacket{} case ssh_FXP_FSTAT: case ssh_FXP_SETSTAT: case ssh_FXP_FSETSTAT: @@ -273,17 +274,32 @@ func (p sshFxpClosePacket) respond(svr *Server) error { func (p sshFxpReadPacket) respond(svr *Server) error { if f, ok := svr.getHandle(p.Handle); !ok { return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) - } else if p.Len > maxWritePacket { - return svr.sendPacket(statusFromError(p.Id, syscall.EINVAL)) - } else if osf, ok := f.(*os.File); ok { - debug("in readpacket server respond: len %d", p.Len) - ret := sshFxpDataPacket{Id: p.Id, Length: p.Len, Data: make([]byte, p.Len)} - if n, err := osf.ReadAt(ret.Data, int64(p.Offset)); err != nil && (err != io.EOF || n == 0) { - return svr.sendPacket(statusFromError(p.Id, err)) + } else { + if p.Len > maxWritePacket { + p.Len = maxWritePacket + } + if osf, ok := f.(*os.File); ok { + debug("in readpacket server respond: len %d", p.Len) + ret := sshFxpDataPacket{Id: p.Id, Length: p.Len, Data: make([]byte, p.Len)} + if n, err := osf.ReadAt(ret.Data, int64(p.Offset)); err != nil && (err != io.EOF || n == 0) { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + ret.Length = uint32(n) + return svr.sendPacket(ret) + } } else { - ret.Length = uint32(n) - return svr.sendPacket(ret) + // server error... + return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) } + } +} + +func (p sshFxpWritePacket) respond(svr *Server) error { + if f, ok := svr.getHandle(p.Handle); !ok { + return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + } else if osf, ok := f.(*os.File); ok { + _, err := osf.WriteAt(p.Data, int64(p.Offset)) + return svr.sendPacket(statusFromError(p.Id, err)) } else { // server error... return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) From bf6b5bce28b4e19fb2e0740c61231e071cd3fef9 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Thu, 30 Jul 2015 09:21:59 -0700 Subject: [PATCH 08/41] fstat --- packet.go | 8 +++++--- server.go | 22 +++++++++++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/packet.go b/packet.go index ac5abeff..15845cbf 100644 --- a/packet.go +++ b/packet.go @@ -2,7 +2,6 @@ package sftp import ( "encoding" - "encoding/hex" "fmt" "io" "reflect" @@ -103,7 +102,7 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { } l := uint32(len(bb)) hdr := []byte{byte(l >> 24), byte(l >> 16), byte(l >> 8), byte(l)} - debug("send packet %T, len: %v data: %v", m, l, hex.EncodeToString(bb)) + debug("send packet %T, len: %v", m, l) _, err = w.Write(hdr) if err != nil { return err @@ -128,7 +127,6 @@ func recvPacket(r io.Reader) (uint8, []byte, error) { debug("recv packet %d bytes: err %v", l, err) return 0, nil, err } - debug("recv packet %d bytes: %v", l, hex.EncodeToString(b)) return b[0], b[1:], nil } @@ -278,6 +276,10 @@ func (p sshFxpFstatPacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_FSTAT, p.Id, p.Handle) } +func (p *sshFxpFstatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Handle) +} + type sshFxpClosePacket struct { Id uint32 Handle string diff --git a/server.go b/server.go index 08adc765..8101c500 100644 --- a/server.go +++ b/server.go @@ -154,6 +154,7 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable case ssh_FXP_WRITE: pkt = &sshFxpWritePacket{} case ssh_FXP_FSTAT: + pkt = &sshFxpFstatPacket{} case ssh_FXP_SETSTAT: case ssh_FXP_FSETSTAT: case ssh_FXP_OPENDIR: @@ -199,12 +200,12 @@ func (p sshFxInitPacket) respond(svr *Server) error { return svr.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) } -type sshFxpLstatReponse struct { +type sshFxpStatReponse struct { Id uint32 info os.FileInfo } -func (p sshFxpLstatReponse) MarshalBinary() ([]byte, error) { +func (p sshFxpStatReponse) MarshalBinary() ([]byte, error) { b := []byte{ssh_FXP_ATTRS} b = marshalUint32(b, p.Id) b = marshalFileInfo(b, p.info) @@ -216,7 +217,22 @@ func (p sshFxpLstatPacket) respond(svr *Server) error { if info, err := svr.fs.Lstat(p.Path); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { - return svr.sendPacket(sshFxpLstatReponse{p.Id, info}) + return svr.sendPacket(sshFxpStatReponse{p.Id, info}) + } +} + +func (p sshFxpFstatPacket) respond(svr *Server) error { + if f, ok := svr.getHandle(p.Handle); !ok { + return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + } else if osf, ok := f.(*os.File); ok { + if info, err := osf.Stat(); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + return svr.sendPacket(sshFxpStatReponse{p.Id, info}) + } + } else { + // server error... + return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) } } From 7ab0966023d3837128d6eb3001e4c5a3d84b19d7 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Thu, 30 Jul 2015 23:43:00 -0700 Subject: [PATCH 09/41] readonly (only for open right now) --- server.go | 54 +++++++++++++++++++++++++++------------ server_standalone/main.go | 19 +++++++++++++- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/server.go b/server.go index 8101c500..7f7ab39e 100644 --- a/server.go +++ b/server.go @@ -52,6 +52,9 @@ func (nfs *nativeFs) OpenFile(name string, flag int, perm os.FileMode) (file *os type Server struct { in io.Reader out io.Writer + debugStream io.Writer + debugLevel int + readOnly bool rootDir string lastId uint32 fs FileSystem @@ -95,7 +98,7 @@ type serverRespondablePacket interface { // Creates a new server instance around the provided streams. // A subsequent call to Run() is required. -func NewServer(in io.Reader, out io.Writer, rootDir string) (*Server, error) { +func NewServer(in io.Reader, out io.Writer, debugStream io.Writer, debugLevel int, readOnly bool, rootDir string) (*Server, error) { if rootDir == "" { if wd, err := os.Getwd(); err != nil { return nil, err @@ -106,6 +109,9 @@ func NewServer(in io.Reader, out io.Writer, rootDir string) (*Server, error) { return &Server{ in: in, out: out, + debugStream: debugStream, + debugLevel: debugLevel, + readOnly: readOnly, rootDir: rootDir, fs: &nativeFs{}, pktChan: make(chan serverRespondablePacket, 4), @@ -245,12 +251,23 @@ func (p sshFxpMkdirPacket) respond(svr *Server) error { func (p sshFxpOpenPacket) respond(svr *Server) error { osFlags := 0 if p.Pflags&ssh_FXF_READ != 0 && p.Pflags&ssh_FXF_WRITE != 0 { + if svr.readOnly { + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } osFlags |= os.O_RDWR - } else if p.Pflags&ssh_FXF_READ != 0 { - osFlags |= os.O_RDONLY } else if p.Pflags&ssh_FXF_WRITE != 0 { + if svr.readOnly { + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } osFlags |= os.O_WRONLY + } else if p.Pflags&ssh_FXF_READ != 0 { + osFlags |= os.O_RDONLY + } else { + // how are they opening? + return svr.sendPacket(statusFromError(p.Id, syscall.EINVAL)) + } + if p.Pflags&ssh_FXF_APPEND != 0 { osFlags |= os.O_APPEND } @@ -322,6 +339,20 @@ func (p sshFxpWritePacket) respond(svr *Server) error { } } +func errnoToSshErr(errno syscall.Errno) uint32 { + if errno == 0 { + return ssh_FX_OK + } else if errno == syscall.ENOENT { + return ssh_FX_NO_SUCH_FILE + } else if errno == syscall.EPERM { + return ssh_FX_PERMISSION_DENIED + } else { + return ssh_FX_FAILURE + } + + return uint32(errno) +} + func statusFromError(id uint32, err error) sshFxpStatusPacket { ret := sshFxpStatusPacket{ Id: id, @@ -346,21 +377,12 @@ func statusFromError(id uint32, err error) sshFxpStatusPacket { ret.StatusError.msg = err.Error() if err == io.EOF { ret.StatusError.Code = ssh_FX_EOF - } - if pathError, ok := err.(*os.PathError); ok { + } else if errno, ok := err.(syscall.Errno); ok { + ret.StatusError.Code = errnoToSshErr(errno) + } else if pathError, ok := err.(*os.PathError); ok { debug("statusFromError: error is %T %#v", pathError.Err, pathError.Err) if errno, ok := pathError.Err.(syscall.Errno); ok { - if errno == 0 { - ret.StatusError.Code = ssh_FX_OK - } else if errno == syscall.ENOENT { - ret.StatusError.Code = ssh_FX_NO_SUCH_FILE - } else if errno == syscall.EPERM { - ret.StatusError.Code = ssh_FX_PERMISSION_DENIED - } else { - ret.StatusError.Code = ssh_FX_FAILURE - } - - ret.StatusError.Code = uint32(errno) + ret.StatusError.Code = errnoToSshErr(errno) } } } diff --git a/server_standalone/main.go b/server_standalone/main.go index e053c243..7bdfcefe 100644 --- a/server_standalone/main.go +++ b/server_standalone/main.go @@ -4,12 +4,29 @@ package main // in practice this will statically link; however this allows unit testing from the sftp client. import ( + "flag" + "io/ioutil" "os" "github.com/ScriptRock/sftp" ) func main() { - svr, _ := sftp.NewServer(os.Stdin, os.Stdout, "") + readOnly := false + debugLevelStr := "none" + debugLevel := 0 + debugStderr := false + flag.BoolVar(&readOnly, "R", false, "read-only server") + flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.StringVar(&debugLevelStr, "l", "none", "debug level") + flag.Parse() + + debugStream := ioutil.Discard + if debugStderr { + debugStream = os.Stderr + debugLevel = 1 + } + + svr, _ := sftp.NewServer(os.Stdin, os.Stdout, debugStream, debugLevel, readOnly, "") svr.Run() } From 435f753792d60133ec3be774bfa291ee925a2e2b Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Fri, 31 Jul 2015 15:46:13 -0700 Subject: [PATCH 10/41] readdir, rename, remove --- client_integration_test.go | 2 +- packet.go | 69 ++++++++++++++++++++ server.go | 127 ++++++++++++++++++++++++++++++++----- 3 files changed, 180 insertions(+), 18 deletions(-) diff --git a/client_integration_test.go b/client_integration_test.go index 205aac74..49aa1e18 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -404,7 +404,7 @@ func TestClientRename(t *testing.T) { } } -func TestClientReadLine(t *testing.T) { +func TestClientReadLink(t *testing.T) { sftp, cmd := testClient(t, READWRITE) defer cmd.Wait() defer sftp.Close() diff --git a/packet.go b/packet.go index 15845cbf..293f4f7d 100644 --- a/packet.go +++ b/packet.go @@ -4,6 +4,7 @@ import ( "encoding" "fmt" "io" + "os" "reflect" ) @@ -22,6 +23,9 @@ func marshalString(b []byte, v string) []byte { } func marshal(b []byte, v interface{}) []byte { + if v == nil { + return b + } switch v := v.(type) { case uint8: return append(b, v) @@ -31,6 +35,8 @@ func marshal(b []byte, v interface{}) []byte { return marshalUint64(b, v) case string: return marshalString(b, v) + case os.FileInfo: + return marshalFileInfo(b, v) default: switch d := reflect.ValueOf(v); d.Kind() { case reflect.Struct: @@ -245,6 +251,10 @@ func (p sshFxpReaddirPacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_READDIR, p.Id, p.Handle) } +func (p *sshFxpReaddirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Handle) +} + type sshFxpOpendirPacket struct { Id uint32 Path string @@ -254,6 +264,10 @@ func (p sshFxpOpendirPacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_OPENDIR, p.Id, p.Path) } +func (p *sshFxpOpendirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Path) +} + type sshFxpLstatPacket struct { Id uint32 Path string @@ -302,6 +316,10 @@ func (p sshFxpRemovePacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_REMOVE, p.Id, p.Filename) } +func (p *sshFxpRemovePacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Filename) +} + type sshFxpRmdirPacket struct { Id uint32 Path string @@ -311,6 +329,10 @@ func (p sshFxpRmdirPacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_RMDIR, p.Id, p.Path) } +func (p *sshFxpRmdirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Path) +} + type sshFxpReadlinkPacket struct { Id uint32 Path string @@ -320,6 +342,42 @@ func (p sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_READLINK, p.Id, p.Path) } +func (p *sshFxpReadlinkPacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Path) +} + +type sshFxpNameAttr struct { + Name string + Attrs interface{} +} + +func (p sshFxpNameAttr) MarshalBinary() ([]byte, error) { + b := []byte{} + b = marshalString(b, p.Name) + b = marshal(b, p.Attrs) + return b, nil +} + +type sshFxpNamePacket struct { + Id uint32 + NameAttrs []sshFxpNameAttr +} + +func (p sshFxpNamePacket) MarshalBinary() ([]byte, error) { + b := []byte{} + b = append(b, ssh_FXP_NAME) + b = marshalUint32(b, p.Id) + b = marshalUint32(b, uint32(len(p.NameAttrs))) + for _, na := range p.NameAttrs { + if ab, err := na.MarshalBinary(); err != nil { + return nil, err + } else { + b = append(b, ab...) + } + } + return b, nil +} + type sshFxpOpenPacket struct { Id uint32 Path string @@ -407,6 +465,17 @@ func (p sshFxpRenamePacket) MarshalBinary() ([]byte, error) { return b, nil } +func (p *sshFxpRenamePacket) UnmarshalBinary(b []byte) (err error) { + if p.Id, b, err = unmarshalUint32Safe(b); err != nil { + return + } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil { + return + } else if p.Newpath, b, err = unmarshalStringSafe(b); err != nil { + return + } + return +} + type sshFxpWritePacket struct { Id uint32 Handle string diff --git a/server.go b/server.go index 7f7ab39e..0f0d4128 100644 --- a/server.go +++ b/server.go @@ -12,18 +12,23 @@ import ( ) type FileSystem interface { - Lstat(p string) (os.FileInfo, error) - Mkdir(name string, perm os.FileMode) error + Lstat(name string) (os.FileInfo, error) + Remove(name string) error + Rename(oldpath, newpath string) error } -type FileSystemOpen interface { +type FileSystemOS interface { FileSystem OpenFile(name string, flag int, perm os.FileMode) (file *os.File, err error) + Readlink(path string) (string, error) + Mkdir(name string, perm os.FileMode) error } -type FileSystemSFTPOpen interface { +type FileSystemSFTP interface { FileSystem OpenFile(path string, f int) (*File, error) // sftp package has a strange OpenFile method with no perm + ReadLink(path string) (string, error) + Mkdir(name string) error } // common subset of os.File and sftp.File @@ -43,12 +48,18 @@ type svrFile interface { type nativeFs struct { } -func (nfs *nativeFs) Lstat(p string) (os.FileInfo, error) { return os.Lstat(p) } -func (nfs *nativeFs) Mkdir(name string, perm os.FileMode) error { return os.Mkdir(name, perm) } -func (nfs *nativeFs) OpenFile(name string, flag int, perm os.FileMode) (file *os.File, err error) { - return os.OpenFile(name, flag, perm) +func (nfs *nativeFs) Lstat(path string) (os.FileInfo, error) { return os.Lstat(path) } +func (nfs *nativeFs) Mkdir(path string, perm os.FileMode) error { return os.Mkdir(path, perm) } +func (nfs *nativeFs) Remove(path string) error { return os.Remove(path) } +func (nfs *nativeFs) Rename(oldpath, newpath string) error { return os.Rename(oldpath, newpath) } +func (nfs *nativeFs) Readlink(path string) (string, error) { return os.Readlink(path) } +func (nfs *nativeFs) OpenFile(path string, flag int, perm os.FileMode) (file *os.File, err error) { + return os.OpenFile(path, flag, perm) } +var __typecheck_fsos FileSystemOS = &nativeFs{} +var __typecheck_sftpos FileSystemSFTP = &Client{} + type Server struct { in io.Reader out io.Writer @@ -164,15 +175,20 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable case ssh_FXP_SETSTAT: case ssh_FXP_FSETSTAT: case ssh_FXP_OPENDIR: + pkt = &sshFxpOpendirPacket{} case ssh_FXP_READDIR: + pkt = &sshFxpReaddirPacket{} case ssh_FXP_REMOVE: + pkt = &sshFxpRemovePacket{} case ssh_FXP_MKDIR: pkt = &sshFxpMkdirPacket{} case ssh_FXP_RMDIR: case ssh_FXP_REALPATH: case ssh_FXP_STAT: case ssh_FXP_RENAME: + pkt = &sshFxpRenamePacket{} case ssh_FXP_READLINK: + pkt = &sshFxpReadlinkPacket{} case ssh_FXP_SYMLINK: case ssh_FXP_STATUS: case ssh_FXP_HANDLE: @@ -206,12 +222,12 @@ func (p sshFxInitPacket) respond(svr *Server) error { return svr.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) } -type sshFxpStatReponse struct { +type sshFxpStatResponse struct { Id uint32 info os.FileInfo } -func (p sshFxpStatReponse) MarshalBinary() ([]byte, error) { +func (p sshFxpStatResponse) MarshalBinary() ([]byte, error) { b := []byte{ssh_FXP_ATTRS} b = marshalUint32(b, p.Id) b = marshalFileInfo(b, p.info) @@ -223,7 +239,7 @@ func (p sshFxpLstatPacket) respond(svr *Server) error { if info, err := svr.fs.Lstat(p.Path); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { - return svr.sendPacket(sshFxpStatReponse{p.Id, info}) + return svr.sendPacket(sshFxpStatResponse{p.Id, info}) } } @@ -234,7 +250,7 @@ func (p sshFxpFstatPacket) respond(svr *Server) error { if info, err := osf.Stat(); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { - return svr.sendPacket(sshFxpStatReponse{p.Id, info}) + return svr.sendPacket(sshFxpStatResponse{p.Id, info}) } } else { // server error... @@ -243,11 +259,59 @@ func (p sshFxpFstatPacket) respond(svr *Server) error { } func (p sshFxpMkdirPacket) respond(svr *Server) error { - // ignore flags field - err := svr.fs.Mkdir(p.Path, 0755) + if svr.readOnly { + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } + // TODO FIXME: ignore flags field + if fso, ok := svr.fs.(FileSystemOS); ok { + err := fso.Mkdir(p.Path, 0755) + return svr.sendPacket(statusFromError(p.Id, err)) + } else if sftpo, ok := svr.fs.(FileSystemSFTP); ok { + err := sftpo.Mkdir(p.Path) + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + return svr.sendPacket(statusFromError(p.Id, fmt.Errorf("unknown filesystem backend"))) + } +} + +func (p sshFxpRemovePacket) respond(svr *Server) error { + if svr.readOnly { + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } + err := svr.fs.Remove(p.Filename) return svr.sendPacket(statusFromError(p.Id, err)) } +func (p sshFxpRenamePacket) respond(svr *Server) error { + if svr.readOnly { + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } + err := svr.fs.Rename(p.Oldpath, p.Newpath) + return svr.sendPacket(statusFromError(p.Id, err)) +} + +func (p sshFxpReadlinkPacket) respond(svr *Server) error { + if fso, ok := svr.fs.(FileSystemOS); ok { + if f, err := fso.Readlink(p.Path); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, nil}}}) + } + } else if sftpo, ok := svr.fs.(FileSystemSFTP); ok { + if f, err := sftpo.ReadLink(p.Path); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, nil}}}) + } + } else { + return svr.sendPacket(statusFromError(p.Id, fmt.Errorf("unknown filesystem backend"))) + } +} + +func (p sshFxpOpendirPacket) respond(svr *Server) error { + return sshFxpOpenPacket{p.Id, p.Path, ssh_FXF_READ, 0}.respond(svr) +} + func (p sshFxpOpenPacket) respond(svr *Server) error { osFlags := 0 if p.Pflags&ssh_FXF_READ != 0 && p.Pflags&ssh_FXF_WRITE != 0 { @@ -281,14 +345,14 @@ func (p sshFxpOpenPacket) respond(svr *Server) error { osFlags |= os.O_EXCL } - if fso, ok := svr.fs.(FileSystemOpen); ok { + if fso, ok := svr.fs.(FileSystemOS); ok { if f, err := fso.OpenFile(p.Path, osFlags, 0644); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { handle := svr.nextHandle(f) return svr.sendPacket(sshFxpHandlePacket{p.Id, handle}) } - } else if sftpo, ok := svr.fs.(FileSystemSFTPOpen); ok { + } else if sftpo, ok := svr.fs.(FileSystemSFTP); ok { if f, err := sftpo.OpenFile(p.Path, osFlags); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { @@ -312,7 +376,6 @@ func (p sshFxpReadPacket) respond(svr *Server) error { p.Len = maxWritePacket } if osf, ok := f.(*os.File); ok { - debug("in readpacket server respond: len %d", p.Len) ret := sshFxpDataPacket{Id: p.Id, Length: p.Len, Data: make([]byte, p.Len)} if n, err := osf.ReadAt(ret.Data, int64(p.Offset)); err != nil && (err != io.EOF || n == 0) { return svr.sendPacket(statusFromError(p.Id, err)) @@ -328,6 +391,10 @@ func (p sshFxpReadPacket) respond(svr *Server) error { } func (p sshFxpWritePacket) respond(svr *Server) error { + if svr.readOnly { + // shouldn't really get here, the open should have failed + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } if f, ok := svr.getHandle(p.Handle); !ok { return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) } else if osf, ok := f.(*os.File); ok { @@ -339,6 +406,32 @@ func (p sshFxpWritePacket) respond(svr *Server) error { } } +func (p sshFxpReaddirPacket) respond(svr *Server) error { + if f, ok := svr.getHandle(p.Handle); !ok { + return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + } else { + dirents := []os.FileInfo{} + var err error = nil + + if osf, ok := f.(*os.File); ok { + dirents, err = osf.Readdir(128) + } else { + // server error... + return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + } + + if err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) + } + + ret := sshFxpNamePacket{p.Id, nil} + for _, dirent := range dirents { + ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{dirent.Name(), dirent}) + } + return svr.sendPacket(ret) + } +} + func errnoToSshErr(errno syscall.Errno) uint32 { if errno == 0 { return ssh_FX_OK From 0f2bc1aa17059fdfe2e4c03b450933d264c1f8c7 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Fri, 31 Jul 2015 23:09:51 -0700 Subject: [PATCH 11/41] server is passing the client integration tests now. I don't understand the ATTRs field, it has a name in it. --- client.go | 1 + client_integration_test.go | 28 ++++++++++++++++++++++++++++ packet.go | 6 ++++-- server.go | 2 +- 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index f895548c..67200822 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package sftp import ( "encoding" + "fmt" "io" "os" "path" diff --git a/client_integration_test.go b/client_integration_test.go index 49aa1e18..a0d38aba 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -27,6 +27,7 @@ const ( debuglevel = "ERROR" // set to "DEBUG" for debugging ) +var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance") var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") var testSftp = flag.String("sftp", "/usr/lib/openssh/sftp-server", "location of the sftp server binary") @@ -36,6 +37,33 @@ func testClient(t testing.TB, readonly bool) (*Client, *exec.Cmd) { if !*testIntegration { t.Skip("skipping intergration test") } + + if *testServerImpl { + txPipeRd, txPipeWr := io.Pipe() + rxPipeRd, rxPipeWr := io.Pipe() + + server, err := NewServer(txPipeRd, rxPipeWr, os.Stderr, 0, readonly, ".") + if err != nil { + t.Fatal(err) + } + go server.Run() + + client, err := NewClientPipe(rxPipeRd, txPipeWr) + if err != nil { + t.Fatal(err) + } + + if err := client.sendInit(); err != nil { + t.Fatal(err) + } + if err := client.recvVersion(); err != nil { + t.Fatal(err) + } + + // dummy command... + return client, exec.Command("true") + } + cmd := exec.Command(*testSftp, "-e", "-R", "-l", debuglevel) // log to stderr, read only if !readonly { cmd = exec.Command(*testSftp, "-e", "-l", debuglevel) // log to stderr diff --git a/packet.go b/packet.go index 293f4f7d..bff7e7e9 100644 --- a/packet.go +++ b/packet.go @@ -348,13 +348,15 @@ func (p *sshFxpReadlinkPacket) UnmarshalBinary(b []byte) error { type sshFxpNameAttr struct { Name string - Attrs interface{} + Attrs []interface{} } func (p sshFxpNameAttr) MarshalBinary() ([]byte, error) { b := []byte{} b = marshalString(b, p.Name) - b = marshal(b, p.Attrs) + for _, attr := range p.Attrs { + b = marshal(b, attr) + } return b, nil } diff --git a/server.go b/server.go index 0f0d4128..3d0adcec 100644 --- a/server.go +++ b/server.go @@ -426,7 +426,7 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error { ret := sshFxpNamePacket{p.Id, nil} for _, dirent := range dirents { - ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{dirent.Name(), dirent}) + ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{dirent.Name(), []interface{}{dirent.Name(), dirent}}) } return svr.sendPacket(ret) } From f4c4138a0e19cd3a99f8259512819923f8eaaf7a Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Tue, 4 Aug 2015 20:47:26 -0700 Subject: [PATCH 12/41] update client integration tests for more coverage --- client.go | 1 - client_integration_test.go | 192 +++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 67200822..f895548c 100644 --- a/client.go +++ b/client.go @@ -2,7 +2,6 @@ package sftp import ( "encoding" - "fmt" "io" "os" "path" diff --git a/client_integration_test.go b/client_integration_test.go index a0d38aba..63619b0f 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -11,11 +11,15 @@ import ( "math/rand" "os" "os/exec" + "os/user" "path" "path/filepath" "reflect" + "regexp" + "strconv" "testing" "testing/quick" + "time" "github.com/kr/fs" ) @@ -27,6 +31,8 @@ const ( debuglevel = "ERROR" // set to "DEBUG" for debugging ) +var spaceRegex = regexp.MustCompile(`\s+`) + var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance") var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") var testSftp = flag.String("sftp", "/usr/lib/openssh/sftp-server", "location of the sftp server binary") @@ -450,6 +456,192 @@ func TestClientReadLink(t *testing.T) { } } +func TestClientChmod(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Chmod(f.Name(), 0531); err != nil { + t.Fatal(err) + } + if stat, err := os.Stat(f.Name()); err != nil { + t.Fatal(err) + } else if stat.Mode()&os.ModePerm != 0531 { + t.Fatalf("invalid perm %o\n", stat.Mode()) + } +} + +func TestClientChmodReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Chmod(f.Name(), 0531); err == nil { + t.Fatal("expected error") + } +} + +func TestClientChown(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + usr, err := user.Current() + if err != nil { + t.Fatal(err) + } + chownto, err := user.Lookup("daemon") // seems common-ish... + if err != nil { + t.Fatal(err) + } + + if usr.Uid != "0" { + t.Log("must be root to run chown tests") + t.Skip() + } + toUid, err := strconv.Atoi(chownto.Uid) + if err != nil { + t.Fatal(err) + } + toGid, err := strconv.Atoi(chownto.Gid) + if err != nil { + t.Fatal(err) + } + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + before, err := exec.Command("ls", "-nl", f.Name()).Output() + if err != nil { + t.Fatal(err) + } + if err := sftp.Chown(f.Name(), toUid, toGid); err != nil { + t.Fatal(err) + } + after, err := exec.Command("ls", "-nl", f.Name()).Output() + if err != nil { + t.Fatal(err) + } + + beforeWords := spaceRegex.Split(string(before), -1) + if beforeWords[2] != "0" { + t.Fatalf("bad previous user? should be root") + } + afterWords := spaceRegex.Split(string(after), -1) + if afterWords[2] != chownto.Uid || afterWords[3] != chownto.Gid { + t.Fatalf("bad chown: %#v", afterWords) + } + t.Logf("before: %v", string(before)) + t.Logf(" after: %v", string(after)) +} + +func TestClientChownReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + usr, err := user.Current() + if err != nil { + t.Fatal(err) + } + chownto, err := user.Lookup("daemon") // seems common-ish... + if err != nil { + t.Fatal(err) + } + + if usr.Uid != "0" { + t.Log("must be root to run chown tests") + t.Skip() + } + toUid, err := strconv.Atoi(chownto.Uid) + if err != nil { + t.Fatal(err) + } + toGid, err := strconv.Atoi(chownto.Gid) + if err != nil { + t.Fatal(err) + } + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + if err := sftp.Chown(f.Name(), toUid, toGid); err == nil { + t.Fatal("expected error") + } +} + +func TestClientChtimes(t *testing.T) { + sftp, cmd := testClient(t, READWRITE) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + + atime := time.Date(2013, 2, 23, 13, 24, 35, 0, time.UTC) + mtime := time.Date(1985, 6, 12, 6, 6, 6, 0, time.UTC) + if err := sftp.Chtimes(f.Name(), atime, mtime); err != nil { + t.Fatal(err) + } + if stat, err := os.Stat(f.Name()); err != nil { + t.Fatal(err) + } else if stat.ModTime().Sub(mtime) != 0 { + t.Fatalf("incorrect mtime: %v vs %v", stat.ModTime(), mtime) + } +} + +func TestClientChtimesReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + + atime := time.Date(2013, 2, 23, 13, 24, 35, 0, time.UTC) + mtime := time.Date(1985, 6, 12, 6, 6, 6, 0, time.UTC) + if err := sftp.Chtimes(f.Name(), atime, mtime); err == nil { + t.Fatal("expected error") + } +} + +/* +func TestClientStatVFS(t *testing.T) { + sftp, cmd := testClient(t, READONLY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + + if svfs, err := sftp.StatVFS("/"); err != nil { + t.Fatal(err) + } else { + t.Fatalf("vfs: %v", *svfs) + } +} + +func (c *Client) StatVFS(path string) (*StatVFS, error) +func (c *Client) Truncate(path string, size int64) error +func (c *Client) Walk(root string) *fs.Walker +*/ + func sameFile(want, got os.FileInfo) bool { return want.Name() == got.Name() && want.Size() == got.Size() From 348ee1a4692ea4570c2d1b421f0e500487e5195d Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Tue, 4 Aug 2015 22:21:35 -0700 Subject: [PATCH 13/41] truncate + darwin VFS test --- client_integration_darwin_test.go | 41 ++++++++++++++++++++++++++ client_integration_test.go | 49 ++++++++++++++++++++++++------- 2 files changed, 80 insertions(+), 10 deletions(-) create mode 100644 client_integration_darwin_test.go diff --git a/client_integration_darwin_test.go b/client_integration_darwin_test.go new file mode 100644 index 00000000..949a0427 --- /dev/null +++ b/client_integration_darwin_test.go @@ -0,0 +1,41 @@ +package sftp + +import ( + "syscall" + "testing" +) + +func TestClientStatVFS(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + vfs, err := sftp.StatVFS("/") + if err != nil { + t.Fatal(err) + } + + // get system stats + s := syscall.Statfs_t{} + err = syscall.Statfs("/", &s) + if err != nil { + t.Fatal(err) + } + + // check some stats + if vfs.Files != uint64(s.Files) { + t.Fatal("fr_size does not match") + } + + if vfs.Bfree != uint64(s.Bfree) { + t.Fatal("f_bsize does not match") + } + + if vfs.Favail != uint64(s.Ffree) { + t.Fatal("f_namemax does not match") + } + + if vfs.Bavail != s.Bavail { + t.Fatal("f_bavail does not match") + } +} diff --git a/client_integration_test.go b/client_integration_test.go index 5347fc5f..0d35a42e 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -663,9 +663,8 @@ func TestClientChtimesReadonly(t *testing.T) { } } -/* -func TestClientStatVFS(t *testing.T) { - sftp, cmd := testClient(t, READONLY) +func TestClientTruncate(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) defer cmd.Wait() defer sftp.Close() @@ -673,18 +672,48 @@ func TestClientStatVFS(t *testing.T) { if err != nil { t.Fatal(err) } + fname := f.Name() - if svfs, err := sftp.StatVFS("/"); err != nil { + if n, err := f.Write([]byte("hello world")); n != 11 || err != nil { + t.Fatal(err) + } + f.Close() + + if err := sftp.Truncate(fname, 5); err != nil { + t.Fatal(err) + } + if stat, err := os.Stat(fname); err != nil { t.Fatal(err) - } else { - t.Fatalf("vfs: %v", *svfs) + } else if stat.Size() != 5 { + t.Fatalf("unexpected size: %d", stat.Size()) } } -func (c *Client) StatVFS(path string) (*StatVFS, error) -func (c *Client) Truncate(path string, size int64) error -func (c *Client) Walk(root string) *fs.Walker -*/ +func TestClientTruncateReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + fname := f.Name() + + if n, err := f.Write([]byte("hello world")); n != 11 || err != nil { + t.Fatal(err) + } + f.Close() + + if err := sftp.Truncate(fname, 5); err == nil { + t.Fatal("expected error") + } + if stat, err := os.Stat(fname); err != nil { + t.Fatal(err) + } else if stat.Size() != 11 { + t.Fatalf("unexpected size: %d", stat.Size()) + } +} func sameFile(want, got os.FileInfo) bool { return want.Name() == got.Name() && From 82ef5086ee79b79f19ee12c0867e41748d30a45d Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Tue, 4 Aug 2015 23:37:18 -0700 Subject: [PATCH 14/41] server integration test scaffolding --- server_integration_test.go | 342 +++++++++++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100644 server_integration_test.go diff --git a/server_integration_test.go b/server_integration_test.go new file mode 100644 index 00000000..c3120a43 --- /dev/null +++ b/server_integration_test.go @@ -0,0 +1,342 @@ +package sftp + +// sftp server integration tests +// enable with -integration + +import ( + "bytes" + "encoding/hex" + "flag" + "fmt" + "net" + "os" + "os/exec" + "strconv" + "strings" + "testing" + + "github.com/ScriptRock/crypto/ssh" +) + +var testSftpClientBin = flag.String("sftp_client", "/usr/bin/sftp", "location of the sftp client binary") + +/*********************************************************************************************** + + +SSH server scaffolding; very simple, no strict auth. This is for unit testing, not real servers + + +***********************************************************************************************/ + +var ( + hostPrivateKeySigner ssh.Signer + privKey = []byte(` +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEArhp7SqFnXVZAgWREL9Ogs+miy4IU/m0vmdkoK6M97G9NX/Pj +wf8I/3/ynxmcArbt8Rc4JgkjT2uxx/NqR0yN42N1PjO5Czu0dms1PSqcKIJdeUBV +7gdrKSm9Co4d2vwfQp5mg47eG4w63pz7Drk9+VIyi9YiYH4bve7WnGDswn4ycvYZ +slV5kKnjlfCdPig+g5P7yQYud0cDWVwyA0+kxvL6H3Ip+Fu8rLDZn4/P1WlFAIuc +PAf4uEKDGGmC2URowi5eesYR7f6GN/HnBs2776laNlAVXZUmYTUfOGagwLsEkx8x +XdNqntfbs2MOOoK+myJrNtcB9pCrM0H6um19uQIDAQABAoIBABkWr9WdVKvalgkP +TdQmhu3mKRNyd1wCl+1voZ5IM9Ayac/98UAvZDiNU4Uhx52MhtVLJ0gz4Oa8+i16 +IkKMAZZW6ro/8dZwkBzQbieWUFJ2Fso2PyvB3etcnGU8/Yhk9IxBDzy+BbuqhYE2 +1ebVQtz+v1HvVZzaD11bYYm/Xd7Y28QREVfFen30Q/v3dv7dOteDE/RgDS8Czz7w +jMW32Q8JL5grz7zPkMK39BLXsTcSYcaasT2ParROhGJZDmbgd3l33zKCVc1zcj9B +SA47QljGd09Tys958WWHgtj2o7bp9v1Ufs4LnyKgzrB80WX1ovaSQKvd5THTLchO +kLIhUAECgYEA2doGXy9wMBmTn/hjiVvggR1aKiBwUpnB87Hn5xCMgoECVhFZlT6l +WmZe7R2klbtG1aYlw+y+uzHhoVDAJW9AUSV8qoDUwbRXvBVlp+In5wIqJ+VjfivK +zgIfzomL5NvDz37cvPmzqIeySTowEfbQyq7CUQSoDtE9H97E2wWZhDkCgYEAzJdJ +k+NSFoTkHhfD3L0xCDHpRV3gvaOeew8524fVtVUq53X8m91ng4AX1r74dCUYwwiF +gqTtSSJfx2iH1xKnNq28M9uKg7wOrCKrRqNPnYUO3LehZEC7rwUr26z4iJDHjjoB +uBcS7nw0LJ+0Zeg1IF+aIdZGV3MrAKnrzWPixYECgYBsffX6ZWebrMEmQ89eUtFF +u9ZxcGI/4K8ErC7vlgBD5ffB4TYZ627xzFWuBLs4jmHCeNIJ9tct5rOVYN+wRO1k +/CRPzYUnSqb+1jEgILL6istvvv+DkE+ZtNkeRMXUndWwel94BWsBnUKe0UmrSJ3G +sq23J3iCmJW2T3z+DpXbkQKBgQCK+LUVDNPE0i42NsRnm+fDfkvLP7Kafpr3Umdl +tMY474o+QYn+wg0/aPJIf9463rwMNyyhirBX/k57IIktUdFdtfPicd2MEGETElWv +nN1GzYxD50Rs2f/jKisZhEwqT9YNyV9DkgDdGGdEbJNYqbv0qpwDIg8T9foe8E1p +bdErgQKBgAt290I3L316cdxIQTkJh1DlScN/unFffITwu127WMr28Jt3mq3cZpuM +Aecey/eEKCj+Rlas5NDYKsB18QIuAw+qqWyq0LAKLiAvP1965Rkc4PLScl3MgJtO +QYa37FK0p8NcDeUuF86zXBVutwS5nJLchHhKfd590ks57OROtm29 +-----END RSA PRIVATE KEY----- +`) +) + +func init() { + var err error + hostPrivateKeySigner, err = ssh.ParsePrivateKey(privKey) + if err != nil { + panic(err) + } +} + +func keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + } + return permissions, nil +} + +func pwAuth(conn ssh.ConnMetadata, pw []byte) (*ssh.Permissions, error) { + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + } + return permissions, nil +} + +func basicServerConfig() *ssh.ServerConfig { + config := ssh.ServerConfig{ + Config: ssh.Config{ + MACs: []string{"hmac-sha1"}, + }, + PasswordCallback: pwAuth, + PublicKeyCallback: keyAuth, + } + config.AddHostKey(hostPrivateKeySigner) + return &config +} + +type sshServer struct { + conn net.Conn + config *ssh.ServerConfig + sshConn *ssh.ServerConn + newChans <-chan ssh.NewChannel + newReqs <-chan *ssh.Request +} + +func sshServerFromConn(conn net.Conn, config *ssh.ServerConfig) (*sshServer, error) { + // From a standard TCP connection to an encrypted SSH connection + sshConn, newChans, newReqs, err := ssh.NewServerConn(conn, config) + if err != nil { + return nil, err + } + + svr := &sshServer{conn, config, sshConn, newChans, newReqs} + svr.listenChannels() + return svr, nil +} + +func (svr *sshServer) Wait() error { + return svr.sshConn.Wait() +} + +func (svr *sshServer) Close() error { + return svr.sshConn.Close() +} + +func (svr *sshServer) listenChannels() { + go func() { + for chanReq := range svr.newChans { + go svr.handleChanReq(chanReq) + } + }() + go func() { + for req := range svr.newReqs { + go svr.handleReq(req) + } + }() +} + +func (svr *sshServer) handleReq(req *ssh.Request) { + switch req.Type { + default: + rejectRequest(req) + } +} + +type sshChannelServer struct { + svr *sshServer + chanReq ssh.NewChannel + ch ssh.Channel + newReqs <-chan *ssh.Request +} + +type sshSessionChannelServer struct { + *sshChannelServer + env []string +} + +func (svr *sshServer) handleChanReq(chanReq ssh.NewChannel) { + fmt.Fprintf(os.Stderr, "channel request: %v, extra: '%v'\n", chanReq.ChannelType(), hex.EncodeToString(chanReq.ExtraData())) + switch chanReq.ChannelType() { + case "session": + if ch, reqs, err := chanReq.Accept(); err != nil { + fmt.Fprintf(os.Stderr, "fail to accept channel request: %v\n", err) + chanReq.Reject(ssh.ResourceShortage, "channel accept failure") + } else { + chsvr := &sshSessionChannelServer{ + sshChannelServer: &sshChannelServer{svr, chanReq, ch, reqs}, + env: append([]string{}, os.Environ()...), + } + chsvr.handle() + } + default: + chanReq.Reject(ssh.UnknownChannelType, "channel type is not a session") + } +} + +func (chsvr *sshSessionChannelServer) handle() { + // should maybe do something here... + go chsvr.handleReqs() +} + +func (chsvr *sshSessionChannelServer) handleReqs() { + for req := range chsvr.newReqs { + chsvr.handleReq(req) + } +} + +func (chsvr *sshSessionChannelServer) handleReq(req *ssh.Request) { + switch req.Type { + case "env": + chsvr.handleEnv(req) + case "subsystem": + chsvr.handleSubsystem(req) + default: + rejectRequest(req) + } +} + +func rejectRequest(req *ssh.Request) error { + fmt.Fprintf(os.Stderr, "ssh rejecting request, type: %s\n", req.Type) + err := req.Reply(false, []byte{}) + if err != nil { + fmt.Fprintf(os.Stderr, "ssh request reply had error: %v\n", err) + } + return err +} + +func rejectRequestUnmarshalError(req *ssh.Request, s interface{}, err error) error { + fmt.Fprintf(os.Stderr, "ssh request unmarshaling error, type '%T': %v\n", s, err) + rejectRequest(req) + return err +} + +// env request form: +type sshEnvRequest struct { + Envvar string + Value string +} + +func (chsvr *sshSessionChannelServer) handleEnv(req *ssh.Request) error { + envReq := &sshEnvRequest{} + if err := ssh.Unmarshal(req.Payload, envReq); err != nil { + return rejectRequestUnmarshalError(req, envReq, err) + } + req.Reply(true, nil) + + found := false + for i, envstr := range chsvr.env { + if strings.HasPrefix(envstr, envReq.Envvar+"=") { + found = true + chsvr.env[i] = envReq.Envvar + "=" + envReq.Value + } + } + if !found { + chsvr.env = append(chsvr.env, envReq.Envvar+"="+envReq.Value) + } + + return nil +} + +// Payload: int: command size, string: command +type sshSubsystemRequest struct { + Name string +} +type sshSubsystemExitStatus struct { + Status uint32 +} + +func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error { + defer chsvr.ch.Close() + + subsystemReq := &sshSubsystemRequest{} + if err := ssh.Unmarshal(req.Payload, subsystemReq); err != nil { + return rejectRequestUnmarshalError(req, subsystemReq, err) + } + + // reply to the ssh client + + // no idea if this is actually correct spec-wise. + // just enough for an sftp server to start. + if subsystemReq.Name == "sftp" { + req.Reply(true, nil) + + sftpServer, err := NewServer(chsvr.ch, chsvr.ch, os.Stderr, 0, false, ".") + if err != nil { + return err + } + + // wait for the session to close + return sftpServer.Run() + } else { + return req.Reply(false, nil) + } +} + +/*********************************************************************************************** + + +Actual unit tests + + +***********************************************************************************************/ + +// starts an ssh server to test. returns: host string and port +func testServer(t *testing.T, readonly bool) (string, int) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + host, portStr, err := net.SplitHostPort(listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + t.Fatal(err) + } + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + break + } + + _, err = sshServerFromConn(conn, basicServerConfig()) + if err != nil { + t.Fatal(err) + } + } + }() + + return host, port +} + +func runSftpClient(script string, path string, host string, port int) (string, error) { + cmd := exec.Command(*testSftpClientBin, "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path)) + stdout := &bytes.Buffer{} + cmd.Stdin = bytes.NewBufferString(script) + cmd.Stdout = stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return "", err + } + err := cmd.Wait() + return string(stdout.Bytes()), err +} + +func TestServerLstat(t *testing.T) { + host, port := testServer(t, READONLY) + + script := "ls" + output, err := runSftpClient(script, "/tmp/", host, port) + if err != nil { + t.Fatal(err) + } + + t.Log(output) +} From 4325c3654b9b80109b6231525d6a8c815d314930 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Wed, 5 Aug 2015 12:57:28 -0700 Subject: [PATCH 15/41] fix format of 'name' packets (shortname, longname, attrs), add Stat --- client.go | 24 +++++++++++++++++++ packet.go | 48 +++++++++++++++++++++++++++++++------- server.go | 42 ++++++++++++++++++++++++++++++--- server_integration_test.go | 2 +- 4 files changed, 104 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 7325f865..58e89e0a 100644 --- a/client.go +++ b/client.go @@ -260,6 +260,30 @@ func (c *Client) opendir(path string) (string, error) { } } +func (c *Client) Stat(p string) (os.FileInfo, error) { + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpStatPacket{ + Id: id, + Path: p, + }) + if err != nil { + return nil, err + } + switch typ { + case ssh_FXP_ATTRS: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIdErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return fileInfoFromStat(attr, path.Base(p)), nil + case ssh_FXP_STATUS: + return nil, unmarshalStatus(id, data) + default: + return nil, unimplementedPacketErr(typ) + } +} + func (c *Client) Lstat(p string) (os.FileInfo, error) { id := c.nextId() typ, data, err := c.sendRequest(sshFxpLstatPacket{ diff --git a/packet.go b/packet.go index 79e692f4..aaba532d 100644 --- a/packet.go +++ b/packet.go @@ -246,6 +246,8 @@ type sshFxpReaddirPacket struct { Handle string } +func (p sshFxpReaddirPacket) id() uint32 { return p.Id } + func (p sshFxpReaddirPacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_READDIR, p.Id, p.Handle) } @@ -254,13 +256,13 @@ func (p *sshFxpReaddirPacket) UnmarshalBinary(b []byte) error { return unmarshalIdString(b, &p.Id, &p.Handle) } -func (p sshFxpReaddirPacket) id() uint32 { return p.Id } - type sshFxpOpendirPacket struct { Id uint32 Path string } +func (p sshFxpOpendirPacket) id() uint32 { return p.Id } + func (p sshFxpOpendirPacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_OPENDIR, p.Id, p.Path) } @@ -269,8 +271,6 @@ func (p *sshFxpOpendirPacket) UnmarshalBinary(b []byte) error { return unmarshalIdString(b, &p.Id, &p.Path) } -func (p sshFxpOpendirPacket) id() uint32 { return p.Id } - type sshFxpLstatPacket struct { Id uint32 Path string @@ -286,6 +286,21 @@ func (p *sshFxpLstatPacket) UnmarshalBinary(b []byte) error { return unmarshalIdString(b, &p.Id, &p.Path) } +type sshFxpStatPacket struct { + Id uint32 + Path string +} + +func (p sshFxpStatPacket) id() uint32 { return p.Id } + +func (p sshFxpStatPacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_LSTAT, p.Id, p.Path) +} + +func (p *sshFxpStatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Path) +} + type sshFxpFstatPacket struct { Id uint32 Handle string @@ -306,6 +321,8 @@ type sshFxpClosePacket struct { Handle string } +func (p sshFxpClosePacket) id() uint32 { return p.Id } + func (p sshFxpClosePacket) MarshalBinary() ([]byte, error) { return marshalIdString(ssh_FXP_CLOSE, p.Id, p.Handle) } @@ -314,8 +331,6 @@ func (p *sshFxpClosePacket) UnmarshalBinary(b []byte) error { return unmarshalIdString(b, &p.Id, &p.Handle) } -func (p sshFxpClosePacket) id() uint32 { return p.Id } - type sshFxpRemovePacket struct { Id uint32 Filename string @@ -361,14 +376,31 @@ func (p *sshFxpReadlinkPacket) UnmarshalBinary(b []byte) error { return unmarshalIdString(b, &p.Id, &p.Path) } +type sshFxpRealpathPacket struct { + Id uint32 + Path string +} + +func (p sshFxpRealpathPacket) id() uint32 { return p.Id } + +func (p sshFxpRealpathPacket) MarshalBinary() ([]byte, error) { + return marshalIdString(ssh_FXP_READLINK, p.Id, p.Path) +} + +func (p *sshFxpRealpathPacket) UnmarshalBinary(b []byte) error { + return unmarshalIdString(b, &p.Id, &p.Path) +} + type sshFxpNameAttr struct { - Name string - Attrs []interface{} + Name string + LongName string + Attrs []interface{} } func (p sshFxpNameAttr) MarshalBinary() ([]byte, error) { b := []byte{} b = marshalString(b, p.Name) + b = marshalString(b, p.LongName) for _, attr := range p.Attrs { b = marshal(b, attr) } diff --git a/server.go b/server.go index c5d0e2e3..32b93cdd 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,8 @@ import ( "fmt" "io" "os" + "path" + "path/filepath" "sync" "syscall" ) @@ -14,6 +16,7 @@ import ( type FileSystem interface { Lstat(name string) (os.FileInfo, error) Remove(name string) error + Stat(name string) (os.FileInfo, error) Rename(oldpath, newpath string) error } @@ -21,6 +24,7 @@ type FileSystemOS interface { FileSystem OpenFile(name string, flag int, perm os.FileMode) (file *os.File, err error) Readlink(path string) (string, error) + Realpath(path string) (string, error) Mkdir(name string, perm os.FileMode) error } @@ -49,10 +53,15 @@ type nativeFs struct { } func (nfs *nativeFs) Lstat(path string) (os.FileInfo, error) { return os.Lstat(path) } +func (nfs *nativeFs) Stat(path string) (os.FileInfo, error) { return os.Stat(path) } func (nfs *nativeFs) Mkdir(path string, perm os.FileMode) error { return os.Mkdir(path, perm) } func (nfs *nativeFs) Remove(path string) error { return os.Remove(path) } func (nfs *nativeFs) Rename(oldpath, newpath string) error { return os.Rename(oldpath, newpath) } func (nfs *nativeFs) Readlink(path string) (string, error) { return os.Readlink(path) } +func (nfs *nativeFs) Realpath(path string) (string, error) { + f, err := filepath.Abs(path) + return filepath.Clean(f), err +} func (nfs *nativeFs) OpenFile(path string, flag int, perm os.FileMode) (file *os.File, err error) { return os.OpenFile(path, flag, perm) } @@ -186,7 +195,9 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable pkt = &sshFxpMkdirPacket{} case ssh_FXP_RMDIR: case ssh_FXP_REALPATH: + pkt = &sshFxpRealpathPacket{} case ssh_FXP_STAT: + pkt = &sshFxpStatPacket{} case ssh_FXP_RENAME: pkt = &sshFxpRenamePacket{} case ssh_FXP_READLINK: @@ -245,6 +256,15 @@ func (p sshFxpLstatPacket) respond(svr *Server) error { } } +func (p sshFxpStatPacket) respond(svr *Server) error { + // stat the requested file + if info, err := svr.fs.Stat(p.Path); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + return svr.sendPacket(sshFxpStatResponse{p.Id, info}) + } +} + func (p sshFxpFstatPacket) respond(svr *Server) error { if f, ok := svr.getHandle(p.Handle); !ok { return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) @@ -292,18 +312,32 @@ func (p sshFxpRenamePacket) respond(svr *Server) error { return svr.sendPacket(statusFromError(p.Id, err)) } +var emptyFileStat = []interface{}{uint32(0)} + func (p sshFxpReadlinkPacket) respond(svr *Server) error { if fso, ok := svr.fs.(FileSystemOS); ok { if f, err := fso.Readlink(p.Path); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { - return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, nil}}}) + return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}}) } } else if sftpo, ok := svr.fs.(FileSystemSFTP); ok { if f, err := sftpo.ReadLink(p.Path); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { - return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, nil}}}) + return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}}) + } + } else { + return svr.sendPacket(statusFromError(p.Id, fmt.Errorf("unknown filesystem backend"))) + } +} + +func (p sshFxpRealpathPacket) respond(svr *Server) error { + if fso, ok := svr.fs.(FileSystemOS); ok { + if f, err := fso.Realpath(p.Path); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) + } else { + return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}}) } } else { return svr.sendPacket(statusFromError(p.Id, fmt.Errorf("unknown filesystem backend"))) @@ -412,10 +446,12 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error { if f, ok := svr.getHandle(p.Handle); !ok { return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) } else { + dirname := "" dirents := []os.FileInfo{} var err error = nil if osf, ok := f.(*os.File); ok { + dirname = osf.Name() dirents, err = osf.Readdir(128) } else { // server error... @@ -428,7 +464,7 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error { ret := sshFxpNamePacket{p.Id, nil} for _, dirent := range dirents { - ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{dirent.Name(), []interface{}{dirent.Name(), dirent}}) + ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{dirent.Name(), path.Join(dirname, dirent.Name()), []interface{}{dirent}}) } return svr.sendPacket(ret) } diff --git a/server_integration_test.go b/server_integration_test.go index c3120a43..f1a2a947 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -322,7 +322,7 @@ func runSftpClient(script string, path string, host string, port int) (string, e cmd.Stdin = bytes.NewBufferString(script) cmd.Stdout = stdout cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { + if err := cmd.Start(); err != nil { return "", err } err := cmd.Wait() From f9e831be301c4e386109f9bc79784eea95210bee Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Wed, 5 Aug 2015 23:24:33 -0700 Subject: [PATCH 16/41] proper ssh closing sequence --- server.go | 26 +++++++------- server_integration_test.go | 70 ++++++++++++++++++++++++++++---------- 2 files changed, 66 insertions(+), 30 deletions(-) diff --git a/server.go b/server.go index 32b93cdd..c0e73984 100644 --- a/server.go +++ b/server.go @@ -149,14 +149,15 @@ func (svr *Server) rxPackets() error { for { pktType, pktBytes, err := recvPacket(svr.in) if err == io.EOF { + fmt.Fprintf(svr.debugStream, "rxPackets loop done\n") return nil } else if err != nil { - fmt.Fprintf(os.Stderr, "recvPacket error: %v\n", err) + fmt.Fprintf(svr.debugStream, "recvPacket error: %v\n", err) return err } if pkt, err := svr.decodePacket(fxp(pktType), pktBytes); err != nil { - fmt.Fprintf(os.Stderr, "decodePacket error: %v\n", err) + fmt.Fprintf(svr.debugStream, "decodePacket error: %v\n", err) return err } else { svr.pktChan <- pkt @@ -164,6 +165,17 @@ func (svr *Server) rxPackets() error { } } +// Run this server until the streams stop or until the subsystem is stopped +func (svr *Server) Run() error { + go svr.rxPackets() + for pkt := range svr.pktChan { + fmt.Fprintf(svr.debugStream, "pkt: %T %v\n", pkt, pkt) + pkt.respond(svr) + } + fmt.Fprintf(svr.debugStream, "Run finished\n") + return nil +} + func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondablePacket, error) { //pktId, restBytes := unmarshalUint32(pktBytes[1:]) var pkt serverRespondablePacket = nil @@ -221,16 +233,6 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable return pkt, nil } -// Run this server until the streams stop or until the subsystem is stopped -func (svr *Server) Run() error { - go svr.rxPackets() - for pkt := range svr.pktChan { - fmt.Fprintf(os.Stderr, "pkt: %T %v\n", pkt, pkt) - pkt.respond(svr) - } - return nil -} - func (p sshFxInitPacket) respond(svr *Server) error { return svr.sendPacket(sshFxVersionPacket{sftpProtocolVersion, nil}) } diff --git a/server_integration_test.go b/server_integration_test.go index f1a2a947..f22918d8 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -19,6 +19,9 @@ import ( ) var testSftpClientBin = flag.String("sftp_client", "/usr/bin/sftp", "location of the sftp client binary") +var sshServerDebugStream = os.Stdout // ioutil.Discard +var sftpServerDebugStream = os.Stdout // ioutil.Discard +var sftpClientDebugStream = os.Stdout // ioutil.Discard /*********************************************************************************************** @@ -158,11 +161,11 @@ type sshSessionChannelServer struct { } func (svr *sshServer) handleChanReq(chanReq ssh.NewChannel) { - fmt.Fprintf(os.Stderr, "channel request: %v, extra: '%v'\n", chanReq.ChannelType(), hex.EncodeToString(chanReq.ExtraData())) + fmt.Fprintf(sshServerDebugStream, "channel request: %v, extra: '%v'\n", chanReq.ChannelType(), hex.EncodeToString(chanReq.ExtraData())) switch chanReq.ChannelType() { case "session": if ch, reqs, err := chanReq.Accept(); err != nil { - fmt.Fprintf(os.Stderr, "fail to accept channel request: %v\n", err) + fmt.Fprintf(sshServerDebugStream, "fail to accept channel request: %v\n", err) chanReq.Reject(ssh.ResourceShortage, "channel accept failure") } else { chsvr := &sshSessionChannelServer{ @@ -185,6 +188,7 @@ func (chsvr *sshSessionChannelServer) handleReqs() { for req := range chsvr.newReqs { chsvr.handleReq(req) } + fmt.Fprintf(sshServerDebugStream, "ssh server session channel complete\n") } func (chsvr *sshSessionChannelServer) handleReq(req *ssh.Request) { @@ -199,16 +203,16 @@ func (chsvr *sshSessionChannelServer) handleReq(req *ssh.Request) { } func rejectRequest(req *ssh.Request) error { - fmt.Fprintf(os.Stderr, "ssh rejecting request, type: %s\n", req.Type) + fmt.Fprintf(sshServerDebugStream, "ssh rejecting request, type: %s\n", req.Type) err := req.Reply(false, []byte{}) if err != nil { - fmt.Fprintf(os.Stderr, "ssh request reply had error: %v\n", err) + fmt.Fprintf(sshServerDebugStream, "ssh request reply had error: %v\n", err) } return err } func rejectRequestUnmarshalError(req *ssh.Request, s interface{}, err error) error { - fmt.Fprintf(os.Stderr, "ssh request unmarshaling error, type '%T': %v\n", s, err) + fmt.Fprintf(sshServerDebugStream, "ssh request unmarshaling error, type '%T': %v\n", s, err) rejectRequest(req) return err } @@ -244,12 +248,17 @@ func (chsvr *sshSessionChannelServer) handleEnv(req *ssh.Request) error { type sshSubsystemRequest struct { Name string } + type sshSubsystemExitStatus struct { Status uint32 } func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error { - defer chsvr.ch.Close() + defer func() { + err1 := chsvr.ch.CloseWrite() + err2 := chsvr.ch.Close() + fmt.Fprintf(sshServerDebugStream, "ssh server subsystem request complete, err: %v %v\n", err1, err2) + }() subsystemReq := &sshSubsystemRequest{} if err := ssh.Unmarshal(req.Payload, subsystemReq); err != nil { @@ -263,13 +272,32 @@ func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error { if subsystemReq.Name == "sftp" { req.Reply(true, nil) - sftpServer, err := NewServer(chsvr.ch, chsvr.ch, os.Stderr, 0, false, ".") - if err != nil { - return err - } + if false { + // use the sftp server backend; this is to test the ssh code, not the sftp code + cmd := exec.Command(*testSftp, "-e", "-l", "DEBUG") // log to stderr + cmd.Stdin = chsvr.ch + cmd.Stdout = chsvr.ch + cmd.Stderr = sftpServerDebugStream + if err := cmd.Start(); err != nil { + return err + } + return cmd.Wait() + } else { + sftpServer, err := NewServer(chsvr.ch, chsvr.ch, sftpServerDebugStream, 0, false, ".") + if err != nil { + return err + } + + // wait for the session to close + runErr := sftpServer.Run() + exitStatus := uint32(1) + if runErr == nil { + exitStatus = uint32(0) + } - // wait for the session to close - return sftpServer.Run() + _, exitStatusErr := chsvr.ch.SendRequest("exit-status", false, ssh.Marshal(sshSubsystemExitStatus{exitStatus})) + return exitStatusErr + } } else { return req.Reply(false, nil) } @@ -303,13 +331,19 @@ func testServer(t *testing.T, readonly bool) (string, int) { for { conn, err := listener.Accept() if err != nil { + fmt.Fprintf(sshServerDebugStream, "ssh server socket closed\n") break } - _, err = sshServerFromConn(conn, basicServerConfig()) - if err != nil { - t.Fatal(err) - } + go func() { + defer conn.Close() + sshSvr, err := sshServerFromConn(conn, basicServerConfig()) + if err != nil { + t.Fatal(err) + } + err = sshSvr.Wait() + fmt.Fprintf(sshServerDebugStream, "ssh server finished, err: %v\n", err) + }() } }() @@ -317,11 +351,11 @@ func testServer(t *testing.T, readonly bool) (string, int) { } func runSftpClient(script string, path string, host string, port int) (string, error) { - cmd := exec.Command(*testSftpClientBin, "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path)) + cmd := exec.Command(*testSftpClientBin, "-vvvv", "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path)) stdout := &bytes.Buffer{} cmd.Stdin = bytes.NewBufferString(script) cmd.Stdout = stdout - cmd.Stderr = os.Stderr + cmd.Stderr = sftpClientDebugStream if err := cmd.Start(); err != nil { return "", err } From 1502f6c9e692efe84f61aa58e2732c4782f56ec3 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Fri, 7 Aug 2015 00:51:14 -0700 Subject: [PATCH 17/41] compare golang sftp subsystem to openssh --- server_integration_test.go | 55 ++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/server_integration_test.go b/server_integration_test.go index f22918d8..3ba49dfd 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -23,6 +23,11 @@ var sshServerDebugStream = os.Stdout // ioutil.Discard var sftpServerDebugStream = os.Stdout // ioutil.Discard var sftpClientDebugStream = os.Stdout // ioutil.Discard +const ( + GOLANG_SFTP = true + OPENSSH_SFTP = false +) + /*********************************************************************************************** @@ -101,21 +106,22 @@ func basicServerConfig() *ssh.ServerConfig { } type sshServer struct { - conn net.Conn - config *ssh.ServerConfig - sshConn *ssh.ServerConn - newChans <-chan ssh.NewChannel - newReqs <-chan *ssh.Request + useSubsystem bool + conn net.Conn + config *ssh.ServerConfig + sshConn *ssh.ServerConn + newChans <-chan ssh.NewChannel + newReqs <-chan *ssh.Request } -func sshServerFromConn(conn net.Conn, config *ssh.ServerConfig) (*sshServer, error) { +func sshServerFromConn(conn net.Conn, useSubsystem bool, config *ssh.ServerConfig) (*sshServer, error) { // From a standard TCP connection to an encrypted SSH connection sshConn, newChans, newReqs, err := ssh.NewServerConn(conn, config) if err != nil { return nil, err } - svr := &sshServer{conn, config, sshConn, newChans, newReqs} + svr := &sshServer{useSubsystem, conn, config, sshConn, newChans, newReqs} svr.listenChannels() return svr, nil } @@ -272,8 +278,9 @@ func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error { if subsystemReq.Name == "sftp" { req.Reply(true, nil) - if false { - // use the sftp server backend; this is to test the ssh code, not the sftp code + if !chsvr.svr.useSubsystem { + // use the openssh sftp server backend; this is to test the ssh code, not the sftp code, + // or is used for comparison between our sftp subsystem and the openssh sftp subsystem cmd := exec.Command(*testSftp, "-e", "-l", "DEBUG") // log to stderr cmd.Stdin = chsvr.ch cmd.Stdout = chsvr.ch @@ -312,7 +319,7 @@ Actual unit tests ***********************************************************************************************/ // starts an ssh server to test. returns: host string and port -func testServer(t *testing.T, readonly bool) (string, int) { +func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, string, int) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) @@ -337,7 +344,7 @@ func testServer(t *testing.T, readonly bool) (string, int) { go func() { defer conn.Close() - sshSvr, err := sshServerFromConn(conn, basicServerConfig()) + sshSvr, err := sshServerFromConn(conn, useSubsystem, basicServerConfig()) if err != nil { t.Fatal(err) } @@ -347,7 +354,7 @@ func testServer(t *testing.T, readonly bool) (string, int) { } }() - return host, port + return listener, host, port } func runSftpClient(script string, path string, host string, port int) (string, error) { @@ -363,14 +370,28 @@ func runSftpClient(script string, path string, host string, port int) (string, e return string(stdout.Bytes()), err } -func TestServerLstat(t *testing.T) { - host, port := testServer(t, READONLY) +func TestServerCompareSubsystems(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + listenerOp, hostOp, portOp := testServer(t, OPENSSH_SFTP, READONLY) + defer listenerGo.Close() + defer listenerOp.Close() + + script := ` +ls / +ls /dev/ +` + outputGo, err := runSftpClient(script, "/", hostGo, portGo) + if err != nil { + t.Fatal(err) + } - script := "ls" - output, err := runSftpClient(script, "/tmp/", host, port) + outputOp, err := runSftpClient(script, "/", hostOp, portOp) if err != nil { t.Fatal(err) } - t.Log(output) + if outputGo != outputOp { + t.Errorf("outputs differ, go:\n%v\nopenssh:\n%v\n", outputGo, outputOp) + } + t.Logf("go output:\n%v\nopenssh output:\n%v\n", outputGo, outputOp) } From 0d8e136458ae142a68114833c66ebc31f202d33d Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 19:36:47 -0700 Subject: [PATCH 18/41] removing sftp server to client layer; using straight os.* calls --- client_integration_test.go | 120 +++++++++++++++---- packet.go | 23 +++- server.go | 234 +++++++++++++------------------------ server_integration_test.go | 7 +- 4 files changed, 202 insertions(+), 182 deletions(-) diff --git a/client_integration_test.go b/client_integration_test.go index 0d35a42e..b707726e 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -86,6 +86,30 @@ func (w delayedWriter) Close() error { return nil } +func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration) (*Client, *exec.Cmd) { + txPipeRd, txPipeWr := io.Pipe() + rxPipeRd, rxPipeWr := io.Pipe() + + server, err := NewServer(txPipeRd, rxPipeWr, os.Stderr, 0, readonly, ".") + if err != nil { + t.Fatal(err) + } + go server.Run() + + var ctx io.WriteCloser = txPipeWr + if delay > NO_DELAY { + ctx = newDelayedWriter(ctx, delay) + } + + client, err := NewClientPipe(rxPipeRd, ctx) + if err != nil { + t.Fatal(err) + } + + // dummy command... + return client, exec.Command("true") +} + // testClient returns a *Client connected to a localy running sftp-server // the *exec.Cmd returned must be defer Wait'd. func testClient(t testing.TB, readonly bool, delay time.Duration) (*Client, *exec.Cmd) { @@ -94,29 +118,7 @@ func testClient(t testing.TB, readonly bool, delay time.Duration) (*Client, *exe } if *testServerImpl { - txPipeRd, txPipeWr := io.Pipe() - rxPipeRd, rxPipeWr := io.Pipe() - - server, err := NewServer(txPipeRd, rxPipeWr, os.Stderr, 0, readonly, ".") - if err != nil { - t.Fatal(err) - } - go server.Run() - - client, err := NewClientPipe(rxPipeRd, txPipeWr) - if err != nil { - t.Fatal(err) - } - - if err := client.sendInit(); err != nil { - t.Fatal(err) - } - if err := client.recvVersion(); err != nil { - t.Fatal(err) - } - - // dummy command... - return client, exec.Command("true") + return testClientGoSvr(t, readonly, delay) } cmd := exec.Command(*testSftp, "-e", "-R", "-l", debuglevel) // log to stderr, read only @@ -757,6 +759,78 @@ func TestClientReadSimple(t *testing.T) { } } +func TestClientReadDir(t *testing.T) { + sftp1, cmd1 := testClient(t, READONLY, NO_DELAY) + sftp2, cmd2 := testClientGoSvr(t, READONLY, NO_DELAY) + defer cmd1.Wait() + defer cmd2.Wait() + defer sftp1.Close() + defer sftp2.Close() + + dir := "/dev/" + + d, err := os.Open(dir) + if err != nil { + t.Fatal(err) + } + defer d.Close() + osfiles, err := d.Readdir(4096) + if err != nil { + t.Fatal(err) + } + + sftp1Files, err := sftp1.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + sftp2Files, err := sftp2.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + + osFilesByName := map[string]os.FileInfo{} + for _, f := range osfiles { + osFilesByName[f.Name()] = f + } + sftp1FilesByName := map[string]os.FileInfo{} + for _, f := range sftp1Files { + sftp1FilesByName[f.Name()] = f + } + sftp2FilesByName := map[string]os.FileInfo{} + for _, f := range sftp2Files { + sftp2FilesByName[f.Name()] = f + } + + if len(osFilesByName) != len(sftp1FilesByName) || len(sftp1FilesByName) != len(sftp2FilesByName) { + t.Fatalf("os gives %v, sftp1 gives %v, sftp2 gives %v", len(osFilesByName), len(sftp1FilesByName), len(sftp2FilesByName)) + } + + for name, osF := range osFilesByName { + sftp1F, ok := sftp1FilesByName[name] + if !ok { + t.Fatalf("%v present in os but not sftp1", name) + } + sftp2F, ok := sftp2FilesByName[name] + if !ok { + t.Fatalf("%v present in os but not sftp2", name) + } + + //t.Logf("%v: %v %v %v", name, osF, sftp1F, sftp2F) + if osF.Size() != sftp1F.Size() || sftp1F.Size() != sftp2F.Size() { + t.Fatalf("size %v %v %v", osF.Size(), sftp1F.Size(), sftp2F.Size()) + } + if osF.IsDir() != sftp1F.IsDir() || sftp1F.IsDir() != sftp2F.IsDir() { + t.Fatalf("isdir %v %v %v", osF.IsDir(), sftp1F.IsDir(), sftp2F.IsDir()) + } + if osF.ModTime().Sub(sftp1F.ModTime()) > time.Second || sftp1F.ModTime() != sftp2F.ModTime() { + t.Fatalf("modtime %v %v %v", osF.ModTime(), sftp1F.ModTime(), sftp2F.ModTime()) + } + if osF.Mode() != sftp1F.Mode() || sftp1F.Mode() != sftp2F.Mode() { + t.Fatalf("mode %x %x %x", osF.Mode(), sftp1F.Mode(), sftp2F.Mode()) + } + } +} + var clientReadTests = []struct { n int64 }{ diff --git a/packet.go b/packet.go index aaba532d..4cfe75b4 100644 --- a/packet.go +++ b/packet.go @@ -8,7 +8,13 @@ import ( "reflect" ) -var shortPacketError = fmt.Errorf("packet too short") +var ( + shortPacketError = fmt.Errorf("packet too short") + debugDumpTxPacket = false + debugDumpRxPacket = false + debugDumpTxPacketBytes = false + debugDumpRxPacketBytes = false +) func marshalUint32(b []byte, v uint32) []byte { return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) @@ -105,9 +111,13 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { if err != nil { return fmt.Errorf("marshal2(%#v): binary marshaller failed", err) } + if debugDumpTxPacketBytes { + debug("send packet: %s %d bytes %x", fxp(bb[0]).String(), len(bb), bb[1:]) + } else if debugDumpTxPacket { + debug("send packet: %s %d bytes", fxp(bb[0]).String(), len(bb)) + } l := uint32(len(bb)) hdr := []byte{byte(l >> 24), byte(l >> 16), byte(l >> 8), byte(l)} - debug("send packet %T, len: %v", m, l) _, err = w.Write(hdr) if err != nil { return err @@ -117,6 +127,9 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { } func (svr *Server) sendPacket(m encoding.BinaryMarshaler) error { + // any responder can call sendPacket(); actual socket access must be serialized + svr.outMutex.Lock() + defer svr.outMutex.Unlock() return sendPacket(svr.out, m) } @@ -126,12 +139,16 @@ func recvPacket(r io.Reader) (uint8, []byte, error) { return 0, nil, err } l, _ := unmarshalUint32(b) - debug("recv packet %d bytes", l) b = make([]byte, l) if _, err := io.ReadFull(r, b); err != nil { debug("recv packet %d bytes: err %v", l, err) return 0, nil, err } + if debugDumpRxPacketBytes { + debug("recv packet: %s %d bytes %x", fxp(b[0]).String(), l, b[1:]) + } else if debugDumpRxPacket { + debug("recv packet: %s %d bytes", fxp(b[0]).String(), l) + } return b[0], b[1:], nil } diff --git a/server.go b/server.go index c0e73984..3a01d264 100644 --- a/server.go +++ b/server.go @@ -13,79 +13,24 @@ import ( "syscall" ) -type FileSystem interface { - Lstat(name string) (os.FileInfo, error) - Remove(name string) error - Stat(name string) (os.FileInfo, error) - Rename(oldpath, newpath string) error -} - -type FileSystemOS interface { - FileSystem - OpenFile(name string, flag int, perm os.FileMode) (file *os.File, err error) - Readlink(path string) (string, error) - Realpath(path string) (string, error) - Mkdir(name string, perm os.FileMode) error -} - -type FileSystemSFTP interface { - FileSystem - OpenFile(path string, f int) (*File, error) // sftp package has a strange OpenFile method with no perm - ReadLink(path string) (string, error) - Mkdir(name string) error -} - -// common subset of os.File and sftp.File -type svrFile interface { - Chmod(mode os.FileMode) error - Chown(uid, gid int) error - Close() error - Read(b []byte) (int, error) - Seek(offset int64, whence int) (int64, error) - Stat() (os.FileInfo, error) - Truncate(size int64) error - Write(b []byte) (int, error) - // func (f *File) WriteTo(w io.Writer) (int64, error) // not in os - // func (f *File) ReadFrom(r io.Reader) (int64, error) // not in os -} - -type nativeFs struct { -} - -func (nfs *nativeFs) Lstat(path string) (os.FileInfo, error) { return os.Lstat(path) } -func (nfs *nativeFs) Stat(path string) (os.FileInfo, error) { return os.Stat(path) } -func (nfs *nativeFs) Mkdir(path string, perm os.FileMode) error { return os.Mkdir(path, perm) } -func (nfs *nativeFs) Remove(path string) error { return os.Remove(path) } -func (nfs *nativeFs) Rename(oldpath, newpath string) error { return os.Rename(oldpath, newpath) } -func (nfs *nativeFs) Readlink(path string) (string, error) { return os.Readlink(path) } -func (nfs *nativeFs) Realpath(path string) (string, error) { - f, err := filepath.Abs(path) - return filepath.Clean(f), err -} -func (nfs *nativeFs) OpenFile(path string, flag int, perm os.FileMode) (file *os.File, err error) { - return os.OpenFile(path, flag, perm) -} - -var __typecheck_fsos FileSystemOS = &nativeFs{} -var __typecheck_sftpos FileSystemSFTP = &Client{} - type Server struct { in io.Reader - out io.Writer + out io.WriteCloser + outMutex *sync.Mutex debugStream io.Writer debugLevel int readOnly bool rootDir string lastId uint32 - fs FileSystem - pktChan chan serverRespondablePacket - openFiles map[string]svrFile + pktChan chan rxPacket + openFiles map[string]*os.File openFilesLock *sync.RWMutex handleCount int maxTxPacket uint32 + WorkerCount int } -func (svr *Server) nextHandle(f svrFile) string { +func (svr *Server) nextHandle(f *os.File) string { svr.openFilesLock.Lock() defer svr.openFilesLock.Unlock() svr.handleCount++ @@ -105,7 +50,7 @@ func (svr *Server) closeHandle(handle string) error { } } -func (svr *Server) getHandle(handle string) (svrFile, bool) { +func (svr *Server) getHandle(handle string) (*os.File, bool) { svr.openFilesLock.RLock() defer svr.openFilesLock.RUnlock() f, ok := svr.openFiles[handle] @@ -119,7 +64,7 @@ type serverRespondablePacket interface { // Creates a new server instance around the provided streams. // A subsequent call to Run() is required. -func NewServer(in io.Reader, out io.Writer, debugStream io.Writer, debugLevel int, readOnly bool, rootDir string) (*Server, error) { +func NewServer(in io.Reader, out io.WriteCloser, debugStream io.Writer, debugLevel int, readOnly bool, rootDir string) (*Server, error) { if rootDir == "" { if wd, err := os.Getwd(); err != nil { return nil, err @@ -127,21 +72,28 @@ func NewServer(in io.Reader, out io.Writer, debugStream io.Writer, debugLevel in rootDir = wd } } + workerCount := 8 return &Server{ in: in, out: out, + outMutex: &sync.Mutex{}, debugStream: debugStream, debugLevel: debugLevel, readOnly: readOnly, rootDir: rootDir, - fs: &nativeFs{}, - pktChan: make(chan serverRespondablePacket, 4), - openFiles: map[string]svrFile{}, + pktChan: make(chan rxPacket, workerCount), + openFiles: map[string]*os.File{}, openFilesLock: &sync.RWMutex{}, maxTxPacket: 1 << 15, + WorkerCount: workerCount, }, nil } +type rxPacket struct { + pktType fxp + pktBytes []byte +} + // Unmarshal a single logical packet from the secure channel func (svr *Server) rxPackets() error { defer close(svr.pktChan) @@ -156,24 +108,43 @@ func (svr *Server) rxPackets() error { return err } - if pkt, err := svr.decodePacket(fxp(pktType), pktBytes); err != nil { + svr.pktChan <- rxPacket{fxp(pktType), pktBytes} + } +} + +// Up to N parallel servers +func (svr *Server) sftpServerWorker(doneChan chan error) { + for pkt := range svr.pktChan { + if pkt, err := svr.decodePacket(pkt.pktType, pkt.pktBytes); err != nil { fmt.Fprintf(svr.debugStream, "decodePacket error: %v\n", err) - return err + doneChan <- err + return } else { - svr.pktChan <- pkt + //fmt.Fprintf(svr.debugStream, "pkt: %T %v\n", pkt, pkt) + pkt.respond(svr) } } + doneChan <- nil } // Run this server until the streams stop or until the subsystem is stopped func (svr *Server) Run() error { + if svr.WorkerCount <= 0 { + return fmt.Errorf("sftp server requires > 0 workers") + } go svr.rxPackets() - for pkt := range svr.pktChan { - fmt.Fprintf(svr.debugStream, "pkt: %T %v\n", pkt, pkt) - pkt.respond(svr) + doneChan := make(chan error) + for i := 0; i < svr.WorkerCount; i++ { + go svr.sftpServerWorker(doneChan) } - fmt.Fprintf(svr.debugStream, "Run finished\n") - return nil + for i := 0; i < svr.WorkerCount; i++ { + if err := <-doneChan; err != nil { + // abort early and shut down the session on un-decodable packets + break + } + } + fmt.Fprintf(svr.debugStream, "sftp server run finished\n") + return svr.out.Close() } func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondablePacket, error) { @@ -251,7 +222,7 @@ func (p sshFxpStatResponse) MarshalBinary() ([]byte, error) { func (p sshFxpLstatPacket) respond(svr *Server) error { // stat the requested file - if info, err := svr.fs.Lstat(p.Path); err != nil { + if info, err := os.Lstat(p.Path); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { return svr.sendPacket(sshFxpStatResponse{p.Id, info}) @@ -260,7 +231,7 @@ func (p sshFxpLstatPacket) respond(svr *Server) error { func (p sshFxpStatPacket) respond(svr *Server) error { // stat the requested file - if info, err := svr.fs.Stat(p.Path); err != nil { + if info, err := os.Stat(p.Path); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { return svr.sendPacket(sshFxpStatResponse{p.Id, info}) @@ -270,15 +241,10 @@ func (p sshFxpStatPacket) respond(svr *Server) error { func (p sshFxpFstatPacket) respond(svr *Server) error { if f, ok := svr.getHandle(p.Handle); !ok { return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) - } else if osf, ok := f.(*os.File); ok { - if info, err := osf.Stat(); err != nil { - return svr.sendPacket(statusFromError(p.Id, err)) - } else { - return svr.sendPacket(sshFxpStatResponse{p.Id, info}) - } + } else if info, err := f.Stat(); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) } else { - // server error... - return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + return svr.sendPacket(sshFxpStatResponse{p.Id, info}) } } @@ -287,22 +253,15 @@ func (p sshFxpMkdirPacket) respond(svr *Server) error { return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) } // TODO FIXME: ignore flags field - if fso, ok := svr.fs.(FileSystemOS); ok { - err := fso.Mkdir(p.Path, 0755) - return svr.sendPacket(statusFromError(p.Id, err)) - } else if sftpo, ok := svr.fs.(FileSystemSFTP); ok { - err := sftpo.Mkdir(p.Path) - return svr.sendPacket(statusFromError(p.Id, err)) - } else { - return svr.sendPacket(statusFromError(p.Id, fmt.Errorf("unknown filesystem backend"))) - } + err := os.Mkdir(p.Path, 0755) + return svr.sendPacket(statusFromError(p.Id, err)) } func (p sshFxpRemovePacket) respond(svr *Server) error { if svr.readOnly { return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) } - err := svr.fs.Remove(p.Filename) + err := os.Remove(p.Filename) return svr.sendPacket(statusFromError(p.Id, err)) } @@ -310,39 +269,26 @@ func (p sshFxpRenamePacket) respond(svr *Server) error { if svr.readOnly { return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) } - err := svr.fs.Rename(p.Oldpath, p.Newpath) + err := os.Rename(p.Oldpath, p.Newpath) return svr.sendPacket(statusFromError(p.Id, err)) } var emptyFileStat = []interface{}{uint32(0)} func (p sshFxpReadlinkPacket) respond(svr *Server) error { - if fso, ok := svr.fs.(FileSystemOS); ok { - if f, err := fso.Readlink(p.Path); err != nil { - return svr.sendPacket(statusFromError(p.Id, err)) - } else { - return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}}) - } - } else if sftpo, ok := svr.fs.(FileSystemSFTP); ok { - if f, err := sftpo.ReadLink(p.Path); err != nil { - return svr.sendPacket(statusFromError(p.Id, err)) - } else { - return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}}) - } + if f, err := os.Readlink(p.Path); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) } else { - return svr.sendPacket(statusFromError(p.Id, fmt.Errorf("unknown filesystem backend"))) + return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}}) } } func (p sshFxpRealpathPacket) respond(svr *Server) error { - if fso, ok := svr.fs.(FileSystemOS); ok { - if f, err := fso.Realpath(p.Path); err != nil { - return svr.sendPacket(statusFromError(p.Id, err)) - } else { - return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}}) - } + if f, err := filepath.Abs(p.Path); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) } else { - return svr.sendPacket(statusFromError(p.Id, fmt.Errorf("unknown filesystem backend"))) + f = filepath.Clean(f) + return svr.sendPacket(sshFxpNamePacket{p.Id, []sshFxpNameAttr{sshFxpNameAttr{f, f, emptyFileStat}}}) } } @@ -383,22 +329,11 @@ func (p sshFxpOpenPacket) respond(svr *Server) error { osFlags |= os.O_EXCL } - if fso, ok := svr.fs.(FileSystemOS); ok { - if f, err := fso.OpenFile(p.Path, osFlags, 0644); err != nil { - return svr.sendPacket(statusFromError(p.Id, err)) - } else { - handle := svr.nextHandle(f) - return svr.sendPacket(sshFxpHandlePacket{p.Id, handle}) - } - } else if sftpo, ok := svr.fs.(FileSystemSFTP); ok { - if f, err := sftpo.OpenFile(p.Path, osFlags); err != nil { - return svr.sendPacket(statusFromError(p.Id, err)) - } else { - handle := svr.nextHandle(f) - return svr.sendPacket(sshFxpHandlePacket{p.Id, handle}) - } + if f, err := os.OpenFile(p.Path, osFlags, 0644); err != nil { + return svr.sendPacket(statusFromError(p.Id, err)) } else { - return svr.sendPacket(statusFromError(p.Id, fmt.Errorf("unknown filesystem backend"))) + handle := svr.nextHandle(f) + return svr.sendPacket(sshFxpHandlePacket{p.Id, handle}) } } @@ -413,17 +348,12 @@ func (p sshFxpReadPacket) respond(svr *Server) error { if p.Len > svr.maxTxPacket { p.Len = svr.maxTxPacket } - if osf, ok := f.(*os.File); ok { - ret := sshFxpDataPacket{Id: p.Id, Length: p.Len, Data: make([]byte, p.Len)} - if n, err := osf.ReadAt(ret.Data, int64(p.Offset)); err != nil && (err != io.EOF || n == 0) { - return svr.sendPacket(statusFromError(p.Id, err)) - } else { - ret.Length = uint32(n) - return svr.sendPacket(ret) - } + ret := sshFxpDataPacket{Id: p.Id, Length: p.Len, Data: make([]byte, p.Len)} + if n, err := f.ReadAt(ret.Data, int64(p.Offset)); err != nil && (err != io.EOF || n == 0) { + return svr.sendPacket(statusFromError(p.Id, err)) } else { - // server error... - return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + ret.Length = uint32(n) + return svr.sendPacket(ret) } } } @@ -435,12 +365,9 @@ func (p sshFxpWritePacket) respond(svr *Server) error { } if f, ok := svr.getHandle(p.Handle); !ok { return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) - } else if osf, ok := f.(*os.File); ok { - _, err := osf.WriteAt(p.Data, int64(p.Offset)) - return svr.sendPacket(statusFromError(p.Id, err)) } else { - // server error... - return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + _, err := f.WriteAt(p.Data, int64(p.Offset)) + return svr.sendPacket(statusFromError(p.Id, err)) } } @@ -452,22 +379,21 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error { dirents := []os.FileInfo{} var err error = nil - if osf, ok := f.(*os.File); ok { - dirname = osf.Name() - dirents, err = osf.Readdir(128) - } else { - // server error... - return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) - } - + dirname = f.Name() + dirents, err = f.Readdir(128) if err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } ret := sshFxpNamePacket{p.Id, nil} for _, dirent := range dirents { - ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{dirent.Name(), path.Join(dirname, dirent.Name()), []interface{}{dirent}}) + ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{ + dirent.Name(), + path.Join(dirname, dirent.Name()), + []interface{}{dirent}, + }) } + //debug("readdir respond %v", ret) return svr.sendPacket(ret) } } diff --git a/server_integration_test.go b/server_integration_test.go index 3ba49dfd..689398e9 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -377,8 +377,11 @@ func TestServerCompareSubsystems(t *testing.T) { defer listenerOp.Close() script := ` -ls / -ls /dev/ +#ls / +#ls -l / +#ls /dev/ +#ls -l /dev/ +ls -l /tmp/shit/ ` outputGo, err := runSftpClient(script, "/", hostGo, portGo) if err != nil { From 5d19ff23723391887db39cf3f5694275f66c891c Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 19:37:30 -0700 Subject: [PATCH 19/41] copy os/user/* to here --- group/lookup.go | 22 ++++++ group/lookup_plan9.go | 46 +++++++++++++ group/lookup_stubs.go | 28 ++++++++ group/lookup_unix.go | 112 ++++++++++++++++++++++++++++++ group/lookup_windows.go | 149 ++++++++++++++++++++++++++++++++++++++++ group/user.go | 43 ++++++++++++ group/user_test.go | 89 ++++++++++++++++++++++++ 7 files changed, 489 insertions(+) create mode 100644 group/lookup.go create mode 100644 group/lookup_plan9.go create mode 100644 group/lookup_stubs.go create mode 100644 group/lookup_unix.go create mode 100644 group/lookup_windows.go create mode 100644 group/user.go create mode 100644 group/user_test.go diff --git a/group/lookup.go b/group/lookup.go new file mode 100644 index 00000000..09f00c7b --- /dev/null +++ b/group/lookup.go @@ -0,0 +1,22 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +// Current returns the current user. +func Current() (*User, error) { + return current() +} + +// Lookup looks up a user by username. If the user cannot be found, the +// returned error is of type UnknownUserError. +func Lookup(username string) (*User, error) { + return lookup(username) +} + +// LookupId looks up a user by userid. If the user cannot be found, the +// returned error is of type UnknownUserIdError. +func LookupId(uid string) (*User, error) { + return lookupId(uid) +} diff --git a/group/lookup_plan9.go b/group/lookup_plan9.go new file mode 100644 index 00000000..f7ef3482 --- /dev/null +++ b/group/lookup_plan9.go @@ -0,0 +1,46 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "fmt" + "io/ioutil" + "os" + "syscall" +) + +// Partial os/user support on Plan 9. +// Supports Current(), but not Lookup()/LookupId(). +// The latter two would require parsing /adm/users. +const ( + userFile = "/dev/user" +) + +func current() (*User, error) { + ubytes, err := ioutil.ReadFile(userFile) + if err != nil { + return nil, fmt.Errorf("user: %s", err) + } + + uname := string(ubytes) + + u := &User{ + Uid: uname, + Gid: uname, + Username: uname, + Name: uname, + HomeDir: os.Getenv("home"), + } + + return u, nil +} + +func lookup(username string) (*User, error) { + return nil, syscall.EPLAN9 +} + +func lookupId(uid string) (*User, error) { + return nil, syscall.EPLAN9 +} diff --git a/group/lookup_stubs.go b/group/lookup_stubs.go new file mode 100644 index 00000000..4fb0e3c6 --- /dev/null +++ b/group/lookup_stubs.go @@ -0,0 +1,28 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !cgo,!windows,!plan9 android + +package user + +import ( + "fmt" + "runtime" +) + +func init() { + implemented = false +} + +func current() (*User, error) { + return nil, fmt.Errorf("user: Current not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +func lookup(username string) (*User, error) { + return nil, fmt.Errorf("user: Lookup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +func lookupId(uid string) (*User, error) { + return nil, fmt.Errorf("user: LookupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} diff --git a/group/lookup_unix.go b/group/lookup_unix.go new file mode 100644 index 00000000..0871473d --- /dev/null +++ b/group/lookup_unix.go @@ -0,0 +1,112 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris +// +build cgo + +package user + +import ( + "fmt" + "runtime" + "strconv" + "strings" + "syscall" + "unsafe" +) + +/* +#include +#include +#include +#include + +static int mygetpwuid_r(int uid, struct passwd *pwd, + char *buf, size_t buflen, struct passwd **result) { + return getpwuid_r(uid, pwd, buf, buflen, result); +} +*/ +import "C" + +func current() (*User, error) { + return lookupUnix(syscall.Getuid(), "", false) +} + +func lookup(username string) (*User, error) { + return lookupUnix(-1, username, true) +} + +func lookupId(uid string) (*User, error) { + i, e := strconv.Atoi(uid) + if e != nil { + return nil, e + } + return lookupUnix(i, "", false) +} + +func lookupUnix(uid int, username string, lookupByName bool) (*User, error) { + var pwd C.struct_passwd + var result *C.struct_passwd + + var bufSize C.long + if runtime.GOOS == "dragonfly" || runtime.GOOS == "freebsd" { + // DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX + // and just return -1. So just use the same + // size that Linux returns. + bufSize = 1024 + } else { + bufSize = C.sysconf(C._SC_GETPW_R_SIZE_MAX) + if bufSize <= 0 || bufSize > 1<<20 { + return nil, fmt.Errorf("user: unreasonable _SC_GETPW_R_SIZE_MAX of %d", bufSize) + } + } + buf := C.malloc(C.size_t(bufSize)) + defer C.free(buf) + var rv C.int + if lookupByName { + nameC := C.CString(username) + defer C.free(unsafe.Pointer(nameC)) + rv = C.getpwnam_r(nameC, + &pwd, + (*C.char)(buf), + C.size_t(bufSize), + &result) + if rv != 0 { + return nil, fmt.Errorf("user: lookup username %s: %s", username, syscall.Errno(rv)) + } + if result == nil { + return nil, UnknownUserError(username) + } + } else { + // mygetpwuid_r is a wrapper around getpwuid_r to + // to avoid using uid_t because C.uid_t(uid) for + // unknown reasons doesn't work on linux. + rv = C.mygetpwuid_r(C.int(uid), + &pwd, + (*C.char)(buf), + C.size_t(bufSize), + &result) + if rv != 0 { + return nil, fmt.Errorf("user: lookup userid %d: %s", uid, syscall.Errno(rv)) + } + if result == nil { + return nil, UnknownUserIdError(uid) + } + } + u := &User{ + Uid: strconv.Itoa(int(pwd.pw_uid)), + Gid: strconv.Itoa(int(pwd.pw_gid)), + Username: C.GoString(pwd.pw_name), + Name: C.GoString(pwd.pw_gecos), + HomeDir: C.GoString(pwd.pw_dir), + } + // The pw_gecos field isn't quite standardized. Some docs + // say: "It is expected to be a comma separated list of + // personal data where the first item is the full name of the + // user." + if i := strings.Index(u.Name, ","); i >= 0 { + u.Name = u.Name[:i] + } + return u, nil +} diff --git a/group/lookup_windows.go b/group/lookup_windows.go new file mode 100644 index 00000000..99c325ff --- /dev/null +++ b/group/lookup_windows.go @@ -0,0 +1,149 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "fmt" + "syscall" + "unsafe" +) + +func isDomainJoined() (bool, error) { + var domain *uint16 + var status uint32 + err := syscall.NetGetJoinInformation(nil, &domain, &status) + if err != nil { + return false, err + } + syscall.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) + return status == syscall.NetSetupDomainName, nil +} + +func lookupFullNameDomain(domainAndUser string) (string, error) { + return syscall.TranslateAccountName(domainAndUser, + syscall.NameSamCompatible, syscall.NameDisplay, 50) +} + +func lookupFullNameServer(servername, username string) (string, error) { + s, e := syscall.UTF16PtrFromString(servername) + if e != nil { + return "", e + } + u, e := syscall.UTF16PtrFromString(username) + if e != nil { + return "", e + } + var p *byte + e = syscall.NetUserGetInfo(s, u, 10, &p) + if e != nil { + return "", e + } + defer syscall.NetApiBufferFree(p) + i := (*syscall.UserInfo10)(unsafe.Pointer(p)) + if i.FullName == nil { + return "", nil + } + name := syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:]) + return name, nil +} + +func lookupFullName(domain, username, domainAndUser string) (string, error) { + joined, err := isDomainJoined() + if err == nil && joined { + name, err := lookupFullNameDomain(domainAndUser) + if err == nil { + return name, nil + } + } + name, err := lookupFullNameServer(domain, username) + if err == nil { + return name, nil + } + // domain worked neigher as a domain nor as a server + // could be domain server unavailable + // pretend username is fullname + return username, nil +} + +func newUser(usid *syscall.SID, gid, dir string) (*User, error) { + username, domain, t, e := usid.LookupAccount("") + if e != nil { + return nil, e + } + if t != syscall.SidTypeUser { + return nil, fmt.Errorf("user: should be user account type, not %d", t) + } + domainAndUser := domain + `\` + username + uid, e := usid.String() + if e != nil { + return nil, e + } + name, e := lookupFullName(domain, username, domainAndUser) + if e != nil { + return nil, e + } + u := &User{ + Uid: uid, + Gid: gid, + Username: domainAndUser, + Name: name, + HomeDir: dir, + } + return u, nil +} + +func current() (*User, error) { + t, e := syscall.OpenCurrentProcessToken() + if e != nil { + return nil, e + } + defer t.Close() + u, e := t.GetTokenUser() + if e != nil { + return nil, e + } + pg, e := t.GetTokenPrimaryGroup() + if e != nil { + return nil, e + } + gid, e := pg.PrimaryGroup.String() + if e != nil { + return nil, e + } + dir, e := t.GetUserProfileDirectory() + if e != nil { + return nil, e + } + return newUser(u.User.Sid, gid, dir) +} + +// BUG(brainman): Lookup and LookupId functions do not set +// Gid and HomeDir fields in the User struct returned on windows. + +func newUserFromSid(usid *syscall.SID) (*User, error) { + // TODO(brainman): do not know where to get gid and dir fields + gid := "unknown" + dir := "Unknown directory" + return newUser(usid, gid, dir) +} + +func lookup(username string) (*User, error) { + sid, _, t, e := syscall.LookupSID("", username) + if e != nil { + return nil, e + } + if t != syscall.SidTypeUser { + return nil, fmt.Errorf("user: should be user account type, not %d", t) + } + return newUserFromSid(sid) +} + +func lookupId(uid string) (*User, error) { + sid, e := syscall.StringToSid(uid) + if e != nil { + return nil, e + } + return newUserFromSid(sid) +} diff --git a/group/user.go b/group/user.go new file mode 100644 index 00000000..e8680fe5 --- /dev/null +++ b/group/user.go @@ -0,0 +1,43 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package user allows user account lookups by name or id. +package user + +import ( + "strconv" +) + +var implemented = true // set to false by lookup_stubs.go's init + +// User represents a user account. +// +// On posix systems Uid and Gid contain a decimal number +// representing uid and gid. On windows Uid and Gid +// contain security identifier (SID) in a string format. +// On Plan 9, Uid, Gid, Username, and Name will be the +// contents of /dev/user. +type User struct { + Uid string // user id + Gid string // primary group id + Username string + Name string + HomeDir string +} + +// UnknownUserIdError is returned by LookupId when +// a user cannot be found. +type UnknownUserIdError int + +func (e UnknownUserIdError) Error() string { + return "user: unknown userid " + strconv.Itoa(int(e)) +} + +// UnknownUserError is returned by Lookup when +// a user cannot be found. +type UnknownUserError string + +func (e UnknownUserError) Error() string { + return "user: unknown user " + string(e) +} diff --git a/group/user_test.go b/group/user_test.go new file mode 100644 index 00000000..9d9420e8 --- /dev/null +++ b/group/user_test.go @@ -0,0 +1,89 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package user + +import ( + "runtime" + "testing" +) + +func check(t *testing.T) { + if !implemented { + t.Skip("user: not implemented; skipping tests") + } +} + +func TestCurrent(t *testing.T) { + check(t) + + u, err := Current() + if err != nil { + t.Fatalf("Current: %v", err) + } + if u.HomeDir == "" { + t.Errorf("didn't get a HomeDir") + } + if u.Username == "" { + t.Errorf("didn't get a username") + } +} + +func compare(t *testing.T, want, got *User) { + if want.Uid != got.Uid { + t.Errorf("got Uid=%q; want %q", got.Uid, want.Uid) + } + if want.Username != got.Username { + t.Errorf("got Username=%q; want %q", got.Username, want.Username) + } + if want.Name != got.Name { + t.Errorf("got Name=%q; want %q", got.Name, want.Name) + } + // TODO(brainman): fix it once we know how. + if runtime.GOOS == "windows" { + t.Skip("skipping Gid and HomeDir comparisons") + } + if want.Gid != got.Gid { + t.Errorf("got Gid=%q; want %q", got.Gid, want.Gid) + } + if want.HomeDir != got.HomeDir { + t.Errorf("got HomeDir=%q; want %q", got.HomeDir, want.HomeDir) + } +} + +func TestLookup(t *testing.T) { + check(t) + + if runtime.GOOS == "plan9" { + t.Skipf("Lookup not implemented on %q", runtime.GOOS) + } + + want, err := Current() + if err != nil { + t.Fatalf("Current: %v", err) + } + got, err := Lookup(want.Username) + if err != nil { + t.Fatalf("Lookup: %v", err) + } + compare(t, want, got) +} + +func TestLookupId(t *testing.T) { + check(t) + + if runtime.GOOS == "plan9" { + t.Skipf("LookupId not implemented on %q", runtime.GOOS) + } + + want, err := Current() + if err != nil { + t.Fatalf("Current: %v", err) + } + got, err := LookupId(want.Uid) + if err != nil { + t.Fatalf("LookupId: %v", err) + } + compare(t, want, got) +} From 8d4ed823fa666fdc084518862876dbead1b17ba0 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 19:58:29 -0700 Subject: [PATCH 20/41] apply patch set 2 from https://codereview.appspot.com/13454043 --- group/lookup.go | 31 ++++++- group/lookup_stubs.go | 12 ++- group/lookup_unix.go | 186 ++++++++++++++++++++++++++++++++++++++++-- group/user.go | 22 +++++ group/user_test.go | 49 +++++++++++ 5 files changed, 290 insertions(+), 10 deletions(-) diff --git a/group/lookup.go b/group/lookup.go index 09f00c7b..66d8782e 100644 --- a/group/lookup.go +++ b/group/lookup.go @@ -12,11 +12,38 @@ func Current() (*User, error) { // Lookup looks up a user by username. If the user cannot be found, the // returned error is of type UnknownUserError. func Lookup(username string) (*User, error) { - return lookup(username) + return lookupUser(username) } // LookupId looks up a user by userid. If the user cannot be found, the // returned error is of type UnknownUserIdError. func LookupId(uid string) (*User, error) { - return lookupId(uid) + return lookupUserId(uid) +} + +// CurrentGroup returns the current group. +func CurrentGroup() (*Group, error) { + return currentGroup() +} + +// LookupGroup looks up a group by name. If the group cannot be found, the +// returned error is of type UnknownGroupError. +func LookupGroup(groupname string) (*Group, error) { + return lookupGroup(groupname) +} + +// LookupGroupId looks up a group by groupid. If the group cannot be found, the +// returned error is of type UnknownGroupIdError. +func LookupGroupId(gid string) (*Group, error) { + return lookupGroupId(gid) +} + +// In indicates whether the user is a member of the given group. +func (u *User) In(g *Group) (bool, error) { + return userInGroup(u, g) +} + +// Members returns the list of members of the group. +func (g *Group) Members() ([]string, error) { + return groupMembers(g) } diff --git a/group/lookup_stubs.go b/group/lookup_stubs.go index 4fb0e3c6..83f7174a 100644 --- a/group/lookup_stubs.go +++ b/group/lookup_stubs.go @@ -19,10 +19,18 @@ func current() (*User, error) { return nil, fmt.Errorf("user: Current not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) } -func lookup(username string) (*User, error) { +func lookupUser(username string) (*User, error) { return nil, fmt.Errorf("user: Lookup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) } -func lookupId(uid string) (*User, error) { +func lookupUserId(uid string) (*User, error) { return nil, fmt.Errorf("user: LookupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) } + +func lookupGroup(groupname string) (*Group, error) { + return nil, fmt.Errorf("user: LookupGroup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +func lookupGroupId(int) (*Group, error) { + return nil, fmt.Errorf("user: LookupGroupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} diff --git a/group/lookup_unix.go b/group/lookup_unix.go index 0871473d..d047efc6 100644 --- a/group/lookup_unix.go +++ b/group/lookup_unix.go @@ -20,32 +20,55 @@ import ( #include #include #include +#include #include static int mygetpwuid_r(int uid, struct passwd *pwd, char *buf, size_t buflen, struct passwd **result) { return getpwuid_r(uid, pwd, buf, buflen, result); } + +static int mygetgrgid_r(int gid, struct group *grp, + char *buf, size_t buflen, struct group **result) { + return getgrgid_r(gid, grp, buf, buflen, result); +} + +static int mygetgrouplist(const char *user, gid_t group, gid_t *groups, + int *ngroups) { + return getgrouplist(user, group, (void *)groups, ngroups); +} + +static inline gid_t group_at(int i, gid_t *groups) { + return groups[i]; +} + +static inline char **next_member(char **members) { return members + 1; } + */ import "C" +const ( + userBuffer = iota + groupBuffer +) + func current() (*User, error) { - return lookupUnix(syscall.Getuid(), "", false) + return lookupUnixUser(syscall.Getuid(), "", false) } -func lookup(username string) (*User, error) { - return lookupUnix(-1, username, true) +func lookupUser(username string) (*User, error) { + return lookupUnixUser(-1, username, true) } -func lookupId(uid string) (*User, error) { +func lookupUserId(uid string) (*User, error) { i, e := strconv.Atoi(uid) if e != nil { return nil, e } - return lookupUnix(i, "", false) + return lookupUnixUser(i, "", false) } -func lookupUnix(uid int, username string, lookupByName bool) (*User, error) { +func lookupUnixUser(uid int, username string, lookupByName bool) (*User, error) { var pwd C.struct_passwd var result *C.struct_passwd @@ -110,3 +133,154 @@ func lookupUnix(uid int, username string, lookupByName bool) (*User, error) { } return u, nil } + +func currentGroup() (*Group, error) { + return lookupUnixGroup(syscall.Getgid(), "", false, buildGroup) +} + +func lookupGroup(groupname string) (*Group, error) { + return lookupUnixGroup(-1, groupname, true, buildGroup) +} + +func lookupGroupId(gid string) (*Group, error) { + i, e := strconv.Atoi(gid) + if e != nil { + return nil, e + } + return lookupUnixGroup(i, "", false, buildGroup) +} + +func lookupUnixGroup(gid int, groupname string, lookupByName bool, f func(*C.struct_group) *Group) (*Group, error) { + var grp C.struct_group + var result *C.struct_group + + buf, bufSize, err := allocBuffer(groupBuffer) + if err != nil { + return nil, err + } + defer C.free(buf) + + if lookupByName { + nameC := C.CString(groupname) + defer C.free(unsafe.Pointer(nameC)) + rv := C.getgrnam_r(nameC, + &grp, + (*C.char)(buf), + C.size_t(bufSize), + &result) + if rv != 0 { + return nil, fmt.Errorf("group: lookup groupname %s: %s", groupname, syscall.Errno(rv)) + } + if result == nil { + return nil, UnknownGroupError(groupname) + } + } else { + // mygetgrgid_r is a wrapper around getgrgid_r to + // to avoid using gid_t because C.gid_t(gid) for + // unknown reasons doesn't work on linux. + rv := C.mygetgrgid_r(C.int(gid), + &grp, + (*C.char)(buf), + C.size_t(bufSize), + &result) + if rv != 0 { + return nil, fmt.Errorf("group: lookup groupid %d: %s", gid, syscall.Errno(rv)) + } + if result == nil { + return nil, UnknownGroupIdError(gid) + } + } + g := f(&grp) + return g, nil +} + +func buildGroup(grp *C.struct_group) *Group { + g := &Group{ + Gid: strconv.Itoa(int(grp.gr_gid)), + Name: C.GoString(grp.gr_name), + } + return g +} + +func userInGroup(u *User, g *Group) (bool, error) { + if u.Gid == g.Gid { + return true, nil + } + gid, err := strconv.Atoi(g.Gid) + if err != nil { + return false, err + } + + nameC := C.CString(u.Username) + defer C.free(unsafe.Pointer(nameC)) + groupC := C.gid_t(gid) + ngroupsC := C.int(0) + + C.mygetgrouplist(nameC, groupC, nil, &ngroupsC) + ngroups := int(ngroupsC) + + groups := C.malloc(C.size_t(int(unsafe.Sizeof(groupC)) * ngroups)) + defer C.free(groups) + + rv := C.mygetgrouplist(nameC, groupC, (*C.gid_t)(groups), &ngroupsC) + if rv == -1 { + return false, fmt.Errorf("user: membership of %s in %s: %s", u.Username, g.Name, syscall.Errno(rv)) + } + + ngroups = int(ngroupsC) + for i := 0; i < ngroups; i++ { + gid := C.group_at(C.int(i), (*C.gid_t)(groups)) + if g.Gid == strconv.Itoa(int(gid)) { + return true, nil + } + } + return false, nil +} + +func groupMembers(g *Group) ([]string, error) { + var members []string + gid, err := strconv.Atoi(g.Gid) + if err != nil { + return nil, err + } + + _, err = lookupUnixGroup(gid, "", false, func(grp *C.struct_group) *Group { + cmem := grp.gr_mem + for *cmem != nil { + members = append(members, C.GoString(*cmem)) + cmem = C.next_member(cmem) + } + return g + }) + if err != nil { + return nil, err + } + + return members, nil +} + +func allocBuffer(bufType int) (unsafe.Pointer, C.long, error) { + var bufSize C.long + if runtime.GOOS == "freebsd" { + // FreeBSD doesn't have _SC_GETPW_R_SIZE_MAX + // or SC_GETGR_R_SIZE_MAX and just returns -1. + // So just use the same size that Linux returns + bufSize = 1024 + } else { + var size C.int + var constName string + switch bufType { + case userBuffer: + size = C._SC_GETPW_R_SIZE_MAX + constName = "_SC_GETPW_R_SIZE_MAX" + case groupBuffer: + size = C._SC_GETGR_R_SIZE_MAX + constName = "_SC_GETGR_R_SIZE_MAX" + } + bufSize = C.sysconf(size) + if bufSize <= 0 || bufSize > 1<<20 { + return nil, bufSize, fmt.Errorf("user: unreasonable %s of %d", constName, bufSize) + } + } + return C.malloc(C.size_t(bufSize)), bufSize, nil +} diff --git a/group/user.go b/group/user.go index e8680fe5..9a7b5c16 100644 --- a/group/user.go +++ b/group/user.go @@ -26,6 +26,12 @@ type User struct { HomeDir string } +// Group represents a group database entry. +type Group struct { + Gid string // group id + Name string // group name +} + // UnknownUserIdError is returned by LookupId when // a user cannot be found. type UnknownUserIdError int @@ -41,3 +47,19 @@ type UnknownUserError string func (e UnknownUserError) Error() string { return "user: unknown user " + string(e) } + +// UnknownGroupIdError is returned by LookupGroupId when +// a group cannot be found. +type UnknownGroupIdError int + +func (e UnknownGroupIdError) Error() string { + return "group: unknown groupid " + strconv.Itoa(int(e)) +} + +// UnknownGroupError is returned by LookupGroup when +// a group cannot be found. +type UnknownGroupError string + +func (e UnknownGroupError) Error() string { + return "group: unknown group " + string(e) +} diff --git a/group/user_test.go b/group/user_test.go index 9d9420e8..055b5ec4 100644 --- a/group/user_test.go +++ b/group/user_test.go @@ -87,3 +87,52 @@ func TestLookupId(t *testing.T) { } compare(t, want, got) } + +func compareGroup(t *testing.T, want, got *Group) { + if want.Gid != got.Gid { + t.Errorf("got Gid=%q; want %q", got.Gid, want.Gid) + } + if want.Name != got.Name { + t.Errorf("got Name=%q; want %q", got.Name, want.Name) + } +} + +func TestLookupGroup(t *testing.T) { + check(t) + + // Test LookupGroupId on the current user + want, err := CurrentGroup() + if err != nil { + t.Fatalf("CurrentGroup: %v", err) + } + got, err := LookupGroupId(want.Gid) + if err != nil { + t.Fatalf("LookupGroupId: %v", err) + } + compareGroup(t, want, got) + + members, err := got.Members() + if err != nil { + t.Fatalf("Members: %v", err) + } + for _, user := range members { + u, err := Lookup(user) + if err != nil { + t.Errorf("expected a valid group member; user=%v, err=%v", user, err) + } + isMember, err := u.In(got) + if err != nil { + t.Fatalf("u.In: %v", err) + } + if !isMember { + t.Errorf("expected user to be group member; user=%v, group=%v, err=%v", user, got.Name, err) + } + } + + // Test Lookup by groupname, using the groupname from LookupId + g, err := LookupGroup(got.Name) + if err != nil { + t.Fatalf("Lookup: %v", err) + } + compareGroup(t, got, g) +} From 45ad5b7ca3c94e12f404da612ea35920c112e063 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 20:01:29 -0700 Subject: [PATCH 21/41] neuter OSX 'staff' group complaint, it makes no sense to me --- group/lookup_unix.go | 2 +- group/user_test.go | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/group/lookup_unix.go b/group/lookup_unix.go index d047efc6..c2f970cb 100644 --- a/group/lookup_unix.go +++ b/group/lookup_unix.go @@ -32,7 +32,7 @@ static int mygetgrgid_r(int gid, struct group *grp, char *buf, size_t buflen, struct group **result) { return getgrgid_r(gid, grp, buf, buflen, result); } - + static int mygetgrouplist(const char *user, gid_t group, gid_t *groups, int *ngroups) { return getgrouplist(user, group, (void *)groups, ngroups); diff --git a/group/user_test.go b/group/user_test.go index 055b5ec4..f34db9bd 100644 --- a/group/user_test.go +++ b/group/user_test.go @@ -125,7 +125,11 @@ func TestLookupGroup(t *testing.T) { t.Fatalf("u.In: %v", err) } if !isMember { - t.Errorf("expected user to be group member; user=%v, group=%v, err=%v", user, got.Name, err) + if runtime.GOOS == "darwin" && got.Name == "staff" { + // staff group on OSX is strange and I don't understand it + } else { + t.Errorf("expected user to be group member; user=%v, group=%v, err=%v", user, got.Name, err) + } } } From c7b2d976eafdb485ca94c66127f5ffeb4e4d2590 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 20:57:59 -0700 Subject: [PATCH 22/41] move group to user such that import matches --- {group => user}/lookup.go | 0 {group => user}/lookup_plan9.go | 0 {group => user}/lookup_stubs.go | 0 {group => user}/lookup_unix.go | 0 {group => user}/lookup_windows.go | 0 {group => user}/user.go | 0 {group => user}/user_test.go | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename {group => user}/lookup.go (100%) rename {group => user}/lookup_plan9.go (100%) rename {group => user}/lookup_stubs.go (100%) rename {group => user}/lookup_unix.go (100%) rename {group => user}/lookup_windows.go (100%) rename {group => user}/user.go (100%) rename {group => user}/user_test.go (100%) diff --git a/group/lookup.go b/user/lookup.go similarity index 100% rename from group/lookup.go rename to user/lookup.go diff --git a/group/lookup_plan9.go b/user/lookup_plan9.go similarity index 100% rename from group/lookup_plan9.go rename to user/lookup_plan9.go diff --git a/group/lookup_stubs.go b/user/lookup_stubs.go similarity index 100% rename from group/lookup_stubs.go rename to user/lookup_stubs.go diff --git a/group/lookup_unix.go b/user/lookup_unix.go similarity index 100% rename from group/lookup_unix.go rename to user/lookup_unix.go diff --git a/group/lookup_windows.go b/user/lookup_windows.go similarity index 100% rename from group/lookup_windows.go rename to user/lookup_windows.go diff --git a/group/user.go b/user/user.go similarity index 100% rename from group/user.go rename to user/user.go diff --git a/group/user_test.go b/user/user_test.go similarity index 100% rename from group/user_test.go rename to user/user_test.go From c36e806e57fda14e8fba1b3f3950cfa3fa5ab136 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 21:37:33 -0700 Subject: [PATCH 23/41] runLs output matching openssh for some stuff --- server.go | 3 +- server_integration_test.go | 37 +++++++-- server_stubs.go | 7 ++ server_unix.go | 149 +++++++++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 10 deletions(-) create mode 100644 server_stubs.go create mode 100644 server_unix.go diff --git a/server.go b/server.go index 3a01d264..dcd00a51 100644 --- a/server.go +++ b/server.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "os" - "path" "path/filepath" "sync" "syscall" @@ -389,7 +388,7 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error { for _, dirent := range dirents { ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{ dirent.Name(), - path.Join(dirname, dirent.Name()), + runLs(dirname, dirent), []interface{}{dirent}, }) } diff --git a/server_integration_test.go b/server_integration_test.go index 689398e9..a1662fa9 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -358,7 +358,7 @@ func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, s } func runSftpClient(script string, path string, host string, port int) (string, error) { - cmd := exec.Command(*testSftpClientBin, "-vvvv", "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path)) + cmd := exec.Command(*testSftpClientBin /*"-vvvv",*/, "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path)) stdout := &bytes.Buffer{} cmd.Stdin = bytes.NewBufferString(script) cmd.Stdout = stdout @@ -377,11 +377,13 @@ func TestServerCompareSubsystems(t *testing.T) { defer listenerOp.Close() script := ` -#ls / -#ls -l / -#ls /dev/ -#ls -l /dev/ -ls -l /tmp/shit/ +ls / +ls -l / +ls /dev/ +ls -l /dev/ +ls -l /etc/ +ls -l /bin/ +ls -l /usr/bin/ ` outputGo, err := runSftpClient(script, "/", hostGo, portGo) if err != nil { @@ -394,7 +396,26 @@ ls -l /tmp/shit/ } if outputGo != outputOp { - t.Errorf("outputs differ, go:\n%v\nopenssh:\n%v\n", outputGo, outputOp) + diffOffsetLine := 0 + diffOffsetNextLine := 0 + bad := false + for i := 0; i < len(outputGo) && i < len(outputOp); i++ { + if outputGo[i] != outputOp[i] { + bad = true + } else if outputGo[i] == '\n' { + if !bad { + diffOffsetLine = i + diffOffsetNextLine = i + } else { + diffOffsetNextLine = i + break + } + } + } + + t.Errorf("outputs differ, go:\n%v\nopenssh:\n%v\n", + outputGo[diffOffsetLine:diffOffsetNextLine], + outputOp[diffOffsetLine:diffOffsetNextLine]) } - t.Logf("go output:\n%v\nopenssh output:\n%v\n", outputGo, outputOp) + //t.Logf("go output:\n%v\nopenssh output:\n%v\n", outputGo, outputOp) } diff --git a/server_stubs.go b/server_stubs.go new file mode 100644 index 00000000..8e58c5fe --- /dev/null +++ b/server_stubs.go @@ -0,0 +1,7 @@ +// +build !cgo,!plan9 android + +package sftp + +func runLs(dirname string, dirent os.FileInfo) string { + return path.Join(dirname, dirent.Name()) +} diff --git a/server_unix.go b/server_unix.go new file mode 100644 index 00000000..03f62e57 --- /dev/null +++ b/server_unix.go @@ -0,0 +1,149 @@ +// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris +// +build cgo + +package sftp + +import ( + "fmt" + "os" + "path" + "syscall" + "time" + + "github.com/ScriptRock/sftp/user" +) + +func runLsTypeWord(dirent os.FileInfo) string { + // find first character, the type char + // b Block special file. + // c Character special file. + // d Directory. + // l Symbolic link. + // s Socket link. + // p FIFO. + // - Regular file. + tc := '-' + mode := dirent.Mode() + if (mode & os.ModeDir) != 0 { + tc = 'd' + } else if (mode & os.ModeDevice) != 0 { + tc = 'b' + if (mode & os.ModeCharDevice) != 0 { + tc = 'c' + } + } else if (mode & os.ModeSymlink) != 0 { + tc = 'l' + } else if (mode & os.ModeSocket) != 0 { + tc = 's' + } else if (mode & os.ModeNamedPipe) != 0 { + tc = 'p' + } + + // owner + orc := '-' + if (mode & 0400) != 0 { + orc = 'r' + } + owc := '-' + if (mode & 0200) != 0 { + owc = 'w' + } + oxc := '-' + ox := (mode & 0100) != 0 + setuid := (mode & os.ModeSetuid) != 0 + if ox && setuid { + oxc = 's' + } else if setuid { + oxc = 'S' + } else if ox { + oxc = 'x' + } + + // group + grc := '-' + if (mode & 040) != 0 { + grc = 'r' + } + gwc := '-' + if (mode & 020) != 0 { + gwc = 'w' + } + gxc := '-' + gx := (mode & 010) != 0 + setgid := (mode & os.ModeSetgid) != 0 + if gx && setgid { + gxc = 's' + } else if setgid { + gxc = 'S' + } else if gx { + gxc = 'x' + } + + // all / others + arc := '-' + if (mode & 04) != 0 { + arc = 'r' + } + awc := '-' + if (mode & 02) != 0 { + awc = 'w' + } + axc := '-' + ax := (mode & 01) != 0 + sticky := (mode & os.ModeSticky) != 0 + if ax && sticky { + axc = 't' + } else if sticky { + axc = 'T' + } else if ax { + axc = 'x' + } + + return fmt.Sprintf("%c%c%c%c%c%c%c%c%c%c", tc, orc, owc, oxc, grc, gwc, gxc, arc, awc, axc) +} + +func runLsStatt(dirname string, dirent os.FileInfo, statt *syscall.Stat_t) string { + // example from openssh sftp server: + // crw-rw-rw- 1 root wheel 0 Jul 31 20:52 ttyvd + // format: + // {directory / char device / etc}{rwxrwxrwx} {number of links} owner group size month day [time (this year) | year (otherwise)] name + + typeword := runLsTypeWord(dirent) + numLinks := statt.Nlink + uid := statt.Uid + gid := statt.Gid + username := fmt.Sprintf("%d", uid) + if usr, err := user.LookupId(username); err == nil { + username = usr.Username + } + groupname := fmt.Sprintf("%d", gid) + if grp, err := user.LookupGroupId(groupname); err == nil { + groupname = grp.Name + } + + mtime := dirent.ModTime() + monthStr := mtime.Month().String()[0:3] + day := mtime.Day() + year := mtime.Year() + now := time.Now() + isOld := mtime.Before(now.Add(-1 * time.Hour * 24 * 182)) + yearOrTime := fmt.Sprintf("%02d:%02d", mtime.Hour(), mtime.Minute()) + if isOld { + yearOrTime = fmt.Sprintf("%d", year) + } + + return fmt.Sprintf("%s %4d %-8s %-8s %8d %s %2d %5s %s", typeword, numLinks, username, groupname, dirent.Size(), monthStr, day, yearOrTime, dirent.Name()) +} + +// ls -l style output for a file, which is in the 'long output' section of a readdir response packet +// this is a very simple (lazy) implementation, just enough to look almost like openssh in a few basic cases +func runLs(dirname string, dirent os.FileInfo) string { + dsys := dirent.Sys() + if dsys == nil { + } else if statt, ok := dsys.(*syscall.Stat_t); !ok { + } else { + return runLsStatt(dirname, dirent, statt) + } + + return path.Join(dirname, dirent.Name()) +} From d9371ace4647326aacf96016b540a9961ce18f40 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 21:43:59 -0700 Subject: [PATCH 24/41] skipping integration tests for server unless specified --- server_integration_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server_integration_test.go b/server_integration_test.go index a1662fa9..148ecf02 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -320,6 +320,10 @@ Actual unit tests // starts an ssh server to test. returns: host string and port func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, string, int) { + if !*testIntegration { + t.Skip("skipping intergration test") + } + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) From 0aec5ce5ec398839d49d04aa619d0825f1c603e0 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 21:54:42 -0700 Subject: [PATCH 25/41] use merge-to target branches, not ScriptRock ones --- client.go | 4 ++-- example_test.go | 6 ++++-- server_integration_test.go | 4 +++- server_unix.go | 3 ++- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 58e89e0a..db8a142d 100644 --- a/client.go +++ b/client.go @@ -15,8 +15,8 @@ import ( "github.com/kr/fs" - //"golang.org/x/crypto/ssh" - "github.com/ScriptRock/crypto/ssh" + "golang.org/x/crypto/ssh" + //"github.com/ScriptRock/crypto/ssh" ) // MaxPacket sets the maximum size of the payload. diff --git a/example_test.go b/example_test.go index e38ecaf4..0bd76aab 100644 --- a/example_test.go +++ b/example_test.go @@ -6,8 +6,10 @@ import ( "os" "os/exec" - "github.com/ScriptRock/crypto/ssh" - "github.com/ScriptRock/sftp" + //"github.com/ScriptRock/crypto/ssh" + //"github.com/ScriptRock/sftp" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" ) func Example(conn *ssh.Client) { diff --git a/server_integration_test.go b/server_integration_test.go index 148ecf02..9e209ceb 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -2,6 +2,7 @@ package sftp // sftp server integration tests // enable with -integration +// example invokation (darwin): gofmt -w `find . -name \*.go` && (cd server_standalone/ ; go build -tags debug) && go test -tags debug github.com/ScriptRock/sftp -integration -v -sftp /usr/libexec/sftp-server -run ServerCompareSubsystems import ( "bytes" @@ -15,7 +16,8 @@ import ( "strings" "testing" - "github.com/ScriptRock/crypto/ssh" + //"github.com/ScriptRock/crypto/ssh" + "golang.org/x/crypto/ssh" ) var testSftpClientBin = flag.String("sftp_client", "/usr/bin/sftp", "location of the sftp client binary") diff --git a/server_unix.go b/server_unix.go index 03f62e57..fe512490 100644 --- a/server_unix.go +++ b/server_unix.go @@ -10,7 +10,8 @@ import ( "syscall" "time" - "github.com/ScriptRock/sftp/user" + //"github.com/ScriptRock/sftp/user" + "github.com/pkg/sftp/user" ) func runLsTypeWord(dirent os.FileInfo) string { From 421e8919ba86d5d111ca94b10123153d3a5da8f1 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 22:01:51 -0700 Subject: [PATCH 26/41] remove unused flags --- attrs.go | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/attrs.go b/attrs.go index f7c69cc9..9ac4cda8 100644 --- a/attrs.go +++ b/attrs.go @@ -13,20 +13,7 @@ const ( ssh_FILEXFER_ATTR_SIZE = 0x00000001 ssh_FILEXFER_ATTR_UIDGID = 0x00000002 ssh_FILEXFER_ATTR_PERMISSIONS = 0x00000004 - ssh_FILEXFER_ATTR_ACMODTIME = 0x00000008 // protocol version 2 - ssh_FILEXFER_ATTR_ACCESSTIME = 0x00000008 // protocol version 3 onwards - ssh_FILEXFER_ATTR_CREATETIME = 0x00000010 - ssh_FILEXFER_ATTR_MODIFYTIME = 0x00000020 - ssh_FILEXFER_ATTR_ACL = 0x00000040 - ssh_FILEXFER_ATTR_OWNERGROUP = 0x00000080 - ssh_FILEXFER_ATTR_SUBSECOND_TIMES = 0x00000100 - ssh_FILEXFER_ATTR_BITS = 0x00000200 - ssh_FILEXFER_ATTR_ALLOCATION_SIZE = 0x00000400 - ssh_FILEXFER_ATTR_TEXT_HINT = 0x00000800 - ssh_FILEXFER_ATTR_MIME_TYPE = 0x00001000 - ssh_FILEXFER_ATTR_LINK_COUNT = 0x00002000 - ssh_FILEXFER_ATTR_UNTRANSLATED_NAME = 0x00004000 - ssh_FILEXFER_ATTR_CTIME = 0x00008000 + ssh_FILEXFER_ATTR_ACMODTIME = 0x00000008 ssh_FILEXFER_ATTR_EXTENDED = 0x80000000 ) From ec909a249dbf753a23ac01bb0e682c6a5cff255b Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 22:02:47 -0700 Subject: [PATCH 27/41] gofmt --- attrs.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/attrs.go b/attrs.go index 9ac4cda8..a9f241f9 100644 --- a/attrs.go +++ b/attrs.go @@ -10,11 +10,11 @@ import ( ) const ( - ssh_FILEXFER_ATTR_SIZE = 0x00000001 - ssh_FILEXFER_ATTR_UIDGID = 0x00000002 - ssh_FILEXFER_ATTR_PERMISSIONS = 0x00000004 - ssh_FILEXFER_ATTR_ACMODTIME = 0x00000008 - ssh_FILEXFER_ATTR_EXTENDED = 0x80000000 + ssh_FILEXFER_ATTR_SIZE = 0x00000001 + ssh_FILEXFER_ATTR_UIDGID = 0x00000002 + ssh_FILEXFER_ATTR_PERMISSIONS = 0x00000004 + ssh_FILEXFER_ATTR_ACMODTIME = 0x00000008 + ssh_FILEXFER_ATTR_EXTENDED = 0x80000000 ) // fileInfo is an artificial type designed to satisfy os.FileInfo. From af012f1b56cb6d046662c5af8fa06f8e118dda70 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 22:21:45 -0700 Subject: [PATCH 28/41] fix import path for server_standalone --- server_standalone/main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server_standalone/main.go b/server_standalone/main.go index 7bdfcefe..bfd33dc2 100644 --- a/server_standalone/main.go +++ b/server_standalone/main.go @@ -8,7 +8,8 @@ import ( "io/ioutil" "os" - "github.com/ScriptRock/sftp" + //"github.com/ScriptRock/sftp" + "github.com/pkg/sftp" ) func main() { From 7bb2083ca932ae61047c87d99d683ad14e8d96e2 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Sun, 6 Sep 2015 23:55:15 -0700 Subject: [PATCH 29/41] Address review comments; about to change decodePacket() --- attrs.go | 5 ++--- client.go | 1 - client_integration_test.go | 6 +++--- example_test.go | 2 -- packet.go | 14 +++++++++----- server.go | 28 +++++++++++++++++----------- server_integration_test.go | 5 ++--- server_standalone/main.go | 7 +++++-- server_unix.go | 1 - 9 files changed, 38 insertions(+), 31 deletions(-) diff --git a/attrs.go b/attrs.go index a9f241f9..a1f43cc0 100644 --- a/attrs.go +++ b/attrs.go @@ -127,10 +127,9 @@ func marshalFileInfo(b []byte, fi os.FileInfo) []byte { mtime := uint32(fi.ModTime().Unix()) atime := mtime - flags := ssh_FILEXFER_ATTR_SIZE | + var flags uint32 = ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_PERMISSIONS | - ssh_FILEXFER_ATTR_ACMODTIME | - uint32(0) + ssh_FILEXFER_ATTR_ACMODTIME if statt, ok := fi.Sys().(*syscall.Stat_t); ok { flags |= ssh_FILEXFER_ATTR_UIDGID diff --git a/client.go b/client.go index db8a142d..6479f8bd 100644 --- a/client.go +++ b/client.go @@ -16,7 +16,6 @@ import ( "github.com/kr/fs" "golang.org/x/crypto/ssh" - //"github.com/ScriptRock/crypto/ssh" ) // MaxPacket sets the maximum size of the payload. diff --git a/client_integration_test.go b/client_integration_test.go index b707726e..e3c8c5bb 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -32,8 +32,6 @@ const ( debuglevel = "ERROR" // set to "DEBUG" for debugging ) -var spaceRegex = regexp.MustCompile(`\s+`) - var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance") var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") var testSftp = flag.String("sftp", "/usr/lib/openssh/sftp-server", "location of the sftp server binary") @@ -94,7 +92,7 @@ func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration) (*Client, if err != nil { t.Fatal(err) } - go server.Run() + go server.Serve() var ctx io.WriteCloser = txPipeWr if delay > NO_DELAY { @@ -578,6 +576,8 @@ func TestClientChown(t *testing.T) { t.Fatal(err) } + spaceRegex := regexp.MustCompile(`\s+`) + beforeWords := spaceRegex.Split(string(before), -1) if beforeWords[2] != "0" { t.Fatalf("bad previous user? should be root") diff --git a/example_test.go b/example_test.go index 0bd76aab..11dc1e12 100644 --- a/example_test.go +++ b/example_test.go @@ -6,8 +6,6 @@ import ( "os" "os/exec" - //"github.com/ScriptRock/crypto/ssh" - //"github.com/ScriptRock/sftp" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" ) diff --git a/packet.go b/packet.go index 4cfe75b4..1045e0fb 100644 --- a/packet.go +++ b/packet.go @@ -9,7 +9,10 @@ import ( ) var ( - shortPacketError = fmt.Errorf("packet too short") + shortPacketError = fmt.Errorf("packet too short") +) + +const ( debugDumpTxPacket = false debugDumpRxPacket = false debugDumpTxPacketBytes = false @@ -246,16 +249,17 @@ func marshalIdString(packetType byte, id uint32, str string) ([]byte, error) { return b, nil } -func unmarshalIdString(b []byte, id *uint32, str *string) (err error) { +func unmarshalIdString(b []byte, id *uint32, str *string) (error) { + var err error = nil *id, b, err = unmarshalUint32Safe(b) if err != nil { - return + return err } *str, b, err = unmarshalStringSafe(b) if err != nil { - return + return err } - return + return nil } type sshFxpReaddirPacket struct { diff --git a/server.go b/server.go index dcd00a51..d09256eb 100644 --- a/server.go +++ b/server.go @@ -12,6 +12,15 @@ import ( "syscall" ) +const ( + sftpServerWorkerCount = 8 +) + +// Server is an SSH File Transfer Protocol (sftp) server. +// This is intended to provide the sftp subsystem to an ssh server daemon. +// This implementation currently supports most of sftp server protocol version 3, +// as specified at http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 +// Currently unimplemented are SETSTAT, FSETSTAT, SYMLINK, possibly others; patches welcome. type Server struct { in io.Reader out io.WriteCloser @@ -26,7 +35,7 @@ type Server struct { openFilesLock *sync.RWMutex handleCount int maxTxPacket uint32 - WorkerCount int + workerCount int } func (svr *Server) nextHandle(f *os.File) string { @@ -62,7 +71,8 @@ type serverRespondablePacket interface { } // Creates a new server instance around the provided streams. -// A subsequent call to Run() is required. +// Various debug output will be written to debugStream, with verbosity set by debugLevel +// A subsequent call to Serve() is required. func NewServer(in io.Reader, out io.WriteCloser, debugStream io.Writer, debugLevel int, readOnly bool, rootDir string) (*Server, error) { if rootDir == "" { if wd, err := os.Getwd(); err != nil { @@ -71,7 +81,6 @@ func NewServer(in io.Reader, out io.WriteCloser, debugStream io.Writer, debugLev rootDir = wd } } - workerCount := 8 return &Server{ in: in, out: out, @@ -80,11 +89,11 @@ func NewServer(in io.Reader, out io.WriteCloser, debugStream io.Writer, debugLev debugLevel: debugLevel, readOnly: readOnly, rootDir: rootDir, - pktChan: make(chan rxPacket, workerCount), + pktChan: make(chan rxPacket, sftpServerWorkerCount), openFiles: map[string]*os.File{}, openFilesLock: &sync.RWMutex{}, maxTxPacket: 1 << 15, - WorkerCount: workerCount, + workerCount: sftpServerWorkerCount, }, nil } @@ -127,16 +136,13 @@ func (svr *Server) sftpServerWorker(doneChan chan error) { } // Run this server until the streams stop or until the subsystem is stopped -func (svr *Server) Run() error { - if svr.WorkerCount <= 0 { - return fmt.Errorf("sftp server requires > 0 workers") - } +func (svr *Server) Serve() error { go svr.rxPackets() doneChan := make(chan error) - for i := 0; i < svr.WorkerCount; i++ { + for i := 0; i < svr.workerCount; i++ { go svr.sftpServerWorker(doneChan) } - for i := 0; i < svr.WorkerCount; i++ { + for i := 0; i < svr.workerCount; i++ { if err := <-doneChan; err != nil { // abort early and shut down the session on un-decodable packets break diff --git a/server_integration_test.go b/server_integration_test.go index 9e209ceb..0a2f484a 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -2,7 +2,7 @@ package sftp // sftp server integration tests // enable with -integration -// example invokation (darwin): gofmt -w `find . -name \*.go` && (cd server_standalone/ ; go build -tags debug) && go test -tags debug github.com/ScriptRock/sftp -integration -v -sftp /usr/libexec/sftp-server -run ServerCompareSubsystems +// example invokation (darwin): gofmt -w `find . -name \*.go` && (cd server_standalone/ ; go build -tags debug) && go test -tags debug github.com/pkg/sftp -integration -v -sftp /usr/libexec/sftp-server -run ServerCompareSubsystems import ( "bytes" @@ -16,7 +16,6 @@ import ( "strings" "testing" - //"github.com/ScriptRock/crypto/ssh" "golang.org/x/crypto/ssh" ) @@ -298,7 +297,7 @@ func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error { } // wait for the session to close - runErr := sftpServer.Run() + runErr := sftpServer.Serve() exitStatus := uint32(1) if runErr == nil { exitStatus = uint32(0) diff --git a/server_standalone/main.go b/server_standalone/main.go index bfd33dc2..eb89a9e6 100644 --- a/server_standalone/main.go +++ b/server_standalone/main.go @@ -7,8 +7,8 @@ import ( "flag" "io/ioutil" "os" + "fmt" - //"github.com/ScriptRock/sftp" "github.com/pkg/sftp" ) @@ -29,5 +29,8 @@ func main() { } svr, _ := sftp.NewServer(os.Stdin, os.Stdout, debugStream, debugLevel, readOnly, "") - svr.Run() + if err := svr.Serve(); err != nil { + fmt.Fprintf(debugStream, "sftp server completed with error: %v", err) + os.Exit(1) + } } diff --git a/server_unix.go b/server_unix.go index fe512490..37e739d6 100644 --- a/server_unix.go +++ b/server_unix.go @@ -10,7 +10,6 @@ import ( "syscall" "time" - //"github.com/ScriptRock/sftp/user" "github.com/pkg/sftp/user" ) From d80ae36051ff19db93b847ecad2cc7c798c0374b Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 01:05:16 -0700 Subject: [PATCH 30/41] rmdir and symlink packet handling --- client.go | 19 +++++++++++ client_integration_test.go | 24 +++++++++++++- packet.go | 35 +++++++++++++++++++- server.go | 18 +++++++++++ server_integration_test.go | 17 +++++----- server_standalone/main.go | 2 +- server_test.go | 65 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 169 insertions(+), 11 deletions(-) create mode 100644 server_test.go diff --git a/client.go b/client.go index 6479f8bd..0ad90470 100644 --- a/client.go +++ b/client.go @@ -336,6 +336,25 @@ func (c *Client) ReadLink(p string) (string, error) { } } +// Symlink creates a symbolic link at 'newname', pointing at target 'oldname' +func (c *Client) Symlink(oldname, newname string) error { + id := c.nextId() + typ, data, err := c.sendRequest(sshFxpSymlinkPacket{ + Id: id, + Linkpath: newname, + Targetpath: oldname, + }) + if err != nil { + return err + } + switch typ { + case ssh_FXP_STATUS: + return okOrErr(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + // setstat is a convience wrapper to allow for changing of various parts of the file descriptor. func (c *Client) setstat(path string, flags uint32, attrs interface{}) error { id := c.nextId() diff --git a/client_integration_test.go b/client_integration_test.go index e3c8c5bb..84c356e2 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -495,8 +495,30 @@ func TestClientReadLink(t *testing.T) { if err := os.Symlink(f.Name(), f2); err != nil { t.Fatal(err) } - if _, err := sftp.ReadLink(f2); err != nil { + if rl, err := sftp.ReadLink(f2); err != nil { t.Fatal(err) + } else if rl != f.Name() { + t.Fatalf("unexpected link target: %v, not %v", rl, f.Name()) + } +} + +func TestClientSymlink(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NO_DELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest") + if err != nil { + t.Fatal(err) + } + f2 := f.Name() + ".sym" + if err := sftp.Symlink(f.Name(), f2); err != nil { + t.Fatal(err) + } + if rl, err := sftp.ReadLink(f2); err != nil { + t.Fatal(err) + } else if rl != f.Name() { + t.Fatalf("unexpected link target: %v, not %v", rl, f.Name()) } } diff --git a/packet.go b/packet.go index 1045e0fb..14915d12 100644 --- a/packet.go +++ b/packet.go @@ -249,7 +249,7 @@ func marshalIdString(packetType byte, id uint32, str string) ([]byte, error) { return b, nil } -func unmarshalIdString(b []byte, id *uint32, str *string) (error) { +func unmarshalIdString(b []byte, id *uint32, str *string) error { var err error = nil *id, b, err = unmarshalUint32Safe(b) if err != nil { @@ -382,6 +382,39 @@ func (p *sshFxpRmdirPacket) UnmarshalBinary(b []byte) error { return unmarshalIdString(b, &p.Id, &p.Path) } +type sshFxpSymlinkPacket struct { + Id uint32 + Targetpath string + Linkpath string +} + +func (p sshFxpSymlinkPacket) id() uint32 { return p.Id } + +func (p sshFxpSymlinkPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Targetpath) + + 4 + len(p.Linkpath) + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_SYMLINK) + b = marshalUint32(b, p.Id) + b = marshalString(b, p.Targetpath) + b = marshalString(b, p.Linkpath) + return b, nil +} + +func (p *sshFxpSymlinkPacket) UnmarshalBinary(b []byte) error { + var err error = nil + if p.Id, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Targetpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Linkpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + type sshFxpReadlinkPacket struct { Id uint32 Path string diff --git a/server.go b/server.go index d09256eb..dd452592 100644 --- a/server.go +++ b/server.go @@ -182,6 +182,7 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable case ssh_FXP_MKDIR: pkt = &sshFxpMkdirPacket{} case ssh_FXP_RMDIR: + pkt = &sshFxpRmdirPacket{} case ssh_FXP_REALPATH: pkt = &sshFxpRealpathPacket{} case ssh_FXP_STAT: @@ -191,6 +192,7 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable case ssh_FXP_READLINK: pkt = &sshFxpReadlinkPacket{} case ssh_FXP_SYMLINK: + pkt = &sshFxpSymlinkPacket{} case ssh_FXP_STATUS: case ssh_FXP_HANDLE: case ssh_FXP_DATA: @@ -262,6 +264,14 @@ func (p sshFxpMkdirPacket) respond(svr *Server) error { return svr.sendPacket(statusFromError(p.Id, err)) } +func (p sshFxpRmdirPacket) respond(svr *Server) error { + if svr.readOnly { + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } + err := os.Remove(p.Path) + return svr.sendPacket(statusFromError(p.Id, err)) +} + func (p sshFxpRemovePacket) respond(svr *Server) error { if svr.readOnly { return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) @@ -278,6 +288,14 @@ func (p sshFxpRenamePacket) respond(svr *Server) error { return svr.sendPacket(statusFromError(p.Id, err)) } +func (p sshFxpSymlinkPacket) respond(svr *Server) error { + if svr.readOnly { + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } + err := os.Symlink(p.Targetpath, p.Linkpath) + return svr.sendPacket(statusFromError(p.Id, err)) +} + var emptyFileStat = []interface{}{uint32(0)} func (p sshFxpReadlinkPacket) respond(svr *Server) error { diff --git a/server_integration_test.go b/server_integration_test.go index 0a2f484a..a2d2ef67 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "flag" "fmt" + "io/ioutil" "net" "os" "os/exec" @@ -20,9 +21,9 @@ import ( ) var testSftpClientBin = flag.String("sftp_client", "/usr/bin/sftp", "location of the sftp client binary") -var sshServerDebugStream = os.Stdout // ioutil.Discard -var sftpServerDebugStream = os.Stdout // ioutil.Discard -var sftpClientDebugStream = os.Stdout // ioutil.Discard +var sshServerDebugStream = ioutil.Discard +var sftpServerDebugStream = ioutil.Discard +var sftpClientDebugStream = ioutil.Discard const ( GOLANG_SFTP = true @@ -321,10 +322,6 @@ Actual unit tests // starts an ssh server to test. returns: host string and port func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, string, int) { - if !*testIntegration { - t.Skip("skipping intergration test") - } - listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) @@ -363,7 +360,7 @@ func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, s } func runSftpClient(script string, path string, host string, port int) (string, error) { - cmd := exec.Command(*testSftpClientBin /*"-vvvv",*/, "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path)) + cmd := exec.Command(*testSftpClientBin, "-vvvv", "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path)) stdout := &bytes.Buffer{} cmd.Stdin = bytes.NewBufferString(script) cmd.Stdout = stdout @@ -376,6 +373,10 @@ func runSftpClient(script string, path string, host string, port int) (string, e } func TestServerCompareSubsystems(t *testing.T) { + if !*testIntegration { + t.Skip("skipping intergration test") + } + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) listenerOp, hostOp, portOp := testServer(t, OPENSSH_SFTP, READONLY) defer listenerGo.Close() diff --git a/server_standalone/main.go b/server_standalone/main.go index eb89a9e6..0b510a5c 100644 --- a/server_standalone/main.go +++ b/server_standalone/main.go @@ -5,9 +5,9 @@ package main import ( "flag" + "fmt" "io/ioutil" "os" - "fmt" "github.com/pkg/sftp" ) diff --git a/server_test.go b/server_test.go new file mode 100644 index 00000000..c6b8d969 --- /dev/null +++ b/server_test.go @@ -0,0 +1,65 @@ +package sftp + +import ( + "encoding/hex" + "math/rand" + "os" + "testing" + "time" +) + +func randName() string { + r := rand.New(rand.NewSource(time.Now().Unix())) + data := make([]byte, 16) + for i := 0; i < 16; i++ { + data[i] = byte(r.Uint32()) + } + return "sftp." + hex.EncodeToString(data) +} + +func TestServerMkdirRmdir(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + tmpDir := "/tmp/" + randName() + defer os.RemoveAll(tmpDir) + + // mkdir remote + if _, err := runSftpClient("mkdir "+tmpDir, "/", hostGo, portGo); err != nil { + t.Fatal(err) + } + + // directory should now exist + if _, err := os.Stat(tmpDir); err != nil { + t.Fatal(err) + } + + // now remove the directory + if _, err := runSftpClient("rmdir "+tmpDir, "/", hostGo, portGo); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(tmpDir); err == nil { + t.Fatal("should have error after deleting the directory") + } +} + +func TestServerSymlink(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + link := "/tmp/" + randName() + //defer os.RemoveAll(link) + + // now create a symbolic link within the new directory + if output, err := runSftpClient("symlink /bin/sh "+link, "/", hostGo, portGo); err != nil { + t.Fatalf("failed: %v %v", err, string(output)) + } + + // symlink should now exist + if stat, err := os.Lstat(link); err != nil { + t.Fatal(err) + } else if (stat.Mode() & os.ModeSymlink) != os.ModeSymlink { + t.Fatalf("is not a symlink: %v", stat.Mode()) + } +} From 90f1d88de0839edfac17eda84f233ffa3e145f02 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 01:06:16 -0700 Subject: [PATCH 31/41] add test cleanup back in --- server_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server_test.go b/server_test.go index c6b8d969..1cfb4b9d 100644 --- a/server_test.go +++ b/server_test.go @@ -49,7 +49,7 @@ func TestServerSymlink(t *testing.T) { defer listenerGo.Close() link := "/tmp/" + randName() - //defer os.RemoveAll(link) + defer os.RemoveAll(link) // now create a symbolic link within the new directory if output, err := runSftpClient("symlink /bin/sh "+link, "/", hostGo, portGo); err != nil { From cbd08aeb80fe4a82c95613270f4f0b6b0511fb24 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 02:13:07 -0700 Subject: [PATCH 32/41] implement setstat & fsetstat --- packet.go | 50 +++++++++++++++++++++- server.go | 124 +++++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 158 insertions(+), 16 deletions(-) diff --git a/packet.go b/packet.go index 14915d12..148ee305 100644 --- a/packet.go +++ b/packet.go @@ -668,7 +668,15 @@ type sshFxpSetstatPacket struct { Attrs interface{} } -func (p sshFxpSetstatPacket) id() uint32 { return p.Id } +type sshFxpFsetstatPacket struct { + Id uint32 + Handle string + Flags uint32 + Attrs interface{} +} + +func (p sshFxpSetstatPacket) id() uint32 { return p.Id } +func (p sshFxpFsetstatPacket) id() uint32 { return p.Id } func (p sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { l := 1 + 4 + // type(byte) + uint32 @@ -684,6 +692,46 @@ func (p sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { return b, nil } +func (p sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) { + l := 1 + 4 + // type(byte) + uint32 + 4 + len(p.Handle) + + 4 // uint32 + uint64 + + b := make([]byte, 0, l) + b = append(b, ssh_FXP_FSETSTAT) + b = marshalUint32(b, p.Id) + b = marshalString(b, p.Handle) + b = marshalUint32(b, p.Flags) + b = marshal(b, p.Attrs) + return b, nil +} + +func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { + var err error = nil + if p.Id, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + p.Attrs = b + return nil +} + +func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { + var err error = nil + if p.Id, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + p.Attrs = b + return nil +} + type sshFxpHandlePacket struct { Id uint32 Handle string diff --git a/server.go b/server.go index dd452592..6b64616f 100644 --- a/server.go +++ b/server.go @@ -10,6 +10,7 @@ import ( "path/filepath" "sync" "syscall" + "time" ) const ( @@ -20,7 +21,6 @@ const ( // This is intended to provide the sftp subsystem to an ssh server daemon. // This implementation currently supports most of sftp server protocol version 3, // as specified at http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 -// Currently unimplemented are SETSTAT, FSETSTAT, SYMLINK, possibly others; patches welcome. type Server struct { in io.Reader out io.WriteCloser @@ -31,19 +31,24 @@ type Server struct { rootDir string lastId uint32 pktChan chan rxPacket - openFiles map[string]*os.File + openFiles map[string]serverOpenFile openFilesLock *sync.RWMutex handleCount int maxTxPacket uint32 workerCount int } -func (svr *Server) nextHandle(f *os.File) string { +type serverOpenFile struct { + *os.File + path string +} + +func (svr *Server) nextHandle(path string, f *os.File) string { svr.openFilesLock.Lock() defer svr.openFilesLock.Unlock() svr.handleCount++ handle := fmt.Sprintf("%d", svr.handleCount) - svr.openFiles[handle] = f + svr.openFiles[handle] = serverOpenFile{f, path} return handle } @@ -58,7 +63,7 @@ func (svr *Server) closeHandle(handle string) error { } } -func (svr *Server) getHandle(handle string) (*os.File, bool) { +func (svr *Server) getHandle(handle string) (serverOpenFile, bool) { svr.openFilesLock.RLock() defer svr.openFilesLock.RUnlock() f, ok := svr.openFiles[handle] @@ -90,7 +95,7 @@ func NewServer(in io.Reader, out io.WriteCloser, debugStream io.Writer, debugLev readOnly: readOnly, rootDir: rootDir, pktChan: make(chan rxPacket, sftpServerWorkerCount), - openFiles: map[string]*os.File{}, + openFiles: map[string]serverOpenFile{}, openFilesLock: &sync.RWMutex{}, maxTxPacket: 1 << 15, workerCount: sftpServerWorkerCount, @@ -160,7 +165,6 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable pkt = &sshFxInitPacket{} case ssh_FXP_LSTAT: pkt = &sshFxpLstatPacket{} - case ssh_FXP_VERSION: case ssh_FXP_OPEN: pkt = &sshFxpOpenPacket{} case ssh_FXP_CLOSE: @@ -172,7 +176,9 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable case ssh_FXP_FSTAT: pkt = &sshFxpFstatPacket{} case ssh_FXP_SETSTAT: + pkt = &sshFxpSetstatPacket{} case ssh_FXP_FSETSTAT: + pkt = &sshFxpFsetstatPacket{} case ssh_FXP_OPENDIR: pkt = &sshFxpOpendirPacket{} case ssh_FXP_READDIR: @@ -193,14 +199,8 @@ func (svr *Server) decodePacket(pktType fxp, pktBytes []byte) (serverRespondable pkt = &sshFxpReadlinkPacket{} case ssh_FXP_SYMLINK: pkt = &sshFxpSymlinkPacket{} - case ssh_FXP_STATUS: - case ssh_FXP_HANDLE: - case ssh_FXP_DATA: - case ssh_FXP_NAME: - case ssh_FXP_ATTRS: - case ssh_FXP_EXTENDED: - case ssh_FXP_EXTENDED_REPLY: default: + return nil, fmt.Errorf("unhandled packet type: %s", pktType.String()) } if pkt == nil { return nil, fmt.Errorf("unhandled packet type: %s", pktType.String()) @@ -355,7 +355,7 @@ func (p sshFxpOpenPacket) respond(svr *Server) error { if f, err := os.OpenFile(p.Path, osFlags, 0644); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { - handle := svr.nextHandle(f) + handle := svr.nextHandle(p.Path, f) return svr.sendPacket(sshFxpHandlePacket{p.Id, handle}) } } @@ -421,6 +421,100 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error { } } +func (p sshFxpSetstatPacket) respond(svr *Server) error { + if svr.readOnly { + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } else { + // additional unmarshalling is required for each possibility here + b := p.Attrs.([]byte) + var err error = nil + + debug("setstat name \"%s\"", p.Path) + if (p.Flags & ssh_FILEXFER_ATTR_SIZE) != 0 { + var size uint64 = 0 + if size, b, err = unmarshalUint64Safe(b); err == nil { + err = os.Truncate(p.Path, int64(size)) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_PERMISSIONS) != 0 { + var mode uint32 = 0 + if mode, b, err = unmarshalUint32Safe(b); err == nil { + err = os.Chmod(p.Path, os.FileMode(mode)) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_ACMODTIME) != 0 { + var atime uint32 = 0 + var mtime uint32 = 0 + if atime, b, err = unmarshalUint32Safe(b); err != nil { + } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { + } else { + atimeT := time.Unix(int64(atime), 0) + mtimeT := time.Unix(int64(mtime), 0) + err = os.Chtimes(p.Path, atimeT, mtimeT) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_UIDGID) != 0 { + var uid uint32 = 0 + var gid uint32 = 0 + if uid, b, err = unmarshalUint32Safe(b); err != nil { + } else if gid, b, err = unmarshalUint32Safe(b); err != nil { + } else { + err = os.Chown(p.Path, int(uid), int(gid)) + } + } + + return svr.sendPacket(statusFromError(p.Id, err)) + } +} + +func (p sshFxpFsetstatPacket) respond(svr *Server) error { + if svr.readOnly { + return svr.sendPacket(statusFromError(p.Id, syscall.EPERM)) + } else if f, ok := svr.getHandle(p.Handle); !ok { + return svr.sendPacket(statusFromError(p.Id, syscall.EBADF)) + } else { + // additional unmarshalling is required for each possibility here + b := p.Attrs.([]byte) + var err error = nil + + debug("fsetstat name \"%s\"", f.path) + if (p.Flags & ssh_FILEXFER_ATTR_SIZE) != 0 { + var size uint64 = 0 + if size, b, err = unmarshalUint64Safe(b); err == nil { + err = f.Truncate(int64(size)) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_PERMISSIONS) != 0 { + var mode uint32 = 0 + if mode, b, err = unmarshalUint32Safe(b); err == nil { + err = f.Chmod(os.FileMode(mode)) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_ACMODTIME) != 0 { + var atime uint32 = 0 + var mtime uint32 = 0 + if atime, b, err = unmarshalUint32Safe(b); err != nil { + } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { + } else { + atimeT := time.Unix(int64(atime), 0) + mtimeT := time.Unix(int64(mtime), 0) + err = os.Chtimes(f.path, atimeT, mtimeT) + } + } + if (p.Flags & ssh_FILEXFER_ATTR_UIDGID) != 0 { + var uid uint32 = 0 + var gid uint32 = 0 + if uid, b, err = unmarshalUint32Safe(b); err != nil { + } else if gid, b, err = unmarshalUint32Safe(b); err != nil { + } else { + err = f.Chown(int(uid), int(gid)) + } + } + + return svr.sendPacket(statusFromError(p.Id, err)) + } +} + func errnoToSshErr(errno syscall.Errno) uint32 { if errno == 0 { return ssh_FX_OK From 20391fc5d1749bce033fa15ba140ee951892c067 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 02:20:58 -0700 Subject: [PATCH 33/41] skip server tests if the openssh sftp client binary is unavailable --- server_integration_test.go | 12 ++++++++---- server_test.go | 6 +++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/server_integration_test.go b/server_integration_test.go index a2d2ef67..7530c227 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -359,8 +359,12 @@ func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, s return listener, host, port } -func runSftpClient(script string, path string, host string, port int) (string, error) { - cmd := exec.Command(*testSftpClientBin, "-vvvv", "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path)) +func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) { + // if sftp client binary is unavailable, skip test + if _, err := os.Stat(*testSftpClientBin); err != nil { + t.Skip("sftp client binary unavailable") + } + cmd := exec.Command(*testSftpClientBin /*"-vvvv",*/, "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path)) stdout := &bytes.Buffer{} cmd.Stdin = bytes.NewBufferString(script) cmd.Stdout = stdout @@ -391,12 +395,12 @@ ls -l /etc/ ls -l /bin/ ls -l /usr/bin/ ` - outputGo, err := runSftpClient(script, "/", hostGo, portGo) + outputGo, err := runSftpClient(t, script, "/", hostGo, portGo) if err != nil { t.Fatal(err) } - outputOp, err := runSftpClient(script, "/", hostOp, portOp) + outputOp, err := runSftpClient(t, script, "/", hostOp, portOp) if err != nil { t.Fatal(err) } diff --git a/server_test.go b/server_test.go index 1cfb4b9d..9599ad74 100644 --- a/server_test.go +++ b/server_test.go @@ -25,7 +25,7 @@ func TestServerMkdirRmdir(t *testing.T) { defer os.RemoveAll(tmpDir) // mkdir remote - if _, err := runSftpClient("mkdir "+tmpDir, "/", hostGo, portGo); err != nil { + if _, err := runSftpClient(t, "mkdir "+tmpDir, "/", hostGo, portGo); err != nil { t.Fatal(err) } @@ -35,7 +35,7 @@ func TestServerMkdirRmdir(t *testing.T) { } // now remove the directory - if _, err := runSftpClient("rmdir "+tmpDir, "/", hostGo, portGo); err != nil { + if _, err := runSftpClient(t, "rmdir "+tmpDir, "/", hostGo, portGo); err != nil { t.Fatal(err) } @@ -52,7 +52,7 @@ func TestServerSymlink(t *testing.T) { defer os.RemoveAll(link) // now create a symbolic link within the new directory - if output, err := runSftpClient("symlink /bin/sh "+link, "/", hostGo, portGo); err != nil { + if output, err := runSftpClient(t, "symlink /bin/sh "+link, "/", hostGo, portGo); err != nil { t.Fatalf("failed: %v %v", err, string(output)) } From c69ab311abd8482ba655ac0dc8b0c0318240f78d Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 02:28:33 -0700 Subject: [PATCH 34/41] all tests invoking external binaries considered integration tests to satisfy wercker --- server.go | 1 - server_integration_test.go | 67 +++++++++++++++++++++++++++++++++++--- server_test.go | 65 ------------------------------------ 3 files changed, 62 insertions(+), 71 deletions(-) delete mode 100644 server_test.go diff --git a/server.go b/server.go index 6b64616f..984fc751 100644 --- a/server.go +++ b/server.go @@ -416,7 +416,6 @@ func (p sshFxpReaddirPacket) respond(svr *Server) error { []interface{}{dirent}, }) } - //debug("readdir respond %v", ret) return svr.sendPacket(ret) } } diff --git a/server_integration_test.go b/server_integration_test.go index 7530c227..bf129165 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -10,12 +10,14 @@ import ( "flag" "fmt" "io/ioutil" + "math/rand" "net" "os" "os/exec" "strconv" "strings" "testing" + "time" "golang.org/x/crypto/ssh" ) @@ -322,6 +324,10 @@ Actual unit tests // starts an ssh server to test. returns: host string and port func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, string, int) { + if !*testIntegration { + t.Skip("skipping intergration test") + } + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) @@ -377,10 +383,6 @@ func runSftpClient(t *testing.T, script string, path string, host string, port i } func TestServerCompareSubsystems(t *testing.T) { - if !*testIntegration { - t.Skip("skipping intergration test") - } - listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) listenerOp, hostOp, portOp := testServer(t, OPENSSH_SFTP, READONLY) defer listenerGo.Close() @@ -427,5 +429,60 @@ ls -l /usr/bin/ outputGo[diffOffsetLine:diffOffsetNextLine], outputOp[diffOffsetLine:diffOffsetNextLine]) } - //t.Logf("go output:\n%v\nopenssh output:\n%v\n", outputGo, outputOp) +} + +func randName() string { + r := rand.New(rand.NewSource(time.Now().Unix())) + data := make([]byte, 16) + for i := 0; i < 16; i++ { + data[i] = byte(r.Uint32()) + } + return "sftp." + hex.EncodeToString(data) +} + +func TestServerMkdirRmdir(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + tmpDir := "/tmp/" + randName() + defer os.RemoveAll(tmpDir) + + // mkdir remote + if _, err := runSftpClient(t, "mkdir "+tmpDir, "/", hostGo, portGo); err != nil { + t.Fatal(err) + } + + // directory should now exist + if _, err := os.Stat(tmpDir); err != nil { + t.Fatal(err) + } + + // now remove the directory + if _, err := runSftpClient(t, "rmdir "+tmpDir, "/", hostGo, portGo); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(tmpDir); err == nil { + t.Fatal("should have error after deleting the directory") + } +} + +func TestServerSymlink(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + link := "/tmp/" + randName() + defer os.RemoveAll(link) + + // now create a symbolic link within the new directory + if output, err := runSftpClient(t, "symlink /bin/sh "+link, "/", hostGo, portGo); err != nil { + t.Fatalf("failed: %v %v", err, string(output)) + } + + // symlink should now exist + if stat, err := os.Lstat(link); err != nil { + t.Fatal(err) + } else if (stat.Mode() & os.ModeSymlink) != os.ModeSymlink { + t.Fatalf("is not a symlink: %v", stat.Mode()) + } } diff --git a/server_test.go b/server_test.go deleted file mode 100644 index 9599ad74..00000000 --- a/server_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package sftp - -import ( - "encoding/hex" - "math/rand" - "os" - "testing" - "time" -) - -func randName() string { - r := rand.New(rand.NewSource(time.Now().Unix())) - data := make([]byte, 16) - for i := 0; i < 16; i++ { - data[i] = byte(r.Uint32()) - } - return "sftp." + hex.EncodeToString(data) -} - -func TestServerMkdirRmdir(t *testing.T) { - listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) - defer listenerGo.Close() - - tmpDir := "/tmp/" + randName() - defer os.RemoveAll(tmpDir) - - // mkdir remote - if _, err := runSftpClient(t, "mkdir "+tmpDir, "/", hostGo, portGo); err != nil { - t.Fatal(err) - } - - // directory should now exist - if _, err := os.Stat(tmpDir); err != nil { - t.Fatal(err) - } - - // now remove the directory - if _, err := runSftpClient(t, "rmdir "+tmpDir, "/", hostGo, portGo); err != nil { - t.Fatal(err) - } - - if _, err := os.Stat(tmpDir); err == nil { - t.Fatal("should have error after deleting the directory") - } -} - -func TestServerSymlink(t *testing.T) { - listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) - defer listenerGo.Close() - - link := "/tmp/" + randName() - defer os.RemoveAll(link) - - // now create a symbolic link within the new directory - if output, err := runSftpClient(t, "symlink /bin/sh "+link, "/", hostGo, portGo); err != nil { - t.Fatalf("failed: %v %v", err, string(output)) - } - - // symlink should now exist - if stat, err := os.Lstat(link); err != nil { - t.Fatal(err) - } else if (stat.Mode() & os.ModeSymlink) != os.ModeSymlink { - t.Fatalf("is not a symlink: %v", stat.Mode()) - } -} From 2eabfa33fbdc3c12c663eebd872a5ae12b1c68e9 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 22:06:17 -0700 Subject: [PATCH 35/41] remove user/ directory; uid -> username & gid -> groupname lookup removed --- user/lookup.go | 49 ------- user/lookup_plan9.go | 46 ------- user/lookup_stubs.go | 36 ------ user/lookup_unix.go | 286 ----------------------------------------- user/lookup_windows.go | 149 --------------------- user/user.go | 65 ---------- user/user_test.go | 142 -------------------- 7 files changed, 773 deletions(-) delete mode 100644 user/lookup.go delete mode 100644 user/lookup_plan9.go delete mode 100644 user/lookup_stubs.go delete mode 100644 user/lookup_unix.go delete mode 100644 user/lookup_windows.go delete mode 100644 user/user.go delete mode 100644 user/user_test.go diff --git a/user/lookup.go b/user/lookup.go deleted file mode 100644 index 66d8782e..00000000 --- a/user/lookup.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package user - -// Current returns the current user. -func Current() (*User, error) { - return current() -} - -// Lookup looks up a user by username. If the user cannot be found, the -// returned error is of type UnknownUserError. -func Lookup(username string) (*User, error) { - return lookupUser(username) -} - -// LookupId looks up a user by userid. If the user cannot be found, the -// returned error is of type UnknownUserIdError. -func LookupId(uid string) (*User, error) { - return lookupUserId(uid) -} - -// CurrentGroup returns the current group. -func CurrentGroup() (*Group, error) { - return currentGroup() -} - -// LookupGroup looks up a group by name. If the group cannot be found, the -// returned error is of type UnknownGroupError. -func LookupGroup(groupname string) (*Group, error) { - return lookupGroup(groupname) -} - -// LookupGroupId looks up a group by groupid. If the group cannot be found, the -// returned error is of type UnknownGroupIdError. -func LookupGroupId(gid string) (*Group, error) { - return lookupGroupId(gid) -} - -// In indicates whether the user is a member of the given group. -func (u *User) In(g *Group) (bool, error) { - return userInGroup(u, g) -} - -// Members returns the list of members of the group. -func (g *Group) Members() ([]string, error) { - return groupMembers(g) -} diff --git a/user/lookup_plan9.go b/user/lookup_plan9.go deleted file mode 100644 index f7ef3482..00000000 --- a/user/lookup_plan9.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package user - -import ( - "fmt" - "io/ioutil" - "os" - "syscall" -) - -// Partial os/user support on Plan 9. -// Supports Current(), but not Lookup()/LookupId(). -// The latter two would require parsing /adm/users. -const ( - userFile = "/dev/user" -) - -func current() (*User, error) { - ubytes, err := ioutil.ReadFile(userFile) - if err != nil { - return nil, fmt.Errorf("user: %s", err) - } - - uname := string(ubytes) - - u := &User{ - Uid: uname, - Gid: uname, - Username: uname, - Name: uname, - HomeDir: os.Getenv("home"), - } - - return u, nil -} - -func lookup(username string) (*User, error) { - return nil, syscall.EPLAN9 -} - -func lookupId(uid string) (*User, error) { - return nil, syscall.EPLAN9 -} diff --git a/user/lookup_stubs.go b/user/lookup_stubs.go deleted file mode 100644 index 83f7174a..00000000 --- a/user/lookup_stubs.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !cgo,!windows,!plan9 android - -package user - -import ( - "fmt" - "runtime" -) - -func init() { - implemented = false -} - -func current() (*User, error) { - return nil, fmt.Errorf("user: Current not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) -} - -func lookupUser(username string) (*User, error) { - return nil, fmt.Errorf("user: Lookup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) -} - -func lookupUserId(uid string) (*User, error) { - return nil, fmt.Errorf("user: LookupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) -} - -func lookupGroup(groupname string) (*Group, error) { - return nil, fmt.Errorf("user: LookupGroup not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) -} - -func lookupGroupId(int) (*Group, error) { - return nil, fmt.Errorf("user: LookupGroupId not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) -} diff --git a/user/lookup_unix.go b/user/lookup_unix.go deleted file mode 100644 index c2f970cb..00000000 --- a/user/lookup_unix.go +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris -// +build cgo - -package user - -import ( - "fmt" - "runtime" - "strconv" - "strings" - "syscall" - "unsafe" -) - -/* -#include -#include -#include -#include -#include - -static int mygetpwuid_r(int uid, struct passwd *pwd, - char *buf, size_t buflen, struct passwd **result) { - return getpwuid_r(uid, pwd, buf, buflen, result); -} - -static int mygetgrgid_r(int gid, struct group *grp, - char *buf, size_t buflen, struct group **result) { - return getgrgid_r(gid, grp, buf, buflen, result); -} - -static int mygetgrouplist(const char *user, gid_t group, gid_t *groups, - int *ngroups) { - return getgrouplist(user, group, (void *)groups, ngroups); -} - -static inline gid_t group_at(int i, gid_t *groups) { - return groups[i]; -} - -static inline char **next_member(char **members) { return members + 1; } - -*/ -import "C" - -const ( - userBuffer = iota - groupBuffer -) - -func current() (*User, error) { - return lookupUnixUser(syscall.Getuid(), "", false) -} - -func lookupUser(username string) (*User, error) { - return lookupUnixUser(-1, username, true) -} - -func lookupUserId(uid string) (*User, error) { - i, e := strconv.Atoi(uid) - if e != nil { - return nil, e - } - return lookupUnixUser(i, "", false) -} - -func lookupUnixUser(uid int, username string, lookupByName bool) (*User, error) { - var pwd C.struct_passwd - var result *C.struct_passwd - - var bufSize C.long - if runtime.GOOS == "dragonfly" || runtime.GOOS == "freebsd" { - // DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX - // and just return -1. So just use the same - // size that Linux returns. - bufSize = 1024 - } else { - bufSize = C.sysconf(C._SC_GETPW_R_SIZE_MAX) - if bufSize <= 0 || bufSize > 1<<20 { - return nil, fmt.Errorf("user: unreasonable _SC_GETPW_R_SIZE_MAX of %d", bufSize) - } - } - buf := C.malloc(C.size_t(bufSize)) - defer C.free(buf) - var rv C.int - if lookupByName { - nameC := C.CString(username) - defer C.free(unsafe.Pointer(nameC)) - rv = C.getpwnam_r(nameC, - &pwd, - (*C.char)(buf), - C.size_t(bufSize), - &result) - if rv != 0 { - return nil, fmt.Errorf("user: lookup username %s: %s", username, syscall.Errno(rv)) - } - if result == nil { - return nil, UnknownUserError(username) - } - } else { - // mygetpwuid_r is a wrapper around getpwuid_r to - // to avoid using uid_t because C.uid_t(uid) for - // unknown reasons doesn't work on linux. - rv = C.mygetpwuid_r(C.int(uid), - &pwd, - (*C.char)(buf), - C.size_t(bufSize), - &result) - if rv != 0 { - return nil, fmt.Errorf("user: lookup userid %d: %s", uid, syscall.Errno(rv)) - } - if result == nil { - return nil, UnknownUserIdError(uid) - } - } - u := &User{ - Uid: strconv.Itoa(int(pwd.pw_uid)), - Gid: strconv.Itoa(int(pwd.pw_gid)), - Username: C.GoString(pwd.pw_name), - Name: C.GoString(pwd.pw_gecos), - HomeDir: C.GoString(pwd.pw_dir), - } - // The pw_gecos field isn't quite standardized. Some docs - // say: "It is expected to be a comma separated list of - // personal data where the first item is the full name of the - // user." - if i := strings.Index(u.Name, ","); i >= 0 { - u.Name = u.Name[:i] - } - return u, nil -} - -func currentGroup() (*Group, error) { - return lookupUnixGroup(syscall.Getgid(), "", false, buildGroup) -} - -func lookupGroup(groupname string) (*Group, error) { - return lookupUnixGroup(-1, groupname, true, buildGroup) -} - -func lookupGroupId(gid string) (*Group, error) { - i, e := strconv.Atoi(gid) - if e != nil { - return nil, e - } - return lookupUnixGroup(i, "", false, buildGroup) -} - -func lookupUnixGroup(gid int, groupname string, lookupByName bool, f func(*C.struct_group) *Group) (*Group, error) { - var grp C.struct_group - var result *C.struct_group - - buf, bufSize, err := allocBuffer(groupBuffer) - if err != nil { - return nil, err - } - defer C.free(buf) - - if lookupByName { - nameC := C.CString(groupname) - defer C.free(unsafe.Pointer(nameC)) - rv := C.getgrnam_r(nameC, - &grp, - (*C.char)(buf), - C.size_t(bufSize), - &result) - if rv != 0 { - return nil, fmt.Errorf("group: lookup groupname %s: %s", groupname, syscall.Errno(rv)) - } - if result == nil { - return nil, UnknownGroupError(groupname) - } - } else { - // mygetgrgid_r is a wrapper around getgrgid_r to - // to avoid using gid_t because C.gid_t(gid) for - // unknown reasons doesn't work on linux. - rv := C.mygetgrgid_r(C.int(gid), - &grp, - (*C.char)(buf), - C.size_t(bufSize), - &result) - if rv != 0 { - return nil, fmt.Errorf("group: lookup groupid %d: %s", gid, syscall.Errno(rv)) - } - if result == nil { - return nil, UnknownGroupIdError(gid) - } - } - g := f(&grp) - return g, nil -} - -func buildGroup(grp *C.struct_group) *Group { - g := &Group{ - Gid: strconv.Itoa(int(grp.gr_gid)), - Name: C.GoString(grp.gr_name), - } - return g -} - -func userInGroup(u *User, g *Group) (bool, error) { - if u.Gid == g.Gid { - return true, nil - } - gid, err := strconv.Atoi(g.Gid) - if err != nil { - return false, err - } - - nameC := C.CString(u.Username) - defer C.free(unsafe.Pointer(nameC)) - groupC := C.gid_t(gid) - ngroupsC := C.int(0) - - C.mygetgrouplist(nameC, groupC, nil, &ngroupsC) - ngroups := int(ngroupsC) - - groups := C.malloc(C.size_t(int(unsafe.Sizeof(groupC)) * ngroups)) - defer C.free(groups) - - rv := C.mygetgrouplist(nameC, groupC, (*C.gid_t)(groups), &ngroupsC) - if rv == -1 { - return false, fmt.Errorf("user: membership of %s in %s: %s", u.Username, g.Name, syscall.Errno(rv)) - } - - ngroups = int(ngroupsC) - for i := 0; i < ngroups; i++ { - gid := C.group_at(C.int(i), (*C.gid_t)(groups)) - if g.Gid == strconv.Itoa(int(gid)) { - return true, nil - } - } - return false, nil -} - -func groupMembers(g *Group) ([]string, error) { - var members []string - gid, err := strconv.Atoi(g.Gid) - if err != nil { - return nil, err - } - - _, err = lookupUnixGroup(gid, "", false, func(grp *C.struct_group) *Group { - cmem := grp.gr_mem - for *cmem != nil { - members = append(members, C.GoString(*cmem)) - cmem = C.next_member(cmem) - } - return g - }) - if err != nil { - return nil, err - } - - return members, nil -} - -func allocBuffer(bufType int) (unsafe.Pointer, C.long, error) { - var bufSize C.long - if runtime.GOOS == "freebsd" { - // FreeBSD doesn't have _SC_GETPW_R_SIZE_MAX - // or SC_GETGR_R_SIZE_MAX and just returns -1. - // So just use the same size that Linux returns - bufSize = 1024 - } else { - var size C.int - var constName string - switch bufType { - case userBuffer: - size = C._SC_GETPW_R_SIZE_MAX - constName = "_SC_GETPW_R_SIZE_MAX" - case groupBuffer: - size = C._SC_GETGR_R_SIZE_MAX - constName = "_SC_GETGR_R_SIZE_MAX" - } - bufSize = C.sysconf(size) - if bufSize <= 0 || bufSize > 1<<20 { - return nil, bufSize, fmt.Errorf("user: unreasonable %s of %d", constName, bufSize) - } - } - return C.malloc(C.size_t(bufSize)), bufSize, nil -} diff --git a/user/lookup_windows.go b/user/lookup_windows.go deleted file mode 100644 index 99c325ff..00000000 --- a/user/lookup_windows.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package user - -import ( - "fmt" - "syscall" - "unsafe" -) - -func isDomainJoined() (bool, error) { - var domain *uint16 - var status uint32 - err := syscall.NetGetJoinInformation(nil, &domain, &status) - if err != nil { - return false, err - } - syscall.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) - return status == syscall.NetSetupDomainName, nil -} - -func lookupFullNameDomain(domainAndUser string) (string, error) { - return syscall.TranslateAccountName(domainAndUser, - syscall.NameSamCompatible, syscall.NameDisplay, 50) -} - -func lookupFullNameServer(servername, username string) (string, error) { - s, e := syscall.UTF16PtrFromString(servername) - if e != nil { - return "", e - } - u, e := syscall.UTF16PtrFromString(username) - if e != nil { - return "", e - } - var p *byte - e = syscall.NetUserGetInfo(s, u, 10, &p) - if e != nil { - return "", e - } - defer syscall.NetApiBufferFree(p) - i := (*syscall.UserInfo10)(unsafe.Pointer(p)) - if i.FullName == nil { - return "", nil - } - name := syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:]) - return name, nil -} - -func lookupFullName(domain, username, domainAndUser string) (string, error) { - joined, err := isDomainJoined() - if err == nil && joined { - name, err := lookupFullNameDomain(domainAndUser) - if err == nil { - return name, nil - } - } - name, err := lookupFullNameServer(domain, username) - if err == nil { - return name, nil - } - // domain worked neigher as a domain nor as a server - // could be domain server unavailable - // pretend username is fullname - return username, nil -} - -func newUser(usid *syscall.SID, gid, dir string) (*User, error) { - username, domain, t, e := usid.LookupAccount("") - if e != nil { - return nil, e - } - if t != syscall.SidTypeUser { - return nil, fmt.Errorf("user: should be user account type, not %d", t) - } - domainAndUser := domain + `\` + username - uid, e := usid.String() - if e != nil { - return nil, e - } - name, e := lookupFullName(domain, username, domainAndUser) - if e != nil { - return nil, e - } - u := &User{ - Uid: uid, - Gid: gid, - Username: domainAndUser, - Name: name, - HomeDir: dir, - } - return u, nil -} - -func current() (*User, error) { - t, e := syscall.OpenCurrentProcessToken() - if e != nil { - return nil, e - } - defer t.Close() - u, e := t.GetTokenUser() - if e != nil { - return nil, e - } - pg, e := t.GetTokenPrimaryGroup() - if e != nil { - return nil, e - } - gid, e := pg.PrimaryGroup.String() - if e != nil { - return nil, e - } - dir, e := t.GetUserProfileDirectory() - if e != nil { - return nil, e - } - return newUser(u.User.Sid, gid, dir) -} - -// BUG(brainman): Lookup and LookupId functions do not set -// Gid and HomeDir fields in the User struct returned on windows. - -func newUserFromSid(usid *syscall.SID) (*User, error) { - // TODO(brainman): do not know where to get gid and dir fields - gid := "unknown" - dir := "Unknown directory" - return newUser(usid, gid, dir) -} - -func lookup(username string) (*User, error) { - sid, _, t, e := syscall.LookupSID("", username) - if e != nil { - return nil, e - } - if t != syscall.SidTypeUser { - return nil, fmt.Errorf("user: should be user account type, not %d", t) - } - return newUserFromSid(sid) -} - -func lookupId(uid string) (*User, error) { - sid, e := syscall.StringToSid(uid) - if e != nil { - return nil, e - } - return newUserFromSid(sid) -} diff --git a/user/user.go b/user/user.go deleted file mode 100644 index 9a7b5c16..00000000 --- a/user/user.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package user allows user account lookups by name or id. -package user - -import ( - "strconv" -) - -var implemented = true // set to false by lookup_stubs.go's init - -// User represents a user account. -// -// On posix systems Uid and Gid contain a decimal number -// representing uid and gid. On windows Uid and Gid -// contain security identifier (SID) in a string format. -// On Plan 9, Uid, Gid, Username, and Name will be the -// contents of /dev/user. -type User struct { - Uid string // user id - Gid string // primary group id - Username string - Name string - HomeDir string -} - -// Group represents a group database entry. -type Group struct { - Gid string // group id - Name string // group name -} - -// UnknownUserIdError is returned by LookupId when -// a user cannot be found. -type UnknownUserIdError int - -func (e UnknownUserIdError) Error() string { - return "user: unknown userid " + strconv.Itoa(int(e)) -} - -// UnknownUserError is returned by Lookup when -// a user cannot be found. -type UnknownUserError string - -func (e UnknownUserError) Error() string { - return "user: unknown user " + string(e) -} - -// UnknownGroupIdError is returned by LookupGroupId when -// a group cannot be found. -type UnknownGroupIdError int - -func (e UnknownGroupIdError) Error() string { - return "group: unknown groupid " + strconv.Itoa(int(e)) -} - -// UnknownGroupError is returned by LookupGroup when -// a group cannot be found. -type UnknownGroupError string - -func (e UnknownGroupError) Error() string { - return "group: unknown group " + string(e) -} diff --git a/user/user_test.go b/user/user_test.go deleted file mode 100644 index f34db9bd..00000000 --- a/user/user_test.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package user - -import ( - "runtime" - "testing" -) - -func check(t *testing.T) { - if !implemented { - t.Skip("user: not implemented; skipping tests") - } -} - -func TestCurrent(t *testing.T) { - check(t) - - u, err := Current() - if err != nil { - t.Fatalf("Current: %v", err) - } - if u.HomeDir == "" { - t.Errorf("didn't get a HomeDir") - } - if u.Username == "" { - t.Errorf("didn't get a username") - } -} - -func compare(t *testing.T, want, got *User) { - if want.Uid != got.Uid { - t.Errorf("got Uid=%q; want %q", got.Uid, want.Uid) - } - if want.Username != got.Username { - t.Errorf("got Username=%q; want %q", got.Username, want.Username) - } - if want.Name != got.Name { - t.Errorf("got Name=%q; want %q", got.Name, want.Name) - } - // TODO(brainman): fix it once we know how. - if runtime.GOOS == "windows" { - t.Skip("skipping Gid and HomeDir comparisons") - } - if want.Gid != got.Gid { - t.Errorf("got Gid=%q; want %q", got.Gid, want.Gid) - } - if want.HomeDir != got.HomeDir { - t.Errorf("got HomeDir=%q; want %q", got.HomeDir, want.HomeDir) - } -} - -func TestLookup(t *testing.T) { - check(t) - - if runtime.GOOS == "plan9" { - t.Skipf("Lookup not implemented on %q", runtime.GOOS) - } - - want, err := Current() - if err != nil { - t.Fatalf("Current: %v", err) - } - got, err := Lookup(want.Username) - if err != nil { - t.Fatalf("Lookup: %v", err) - } - compare(t, want, got) -} - -func TestLookupId(t *testing.T) { - check(t) - - if runtime.GOOS == "plan9" { - t.Skipf("LookupId not implemented on %q", runtime.GOOS) - } - - want, err := Current() - if err != nil { - t.Fatalf("Current: %v", err) - } - got, err := LookupId(want.Uid) - if err != nil { - t.Fatalf("LookupId: %v", err) - } - compare(t, want, got) -} - -func compareGroup(t *testing.T, want, got *Group) { - if want.Gid != got.Gid { - t.Errorf("got Gid=%q; want %q", got.Gid, want.Gid) - } - if want.Name != got.Name { - t.Errorf("got Name=%q; want %q", got.Name, want.Name) - } -} - -func TestLookupGroup(t *testing.T) { - check(t) - - // Test LookupGroupId on the current user - want, err := CurrentGroup() - if err != nil { - t.Fatalf("CurrentGroup: %v", err) - } - got, err := LookupGroupId(want.Gid) - if err != nil { - t.Fatalf("LookupGroupId: %v", err) - } - compareGroup(t, want, got) - - members, err := got.Members() - if err != nil { - t.Fatalf("Members: %v", err) - } - for _, user := range members { - u, err := Lookup(user) - if err != nil { - t.Errorf("expected a valid group member; user=%v, err=%v", user, err) - } - isMember, err := u.In(got) - if err != nil { - t.Fatalf("u.In: %v", err) - } - if !isMember { - if runtime.GOOS == "darwin" && got.Name == "staff" { - // staff group on OSX is strange and I don't understand it - } else { - t.Errorf("expected user to be group member; user=%v, group=%v, err=%v", user, got.Name, err) - } - } - } - - // Test Lookup by groupname, using the groupname from LookupId - g, err := LookupGroup(got.Name) - if err != nil { - t.Fatalf("Lookup: %v", err) - } - compareGroup(t, got, g) -} From a94c674357d5e4fa12f2a248ea2f396971055500 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 22:10:18 -0700 Subject: [PATCH 36/41] missed a file... --- server_unix.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/server_unix.go b/server_unix.go index 37e739d6..b0fceec1 100644 --- a/server_unix.go +++ b/server_unix.go @@ -9,8 +9,6 @@ import ( "path" "syscall" "time" - - "github.com/pkg/sftp/user" ) func runLsTypeWord(dirent os.FileInfo) string { @@ -113,13 +111,8 @@ func runLsStatt(dirname string, dirent os.FileInfo, statt *syscall.Stat_t) strin uid := statt.Uid gid := statt.Gid username := fmt.Sprintf("%d", uid) - if usr, err := user.LookupId(username); err == nil { - username = usr.Username - } groupname := fmt.Sprintf("%d", gid) - if grp, err := user.LookupGroupId(groupname); err == nil { - groupname = grp.Name - } + // TODO FIXME: uid -> username, gid -> groupname lookup for ls -l format output mtime := dirent.ModTime() monthStr := mtime.Month().String()[0:3] From bce43f23ac848039330834070e11f85e544511cd Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 22:50:46 -0700 Subject: [PATCH 37/41] Address review comments; alter server test to allow the user / group words to be different --- packet.go | 19 +++++++++--------- server.go | 21 ++++++++----------- server_integration_test.go | 41 +++++++++++++++++++++++--------------- server_unix.go | 4 +++- 4 files changed, 46 insertions(+), 39 deletions(-) diff --git a/packet.go b/packet.go index 148ee305..59549329 100644 --- a/packet.go +++ b/packet.go @@ -70,11 +70,12 @@ func unmarshalUint32(b []byte) (uint32, []byte) { } func unmarshalUint32Safe(b []byte) (uint32, []byte, error) { + var v uint32 = 0 if len(b) < 4 { return 0, nil, shortPacketError } - v := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 - return v, b[4:], nil + v, b = unmarshalUint32(b) + return v, b, nil } func unmarshalUint64(b []byte) (uint64, []byte) { @@ -84,12 +85,12 @@ func unmarshalUint64(b []byte) (uint64, []byte) { } func unmarshalUint64Safe(b []byte) (uint64, []byte, error) { + var v uint64 = 0 if len(b) < 8 { return 0, nil, shortPacketError } - h, b := unmarshalUint32(b) - l, b := unmarshalUint32(b) - return uint64(h)<<32 | uint64(l), b, nil + v, b = unmarshalUint64(b) + return v, b, nil } func unmarshalString(b []byte) (string, []byte) { @@ -115,9 +116,9 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { return fmt.Errorf("marshal2(%#v): binary marshaller failed", err) } if debugDumpTxPacketBytes { - debug("send packet: %s %d bytes %x", fxp(bb[0]).String(), len(bb), bb[1:]) + debug("send packet: %s %d bytes %x", fxp(bb[0]), len(bb), bb[1:]) } else if debugDumpTxPacket { - debug("send packet: %s %d bytes", fxp(bb[0]).String(), len(bb)) + debug("send packet: %s %d bytes", fxp(bb[0]), len(bb)) } l := uint32(len(bb)) hdr := []byte{byte(l >> 24), byte(l >> 16), byte(l >> 8), byte(l)} @@ -148,9 +149,9 @@ func recvPacket(r io.Reader) (uint8, []byte, error) { return 0, nil, err } if debugDumpRxPacketBytes { - debug("recv packet: %s %d bytes %x", fxp(b[0]).String(), l, b[1:]) + debug("recv packet: %s %d bytes %x", fxp(b[0]), l, b[1:]) } else if debugDumpRxPacket { - debug("recv packet: %s %d bytes", fxp(b[0]).String(), l) + debug("recv packet: %s %d bytes", fxp(b[0]), l) } return b[0], b[1:], nil } diff --git a/server.go b/server.go index 984fc751..25641b22 100644 --- a/server.go +++ b/server.go @@ -31,24 +31,19 @@ type Server struct { rootDir string lastId uint32 pktChan chan rxPacket - openFiles map[string]serverOpenFile + openFiles map[string]*os.File openFilesLock *sync.RWMutex handleCount int maxTxPacket uint32 workerCount int } -type serverOpenFile struct { - *os.File - path string -} - -func (svr *Server) nextHandle(path string, f *os.File) string { +func (svr *Server) nextHandle(f *os.File) string { svr.openFilesLock.Lock() defer svr.openFilesLock.Unlock() svr.handleCount++ handle := fmt.Sprintf("%d", svr.handleCount) - svr.openFiles[handle] = serverOpenFile{f, path} + svr.openFiles[handle] = f return handle } @@ -63,7 +58,7 @@ func (svr *Server) closeHandle(handle string) error { } } -func (svr *Server) getHandle(handle string) (serverOpenFile, bool) { +func (svr *Server) getHandle(handle string) (*os.File, bool) { svr.openFilesLock.RLock() defer svr.openFilesLock.RUnlock() f, ok := svr.openFiles[handle] @@ -95,7 +90,7 @@ func NewServer(in io.Reader, out io.WriteCloser, debugStream io.Writer, debugLev readOnly: readOnly, rootDir: rootDir, pktChan: make(chan rxPacket, sftpServerWorkerCount), - openFiles: map[string]serverOpenFile{}, + openFiles: map[string]*os.File{}, openFilesLock: &sync.RWMutex{}, maxTxPacket: 1 << 15, workerCount: sftpServerWorkerCount, @@ -355,7 +350,7 @@ func (p sshFxpOpenPacket) respond(svr *Server) error { if f, err := os.OpenFile(p.Path, osFlags, 0644); err != nil { return svr.sendPacket(statusFromError(p.Id, err)) } else { - handle := svr.nextHandle(p.Path, f) + handle := svr.nextHandle(f) return svr.sendPacket(sshFxpHandlePacket{p.Id, handle}) } } @@ -476,7 +471,7 @@ func (p sshFxpFsetstatPacket) respond(svr *Server) error { b := p.Attrs.([]byte) var err error = nil - debug("fsetstat name \"%s\"", f.path) + debug("fsetstat name \"%s\"", f.Name()) if (p.Flags & ssh_FILEXFER_ATTR_SIZE) != 0 { var size uint64 = 0 if size, b, err = unmarshalUint64Safe(b); err == nil { @@ -497,7 +492,7 @@ func (p sshFxpFsetstatPacket) respond(svr *Server) error { } else { atimeT := time.Unix(int64(atime), 0) mtimeT := time.Unix(int64(mtime), 0) - err = os.Chtimes(f.path, atimeT, mtimeT) + err = os.Chtimes(f.Name(), atimeT, mtimeT) } } if (p.Flags & ssh_FILEXFER_ATTR_UIDGID) != 0 { diff --git a/server_integration_test.go b/server_integration_test.go index bf129165..d4a2237d 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -14,6 +14,7 @@ import ( "net" "os" "os/exec" + "regexp" "strconv" "strings" "testing" @@ -407,27 +408,35 @@ ls -l /usr/bin/ t.Fatal(err) } - if outputGo != outputOp { - diffOffsetLine := 0 - diffOffsetNextLine := 0 + newlineRegex := regexp.MustCompile(`\r*\n`) + spaceRegex := regexp.MustCompile(`\s+`) + outputGoLines := newlineRegex.Split(outputGo, -1) + outputOpLines := newlineRegex.Split(outputOp, -1) + + for i, goLine := range outputGoLines { + if i > len(outputOpLines) { + t.Fatalf("output line count differs") + } + opLine := outputOpLines[i] bad := false - for i := 0; i < len(outputGo) && i < len(outputOp); i++ { - if outputGo[i] != outputOp[i] { - bad = true - } else if outputGo[i] == '\n' { - if !bad { - diffOffsetLine = i - diffOffsetNextLine = i - } else { - diffOffsetNextLine = i - break + if goLine != opLine { + goWords := spaceRegex.Split(goLine, -1) + opWords := spaceRegex.Split(opLine, -1) + // allow words[2] and [3] to be different as these are users & groups + for j, goWord := range goWords { + if j > len(opWords) { + bad = true + } + opWord := opWords[j] + if goWord != opWord && j != 2 && j != 3 { + bad = true } } } - t.Errorf("outputs differ, go:\n%v\nopenssh:\n%v\n", - outputGo[diffOffsetLine:diffOffsetNextLine], - outputOp[diffOffsetLine:diffOffsetNextLine]) + if bad { + t.Errorf("outputs differ, go:\n%v\nopenssh:\n%v\n", goLine, opLine) + } } } diff --git a/server_unix.go b/server_unix.go index b0fceec1..c13b1abc 100644 --- a/server_unix.go +++ b/server_unix.go @@ -119,7 +119,9 @@ func runLsStatt(dirname string, dirent os.FileInfo, statt *syscall.Stat_t) strin day := mtime.Day() year := mtime.Year() now := time.Now() - isOld := mtime.Before(now.Add(-1 * time.Hour * 24 * 182)) + nowDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + isOld := mtime.Before(nowDay.Add(-time.Hour * 24 * 365 / 2)) + yearOrTime := fmt.Sprintf("%02d:%02d", mtime.Hour(), mtime.Minute()) if isOld { yearOrTime = fmt.Sprintf("%d", year) From fe6bfd71e7cb588e9237ef77147b160a8a4025ec Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 22:53:32 -0700 Subject: [PATCH 38/41] remove testdata/ from .gitignore --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 00ea0bf2..a864e484 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ .*.swo .*.swp -testdata/ server_standalone/server_standalone From a6fc4b8c1f99745cc03640c6a72f2588423f8ba1 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Mon, 7 Sep 2015 23:04:52 -0700 Subject: [PATCH 39/41] Add comments for *Client.Stat and *Client.Lstat --- client.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client.go b/client.go index 0ad90470..8ee052e8 100644 --- a/client.go +++ b/client.go @@ -259,6 +259,8 @@ func (c *Client) opendir(path string) (string, error) { } } +// Stat returns a FileInfo structure describing the file specified by path 'p'. +// If 'p' is a symbolic link, the returned FileInfo structure describes the referent file. func (c *Client) Stat(p string) (os.FileInfo, error) { id := c.nextId() typ, data, err := c.sendRequest(sshFxpStatPacket{ @@ -283,6 +285,8 @@ func (c *Client) Stat(p string) (os.FileInfo, error) { } } +// Lstat returns a FileInfo structure describing the file specified by path 'p'. +// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link. func (c *Client) Lstat(p string) (os.FileInfo, error) { id := c.nextId() typ, data, err := c.sendRequest(sshFxpLstatPacket{ From d7309968cd64a6a7e80b8ec5412107ab38035626 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Tue, 8 Sep 2015 17:03:18 -0700 Subject: [PATCH 40/41] add more tests; bug setting S_IFREG caused openssh sftp to refuse to get the file --- attrs.go | 2 +- server.go | 5 +++ server_integration_test.go | 77 +++++++++++++++++++++++++++++++++++--- server_unix.go | 3 +- 4 files changed, 78 insertions(+), 9 deletions(-) diff --git a/attrs.go b/attrs.go index a1f43cc0..41e0e277 100644 --- a/attrs.go +++ b/attrs.go @@ -220,7 +220,7 @@ func fromFileMode(mode os.FileMode) uint32 { ret |= syscall.S_IFSOCK } - if mode == 0 { + if mode&os.ModeType == 0 { ret |= syscall.S_IFREG } ret |= uint32(mode & os.ModePerm) diff --git a/server.go b/server.go index 25641b22..3b6c9175 100644 --- a/server.go +++ b/server.go @@ -149,6 +149,11 @@ func (svr *Server) Serve() error { } } fmt.Fprintf(svr.debugStream, "sftp server run finished\n") + // close any still-open files + for handle, file := range svr.openFiles { + fmt.Fprintf(svr.debugStream, "sftp server file with handle '%v' left open: %v\n", handle, file.Name()) + file.Close() + } return svr.out.Close() } diff --git a/server_integration_test.go b/server_integration_test.go index d4a2237d..d2fc6ebb 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -440,13 +440,18 @@ ls -l /usr/bin/ } } -func randName() string { - r := rand.New(rand.NewSource(time.Now().Unix())) - data := make([]byte, 16) - for i := 0; i < 16; i++ { - data[i] = byte(r.Uint32()) +var rng = rand.New(rand.NewSource(time.Now().Unix())) + +func randData(length int) []byte { + data := make([]byte, length) + for i := 0; i < length; i++ { + data[i] = byte(rng.Uint32()) } - return "sftp." + hex.EncodeToString(data) + return data +} + +func randName() string { + return "sftp." + hex.EncodeToString(randData(16)) } func TestServerMkdirRmdir(t *testing.T) { @@ -495,3 +500,63 @@ func TestServerSymlink(t *testing.T) { t.Fatalf("is not a symlink: %v", stat.Mode()) } } + +func TestServerPut(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + tmpFileLocal := "/tmp/" + randName() + tmpFileRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpFileLocal) + defer os.RemoveAll(tmpFileRemote) + + t.Logf("put: local %v remote %v", tmpFileLocal, tmpFileRemote) + + // create a file with random contents. This will be the local file pushed to the server + tmpFileLocalData := randData(10 * 1024 * 1024) + if err := ioutil.WriteFile(tmpFileLocal, tmpFileLocalData, 0644); err != nil { + t.Fatal(err) + } + + // sftp the file to the server + if output, err := runSftpClient(t, "put "+tmpFileLocal+" "+tmpFileRemote, "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + // tmpFile2 should now exist, with the same contents + if tmpFileRemoteData, err := ioutil.ReadFile(tmpFileRemote); err != nil { + t.Fatal(err) + } else if string(tmpFileLocalData) != string(tmpFileRemoteData) { + t.Fatal("contents of file incorrect after put") + } +} + +func TestServerGet(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + tmpFileLocal := "/tmp/" + randName() + tmpFileRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpFileLocal) + defer os.RemoveAll(tmpFileRemote) + + t.Logf("get: local %v remote %v", tmpFileLocal, tmpFileRemote) + + // create a file with random contents. This will be the remote file pulled from the server + tmpFileRemoteData := randData(10 * 1024 * 1024) + if err := ioutil.WriteFile(tmpFileRemote, tmpFileRemoteData, 0644); err != nil { + t.Fatal(err) + } + + // sftp the file to the server + if output, err := runSftpClient(t, "get "+tmpFileRemote+" "+tmpFileLocal, "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + // tmpFile2 should now exist, with the same contents + if tmpFileLocalData, err := ioutil.ReadFile(tmpFileLocal); err != nil { + t.Fatal(err) + } else if string(tmpFileLocalData) != string(tmpFileRemoteData) { + t.Fatal("contents of file incorrect after put") + } +} diff --git a/server_unix.go b/server_unix.go index c13b1abc..8c3f0b44 100644 --- a/server_unix.go +++ b/server_unix.go @@ -119,8 +119,7 @@ func runLsStatt(dirname string, dirent os.FileInfo, statt *syscall.Stat_t) strin day := mtime.Day() year := mtime.Year() now := time.Now() - nowDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) - isOld := mtime.Before(nowDay.Add(-time.Hour * 24 * 365 / 2)) + isOld := mtime.Before(now.Add(-time.Hour * 24 * 365 / 2)) yearOrTime := fmt.Sprintf("%02d:%02d", mtime.Hour(), mtime.Minute()) if isOld { From e4daa2d013f20e3645b6742e97a2a76e55f11c27 Mon Sep 17 00:00:00 2001 From: Mark Sheahan Date: Tue, 8 Sep 2015 17:54:28 -0700 Subject: [PATCH 41/41] added server integration test for recursive put and recursive get --- server_integration_test.go | 111 +++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/server_integration_test.go b/server_integration_test.go index d2fc6ebb..2d887f03 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -14,12 +14,15 @@ import ( "net" "os" "os/exec" + "path" + "path/filepath" "regexp" "strconv" "strings" "testing" "time" + "github.com/kr/fs" "golang.org/x/crypto/ssh" ) @@ -560,3 +563,111 @@ func TestServerGet(t *testing.T) { t.Fatal("contents of file incorrect after put") } } + +func compareDirectoriesRecursive(t *testing.T, aroot, broot string) { + walker := fs.Walk(aroot) + for walker.Step() { + if err := walker.Err(); err != nil { + t.Fatal(err) + } + // find paths + aPath := walker.Path() + aRel, err := filepath.Rel(aroot, aPath) + if err != nil { + t.Fatalf("could not find relative path for %v: %v", aPath, err) + } + bPath := path.Join(broot, aRel) + + if aRel == "." { + continue + } + + //t.Logf("comparing: %v a: %v b %v", aRel, aPath, bPath) + + // if a is a link, the sftp recursive copy won't have copied it. ignore + aLink, err := os.Lstat(aPath) + if err != nil { + t.Fatalf("could not lstat %v: %v", aPath, err) + } + if aLink.Mode()&os.ModeSymlink != 0 { + continue + } + + // stat the files + aFile, err := os.Stat(aPath) + if err != nil { + t.Fatalf("could not stat %v: %v", aPath, err) + } + bFile, err := os.Stat(bPath) + if err != nil { + t.Fatalf("could not stat %v: %v", bPath, err) + } + + // compare stats, with some leniency for the timestamp + if aFile.Mode() != bFile.Mode() { + t.Fatalf("modes different for %v: %v vs %v", aRel, aFile.Mode(), bFile.Mode()) + } + if !aFile.IsDir() { + if aFile.Size() != bFile.Size() { + t.Fatalf("sizes different for %v: %v vs %v", aRel, aFile.Size(), bFile.Size()) + } + } + timeDiff := aFile.ModTime().Sub(bFile.ModTime()) + if timeDiff > time.Second || timeDiff < -time.Second { + t.Fatalf("mtimes different for %v: %v vs %v", aRel, aFile.ModTime(), bFile.ModTime()) + } + + // compare contents + if !aFile.IsDir() { + if aContents, err := ioutil.ReadFile(aPath); err != nil { + t.Fatal(err) + } else if bContents, err := ioutil.ReadFile(bPath); err != nil { + t.Fatal(err) + } else if string(aContents) != string(bContents) { + t.Fatalf("contents different for %v", aRel) + } + } + } +} + +func TestServerPutRecursive(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + dirLocal, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + tmpDirRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpDirRemote) + + t.Logf("put recursive: local %v remote %v", dirLocal, tmpDirRemote) + + // push this directory (source code etc) recursively to the server + if output, err := runSftpClient(t, "mkdir "+tmpDirRemote+"\r\nput -r -P "+dirLocal+"/ "+tmpDirRemote+"/", "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + compareDirectoriesRecursive(t, dirLocal, path.Join(tmpDirRemote, path.Base(dirLocal))) +} + +func TestServerGetRecursive(t *testing.T) { + listenerGo, hostGo, portGo := testServer(t, GOLANG_SFTP, READONLY) + defer listenerGo.Close() + + dirRemote, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + tmpDirLocal := "/tmp/" + randName() + defer os.RemoveAll(tmpDirLocal) + + t.Logf("get recursive: local %v remote %v", tmpDirLocal, dirRemote) + + // pull this directory (source code etc) recursively from the server + if output, err := runSftpClient(t, "lmkdir "+tmpDirLocal+"\r\nget -r -P "+dirRemote+"/ "+tmpDirLocal+"/", "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + compareDirectoriesRecursive(t, dirRemote, path.Join(tmpDirLocal, path.Base(dirRemote))) +}