#
# Copyright (C) 2018  Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import unittest
from flask import Flask
from datetime import timedelta

from pylorax.api.crossdomain import crossdomain


server = Flask(__name__)

@server.route('/01')
@crossdomain(origin='*', methods=['GET'])
def hello_world_01():
    return 'Hello, World!'


@server.route('/02')
@crossdomain(origin='*', headers=['TESTING'])
def hello_world_02():
    return 'Hello, World!'

@server.route('/03')
@crossdomain(origin='*', max_age=timedelta(days=7))
def hello_world_03():
    return 'Hello, World!'


@server.route('/04')
@crossdomain(origin='*', attach_to_all=False)
def hello_world_04():
    return 'Hello, World!'


@server.route('/05')
@crossdomain(origin='*', automatic_options=False)
def hello_world_05():
    return 'Hello, World!'


@server.route('/06')
@crossdomain(origin=['https://redhat.com', 'http://weldr.io'])
def hello_world_06():
    return 'Hello, World!'


class CrossdomainTest(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        self.server = server.test_client()

    def test_01_with_methods_specified(self):
        # first send a preflight request to check what methods are allowed
        response = self.server.options("/01")
        self.assertEqual(200, response.status_code)
        self.assertIn('GET', response.headers['Access-Control-Allow-Methods'])

        # then try to issue a POST request which isn't allowed
        response = self.server.post("/01")
        self.assertEqual(405, response.status_code)

    def test_02_with_headers_specified(self):
        response = self.server.get("/02")
        self.assertEqual(200, response.status_code)
        self.assertEqual(b'Hello, World!', response.data)

        self.assertEqual('TESTING', response.headers['Access-Control-Allow-Headers'])

    def test_03_with_max_age_as_timedelta(self):
        response = self.server.get("/03")
        self.assertEqual(200, response.status_code)
        self.assertEqual(b'Hello, World!', response.data)

        expected_max_age = int(timedelta(days=7).total_seconds())
        actual_max_age = int(response.headers['Access-Control-Max-Age'])
        self.assertEqual(expected_max_age, actual_max_age)

    def test_04_attach_to_all_false(self):
        response = self.server.get("/04")
        self.assertEqual(200, response.status_code)
        self.assertEqual(b'Hello, World!', response.data)

        # when attach_to_all is False the decorator will not assign
        # the Access-Control-* headers to the response
        for header, _ in response.headers:
            self.assertFalse(header.startswith('Access-Control-'))


    def test_05_options_request(self):
        response = self.server.options("/05")
        self.assertEqual(200, response.status_code)
        self.assertEqual(b'Hello, World!', response.data)

        # Not always in the same order, so test individually
        for m in ["HEAD", "OPTIONS", "GET"]:
            self.assertIn(m, response.headers['Access-Control-Allow-Methods'])


    def test_06_with_origin_as_list(self):
        response = self.server.get("/06")
        self.assertEqual(200, response.status_code)
        self.assertEqual(b'Hello, World!', response.data)

        for header, value in response.headers:
            if header == 'Access-Control-Allow-Origin':
                self.assertIn(value, ['https://redhat.com', 'http://weldr.io'])