diff --git a/components/net/Cargo.toml b/components/net/Cargo.toml index f8c9a6bb5ea..385f81da21e 100644 --- a/components/net/Cargo.toml +++ b/components/net/Cargo.toml @@ -43,6 +43,7 @@ openssl = "0.7.6" rustc-serialize = "0.3" threadpool = "1.0" time = "0.1.17" +unicase = "1.4.0" url = {version = "0.5.7", features = ["heap_size"]} uuid = { version = "0.2", features = ["v4"] } websocket = "0.16.1" diff --git a/components/net/fetch/methods.rs b/components/net/fetch/methods.rs index 0980bfee301..a9a8a3969f8 100644 --- a/components/net/fetch/methods.rs +++ b/components/net/fetch/methods.rs @@ -5,12 +5,12 @@ use data_loader::decode; use fetch::cors_cache::{BasicCORSCache, CORSCache, CacheRequestDetails}; use http_loader::{NetworkHttpRequestFactory, create_http_connector, obtain_response}; -use hyper::header::{Accept, CacheControl, IfMatch, IfRange, IfUnmodifiedSince, Location}; -use hyper::header::{AcceptLanguage, ContentLength, ContentLanguage, HeaderView, Pragma}; -use hyper::header::{AccessControlAllowCredentials, AccessControlAllowOrigin}; -use hyper::header::{Authorization, Basic, CacheDirective, ContentEncoding, Encoding}; -use hyper::header::{ContentType, Headers, IfModifiedSince, IfNoneMatch}; -use hyper::header::{QualityItem, q, qitem, Referer as RefererHeader, UserAgent}; +use hyper::header::{Accept, AcceptLanguage, Authorization, AccessControlAllowCredentials}; +use hyper::header::{AccessControlAllowOrigin, AccessControlAllowHeaders, AccessControlAllowMethods}; +use hyper::header::{AccessControlRequestHeaders, AccessControlMaxAge, AccessControlRequestMethod, Basic}; +use hyper::header::{CacheControl, CacheDirective, ContentEncoding, ContentLength, ContentLanguage, ContentType}; +use hyper::header::{Encoding, HeaderView, Headers, IfMatch, IfRange, IfUnmodifiedSince, IfModifiedSince}; +use hyper::header::{IfNoneMatch, Pragma, Location, QualityItem, Referer as RefererHeader, UserAgent, q, qitem}; use hyper::method::Method; use hyper::mime::{Attr, Mime, SubLevel, TopLevel, Value}; use hyper::status::StatusCode; @@ -20,9 +20,12 @@ use net_traits::request::{RedirectMode, Referer, Request, RequestMode, ResponseT use net_traits::response::{HttpsState, TerminationReason}; use net_traits::response::{Response, ResponseBody, ResponseType}; use resource_thread::CancellationListener; +use std::collections::HashSet; use std::io::Read; +use std::iter::FromIterator; use std::rc::Rc; use std::thread; +use unicase::UniCase; use url::idna::domain_to_ascii; use url::{Origin as UrlOrigin, OpaqueOrigin, Url, UrlParser, whatwg_scheme_type_mapper}; use util::thread::spawn_named; @@ -210,7 +213,7 @@ fn main_fetch(request: Rc, cors_flag: bool, recursive_flag: bool) -> Re let internal_response = if response.is_network_error() { &network_error_res } else { - response.get_actual_response() + response.actual_response() }; // Step 13 @@ -245,7 +248,7 @@ fn main_fetch(request: Rc, cors_flag: bool, recursive_flag: bool) -> Re // Step 16 if request.synchronous { - response.get_actual_response().wait_until_done(); + response.actual_response().wait_until_done(); return response; } @@ -263,7 +266,7 @@ fn main_fetch(request: Rc, cors_flag: bool, recursive_flag: bool) -> Re let internal_response = if response.is_network_error() { &network_error_res } else { - response.get_actual_response() + response.actual_response() }; // Step 18 @@ -367,7 +370,7 @@ fn http_fetch(request: Rc, } // Substep 4 - let actual_response = res.get_actual_response(); + let actual_response = res.actual_response(); if actual_response.url_list.borrow().is_empty() { *actual_response.url_list.borrow_mut() = request.url_list.borrow().clone(); } @@ -392,7 +395,7 @@ fn http_fetch(request: Rc, }, request.method.borrow().clone()); let method_mismatch = !method_cache_match && (!is_simple_method(&request.method.borrow()) || - request.use_cors_preflight); + request.use_cors_preflight); let header_mismatch = request.headers.borrow().iter().any(|view| !cache.match_header(CacheRequestDetails { origin: origin.clone(), @@ -403,7 +406,7 @@ fn http_fetch(request: Rc, // Sub-substep 1 if method_mismatch || header_mismatch { - let preflight_result = preflight_fetch(request.clone()); + let preflight_result = cors_preflight_fetch(request.clone(), Some(cache)); // Sub-substep 2 if preflight_result.response_type == ResponseType::Error { return Response::network_error(); @@ -417,8 +420,7 @@ fn http_fetch(request: Rc, // Substep 3 let credentials = match request.credentials_mode { CredentialsMode::Include => true, - CredentialsMode::CredentialsSameOrigin if - request.response_tainting.get() == ResponseTainting::Basic + CredentialsMode::CredentialsSameOrigin if request.response_tainting.get() == ResponseTainting::Basic => true, _ => false }; @@ -439,7 +441,7 @@ fn http_fetch(request: Rc, let mut response = response.unwrap(); // Step 5 - match response.get_actual_response().status.unwrap() { + match response.actual_response().status.unwrap() { // Code 301, 302, 303, 307, 308 StatusCode::MovedPermanently | StatusCode::Found | StatusCode::SeeOther | @@ -520,21 +522,21 @@ fn http_redirect_fetch(request: Rc, assert_eq!(response.return_internal.get(), true); // Step 3 - // this step is done early, because querying if Location is available says + // this step is done early, because querying if Location exists says // if it is None or Some, making it easy to seperate from the retrieval failure case - if !response.get_actual_response().headers.has::() { + if !response.actual_response().headers.has::() { return Rc::try_unwrap(response).ok().unwrap(); } // Step 2 - let location = match response.get_actual_response().headers.get::() { + let location = match response.actual_response().headers.get::() { Some(&Location(ref location)) => location.clone(), // Step 4 _ => return Response::network_error() }; // Step 5 - let response_url = response.get_actual_response().url.as_ref().unwrap(); + let response_url = response.actual_response().url.as_ref().unwrap(); let location_url = UrlParser::new().base_url(response_url).parse(&*location); // Step 6 @@ -577,7 +579,7 @@ fn http_redirect_fetch(request: Rc, } // Step 13 - let status_code = response.get_actual_response().status.unwrap(); + let status_code = response.actual_response().status.unwrap(); if ((status_code == StatusCode::MovedPermanently || status_code == StatusCode::Found) && *request.method.borrow() == Method::Post) || status_code == StatusCode::SeeOther { @@ -878,11 +880,11 @@ fn http_network_fetch(request: Rc, // Substep 2 - // TODO how can I tell if response was retrieved over HTTPS? + // TODO Determine if response was retrieved over HTTPS // TODO Servo needs to decide what ciphers are to be treated as "deprecated" response.https_state = HttpsState::None; - // TODO how do I read request? + // TODO Read request // Step 5 // TODO when https://bugzilla.mozilla.org/show_bug.cgi?id=1030660 @@ -927,8 +929,113 @@ fn http_network_fetch(request: Rc, } /// [CORS preflight fetch](https://fetch.spec.whatwg.org#cors-preflight-fetch) -fn preflight_fetch(_request: Rc) -> Response { - // TODO: Implement preflight fetch spec +fn cors_preflight_fetch(request: Rc, cache: Option) -> Response { + // Step 1 + let mut preflight = Request::new(request.current_url(), Some(request.origin.borrow().clone()), false); + *preflight.method.borrow_mut() = Method::Options; + preflight.initiator = request.initiator.clone(); + preflight.type_ = request.type_.clone(); + preflight.destination = request.destination.clone(); + preflight.referer = request.referer.clone(); + + // Step 2 + preflight.headers.borrow_mut().set::( + AccessControlRequestMethod(request.method.borrow().clone())); + + // Step 3, 4 + let mut value = request.headers.borrow().iter() + .filter_map(|ref view| if is_simple_header(view) { + None + } else { + Some(UniCase(view.name().to_owned())) + }).collect::>>(); + value.sort(); + + // Step 5 + preflight.headers.borrow_mut().set::( + AccessControlRequestHeaders(value)); + + // Step 6 + let preflight = Rc::new(preflight); + let response = http_network_or_cache_fetch(preflight.clone(), false, false); + + // Step 7 + if cors_check(request.clone(), &response).is_ok() && + response.status.map_or(false, |status| status.is_success()) { + // Substep 1 + let mut methods = if response.headers.has::() { + match response.headers.get::() { + Some(&AccessControlAllowMethods(ref m)) => m.clone(), + // Substep 3 + None => return Response::network_error() + } + } else { + vec![] + }; + + // Substep 2 + let header_names = if response.headers.has::() { + match response.headers.get::() { + Some(&AccessControlAllowHeaders(ref hn)) => hn.clone(), + // Substep 3 + None => return Response::network_error() + } + } else { + vec![] + }; + + // Substep 4 + if methods.is_empty() && request.use_cors_preflight { + methods = vec![request.method.borrow().clone()]; + } + + // Substep 5 + if methods.iter().all(|method| *method != *request.method.borrow()) && + !is_simple_method(&*request.method.borrow()) { + return Response::network_error(); + } + + // Substep 6 + let set: HashSet<&UniCase> = HashSet::from_iter(header_names.iter()); + if request.headers.borrow().iter().any(|ref hv| !set.contains(&UniCase(hv.name().to_owned())) && + !is_simple_header(hv)) { + return Response::network_error(); + } + + // Substep 7, 8 + let max_age = response.headers.get::().map(|acma| acma.0).unwrap_or(0); + + // TODO: Substep 9 - Need to define what an imposed limit on max-age is + + // Substep 10 + let mut cache = match cache { + Some(c) => c, + None => return response + }; + + // Substep 11, 12 + for method in &methods { + cache.match_method_and_update(CacheRequestDetails { + origin: request.origin.borrow().clone(), + destination: request.current_url(), + credentials: request.credentials_mode == CredentialsMode::Include + }, method.clone(), max_age); + } + + // Substep 13, 14 + for header_name in &header_names { + cache.match_header_and_update(CacheRequestDetails { + origin: request.origin.borrow().clone(), + destination: request.current_url(), + credentials: request.credentials_mode == CredentialsMode::Include + }, &*header_name, max_age); + } + + // Substep 15 + return response; + } + + // Step 8 Response::network_error() } @@ -936,7 +1043,6 @@ fn preflight_fetch(_request: Rc) -> Response { fn cors_check(request: Rc, response: &Response) -> Result<(), ()> { // Step 1 - // let headers = request.headers.borrow(); let origin = response.headers.get::().cloned(); // Step 2 @@ -944,18 +1050,18 @@ fn cors_check(request: Rc, response: &Response) -> Result<(), ()> { // Step 3 if request.credentials_mode != CredentialsMode::Include && - origin == AccessControlAllowOrigin::Any { + origin == AccessControlAllowOrigin::Any { return Ok(()); } // Step 4 let origin = match origin { AccessControlAllowOrigin::Value(origin) => origin, - // if it's Any or Null at this point, I see nothing to do but return Err(()) + // if it's Any or Null at this point, there's nothing to do but return Err(()) _ => return Err(()) }; - // strings are already utf-8 encoded, so I don't need to re-encode origin for this step + // strings are already utf-8 encoded, so there's no need to re-encode origin for this step match ascii_serialise_origin(&request.origin.borrow()) { Ok(request_origin) => { if request_origin != origin { diff --git a/components/net/lib.rs b/components/net/lib.rs index 2b008a89a32..1926bee5085 100644 --- a/components/net/lib.rs +++ b/components/net/lib.rs @@ -27,6 +27,7 @@ extern crate openssl; extern crate rustc_serialize; extern crate threadpool; extern crate time; +extern crate unicase; extern crate url; extern crate util; extern crate uuid; @@ -47,7 +48,7 @@ pub mod resource_thread; pub mod storage_thread; pub mod websocket_loader; -/// An implementation of the [Fetch spec](https://fetch.spec.whatwg.org/) +/// An implementation of the [Fetch specification](https://fetch.spec.whatwg.org/) pub mod fetch { pub mod cors_cache; pub mod methods; diff --git a/components/net_traits/response.rs b/components/net_traits/response.rs index 9e6cedd2765..1c1796df52a 100644 --- a/components/net_traits/response.rs +++ b/components/net_traits/response.rs @@ -144,7 +144,7 @@ impl Response { } } - pub fn get_actual_response(&self) -> &Response { + pub fn actual_response(&self) -> &Response { if self.return_internal.get() && self.internal_response.is_some() { &**self.internal_response.as_ref().unwrap() } else { diff --git a/components/servo/Cargo.lock b/components/servo/Cargo.lock index 122dca34ba6..ba51bf7418e 100644 --- a/components/servo/Cargo.lock +++ b/components/servo/Cargo.lock @@ -918,7 +918,7 @@ dependencies = [ "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", "traitobject 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)", "typeable 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -1294,6 +1294,7 @@ dependencies = [ "rustc-serialize 0.3.16 (registry+https://github.com/rust-lang/crates.io-index)", "threadpool 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", "util 0.0.1", "uuid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1327,7 +1328,7 @@ dependencies = [ "net_traits 0.0.1", "plugins 0.0.1", "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", "util 0.0.1", ] @@ -1737,7 +1738,7 @@ dependencies = [ "string_cache 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", "style 0.0.1", "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", "util 0.0.1", "uuid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -2102,8 +2103,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "unicase" -version = "1.0.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rustc_version 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", +] [[package]] name = "unicode-bidi" @@ -2355,7 +2359,7 @@ dependencies = [ "openssl 0.7.8 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.3.14 (registry+https://github.com/rust-lang/crates.io-index)", "rustc-serialize 0.3.16 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", ] diff --git a/ports/cef/Cargo.lock b/ports/cef/Cargo.lock index 19a76595c6a..44482780143 100644 --- a/ports/cef/Cargo.lock +++ b/ports/cef/Cargo.lock @@ -837,7 +837,7 @@ dependencies = [ "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", "traitobject 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)", "typeable 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -1206,6 +1206,7 @@ dependencies = [ "rustc-serialize 0.3.16 (registry+https://github.com/rust-lang/crates.io-index)", "threadpool 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", "util 0.0.1", "uuid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1604,7 +1605,7 @@ dependencies = [ "string_cache 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", "style 0.0.1", "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", "util 0.0.1", "uuid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1981,8 +1982,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "unicase" -version = "1.0.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rustc_version 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", +] [[package]] name = "unicode-bidi" @@ -2223,7 +2227,7 @@ dependencies = [ "openssl 0.7.8 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.3.14 (registry+https://github.com/rust-lang/crates.io-index)", "rustc-serialize 0.3.16 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", ] diff --git a/ports/gonk/Cargo.lock b/ports/gonk/Cargo.lock index b741ead8af2..49b5dee5f4a 100644 --- a/ports/gonk/Cargo.lock +++ b/ports/gonk/Cargo.lock @@ -819,7 +819,7 @@ dependencies = [ "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", "traitobject 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)", "typeable 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -1188,6 +1188,7 @@ dependencies = [ "rustc-serialize 0.3.16 (registry+https://github.com/rust-lang/crates.io-index)", "threadpool 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", "util 0.0.1", "uuid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1586,7 +1587,7 @@ dependencies = [ "string_cache 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", "style 0.0.1", "time 0.1.34 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", "util 0.0.1", "uuid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1961,8 +1962,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "unicase" -version = "1.0.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rustc_version 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", +] [[package]] name = "unicode-bidi" @@ -2173,7 +2177,7 @@ dependencies = [ "openssl 0.7.8 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.3.14 (registry+https://github.com/rust-lang/crates.io-index)", "rustc-serialize 0.3.16 (registry+https://github.com/rust-lang/crates.io-index)", - "unicase 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicase 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 0.5.8 (registry+https://github.com/rust-lang/crates.io-index)", ] diff --git a/tests/unit/net/fetch.rs b/tests/unit/net/fetch.rs index 4b6eff5da69..b2353d33113 100644 --- a/tests/unit/net/fetch.rs +++ b/tests/unit/net/fetch.rs @@ -2,7 +2,8 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use hyper::header::{AccessControlAllowHeaders, AccessControlAllowOrigin}; +use hyper::header::{AccessControlAllowCredentials, AccessControlAllowHeaders, AccessControlAllowOrigin}; +use hyper::header::{AccessControlAllowMethods, AccessControlRequestHeaders, AccessControlRequestMethod}; use hyper::header::{CacheControl, ContentLanguage, ContentType, Expires, LastModified}; use hyper::header::{Headers, HttpDate, Location, SetCookie, Pragma}; use hyper::method::Method; @@ -16,6 +17,7 @@ use net_traits::AsyncFetchListener; use net_traits::request::{Origin, RedirectMode, Referer, Request, RequestMode}; use net_traits::response::{CacheState, Response, ResponseBody, ResponseType}; use std::rc::Rc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{Sender, channel}; use std::sync::{Arc, Mutex}; use time::{self, Duration}; @@ -136,6 +138,74 @@ fn test_fetch_data() { } } +#[test] +fn test_cors_preflight_fetch() { + static ACK: &'static [u8] = b"ACK"; + let state = Arc::new(AtomicUsize::new(0)); + let handler = move |request: HyperRequest, mut response: HyperResponse| { + if request.method == Method::Options && state.clone().fetch_add(1, Ordering::SeqCst) == 0 { + assert!(request.headers.has::()); + assert!(request.headers.has::()); + response.headers_mut().set(AccessControlAllowOrigin::Any); + response.headers_mut().set(AccessControlAllowCredentials); + response.headers_mut().set(AccessControlAllowMethods(vec![Method::Get])); + } else { + response.headers_mut().set(AccessControlAllowOrigin::Any); + response.send(ACK).unwrap(); + } + }; + let (mut server, url) = make_server(handler); + + let origin = Origin::Origin(UrlOrigin::UID(OpaqueOrigin::new())); + let mut request = Request::new(url, Some(origin), false); + request.referer = Referer::NoReferer; + request.use_cors_preflight = true; + request.mode = RequestMode::CORSMode; + let wrapped_request = Rc::new(request); + + let fetch_response = fetch(wrapped_request); + let _ = server.close(); + + assert!(!fetch_response.is_network_error()); + + match *fetch_response.body.lock().unwrap() { + ResponseBody::Done(ref body) => assert_eq!(&**body, ACK), + _ => panic!() + }; +} + +#[test] +fn test_cors_preflight_fetch_network_error() { + static ACK: &'static [u8] = b"ACK"; + let state = Arc::new(AtomicUsize::new(0)); + let handler = move |request: HyperRequest, mut response: HyperResponse| { + if request.method == Method::Options && state.clone().fetch_add(1, Ordering::SeqCst) == 0 { + assert!(request.headers.has::()); + assert!(request.headers.has::()); + response.headers_mut().set(AccessControlAllowOrigin::Any); + response.headers_mut().set(AccessControlAllowCredentials); + response.headers_mut().set(AccessControlAllowMethods(vec![Method::Get])); + } else { + response.headers_mut().set(AccessControlAllowOrigin::Any); + response.send(ACK).unwrap(); + } + }; + let (mut server, url) = make_server(handler); + + let origin = Origin::Origin(UrlOrigin::UID(OpaqueOrigin::new())); + let mut request = Request::new(url, Some(origin), false); + *request.method.borrow_mut() = Method::Extension("CHICKEN".to_owned()); + request.referer = Referer::NoReferer; + request.use_cors_preflight = true; + request.mode = RequestMode::CORSMode; + let wrapped_request = Rc::new(request); + + let fetch_response = fetch(wrapped_request); + let _ = server.close(); + + assert!(fetch_response.is_network_error()); +} + #[test] fn test_fetch_response_is_basic_filtered() { @@ -342,7 +412,7 @@ fn test_fetch_with_local_urls_only() { assert!(server_response.is_network_error()); } -fn test_fetch_redirect_count(message: &'static [u8], redirect_cap: u32) -> Response { +fn setup_server_and_fetch(message: &'static [u8], redirect_cap: u32) -> Response { let handler = move |request: HyperRequest, mut response: HyperResponse| { @@ -382,7 +452,7 @@ fn test_fetch_redirect_count_ceiling() { // how many redirects to cause let redirect_cap = 20; - let fetch_response = test_fetch_redirect_count(MESSAGE, redirect_cap); + let fetch_response = setup_server_and_fetch(MESSAGE, redirect_cap); assert!(!fetch_response.is_network_error()); assert_eq!(fetch_response.response_type, ResponseType::Basic); @@ -402,7 +472,7 @@ fn test_fetch_redirect_count_failure() { // how many redirects to cause let redirect_cap = 21; - let fetch_response = test_fetch_redirect_count(MESSAGE, redirect_cap); + let fetch_response = setup_server_and_fetch(MESSAGE, redirect_cap); assert!(fetch_response.is_network_error());