#!/usr/bin/python3
# SPDX-License-Identifier: Apache-2.0
#
# Based on https://github.com/enarx/spdx/blob/master/verify-spdx-headers
#

import os
import re

SLUG = re.compile('[a-zA-Z0-9.-]+')
SPDX = re.compile(f'SPDX-License-Identifier:\s+({SLUG.pattern})')

class Language:
    def __init__(self, *comments, shebang=False):
        assert(isinstance(shebang, bool))
        self.__shebang = shebang

        self.__match = []
        for comment in comments:
            (init, fini) = (comment, '')
            if isinstance(comment, tuple):
                (init, fini) = comment

            pattern = f"^{init}\s*{SPDX.pattern}\s*{fini}\s*$"
            self.__match.append(re.compile(pattern))

    def license(self, path):
        "Find the license from the SPDX header."
        with open(path) as f:
            lines = f.readlines()
            for line in lines:
                for matcher in self.__match:
                    match = matcher.match(line)
                    if match:
                        return match.group(1)
        return None


class Index:
    INTERPRETERS = {
        'python3': 'python',
        'python2': 'python',
        'python': 'python',
        'ruby': 'ruby',
    }

    EXTENSIONS = {
        '.py': 'python',
        '.proto': 'protobuf',
        '.rs': 'rust',
        '.yml': 'yaml',
        '.yaml': 'yaml',
        '.json': 'json',
        '.toml': 'toml',
        '.md': 'md',
        '.rb': 'ruby',
        '.c': 'c',
        '.h': 'c',
        '.cpp': 'c++',
        '.hpp': 'c++',
        '.cc': 'c++',
        '.hh': 'c++',
    }

    def __init__(self):
        """
        This can be relaxed as needed
        """
        self.__languages = {
            'python': Language('#+.*', shebang=True),
            'ruby': Language('#+', shebang=True),
            'c': Language('//+', '.*', ('/\\*', '\\*/')),
            'c++': Language('//+', '.*',  ('/\\*', '\\*/')),
            'rust': Language('//+', '.*', '//!', ('/\\*', '\\*/')),
            'protobuf': Language('//+', '//!', '.*', ('/\\*', '\\*/')),
        }

    def language(self, path):
        name = self.EXTENSIONS.get(os.path.splitext(path)[1])
        if name is None:
            interpreter = None
            with open(path, "rb") as f:
                if f.read(2) == bytearray('#!'.encode('ascii')):
                    # assume a text file and retry as text file
                    try:
                        with open(path, "r") as t:
                            interpreter = t.readline().rstrip().rsplit(os.path.sep)[-1]
                    except:
                        pass
            name = self.INTERPRETERS.get(interpreter)
        return self.__languages.get(name)

    def scan(self):
        IGNORE_DIRS = { ".git" , ".github"}

        files = os.getenv('files')
        files = files.split(',')
        for file in files:
            file = os.path.abspath(file)
            if os.path.isfile(file):
                # Ignore the specified directories.
                dirs = file.split(os.sep)
                skip = False
                for dir in dirs:
                    if dir in IGNORE_DIRS:
                        skip = True
                if skip:
                    continue
                # Find the language of the file.
                language = self.language(file)
                if language is None:
                    continue

                # Parse the SPDX header for the language.
                yield (file, language.license(file))

if __name__ == '__main__':
    import sys

    # Validate the arguments
    license_arg = os.getenv('licenses')
    if license_arg is None:
        license_arg = sys.argv[1:]

    licenses = [x.strip() for x in license_arg.split(',')]
    print(f"{len(licenses)} will be accepted!")
    for license in licenses:
        if not SLUG.match(license):
            print("Invalid license '%s'!" % license)
            raise SystemExit(1)
        print(f"Allowed license: {license}")

    index = Index()
    failed = False
    for (path, license) in index.scan():
        if license not in licenses:
            failed = True
            if license == None:
                print(f"NO SPDX {path}")
            else:
                print(f"WRONG LICENSE: {license:16} {path}")
        else:
            print(f"OK: {license:16} {path}")
    if failed:
        raise Exception("Check the output, some files have not the right licenses!")

    raise SystemExit(0)
