Update tungstenite

This commit is contained in:
Fabrice Desré 2023-02-25 13:22:47 -08:00 committed by The Capyloon Team
parent bc8cea2495
commit 0d0540fc95
5 changed files with 98 additions and 264 deletions

View file

@ -22,7 +22,7 @@ use embedder_traits::resources::{self, Resource};
use futures::future::TryFutureExt;
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use http::header::{HeaderMap, HeaderName, HeaderValue};
use http::header::{self, HeaderName, HeaderValue};
use ipc_channel::ipc::{IpcReceiver, IpcSender};
use ipc_channel::router::ROUTER;
use net_traits::request::{RequestBuilder, RequestMode};
@ -33,14 +33,13 @@ use servo_url::ServoUrl;
use std::fs;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use tokio2::net::TcpStream;
use tokio2::runtime::Runtime;
use tokio2::select;
use tokio2::sync::mpsc::{unbounded_channel, UnboundedReceiver};
use tungstenite::error::Error;
use tokio::net::TcpStream;
use tokio::runtime::Runtime;
use tokio::select;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
use tungstenite::error::Result as WebSocketResult;
use tungstenite::error::{Error, ProtocolError, UrlError};
use tungstenite::handshake::client::{Request, Response};
use tungstenite::http::header::{self as WSHeader, HeaderValue as WSHeaderValue};
use tungstenite::protocol::CloseFrame;
use tungstenite::Message;
use url::Url;
@ -65,20 +64,32 @@ fn create_request(
) -> WebSocketResult<Request> {
let mut builder = Request::get(resource_url.as_str());
let headers = builder.headers_mut().unwrap();
headers.insert("Origin", WSHeaderValue::from_str(origin)?);
headers.insert("Origin", HeaderValue::from_str(origin)?);
let origin = resource_url.origin();
let host = format!(
"{}",
origin
.host()
.ok_or_else(|| Error::Url(UrlError::NoHostName))?
);
headers.insert("Host", HeaderValue::from_str(&host)?);
headers.insert("Connection", HeaderValue::from_static("upgrade"));
headers.insert("Upgrade", HeaderValue::from_static("websocket"));
headers.insert("Sec-Websocket-Version", HeaderValue::from_static("13"));
let key = HeaderValue::from_str(&tungstenite::handshake::client::generate_key()).unwrap();
headers.insert("Sec-WebSocket-Key", key);
if !protocols.is_empty() {
let protocols = protocols.join(",");
headers.insert(
"Sec-WebSocket-Protocol",
WSHeaderValue::from_str(&protocols)?,
);
headers.insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocols)?);
}
let mut cookie_jar = http_state.cookie_jar.write().unwrap();
cookie_jar.remove_expired_cookies_for_url(resource_url);
if let Some(cookie_list) = cookie_jar.cookies_for_url(resource_url, CookieSource::HTTP) {
headers.insert("Cookie", WSHeaderValue::from_str(&cookie_list)?);
headers.insert("Cookie", HeaderValue::from_str(&cookie_list)?);
}
if resource_url.password().is_some() || resource_url.username() != "" {
@ -89,7 +100,7 @@ fn create_request(
));
headers.insert(
"Authorization",
WSHeaderValue::from_str(&format!("Basic {}", basic))?,
HeaderValue::from_str(&format!("Basic {}", basic))?,
);
}
@ -110,18 +121,18 @@ fn process_ws_response(
trace!("processing websocket http response for {}", resource_url);
let mut protocol_in_use = None;
if let Some(protocol_name) = response.headers().get("Sec-WebSocket-Protocol") {
let protocol_name = protocol_name.to_str().unwrap();
let protocol_name = protocol_name.to_str().unwrap_or("");
if !protocols.is_empty() && !protocols.iter().any(|p| protocol_name == (*p)) {
return Err(Error::Protocol(
"Protocol in use not in client-supplied protocol list".into(),
));
return Err(Error::Protocol(ProtocolError::InvalidHeader(
HeaderName::from_static("sec-websocket-protocol"),
)));
}
protocol_in_use = Some(protocol_name.to_string());
}
let mut jar = http_state.cookie_jar.write().unwrap();
// TODO(eijebong): Replace thise once typed headers settled on a cookie impl
for cookie in response.headers().get_all(WSHeader::SET_COOKIE) {
for cookie in response.headers().get_all(header::SET_COOKIE) {
if let Ok(s) = std::str::from_utf8(cookie.as_bytes()) {
if let Some(cookie) =
Cookie::from_cookie_string(s.into(), resource_url, CookieSource::HTTP)
@ -131,23 +142,11 @@ fn process_ws_response(
}
}
// We need to make a new header map here because tungstenite depends on
// a more recent version of http than the rest of the network stack, so the
// HeaderMap types are incompatible.
let mut headers = HeaderMap::new();
for (key, value) in response.headers().iter() {
if let (Ok(key), Ok(value)) = (
HeaderName::from_bytes(key.as_ref()),
HeaderValue::from_bytes(value.as_ref()),
) {
headers.insert(key, value);
}
}
http_state
.hsts_list
.write()
.unwrap()
.update_hsts_list_from_response(resource_url, &headers);
.update_hsts_list_from_response(resource_url, &response.headers());
Ok(protocol_in_use)
}
@ -283,6 +282,10 @@ async fn run_ws_loop(
));
break;
}
Message::Frame(_) => {
warn!("Unexpected websocket frame message");
}
}
}
}
@ -309,20 +312,20 @@ async fn start_websocket(
let host_str = client
.uri()
.host()
.ok_or_else(|| Error::Url("No host string".into()))?;
.ok_or_else(|| Error::Url(UrlError::NoHostName))?;
let host = replace_host(host_str);
let mut net_url =
Url::parse(&client.uri().to_string()).map_err(|e| Error::Url(e.to_string().into()))?;
let mut net_url = Url::parse(&client.uri().to_string())
.map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
net_url
.set_host(Some(&host))
.map_err(|e| Error::Url(e.to_string().into()))?;
.map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
let domain = net_url
.host()
.ok_or_else(|| Error::Url("No host string".into()))?;
.ok_or_else(|| Error::Url(UrlError::NoHostName))?;
let port = net_url
.port_or_known_default()
.ok_or_else(|| Error::Url("Unknown port".into()))?;
.ok_or_else(|| Error::Url(UrlError::UnableToConnect("Unknown port".into())))?;
let try_socket = TcpStream::connect((&*domain.to_string(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
@ -366,32 +369,13 @@ fn connect(
};
// https://fetch.spec.whatwg.org/#websocket-opening-handshake
// By standard, we should work with an http(s):// URL (req_url),
// but as ws-rs expects to be called with a ws(s):// URL (net_url)
// we upgrade ws to wss, so we don't have to convert http(s) back to ws(s).
http_state
.hsts_list
.read()
.unwrap()
.apply_hsts_rules(&mut req_builder.url);
let scheme = req_builder.url.scheme();
let mut req_url = req_builder.url.clone();
match scheme {
"ws" => {
req_url
.as_mut_url()
.set_scheme("http")
.map_err(|()| "couldn't replace scheme".to_string())?;
},
"wss" => {
req_url
.as_mut_url()
.set_scheme("https")
.map_err(|()| "couldn't replace scheme".to_string())?;
},
_ => {},
}
let req_url = req_builder.url.clone();
if should_be_blocked_due_to_bad_port(&req_url) {
return Err("Port blocked".to_string());
@ -403,7 +387,7 @@ fn connect(
};
let client = match create_request(
&req_builder.url,
&req_url,
&req_builder.origin.ascii_serialization(),
&protocols,
&*http_state,