Update web-platform-tests to revision e15b5ebba7465e09bcda2962f6758cddcdcfa248

This commit is contained in:
WPT Sync Bot 2018-10-09 21:32:32 -04:00
parent 68e55ead42
commit 3eaee747ed
214 changed files with 4692 additions and 245 deletions

View file

@ -55,7 +55,7 @@ def test_subprocess_exit(server_subprocesses, tempfile_name):
# which are relevant to this functionality. Disable the check so that
# the constructor is only used to create relevant processes.
with open(tempfile_name, 'w') as handle:
json.dump({"check_subdomains": False}, handle)
json.dump({"check_subdomains": False, "bind_address": False}, handle)
# The `logger` module from the wptserver package uses a singleton
# pattern which resists testing. In order to avoid conflicting with

View file

@ -2,4 +2,4 @@ html5lib == 1.0.1
mozinfo == 0.10
mozlog==3.8
mozdebug == 0.1
urllib3[secure] == 1.22
urllib3[secure]==1.23

View file

@ -508,7 +508,8 @@ class CallbackHandler(object):
self.actions = {
"click": ClickAction(self.logger, self.protocol),
"send_keys": SendKeysAction(self.logger, self.protocol)
"send_keys": SendKeysAction(self.logger, self.protocol),
"action_sequence": ActionSequenceAction(self.logger, self.protocol)
}
def __call__(self, result):
@ -539,7 +540,8 @@ class CallbackHandler(object):
except Exception:
self.logger.warning("Action %s failed" % action)
self.logger.warning(traceback.format_exc())
self._send_message("complete", "failure")
self._send_message("complete", "error")
raise
else:
self.logger.debug("Action %s completed" % action)
self._send_message("complete", "success")
@ -559,13 +561,9 @@ class ClickAction(object):
def __call__(self, payload):
selector = payload["selector"]
elements = self.protocol.select.elements_by_selector(selector)
if len(elements) == 0:
raise ValueError("Selector matches no elements")
elif len(elements) > 1:
raise ValueError("Selector matches multiple elements")
element = self.protocol.select.element_by_selector(selector)
self.logger.debug("Clicking element: %s" % selector)
self.protocol.click.element(elements[0])
self.protocol.click.element(element)
class SendKeysAction(object):
@ -576,10 +574,27 @@ class SendKeysAction(object):
def __call__(self, payload):
selector = payload["selector"]
keys = payload["keys"]
elements = self.protocol.select.elements_by_selector(selector)
if len(elements) == 0:
raise ValueError("Selector matches no elements")
elif len(elements) > 1:
raise ValueError("Selector matches multiple elements")
element = self.protocol.select.element_by_selector(selector)
self.logger.debug("Sending keys to element: %s" % selector)
self.protocol.send_keys.send_keys(elements[0], keys)
self.protocol.send_keys.send_keys(element, keys)
class ActionSequenceAction(object):
def __init__(self, logger, protocol):
self.logger = logger
self.protocol = protocol
def __call__(self, payload):
# TODO: some sort of shallow error checking
actions = payload["actions"]
for actionSequence in actions:
if actionSequence["type"] == "pointer":
for action in actionSequence["actions"]:
if (action["type"] == "pointerMove" and
isinstance(action["origin"], dict)):
action["origin"] = self.get_element(action["origin"]["selector"])
self.protocol.action_sequence.send_actions({"actions": actions})
def get_element(self, selector):
element = self.protocol.select.element_by_selector(selector)
return element

View file

@ -19,7 +19,8 @@ from .base import (CallbackHandler,
WebDriverProtocol,
extra_timeout,
strip_server)
from .protocol import (AssertsProtocolPart,
from .protocol import (ActionSequenceProtocolPart,
AssertsProtocolPart,
BaseProtocolPart,
TestharnessProtocolPart,
PrefsProtocolPart,
@ -359,6 +360,16 @@ class MarionetteSendKeysProtocolPart(SendKeysProtocolPart):
return element.send_keys(keys)
class MarionetteActionSequenceProtocolPart(ActionSequenceProtocolPart):
def setup(self):
self.marionette = self.parent.marionette
def send_actions(self, actions):
actions = self.marionette._to_json(actions)
self.logger.info(actions)
self.marionette._send_message("WebDriver:PerformActions", actions)
class MarionetteTestDriverProtocolPart(TestDriverProtocolPart):
def setup(self):
self.marionette = self.parent.marionette
@ -433,6 +444,7 @@ class MarionetteProtocol(Protocol):
MarionetteSelectorProtocolPart,
MarionetteClickProtocolPart,
MarionetteSendKeysProtocolPart,
MarionetteActionSequenceProtocolPart,
MarionetteTestDriverProtocolPart,
MarionetteAssertsProtocolPart,
MarionetteCoverageProtocolPart]

View file

@ -18,6 +18,7 @@ from .protocol import (BaseProtocolPart,
SelectorProtocolPart,
ClickProtocolPart,
SendKeysProtocolPart,
ActionSequenceProtocolPart,
TestDriverProtocolPart)
from ..testrunner import Stop
@ -26,15 +27,18 @@ here = os.path.join(os.path.split(__file__)[0])
webdriver = None
exceptions = None
RemoteConnection = None
Command = None
def do_delayed_imports():
global webdriver
global exceptions
global RemoteConnection
global Command
from selenium import webdriver
from selenium.common import exceptions
from selenium.webdriver.remote.remote_connection import RemoteConnection
from selenium.webdriver.remote.command import Command
class SeleniumBaseProtocolPart(BaseProtocolPart):
@ -135,6 +139,7 @@ class SeleniumClickProtocolPart(ClickProtocolPart):
def element(self, element):
return element.click()
class SeleniumSendKeysProtocolPart(SendKeysProtocolPart):
def setup(self):
self.webdriver = self.parent.webdriver
@ -143,6 +148,14 @@ class SeleniumSendKeysProtocolPart(SendKeysProtocolPart):
return element.send_keys(keys)
class SeleniumActionSequenceProtocolPart(ActionSequenceProtocolPart):
def setup(self):
self.webdriver = self.parent.webdriver
def send_actions(self, actions):
self.webdriver.execute(Command.W3C_ACTIONS, {"actions": actions})
class SeleniumTestDriverProtocolPart(TestDriverProtocolPart):
def setup(self):
self.webdriver = self.parent.webdriver
@ -163,7 +176,8 @@ class SeleniumProtocol(Protocol):
SeleniumSelectorProtocolPart,
SeleniumClickProtocolPart,
SeleniumSendKeysProtocolPart,
SeleniumTestDriverProtocolPart]
SeleniumTestDriverProtocolPart,
SeleniumActionSequenceProtocolPart]
def __init__(self, executor, browser, capabilities, **kwargs):
do_delayed_imports()

View file

@ -18,6 +18,7 @@ from .protocol import (BaseProtocolPart,
SelectorProtocolPart,
ClickProtocolPart,
SendKeysProtocolPart,
ActionSequenceProtocolPart,
TestDriverProtocolPart)
from ..testrunner import Stop
@ -25,6 +26,7 @@ import webdriver as client
here = os.path.join(os.path.split(__file__)[0])
class WebDriverBaseProtocolPart(BaseProtocolPart):
def setup(self):
self.webdriver = self.parent.webdriver
@ -146,6 +148,14 @@ class WebDriverSendKeysProtocolPart(SendKeysProtocolPart):
return element.send_element_command("POST", "value", {"value": list(keys)})
class WebDriverActionSequenceProtocolPart(ActionSequenceProtocolPart):
def setup(self):
self.webdriver = self.parent.webdriver
def send_actions(self, actions):
self.webdriver.actions.perform(actions)
class WebDriverTestDriverProtocolPart(TestDriverProtocolPart):
def setup(self):
self.webdriver = self.parent.webdriver
@ -166,6 +176,7 @@ class WebDriverProtocol(Protocol):
WebDriverSelectorProtocolPart,
WebDriverClickProtocolPart,
WebDriverSendKeysProtocolPart,
WebDriverActionSequenceProtocolPart,
WebDriverTestDriverProtocolPart]
def __init__(self, executor, browser, capabilities, **kwargs):

View file

@ -242,6 +242,14 @@ class SelectorProtocolPart(ProtocolPart):
name = "select"
def element_by_selector(self, selector):
elements = self.elements_by_selector(selector)
if len(elements) == 0:
raise ValueError("Selector '%s' matches no elements" % selector)
elif len(elements) > 1:
raise ValueError("Selector '%s' matches multiple elements" % selector)
return elements[0]
@abstractmethod
def elements_by_selector(self, selector):
"""Select elements matching a CSS selector
@ -279,6 +287,20 @@ class SendKeysProtocolPart(ProtocolPart):
pass
class ActionSequenceProtocolPart(ProtocolPart):
"""Protocol part for performing trusted clicks"""
__metaclass__ = ABCMeta
name = "action_sequence"
@abstractmethod
def send_actions(self, actions):
"""Send a sequence of actions to the window.
:param actions: A protocol-specific handle to an array of actions."""
pass
class TestDriverProtocolPart(ProtocolPart):
"""Protocol part that implements the basic functionality required for
all testdriver-based tests."""

View file

@ -70,4 +70,22 @@
window.opener.postMessage({"type": "action", "action": "send_keys", "selector": selector, "keys": keys}, "*");
return pending_promise;
};
window.test_driver_internal.action_sequence = function(actions) {
const pending_promise = new Promise(function(resolve, reject) {
pending_resolve = resolve;
pending_reject = reject;
});
for (let actionSequence of actions) {
if (actionSequence.type == "pointer") {
for (let action of actionSequence.actions) {
if (action.type == "pointerMove" && action.origin instanceof Element) {
action.origin = {selector: get_selector(action.origin)};
}
}
}
}
window.opener.postMessage({"type": "action", "action": "action_sequence", "actions": actions}, "*");
return pending_promise;
};
})();

View file

@ -52,7 +52,7 @@ class TestRunner(object):
:param command_queue: subprocess.Queue used to send commands to the
process
:param result_queue: subprocess.Queue used to send results to the
parent TestManager process
parent TestRunnerManager process
:param executor: TestExecutor object that will actually run a test.
"""
self.command_queue = command_queue
@ -304,7 +304,7 @@ class TestRunnerManager(threading.Thread):
self.test_runner_proc = None
threading.Thread.__init__(self, name="Thread-TestrunnerManager-%i" % self.manager_number)
threading.Thread.__init__(self, name="TestRunnerManager-%i" % self.manager_number)
# This is started in the actual new thread
self.logger = None
@ -321,9 +321,9 @@ class TestRunnerManager(threading.Thread):
self.capture_stdio = capture_stdio
def run(self):
"""Main loop for the TestManager.
"""Main loop for the TestRunnerManager.
TestManagers generally receive commands from their
TestRunnerManagers generally receive commands from their
TestRunner updating them on the status of a test. They
may also have a stop flag set by the main thread indicating
that the manager should shut down the next time the event loop
@ -490,7 +490,7 @@ class TestRunnerManager(threading.Thread):
self.child_stop_flag)
self.test_runner_proc = Process(target=start_runner,
args=args,
name="Thread-TestRunner-%i" % self.manager_number)
name="TestRunner-%i" % self.manager_number)
self.test_runner_proc.start()
self.logger.debug("Test runner started")
# Now we wait for either an init_succeeded event or an init_failed event
@ -623,10 +623,10 @@ class TestRunnerManager(threading.Thread):
def wait_finished(self):
assert isinstance(self.state, RunnerManagerState.running)
# The browser should be stopped already, but this ensures we do any post-stop
# processing
self.logger.debug("Wait finished")
# The browser should be stopped already, but this ensures we do any
# post-stop processing
return self.after_test_end(self.state.test, True)
def after_test_end(self, test, restart):
@ -674,7 +674,7 @@ class TestRunnerManager(threading.Thread):
self.cleanup()
def teardown(self):
self.logger.debug("teardown in testrunnermanager")
self.logger.debug("TestRunnerManager teardown")
self.test_runner_proc = None
self.command_queue.close()
self.remote_queue.close()
@ -695,7 +695,7 @@ class TestRunnerManager(threading.Thread):
self.test_runner_proc.terminate()
self.test_runner_proc.join(10)
else:
self.logger.debug("Testrunner exited with code %i" % self.test_runner_proc.exitcode)
self.logger.debug("Runner process exited with code %i" % self.test_runner_proc.exitcode)
def runner_teardown(self):
self.ensure_runner_stopped()
@ -705,7 +705,7 @@ class TestRunnerManager(threading.Thread):
self.remote_queue.put((command, args))
def cleanup(self):
self.logger.debug("TestManager cleanup")
self.logger.debug("TestRunnerManager cleanup")
if self.browser:
self.browser.cleanup()
while True:
@ -716,12 +716,17 @@ class TestRunnerManager(threading.Thread):
else:
if cmd == "log":
self.log(*data)
elif cmd == "runner_teardown":
# It's OK for the "runner_teardown" message to be left in
# the queue during cleanup, as we will already have tried
# to stop the TestRunner in `stop_runner`.
pass
else:
self.logger.warning("%r: %r" % (cmd, data))
self.logger.warning("Command left in command_queue during cleanup: %r, %r" % (cmd, data))
while True:
try:
cmd, data = self.remote_queue.get_nowait()
self.logger.warning("%r: %r" % (cmd, data))
self.logger.warning("Command left in remote_queue during cleanup: %r, %r" % (cmd, data))
except Empty:
break
@ -747,7 +752,7 @@ class ManagerGroup(object):
restart_on_unexpected=True,
debug_info=None,
capture_stdio=True):
"""Main thread object that owns all the TestManager threads."""
"""Main thread object that owns all the TestRunnerManager threads."""
self.suite_name = suite_name
self.size = size
self.test_source_cls = test_source_cls

View file

@ -39,6 +39,7 @@ metadata files are used to store the expected test results.
def setup_logging(*args, **kwargs):
global logger
logger = wptlogging.setup(*args, **kwargs)
return logger
def get_loader(test_paths, product, debug=None, run_info_extras=None, **kwargs):

View file

@ -1,24 +0,0 @@
language: python
sudo: false
cache:
directories:
- $HOME/.cache/pip
matrix:
include:
- python: 2.7
env: TOXENV=py27
- python: pypy
env: TOXENV=pypy
install:
- pip install -U tox codecov
script:
- tox
after_success:
- coverage combine
- codecov

View file

@ -75,7 +75,7 @@ class TestUsingServer(unittest.TestCase):
req.add_data(body)
if auth is not None:
req.add_header("Authorization", "Basic %s" % base64.b64encode('%s:%s' % auth))
req.add_header("Authorization", b"Basic %s" % base64.b64encode((b"%s:%s" % auth)))
return urlopen(req)

View file

@ -57,40 +57,35 @@ class TestSlice(TestUsingServer):
self.assertEqual(resp.read(), expected[:10])
class TestSub(TestUsingServer):
@pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
def test_sub_config(self):
resp = self.request("/sub.txt", query="pipe=sub")
expected = "localhost localhost %i" % self.server.port
expected = b"localhost localhost %i" % self.server.port
self.assertEqual(resp.read().rstrip(), expected)
@pytest.mark.xfail(sys.platform == "win32",
reason="https://github.com/web-platform-tests/wpt/issues/12949")
@pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
def test_sub_file_hash(self):
resp = self.request("/sub_file_hash.sub.txt")
expected = """
expected = b"""
md5: JmI1W8fMHfSfCarYOSxJcw==
sha1: nqpWqEw4IW8NjD6R375gtrQvtTo=
sha224: RqQ6fMmta6n9TuA/vgTZK2EqmidqnrwBAmQLRQ==
sha256: G6Ljg1uPejQxqFmvFOcV/loqnjPTW5GSOePOfM/u0jw=
sha384: lkXHChh1BXHN5nT5BYhi1x67E1CyYbPKRKoF2LTm5GivuEFpVVYtvEBHtPr74N9E
sha512: r8eLGRTc7ZznZkFjeVLyo6/FyQdra9qmlYCwKKxm3kfQAswRS9+3HsYk3thLUhcFmmWhK4dXaICz
JwGFonfXwg=="""
sha512: r8eLGRTc7ZznZkFjeVLyo6/FyQdra9qmlYCwKKxm3kfQAswRS9+3HsYk3thLUhcFmmWhK4dXaICzJwGFonfXwg=="""
self.assertEqual(resp.read().rstrip(), expected.strip())
def test_sub_file_hash_unrecognized(self):
with self.assertRaises(urllib.error.HTTPError):
self.request("/sub_file_hash_unrecognized.sub.txt")
@pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
def test_sub_headers(self):
resp = self.request("/sub_headers.txt", query="pipe=sub", headers={"X-Test": "PASS"})
expected = "PASS"
expected = b"PASS"
self.assertEqual(resp.read().rstrip(), expected)
@pytest.mark.xfail(sys.platform == "win32",
reason="https://github.com/web-platform-tests/wpt/issues/12949")
@pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
def test_sub_location(self):
resp = self.request("/sub_location.sub.txt?query_string")
expected = """
@ -101,30 +96,26 @@ pathname: /sub_location.sub.txt
port: {0}
query: ?query_string
scheme: http
server: http://localhost:{0}""".format(self.server.port)
server: http://localhost:{0}""".format(self.server.port).encode("ascii")
self.assertEqual(resp.read().rstrip(), expected.strip())
@pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
def test_sub_params(self):
resp = self.request("/sub_params.txt", query="test=PASS&pipe=sub")
expected = "PASS"
expected = b"PASS"
self.assertEqual(resp.read().rstrip(), expected)
@pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
def test_sub_url_base(self):
resp = self.request("/sub_url_base.sub.txt")
self.assertEqual(resp.read().rstrip(), "Before / After")
self.assertEqual(resp.read().rstrip(), b"Before / After")
@pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
def test_sub_uuid(self):
resp = self.request("/sub_uuid.sub.txt")
self.assertRegexpMatches(resp.read().rstrip(), r"Before [a-f0-9-]+ After")
self.assertRegexpMatches(resp.read().rstrip(), b"Before [a-f0-9-]+ After")
@pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
def test_sub_var(self):
resp = self.request("/sub_var.sub.txt")
port = self.server.port
expected = "localhost %s A %s B localhost C" % (port, port)
expected = b"localhost %d A %d B localhost C" % (port, port)
self.assertEqual(resp.read().rstrip(), expected)
class TestTrickle(TestUsingServer):

View file

@ -1,5 +1,4 @@
import sys
# -*- coding: utf-8 -*-
import pytest
wptserve = pytest.importorskip("wptserve")
@ -115,16 +114,40 @@ class TestRequest(TestUsingServer):
resp = self.request("/test/some_route")
self.assertEqual(b"some route", resp.read())
def test_non_ascii_in_headers(self):
@wptserve.handlers.handler
def handler(request, response):
return request.headers["foo"]
route = ("GET", "/test/test_unicode_in_headers", handler)
self.server.router.register(*route)
# Try some non-ASCII characters and the server shouldn't crash.
encoded_text = u"你好".encode("utf-8")
resp = self.request(route[1], headers={"foo": encoded_text})
self.assertEqual(encoded_text, resp.read())
# Try a different encoding from utf-8 to make sure the binary value is
# returned in verbatim.
encoded_text = u"どうも".encode("shift-jis")
resp = self.request(route[1], headers={"foo": encoded_text})
self.assertEqual(encoded_text, resp.read())
class TestAuth(TestUsingServer):
@pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
def test_auth(self):
@wptserve.handlers.handler
def handler(request, response):
return " ".join((request.auth.username, request.auth.password))
return b" ".join((request.auth.username, request.auth.password))
route = ("GET", "/test/test_auth", handler)
self.server.router.register(*route)
resp = self.request(route[1], auth=("test", "PASS"))
resp = self.request(route[1], auth=(b"test", b"PASS"))
self.assertEqual(200, resp.getcode())
self.assertEqual(["test", "PASS"], resp.read().split(" "))
self.assertEqual([b"test", b"PASS"], resp.read().split(b" "))
encoded_text = u"どうも".encode("shift-jis")
resp = self.request(route[1], auth=(encoded_text, encoded_text))
self.assertEqual(200, resp.getcode())
self.assertEqual([encoded_text, encoded_text], resp.read().split(b" "))

View file

@ -1,5 +1,6 @@
from cgi import escape
from collections import deque
import base64
import gzip as gzip_module
import hashlib
import os
@ -393,7 +394,7 @@ class SubFunctions(object):
@staticmethod
def file_hash(request, algorithm, path):
algorithm = algorithm.decode("ascii")
assert isinstance(algorithm, text_type)
if algorithm not in SubFunctions.supported_algorithms:
raise ValueError("Unsupported encryption algorithm: '%s'" % algorithm)
@ -401,7 +402,7 @@ class SubFunctions(object):
absolute_path = os.path.join(request.doc_root, path)
try:
with open(absolute_path) as f:
with open(absolute_path, "rb") as f:
hash_obj.update(f.read())
except IOError:
# In this context, an unhandled IOError will be interpreted by the
@ -411,7 +412,7 @@ class SubFunctions(object):
# the path to the file to be hashed is invalid.
raise Exception('Cannot open file for hash computation: "%s"' % absolute_path)
return hash_obj.digest().encode('base64').strip()
return base64.b64encode(hash_obj.digest()).strip()
def template(request, content, escape_type="html"):
#TODO: There basically isn't any error handling here
@ -490,9 +491,14 @@ def template(request, content, escape_type="html"):
escape_func = {"html": lambda x:escape(x, quote=True),
"none": lambda x:x}[escape_type]
#Should possibly support escaping for other contexts e.g. script
#TODO: read the encoding of the response
return escape_func(text_type(value)).encode("utf-8")
# Should possibly support escaping for other contexts e.g. script
# TODO: read the encoding of the response
# cgi.escape() only takes text strings in Python 3.
if isinstance(value, binary_type):
value = value.decode("utf-8")
elif isinstance(value, int):
value = text_type(value)
return escape_func(value).encode("utf-8")
template_regexp = re.compile(br"{{([^}]*)}}")
new_content = template_regexp.sub(config_replacement, content)

View file

@ -3,6 +3,10 @@ from .utils import HTTPException
class RangeParser(object):
def __call__(self, header, file_size):
try:
header = header.decode("ascii")
except UnicodeDecodeError:
raise HTTPException(400, "Non-ASCII range header value")
prefix = "bytes="
if not header.startswith(prefix):
raise HTTPException(416, message="Unrecognised range type %s" % (header,))

View file

@ -1,7 +1,7 @@
import base64
import cgi
from six.moves.http_cookies import BaseCookie
from six import BytesIO
from six import BytesIO, binary_type, text_type
import tempfile
from six.moves.urllib.parse import parse_qsl, urlsplit
@ -308,7 +308,7 @@ class Request(object):
self.raw_input.seek(0)
fs = cgi.FieldStorage(fp=self.raw_input,
environ={"REQUEST_METHOD": self.method},
headers=self.headers,
headers=self.raw_headers,
keep_blank_values=True)
self._POST = MultiDict.from_field_storage(fs)
self.raw_input.seek(pos)
@ -318,7 +318,7 @@ class Request(object):
def cookies(self):
if self._cookies is None:
parser = BaseCookie()
cookie_headers = self.headers.get("cookie", "")
cookie_headers = self.headers.get("cookie", b"")
parser.load(cookie_headers)
cookies = Cookies()
for key, value in parser.iteritems():
@ -355,11 +355,34 @@ class H2Request(Request):
super(H2Request, self).__init__(request_handler)
def _maybe_encode(s):
"""Encodes a text-type string into binary data using iso-8859-1.
Returns `str` in Python 2 and `bytes` in Python 3. The function is a no-op
if the argument already has a binary type.
"""
if isinstance(s, binary_type):
return s
# Python 3 assumes iso-8859-1 when parsing headers, which will garble text
# with non ASCII characters. We try to encode the text back to binary.
# https://github.com/python/cpython/blob/273fc220b25933e443c82af6888eb1871d032fb8/Lib/http/client.py#L213
if isinstance(s, text_type):
return s.encode("iso-8859-1")
raise TypeError("Unexpected value in RequestHeaders: %r" % s)
class RequestHeaders(dict):
"""Dictionary-like API for accessing request headers."""
"""Read-only dictionary-like API for accessing request headers.
Unlike BaseHTTPRequestHandler.headers, this class always returns all
headers with the same name (separated by commas). And it ensures all keys
(i.e. names of headers) and values have binary type.
"""
def __init__(self, items):
for header in items.keys():
key = header.lower()
key = _maybe_encode(header).lower()
# get all headers with the same name
values = items.getallmatchingheaders(header)
if len(values) > 1:
@ -369,15 +392,17 @@ class RequestHeaders(dict):
for value in values:
# getallmatchingheaders returns raw header lines, so
# split to get name, value
multiples.append(value.split(':', 1)[1].strip())
dict.__setitem__(self, key, multiples)
multiples.append(_maybe_encode(value).split(b':', 1)[1].strip())
headers = multiples
else:
dict.__setitem__(self, key, [items[header]])
headers = [_maybe_encode(items[header])]
dict.__setitem__(self, key, headers)
def __getitem__(self, key):
"""Get all headers of a certain (case-insensitive) name. If there is
more than one, the values are returned comma separated"""
key = _maybe_encode(key)
values = dict.__getitem__(self, key.lower())
if len(values) == 1:
return values[0]
@ -403,6 +428,7 @@ class RequestHeaders(dict):
def get_list(self, key, default=missing):
"""Get all the header values for a particular field name as
a list"""
key = _maybe_encode(key)
try:
return dict.__getitem__(self, key.lower())
except KeyError:
@ -412,6 +438,7 @@ class RequestHeaders(dict):
raise
def __contains__(self, key):
key = _maybe_encode(key)
return dict.__contains__(self, key.lower())
def iteritems(self):
@ -590,21 +617,28 @@ class Authentication(object):
The password supplied in the HTTP Authorization
header, or None
Both attributes are binary strings (`str` in Py2, `bytes` in Py3), since
RFC7617 Section 2.1 does not specify the encoding for username & passsword
(as long it's compatible with ASCII). UTF-8 should be a relatively safe
choice if callers need to decode them as most browsers use it.
"""
def __init__(self, headers):
self.username = None
self.password = None
auth_schemes = {"Basic": self.decode_basic}
auth_schemes = {b"Basic": self.decode_basic}
if "authorization" in headers:
header = headers.get("authorization")
auth_type, data = header.split(" ", 1)
assert isinstance(header, binary_type)
auth_type, data = header.split(b" ", 1)
if auth_type in auth_schemes:
self.username, self.password = auth_schemes[auth_type](data)
else:
raise HTTPException(400, "Unsupported authentication scheme %s" % auth_type)
def decode_basic(self, data):
decoded_data = base64.decodestring(data)
return decoded_data.split(":", 1)
assert isinstance(data, binary_type)
decoded_data = base64.b64decode(data)
return decoded_data.split(b":", 1)