From de2b9b7707ad07a5ae92560b0e761b2bda42cca4 Mon Sep 17 00:00:00 2001 From: rabisg Date: Sun, 1 Jan 2017 17:57:34 +0530 Subject: [PATCH] Fixes #14787: Set Origin header in http_network_or_cache_fetch Sets Origin header on request with CORS flag set or on requests other than those with GET/HEAD methods --- components/net/http_loader.rs | 24 ++++++++++-- tests/unit/net/http_loader.rs | 74 ++++++++++++++++++++++++++++++++--- 2 files changed, 89 insertions(+), 9 deletions(-) diff --git a/components/net/http_loader.rs b/components/net/http_loader.rs index 91bf0b643f3..9ffc95b8620 100644 --- a/components/net/http_loader.rs +++ b/components/net/http_loader.rs @@ -22,6 +22,7 @@ use hyper::header::{Authorization, Basic, CacheControl, CacheDirective, ContentE use hyper::header::{ContentLength, Encoding, Header, Headers, Host, IfMatch, IfRange}; use hyper::header::{IfUnmodifiedSince, IfModifiedSince, IfNoneMatch, Location, Pragma, Quality}; use hyper::header::{QualityItem, Referer, SetCookie, UserAgent, qitem}; +use hyper::header::Origin as HyperOrigin; use hyper::method::Method; use hyper::net::Fresh; use hyper::status::StatusCode; @@ -785,6 +786,15 @@ fn http_redirect_fetch(request: Rc, main_fetch(request, cache, cors_flag, true, target, done_chan, context) } +fn try_immutable_origin_to_hyper_origin(url_origin: &ImmutableOrigin) -> Option { + match *url_origin { + // TODO (servo/servo#15569) Set "Origin: null" when hyper supports it + ImmutableOrigin::Opaque(_) => None, + ImmutableOrigin::Tuple(ref scheme, ref host, ref port) => + Some(HyperOrigin::new(scheme.clone(), host.to_string(), Some(port.clone()))) + } +} + /// [HTTP network or cache fetch](https://fetch.spec.whatwg.org#http-network-or-cache-fetch) fn http_network_or_cache_fetch(request: Rc, authentication_fetch_flag: bool, @@ -843,10 +853,16 @@ fn http_network_or_cache_fetch(request: Rc, }; // Step 9 - if cors_flag || - (*http_request.method.borrow() != Method::Get && *http_request.method.borrow() != Method::Head) { - // TODO update this when https://github.com/hyperium/hyper/pull/691 is finished - // http_request.headers.borrow_mut().set_raw("origin", origin); + if !http_request.omit_origin_header.get() { + let method = http_request.method.borrow(); + if cors_flag || (*method != Method::Get && *method != Method::Head) { + debug_assert!(*http_request.origin.borrow() != Origin::Client); + if let Origin::Origin(ref url_origin) = *http_request.origin.borrow() { + if let Some(hyper_origin) = try_immutable_origin_to_hyper_origin(url_origin) { + http_request.headers.borrow_mut().set(hyper_origin) + } + } + } } // Step 10 diff --git a/tests/unit/net/http_loader.rs b/tests/unit/net/http_loader.rs index eae4f4c3c4c..94da5a975fc 100644 --- a/tests/unit/net/http_loader.rs +++ b/tests/unit/net/http_loader.rs @@ -13,8 +13,8 @@ use flate2::Compression; use flate2::write::{DeflateEncoder, GzEncoder}; use hyper::LanguageTag; use hyper::header::{Accept, AcceptEncoding, ContentEncoding, ContentLength, Cookie as CookieHeader}; -use hyper::header::{AcceptLanguage, Authorization, Basic, Date}; -use hyper::header::{Encoding, Headers, Host, Location, Quality, QualityItem, SetCookie, qitem}; +use hyper::header::{AcceptLanguage, AccessControlAllowOrigin, Authorization, Basic, Date}; +use hyper::header::{Encoding, Headers, Host, Location, Origin, Quality, QualityItem, SetCookie, qitem}; use hyper::header::{StrictTransportSecurity, UserAgent}; use hyper::method::Method; use hyper::mime::{Mime, SubLevel, TopLevel}; @@ -28,12 +28,13 @@ use net::cookie_storage::CookieStorage; use net::resource_thread::AuthCacheEntry; use net_traits::{CookieSource, NetworkError}; use net_traits::hosts::replace_host_table; -use net_traits::request::{Request, RequestInit, CredentialsMode, Destination}; +use net_traits::request::{Request, RequestInit, RequestMode, CredentialsMode, Destination}; use net_traits::response::ResponseBody; use new_fetch_context; use servo_url::ServoUrl; use std::collections::HashMap; use std::io::{Read, Write}; +use std::str::FromStr; use std::sync::{Arc, Mutex, RwLock, mpsc}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::Receiver; @@ -146,8 +147,13 @@ fn test_check_default_headers_loaded_in_every_request() { assert!(response.status.unwrap().is_success()); // Testing for method.POST - headers.set(ContentLength(0 as u64)); - *expected_headers.lock().unwrap() = Some(headers.clone()); + let mut post_headers = headers.clone(); + post_headers.set(ContentLength(0 as u64)); + let url_str = url.as_str(); + // request gets header "Origin: http://example.com" but expected_headers has + // "Origin: http://example.com/" which do not match for equality so strip trailing '/' + post_headers.set(Origin::from_str(&url_str[..url_str.len()-1]).unwrap()); + *expected_headers.lock().unwrap() = Some(post_headers); let request = Request::from_init(RequestInit { url: url.clone(), method: Method::Post, @@ -1118,3 +1124,61 @@ fn test_auth_ui_needs_www_auth() { assert_eq!(response.status.unwrap(), StatusCode::Unauthorized); } + +#[test] +fn test_origin_set() { + let origin_header = Arc::new(Mutex::new(None)); + let origin_header_clone = origin_header.clone(); + let handler = move |request: HyperRequest, mut resp: HyperResponse| { + let origin_header_clone = origin_header.clone(); + resp.headers_mut().set(AccessControlAllowOrigin::Any); + match request.headers.get::() { + None => assert_eq!(origin_header_clone.lock().unwrap().take(), None), + Some(h) => assert_eq!(*h, origin_header_clone.lock().unwrap().take().unwrap()), + } + }; + let (mut server, url) = make_server(handler); + + let mut origin = Origin::new(url.scheme(), url.host_str().unwrap(), url.port()); + *origin_header_clone.lock().unwrap() = Some(origin.clone()); + let request = Request::from_init(RequestInit { + url: url.clone(), + method: Method::Post, + body: None, + origin: url.clone(), + .. RequestInit::default() + }); + let response = fetch(request, None); + assert!(response.status.unwrap().is_success()); + + let origin_url = ServoUrl::parse("http://example.com").unwrap(); + origin = Origin::new(origin_url.scheme(), origin_url.host_str().unwrap(), origin_url.port()); + // Test Origin header is set on Get request with CORS mode + let request = Request::from_init(RequestInit { + url: url.clone(), + method: Method::Get, + mode: RequestMode::CorsMode, + body: None, + origin: origin_url.clone(), + .. RequestInit::default() + }); + + *origin_header_clone.lock().unwrap() = Some(origin.clone()); + let response = fetch(request, None); + assert!(response.status.unwrap().is_success()); + + // Test Origin header is not set on method Head + let request = Request::from_init(RequestInit { + url: url.clone(), + method: Method::Head, + body: None, + origin: url.clone(), + .. RequestInit::default() + }); + + *origin_header_clone.lock().unwrap() = None; + let response = fetch(request, None); + assert!(response.status.unwrap().is_success()); + + let _ = server.close(); +}