Package tlslite :: Module messages
[hide private]
[frames] | no frames]

Source Code for Module tlslite.messages

  1  # Authors:  
  2  #   Trevor Perrin 
  3  #   Google - handling CertificateRequest.certificate_types 
  4  #   Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support 
  5  #   Dimitris Moraitis - Anon ciphersuites 
  6  # 
  7  # See the LICENSE file for legal information regarding use of this file. 
  8   
  9  """Classes representing TLS messages.""" 
 10   
 11  from .utils.compat import * 
 12  from .utils.cryptomath import * 
 13  from .errors import * 
 14  from .utils.codec import * 
 15  from .constants import * 
 16  from .x509 import X509 
 17  from .x509certchain import X509CertChain 
 18  from .utils.tackwrapper import * 
 19   
20 -class RecordHeader3(object):
21 - def __init__(self):
22 self.type = 0 23 self.version = (0,0) 24 self.length = 0 25 self.ssl2 = False
26
27 - def create(self, version, type, length):
28 self.type = type 29 self.version = version 30 self.length = length 31 return self
32
33 - def write(self):
34 w = Writer() 35 w.add(self.type, 1) 36 w.add(self.version[0], 1) 37 w.add(self.version[1], 1) 38 w.add(self.length, 2) 39 return w.bytes
40
41 - def parse(self, p):
42 self.type = p.get(1) 43 self.version = (p.get(1), p.get(1)) 44 self.length = p.get(2) 45 self.ssl2 = False 46 return self
47
48 -class RecordHeader2(object):
49 - def __init__(self):
50 self.type = 0 51 self.version = (0,0) 52 self.length = 0 53 self.ssl2 = True
54
55 - def parse(self, p):
56 if p.get(1)!=128: 57 raise SyntaxError() 58 self.type = ContentType.handshake 59 self.version = (2,0) 60 #We don't support 2-byte-length-headers; could be a problem 61 self.length = p.get(1) 62 return self
63 64
65 -class Alert(object):
66 - def __init__(self):
67 self.contentType = ContentType.alert 68 self.level = 0 69 self.description = 0
70
71 - def create(self, description, level=AlertLevel.fatal):
72 self.level = level 73 self.description = description 74 return self
75
76 - def parse(self, p):
77 p.setLengthCheck(2) 78 self.level = p.get(1) 79 self.description = p.get(1) 80 p.stopLengthCheck() 81 return self
82
83 - def write(self):
84 w = Writer() 85 w.add(self.level, 1) 86 w.add(self.description, 1) 87 return w.bytes
88 89
90 -class HandshakeMsg(object):
91 - def __init__(self, handshakeType):
92 self.contentType = ContentType.handshake 93 self.handshakeType = handshakeType
94
95 - def postWrite(self, w):
96 headerWriter = Writer() 97 headerWriter.add(self.handshakeType, 1) 98 headerWriter.add(len(w.bytes), 3) 99 return headerWriter.bytes + w.bytes
100
101 -class ClientHello(HandshakeMsg):
102 - def __init__(self, ssl2=False):
103 HandshakeMsg.__init__(self, HandshakeType.client_hello) 104 self.ssl2 = ssl2 105 self.client_version = (0,0) 106 self.random = bytearray(32) 107 self.session_id = bytearray(0) 108 self.cipher_suites = [] # a list of 16-bit values 109 self.certificate_types = [CertificateType.x509] 110 self.compression_methods = [] # a list of 8-bit values 111 self.srp_username = None # a string 112 self.tack = False 113 self.supports_npn = False 114 self.server_name = bytearray(0)
115
116 - def create(self, version, random, session_id, cipher_suites, 117 certificate_types=None, srpUsername=None, 118 tack=False, supports_npn=False, serverName=None):
119 self.client_version = version 120 self.random = random 121 self.session_id = session_id 122 self.cipher_suites = cipher_suites 123 self.certificate_types = certificate_types 124 self.compression_methods = [0] 125 if srpUsername: 126 self.srp_username = bytearray(srpUsername, "utf-8") 127 self.tack = tack 128 self.supports_npn = supports_npn 129 if serverName: 130 self.server_name = bytearray(serverName, "utf-8") 131 return self
132
133 - def parse(self, p):
134 if self.ssl2: 135 self.client_version = (p.get(1), p.get(1)) 136 cipherSpecsLength = p.get(2) 137 sessionIDLength = p.get(2) 138 randomLength = p.get(2) 139 self.cipher_suites = p.getFixList(3, cipherSpecsLength//3) 140 self.session_id = p.getFixBytes(sessionIDLength) 141 self.random = p.getFixBytes(randomLength) 142 if len(self.random) < 32: 143 zeroBytes = 32-len(self.random) 144 self.random = bytearray(zeroBytes) + self.random 145 self.compression_methods = [0]#Fake this value 146 147 #We're not doing a stopLengthCheck() for SSLv2, oh well.. 148 else: 149 p.startLengthCheck(3) 150 self.client_version = (p.get(1), p.get(1)) 151 self.random = p.getFixBytes(32) 152 self.session_id = p.getVarBytes(1) 153 self.cipher_suites = p.getVarList(2, 2) 154 self.compression_methods = p.getVarList(1, 1) 155 if not p.atLengthCheck(): 156 totalExtLength = p.get(2) 157 soFar = 0 158 while soFar != totalExtLength: 159 extType = p.get(2) 160 extLength = p.get(2) 161 index1 = p.index 162 if extType == ExtensionType.srp: 163 self.srp_username = p.getVarBytes(1) 164 elif extType == ExtensionType.cert_type: 165 self.certificate_types = p.getVarList(1, 1) 166 elif extType == ExtensionType.tack: 167 self.tack = True 168 elif extType == ExtensionType.supports_npn: 169 self.supports_npn = True 170 elif extType == ExtensionType.server_name: 171 serverNameListBytes = p.getFixBytes(extLength) 172 p2 = Parser(serverNameListBytes) 173 p2.startLengthCheck(2) 174 while 1: 175 if p2.atLengthCheck(): 176 break # no host_name, oh well 177 name_type = p2.get(1) 178 hostNameBytes = p2.getVarBytes(2) 179 if name_type == NameType.host_name: 180 self.server_name = hostNameBytes 181 break 182 else: 183 _ = p.getFixBytes(extLength) 184 index2 = p.index 185 if index2 - index1 != extLength: 186 raise SyntaxError("Bad length for extension_data") 187 soFar += 4 + extLength 188 p.stopLengthCheck() 189 return self
190
191 - def write(self):
192 w = Writer() 193 w.add(self.client_version[0], 1) 194 w.add(self.client_version[1], 1) 195 w.addFixSeq(self.random, 1) 196 w.addVarSeq(self.session_id, 1, 1) 197 w.addVarSeq(self.cipher_suites, 2, 2) 198 w.addVarSeq(self.compression_methods, 1, 1) 199 200 w2 = Writer() # For Extensions 201 if self.certificate_types and self.certificate_types != \ 202 [CertificateType.x509]: 203 w2.add(ExtensionType.cert_type, 2) 204 w2.add(len(self.certificate_types)+1, 2) 205 w2.addVarSeq(self.certificate_types, 1, 1) 206 if self.srp_username: 207 w2.add(ExtensionType.srp, 2) 208 w2.add(len(self.srp_username)+1, 2) 209 w2.addVarSeq(self.srp_username, 1, 1) 210 if self.supports_npn: 211 w2.add(ExtensionType.supports_npn, 2) 212 w2.add(0, 2) 213 if self.server_name: 214 w2.add(ExtensionType.server_name, 2) 215 w2.add(len(self.server_name)+5, 2) 216 w2.add(len(self.server_name)+3, 2) 217 w2.add(NameType.host_name, 1) 218 w2.addVarSeq(self.server_name, 1, 2) 219 if self.tack: 220 w2.add(ExtensionType.tack, 2) 221 w2.add(0, 2) 222 if len(w2.bytes): 223 w.add(len(w2.bytes), 2) 224 w.bytes += w2.bytes 225 return self.postWrite(w)
226
227 -class BadNextProtos(Exception):
228 - def __init__(self, l):
229 self.length = l
230
231 - def __str__(self):
232 return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
233
234 -class ServerHello(HandshakeMsg):
235 - def __init__(self):
236 HandshakeMsg.__init__(self, HandshakeType.server_hello) 237 self.server_version = (0,0) 238 self.random = bytearray(32) 239 self.session_id = bytearray(0) 240 self.cipher_suite = 0 241 self.certificate_type = CertificateType.x509 242 self.compression_method = 0 243 self.tackExt = None 244 self.next_protos_advertised = None 245 self.next_protos = None
246
247 - def create(self, version, random, session_id, cipher_suite, 248 certificate_type, tackExt, next_protos_advertised):
249 self.server_version = version 250 self.random = random 251 self.session_id = session_id 252 self.cipher_suite = cipher_suite 253 self.certificate_type = certificate_type 254 self.compression_method = 0 255 self.tackExt = tackExt 256 self.next_protos_advertised = next_protos_advertised 257 return self
258
259 - def parse(self, p):
260 p.startLengthCheck(3) 261 self.server_version = (p.get(1), p.get(1)) 262 self.random = p.getFixBytes(32) 263 self.session_id = p.getVarBytes(1) 264 self.cipher_suite = p.get(2) 265 self.compression_method = p.get(1) 266 if not p.atLengthCheck(): 267 totalExtLength = p.get(2) 268 soFar = 0 269 while soFar != totalExtLength: 270 extType = p.get(2) 271 extLength = p.get(2) 272 if extType == ExtensionType.cert_type: 273 if extLength != 1: 274 raise SyntaxError() 275 self.certificate_type = p.get(1) 276 elif extType == ExtensionType.tack and tackpyLoaded: 277 self.tackExt = TackExtension(p.getFixBytes(extLength)) 278 elif extType == ExtensionType.supports_npn: 279 self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength)) 280 else: 281 p.getFixBytes(extLength) 282 soFar += 4 + extLength 283 p.stopLengthCheck() 284 return self
285
286 - def __parse_next_protos(self, b):
287 protos = [] 288 while True: 289 if len(b) == 0: 290 break 291 l = b[0] 292 b = b[1:] 293 if len(b) < l: 294 raise BadNextProtos(len(b)) 295 protos.append(b[:l]) 296 b = b[l:] 297 return protos
298
299 - def __next_protos_encoded(self):
300 b = bytearray() 301 for e in self.next_protos_advertised: 302 if len(e) > 255 or len(e) == 0: 303 raise BadNextProtos(len(e)) 304 b += bytearray( [len(e)] ) + bytearray(e) 305 return b
306
307 - def write(self):
308 w = Writer() 309 w.add(self.server_version[0], 1) 310 w.add(self.server_version[1], 1) 311 w.addFixSeq(self.random, 1) 312 w.addVarSeq(self.session_id, 1, 1) 313 w.add(self.cipher_suite, 2) 314 w.add(self.compression_method, 1) 315 316 w2 = Writer() # For Extensions 317 if self.certificate_type and self.certificate_type != \ 318 CertificateType.x509: 319 w2.add(ExtensionType.cert_type, 2) 320 w2.add(1, 2) 321 w2.add(self.certificate_type, 1) 322 if self.tackExt: 323 b = self.tackExt.serialize() 324 w2.add(ExtensionType.tack, 2) 325 w2.add(len(b), 2) 326 w2.bytes += b 327 if self.next_protos_advertised is not None: 328 encoded_next_protos_advertised = self.__next_protos_encoded() 329 w2.add(ExtensionType.supports_npn, 2) 330 w2.add(len(encoded_next_protos_advertised), 2) 331 w2.addFixSeq(encoded_next_protos_advertised, 1) 332 if len(w2.bytes): 333 w.add(len(w2.bytes), 2) 334 w.bytes += w2.bytes 335 return self.postWrite(w)
336 337
338 -class Certificate(HandshakeMsg):
339 - def __init__(self, certificateType):
340 HandshakeMsg.__init__(self, HandshakeType.certificate) 341 self.certificateType = certificateType 342 self.certChain = None
343
344 - def create(self, certChain):
345 self.certChain = certChain 346 return self
347
348 - def parse(self, p):
349 p.startLengthCheck(3) 350 if self.certificateType == CertificateType.x509: 351 chainLength = p.get(3) 352 index = 0 353 certificate_list = [] 354 while index != chainLength: 355 certBytes = p.getVarBytes(3) 356 x509 = X509() 357 x509.parseBinary(certBytes) 358 certificate_list.append(x509) 359 index += len(certBytes)+3 360 if certificate_list: 361 self.certChain = X509CertChain(certificate_list) 362 else: 363 raise AssertionError() 364 365 p.stopLengthCheck() 366 return self
367
368 - def write(self):
369 w = Writer() 370 if self.certificateType == CertificateType.x509: 371 chainLength = 0 372 if self.certChain: 373 certificate_list = self.certChain.x509List 374 else: 375 certificate_list = [] 376 #determine length 377 for cert in certificate_list: 378 bytes = cert.writeBytes() 379 chainLength += len(bytes)+3 380 #add bytes 381 w.add(chainLength, 3) 382 for cert in certificate_list: 383 bytes = cert.writeBytes() 384 w.addVarSeq(bytes, 1, 3) 385 else: 386 raise AssertionError() 387 return self.postWrite(w)
388
389 -class CertificateRequest(HandshakeMsg):
390 - def __init__(self):
391 HandshakeMsg.__init__(self, HandshakeType.certificate_request) 392 #Apple's Secure Transport library rejects empty certificate_types, so 393 #default to rsa_sign. 394 self.certificate_types = [ClientCertificateType.rsa_sign] 395 self.certificate_authorities = []
396
397 - def create(self, certificate_types, certificate_authorities):
398 self.certificate_types = certificate_types 399 self.certificate_authorities = certificate_authorities 400 return self
401
402 - def parse(self, p):
403 p.startLengthCheck(3) 404 self.certificate_types = p.getVarList(1, 1) 405 ca_list_length = p.get(2) 406 index = 0 407 self.certificate_authorities = [] 408 while index != ca_list_length: 409 ca_bytes = p.getVarBytes(2) 410 self.certificate_authorities.append(ca_bytes) 411 index += len(ca_bytes)+2 412 p.stopLengthCheck() 413 return self
414
415 - def write(self):
416 w = Writer() 417 w.addVarSeq(self.certificate_types, 1, 1) 418 caLength = 0 419 #determine length 420 for ca_dn in self.certificate_authorities: 421 caLength += len(ca_dn)+2 422 w.add(caLength, 2) 423 #add bytes 424 for ca_dn in self.certificate_authorities: 425 w.addVarSeq(ca_dn, 1, 2) 426 return self.postWrite(w)
427
428 -class ServerKeyExchange(HandshakeMsg):
429 - def __init__(self, cipherSuite):
430 HandshakeMsg.__init__(self, HandshakeType.server_key_exchange) 431 self.cipherSuite = cipherSuite 432 self.srp_N = 0 433 self.srp_g = 0 434 self.srp_s = bytearray(0) 435 self.srp_B = 0 436 # Anon DH params: 437 self.dh_p = 0 438 self.dh_g = 0 439 self.dh_Ys = 0 440 self.signature = bytearray(0)
441
442 - def createSRP(self, srp_N, srp_g, srp_s, srp_B):
443 self.srp_N = srp_N 444 self.srp_g = srp_g 445 self.srp_s = srp_s 446 self.srp_B = srp_B 447 return self
448
449 - def createDH(self, dh_p, dh_g, dh_Ys):
450 self.dh_p = dh_p 451 self.dh_g = dh_g 452 self.dh_Ys = dh_Ys 453 return self
454
455 - def parse(self, p):
456 p.startLengthCheck(3) 457 if self.cipherSuite in CipherSuite.srpAllSuites: 458 self.srp_N = bytesToNumber(p.getVarBytes(2)) 459 self.srp_g = bytesToNumber(p.getVarBytes(2)) 460 self.srp_s = p.getVarBytes(1) 461 self.srp_B = bytesToNumber(p.getVarBytes(2)) 462 if self.cipherSuite in CipherSuite.srpCertSuites: 463 self.signature = p.getVarBytes(2) 464 elif self.cipherSuite in CipherSuite.anonSuites: 465 self.dh_p = bytesToNumber(p.getVarBytes(2)) 466 self.dh_g = bytesToNumber(p.getVarBytes(2)) 467 self.dh_Ys = bytesToNumber(p.getVarBytes(2)) 468 p.stopLengthCheck() 469 return self
470
471 - def write(self):
472 w = Writer() 473 if self.cipherSuite in CipherSuite.srpAllSuites: 474 w.addVarSeq(numberToByteArray(self.srp_N), 1, 2) 475 w.addVarSeq(numberToByteArray(self.srp_g), 1, 2) 476 w.addVarSeq(self.srp_s, 1, 1) 477 w.addVarSeq(numberToByteArray(self.srp_B), 1, 2) 478 if self.cipherSuite in CipherSuite.srpCertSuites: 479 w.addVarSeq(self.signature, 1, 2) 480 elif self.cipherSuite in CipherSuite.anonSuites: 481 w.addVarSeq(numberToByteArray(self.dh_p), 1, 2) 482 w.addVarSeq(numberToByteArray(self.dh_g), 1, 2) 483 w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2) 484 if self.cipherSuite in []: # TODO support for signed_params 485 w.addVarSeq(self.signature, 1, 2) 486 return self.postWrite(w)
487
488 - def hash(self, clientRandom, serverRandom):
489 oldCipherSuite = self.cipherSuite 490 self.cipherSuite = None 491 try: 492 bytes = clientRandom + serverRandom + self.write()[4:] 493 return MD5(bytes) + SHA1(bytes) 494 finally: 495 self.cipherSuite = oldCipherSuite
496
497 -class ServerHelloDone(HandshakeMsg):
498 - def __init__(self):
500
501 - def create(self):
502 return self
503
504 - def parse(self, p):
505 p.startLengthCheck(3) 506 p.stopLengthCheck() 507 return self
508
509 - def write(self):
510 w = Writer() 511 return self.postWrite(w)
512
513 -class ClientKeyExchange(HandshakeMsg):
514 - def __init__(self, cipherSuite, version=None):
515 HandshakeMsg.__init__(self, HandshakeType.client_key_exchange) 516 self.cipherSuite = cipherSuite 517 self.version = version 518 self.srp_A = 0 519 self.encryptedPreMasterSecret = bytearray(0)
520
521 - def createSRP(self, srp_A):
522 self.srp_A = srp_A 523 return self
524
525 - def createRSA(self, encryptedPreMasterSecret):
526 self.encryptedPreMasterSecret = encryptedPreMasterSecret 527 return self
528
529 - def createDH(self, dh_Yc):
530 self.dh_Yc = dh_Yc 531 return self
532
533 - def parse(self, p):
534 p.startLengthCheck(3) 535 if self.cipherSuite in CipherSuite.srpAllSuites: 536 self.srp_A = bytesToNumber(p.getVarBytes(2)) 537 elif self.cipherSuite in CipherSuite.certSuites: 538 if self.version in ((3,1), (3,2)): 539 self.encryptedPreMasterSecret = p.getVarBytes(2) 540 elif self.version == (3,0): 541 self.encryptedPreMasterSecret = \ 542 p.getFixBytes(len(p.bytes)-p.index) 543 else: 544 raise AssertionError() 545 elif self.cipherSuite in CipherSuite.anonSuites: 546 self.dh_Yc = bytesToNumber(p.getVarBytes(2)) 547 else: 548 raise AssertionError() 549 p.stopLengthCheck() 550 return self
551
552 - def write(self):
553 w = Writer() 554 if self.cipherSuite in CipherSuite.srpAllSuites: 555 w.addVarSeq(numberToByteArray(self.srp_A), 1, 2) 556 elif self.cipherSuite in CipherSuite.certSuites: 557 if self.version in ((3,1), (3,2)): 558 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2) 559 elif self.version == (3,0): 560 w.addFixSeq(self.encryptedPreMasterSecret, 1) 561 else: 562 raise AssertionError() 563 elif self.cipherSuite in CipherSuite.anonSuites: 564 w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2) 565 else: 566 raise AssertionError() 567 return self.postWrite(w)
568
569 -class CertificateVerify(HandshakeMsg):
570 - def __init__(self):
571 HandshakeMsg.__init__(self, HandshakeType.certificate_verify) 572 self.signature = bytearray(0)
573
574 - def create(self, signature):
575 self.signature = signature 576 return self
577
578 - def parse(self, p):
579 p.startLengthCheck(3) 580 self.signature = p.getVarBytes(2) 581 p.stopLengthCheck() 582 return self
583
584 - def write(self):
585 w = Writer() 586 w.addVarSeq(self.signature, 1, 2) 587 return self.postWrite(w)
588
589 -class ChangeCipherSpec(object):
590 - def __init__(self):
591 self.contentType = ContentType.change_cipher_spec 592 self.type = 1
593
594 - def create(self):
595 self.type = 1 596 return self
597
598 - def parse(self, p):
599 p.setLengthCheck(1) 600 self.type = p.get(1) 601 p.stopLengthCheck() 602 return self
603
604 - def write(self):
605 w = Writer() 606 w.add(self.type,1) 607 return w.bytes
608 609
610 -class NextProtocol(HandshakeMsg):
611 - def __init__(self):
612 HandshakeMsg.__init__(self, HandshakeType.next_protocol) 613 self.next_proto = None
614
615 - def create(self, next_proto):
616 self.next_proto = next_proto 617 return self
618
619 - def parse(self, p):
620 p.startLengthCheck(3) 621 self.next_proto = p.getVarBytes(1) 622 _ = p.getVarBytes(1) 623 p.stopLengthCheck() 624 return self
625
626 - def write(self, trial=False):
627 w = Writer() 628 w.addVarSeq(self.next_proto, 1, 1) 629 paddingLen = 32 - ((len(self.next_proto) + 2) % 32) 630 w.addVarSeq(bytearray(paddingLen), 1, 1) 631 return self.postWrite(w)
632
633 -class Finished(HandshakeMsg):
634 - def __init__(self, version):
635 HandshakeMsg.__init__(self, HandshakeType.finished) 636 self.version = version 637 self.verify_data = bytearray(0)
638
639 - def create(self, verify_data):
640 self.verify_data = verify_data 641 return self
642
643 - def parse(self, p):
644 p.startLengthCheck(3) 645 if self.version == (3,0): 646 self.verify_data = p.getFixBytes(36) 647 elif self.version in ((3,1), (3,2)): 648 self.verify_data = p.getFixBytes(12) 649 else: 650 raise AssertionError() 651 p.stopLengthCheck() 652 return self
653
654 - def write(self):
655 w = Writer() 656 w.addFixSeq(self.verify_data, 1) 657 return self.postWrite(w)
658
659 -class ApplicationData(object):
660 - def __init__(self):
661 self.contentType = ContentType.application_data 662 self.bytes = bytearray(0)
663
664 - def create(self, bytes):
665 self.bytes = bytes 666 return self
667
668 - def splitFirstByte(self):
669 newMsg = ApplicationData().create(self.bytes[:1]) 670 self.bytes = self.bytes[1:] 671 return newMsg
672
673 - def parse(self, p):
674 self.bytes = p.bytes 675 return self
676
677 - def write(self):
678 return self.bytes
679