#! /usr/bin/env python
#
# privacyidea-checkotp.py - python implementation of PrivacyIDEA OTP check for
#			    command-line use or integration with FreeRadius
#
# Version 1.0, latest version, documentation and bugtracker available at:
#		https://gitlab.lindenaar.net/privacyidea/checkotp
#
# Copyright (c) 2016 Frederik Lindenaar
#
# This script is free software: you can redistribute and/or modify it under the
# terms of version 3 of the GNU General Public License as published by the Free
# Software Foundation, or (at your option) any later version of the license.
#
# This script is distributed in the hope that it will be useful but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, visit <http://www.gnu.org/licenses/> to download it.

import sys, os, logging, json
from getpass import getpass
from urllib import urlencode
from urllib2 import Request, HTTPError, urlopen
from argparse import ArgumentParser as StandardArgumentParser, FileType, \
              _StoreAction as StoreAction, _StoreConstAction as StoreConstAction

VERSION="2.0"
PROG_NAME=os.path.splitext(os.path.basename(__file__))[0]
PROG_VERSION=PROG_NAME + ' ' + VERSION
URL_API_SUFFIX='/validate/check'
ENV_VAR_USER='USER_NAME'
ENV_VAR_USERSTRIPPED='STRIPPED_USER_NAME'
ENV_VAR_PWD='USER_PASSWORD'
ENV_VAR_NAS='NAS_IP_ADDRESS'
LOG_FORMAT='%(levelname)s - %(message)s'
LOG_FORMAT_FILE='%(asctime)s - ' + LOG_FORMAT
LOGGING_RADIUS=logging.CRITICAL + 10
LOGGING_NONE=logging.CRITICAL + 20

# Setup logging
logging.basicConfig(format=LOG_FORMAT)
logging.addLevelName(LOGGING_RADIUS, 'RADIUS')
logging.addLevelName(LOGGING_NONE, 'NONE')
logger = logging.getLogger(PROG_NAME)
logger.setLevel(logging.CRITICAL)


################[ wrapper to stop ArgumentParser from exiting ]################
# Stop ArgumentParser from exiting with an error message upon errors
# based on http://stackoverflow.com/questions/14728376/i-want-python-argparse-to-throw-an-exception-rather-than-usage/14728477#14728477
# the only way to do this seems overriding error() and raising an exception
class ArgumentParserError(Exception): pass

class ArgumentParser(StandardArgumentParser):
    def error(self, message):
        raise ArgumentParserError(message)

##################[ Action to immediately set the log level ]##################
class SetLogLevel(StoreConstAction):
    """ArgumentParser action to set log level to provided const value"""
    def __call__(self, parser, namespace, values, option_string=None):
        logging.getLogger(PROG_NAME).setLevel(self.const)

####################[ Action to immediately log to a file ]####################
class SetLogFile(StoreAction):
    """ArgumentParser action to log to file (sets up FileHandler accordingly)"""
    def __call__(self, parser, namespace, values, option_string=None):
        super(SetLogFile, self).__call__(parser,namespace,values,option_string)
        formatter = logging.Formatter(LOG_FORMAT_FILE)
        handler = logging.FileHandler(values)
        handler.setFormatter(formatter)
        logger = logging.getLogger(PROG_NAME)
        logger.propagate = False
        logger.addHandler(handler)

###############################################################################


def isempty(str):
    """Checks whether a string is unset or empty"""
    return str is None or len(str)== 0


def envvar(name, default=None):
    """Returns the value of environment value name"""
    return os.environ.get(name, default)


def dequote(str):
    """Remove the starting and trailing quotes from a string, if both present"""
    return str[1:-1] if not isempty(str) and str[0] == str[-1] == '"' else str


def parse_args():
    """Parse command line and get parameters from environment if not set"""

    # Setup argument parser
    parser = ArgumentParser(
        description='check an OTP agains PrivacyIDEA from the command-line',
        epilog='* parameter is required but can also be passed in environment '
               'variables\n  %s and %s. Value for nas can be set in %s.'
               'In case the value for password equals "-" it is read from stdin'
               % (ENV_VAR_USER, ENV_VAR_PWD, ENV_VAR_NAS)
    )
    parser.add_argument('-V', '--version',action="version",version=PROG_VERSION)

    parser.add_argument('url',             help='URL to PrivacyIDEA/LinOTP')
    parser.add_argument('principal',       default=dequote(envvar(ENV_VAR_USERSTRIPPED, envvar(ENV_VAR_USER))),
                        nargs='?',         help='user or token serial to login with *')
    parser.add_argument('password',        default=dequote(envvar(ENV_VAR_PWD)),
                        nargs='?',         help='password + OTP to authenticate with *')
    parser.add_argument('nas',             default=dequote(envvar(ENV_VAR_NAS)),
                        nargs='?',         help='ID of the Network Access System')

    pgroup = parser.add_mutually_exclusive_group(required=False)
    pgroup.add_argument('-q', '--quiet',   action=SetLogLevel, const=LOGGING_NONE,
                        default=logging.CRITICAL,
                        help='quiet (no output, only exit with exit code)')
    pgroup.add_argument('-v', '--verbose', action=SetLogLevel, const=logging.INFO,
                        help='more verbose output')
    pgroup.add_argument('-d', '--debug',   action=SetLogLevel, const=logging.DEBUG,
                        help='debug output (more verbose)')
    pgroup.add_argument('-r', '--radius',  action=SetLogLevel, const=LOGGING_RADIUS,
                        help='run in radius mode (only produce Radius output)')

    parser.add_argument('-l', '--logfile', action=SetLogFile,
                        help='send logging output to logfile')

    pgroup = parser.add_mutually_exclusive_group()
    pgroup.add_argument('-u', '--user',    action='store_false', dest='isserial',
                        help='provided principal contains a login (default)')
    pgroup.add_argument('-s', '--serial',  action='store_true', dest='isserial',
                        help='provided principal contains a token serial')

    parser.add_argument('-p', '--prompt',  action='store_true',
                        help='prompt for password + OTP (not in Radius mode)')

    # parse arguments
    args = parser.parse_args()

    # Post-process command line options
    if args.prompt and not isempty(args.principal):
        args.password = getpass("please enter password: " )
    elif args.password == '-':
        args.password = sys.stdin.readline().strip()

    # We should now be ready to authenticate, fail if that's not the case
    if isempty(args.principal) or isempty(args.password):
        parser.error('user/serial and password are required!')

    # if we got here all seems OK
    return args


def checkotp(url, subject, secret, isserial=False, nas=None):
    """Check a subject (user or token) with secret against PrivacyIDEA / LinOTP.

    Args:
        url      (str) : URL to connect to, URL_API_SUFFIX is added if missing
        subject  (str) : subject to authenticate (user or a token serial)
        secret   (str) : secret (password+OTP) to authenticate with
        isserial (bool): True if subject is a token serial (optional)
        nas      (str) : string to pass-on as the nas string (optional)

    Returns:
        The result response from the PrivacyIDEA server (mapping object)
    """
    # Complete (fix) URL
    if not url.endswith(URL_API_SUFFIX):
        url += URL_API_SUFFIX[1:] if url[-1] == '/' else URL_API_SUFFIX
    logger.info('connecting to %s', url)

    # Prepare the parameters
    params = { 'pass': secret, 'serial' if isserial else 'user': subject }
    if not isempty(nas):
        params['nas'] = nas
    if logger.isEnabledFor(logging.DEBUG):
        logger.debug('HTTP request parameters: %s',
                     ', '.join(map(lambda (k,v): '%s="%s"' % (k, v if k!='pass'
                                     else '***MASKED***'), params.iteritems())))

     # Perform the API authentication request
    response = json.load(urlopen(Request(url, data=urlencode(params))))
    if logger.isEnabledFor(logging.DEBUG):
        logger.debug('result: %s', json.dumps(response, indent=4))

    return response


####################[ command-line script implementation ]####################
if __name__ == '__main__':

    try:
        args = parse_args()

        response=checkotp(args.url, args.principal, args.password, args.isserial, args.nas)

    except (ArgumentParserError, HTTPError) as e:
        logger.critical('authentication failed: %s', e)
        radius_result = (2, 'ERROR')

    else:
        resultdata = response.get('result')
        authenticated = resultdata.get('status') and resultdata.get('value')
        radius_result = (0, 'PrivacyIDEA') if authenticated else (1, 'REJECT')

        if logger.isEnabledFor(logging.INFO):
            logger.info('Got response from : %s', response.get('version'))
            logger.info('Got valid result  : %s', resultdata.get('status'))
            logger.info('Authenticated     : %s', authenticated)
            detaildata = response.get('detail')
            for field in 'message', 'type', 'serial':
                if field in detaildata:
                    logger.info('Token %-12s: %s', field, detaildata.get(field))

    finally:
        if logger.propagate == False and logger.isEnabledFor(LOGGING_RADIUS) \
                                or logger.getEffectiveLevel() == LOGGING_RADIUS:
            print 'Auth-Type=%s' % radius_result[1]
        sys.exit(radius_result[0])