diff --git a/third_party/tlslite/tlslite/constants.py b/third_party/tlslite/tlslite/constants.py index 715def9..e9743e4 100644 --- a/third_party/tlslite/tlslite/constants.py +++ b/third_party/tlslite/tlslite/constants.py @@ -54,6 +54,7 @@ class ExtensionType: # RFC 6066 / 4366 status_request = 5 # RFC 6066 / 4366 srp = 12 # RFC 5054 cert_type = 9 # RFC 6091 + alpn = 16 # RFC 7301 signed_cert_timestamps = 18 # RFC 6962 extended_master_secret = 23 # RFC 7627 token_binding = 24 # draft-ietf-tokbind-negotiation diff --git a/third_party/tlslite/tlslite/handshakesettings.py b/third_party/tlslite/tlslite/handshakesettings.py index d7be5b3..69fc6f4 100644 --- a/third_party/tlslite/tlslite/handshakesettings.py +++ b/third_party/tlslite/tlslite/handshakesettings.py @@ -128,6 +128,12 @@ class HandshakeSettings(object): Note that TACK support is not standardized by IETF and uses a temporary TLS Extension number, so should NOT be used in production software. + + @type alpnProtos: list of strings. + @param alpnProtos: A list of supported upper layer protocols to use in the + Application-Layer Protocol Negotiation Extension (RFC 7301). For the + client, the order does not matter. For the server, the list is in + decreasing order of preference. """ def __init__(self): self.minKeySize = 1023 @@ -146,6 +152,7 @@ class HandshakeSettings(object): self.enableChannelID = True self.enableExtendedMasterSecret = True self.supportedTokenBindingParams = [] + self.alpnProtos = None # Validates the min/max fields, and certificateTypes # Filters out unsupported cipherNames and cipherImplementations @@ -166,6 +173,7 @@ class HandshakeSettings(object): other.enableChannelID = self.enableChannelID other.enableExtendedMasterSecret = self.enableExtendedMasterSecret other.supportedTokenBindingParams = self.supportedTokenBindingParams + other.alpnProtos = self.alpnProtos; if not cipherfactory.tripleDESPresent: other.cipherNames = [e for e in self.cipherNames if e != "3des"] diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py index 5762ac6..1ce9320 100644 --- a/third_party/tlslite/tlslite/messages.py +++ b/third_party/tlslite/tlslite/messages.py @@ -18,6 +18,27 @@ from .x509 import X509 from .x509certchain import X509CertChain from .utils.tackwrapper import * +def parse_next_protos(b): + protos = [] + while True: + if len(b) == 0: + break + l = b[0] + b = b[1:] + if len(b) < l: + raise BadNextProtos(len(b)) + protos.append(b[:l]) + b = b[l:] + return protos + +def next_protos_encoded(protocol_list): + b = bytearray() + for e in protocol_list: + if len(e) > 255 or len(e) == 0: + raise BadNextProtos(len(e)) + b += bytearray( [len(e)] ) + bytearray(e) + return b + class RecordHeader3(object): def __init__(self): self.type = 0 @@ -111,6 +132,7 @@ class ClientHello(HandshakeMsg): self.compression_methods = [] # a list of 8-bit values self.srp_username = None # a string self.tack = False + self.alpn_protos_advertised = None self.supports_npn = False self.server_name = bytearray(0) self.channel_id = False @@ -121,7 +143,8 @@ class ClientHello(HandshakeMsg): def create(self, version, random, session_id, cipher_suites, certificate_types=None, srpUsername=None, - tack=False, supports_npn=False, serverName=None): + tack=False, alpn_protos_advertised=None, + supports_npn=False, serverName=None): self.client_version = version self.random = random self.session_id = session_id @@ -131,6 +154,7 @@ class ClientHello(HandshakeMsg): if srpUsername: self.srp_username = bytearray(srpUsername, "utf-8") self.tack = tack + self.alpn_protos_advertised = alpn_protos_advertised self.supports_npn = supports_npn if serverName: self.server_name = bytearray(serverName, "utf-8") @@ -171,6 +195,11 @@ class ClientHello(HandshakeMsg): self.certificate_types = p.getVarList(1, 1) elif extType == ExtensionType.tack: self.tack = True + elif extType == ExtensionType.alpn: + structLength = p.get(2) + if structLength + 2 != extLength: + raise SyntaxError() + self.alpn_protos_advertised = parse_next_protos(p.getFixBytes(structLength)) elif extType == ExtensionType.supports_npn: self.supports_npn = True elif extType == ExtensionType.server_name: @@ -243,6 +272,12 @@ class ClientHello(HandshakeMsg): w2.add(ExtensionType.srp, 2) w2.add(len(self.srp_username)+1, 2) w2.addVarSeq(self.srp_username, 1, 1) + if self.alpn_protos_advertised is not None: + encoded_alpn_protos_advertised = next_protos_encoded(self.alpn_protos_advertised) + w2.add(ExtensionType.alpn, 2) + w2.add(len(encoded_alpn_protos_advertised) + 2, 2) + w2.add(len(encoded_alpn_protos_advertised), 2) + w2.addFixSeq(encoded_alpn_protos_advertised, 1) if self.supports_npn: w2.add(ExtensionType.supports_npn, 2) w2.add(0, 2) @@ -267,6 +302,13 @@ class BadNextProtos(Exception): def __str__(self): return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length +class InvalidALPNResponse(Exception): + def __init__(self, l): + self.length = l + + def __str__(self): + return 'ALPN server response protocol list has invalid length %d. It must be of length one.' % self.length + class ServerHello(HandshakeMsg): def __init__(self): HandshakeMsg.__init__(self, HandshakeType.server_hello) @@ -277,6 +319,7 @@ class ServerHello(HandshakeMsg): self.certificate_type = CertificateType.x509 self.compression_method = 0 self.tackExt = None + self.alpn_proto_selected = None self.next_protos_advertised = None self.next_protos = None self.channel_id = False @@ -286,7 +329,8 @@ class ServerHello(HandshakeMsg): self.status_request = False def create(self, version, random, session_id, cipher_suite, - certificate_type, tackExt, next_protos_advertised): + certificate_type, tackExt, alpn_proto_selected, + next_protos_advertised): self.server_version = version self.random = random self.session_id = session_id @@ -294,6 +338,7 @@ class ServerHello(HandshakeMsg): self.certificate_type = certificate_type self.compression_method = 0 self.tackExt = tackExt + self.alpn_proto_selected = alpn_proto_selected self.next_protos_advertised = next_protos_advertised return self @@ -316,35 +361,22 @@ class ServerHello(HandshakeMsg): self.certificate_type = p.get(1) elif extType == ExtensionType.tack and tackpyLoaded: self.tackExt = TackExtension(p.getFixBytes(extLength)) + elif extType == ExtensionType.alpn: + structLength = p.get(2) + if structLength + 2 != extLength: + raise SyntaxError() + alpn_protos = parse_next_protos(p.getFixBytes(structLength)) + if len(alpn_protos) != 1: + raise InvalidALPNResponse(len(alpn_protos)); + self.alpn_proto_selected = alpn_protos[0] elif extType == ExtensionType.supports_npn: - self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength)) + self.next_protos = parse_next_protos(p.getFixBytes(extLength)) else: p.getFixBytes(extLength) soFar += 4 + extLength p.stopLengthCheck() return self - def __parse_next_protos(self, b): - protos = [] - while True: - if len(b) == 0: - break - l = b[0] - b = b[1:] - if len(b) < l: - raise BadNextProtos(len(b)) - protos.append(b[:l]) - b = b[l:] - return protos - - def __next_protos_encoded(self): - b = bytearray() - for e in self.next_protos_advertised: - if len(e) > 255 or len(e) == 0: - raise BadNextProtos(len(e)) - b += bytearray( [len(e)] ) + bytearray(e) - return b - def write(self): w = Writer() w.add(self.server_version[0], 1) @@ -365,8 +397,15 @@ class ServerHello(HandshakeMsg): w2.add(ExtensionType.tack, 2) w2.add(len(b), 2) w2.bytes += b + if self.alpn_proto_selected is not None: + alpn_protos_single_element_list = [self.alpn_proto_selected] + encoded_alpn_protos_advertised = next_protos_encoded(alpn_protos_single_element_list) + w2.add(ExtensionType.alpn, 2) + w2.add(len(encoded_alpn_protos_advertised) + 2, 2) + w2.add(len(encoded_alpn_protos_advertised), 2) + w2.addFixSeq(encoded_alpn_protos_advertised, 1) if self.next_protos_advertised is not None: - encoded_next_protos_advertised = self.__next_protos_encoded() + encoded_next_protos_advertised = next_protos_encoded(self.next_protos_advertised) w2.add(ExtensionType.supports_npn, 2) w2.add(len(encoded_next_protos_advertised), 2) w2.addFixSeq(encoded_next_protos_advertised, 1) diff --git a/third_party/tlslite/tlslite/tlsconnection.py b/third_party/tlslite/tlslite/tlsconnection.py index 41aab85..de5d580 100644 --- a/third_party/tlslite/tlslite/tlsconnection.py +++ b/third_party/tlslite/tlslite/tlsconnection.py @@ -495,6 +495,10 @@ class TLSConnection(TLSRecordLayer): settings = HandshakeSettings() settings = settings._filter() + if settings.alpnProtos is not None: + if len(settings.alpnProtos) == 0: + raise ValueError("Caller passed no alpnProtos") + if clientCertChain: if not isinstance(clientCertChain, X509CertChain): raise ValueError("Unrecognized certificate type") @@ -651,7 +655,8 @@ class TLSConnection(TLSRecordLayer): session.sessionID, cipherSuites, certificateTypes, session.srpUsername, - reqTack, nextProtos is not None, + reqTack, settings.alpnProtos, + nextProtos is not None, session.serverName) #Or send ClientHello (without) @@ -661,7 +666,8 @@ class TLSConnection(TLSRecordLayer): bytearray(0), cipherSuites, certificateTypes, srpUsername, - reqTack, nextProtos is not None, + reqTack, settings.alpnProtos, + nextProtos is not None, serverName) for result in self._sendMsg(clientHello): yield result @@ -714,6 +720,16 @@ class TLSConnection(TLSRecordLayer): AlertDescription.illegal_parameter, "Server responded with unrequested Tack Extension"): yield result + if serverHello.alpn_proto_selected and not clientHello.alpn_protos_advertised: + for result in self._sendError(\ + AlertDescription.illegal_parameter, + "Server responded with unrequested ALPN Extension"): + yield result + if serverHello.alpn_proto_selected and serverHello.next_protos: + for result in self._sendError(\ + AlertDescription.illegal_parameter, + "Server responded with both ALPN and NPN extension"): + yield result if serverHello.next_protos and not clientHello.supports_npn: for result in self._sendError(\ AlertDescription.illegal_parameter, @@ -1315,6 +1331,15 @@ class TLSConnection(TLSRecordLayer): else: sessionID = bytearray(0) + alpn_proto_selected = None + if (clientHello.alpn_protos_advertised is not None + and settings.alpnProtos is not None): + for proto in settings.alpnProtos: + if proto in clientHello.alpn_protos_advertised: + alpn_proto_selected = proto + nextProtos = None + break; + if not clientHello.supports_npn: nextProtos = None @@ -1330,6 +1355,7 @@ class TLSConnection(TLSRecordLayer): serverHello = ServerHello() serverHello.create(self.version, getRandomBytes(32), sessionID, \ cipherSuite, CertificateType.x509, tackExt, + alpn_proto_selected, nextProtos) serverHello.channel_id = \ clientHello.channel_id and settings.enableChannelID @@ -1500,6 +1526,14 @@ class TLSConnection(TLSRecordLayer): else: assert(False) + alpn_proto_selected = None + if (clientHello.alpn_protos_advertised is not None + and settings.alpnProtos is not None): + for proto in settings.alpnProtos: + if proto in clientHello.alpn_protos_advertised: + alpn_proto_selected = proto + break; + #If resumption was requested and we have a session cache... if clientHello.session_id and sessionCache: session = None @@ -1540,7 +1574,8 @@ class TLSConnection(TLSRecordLayer): serverHello = ServerHello() serverHello.create(self.version, getRandomBytes(32), session.sessionID, session.cipherSuite, - CertificateType.x509, None, None) + CertificateType.x509, None, + alpn_proto_selected, None) serverHello.extended_master_secret = \ clientHello.extended_master_secret and \ settings.enableExtendedMasterSecret