#!/usr/bin/python

import socket
import sys
import os
import fuse
import stat
import errno
import threading
import traceback
import getopt

from py9p import marshal9p
from py9p import py9p


class Error(py9p.Error):
    pass

fuse.fuse_python_api = (0, 2)


def mode2stat(mode):
    return (mode & 0o777) |\
            ((mode & py9p.DMDIR ^ py9p.DMDIR) >> 16) |\
            ((mode & py9p.DMDIR) >> 17) |\
            ((mode & py9p.DMSYMLINK) >> 10) |\
            ((mode & py9p.DMSYMLINK) >> 12) |\
            ((mode & py9p.DMSETUID) >> 8) |\
            ((mode & py9p.DMSETGID) >> 8) |\
            ((mode & py9p.DMSTICKY) >> 7)

def mode2plan(mode):
    return (mode & 0o777) | \
            ((mode & stat.S_IFDIR) << 17) |\
            ((mode & stat.S_ISUID) << 8) |\
            ((mode & stat.S_ISGID) << 8) |\
            ((mode & stat.S_ISVTX) << 7) |\
            (int(mode == stat.S_IFLNK) << 25)

class fStat(fuse.Stat):
    """
    FUSE stat structure, that will represent PyVFS Inode
    """
    def __init__(self, inode):
        self.st_mode = mode2stat(inode.mode)
        self.st_ino = 0
        self.st_dev = 0
        if inode.mode & stat.S_IFDIR:
            self.st_nlink = inode.length
        else:
            self.st_nlink = 1
        self.st_uid = inode.uidnum
        self.st_gid = inode.gidnum
        self.st_size = inode.length
        self.st_atime = inode.atime
        self.st_mtime = inode.mtime
        self.st_ctime = inode.mtime


def guard(c):
    def wrapped(self, *argv, **kwarg):
        ret = -errno.EIO
        print ">> ",str(c)
        with self._lock:
            try:
                ret = c(self, *argv, **kwarg)
            except:
                traceback.print_exc()
        print "!! ",type(ret)
        return ret
    return wrapped

class ClientFS(fuse.Fuse):
    def __init__(self, fd, user, authmode, keyfile, mountpoint):

        self.cache = {}
        self.client = py9p.Client(fd,
                chatty=False,
                user=user,
                authmode=authmode,
                key=keyfile,
                dotu=1)

        fuse.Fuse.__init__(self, version="%prog " + fuse.__version__,
                dash_s_do='undef')

        self.fuse_args.setmod('foreground')
        self.fuse_args.add('debug')
        self.fuse_args.mountpoint = os.path.realpath(mountpoint)

        self._lock = threading.Lock()

    def open(self, path, flags):
        pass

    def chmod(self, path, mode):
        pass

    def chown(self, path, uid, gid):
        pass

    def utime(self, path, times):
        pass

    @guard
    def mknod(self, path, mode, dev):
        if dev != 0:
            return -errno.ENOSYS
        # FIXME
        if not mode & stat.S_IFREG:
            mode |= stat.S_IFDIR
        try:
            self.client._walk(self.client.ROOT,
                    63, filter(None, path.split("/")))
            self.client._open(63, py9p.OTRUNC)
            self.client._clunk(63)
        except py9p.RpcError as e:
            if e.message == "file not found":
                    self.client._walk(self.client.ROOT,
                            63, filter(None, path.split("/"))[:-1])
                    self.client._create(63,
                            filter(None, path.split("/"))[-1],
                            mode2plan(mode), 0)
                    self.client._clunk(63)
            else:
                return -errno.EIO
 
    @guard
    def truncate(self, path, size):
        if size != 0:
            return -errno.ENOSYS
        self.client._walk(self.client.ROOT,
                63, filter(None, path.split("/")))
        self.client._open(63, py9p.OTRUNC)
        self.client._clunk(63)

    @guard
    def write(self, path, buf, offset):
        self.client._walk(self.client.ROOT,
                63, filter(None, path.split("/")))
        self.client._open(63, py9p.OWRITE)
        for i in range(len(buf) / 8192 + 1):
            start = i * 8192
            length = 8192 + min(0, (len(buf) - ((i + 1) * 8192)))
            self.client._write(63, offset + start, buf[start:length])
        self.client._clunk(63)
        return len(buf)

    @guard
    def read(self, path, size, offset):
        self.client._walk(self.client.ROOT,
                63, filter(None, path.split("/")))
        self.client._open(63, py9p.OREAD)
        data = bytes()
        for i in range(size / 8192 + 1):
            ret = self.client._read(63, offset, 8192)
            data += ret.data
            offset += len(ret.data)
        self.client._clunk(63)
        return data[:size]

    @guard
    def readlink(self, path):
        self.client._walk(self.client.ROOT,
                63, filter(None, path.split("/")))
        self.client._open(63, py9p.OREAD)
        ret = self.client._read(63, 0, 8192)
        self.client._clunk(63)
        return ret.data

    @guard
    def getattr(self, path):
        if py9p.hash8(path) in self.cache:
            return fStat(self.cache[py9p.hash8(path)])
        try:
            self.client._walk(self.client.ROOT,
                    63, filter(None, path.split("/")))
            ret = self.client._stat(63).stat[0]
        except py9p.RpcError as e:
            if e.message == "file not found":
                return -errno.ENOENT
            else:
                return -errno.EIO
        s = fStat(ret)
        self.client._clunk(63)
        return s

    def readdir(self, path, offset):
        with self._lock:
            self.client._walk(self.client.ROOT,
                    64, filter(None, path.split("/")))
            self.client._open(64, py9p.OREAD)
            offset = 0
            dirs = []
            while True:
                ret = self.client._read(64, offset, 8192)
                if len(ret.data) == 0:
                    break
                offset += len(ret.data)
                p9 = marshal9p.Marshal9P()
                p9.setBuf(ret.data)
                fcall = py9p.Fcall(py9p.Rstat)
                p9.decstat(fcall, 0)
                dirs.extend(fcall.stat)
            self.client._clunk(64)

        for i in dirs:
            self.cache[i.qid.path] = i
            yield fuse.Direntry(i.name)

def usage(prog):
    print("""
Usage: %s [-c mode] [-k file] [-l user] [-p port] \
user@server:port mountpoint

 -c mode  -- authentication mode to use (none|pki)
 -k file  -- path to the private RSA key for PKI (implies -c pki)
 -l user  -- username to use in authentication
 -p port  -- TCP port to use
    """ % (prog))

def main():
    prog = sys.argv[0]
    args = sys.argv[1:]
    port = py9p.PORT
    user = os.environ.get('USER', None)
    server = None
    mountpoint = None
    authmode = None
    keyfile = None

    try:
        opts, args = getopt.getopt(args, "c:k:l:p:")
        assert len(args) == 2
    except:
        usage(prog)
        return 255


    for opt, optarg in opts:
        if opt == "-c":
            authmode = optarg
        elif opt == "-k":
            authmode = "pki"
            keyfile = optarg
        elif opt == "-l":
            user = optarg
        elif opt == "-p":
            port = optarg


    try:
        target = []
        for x in args[0].split("@"):
            target.extend(x.split(":"))
        assert len(target) in (1,2,3)
    except:
        print("invalid target host specification")
        usage(prog)
        return 254

    if len(target) == 3:
        user = target[0]
        server = target[1]
    else:
        server = target[0]

    try:
        if len(target) >= 2:
            port = target[-1]
        port = int(port)
    except:
        print("invalid port specification")
        usage(prog)
        return 253

    mountpoint = args[1]

    try:
        assert user is not None
        assert mountpoint is not None
        assert server is not None
    except:
        usage(prog)
        return 252

    if server.find("/") > -1:
        sock = socket.socket(socket.AF_UNIX)
        connect = server
    else:
        sock = socket.socket(socket.AF_INET)
        connect = (server, port)
    try:
        sock.connect(connect,)
    except socket.error as e:
        print("%s: %s" % (server, e.args[1]))
        return 251

    try:
        fs = ClientFS(sock,
                user,
                authmode,
                keyfile,
                mountpoint)
        fs.main()
    except py9p.Error as e:
        print("connection error: %s" % (e))
        return 250
    except Exception as e:
        print type(e)
        print("error: %s" %(e))
        return 249

if __name__ == "__main__":
    try:
        ret = main()
        if ret is not None:
            sys.exit(ret)
    except KeyboardInterrupt:
        print("interrupted")
    except EOFError:
        print("done")
    except Exception as e:
        print("unhandled exception: %s" % (e))
        raise
