diff --git a/components/net/fetch/cors_cache.rs b/components/net/fetch/cors_cache.rs index 1f6d39f71b7..f6e83f79459 100644 --- a/components/net/fetch/cors_cache.rs +++ b/components/net/fetch/cors_cache.rs @@ -10,7 +10,7 @@ //! with CORSRequest being expanded into FetchRequest (etc) use hyper::method::Method; -use net_traits::request::Origin; +use net_traits::request::{CredentialsMode, Origin, Request}; use std::ascii::AsciiExt; use time; use time::{now, Timespec}; @@ -66,17 +66,10 @@ impl CORSCacheEntry { } } -/// Properties of Request required to cache match. -pub struct CacheRequestDetails { - pub origin: Origin, - pub destination: Url, - pub credentials: bool -} - -fn match_headers(cors_cache: &CORSCacheEntry, cors_req: &CacheRequestDetails) -> bool { - cors_cache.origin == cors_req.origin && - cors_cache.url == cors_req.destination && - (cors_cache.credentials || !cors_req.credentials) +fn match_headers(cors_cache: &CORSCacheEntry, cors_req: &Request) -> bool { + cors_cache.origin == *cors_req.origin.borrow() && + cors_cache.url == cors_req.current_url() && + (cors_cache.credentials || cors_req.credentials_mode != CredentialsMode::Include) } /// A simple, vector-based CORS Cache @@ -89,13 +82,13 @@ impl CORSCache { CORSCache(vec![]) } - fn find_entry_by_header<'a>(&'a mut self, request: &CacheRequestDetails, + fn find_entry_by_header<'a>(&'a mut self, request: &Request, header_name: &str) -> Option<&'a mut CORSCacheEntry> { self.cleanup(); self.0.iter_mut().find(|e| match_headers(e, request) && e.header_or_method.match_header(header_name)) } - fn find_entry_by_method<'a>(&'a mut self, request: &CacheRequestDetails, + fn find_entry_by_method<'a>(&'a mut self, request: &Request, method: Method) -> Option<&'a mut CORSCacheEntry> { // we can take the method from CORSRequest itself self.cleanup(); @@ -103,10 +96,11 @@ impl CORSCache { } /// [Clear the cache](https://fetch.spec.whatwg.org/#concept-cache-clear) - pub fn clear (&mut self, request: CacheRequestDetails) { + pub fn clear (&mut self, request: &Request) { let CORSCache(buf) = self.clone(); let new_buf: Vec = - buf.into_iter().filter(|e| e.origin == request.origin && request.destination == e.url).collect(); + buf.into_iter().filter(|e| e.origin == *request.origin.borrow() && + request.current_url() == e.url).collect(); *self = CORSCache(new_buf); } @@ -122,7 +116,7 @@ impl CORSCache { /// Returns true if an entry with a /// [matching header](https://fetch.spec.whatwg.org/#concept-cache-match-header) is found - pub fn match_header(&mut self, request: CacheRequestDetails, header_name: &str) -> bool { + pub fn match_header(&mut self, request: &Request, header_name: &str) -> bool { self.find_entry_by_header(&request, header_name).is_some() } @@ -130,13 +124,13 @@ impl CORSCache { /// [matching header](https://fetch.spec.whatwg.org/#concept-cache-match-header) is found. /// /// If not, it will insert an equivalent entry - pub fn match_header_and_update(&mut self, request: CacheRequestDetails, + pub fn match_header_and_update(&mut self, request: &Request, header_name: &str, new_max_age: u32) -> bool { match self.find_entry_by_header(&request, header_name).map(|e| e.max_age = new_max_age) { Some(_) => true, None => { - self.insert(CORSCacheEntry::new(request.origin, request.destination, new_max_age, - request.credentials, + self.insert(CORSCacheEntry::new(request.origin.borrow().clone(), request.current_url(), new_max_age, + request.credentials_mode == CredentialsMode::Include, HeaderOrMethod::HeaderData(header_name.to_owned()))); false } @@ -145,7 +139,7 @@ impl CORSCache { /// Returns true if an entry with a /// [matching method](https://fetch.spec.whatwg.org/#concept-cache-match-method) is found - pub fn match_method(&mut self, request: CacheRequestDetails, method: Method) -> bool { + pub fn match_method(&mut self, request: &Request, method: Method) -> bool { self.find_entry_by_method(&request, method).is_some() } @@ -153,12 +147,13 @@ impl CORSCache { /// [a matching method](https://fetch.spec.whatwg.org/#concept-cache-match-method) is found. /// /// If not, it will insert an equivalent entry - pub fn match_method_and_update(&mut self, request: CacheRequestDetails, method: Method, new_max_age: u32) -> bool { + pub fn match_method_and_update(&mut self, request: &Request, method: Method, new_max_age: u32) -> bool { match self.find_entry_by_method(&request, method.clone()).map(|e| e.max_age = new_max_age) { Some(_) => true, None => { - self.insert(CORSCacheEntry::new(request.origin, request.destination, new_max_age, - request.credentials, HeaderOrMethod::MethodData(method))); + self.insert(CORSCacheEntry::new(request.origin.borrow().clone(), request.current_url(), new_max_age, + request.credentials_mode == CredentialsMode::Include, + HeaderOrMethod::MethodData(method))); false } } diff --git a/components/net/fetch/methods.rs b/components/net/fetch/methods.rs index 1b4c43195ea..df5eae0a892 100644 --- a/components/net/fetch/methods.rs +++ b/components/net/fetch/methods.rs @@ -3,7 +3,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use data_loader::decode; -use fetch::cors_cache::{CORSCache, CacheRequestDetails}; +use fetch::cors_cache::CORSCache; use http_loader::{NetworkHttpRequestFactory, create_http_connector, obtain_response}; use hyper::header::{Accept, AcceptLanguage, Authorization, AccessControlAllowCredentials}; use hyper::header::{AccessControlAllowOrigin, AccessControlAllowHeaders, AccessControlAllowMethods}; @@ -402,23 +402,13 @@ fn http_fetch(request: Rc, // Substep 1 if cors_preflight_flag { - let origin = request.origin.borrow().clone(); - let url = request.current_url(); - let credentials = request.credentials_mode == CredentialsMode::Include; - let method_cache_match = cache.match_method(CacheRequestDetails { - origin: origin.clone(), - destination: url.clone(), - credentials: credentials - }, request.method.borrow().clone()); + let method_cache_match = cache.match_method(&*request, + request.method.borrow().clone()); let method_mismatch = !method_cache_match && (!is_simple_method(&request.method.borrow()) || request.use_cors_preflight); let header_mismatch = request.headers.borrow().iter().any(|view| - !cache.match_header(CacheRequestDetails { - origin: origin.clone(), - destination: url.clone(), - credentials: credentials - }, view.name()) && !is_simple_header(&view) + !cache.match_header(&*request, view.name()) && !is_simple_header(&view) ); // Sub-substep 1 @@ -1027,20 +1017,12 @@ fn cors_preflight_fetch(request: Rc, cache: &mut CORSCache) -> Response // Substep 11, 12 for method in &methods { - cache.match_method_and_update(CacheRequestDetails { - origin: request.origin.borrow().clone(), - destination: request.current_url(), - credentials: request.credentials_mode == CredentialsMode::Include - }, method.clone(), max_age); + cache.match_method_and_update(&*request, method.clone(), max_age); } // Substep 13, 14 for header_name in &header_names { - cache.match_header_and_update(CacheRequestDetails { - origin: request.origin.borrow().clone(), - destination: request.current_url(), - credentials: request.credentials_mode == CredentialsMode::Include - }, &*header_name, max_age); + cache.match_header_and_update(&*request, &*header_name, max_age); } // Substep 15 diff --git a/tests/unit/net/fetch.rs b/tests/unit/net/fetch.rs index d880e3937c0..a53807bff38 100644 --- a/tests/unit/net/fetch.rs +++ b/tests/unit/net/fetch.rs @@ -13,7 +13,7 @@ use hyper::server::{Handler, Listening, Server}; use hyper::server::{Request as HyperRequest, Response as HyperResponse}; use hyper::status::StatusCode; use hyper::uri::RequestUri; -use net::fetch::cors_cache::{CacheRequestDetails, CORSCache}; +use net::fetch::cors_cache::CORSCache; use net::fetch::methods::{fetch, fetch_async, fetch_with_cors_cache}; use net_traits::AsyncFetchListener; use net_traits::request::{Origin, RedirectMode, Referer, Request, RequestMode}; @@ -238,8 +238,8 @@ fn test_cors_preflight_cache_fetch() { let wrapped_request0 = Rc::new(request.clone()); let wrapped_request1 = Rc::new(request); - let fetch_response0 = fetch_with_cors_cache(wrapped_request0, &mut cache); - let fetch_response1 = fetch_with_cors_cache(wrapped_request1, &mut cache); + let fetch_response0 = fetch_with_cors_cache(wrapped_request0.clone(), &mut cache); + let fetch_response1 = fetch_with_cors_cache(wrapped_request1.clone(), &mut cache); let _ = server.close(); assert!(!fetch_response0.is_network_error() && !fetch_response1.is_network_error()); @@ -248,11 +248,8 @@ fn test_cors_preflight_cache_fetch() { assert_eq!(1, counter.load(Ordering::SeqCst)); // The entry exists in the CORS-preflight cache - assert_eq!(true, cache.match_method(CacheRequestDetails { - origin: origin, - destination: url, - credentials: false - }, Method::Get)); + assert_eq!(true, cache.match_method(&*wrapped_request0, Method::Get)); + assert_eq!(true, cache.match_method(&*wrapped_request1, Method::Get)); match *fetch_response0.body.lock().unwrap() { ResponseBody::Done(ref body) => assert_eq!(&**body, ACK),