RPM build fix (reverted CI changes which will need to be un-reverted or made conditional) and vendor Rust dependencies to make builds much faster in any CI system.

This commit is contained in:
Adam Ierymenko
2022-06-08 07:32:16 -04:00
parent 373ca30269
commit d5ca4e5f52
12611 changed files with 2898014 additions and 284 deletions

View File

@@ -0,0 +1,483 @@
extern crate openssl;
extern crate openssl_probe;
use self::openssl::error::ErrorStack;
use self::openssl::hash::MessageDigest;
use self::openssl::nid::Nid;
use self::openssl::pkcs12::Pkcs12;
use self::openssl::pkey::{PKey, Private};
use self::openssl::ssl::{
self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
SslVerifyMode,
};
use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
use std::error;
use std::fmt;
use std::io;
use std::sync::Once;
use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
#[cfg(have_min_max_version)]
fn supported_protocols(
min: Option<Protocol>,
max: Option<Protocol>,
ctx: &mut SslContextBuilder,
) -> Result<(), ErrorStack> {
use self::openssl::ssl::SslVersion;
fn cvt(p: Protocol) -> SslVersion {
match p {
Protocol::Sslv3 => SslVersion::SSL3,
Protocol::Tlsv10 => SslVersion::TLS1,
Protocol::Tlsv11 => SslVersion::TLS1_1,
Protocol::Tlsv12 => SslVersion::TLS1_2,
Protocol::__NonExhaustive => unreachable!(),
}
}
ctx.set_min_proto_version(min.map(cvt))?;
ctx.set_max_proto_version(max.map(cvt))?;
Ok(())
}
#[cfg(not(have_min_max_version))]
fn supported_protocols(
min: Option<Protocol>,
max: Option<Protocol>,
ctx: &mut SslContextBuilder,
) -> Result<(), ErrorStack> {
use self::openssl::ssl::SslOptions;
let no_ssl_mask = SslOptions::NO_SSLV2
| SslOptions::NO_SSLV3
| SslOptions::NO_TLSV1
| SslOptions::NO_TLSV1_1
| SslOptions::NO_TLSV1_2;
ctx.clear_options(no_ssl_mask);
let mut options = SslOptions::empty();
options |= match min {
None => SslOptions::empty(),
Some(Protocol::Sslv3) => SslOptions::NO_SSLV2,
Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3,
Some(Protocol::Tlsv11) => {
SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
}
Some(Protocol::Tlsv12) => {
SslOptions::NO_SSLV2
| SslOptions::NO_SSLV3
| SslOptions::NO_TLSV1
| SslOptions::NO_TLSV1_1
}
Some(Protocol::__NonExhaustive) => unreachable!(),
};
options |= match max {
None | Some(Protocol::Tlsv12) => SslOptions::empty(),
Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2,
Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
Some(Protocol::Sslv3) => {
SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
}
Some(Protocol::__NonExhaustive) => unreachable!(),
};
ctx.set_options(options);
Ok(())
}
fn init_trust() {
static ONCE: Once = Once::new();
ONCE.call_once(openssl_probe::init_ssl_cert_env_vars);
}
#[cfg(target_os = "android")]
fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> {
use std::fs;
if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts") {
let certs = dir
.filter_map(|r| r.ok())
.filter_map(|e| fs::read(e.path()).ok())
.filter_map(|b| X509::from_pem(&b).ok());
for cert in certs {
if let Err(err) = connector.cert_store_mut().add_cert(cert) {
debug!("load_android_root_certs error: {:?}", err);
}
}
}
Ok(())
}
#[derive(Debug)]
pub enum Error {
Normal(ErrorStack),
Ssl(ssl::Error, X509VerifyResult),
EmptyChain,
NotPkcs8,
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match *self {
Error::Normal(ref e) => error::Error::source(e),
Error::Ssl(ref e, _) => error::Error::source(e),
Error::EmptyChain => None,
Error::NotPkcs8 => None,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::Normal(ref e) => fmt::Display::fmt(e, fmt),
Error::Ssl(ref e, X509VerifyResult::OK) => fmt::Display::fmt(e, fmt),
Error::Ssl(ref e, v) => write!(fmt, "{} ({})", e, v),
Error::EmptyChain => write!(
fmt,
"at least one certificate must be provided to create an identity"
),
Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM"),
}
}
}
impl From<ErrorStack> for Error {
fn from(err: ErrorStack) -> Error {
Error::Normal(err)
}
}
#[derive(Clone)]
pub struct Identity {
pkey: PKey<Private>,
cert: X509,
chain: Vec<X509>,
}
impl Identity {
pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
let pkcs12 = Pkcs12::from_der(buf)?;
let parsed = pkcs12.parse(pass)?;
Ok(Identity {
pkey: parsed.pkey,
cert: parsed.cert,
// > The stack is the reverse of what you might expect due to the way
// > PKCS12_parse is implemented, so we need to load it backwards.
// > https://github.com/sfackler/rust-native-tls/commit/05fb5e583be589ab63d9f83d986d095639f8ec44
chain: parsed.chain.into_iter().flatten().rev().collect(),
})
}
pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result<Identity, Error> {
if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
return Err(Error::NotPkcs8);
}
let pkey = PKey::private_key_from_pem(key)?;
let mut cert_chain = X509::stack_from_pem(buf)?.into_iter();
let cert = cert_chain.next().ok_or(Error::EmptyChain)?;
let chain = cert_chain.collect();
Ok(Identity { pkey, cert, chain })
}
}
#[derive(Clone)]
pub struct Certificate(X509);
impl Certificate {
pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
let cert = X509::from_der(buf)?;
Ok(Certificate(cert))
}
pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
let cert = X509::from_pem(buf)?;
Ok(Certificate(cert))
}
pub fn to_der(&self) -> Result<Vec<u8>, Error> {
let der = self.0.to_der()?;
Ok(der)
}
}
pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);
impl<S> fmt::Debug for MidHandshakeTlsStream<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> MidHandshakeTlsStream<S> {
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S> MidHandshakeTlsStream<S>
where
S: io::Read + io::Write,
{
pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
match self.0.handshake() {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
pub enum HandshakeError<S> {
Failure(Error),
WouldBlock(MidHandshakeTlsStream<S>),
}
impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> {
fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> {
match e {
ssl::HandshakeError::SetupFailure(e) => HandshakeError::Failure(e.into()),
ssl::HandshakeError::Failure(e) => {
let v = e.ssl().verify_result();
HandshakeError::Failure(Error::Ssl(e.into_error(), v))
}
ssl::HandshakeError::WouldBlock(s) => {
HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
}
}
}
}
impl<S> From<ErrorStack> for HandshakeError<S> {
fn from(e: ErrorStack) -> HandshakeError<S> {
HandshakeError::Failure(e.into())
}
}
#[derive(Clone)]
pub struct TlsConnector {
connector: SslConnector,
use_sni: bool,
accept_invalid_hostnames: bool,
accept_invalid_certs: bool,
}
impl TlsConnector {
pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
init_trust();
let mut connector = SslConnector::builder(SslMethod::tls())?;
if let Some(ref identity) = builder.identity {
connector.set_certificate(&identity.0.cert)?;
connector.set_private_key(&identity.0.pkey)?;
for cert in identity.0.chain.iter() {
// https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
// specifies that "When sending a certificate chain, extra chain certificates are
// sent in order following the end entity certificate."
connector.add_extra_chain_cert(cert.to_owned())?;
}
}
supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?;
if builder.disable_built_in_roots {
connector.set_cert_store(X509StoreBuilder::new()?.build());
}
for cert in &builder.root_certificates {
if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) {
debug!("add_cert error: {:?}", err);
}
}
#[cfg(feature = "alpn")]
{
if !builder.alpn.is_empty() {
// Wire format is each alpn preceded by its length as a byte.
let mut alpn_wire_format = Vec::with_capacity(
builder
.alpn
.iter()
.map(|s| s.as_bytes().len())
.sum::<usize>()
+ builder.alpn.len(),
);
for alpn in builder.alpn.iter().map(|s| s.as_bytes()) {
alpn_wire_format.push(alpn.len() as u8);
alpn_wire_format.extend(alpn);
}
connector.set_alpn_protos(&alpn_wire_format)?;
}
}
#[cfg(target_os = "android")]
load_android_root_certs(&mut connector)?;
Ok(TlsConnector {
connector: connector.build(),
use_sni: builder.use_sni,
accept_invalid_hostnames: builder.accept_invalid_hostnames,
accept_invalid_certs: builder.accept_invalid_certs,
})
}
pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let mut ssl = self
.connector
.configure()?
.use_server_name_indication(self.use_sni)
.verify_hostname(!self.accept_invalid_hostnames);
if self.accept_invalid_certs {
ssl.set_verify(SslVerifyMode::NONE);
}
let s = ssl.connect(domain, stream)?;
Ok(TlsStream(s))
}
}
impl fmt::Debug for TlsConnector {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("TlsConnector")
// n.b. SslConnector is a newtype on SslContext which implements a noop Debug so it's omitted
.field("use_sni", &self.use_sni)
.field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
.field("accept_invalid_certs", &self.accept_invalid_certs)
.finish()
}
}
#[derive(Clone)]
pub struct TlsAcceptor(SslAcceptor);
impl TlsAcceptor {
pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
acceptor.set_private_key(&builder.identity.0.pkey)?;
acceptor.set_certificate(&builder.identity.0.cert)?;
for cert in builder.identity.0.chain.iter() {
// https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
// specifies that "When sending a certificate chain, extra chain certificates are
// sent in order following the end entity certificate."
acceptor.add_extra_chain_cert(cert.to_owned())?;
}
supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?;
Ok(TlsAcceptor(acceptor.build()))
}
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let s = self.0.accept(stream)?;
Ok(TlsStream(s))
}
}
pub struct TlsStream<S>(ssl::SslStream<S>);
impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> TlsStream<S> {
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S: io::Read + io::Write> TlsStream<S> {
pub fn buffered_read_size(&self) -> Result<usize, Error> {
Ok(self.0.ssl().pending())
}
pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
Ok(self.0.ssl().peer_certificate().map(Certificate))
}
#[cfg(feature = "alpn")]
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
Ok(self
.0
.ssl()
.selected_alpn_protocol()
.map(|alpn| alpn.to_vec()))
}
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
let cert = if self.0.ssl().is_server() {
self.0.ssl().certificate().map(|x| x.to_owned())
} else {
self.0.ssl().peer_certificate()
};
let cert = match cert {
Some(cert) => cert,
None => return Ok(None),
};
let algo_nid = cert.signature_algorithm().object().nid();
let signature_algorithms = match algo_nid.signature_algorithms() {
Some(algs) => algs,
None => return Ok(None),
};
let md = match signature_algorithms.digest {
Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
nid => match MessageDigest::from_nid(nid) {
Some(md) => md,
None => return Ok(None),
},
};
let digest = cert.digest(md)?;
Ok(Some(digest.to_vec()))
}
pub fn shutdown(&mut self) -> io::Result<()> {
match self.0.shutdown() {
Ok(_) => Ok(()),
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()),
Err(e) => Err(e
.into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))),
}
}
}
impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}

View File

@@ -0,0 +1,562 @@
extern crate schannel;
use self::schannel::cert_context::{CertContext, HashAlgorithm, KeySpec};
use self::schannel::cert_store::{CertAdd, CertStore, Memory, PfxImportOptions};
use self::schannel::crypt_prov::{AcquireOptions, ProviderType};
use self::schannel::schannel_cred::{Direction, Protocol, SchannelCred};
use self::schannel::tls_stream;
use std::error;
use std::fmt;
use std::io;
use std::str;
use {TlsAcceptorBuilder, TlsConnectorBuilder};
const SEC_E_NO_CREDENTIALS: u32 = 0x8009030E;
static PROTOCOLS: &'static [Protocol] = &[
Protocol::Ssl3,
Protocol::Tls10,
Protocol::Tls11,
Protocol::Tls12,
];
fn convert_protocols(min: Option<::Protocol>, max: Option<::Protocol>) -> &'static [Protocol] {
let mut protocols = PROTOCOLS;
if let Some(p) = max.and_then(|max| protocols.get(..=max as usize)) {
protocols = p;
}
if let Some(p) = min.and_then(|min| protocols.get(min as usize..)) {
protocols = p;
}
protocols
}
pub struct Error(io::Error);
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
error::Error::source(&self.0)
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, fmt)
}
}
impl fmt::Debug for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl From<io::Error> for Error {
fn from(error: io::Error) -> Error {
Error(error)
}
}
#[derive(Clone)]
pub struct Identity {
cert: CertContext,
}
impl Identity {
pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
let store = PfxImportOptions::new().password(pass).import(buf)?;
let mut identity = None;
for cert in store.certs() {
if cert
.private_key()
.silent(true)
.compare_key(true)
.acquire()
.is_ok()
{
identity = Some(cert);
break;
}
}
let identity = match identity {
Some(identity) => identity,
None => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"No identity found in PKCS #12 archive",
)
.into());
}
};
Ok(Identity { cert: identity })
}
pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Result<Identity, Error> {
if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "not a PKCS#8 key").into());
}
let mut store = Memory::new()?.into_store();
let mut cert_iter = pem::PemBlock::new(pem).into_iter();
let leaf = cert_iter.next().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"at least one certificate must be provided to create an identity",
)
})?;
let cert = CertContext::from_pem(std::str::from_utf8(leaf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"leaf cert contains invalid utf8",
)
})?)?;
let name = gen_container_name();
let mut options = AcquireOptions::new();
options.container(&name);
let type_ = ProviderType::rsa_full();
let mut container = match options.acquire(type_) {
Ok(container) => container,
Err(_) => options.new_keyset(true).acquire(type_)?,
};
container.import().import_pkcs8_pem(&key)?;
cert.set_key_prov_info()
.container(&name)
.type_(type_)
.keep_open(true)
.key_spec(KeySpec::key_exchange())
.set()?;
let mut context = store.add_cert(&cert, CertAdd::Always)?;
for int_cert in cert_iter {
let certificate = Certificate::from_pem(int_cert)?;
context = store.add_cert(&certificate.0, CertAdd::Always)?;
}
Ok(Identity { cert: context })
}
}
// The name of the container must be unique to have multiple active keys.
fn gen_container_name() -> String {
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
format!("native-tls-{}", COUNTER.fetch_add(1, Ordering::Relaxed))
}
#[derive(Clone)]
pub struct Certificate(CertContext);
impl Certificate {
pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
let cert = CertContext::new(buf)?;
Ok(Certificate(cert))
}
pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
match str::from_utf8(buf) {
Ok(s) => {
let cert = CertContext::from_pem(s)?;
Ok(Certificate(cert))
}
Err(_) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"PEM representation contains non-UTF-8 bytes",
)
.into()),
}
}
pub fn to_der(&self) -> Result<Vec<u8>, Error> {
Ok(self.0.to_der().to_vec())
}
}
pub struct MidHandshakeTlsStream<S>(tls_stream::MidHandshakeTlsStream<S>);
impl<S> fmt::Debug for MidHandshakeTlsStream<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> MidHandshakeTlsStream<S> {
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S> MidHandshakeTlsStream<S>
where
S: io::Read + io::Write,
{
pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
match self.0.handshake() {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
pub enum HandshakeError<S> {
Failure(Error),
WouldBlock(MidHandshakeTlsStream<S>),
}
impl<S> From<tls_stream::HandshakeError<S>> for HandshakeError<S> {
fn from(e: tls_stream::HandshakeError<S>) -> HandshakeError<S> {
match e {
tls_stream::HandshakeError::Failure(e) => HandshakeError::Failure(e.into()),
tls_stream::HandshakeError::Interrupted(s) => {
HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
}
}
}
}
impl<S> From<io::Error> for HandshakeError<S> {
fn from(e: io::Error) -> HandshakeError<S> {
HandshakeError::Failure(e.into())
}
}
#[derive(Clone, Debug)]
pub struct TlsConnector {
cert: Option<CertContext>,
roots: CertStore,
min_protocol: Option<::Protocol>,
max_protocol: Option<::Protocol>,
use_sni: bool,
accept_invalid_hostnames: bool,
accept_invalid_certs: bool,
disable_built_in_roots: bool,
#[cfg(feature = "alpn")]
alpn: Vec<String>,
}
impl TlsConnector {
pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
let cert = builder.identity.as_ref().map(|i| i.0.cert.clone());
let mut roots = Memory::new()?.into_store();
for cert in &builder.root_certificates {
roots.add_cert(&(cert.0).0, CertAdd::ReplaceExisting)?;
}
Ok(TlsConnector {
cert,
roots,
min_protocol: builder.min_protocol,
max_protocol: builder.max_protocol,
use_sni: builder.use_sni,
accept_invalid_hostnames: builder.accept_invalid_hostnames,
accept_invalid_certs: builder.accept_invalid_certs,
disable_built_in_roots: builder.disable_built_in_roots,
#[cfg(feature = "alpn")]
alpn: builder.alpn.clone(),
})
}
pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let mut builder = SchannelCred::builder();
builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
if let Some(cert) = self.cert.as_ref() {
builder.cert(cert.clone());
}
let cred = builder.acquire(Direction::Outbound)?;
let mut builder = tls_stream::Builder::new();
builder
.cert_store(self.roots.clone())
.domain(domain)
.use_sni(self.use_sni)
.accept_invalid_hostnames(self.accept_invalid_hostnames);
if self.accept_invalid_certs {
builder.verify_callback(|_| Ok(()));
} else if self.disable_built_in_roots {
let roots_copy = self.roots.clone();
builder.verify_callback(move |res| {
if let Err(err) = res.result() {
// Propagate previous error encountered during normal cert validation.
return Err(err);
}
if let Some(chain) = res.chain() {
if chain
.certificates()
.any(|cert| roots_copy.certs().any(|root_cert| root_cert == cert))
{
return Ok(());
}
}
Err(io::Error::new(
io::ErrorKind::Other,
"unable to find any user-specified roots in the final cert chain",
))
});
}
#[cfg(feature = "alpn")]
{
if !self.alpn.is_empty() {
builder.request_application_protocols(
&self.alpn.iter().map(|s| s.as_bytes()).collect::<Vec<_>>(),
);
}
}
match builder.connect(cred, stream) {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
#[derive(Clone)]
pub struct TlsAcceptor {
cert: CertContext,
min_protocol: Option<::Protocol>,
max_protocol: Option<::Protocol>,
}
impl TlsAcceptor {
pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
Ok(TlsAcceptor {
cert: builder.identity.0.cert.clone(),
min_protocol: builder.min_protocol,
max_protocol: builder.max_protocol,
})
}
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let mut builder = SchannelCred::builder();
builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
builder.cert(self.cert.clone());
// FIXME we're probably missing the certificate chain?
let cred = builder.acquire(Direction::Inbound)?;
match tls_stream::Builder::new().accept(cred, stream) {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
pub struct TlsStream<S>(tls_stream::TlsStream<S>);
impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> TlsStream<S> {
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S: io::Read + io::Write> TlsStream<S> {
pub fn buffered_read_size(&self) -> Result<usize, Error> {
Ok(self.0.get_buf().len())
}
pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
match self.0.peer_certificate() {
Ok(cert) => Ok(Some(Certificate(cert))),
Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => Ok(None),
Err(e) => Err(Error(e)),
}
}
#[cfg(feature = "alpn")]
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
Ok(self.0.negotiated_application_protocol()?)
}
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
let cert = if self.0.is_server() {
self.0.certificate()
} else {
self.0.peer_certificate()
};
let cert = match cert {
Ok(cert) => cert,
Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => return Ok(None),
Err(e) => return Err(Error(e)),
};
let signature_algorithms = cert.sign_hash_algorithms()?;
let hash = match signature_algorithms.rsplit('/').next().unwrap() {
"MD5" | "SHA1" | "SHA256" => HashAlgorithm::sha256(),
"SHA384" => HashAlgorithm::sha384(),
"SHA512" => HashAlgorithm::sha512(),
_ => return Ok(None),
};
let digest = cert.fingerprint(hash)?;
Ok(Some(digest))
}
pub fn shutdown(&mut self) -> io::Result<()> {
self.0.shutdown()?;
Ok(())
}
}
impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
mod pem {
/// Split data by PEM guard lines
pub struct PemBlock<'a> {
pem_block: &'a str,
cur_end: usize,
}
impl<'a> PemBlock<'a> {
pub fn new(data: &'a [u8]) -> PemBlock<'a> {
let s = ::std::str::from_utf8(data).unwrap();
PemBlock {
pem_block: s,
cur_end: s.find("-----BEGIN").unwrap_or(s.len()),
}
}
}
impl<'a> Iterator for PemBlock<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
let last = self.pem_block.len();
if self.cur_end >= last {
return None;
}
let begin = self.cur_end;
let pos = self.pem_block[begin + 1..].find("-----BEGIN");
self.cur_end = match pos {
Some(end) => end + begin + 1,
None => last,
};
return Some(&self.pem_block[begin..self.cur_end].as_bytes());
}
}
#[test]
fn test_split() {
// Split three certs, CRLF line terminators.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n\
-----BEGIN SECOND-----\r\n-----END SECOND\r\n\
-----BEGIN THIRD-----\r\n-----END THIRD\r\n"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n" as &[u8],
b"-----BEGIN SECOND-----\r\n-----END SECOND\r\n",
b"-----BEGIN THIRD-----\r\n-----END THIRD\r\n"
]
);
// Split three certs, CRLF line terminators except at EOF.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n\
-----BEGIN SECOND-----\r\n-----END SECOND-----\r\n\
-----BEGIN THIRD-----\r\n-----END THIRD-----"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n" as &[u8],
b"-----BEGIN SECOND-----\r\n-----END SECOND-----\r\n",
b"-----BEGIN THIRD-----\r\n-----END THIRD-----"
]
);
// Split two certs, LF line terminators.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\n-----END FIRST-----\n\
-----BEGIN SECOND-----\n-----END SECOND\n"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\n-----END FIRST-----\n" as &[u8],
b"-----BEGIN SECOND-----\n-----END SECOND\n"
]
);
// Split two certs, CR line terminators.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\r-----END FIRST-----\r\
-----BEGIN SECOND-----\r-----END SECOND\r"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\r-----END FIRST-----\r" as &[u8],
b"-----BEGIN SECOND-----\r-----END SECOND\r"
]
);
// Split two certs, LF line terminators except at EOF.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\n-----END FIRST-----\n\
-----BEGIN SECOND-----\n-----END SECOND"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\n-----END FIRST-----\n" as &[u8],
b"-----BEGIN SECOND-----\n-----END SECOND"
]
);
// Split a single cert, LF line terminators.
assert_eq!(
PemBlock::new(b"-----BEGIN FIRST-----\n-----END FIRST-----\n").collect::<Vec<&[u8]>>(),
vec![b"-----BEGIN FIRST-----\n-----END FIRST-----\n" as &[u8]]
);
// Split a single cert, LF line terminators except at EOF.
assert_eq!(
PemBlock::new(b"-----BEGIN FIRST-----\n-----END FIRST-----").collect::<Vec<&[u8]>>(),
vec![b"-----BEGIN FIRST-----\n-----END FIRST-----" as &[u8]]
);
// (Don't) split garbage.
assert_eq!(
PemBlock::new(b"junk").collect::<Vec<&[u8]>>(),
Vec::<&[u8]>::new()
);
assert_eq!(
PemBlock::new(b"junk-----BEGIN garbage").collect::<Vec<&[u8]>>(),
vec![b"-----BEGIN garbage" as &[u8]]
);
}
}

View File

@@ -0,0 +1,632 @@
extern crate libc;
extern crate security_framework;
extern crate security_framework_sys;
extern crate tempfile;
use self::security_framework::base;
use self::security_framework::certificate::SecCertificate;
use self::security_framework::identity::SecIdentity;
use self::security_framework::import_export::{ImportedIdentity, Pkcs12ImportOptions};
use self::security_framework::random::SecRandom;
use self::security_framework::secure_transport::{
self, ClientBuilder, SslConnectionType, SslContext, SslProtocol, SslProtocolSide,
};
use self::security_framework_sys::base::{errSecIO, errSecParam};
use self::tempfile::TempDir;
use std::error;
use std::fmt;
use std::io;
use std::str;
use std::sync::Mutex;
use std::sync::Once;
#[cfg(not(target_os = "ios"))]
use self::security_framework::os::macos::certificate::{PropertyType, SecCertificateExt};
#[cfg(not(target_os = "ios"))]
use self::security_framework::os::macos::certificate_oids::CertificateOid;
#[cfg(not(target_os = "ios"))]
use self::security_framework::os::macos::identity::SecIdentityExt;
#[cfg(not(target_os = "ios"))]
use self::security_framework::os::macos::import_export::{
ImportOptions, Pkcs12ImportOptionsExt, SecItems,
};
#[cfg(not(target_os = "ios"))]
use self::security_framework::os::macos::keychain::{self, KeychainSettings, SecKeychain};
use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
static SET_AT_EXIT: Once = Once::new();
#[cfg(not(target_os = "ios"))]
lazy_static! {
static ref TEMP_KEYCHAIN: Mutex<Option<(SecKeychain, TempDir)>> = Mutex::new(None);
}
fn convert_protocol(protocol: Protocol) -> SslProtocol {
match protocol {
Protocol::Sslv3 => SslProtocol::SSL3,
Protocol::Tlsv10 => SslProtocol::TLS1,
Protocol::Tlsv11 => SslProtocol::TLS11,
Protocol::Tlsv12 => SslProtocol::TLS12,
Protocol::__NonExhaustive => unreachable!(),
}
}
pub struct Error(base::Error);
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
error::Error::source(&self.0)
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, fmt)
}
}
impl fmt::Debug for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl From<base::Error> for Error {
fn from(error: base::Error) -> Error {
Error(error)
}
}
#[derive(Clone, Debug)]
pub struct Identity {
identity: SecIdentity,
chain: Vec<SecCertificate>,
}
impl Identity {
#[cfg(target_os = "ios")]
pub fn from_pkcs8(_: &[u8], _: &[u8]) -> Result<Identity, Error> {
panic!("Not implemented on iOS");
}
#[cfg(not(target_os = "ios"))]
pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Result<Identity, Error> {
if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
return Err(Error(base::Error::from(errSecParam)));
}
let dir = TempDir::new().map_err(|_| Error(base::Error::from(errSecIO)))?;
let keychain = keychain::CreateOptions::new()
.password(&random_password()?)
.create(dir.path().join("identity.keychain"))?;
let mut items = SecItems::default();
ImportOptions::new()
.filename("key.pem")
.items(&mut items)
.keychain(&keychain)
.import(&key)?;
ImportOptions::new()
.filename("chain.pem")
.items(&mut items)
.keychain(&keychain)
.import(&pem)?;
let cert = items
.certificates
.get(0)
.ok_or_else(|| Error(base::Error::from(errSecParam)))?;
let ident = SecIdentity::with_certificate(&[keychain], cert)?;
Ok(Identity {
identity: ident,
chain: items.certificates,
})
}
pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
let mut imports = Identity::import_options(buf, pass)?;
let import = imports.pop().unwrap();
let identity = import
.identity
.expect("Pkcs12 files must include an identity");
// FIXME: Compare the certificates for equality using CFEqual
let identity_cert = identity.certificate()?.to_der();
Ok(Identity {
identity,
chain: import
.cert_chain
.unwrap_or(vec![])
.into_iter()
.filter(|c| c.to_der() != identity_cert)
.collect(),
})
}
#[cfg(not(target_os = "ios"))]
fn import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error> {
SET_AT_EXIT.call_once(|| {
extern "C" fn atexit() {
*TEMP_KEYCHAIN.lock().unwrap() = None;
}
unsafe {
libc::atexit(atexit);
}
});
let keychain = match *TEMP_KEYCHAIN.lock().unwrap() {
Some((ref keychain, _)) => keychain.clone(),
ref mut lock @ None => {
let dir = TempDir::new().map_err(|_| Error(base::Error::from(errSecIO)))?;
let mut keychain = keychain::CreateOptions::new()
.password(pass)
.create(dir.path().join("tmp.keychain"))?;
keychain.set_settings(&KeychainSettings::new())?;
*lock = Some((keychain.clone(), dir));
keychain
}
};
let mut import_opts = Pkcs12ImportOptions::new();
// Method shadowed by deprecated method.
<Pkcs12ImportOptions as Pkcs12ImportOptionsExt>::keychain(&mut import_opts, keychain);
let imports = import_opts.passphrase(pass).import(buf)?;
Ok(imports)
}
#[cfg(target_os = "ios")]
fn import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error> {
let imports = Pkcs12ImportOptions::new().passphrase(pass).import(buf)?;
Ok(imports)
}
}
fn random_password() -> Result<String, Error> {
use std::fmt::Write;
let mut bytes = [0_u8; 10];
SecRandom::default()
.copy_bytes(&mut bytes)
.map_err(|_| Error(base::Error::from(errSecIO)))?;
let mut s = String::with_capacity(2 * bytes.len());
for byte in bytes {
write!(s, "{:02X}", byte).map_err(|_| Error(base::Error::from(errSecIO)))?;
}
Ok(s)
}
#[derive(Clone)]
pub struct Certificate(SecCertificate);
impl Certificate {
pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
let cert = SecCertificate::from_der(buf)?;
Ok(Certificate(cert))
}
#[cfg(not(target_os = "ios"))]
pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
let mut items = SecItems::default();
ImportOptions::new().items(&mut items).import(buf)?;
if items.certificates.len() == 1 && items.identities.is_empty() && items.keys.is_empty() {
Ok(Certificate(items.certificates.pop().unwrap()))
} else {
Err(Error(base::Error::from(errSecParam)))
}
}
#[cfg(target_os = "ios")]
pub fn from_pem(_: &[u8]) -> Result<Certificate, Error> {
panic!("Not implemented on iOS");
}
pub fn to_der(&self) -> Result<Vec<u8>, Error> {
Ok(self.0.to_der())
}
}
pub enum HandshakeError<S> {
WouldBlock(MidHandshakeTlsStream<S>),
Failure(Error),
}
impl<S> From<secure_transport::ClientHandshakeError<S>> for HandshakeError<S> {
fn from(e: secure_transport::ClientHandshakeError<S>) -> HandshakeError<S> {
match e {
secure_transport::ClientHandshakeError::Failure(e) => HandshakeError::Failure(e.into()),
secure_transport::ClientHandshakeError::Interrupted(s) => {
HandshakeError::WouldBlock(MidHandshakeTlsStream::Client(s))
}
}
}
}
impl<S> From<base::Error> for HandshakeError<S> {
fn from(e: base::Error) -> HandshakeError<S> {
HandshakeError::Failure(e.into())
}
}
pub enum MidHandshakeTlsStream<S> {
Server(
secure_transport::MidHandshakeSslStream<S>,
Option<SecCertificate>,
),
Client(secure_transport::MidHandshakeClientBuilder<S>),
}
impl<S> fmt::Debug for MidHandshakeTlsStream<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
MidHandshakeTlsStream::Server(ref s, _) => s.fmt(fmt),
MidHandshakeTlsStream::Client(ref s) => s.fmt(fmt),
}
}
}
impl<S> MidHandshakeTlsStream<S> {
pub fn get_ref(&self) -> &S {
match *self {
MidHandshakeTlsStream::Server(ref s, _) => s.get_ref(),
MidHandshakeTlsStream::Client(ref s) => s.get_ref(),
}
}
pub fn get_mut(&mut self) -> &mut S {
match *self {
MidHandshakeTlsStream::Server(ref mut s, _) => s.get_mut(),
MidHandshakeTlsStream::Client(ref mut s) => s.get_mut(),
}
}
}
impl<S> MidHandshakeTlsStream<S>
where
S: io::Read + io::Write,
{
pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
match self {
MidHandshakeTlsStream::Server(s, cert) => match s.handshake() {
Ok(stream) => Ok(TlsStream { stream, cert }),
Err(secure_transport::HandshakeError::Failure(e)) => {
Err(HandshakeError::Failure(Error(e)))
}
Err(secure_transport::HandshakeError::Interrupted(s)) => Err(
HandshakeError::WouldBlock(MidHandshakeTlsStream::Server(s, cert)),
),
},
MidHandshakeTlsStream::Client(s) => match s.handshake() {
Ok(stream) => Ok(TlsStream { stream, cert: None }),
Err(e) => Err(e.into()),
},
}
}
}
#[derive(Clone, Debug)]
pub struct TlsConnector {
identity: Option<Identity>,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
roots: Vec<SecCertificate>,
use_sni: bool,
danger_accept_invalid_hostnames: bool,
danger_accept_invalid_certs: bool,
disable_built_in_roots: bool,
#[cfg(feature = "alpn")]
alpn: Vec<String>,
}
impl TlsConnector {
pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
Ok(TlsConnector {
identity: builder.identity.as_ref().map(|i| i.0.clone()),
min_protocol: builder.min_protocol,
max_protocol: builder.max_protocol,
roots: builder
.root_certificates
.iter()
.map(|c| (c.0).0.clone())
.collect(),
use_sni: builder.use_sni,
danger_accept_invalid_hostnames: builder.accept_invalid_hostnames,
danger_accept_invalid_certs: builder.accept_invalid_certs,
disable_built_in_roots: builder.disable_built_in_roots,
#[cfg(feature = "alpn")]
alpn: builder.alpn.clone(),
})
}
pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let mut builder = ClientBuilder::new();
if let Some(min) = self.min_protocol {
builder.protocol_min(convert_protocol(min));
}
if let Some(max) = self.max_protocol {
builder.protocol_max(convert_protocol(max));
}
if let Some(identity) = self.identity.as_ref() {
builder.identity(&identity.identity, &identity.chain);
}
builder.anchor_certificates(&self.roots);
builder.use_sni(self.use_sni);
builder.danger_accept_invalid_hostnames(self.danger_accept_invalid_hostnames);
builder.danger_accept_invalid_certs(self.danger_accept_invalid_certs);
builder.trust_anchor_certificates_only(self.disable_built_in_roots);
#[cfg(feature = "alpn")]
{
if !self.alpn.is_empty() {
builder.alpn_protocols(&self.alpn.iter().map(String::as_str).collect::<Vec<_>>());
}
}
match builder.handshake(domain, stream) {
Ok(stream) => Ok(TlsStream { stream, cert: None }),
Err(e) => Err(e.into()),
}
}
}
#[derive(Clone)]
pub struct TlsAcceptor {
identity: Identity,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
}
impl TlsAcceptor {
pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
Ok(TlsAcceptor {
identity: builder.identity.0.clone(),
min_protocol: builder.min_protocol,
max_protocol: builder.max_protocol,
})
}
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
if let Some(min) = self.min_protocol {
ctx.set_protocol_version_min(convert_protocol(min))?;
}
if let Some(max) = self.max_protocol {
ctx.set_protocol_version_max(convert_protocol(max))?;
}
ctx.set_certificate(&self.identity.identity, &self.identity.chain)?;
let cert = Some(self.identity.identity.certificate()?);
match ctx.handshake(stream) {
Ok(stream) => Ok(TlsStream { stream, cert }),
Err(secure_transport::HandshakeError::Failure(e)) => {
Err(HandshakeError::Failure(Error(e)))
}
Err(secure_transport::HandshakeError::Interrupted(s)) => Err(
HandshakeError::WouldBlock(MidHandshakeTlsStream::Server(s, cert)),
),
}
}
}
pub struct TlsStream<S> {
stream: secure_transport::SslStream<S>,
cert: Option<SecCertificate>,
}
impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.stream, fmt)
}
}
impl<S> TlsStream<S> {
pub fn get_ref(&self) -> &S {
self.stream.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.stream.get_mut()
}
}
impl<S: io::Read + io::Write> TlsStream<S> {
pub fn buffered_read_size(&self) -> Result<usize, Error> {
Ok(self.stream.context().buffered_read_size()?)
}
#[allow(deprecated)]
pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
let trust = match self.stream.context().peer_trust2()? {
Some(trust) => trust,
None => return Ok(None),
};
trust.evaluate()?;
Ok(trust.certificate_at_index(0).map(Certificate))
}
#[cfg(feature = "alpn")]
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
match self.stream.context().alpn_protocols() {
Ok(protocols) => {
// Per RFC7301, "ProtocolNameList" MUST contain exactly one "ProtocolName".
assert!(protocols.len() < 2);
if protocols.is_empty() {
// Not sure this is actually possible.
Ok(None)
} else {
Ok(Some(protocols.into_iter().next().unwrap().into_bytes()))
}
}
// The macOS API appears to return `errSecParam` whenever no ALPN was negotiated, both
// when it isn't attempted and when it isn't successful.
Err(e) if e.code() == errSecParam => Ok(None),
Err(other) => Err(Error::from(other)),
}
}
#[cfg(target_os = "ios")]
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
Ok(None)
}
#[cfg(not(target_os = "ios"))]
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
let cert = match self.cert {
Some(ref cert) => cert.clone(),
None => match self.peer_certificate()? {
Some(cert) => cert.0,
None => return Ok(None),
},
};
let property = match cert
.properties(Some(&[CertificateOid::x509_v1_signature_algorithm()]))
.ok()
.and_then(|p| p.get(CertificateOid::x509_v1_signature_algorithm()))
{
Some(property) => property,
None => return Ok(None),
};
let section = match property.get() {
PropertyType::Section(section) => section,
_ => return Ok(None),
};
let algorithm = match section
.iter()
.filter(|p| p.label().to_string() == "Algorithm")
.next()
{
Some(property) => property,
None => return Ok(None),
};
let algorithm = match algorithm.get() {
PropertyType::String(algorithm) => algorithm,
_ => return Ok(None),
};
let digest = match &*algorithm.to_string() {
// MD5
"1.2.840.113549.2.5" | "1.2.840.113549.1.1.4" | "1.3.14.3.2.3" => Digest::Sha256,
// SHA-1
"1.3.14.3.2.26"
| "1.3.14.3.2.15"
| "1.2.840.113549.1.1.5"
| "1.3.14.3.2.29"
| "1.2.840.10040.4.3"
| "1.3.14.3.2.13"
| "1.2.840.10045.4.1" => Digest::Sha256,
// SHA-224
"2.16.840.1.101.3.4.2.4"
| "1.2.840.113549.1.1.14"
| "2.16.840.1.101.3.4.3.1"
| "1.2.840.10045.4.3.1" => Digest::Sha224,
// SHA-256
"2.16.840.1.101.3.4.2.1" | "1.2.840.113549.1.1.11" | "1.2.840.10045.4.3.2" => {
Digest::Sha256
}
// SHA-384
"2.16.840.1.101.3.4.2.2" | "1.2.840.113549.1.1.12" | "1.2.840.10045.4.3.3" => {
Digest::Sha384
}
// SHA-512
"2.16.840.1.101.3.4.2.3" | "1.2.840.113549.1.1.13" | "1.2.840.10045.4.3.4" => {
Digest::Sha512
}
_ => return Ok(None),
};
let der = cert.to_der();
Ok(Some(digest.hash(&der)))
}
pub fn shutdown(&mut self) -> io::Result<()> {
self.stream.close()?;
Ok(())
}
}
impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.stream.read(buf)
}
}
impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.stream.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.stream.flush()
}
}
enum Digest {
Sha224,
Sha256,
Sha384,
Sha512,
}
impl Digest {
fn hash(&self, data: &[u8]) -> Vec<u8> {
unsafe {
assert!(data.len() <= CC_LONG::max_value() as usize);
match *self {
Digest::Sha224 => {
let mut buf = [0; CC_SHA224_DIGEST_LENGTH];
CC_SHA224(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
buf.to_vec()
}
Digest::Sha256 => {
let mut buf = [0; CC_SHA256_DIGEST_LENGTH];
CC_SHA256(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
buf.to_vec()
}
Digest::Sha384 => {
let mut buf = [0; CC_SHA384_DIGEST_LENGTH];
CC_SHA384(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
buf.to_vec()
}
Digest::Sha512 => {
let mut buf = [0; CC_SHA512_DIGEST_LENGTH];
CC_SHA512(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
buf.to_vec()
}
}
}
}
}
// FIXME ideally we'd pull these in from elsewhere
const CC_SHA224_DIGEST_LENGTH: usize = 28;
const CC_SHA256_DIGEST_LENGTH: usize = 32;
const CC_SHA384_DIGEST_LENGTH: usize = 48;
const CC_SHA512_DIGEST_LENGTH: usize = 64;
#[allow(non_camel_case_types)]
type CC_LONG = u32;
extern "C" {
fn CC_SHA224(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
fn CC_SHA256(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
fn CC_SHA384(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
fn CC_SHA512(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
}

721
zeroidc/vendor/native-tls/src/lib.rs vendored Normal file
View File

@@ -0,0 +1,721 @@
//! An abstraction over platform-specific TLS implementations.
//!
//! Many applications require TLS/SSL communication in one form or another as
//! part of their implementation, but finding a library for this isn't always
//! trivial! The purpose of this crate is to provide a seamless integration
//! experience on all platforms with a cross-platform API that deals with all
//! the underlying details for you.
//!
//! # How is this implemented?
//!
//! This crate uses SChannel on Windows (via the `schannel` crate), Secure
//! Transport on OSX (via the `security-framework` crate), and OpenSSL (via the
//! `openssl` crate) on all other platforms. Future futures may also enable
//! other TLS frameworks as well, but these initial libraries are likely to
//! remain as the defaults.
//!
//! Note that this crate also strives to be secure-by-default. For example when
//! using OpenSSL it will configure validation callbacks to ensure that
//! hostnames match certificates, use strong ciphers, etc. This implies that
//! this crate is *not* just a thin abstraction around the underlying libraries,
//! but also an implementation that strives to strike reasonable defaults.
//!
//! # Supported features
//!
//! This crate supports the following features out of the box:
//!
//! * TLS/SSL client communication
//! * TLS/SSL server communication
//! * PKCS#12 encoded identities
//! * X.509/PKCS#8 encoded identities
//! * Secure-by-default for client and server
//! * Includes hostname verification for clients
//! * Supports asynchronous I/O for both the server and the client
//!
//! # Cargo Features
//!
//! * `vendored` - If enabled, the crate will compile and statically link to a
//! vendored copy of OpenSSL. This feature has no effect on Windows and
//! macOS, where OpenSSL is not used.
//!
//! # Examples
//!
//! To connect as a client to a remote server:
//!
//! ```rust
//! use native_tls::TlsConnector;
//! use std::io::{Read, Write};
//! use std::net::TcpStream;
//!
//! let connector = TlsConnector::new().unwrap();
//!
//! let stream = TcpStream::connect("google.com:443").unwrap();
//! let mut stream = connector.connect("google.com", stream).unwrap();
//!
//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
//! let mut res = vec![];
//! stream.read_to_end(&mut res).unwrap();
//! println!("{}", String::from_utf8_lossy(&res));
//! ```
//!
//! To accept connections as a server from remote clients:
//!
//! ```rust,no_run
//! use native_tls::{Identity, TlsAcceptor, TlsStream};
//! use std::fs::File;
//! use std::io::{Read};
//! use std::net::{TcpListener, TcpStream};
//! use std::sync::Arc;
//! use std::thread;
//!
//! let mut file = File::open("identity.pfx").unwrap();
//! let mut identity = vec![];
//! file.read_to_end(&mut identity).unwrap();
//! let identity = Identity::from_pkcs12(&identity, "hunter2").unwrap();
//!
//! let listener = TcpListener::bind("0.0.0.0:8443").unwrap();
//! let acceptor = TlsAcceptor::new(identity).unwrap();
//! let acceptor = Arc::new(acceptor);
//!
//! fn handle_client(stream: TlsStream<TcpStream>) {
//! // ...
//! }
//!
//! for stream in listener.incoming() {
//! match stream {
//! Ok(stream) => {
//! let acceptor = acceptor.clone();
//! thread::spawn(move || {
//! let stream = acceptor.accept(stream).unwrap();
//! handle_client(stream);
//! });
//! }
//! Err(e) => { /* connection failed */ }
//! }
//! }
//! ```
#![doc(html_root_url = "https://docs.rs/native-tls/0.2")]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#[macro_use]
#[cfg(any(target_os = "macos", target_os = "ios"))]
extern crate lazy_static;
use std::any::Any;
use std::error;
use std::fmt;
use std::io;
use std::result;
#[cfg(not(any(target_os = "macos", target_os = "windows", target_os = "ios")))]
#[macro_use]
extern crate log;
#[cfg(any(target_os = "macos", target_os = "ios"))]
#[path = "imp/security_framework.rs"]
mod imp;
#[cfg(target_os = "windows")]
#[path = "imp/schannel.rs"]
mod imp;
#[cfg(not(any(target_os = "macos", target_os = "windows", target_os = "ios")))]
#[path = "imp/openssl.rs"]
mod imp;
#[cfg(test)]
mod test;
/// A typedef of the result-type returned by many methods.
pub type Result<T> = result::Result<T, Error>;
/// An error returned from the TLS implementation.
pub struct Error(imp::Error);
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
error::Error::source(&self.0)
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, fmt)
}
}
impl fmt::Debug for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl From<imp::Error> for Error {
fn from(err: imp::Error) -> Error {
Error(err)
}
}
/// A cryptographic identity.
///
/// An identity is an X509 certificate along with its corresponding private key and chain of certificates to a trusted
/// root.
#[derive(Clone)]
pub struct Identity(imp::Identity);
impl Identity {
/// Parses a DER-formatted PKCS #12 archive, using the specified password to decrypt the key.
///
/// The archive should contain a leaf certificate and its private key, as well any intermediate
/// certificates that should be sent to clients to allow them to build a chain to a trusted
/// root. The chain certificates should be in order from the leaf certificate towards the root.
///
/// PKCS #12 archives typically have the file extension `.p12` or `.pfx`, and can be created
/// with the OpenSSL `pkcs12` tool:
///
/// ```bash
/// openssl pkcs12 -export -out identity.pfx -inkey key.pem -in cert.pem -certfile chain_certs.pem
/// ```
pub fn from_pkcs12(der: &[u8], password: &str) -> Result<Identity> {
let identity = imp::Identity::from_pkcs12(der, password)?;
Ok(Identity(identity))
}
/// Parses a chain of PEM encoded X509 certificates, with the leaf certificate first.
/// `key` is a PEM encoded PKCS #8 formatted private key for the leaf certificate.
///
/// The certificate chain should contain any intermediate cerficates that should be sent to
/// clients to allow them to build a chain to a trusted root.
///
/// A certificate chain here means a series of PEM encoded certificates concatenated together.
pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Result<Identity> {
let identity = imp::Identity::from_pkcs8(pem, key)?;
Ok(Identity(identity))
}
}
/// An X509 certificate.
#[derive(Clone)]
pub struct Certificate(imp::Certificate);
impl Certificate {
/// Parses a DER-formatted X509 certificate.
pub fn from_der(der: &[u8]) -> Result<Certificate> {
let cert = imp::Certificate::from_der(der)?;
Ok(Certificate(cert))
}
/// Parses a PEM-formatted X509 certificate.
pub fn from_pem(pem: &[u8]) -> Result<Certificate> {
let cert = imp::Certificate::from_pem(pem)?;
Ok(Certificate(cert))
}
/// Returns the DER-encoded representation of this certificate.
pub fn to_der(&self) -> Result<Vec<u8>> {
let der = self.0.to_der()?;
Ok(der)
}
}
/// A TLS stream which has been interrupted midway through the handshake process.
pub struct MidHandshakeTlsStream<S>(imp::MidHandshakeTlsStream<S>);
impl<S> fmt::Debug for MidHandshakeTlsStream<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> MidHandshakeTlsStream<S> {
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S> MidHandshakeTlsStream<S>
where
S: io::Read + io::Write,
{
/// Restarts the handshake process.
///
/// If the handshake completes successfully then the negotiated stream is
/// returned. If there is a problem, however, then an error is returned.
/// Note that the error may not be fatal. For example if the underlying
/// stream is an asynchronous one then `HandshakeError::WouldBlock` may
/// just mean to wait for more I/O to happen later.
pub fn handshake(self) -> result::Result<TlsStream<S>, HandshakeError<S>> {
match self.0.handshake() {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
/// An error returned from `ClientBuilder::handshake`.
#[derive(Debug)]
pub enum HandshakeError<S> {
/// A fatal error.
Failure(Error),
/// A stream interrupted midway through the handshake process due to a
/// `WouldBlock` error.
///
/// Note that this is not a fatal error and it should be safe to call
/// `handshake` at a later time once the stream is ready to perform I/O
/// again.
WouldBlock(MidHandshakeTlsStream<S>),
}
impl<S> error::Error for HandshakeError<S>
where
S: Any + fmt::Debug,
{
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match *self {
HandshakeError::Failure(ref e) => Some(e),
HandshakeError::WouldBlock(_) => None,
}
}
}
impl<S> fmt::Display for HandshakeError<S>
where
S: Any + fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match *self {
HandshakeError::Failure(ref e) => fmt::Display::fmt(e, fmt),
HandshakeError::WouldBlock(_) => fmt.write_str("the handshake process was interrupted"),
}
}
}
impl<S> From<imp::HandshakeError<S>> for HandshakeError<S> {
fn from(e: imp::HandshakeError<S>) -> HandshakeError<S> {
match e {
imp::HandshakeError::Failure(e) => HandshakeError::Failure(Error(e)),
imp::HandshakeError::WouldBlock(s) => {
HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
}
}
}
}
/// SSL/TLS protocol versions.
#[derive(Debug, Copy, Clone)]
pub enum Protocol {
/// The SSL 3.0 protocol.
///
/// # Warning
///
/// SSL 3.0 has severe security flaws, and should not be used unless absolutely necessary. If
/// you are not sure if you need to enable this protocol, you should not.
Sslv3,
/// The TLS 1.0 protocol.
Tlsv10,
/// The TLS 1.1 protocol.
Tlsv11,
/// The TLS 1.2 protocol.
Tlsv12,
#[doc(hidden)]
__NonExhaustive,
}
/// A builder for `TlsConnector`s.
pub struct TlsConnectorBuilder {
identity: Option<Identity>,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
root_certificates: Vec<Certificate>,
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
use_sni: bool,
disable_built_in_roots: bool,
#[cfg(feature = "alpn")]
alpn: Vec<String>,
}
impl TlsConnectorBuilder {
/// Sets the identity to be used for client certificate authentication.
pub fn identity(&mut self, identity: Identity) -> &mut TlsConnectorBuilder {
self.identity = Some(identity);
self
}
/// Sets the minimum supported protocol version.
///
/// A value of `None` enables support for the oldest protocols supported by the implementation.
///
/// Defaults to `Some(Protocol::Tlsv10)`.
pub fn min_protocol_version(&mut self, protocol: Option<Protocol>) -> &mut TlsConnectorBuilder {
self.min_protocol = protocol;
self
}
/// Sets the maximum supported protocol version.
///
/// A value of `None` enables support for the newest protocols supported by the implementation.
///
/// Defaults to `None`.
pub fn max_protocol_version(&mut self, protocol: Option<Protocol>) -> &mut TlsConnectorBuilder {
self.max_protocol = protocol;
self
}
/// Adds a certificate to the set of roots that the connector will trust.
///
/// The connector will use the system's trust root by default. This method can be used to add
/// to that set when communicating with servers not trusted by the system.
///
/// Defaults to an empty set.
pub fn add_root_certificate(&mut self, cert: Certificate) -> &mut TlsConnectorBuilder {
self.root_certificates.push(cert);
self
}
/// Controls the use of built-in system certificates during certificate validation.
///
/// Defaults to `false` -- built-in system certs will be used.
pub fn disable_built_in_roots(&mut self, disable: bool) -> &mut TlsConnectorBuilder {
self.disable_built_in_roots = disable;
self
}
/// Request specific protocols through ALPN (Application-Layer Protocol Negotiation).
///
/// Defaults to no protocols.
#[cfg(feature = "alpn")]
#[cfg_attr(docsrs, doc(cfg(feature = "alpn")))]
pub fn request_alpns(&mut self, protocols: &[&str]) -> &mut TlsConnectorBuilder {
self.alpn = protocols.iter().map(|s| (*s).to_owned()).collect();
self
}
/// Controls the use of certificate validation.
///
/// Defaults to `false`.
///
/// # Warning
///
/// You should think very carefully before using this method. If invalid certificates are trusted, *any*
/// certificate for *any* site will be trusted for use. This includes expired certificates. This introduces
/// significant vulnerabilities, and should only be used as a last resort.
pub fn danger_accept_invalid_certs(
&mut self,
accept_invalid_certs: bool,
) -> &mut TlsConnectorBuilder {
self.accept_invalid_certs = accept_invalid_certs;
self
}
/// Controls the use of Server Name Indication (SNI).
///
/// Defaults to `true`.
pub fn use_sni(&mut self, use_sni: bool) -> &mut TlsConnectorBuilder {
self.use_sni = use_sni;
self
}
/// Controls the use of hostname verification.
///
/// Defaults to `false`.
///
/// # Warning
///
/// You should think very carefully before using this method. If invalid hostnames are trusted, *any* valid
/// certificate for *any* site will be trusted for use. This introduces significant vulnerabilities, and should
/// only be used as a last resort.
pub fn danger_accept_invalid_hostnames(
&mut self,
accept_invalid_hostnames: bool,
) -> &mut TlsConnectorBuilder {
self.accept_invalid_hostnames = accept_invalid_hostnames;
self
}
/// Creates a new `TlsConnector`.
pub fn build(&self) -> Result<TlsConnector> {
let connector = imp::TlsConnector::new(self)?;
Ok(TlsConnector(connector))
}
}
/// A builder for client-side TLS connections.
///
/// # Examples
///
/// ```rust
/// use native_tls::TlsConnector;
/// use std::io::{Read, Write};
/// use std::net::TcpStream;
///
/// let connector = TlsConnector::new().unwrap();
///
/// let stream = TcpStream::connect("google.com:443").unwrap();
/// let mut stream = connector.connect("google.com", stream).unwrap();
///
/// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
/// let mut res = vec![];
/// stream.read_to_end(&mut res).unwrap();
/// println!("{}", String::from_utf8_lossy(&res));
/// ```
#[derive(Clone, Debug)]
pub struct TlsConnector(imp::TlsConnector);
impl TlsConnector {
/// Returns a new connector with default settings.
pub fn new() -> Result<TlsConnector> {
TlsConnector::builder().build()
}
/// Returns a new builder for a `TlsConnector`.
pub fn builder() -> TlsConnectorBuilder {
TlsConnectorBuilder {
identity: None,
min_protocol: Some(Protocol::Tlsv10),
max_protocol: None,
root_certificates: vec![],
use_sni: true,
accept_invalid_certs: false,
accept_invalid_hostnames: false,
disable_built_in_roots: false,
#[cfg(feature = "alpn")]
alpn: vec![],
}
}
/// Initiates a TLS handshake.
///
/// The provided domain will be used for both SNI and certificate hostname
/// validation.
///
/// If the socket is nonblocking and a `WouldBlock` error is returned during
/// the handshake, a `HandshakeError::WouldBlock` error will be returned
/// which can be used to restart the handshake when the socket is ready
/// again.
///
/// The domain is ignored if both SNI and hostname verification are
/// disabled.
pub fn connect<S>(
&self,
domain: &str,
stream: S,
) -> result::Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let s = self.0.connect(domain, stream)?;
Ok(TlsStream(s))
}
}
/// A builder for `TlsAcceptor`s.
pub struct TlsAcceptorBuilder {
identity: Identity,
min_protocol: Option<Protocol>,
max_protocol: Option<Protocol>,
}
impl TlsAcceptorBuilder {
/// Sets the minimum supported protocol version.
///
/// A value of `None` enables support for the oldest protocols supported by the implementation.
///
/// Defaults to `Some(Protocol::Tlsv10)`.
pub fn min_protocol_version(&mut self, protocol: Option<Protocol>) -> &mut TlsAcceptorBuilder {
self.min_protocol = protocol;
self
}
/// Sets the maximum supported protocol version.
///
/// A value of `None` enables support for the newest protocols supported by the implementation.
///
/// Defaults to `None`.
pub fn max_protocol_version(&mut self, protocol: Option<Protocol>) -> &mut TlsAcceptorBuilder {
self.max_protocol = protocol;
self
}
/// Creates a new `TlsAcceptor`.
pub fn build(&self) -> Result<TlsAcceptor> {
let acceptor = imp::TlsAcceptor::new(self)?;
Ok(TlsAcceptor(acceptor))
}
}
/// A builder for server-side TLS connections.
///
/// # Examples
///
/// ```rust,no_run
/// use native_tls::{Identity, TlsAcceptor, TlsStream};
/// use std::fs::File;
/// use std::io::{Read};
/// use std::net::{TcpListener, TcpStream};
/// use std::sync::Arc;
/// use std::thread;
///
/// let mut file = File::open("identity.pfx").unwrap();
/// let mut identity = vec![];
/// file.read_to_end(&mut identity).unwrap();
/// let identity = Identity::from_pkcs12(&identity, "hunter2").unwrap();
///
/// let listener = TcpListener::bind("0.0.0.0:8443").unwrap();
/// let acceptor = TlsAcceptor::new(identity).unwrap();
/// let acceptor = Arc::new(acceptor);
///
/// fn handle_client(stream: TlsStream<TcpStream>) {
/// // ...
/// }
///
/// for stream in listener.incoming() {
/// match stream {
/// Ok(stream) => {
/// let acceptor = acceptor.clone();
/// thread::spawn(move || {
/// let stream = acceptor.accept(stream).unwrap();
/// handle_client(stream);
/// });
/// }
/// Err(e) => { /* connection failed */ }
/// }
/// }
/// ```
#[derive(Clone)]
pub struct TlsAcceptor(imp::TlsAcceptor);
impl TlsAcceptor {
/// Creates a acceptor with default settings.
///
/// The identity acts as the server's private key/certificate chain.
pub fn new(identity: Identity) -> Result<TlsAcceptor> {
TlsAcceptor::builder(identity).build()
}
/// Returns a new builder for a `TlsAcceptor`.
///
/// The identity acts as the server's private key/certificate chain.
pub fn builder(identity: Identity) -> TlsAcceptorBuilder {
TlsAcceptorBuilder {
identity,
min_protocol: Some(Protocol::Tlsv10),
max_protocol: None,
}
}
/// Initiates a TLS handshake.
///
/// If the socket is nonblocking and a `WouldBlock` error is returned during
/// the handshake, a `HandshakeError::WouldBlock` error will be returned
/// which can be used to restart the handshake when the socket is ready
/// again.
pub fn accept<S>(&self, stream: S) -> result::Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
match self.0.accept(stream) {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
/// A stream managing a TLS session.
pub struct TlsStream<S>(imp::TlsStream<S>);
impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> TlsStream<S> {
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S: io::Read + io::Write> TlsStream<S> {
/// Returns the number of bytes that can be read without resulting in any
/// network calls.
pub fn buffered_read_size(&self) -> Result<usize> {
Ok(self.0.buffered_read_size()?)
}
/// Returns the peer's leaf certificate, if available.
pub fn peer_certificate(&self) -> Result<Option<Certificate>> {
Ok(self.0.peer_certificate()?.map(Certificate))
}
/// Returns the tls-server-end-point channel binding data as defined in [RFC 5929].
///
/// [RFC 5929]: https://tools.ietf.org/html/rfc5929
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>> {
Ok(self.0.tls_server_end_point()?)
}
/// Returns the negotiated ALPN protocol.
#[cfg(feature = "alpn")]
#[cfg_attr(docsrs, doc(cfg(feature = "alpn")))]
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>> {
Ok(self.0.negotiated_alpn()?)
}
/// Shuts down the TLS session.
pub fn shutdown(&mut self) -> io::Result<()> {
self.0.shutdown()?;
Ok(())
}
}
impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
fn _check_kinds() {
use std::net::TcpStream;
fn is_sync<T: Sync>() {}
fn is_send<T: Send>() {}
is_sync::<Error>();
is_send::<Error>();
is_sync::<TlsConnectorBuilder>();
is_send::<TlsConnectorBuilder>();
is_sync::<TlsConnector>();
is_send::<TlsConnector>();
is_sync::<TlsAcceptorBuilder>();
is_send::<TlsAcceptorBuilder>();
is_sync::<TlsAcceptor>();
is_send::<TlsAcceptor>();
is_sync::<TlsStream<TcpStream>>();
is_send::<TlsStream<TcpStream>>();
is_sync::<MidHandshakeTlsStream<TcpStream>>();
is_send::<MidHandshakeTlsStream<TcpStream>>();
}

573
zeroidc/vendor/native-tls/src/test.rs vendored Normal file
View File

@@ -0,0 +1,573 @@
use std::fs;
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::process::{Command, Stdio};
use std::string::String;
use std::thread;
use super::*;
macro_rules! p {
($e:expr) => {
match $e {
Ok(r) => r,
Err(e) => panic!("{:?}", e),
}
};
}
#[test]
fn connect_google() {
let builder = p!(TlsConnector::new());
let s = p!(TcpStream::connect("google.com:443"));
let mut socket = p!(builder.connect("google.com", s));
p!(socket.write_all(b"GET / HTTP/1.0\r\n\r\n"));
let mut result = vec![];
p!(socket.read_to_end(&mut result));
println!("{}", String::from_utf8_lossy(&result));
assert!(result.starts_with(b"HTTP/1.0"));
assert!(result.ends_with(b"</HTML>\r\n") || result.ends_with(b"</html>"));
}
#[test]
fn connect_bad_hostname() {
let builder = p!(TlsConnector::new());
let s = p!(TcpStream::connect("google.com:443"));
builder.connect("goggle.com", s).unwrap_err();
}
#[test]
fn connect_bad_hostname_ignored() {
let builder = p!(TlsConnector::builder()
.danger_accept_invalid_hostnames(true)
.build());
let s = p!(TcpStream::connect("google.com:443"));
builder.connect("goggle.com", s).unwrap();
}
#[test]
fn connect_no_root_certs() {
let builder = p!(TlsConnector::builder().disable_built_in_roots(true).build());
let s = p!(TcpStream::connect("google.com:443"));
assert!(builder.connect("google.com", s).is_err());
}
#[test]
fn server_no_root_certs() {
let keys = test_cert_gen::keys();
let identity = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let builder = p!(TlsAcceptor::new(identity));
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
let mut socket = p!(builder.accept(socket));
let mut buf = [0; 5];
p!(socket.read_exact(&mut buf));
assert_eq!(&buf, b"hello");
p!(socket.write_all(b"world"));
});
let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap();
let socket = p!(TcpStream::connect(("localhost", port)));
let builder = p!(TlsConnector::builder()
.disable_built_in_roots(true)
.add_root_certificate(root_ca)
.build());
let mut socket = p!(builder.connect("localhost", socket));
p!(socket.write_all(b"hello"));
let mut buf = vec![];
p!(socket.read_to_end(&mut buf));
assert_eq!(buf, b"world");
p!(j.join());
}
#[test]
fn server() {
let keys = test_cert_gen::keys();
let identity = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let builder = p!(TlsAcceptor::new(identity));
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
let mut socket = p!(builder.accept(socket));
let mut buf = [0; 5];
p!(socket.read_exact(&mut buf));
assert_eq!(&buf, b"hello");
p!(socket.write_all(b"world"));
});
let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap();
let socket = p!(TcpStream::connect(("localhost", port)));
let builder = p!(TlsConnector::builder()
.add_root_certificate(root_ca)
.build());
let mut socket = p!(builder.connect("localhost", socket));
p!(socket.write_all(b"hello"));
let mut buf = vec![];
p!(socket.read_to_end(&mut buf));
assert_eq!(buf, b"world");
p!(j.join());
}
#[test]
fn certificate_from_pem() {
let dir = tempfile::tempdir().unwrap();
let keys = test_cert_gen::keys();
let der_path = dir.path().join("cert.der");
fs::write(&der_path, &keys.client.ca.get_der()).unwrap();
let output = Command::new("openssl")
.arg("x509")
.arg("-in")
.arg(der_path)
.arg("-inform")
.arg("der")
.stderr(Stdio::piped())
.output()
.unwrap();
assert!(output.status.success());
let cert = Certificate::from_pem(&output.stdout).unwrap();
assert_eq!(cert.to_der().unwrap(), keys.client.ca.get_der());
}
#[test]
fn peer_certificate() {
let keys = test_cert_gen::keys();
let identity = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let builder = p!(TlsAcceptor::new(identity));
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
let socket = p!(builder.accept(socket));
assert!(socket.peer_certificate().unwrap().is_none());
});
let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap();
let socket = p!(TcpStream::connect(("localhost", port)));
let builder = p!(TlsConnector::builder()
.add_root_certificate(root_ca)
.build());
let socket = p!(builder.connect("localhost", socket));
let cert = socket.peer_certificate().unwrap().unwrap();
assert_eq!(
cert.to_der().unwrap(),
keys.server.cert_and_key.cert.get_der()
);
p!(j.join());
}
#[test]
fn server_tls11_only() {
let keys = test_cert_gen::keys();
let identity = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let builder = p!(TlsAcceptor::builder(identity)
.min_protocol_version(Some(Protocol::Tlsv12))
.max_protocol_version(Some(Protocol::Tlsv12))
.build());
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
let mut socket = p!(builder.accept(socket));
let mut buf = [0; 5];
p!(socket.read_exact(&mut buf));
assert_eq!(&buf, b"hello");
p!(socket.write_all(b"world"));
});
let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap();
let socket = p!(TcpStream::connect(("localhost", port)));
let builder = p!(TlsConnector::builder()
.add_root_certificate(root_ca)
.min_protocol_version(Some(Protocol::Tlsv12))
.max_protocol_version(Some(Protocol::Tlsv12))
.build());
let mut socket = p!(builder.connect("localhost", socket));
p!(socket.write_all(b"hello"));
let mut buf = vec![];
p!(socket.read_to_end(&mut buf));
assert_eq!(buf, b"world");
p!(j.join());
}
#[test]
fn server_no_shared_protocol() {
let keys = test_cert_gen::keys();
let identity = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let builder = p!(TlsAcceptor::builder(identity)
.min_protocol_version(Some(Protocol::Tlsv12))
.build());
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
assert!(builder.accept(socket).is_err());
});
let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap();
let socket = p!(TcpStream::connect(("localhost", port)));
let builder = p!(TlsConnector::builder()
.add_root_certificate(root_ca)
.min_protocol_version(Some(Protocol::Tlsv11))
.max_protocol_version(Some(Protocol::Tlsv11))
.build());
assert!(builder.connect("localhost", socket).is_err());
p!(j.join());
}
#[test]
fn server_untrusted() {
let keys = test_cert_gen::keys();
let identity = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let builder = p!(TlsAcceptor::new(identity));
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
// FIXME should assert error
// https://github.com/steffengy/schannel-rs/issues/20
let _ = builder.accept(socket);
});
let socket = p!(TcpStream::connect(("localhost", port)));
let builder = p!(TlsConnector::new());
builder.connect("localhost", socket).unwrap_err();
p!(j.join());
}
#[test]
fn server_untrusted_unverified() {
let keys = test_cert_gen::keys();
let identity = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let builder = p!(TlsAcceptor::new(identity));
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
let mut socket = p!(builder.accept(socket));
let mut buf = [0; 5];
p!(socket.read_exact(&mut buf));
assert_eq!(&buf, b"hello");
p!(socket.write_all(b"world"));
});
let socket = p!(TcpStream::connect(("localhost", port)));
let builder = p!(TlsConnector::builder()
.danger_accept_invalid_certs(true)
.build());
let mut socket = p!(builder.connect("localhost", socket));
p!(socket.write_all(b"hello"));
let mut buf = vec![];
p!(socket.read_to_end(&mut buf));
assert_eq!(buf, b"world");
p!(j.join());
}
#[test]
fn import_same_identity_multiple_times() {
let keys = test_cert_gen::keys();
let _ = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let _ = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let cert = keys.server.cert_and_key.cert.to_pem().into_bytes();
let key = rsa_to_pkcs8(&keys.server.cert_and_key.key.to_pem_incorrect()).into_bytes();
let _ = p!(Identity::from_pkcs8(&cert, &key));
let _ = p!(Identity::from_pkcs8(&cert, &key));
}
#[test]
fn from_pkcs8_rejects_rsa_key() {
let keys = test_cert_gen::keys();
let cert = keys.server.cert_and_key.cert.to_pem().into_bytes();
let rsa_key = keys.server.cert_and_key.key.to_pem_incorrect();
assert!(Identity::from_pkcs8(&cert, rsa_key.as_bytes()).is_err());
let pkcs8_key = rsa_to_pkcs8(&rsa_key);
assert!(Identity::from_pkcs8(&cert, pkcs8_key.as_bytes()).is_ok());
}
#[test]
fn shutdown() {
let keys = test_cert_gen::keys();
let identity = p!(Identity::from_pkcs12(
&keys.server.cert_and_key_pkcs12.pkcs12.0,
&keys.server.cert_and_key_pkcs12.password
));
let builder = p!(TlsAcceptor::new(identity));
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
let mut socket = p!(builder.accept(socket));
let mut buf = [0; 5];
p!(socket.read_exact(&mut buf));
assert_eq!(&buf, b"hello");
assert_eq!(p!(socket.read(&mut buf)), 0);
p!(socket.shutdown());
});
let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap();
let socket = p!(TcpStream::connect(("localhost", port)));
let builder = p!(TlsConnector::builder()
.add_root_certificate(root_ca)
.build());
let mut socket = p!(builder.connect("localhost", socket));
p!(socket.write_all(b"hello"));
p!(socket.shutdown());
p!(j.join());
}
#[test]
#[cfg(feature = "alpn")]
fn alpn_google_h2() {
let builder = p!(TlsConnector::builder().request_alpns(&["h2"]).build());
let s = p!(TcpStream::connect("google.com:443"));
let socket = p!(builder.connect("google.com", s));
let alpn = p!(socket.negotiated_alpn());
assert_eq!(alpn, Some(b"h2".to_vec()));
}
#[test]
#[cfg(feature = "alpn")]
fn alpn_google_invalid() {
let builder = p!(TlsConnector::builder().request_alpns(&["h2c"]).build());
let s = p!(TcpStream::connect("google.com:443"));
let socket = p!(builder.connect("google.com", s));
let alpn = p!(socket.negotiated_alpn());
assert_eq!(alpn, None);
}
#[test]
#[cfg(feature = "alpn")]
fn alpn_google_none() {
let builder = p!(TlsConnector::new());
let s = p!(TcpStream::connect("google.com:443"));
let socket = p!(builder.connect("google.com", s));
let alpn = p!(socket.negotiated_alpn());
assert_eq!(alpn, None);
}
#[test]
fn server_pkcs8() {
let keys = test_cert_gen::keys();
let cert = keys.server.cert_and_key.cert.to_pem().into_bytes();
let key = rsa_to_pkcs8(&keys.server.cert_and_key.key.to_pem_incorrect()).into_bytes();
let ident = Identity::from_pkcs8(&cert, &key).unwrap();
let ident2 = ident.clone();
let builder = p!(TlsAcceptor::new(ident));
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
let mut socket = p!(builder.accept(socket));
let mut buf = [0; 5];
p!(socket.read_exact(&mut buf));
assert_eq!(&buf, b"hello");
p!(socket.write_all(b"world"));
});
let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap();
let socket = p!(TcpStream::connect(("localhost", port)));
let mut builder = TlsConnector::builder();
// FIXME
// This checks that we can successfully add a certificate on the client side.
// Unfortunately, we can not request client certificates through the API of this library,
// otherwise we could check in the server thread that
// socket.peer_certificate().unwrap().is_some()
builder.identity(ident2);
builder.add_root_certificate(root_ca);
let builder = p!(builder.build());
let mut socket = p!(builder.connect("localhost", socket));
p!(socket.write_all(b"hello"));
let mut buf = vec![];
p!(socket.read_to_end(&mut buf));
assert_eq!(buf, b"world");
p!(j.join());
}
#[test]
fn two_servers() {
let keys1 = test_cert_gen::gen_keys();
let cert = keys1.server.cert_and_key.cert.to_pem().into_bytes();
let key = rsa_to_pkcs8(&keys1.server.cert_and_key.key.to_pem_incorrect()).into_bytes();
let identity = p!(Identity::from_pkcs8(&cert, &key));
let builder = TlsAcceptor::builder(identity);
let builder = p!(builder.build());
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port = p!(listener.local_addr()).port();
let j = thread::spawn(move || {
let socket = p!(listener.accept()).0;
let mut socket = p!(builder.accept(socket));
let mut buf = [0; 5];
p!(socket.read_exact(&mut buf));
assert_eq!(&buf, b"hello");
p!(socket.write_all(b"world"));
});
let keys2 = test_cert_gen::gen_keys();
let cert = keys2.server.cert_and_key.cert.to_pem().into_bytes();
let key = rsa_to_pkcs8(&keys2.server.cert_and_key.key.to_pem_incorrect()).into_bytes();
let identity = p!(Identity::from_pkcs8(&cert, &key));
let builder = TlsAcceptor::builder(identity);
let builder = p!(builder.build());
let listener = p!(TcpListener::bind("0.0.0.0:0"));
let port2 = p!(listener.local_addr()).port();
let j2 = thread::spawn(move || {
let socket = p!(listener.accept()).0;
let mut socket = p!(builder.accept(socket));
let mut buf = [0; 5];
p!(socket.read_exact(&mut buf));
assert_eq!(&buf, b"hello");
p!(socket.write_all(b"world"));
});
let root_ca = Certificate::from_der(keys1.client.ca.get_der()).unwrap();
let socket = p!(TcpStream::connect(("localhost", port)));
let mut builder = TlsConnector::builder();
builder.add_root_certificate(root_ca);
let builder = p!(builder.build());
let mut socket = p!(builder.connect("localhost", socket));
p!(socket.write_all(b"hello"));
let mut buf = vec![];
p!(socket.read_to_end(&mut buf));
assert_eq!(buf, b"world");
let root_ca = Certificate::from_der(keys2.client.ca.get_der()).unwrap();
let socket = p!(TcpStream::connect(("localhost", port2)));
let mut builder = TlsConnector::builder();
builder.add_root_certificate(root_ca);
let builder = p!(builder.build());
let mut socket = p!(builder.connect("localhost", socket));
p!(socket.write_all(b"hello"));
let mut buf = vec![];
p!(socket.read_to_end(&mut buf));
assert_eq!(buf, b"world");
p!(j.join());
p!(j2.join());
}
fn rsa_to_pkcs8(pem: &str) -> String {
let mut child = Command::new("openssl")
.arg("pkcs8")
.arg("-topk8")
.arg("-nocrypt")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.unwrap();
{
let child_stdin = child.stdin.as_mut().unwrap();
child_stdin.write_all(pem.as_bytes()).unwrap();
}
String::from_utf8(child.wait_with_output().unwrap().stdout).unwrap()
}