Refactor mime_classifier

Use more iterators in particular.
This commit is contained in:
Johann Tuffe 2015-08-28 22:47:41 +08:00
parent 71b277d567
commit dd1c8c826e

View file

@ -3,7 +3,6 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
use std::borrow::ToOwned; use std::borrow::ToOwned;
use std::cmp::max;
pub struct MIMEClassifier { pub struct MIMEClassifier {
image_classifier: GroupedClassifier, image_classifier: GroupedClassifier,
@ -24,51 +23,34 @@ impl MIMEClassifier {
data: &[u8]) -> Option<(String, String)> { data: &[u8]) -> Option<(String, String)> {
match *supplied_type { match *supplied_type {
None => { None => self.sniff_unknown_type(!no_sniff, data),
return self.sniff_unknown_type(!no_sniff, data);
}
Some((ref media_type, ref media_subtype)) => { Some((ref media_type, ref media_subtype)) => {
match (&**media_type, &**media_subtype) { match (&**media_type, &**media_subtype) {
("unknown", "unknown") | ("application", "unknown") | ("*", "*") => { ("unknown", "unknown") |
return self.sniff_unknown_type(!no_sniff, data); ("application", "unknown") |
} ("*", "*") => self.sniff_unknown_type(!no_sniff, data),
_ => { _ => {
if no_sniff { if no_sniff {
return supplied_type.clone(); supplied_type.clone()
} } else if check_for_apache_bug {
if check_for_apache_bug { self.sniff_text_or_data(data)
return self.sniff_text_or_data(data); } else if MIMEClassifier::is_xml(media_type, media_subtype) {
} supplied_type.clone()
} else if MIMEClassifier::is_html(media_type, media_subtype) {
if MIMEClassifier::is_xml(media_type, media_subtype) { //Implied in section 7.3, but flow is not clear
return supplied_type.clone(); self.feeds_classifier.classify(data).or(supplied_type.clone())
} } else {
//Inplied in section 7.3, but flow is not clear match (&**media_type, &**media_subtype) {
if MIMEClassifier::is_html(media_type, media_subtype) { ("image", _) => self.image_classifier.classify(data),
return self.feeds_classifier ("audio", _) | ("video", _) | ("application", "ogg") =>
.classify(data) self.audio_video_classifer.classify(data),
.or(supplied_type.clone()); _ => None
} }.or(supplied_type.clone())
if &**media_type == "image" {
if let Some(tp) = self.image_classifier.classify(data) {
return Some(tp);
}
}
match (&**media_type, &**media_subtype) {
("audio", _) | ("video", _) | ("application", "ogg") => {
if let Some(tp) = self.audio_video_classifer.classify(data) {
return Some(tp);
}
}
_ => {}
} }
} }
} }
} }
} }
return supplied_type.clone();
} }
pub fn new() -> MIMEClassifier { pub fn new() -> MIMEClassifier {
@ -99,13 +81,15 @@ impl MIMEClassifier {
fn sniff_text_or_data(&self, data: &[u8]) -> Option<(String, String)> { fn sniff_text_or_data(&self, data: &[u8]) -> Option<(String, String)> {
self.binary_or_plaintext.classify(data) self.binary_or_plaintext.classify(data)
} }
fn is_xml(tp: &str, sub_tp: &str) -> bool { fn is_xml(tp: &str, sub_tp: &str) -> bool {
let suffix = &sub_tp[(max(sub_tp.len() as isize - "+xml".len() as isize, 0) as usize)..]; sub_tp.ends_with("+xml") ||
match (tp, sub_tp, suffix) { match (tp, sub_tp) {
(_, _, "+xml") | ("application", "xml",_) | ("text", "xml",_) => {true} ("application", "xml") | ("text", "xml") => true,
_ => {false} _ => false
} }
} }
fn is_html(tp: &str, sub_tp: &str) -> bool { fn is_html(tp: &str, sub_tp: &str) -> bool {
tp == "text" && sub_tp == "html" tp == "text" && sub_tp == "html"
} }
@ -141,13 +125,11 @@ impl <'a, T: Iterator<Item=&'a u8> + Clone> Matches for T {
// Side effects // Side effects
// moves the iterator when match is found // moves the iterator when match is found
fn matches(&mut self, matches: &[u8]) -> bool { fn matches(&mut self, matches: &[u8]) -> bool {
for (byte_a, byte_b) in self.clone().take(matches.len()).zip(matches) { let result = self.clone().zip(matches).all(|(s, m)| *s == *m);
if byte_a != byte_b { if result {
return false; self.nth(matches.len());
}
} }
self.nth(matches.len()); result
true
} }
} }
@ -155,36 +137,27 @@ struct ByteMatcher {
pattern: &'static [u8], pattern: &'static [u8],
mask: &'static [u8], mask: &'static [u8],
leading_ignore: &'static [u8], leading_ignore: &'static [u8],
content_type: (&'static str,&'static str) content_type: (&'static str, &'static str)
} }
impl ByteMatcher { impl ByteMatcher {
fn matches(&self, data: &[u8]) -> Option<usize> { fn matches(&self, data: &[u8]) -> Option<usize> {
if data.len() < self.pattern.len() { if data.len() < self.pattern.len() {
return None; None
} else if data == self.pattern {
Some(self.pattern.len())
} else {
data[..data.len() - self.pattern.len()].iter()
.position(|x| !self.leading_ignore.contains(x))
.and_then(|start|
if data[start..].iter()
.zip(self.pattern.iter()).zip(self.mask.iter())
.all(|((&data, &pattern), &mask)| (data & mask) == (pattern & mask)) {
Some(start + self.pattern.len())
} else {
None
})
} }
//TODO replace with iterators if I ever figure them out...
let mut i: usize = 0;
let max_i = data.len()-self.pattern.len();
loop {
if !self.leading_ignore.iter().any(|x| *x == data[i]) {
break;
}
i = i + 1;
if i > max_i {
return None;
}
}
for j in 0..self.pattern.len() {
if (data[i] & self.mask[j]) != (self.pattern[j] & self.mask[j]) {
return None;
}
i = i + 1;
}
Some(i)
} }
} }
@ -202,14 +175,13 @@ struct TagTerminatedByteMatcher {
impl MIMEChecker for TagTerminatedByteMatcher { impl MIMEChecker for TagTerminatedByteMatcher {
fn classify(&self, data: &[u8]) -> Option<(String, String)> { fn classify(&self, data: &[u8]) -> Option<(String, String)> {
let pattern = self.matcher.matches(data); self.matcher.matches(data).and_then(|j|
let pattern_matches = pattern.map(|j| j < data.len() && (data[j] == b' ' || data[j] == b'>')); if j < data.len() && (data[j] == b' ' || data[j] == b'>') {
if pattern_matches.unwrap_or(false) { Some((self.matcher.content_type.0.to_owned(),
Some((self.matcher.content_type.0.to_owned(), self.matcher.content_type.1.to_owned()))
self.matcher.content_type.1.to_owned())) } else {
} else { None
None })
}
} }
} }
pub struct Mp4Matcher; pub struct Mp4Matcher;
@ -219,48 +191,21 @@ impl Mp4Matcher {
if data.len() < 12 { if data.len() < 12 {
return false; return false;
} }
let box_size = ((data[0] as u32) << 3 | (data[1] as u32) << 2 | let box_size = ((data[0] as u32) << 3 | (data[1] as u32) << 2 |
(data[2] as u32) << 1 | (data[3] as u32)) as usize; (data[2] as u32) << 1 | (data[3] as u32)) as usize;
if (data.len() < box_size) || (box_size % 4 != 0) { if (data.len() < box_size) || (box_size % 4 != 0) {
return false; return false;
} }
//TODO replace with iterators
let ftyp = [0x66, 0x74, 0x79, 0x70]; let ftyp = [0x66, 0x74, 0x79, 0x70];
let mp4 = [0x6D, 0x70, 0x34]; if !data[4..].starts_with(&ftyp) {
return false;
for i in 4..8 {
if data[i] != ftyp[i - 4] {
return false;
}
}
let mut all_match = true;
for i in 8..11 {
if data[i] != mp4[i - 8] {
all_match = false;
break;
}
}
if all_match {
return true;
} }
let mut bytes_read: usize = 16; let mp4 = [0x6D, 0x70, 0x34];
data[8..].starts_with(&mp4) ||
while bytes_read < box_size { data[16..box_size].chunks(4).any(|chunk| chunk.starts_with(&mp4))
all_match = true;
for i in 0..3 {
if mp4[i] != data[i + bytes_read] {
all_match = false;
break;
}
}
if all_match {
return true;
}
bytes_read = bytes_read + 4;
}
false
} }
} }
@ -278,27 +223,24 @@ struct BinaryOrPlaintextClassifier;
impl BinaryOrPlaintextClassifier { impl BinaryOrPlaintextClassifier {
fn classify_impl(&self, data: &[u8]) -> (&'static str, &'static str) { fn classify_impl(&self, data: &[u8]) -> (&'static str, &'static str) {
if (data.len() >= 2 && if data == &[0xFFu8, 0xFEu8] ||
((data[0] == 0xFFu8 && data[1] == 0xFEu8) || data == &[0xFEu8, 0xFFu8] ||
(data[0] == 0xFEu8 && data[1] == 0xFFu8))) || data.starts_with(&[0xEFu8, 0xBBu8, 0xBFu8])
(data.len() >= 3 && data[0] == 0xEFu8 && data[1] == 0xBBu8 && data[2] == 0xBFu8)
{ {
("text", "plain") ("text", "plain")
} } else if data.iter().any(|&x| x <= 0x08u8 ||
else if data.len() >= 1 && data.iter().any(|&x| x <= 0x08u8 || x == 0x0Bu8 ||
x == 0x0Bu8 || (x >= 0x0Eu8 && x <= 0x1Au8) ||
(x >= 0x0Eu8 && x <= 0x1Au8) || (x >= 0x1Cu8 && x <= 0x1Fu8)) {
(x >= 0x1Cu8 && x <= 0x1Fu8)) {
("application", "octet-stream") ("application", "octet-stream")
} } else {
else {
("text", "plain") ("text", "plain")
} }
} }
} }
impl MIMEChecker for BinaryOrPlaintextClassifier { impl MIMEChecker for BinaryOrPlaintextClassifier {
fn classify(&self, data: &[u8]) -> Option<(String, String)> { fn classify(&self, data: &[u8]) -> Option<(String, String)> {
return as_string_option(Some(self.classify_impl(data))); as_string_option(Some(self.classify_impl(data)))
} }
} }
struct GroupedClassifier { struct GroupedClassifier {
@ -358,7 +300,6 @@ impl GroupedClassifier {
box ByteMatcher::application_pdf() box ByteMatcher::application_pdf()
] ]
} }
} }
fn plaintext_classifier() -> GroupedClassifier { fn plaintext_classifier() -> GroupedClassifier {
GroupedClassifier { GroupedClassifier {
@ -403,68 +344,95 @@ impl MIMEChecker for GroupedClassifier {
} }
} }
enum Match {
Start,
DidNotMatch,
StartAndEnd
}
impl Match {
fn chain<F: FnOnce() -> Match>(self, f: F) -> Match {
if let Match::DidNotMatch = self {
return f();
}
self
}
}
fn eats_until<'a, T>(matcher: &mut T, start: &[u8], end: &[u8]) -> Match
where T: Iterator<Item=&'a u8> + Clone {
if !matcher.matches(start) {
Match::DidNotMatch
} else if end.len() == 1 {
if matcher.any(|&x| x == end[0]) {
Match::StartAndEnd
} else {
Match::Start
}
} else {
while !matcher.matches(end) {
if matcher.next().is_none() {
return Match::Start;
}
}
Match::StartAndEnd
}
}
struct FeedsClassifier; struct FeedsClassifier;
impl FeedsClassifier { impl FeedsClassifier {
fn classify_impl(&self, data: &[u8]) -> Option<(&'static str, &'static str)> { fn classify_impl(&self, data: &[u8]) -> Option<(&'static str, &'static str)> {
let length = data.len();
let mut data_iterator = data.iter();
// acceptable byte sequences
let utf8_bom = &[0xEFu8, 0xBBu8, 0xBFu8];
// can not be feed unless length is > 3 // can not be feed unless length is > 3
if length < 3 { if data.len() < 3 {
return None; return None;
} }
// eat the first three bytes if they are equal to UTF-8 BOM let mut matcher = data.iter();
data_iterator.matches(utf8_bom);
// continuously search for next "<" until end of data_iterator // eat the first three acceptable byte sequences if they are equal to UTF-8 BOM
let utf8_bom = &[0xEFu8, 0xBBu8, 0xBFu8];
matcher.matches(utf8_bom);
// continuously search for next "<" until end of matcher
// TODO: need max_bytes to prevent inadvertently examining html document // TODO: need max_bytes to prevent inadvertently examining html document
// eg. an html page with a feed example // eg. an html page with a feed example
while !data_iterator.find(|&data_iterator| *data_iterator == b'<').is_none() { loop {
if data_iterator.matches(b"?") { if matcher.find(|&x| *x == b'<').is_none() {
// eat until ?> return None;
while !data_iterator.matches(b"?>") { }
if data_iterator.next().is_none() {
return None; match eats_until(&mut matcher, b"?", b"?>")
} .chain(|| eats_until(&mut matcher, b"!--", b"-->"))
} .chain(|| eats_until(&mut matcher, b"!", b">")) {
} else if data_iterator.matches(b"!--") { Match::StartAndEnd => continue,
// eat until --> Match::DidNotMatch => {},
while !data_iterator.matches(b"-->") { Match::Start => return None
if data_iterator.next().is_none() { }
return None;
} if matcher.matches(b"rss") {
}
} else if data_iterator.matches(b"!") {
data_iterator.find(|&data_iterator| *data_iterator == b'>');
} else if data_iterator.matches(b"rss") {
return Some(("application", "rss+xml")); return Some(("application", "rss+xml"));
} else if data_iterator.matches(b"feed") { }
if matcher.matches(b"feed") {
return Some(("application", "atom+xml")); return Some(("application", "atom+xml"));
} else if data_iterator.matches(b"rdf: RDF") { }
while !data_iterator.next().is_none() { if matcher.matches(b"rdf: RDF") {
if data_iterator.matches(b"http: //purl.org/rss/1.0/") { while matcher.next().is_some() {
while !data_iterator.next().is_none() { match eats_until(&mut matcher,
if data_iterator.matches(b"http: //www.w3.org/1999/02/22-rdf-syntax-ns#") { b"http: //purl.org/rss/1.0/",
return Some(("application", "rss+xml")); b"http: //www.w3.org/1999/02/22-rdf-syntax-ns#")
} .chain(|| eats_until(&mut matcher,
} b"http: //www.w3.org/1999/02/22-rdf-syntax-ns#",
} else if data_iterator.matches(b"http: //www.w3.org/1999/02/22-rdf-syntax-ns#") { b"http: //purl.org/rss/1.0/")) {
while !data_iterator.next().is_none() { Match::StartAndEnd => return Some(("application", "rss+xml")),
if data_iterator.matches(b"http: //purl.org/rss/1.0/") { Match::DidNotMatch => {},
return Some(("application", "rss+xml")); Match::Start => return None
}
}
} }
} }
return None;
} }
} }
None
} }
} }