#!/usr/bin/python3

from optparse import OptionParser
import sys
import os
import time
import socket
import struct
import json
import binascii
import re
import logging

try:
    import asyncio
except ImportError as e:
    # system doesn't provide asyncio module (eg. Python 3.3),
    # so use module bundled with silentarmy
    p = os.path.join(sys.path[0], 'thirdparty', 'asyncio')
    sys.path.insert(1, p) # make it the 2nd path
    import asyncio

verbose_level = 0

def b2hex(b):
    '''Convert a bytes object to a hex string.'''
    # This is equivalent to bytes.hex() in Python 3.5.
    return binascii.hexlify(b).decode('ascii')

def warn(msg):
    sys.stderr.write(msg + '\n')

def fatal(msg):
    sys.stderr.write(msg + '\n')
    sys.exit(1)

def verbose(msg):
    if verbose_level > 0:
        print(msg)

def very_verbose(msg):
    if verbose_level > 1:
        print(msg)

def my_ensure_future(coro):
    loop = asyncio.get_event_loop()
    task = loop.create_task(coro)
    return task

def parse_url(url):
    '''Return (host, port, xnsub) from "stratum+tcp://host:port" optionally
    postfixed with #xnsub or /#xnsub (in which case xnsub is set to True)'''
    prefix = 'stratum+tcp://'
    if not url.startswith(prefix):
        fatal('Invalid stratum url: %s' % url)
    xnsub = False
    if url.endswith('#xnsub'):
        xnsub = True
        url = url[:-len('#xnsub')].strip('/')
    url = url[len(prefix):]
    colon = url.rfind(':')
    if colon == -1:
        fatal('Invalid stratum url: %s' % url)
    host = url[:colon]
    port = url[colon + 1:]
    if not host:
        fatal('Invalid host: %s' % host)
    try:
        port = int(port)
    except ValueError as e:
        fatal('Invalid port number: %s' % port)
    return (url[:colon], int(url[colon + 1:]), xnsub)

def decode_solver_line(line):
    '''Decode a line read from the solver. 3 types of lines exist:
    "sol: <job_id> <ntime> <nonce_rightpart> <solSize+sol>"
    "status: <nr_sols> <nr_shares>"
    "<arbitrary log message>"
    '''
    m = re.match(r'(?i)^(sol): ([^ ]+) ([0-9a-f]{8}) ([0-9a-f]+) ([0-9a-f]+)$',
            line)
    if m is not None:
        return m.groups()
    m = re.match(r'(?i)^status: ([0-9]+) ([0-9]+)$', line)
    if m is not None:
        return ('status', int(m.group(1)), int(m.group(2)))
    return ('msg', line)

#
# StratumClientProtocol
#
class StratumClientProtocol(asyncio.Protocol):
    '''asyncio protocol implementation of stratum'''

    def __init__(self, silentarmy):
        '''silentarmy is the instance of Silentarmy'''
        self.sa = silentarmy
        self.read_buffer = b''
        self.transport = None

    #
    # asyncio callbacks
    #
    def connection_made(self, transport):
        self.transport = transport
        verbose('Successfully connected to %s:%s' %
                (self.sa.host, self.sa.port))
        self.do_send(self.sa.stratum_msg('mining.subscribe'))
        self.sa.st_state = 'SENT_SUBSCRIBE'

    def data_received(self, data):
        very_verbose('From stratum server: %s' % (repr(data)))
        self.read_buffer += data
        # process 1 or more messages
        while True:
            i = self.read_buffer.find(b'\n')
            if i == -1:
                break
            tosend = self.sa.process_incoming_msg(self.read_buffer[:i])
            self.read_buffer = self.read_buffer[i + 1:]
            if tosend is not None:
                self.do_send(tosend)

    def connection_lost(self, exc):
        if exc is None:
            print('Stratum: connection was closed (invalid user/pwd?)')
        else:
            print('Stratum: lost connection: %s' % exc)
        my_ensure_future(self.sa.reconnect())

    #
    # other methods
    #
    def do_send(self, data):
        very_verbose('To stratum server: %s' % (repr(data)))
        self.transport.write(data)

#
# Silentarmy
#
class Silentarmy:
    '''Silentarmy Zcash miner'''

    def __init__(self, opts):
        self.opts = opts
        self.host = None
        self.port = None
        self.loop = None
        # Solver-related attributes
        self.solver_procs = {}
        self.solver_binary = os.path.join(sys.path[0], 'sa-solver')
        # Stratum-related attributes
        self.st_transport = None
        self.st_conn_attempt = 0
        self.st_had_job = False
        self.st_protocol = StratumClientProtocol(self)
        self.st_state = 'DISCONNECTED'
        self.st_extranonce = False
        self.st_id = 0
        self.st_expected_id = None
        self.st_accepted = 0
        # Equihash-related attributes
        self.target = None
        self.job_id = None
        self.zcash_nonceless_header = None
        self.nonce_leftpart = None
        # Stats for each device ID. For example, if:
        #   gpu 0 instance 0 found 100 solutions
        #   gpu 0 instance 1 found 200 solutions
        #   gpu 1 instance 0 found 300 solutions
        #   ...
        # then:
        #   total_sols = { '0.0': 100, '0.1': 200, '1.0': 300, ... }
        self.total_sols = {}
        self.total_shares = {}

    def init(self):
        (self.host, self.port, self.st_extranonce) = parse_url(self.opts.pool)
        if sys.platform == 'win32':
            # ProactorEventLoop needed to support subprocesses with Python 3.5
            self.loop = asyncio.ProactorEventLoop()
            asyncio.set_event_loop(self.loop)
        else:
            self.loop = asyncio.get_event_loop()
        if self.opts.debug:
            self.loop.set_debug(True)
            logging.basicConfig(level=logging.DEBUG)
            logging.getLogger('asyncio').setLevel(logging.DEBUG)

    @asyncio.coroutine
    def reconnect(self):
        if self.st_conn_attempt > 0:
            # wait 1 sec before reconnecting
            yield from asyncio.sleep(1)
        self.st_conn_attempt += 1
        info = ""
        if self.st_conn_attempt > 1:
            info = " (attempt %d)" % self.st_conn_attempt
        print("Connecting to %s:%d%s" % (self.host, self.port, info))
        self.st_state = 'CONNECTING'
        coro = self.loop.create_connection(lambda: self.st_protocol,
                self.host, self.port)
        try:
            (self.st_transport, _) = yield from coro
        except Exception as e:
            print("Stratum: error connecting: %s" % e)
            my_ensure_future(self.reconnect())

    @asyncio.coroutine
    def show_stats(self):
        def gpus(last_sols, last_times, period_gpu):
            '''Return per-GPU stats, for example [ '0':20, '1':200 ] means
            GPU 0 is mining at 20 sol/s and GPU 1 at 200 sol/s.'''
            # when we don't have enough samples, we pick the oldest one
            idx = min(period_gpu, len(last_sols) - 1)
            devids = last_sols[0].keys()
            sols = {}
            for devid in devids:
                try:
                    sols[devid] = last_sols[0][devid] - last_sols[idx][devid]
                except KeyError as e:
                    # XXX - ugly fix need to look into why stats are missing
                    sols[devid] = 0
            sec = last_times[0] - last_times[idx]
            # sols contains per-devid data:
            #   { '0.0':10, '0.1':10, '1.0':100, '1.1':100 }
            # we need to reduce it to per-GPU data:
            #   [ '0':20, '1':200 ]
            gpuids = set([x.split('.')[0] for x in devids])
            result = []
            for gpuid in gpuids:
                x = [v for (k,v) in sols.items() if k.startswith('%s.' % gpuid)]
                if sec:
                    result.append((gpuid, sum(x) / sec))
                else:
                    result.append((gpuid, 0))
            return result
        period_glo = 30 # total sol/s rate is computed over <period_glo> sec
        period_gpu = 10 # per-GPU sol/s rate is computed over <period_gpu> sec
        last_sols = []
        last_times = []
        while True:
            yield from asyncio.sleep(1)
            if not self.st_had_job:
                # do not show stats until we start working on a job
                continue
            if not self.total_sols:
                # do not show stats if we don't have any
                continue
            last_sols.insert(0, self.total_sols.copy())
            last_times.insert(0, time.time())
            sols = sum(last_sols[0].values()) - sum(last_sols[-1].values())
            sec = last_times[0] - last_times[-1]
            if sec:
                rate_avg = "%.1f" % (sols / sec)
            else:
                rate_avg = 0.0
            if len(last_sols) > period_glo:
                last_sols.pop()
                last_times.pop()
            rate_gpus = sorted(gpus(last_sols, last_times, period_gpu))
            info_gpus = ', '.join(['dev%s %.1f' % (x[0], x[1])
                for x in rate_gpus])
            shares = sum(self.total_shares.values())
            print("Total %s sol/s [%s] %d share%s" % \
                    (rate_avg, info_gpus, shares, '' if shares == 1 else 's'))
            sys.stdout.flush()

    def list_devices(self):
        try:
            os.execl(self.solver_binary, self.solver_binary, '--list')
            # never reached
        except FileNotFoundError as e:
            fatal(("Could not find '%s' binary; make sure to run 'make' to "
                "compile it") % self.solver_binary)
        except Exception as e:
            fatal("Failed to execute '%s': %s" % (self.solver_binary, e))

    def run(self):
        if self.opts.do_list:
            self.list_devices()
        self.init()
        my_ensure_future(self.reconnect())
        my_ensure_future(self.show_stats())
        for gpuid in self.opts.use:
            for instid in range(self.opts.instances):
                devid = "%d.%d" % (gpuid, instid)
                my_ensure_future(self.start_solvers(devid))
        try:
            self.loop.run_forever()
        except KeyboardInterrupt as e:
            print('\nQuitting')
            sys.exit(0)
        verbose('Closing event loop')
        self.loop.close()

    def cleanup_solvers(self, devid):
        '''Terminate a solver and clean up resources. This might be called for
        example when EOF is read from stdout.'''
        # wait for the process to end (sometimes EOF is received right before
        # Python has time to fill returncode with its status)
        yield from proc.wait()
        print('Solver %s: exit status %d' % (devid, proc.returncode))
        if devid in self.solver_procs:
            del self.solver_procs[devid]

    @asyncio.coroutine
    def start_solvers(self, devid):
        verbose('Solver %s: launching' % devid)
        # execute "sa-solver --mining --use <id>"
        create = asyncio.create_subprocess_exec(
                self.solver_binary, '--mining', '--use', devid.split('.')[0],
                stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.STDOUT)
        try:
            proc = yield from create
        except FileNotFoundError as e:
            warn(("Could not find '%s' binary; make sure to run 'make' to "
                "compile it") % self.solver_binary)
            # exit without using sys.exit() because this raises SystemExit
            # and asyncio catches it and prints a stack trace which confuses
            # end-users.
            os._exit(1)
        except Exception as e:
            fatal("Failed to execute '%s': %s" % (self.solver_binary, e))
        self.solver_procs[devid] = proc
        banner = yield from proc.stdout.readline()
        if not banner:
            print('Solver %s: EOF while reading banner' % devid)
            self.cleanup_solvers(devid)
            return
        banner = banner.decode('ascii').rstrip()
        very_verbose('From solver %s: banner "%s"' % (devid, banner))
        if banner != "SILENTARMY mining mode ready":
            print('Solver %s: unexpected banner "%s"' % (devid, banner))
            proc.kill()
            self.cleanup_solvers(devid)
            return
        # jobs will be written to stdin by update_mining_job(), while this
        # code reads solutions from stdout
        while True:
            solline = yield from proc.stdout.readline()
            if not solline:
                self.cleanup_solvers(devid)
                return
            solline = solline.decode('ascii').rstrip()
            very_verbose('From solver %s: %s' % (devid, solline))
            decoded = decode_solver_line(solline)
            if decoded[0] == 'sol':
                very_verbose('Solver %s: submitting share: %s' % \
                        (devid, str(decoded[1:])))
                self.st_protocol.do_send(self.stratum_msg('mining.submit',
                    *decoded[1:]))
            elif decoded[0] == 'status':
                (nr_sols, nr_shares) = decoded[1:]
                very_verbose('Solver %s: found %d sols %d shares so far' %
                        (devid, nr_sols, nr_shares))
                self.total_sols[devid] = nr_sols
                self.total_shares[devid] = nr_shares
            elif decoded[0] == 'msg':
                (msg,) = decoded[1:]
                very_verbose('Solver %s: reported: %s' % (devid, msg))
            else:
                fatal('Invalid solver line: %s' % repr(decoded))

    def update_mining_job(self):
        '''Called every time the miner receives a piece of data from the
        stratum server that might update the mining work.'''
        # In order to start mining, the client needs 4 things:
        # - have nonce_leftpart (result of mining.subscribe)
        # - be authorized (after mining.authorize)
        # - have a target (sent by mining.set_target)
        # - have a block header (sent by mining.notify)
        if self.nonce_leftpart is None:
            return
        if self.st_state != 'AUTHORIZED':
            return
        if self.target is None:
            return
        if self.zcash_nonceless_header is None:
            return
        for gpuid in self.opts.use:
            for instid in range(self.opts.instances):
                devid = "%d.%d" % (gpuid, instid)
                if devid not in self.solver_procs:
                    # happens if solver crashed
                    print('Solver %s: not running, relaunching it' % devid)
                    my_ensure_future(self.start_solvers(devid))
                    # TODO: ideally the mining job should be sent to the solver
                    # as soon as it is back up and running
        if not self.st_had_job:
            print('Stratum server sent us the first job')
            l = len(self.opts.use)
            print('Mining on %d device%s' % (l, '' if l == 1 else 's'))
            self.st_had_job = True
        job = "%s %s %s %s\n" % (b2hex(self.target), self.job_id,
                b2hex(self.zcash_nonceless_header),
                b2hex(self.nonce_leftpart))
        very_verbose('To solvers: %s' % job.rstrip())
        for devid in self.solver_procs:
            self.solver_procs[devid].stdin.write(job.encode('ascii'))

    def set_nonce_leftpart(self, n):
        self.nonce_leftpart = bytes.fromhex(n)
        l = len(self.nonce_leftpart)
        very_verbose('Stratum server fixes %d bytes of the nonce' % l)
        if l > 17:
            # SILENTARMY requires the last 12 bytes to be zero, then 3 bytes
            # to vary the nonce, this leaves at most 17 bytes that can be
            # fixed by the server.
            fatal('Stratum: SILENTARMY is not compatible with servers ' +
                    'fixing the first %d bytes of the nonce' % l)

    def set_target(self, t):
        verbose('Received target %s' % t)
        if not re.match(r'^[0-9a-fA-F]{64}$', t):
            raise Exception('Invalid target: %s' % t)
        is_first_target = self.target is None
        # store it in internal byte order
        self.target = bytes.fromhex(t)[::-1]
        # take the target into account *immediately* only if it is the first
        # ever received, or else the target applies to the next job
        if is_first_target:
            self.update_mining_job()

    def set_new_job(self, job_id, nversion, hash_prev_block, hash_merkle_root,
            hash_reserved, ntime, nbits, clean_jobs):
        verbose('Received job "%s"' % job_id)
        if not clean_jobs:
            verbose('Ignoring job "%s" (clean_jobs=False)' % job_id)
            return
        self.job_id = job_id
        if nversion != '04000000':
            raise Exception('Invalid version: %s' % nversion)
        if not re.match(r'^[0-9a-fA-F]{64}$', hash_prev_block):
            raise Exception('Invalid hashPrevBlock: %s' % hash_prev_block)
        if not re.match(r'^[0-9a-fA-F]{64}$', hash_merkle_root):
            raise Exception('Invalid hashMerkleRoot: %s' % hash_merkle_root)
        if hash_reserved != '00' * 32:
            raise Exception('Invalid hashReserved: %s' % hash_reserved)
        if not re.match(r'^[0-9a-fA-F]{8}$', ntime):
            raise Exception('Invalid nTime: %s' % ntime)
        if not re.match(r'^[0-9a-fA-F]{8}$', nbits):
            raise Exception('Invalid nBits: %s' % nbits)
        self.zcash_nonceless_header = bytes.fromhex(nversion + \
                hash_prev_block + hash_merkle_root + hash_reserved + ntime + 
                nbits)

    def stratum_next_id(self):
        self.st_id += 1
        self.st_expected_id = self.st_id
        return self.st_id

    def stratum_msg(self, method, *args):
        '''Generate a stratum message to call the specified method.'''
        if method == 'mining.subscribe':
            p = ["silentarmy", None, self.host, str(self.port)]
        elif method == 'mining.extranonce.subscribe':
            p = []
        elif method == 'mining.authorize':
            if self.opts.pwd:
                p = [self.opts.user,self.opts.pwd]
            else:
                p = [self.opts.user,'']
        elif method == 'mining.submit':
            (job_id, ntime, nonce_rightpart, sol) = args
            p = [self.opts.user, job_id, ntime, nonce_rightpart, sol]
        else:
            fatal('Bug: unknown method %s' % method)
        msg_id = self.stratum_next_id()
        msg = json.dumps({'id':msg_id, 'method':method, 'params':p}) + '\n'
        return msg.encode('utf-8')

    def process_incoming_msg(self, msg):
        '''Process an incoming stratum message.
        Return None, or a message to send back.'''
        try:
            msg = json.loads(msg.decode())
            if 'id' not in msg:
                raise Exception("'id' field is missing")
            # server returning a method call result
            if 'result' in msg:
                if 'error' in msg and msg['error'] is not None:
                    print("Stratum server returned an error: %s" % msg['error'])
                    return
                if msg['id'] != self.st_expected_id:
                    # XXX need to track which outstanding IDs we are waiting
                    # a response for
                    very_verbose("Stratum server returned wrong id: %s" % \
                            msg['id'])
                    # attempt to proceed and ignore this error
                self.st_expected_id = None
                if self.st_state == 'SENT_SUBSCRIBE':
                    # result: [ <ignored>, nonce_leftpart ]
                    self.set_nonce_leftpart(msg['result'][1])
                    if self.st_extranonce:
                        self.st_state = 'SENT_EXTRANONCE_SUBSCRIBE'
                        return self.stratum_msg('mining.extranonce.subscribe')
                    else:
                        self.st_state = 'SENT_AUTHORIZE'
                        return self.stratum_msg('mining.authorize')
                elif self.st_state == 'SENT_EXTRANONCE_SUBSCRIBE':
                    # ignore
                    self.st_state = 'SENT_AUTHORIZE'
                    return self.stratum_msg('mining.authorize')
                elif self.st_state == 'SENT_AUTHORIZE':
                    # result: succeeded
                    if not msg['result']:
                        raise Exception('mining.authorize failed')
                    self.st_state = 'AUTHORIZED'
                    self.update_mining_job()
                elif self.st_state == 'AUTHORIZED':
                    # result: succeeded
                    very_verbose("Stratum server accepted a share")
                    self.st_accepted += 1
                else:
                    fatal('Bug: unknown state %s' % self.st_state)
            # server calling a method
            elif 'method' in msg:
                if msg['method'] == 'mining.set_target':
                    # params: [ target ]
                    self.set_target(msg['params'][0])
                elif msg['method'] == 'mining.set_extranonce':
                    # params: [ nonce_leftpart ]
                    self.set_nonce_leftpart(msg['params'][0])
                elif msg['method'] == 'mining.notify':
                    # params: [ job_id, nVersion, hashPrevBlock, hashMerkleRoot,
                    #   hashReserved, nTime, nBits, clean_jobs ]
                    self.set_new_job(*msg['params'][:8])
                    self.update_mining_job()
                elif msg['method'] == 'client.reconnect':
                    print("Stratum server forcing a reconnection")
                    self.st_transport.close()
                    # reconnection will happen automatically in connection_lost()
                else:
                    raise Exception('Unimplemented method: %s' % msg['method'])
            else:
                raise Exception('Message is neither a result nor a method call')
        except Exception as e:
            print('Stratum: invalid msg from server: %s: %s\n' % (e, msg))
        return None

#
# Main
#
def main():
    global verbose_level
    parser = OptionParser()
    parser.add_option(
            "-v", "--verbose",
            dest="verbose", action="count", default=0,
            help="verbose mode (may be repeated for more verbosity)")
    parser.add_option(
            "--debug",
            dest="debug", action="store_true",
            help="enable debug mode (for developers only)")
    parser.add_option(
            "--list",
            dest="do_list", action="store_true",
            help="list available OpenCL devices by ID (GPUs...)")
    parser.add_option(
            "--use",
            dest="use", action="store", type="string", metavar="LIST",
            default='0',
            help="use specified GPU device IDs to mine, for example to use " +
            "the first three: 0,1,2 (default: 0)")
    parser.add_option(
            "--instances",
            dest="instances", action="store", type="int", metavar="N",
            default=2,
            help="run N instances of Equihash per GPU (default: 2)")
    parser.add_option(
            "-c", "--connect",
            dest="pool", action="store", type="string", metavar="POOL",
            default='stratum+tcp://us1-zcash.flypool.org:3333',
            help="connect to POOL, for example stratum+tcp://example.com:1234" +
            " (add \"#xnsub\" to enable extranonce.subscribe)")
    parser.add_option(
            "-u", "--user",
            dest="user", action="store", type="string", metavar="USER",
            default="t1cVviFvgJinQ4w3C2m2CfRxgP5DnHYaoFC",
            help="username for connecting to the pool")
    parser.add_option(
            "-p", "--pwd",
            dest="pwd", action="store", type="string", metavar="PWD",
            help="password for connecting to the pool")
    (opts, args) = parser.parse_args()
    if args:
        parser.error("Extraneous arguments found on command line")
    if not (opts.do_list or opts.pool):
        parser.error("No pool was specified; use --connect")
    verbose_level = opts.verbose
    try:
        opts.use = set([int(x) for x in opts.use.split(',')])
    except Exception:
        fatal("Invalid syntax for --use: %s" % opts.use)
    if opts.instances < 1:
        fatal("The number of instances per GPU should be 1 or greater")
    Silentarmy(opts).run()

if __name__ == "__main__":
    main()
