#!/usr/bin/python

import socket
import sys
import os
import fuse
import stat
import errno
import time
import getopt

from py9p import marshal9p
from py9p import py9p

MIN_TFID = 64
MAX_TFID = 1023
MIN_FID = 1024
MAX_FID = 65535


errcodes = {
        "usage": (255, ""),
        "host": (254, "invalid host specification"),
        "port": (253, "invalid port specification"),
        "timeout": (252, "invalid timeout specification"),
        "key": (155, "key decryption error, probably bad password \
or wrong keyfile"),
        "socket": (154, "socket error"),
        "9connect": (153, "9p server connection error"),
        "undef": (100, "error")}

rpccodes = {
        "duplicate fid": -errno.EBADFD,
        "unknown fid": -errno.EBADFD,
        "create prohibited": -errno.EPERM,
        "remove prohibited": -errno.EPERM,
        "stat prohibited": -errno.EPERM,
        "wstat prohibited": -errno.EPERM,
        "permission denied": -errno.EPERM}


def usage():
    print("""
Usage: fuse9p [-d] [-c mode] [-k file] [-l user] [-p port] [-t secs] \
user@server:port mountpoint

 -c mode  -- authentication mode to use (none|pki)
 -d       -- turn on debug mode and run in foreground
 -k file  -- path to the private RSA key for PKI (implies -c pki)
 -l user  -- username to use in authentication
 -p port  -- TCP port to use
 -t secs  -- timeout for the socket
    """)


def paluu(code, payload=None):
    print(errcodes[code][1])
    if errcodes[code][0] > 200:
        usage()
    if payload is not None:
        print(str(payload))
    sys.exit(errcodes[code][0])


class Error(py9p.Error):
    pass


class NoFidError(Exception):
    pass


fuse.fuse_python_api = (0, 2)


def open2stat(mode):
    return (mode & 3) |\
            ((mode & py9p.OAPPEND) >> 4) |\
            ((mode & py9p.OEXCL) >> 5) |\
            ((mode & py9p.OTRUNC) << 5)


def open2plan(mode):
    return (mode & 3) |\
            ((mode & os.O_APPEND) << 4) |\
            ((mode & os.O_EXCL) << 5) |\
            ((mode & os.O_TRUNC) >> 5)


def mode2stat(mode):
    return (mode & 0o777) |\
            ((mode & py9p.DMDIR ^ py9p.DMDIR) >> 16) |\
            ((mode & py9p.DMDIR) >> 17) |\
            ((mode & py9p.DMSYMLINK) >> 10) |\
            ((mode & py9p.DMSYMLINK) >> 12) |\
            ((mode & py9p.DMSETUID) >> 8) |\
            ((mode & py9p.DMSETGID) >> 8) |\
            ((mode & py9p.DMSTICKY) >> 7)


def mode2plan(mode):
    return (mode & 0o777) | \
            ((mode & stat.S_IFDIR) << 17) |\
            ((mode & stat.S_ISUID) << 8) |\
            ((mode & stat.S_ISGID) << 8) |\
            ((mode & stat.S_ISVTX) << 7) |\
            (int(mode == stat.S_IFLNK) << 25)


class fStat(fuse.Stat):
    """
    FUSE stat structure, that will represent PyVFS Inode
    """
    def __init__(self, inode):
        self.st_mode = mode2stat(inode.mode)
        self.st_ino = 0
        self.st_dev = 0
        if inode.mode & stat.S_IFDIR:
            self.st_nlink = inode.length
        else:
            self.st_nlink = 1
        self.st_uid = inode.uidnum
        self.st_gid = inode.gidnum
        self.st_size = inode.length
        self.st_atime = inode.atime
        self.st_mtime = inode.mtime
        self.st_ctime = inode.mtime


def guard(c):
    def wrapped(self, *argv, **kwarg):
        ret = -errno.EIO
        tries = 0
        tfid = None
        while True:
            if tries > 2:
                paluu("socket")
            tries += 1
            try:
                tfid = self.tfidcache.acquire()
                ret = c(self, tfid.fid, *argv, **kwarg)
                break
            except NoFidError:
                return -errno.EMFILE
            except py9p.RpcError as e:
                return rpccodes.get(e.message, -errno.EIO)
            except:
                self._reconnect()
        if tfid is not None:
            self.tfidcache.release(tfid)
        return ret
    return wrapped


class FidCache(dict):
    def __init__(self, start=MIN_FID, limit=MAX_FID):
        dict.__init__(self)
        self.start = start
        self.limit = limit
        self.fids = list(range(self.start, self.limit + 1))

    def acquire(self):
        if len(self.fids) < 1:
            raise NoFidError()
        return Fid(self.fids.pop(0))

    def release(self, f):
        self.fids.append(f.fid)


class Fid(object):
    def __init__(self, fid):
        self.fid = fid


class ClientFS(fuse.Fuse):
    def __init__(self, address, credentials, mountpoint,
            debug=False, timeout=10):

        self.address = address
        self.credentials = credentials
        self.debug = debug
        self.timeout = timeout
        self.sock = None
        self._reconnect()
        self.dircache = {}
        self.fidcache = FidCache()
        self.tfidcache = FidCache(start=MIN_TFID, limit=MAX_TFID)

        fuse.Fuse.__init__(self, version="%prog " + fuse.__version__,
                dash_s_do='undef')

        if debug:
            self.fuse_args.setmod('foreground')
            self.fuse_args.add('debug')
        self.fuse_args.mountpoint = os.path.realpath(mountpoint)

    def _reconnect(self):
        try:
            self.sock.close()
        except:
            pass

        if self.address[0].find("/") > -1:
            self.sock = socket.socket(socket.AF_UNIX)
        else:
            self.sock = socket.socket(socket.AF_INET)
        self.sock.settimeout(10)
        try:
            self.sock.connect(self.address)
        except Exception as e:
            paluu("socket", e)
        self.client = py9p.Client(
                fd=self.sock,
                chatty=self.debug,
                credentials=self.credentials,
                dotu=1)

    @guard
    def open(self, tfid, path, mode):
        f = self.fidcache.acquire()
        try:
            self.client._walk(self.client.ROOT,
                    f.fid, filter(None, path.split("/")))
            self.client._open(f.fid, open2plan(mode))
        except:
            self.fidcache.release(path)
        return f

    @guard
    def _wstat(self, tfid, path,
            uid=py9p.ERRUNDEF,
            gid=py9p.ERRUNDEF,
            mode=py9p.ERRUNDEF):
        self.client._walk(self.client.ROOT,
                tfid, filter(None, path.split("/")))
        stats = [py9p.Dir(
            dotu=1,
            type=0,
            dev=0,
            qid=py9p.Qid(0, 0, py9p.hash8(path)),
            mode=mode,
            atime=int(time.time()),
            mtime=int(time.time()),
            length=py9p.ERRUNDEF,
            name=path.split("/")[-1],
            uid="",
            gid="",
            muid="",
            extension="",
            uidnum=uid,
            gidnum=gid,
            muidnum=py9p.ERRUNDEF), ]
        self.client._wstat(tfid, stats)
        self.client._clunk(tfid)

    def chmod(self, path, mode):
        return self._wstat(path, mode=mode2plan(mode))

    def chown(self, path, uid, gid):
        return self._wstat(path, uid, gid)

    def utime(self, path, times):
        pass

    @guard
    def unlink(self, tfid, path):
        self.client._walk(self.client.ROOT,
                tfid, filter(None, path.split("/")))
        self.client._remove(tfid)

    def rmdir(self, path):
        self.unlink(path)

    @guard
    def symlink(self, tfid, target, path):
        self.client._walk(self.client.ROOT, tfid,
                filter(None, path.split("/"))[:-1])
        self.client._create(tfid, filter(None, path.split("/"))[-1],
                py9p.DMSYMLINK, 0, target)
        self.client._clunk(tfid)

    @guard
    def mknod(self, tfid, path, mode, dev):
        if dev != 0:
            return -errno.ENOSYS
        # FIXME
        if not mode & stat.S_IFREG:
            mode |= stat.S_IFDIR
        try:
            self.client._walk(self.client.ROOT,
                    tfid, filter(None, path.split("/")))
            self.client._open(tfid, py9p.OTRUNC)
            self.client._clunk(tfid)
        except py9p.RpcError as e:
            if e.message == "file not found":
                    self.client._walk(self.client.ROOT,
                            tfid, filter(None, path.split("/"))[:-1])
                    self.client._create(tfid,
                            filter(None, path.split("/"))[-1],
                            mode2plan(mode), 0)
                    self.client._clunk(tfid)
            else:
                return -errno.EIO

    def mkdir(self, path, mode):
        return self.mknod(path, mode | stat.S_IFDIR, 0)

    @guard
    def truncate(self, tfid, path, size):
        if size != 0:
            return -errno.ENOSYS
        self.client._walk(self.client.ROOT,
                tfid, filter(None, path.split("/")))
        self.client._open(tfid, py9p.OTRUNC)
        self.client._clunk(tfid)

    @guard
    def write(self, tfid, path, buf, offset, f):
        for i in range(len(buf) / 8192 + 1):
            start = i * 8192
            length = 8192 + min(0, (len(buf) - ((i + 1) * 8192)))
            self.client._write(f.fid, offset + start, buf[start:length])
        return len(buf)

    @guard
    def read(self, tfid, path, size, offset, f):
        data = bytes()
        for i in range(size / 8192 + 1):
            ret = self.client._read(f.fid, offset, 8192)
            data += ret.data
            offset += len(ret.data)
        return data[:size]

    @guard
    def release(self, tfid, path, flags, f):
        try:
            self.client._clunk(f.fid)
            self.fidcache.release(f)
        except:
            pass

    @guard
    def readlink(self, tfid, path):
        self.client._walk(self.client.ROOT,
                tfid, filter(None, path.split("/")))
        self.client._open(tfid, py9p.OREAD)
        ret = self.client._read(tfid, 0, 8192)
        self.client._clunk(tfid)
        return ret.data

    @guard
    def getattr(self, tfid, path):
        if py9p.hash8(path) in self.dircache:
            return fStat(self.dircache[py9p.hash8(path)])
        try:
            self.client._walk(self.client.ROOT,
                    tfid, filter(None, path.split("/")))
            ret = self.client._stat(tfid).stat[0]
        except py9p.RpcError as e:
            if e.message == "file not found":
                return -errno.ENOENT
            else:
                return -errno.EIO
        s = fStat(ret)
        self.client._clunk(tfid)
        return s

    @guard
    def readdir(self, tfid, path, offset):
        self.client._walk(self.client.ROOT,
                tfid, filter(None, path.split("/")))
        self.client._open(tfid, py9p.OREAD)
        offset = 0
        dirs = []
        while True:
            ret = self.client._read(tfid, offset, 8192)
            if len(ret.data) == 0:
                break
            offset += len(ret.data)
            p9 = marshal9p.Marshal9P(dotu=1)
            p9.setBuf(ret.data)
            fcall = py9p.Fcall(py9p.Rstat)
            p9.decstat(fcall, 0)
            dirs.extend(fcall.stat)
        self.client._clunk(tfid)

        for i in dirs:
            self.dircache[i.qid.path] = i
            yield fuse.Direntry(i.name)


if __name__ == "__main__":
    prog = sys.argv[0]
    args = sys.argv[1:]
    port = py9p.PORT
    user = os.environ.get('USER', None)
    server = None
    mountpoint = None
    authmode = None
    keyfile = None
    debug = False
    timeout = 10

    try:
        opts, args = getopt.getopt(args, "dc:k:l:p:t:")
        assert len(args) == 2
    except:
        paluu("usage")

    for opt, optarg in opts:
        if opt == "-c":
            authmode = optarg
        elif opt == "-d":
            debug = True
        elif opt == "-k":
            authmode = "pki"
            keyfile = optarg
        elif opt == "-l":
            user = optarg
        elif opt == "-p":
            port = optarg
        elif opt == "-t":
            timeout = optarg

    try:
        target = []
        for x in args[0].split("@"):
            target.extend(x.split(":"))
        assert len(target) in (1, 2, 3)
    except:
        paluu("host")

    if len(target) == 3:
        user = target[0]
        server = target[1]
    else:
        server = target[0]

    try:
        if len(target) >= 2:
            port = target[-1]
        port = int(port)
    except:
        paluu("port")

    mountpoint = args[1]

    try:
        timeout = int(timeout)
    except:
        paluu("timeout")

    try:
        assert user is not None
        assert mountpoint is not None
        assert server is not None
    except:
        paluu("usage")

    try:
        credentials = py9p.Credentials(user, authmode, "", keyfile)
    except:
        paluu("key")

    try:
        fs = ClientFS((server, port),
                credentials,
                mountpoint,
                debug,
                timeout)
        fs.main()
    except py9p.Error as e:
        paluu("9connect", e)
    except Exception as e:
        paluu("undef", e)
