#! /usr/bin/env python3
#
# Copyright (C) 2019 Garmin Ltd.
#
# SPDX-License-Identifier: GPL-2.0-only
#

import argparse
import hashlib
import logging
import os
import pprint
import sys
import threading
import time
import warnings
warnings.simplefilter("default")

try:
    import tqdm
    ProgressBar = tqdm.tqdm
except ImportError:
    class ProgressBar(object):
        def __init__(self, *args, **kwargs):
            pass

        def __enter__(self):
            return self

        def __exit__(self, *args, **kwargs):
            pass

        def update(self):
            pass

sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'lib'))

import hashserv

DEFAULT_ADDRESS = 'unix://./hashserve.sock'
METHOD = 'stress.test.method'


def main():
    def handle_stats(args, client):
        if args.reset:
            s = client.reset_stats()
        else:
            s = client.get_stats()
        pprint.pprint(s)
        return 0

    def handle_stress(args, client):
        def thread_main(pbar, lock):
            nonlocal found_hashes
            nonlocal missed_hashes
            nonlocal max_time

            with hashserv.create_client(args.address) as client:
                for i in range(args.requests):
                    taskhash = hashlib.sha256()
                    taskhash.update(args.taskhash_seed.encode('utf-8'))
                    taskhash.update(str(i).encode('utf-8'))

                    start_time = time.perf_counter()
                    l = client.get_unihash(METHOD, taskhash.hexdigest())
                    elapsed = time.perf_counter() - start_time

                    with lock:
                        if l:
                            found_hashes += 1
                        else:
                            missed_hashes += 1

                        max_time = max(elapsed, max_time)
                        pbar.update()

        max_time = 0
        found_hashes = 0
        missed_hashes = 0
        lock = threading.Lock()
        total_requests = args.clients * args.requests
        start_time = time.perf_counter()
        with ProgressBar(total=total_requests) as pbar:
            threads = [threading.Thread(target=thread_main, args=(pbar, lock), daemon=False) for _ in range(args.clients)]
            for t in threads:
                t.start()

            for t in threads:
                t.join()

        elapsed = time.perf_counter() - start_time
        with lock:
            print("%d requests in %.1fs. %.1f requests per second" % (total_requests, elapsed, total_requests / elapsed))
            print("Average request time %.8fs" % (elapsed / total_requests))
            print("Max request time was %.8fs" % max_time)
            print("Found %d hashes, missed %d" % (found_hashes, missed_hashes))

        if args.report:
            with ProgressBar(total=args.requests) as pbar:
                for i in range(args.requests):
                    taskhash = hashlib.sha256()
                    taskhash.update(args.taskhash_seed.encode('utf-8'))
                    taskhash.update(str(i).encode('utf-8'))

                    outhash = hashlib.sha256()
                    outhash.update(args.outhash_seed.encode('utf-8'))
                    outhash.update(str(i).encode('utf-8'))

                    client.report_unihash(taskhash.hexdigest(), METHOD, outhash.hexdigest(), taskhash.hexdigest())

                    with lock:
                        pbar.update()

    parser = argparse.ArgumentParser(description='Hash Equivalence Client')
    parser.add_argument('--address', default=DEFAULT_ADDRESS, help='Server address (default "%(default)s")')
    parser.add_argument('--log', default='WARNING', help='Set logging level')

    subparsers = parser.add_subparsers()

    stats_parser = subparsers.add_parser('stats', help='Show server stats')
    stats_parser.add_argument('--reset', action='store_true',
                              help='Reset server stats')
    stats_parser.set_defaults(func=handle_stats)

    stress_parser = subparsers.add_parser('stress', help='Run stress test')
    stress_parser.add_argument('--clients', type=int, default=10,
                               help='Number of simultaneous clients')
    stress_parser.add_argument('--requests', type=int, default=1000,
                               help='Number of requests each client will perform')
    stress_parser.add_argument('--report', action='store_true',
                               help='Report new hashes')
    stress_parser.add_argument('--taskhash-seed', default='',
                               help='Include string in taskhash')
    stress_parser.add_argument('--outhash-seed', default='',
                               help='Include string in outhash')
    stress_parser.set_defaults(func=handle_stress)

    args = parser.parse_args()

    logger = logging.getLogger('hashserv')

    level = getattr(logging, args.log.upper(), None)
    if not isinstance(level, int):
        raise ValueError('Invalid log level: %s' % args.log)

    logger.setLevel(level)
    console = logging.StreamHandler()
    console.setLevel(level)
    logger.addHandler(console)

    func = getattr(args, 'func', None)
    if func:
        with hashserv.create_client(args.address) as client:
            return func(args, client)

    return 0


if __name__ == '__main__':
    try:
        ret = main()
    except Exception:
        ret = 1
        import traceback
        traceback.print_exc()
    sys.exit(ret)
