Coding standards and documentation improvements in tls_nb.py

This commit is contained in:
Stephan Erb 2008-12-24 11:10:58 +00:00
parent 5c02a907b4
commit 1e00674505
1 changed files with 97 additions and 77 deletions

View File

@ -17,11 +17,9 @@
import socket import socket
from client import PlugIn from client import PlugIn
from protocol import *
import sys import sys
import os import os
import errno
import time import time
import traceback import traceback
@ -48,25 +46,27 @@ except ImportError:
print >> sys.stderr, "PyOpenSSL not found, falling back to Python builtin SSL objects (insecure)." print >> sys.stderr, "PyOpenSSL not found, falling back to Python builtin SSL objects (insecure)."
print >> sys.stderr, "=" * 79 print >> sys.stderr, "=" * 79
def torf(cond, tv, fv):
if cond: return tv
return fv
def gattr(obj, attr, default=None): def gattr(obj, attr, default=None):
try: try:
return getattr(obj, attr) return getattr(obj, attr)
except: except AttributeError:
return default return default
class SSLWrapper: class SSLWrapper:
'''
Abstract SSLWrapper base class
'''
class Error(IOError): class Error(IOError):
def __init__(self, sock=None, exc=None, errno=None, strerror=None, peer=None): ''' Generic SSL Error Wrapper '''
def __init__(self, sock=None, exc=None, errno=None, strerror=None,
peer=None):
self.parent = IOError self.parent = IOError
errno = errno or gattr(exc, 'errno') errno = errno or gattr(exc, 'errno')
strerror = strerror or gattr(exc, 'strerror') or gattr(exc, 'args') strerror = strerror or gattr(exc, 'strerror') or gattr(exc, 'args')
if not isinstance(strerror, basestring): strerror = repr(strerror) if not isinstance(strerror, basestring):
strerror = repr(strerror)
self.sock = sock self.sock = sock
self.exc = exc self.exc = exc
@ -97,17 +97,22 @@ class SSLWrapper:
if len(ppeer) == 2 and isinstance(ppeer[0], basestring) \ if len(ppeer) == 2 and isinstance(ppeer[0], basestring) \
and isinstance(ppeer[1], int): and isinstance(ppeer[1], int):
self.peer = ppeer self.peer = ppeer
except: pass except:
pass
def __str__(self): def __str__(self):
s = str(self.__class__) s = str(self.__class__)
if self.peer: s += " for %s:%d" % self.peer if self.peer:
if self.errno is not None: s += ": [Errno: %d]" % self.errno s += " for %s:%d" % self.peer
if self.strerror: s += " (%s)" % self.strerror if self.errno is not None:
s += ": [Errno: %d]" % self.errno
if self.strerror:
s += " (%s)" % self.strerror
if self.exc_name: if self.exc_name:
s += ", Caused by %s" % self.exc_name s += ", Caused by %s" % self.exc_name
if self.exc_str: if self.exc_str:
if self.strerror: s += "(%s)" % self.exc_str if self.strerror:
s += "(%s)" % self.exc_str
else: s += "(%s)" % str(self.exc_args) else: s += "(%s)" % str(self.exc_args)
return s return s
@ -117,18 +122,21 @@ class SSLWrapper:
log.debug("%s.__init__ called with %s", self.__class__, sslobj) log.debug("%s.__init__ called with %s", self.__class__, sslobj)
def recv(self, data, flags=None): def recv(self, data, flags=None):
''' Receive wrapper for SSL object '''
Receive wrapper for SSL object
We can return None out of this function to signal that no data is We can return None out of this function to signal that no data is
available right now. Better than an exception, which differs available right now. Better than an exception, which differs
depending on which SSL lib we're using. Unfortunately returning '' depending on which SSL lib we're using. Unfortunately returning ''
can indicate that the socket has been closed, so to be sure, we avoid can indicate that the socket has been closed, so to be sure, we avoid
this by returning None. ''' this by returning None.
'''
raise NotImplementedError
raise NotImplementedException() def send(self, data, flags=None, now=False):
''' Send wrapper for SSL object '''
raise NotImplementedError
def send(self, data, flags=None, now = False):
raise NotImplementedException()
class PyOpenSSLWrapper(SSLWrapper): class PyOpenSSLWrapper(SSLWrapper):
'''Wrapper class for PyOpenSSL's recv() and send() methods''' '''Wrapper class for PyOpenSSL's recv() and send() methods'''
@ -138,21 +146,24 @@ class PyOpenSSLWrapper(SSLWrapper):
self.parent.__init__(self, *args) self.parent.__init__(self, *args)
def is_numtoolarge(self, e): def is_numtoolarge(self, e):
''' Magic methods don't need documentation '''
t = ('asn1 encoding routines', 'a2d_ASN1_OBJECT', 'first num too large') t = ('asn1 encoding routines', 'a2d_ASN1_OBJECT', 'first num too large')
return isinstance(e.args, (list, tuple)) and len(e.args) == 1 and \ return (isinstance(e.args, (list, tuple)) and len(e.args) == 1 and
isinstance(e.args[0], (list, tuple)) and len(e.args[0]) == 2 and \ isinstance(e.args[0], (list, tuple)) and len(e.args[0]) == 2 and
e.args[0][0] == e.args[0][1] == t e.args[0][0] == e.args[0][1] == t)
def recv(self, bufsize, flags=None): def recv(self, bufsize, flags=None):
retval = None retval = None
try: try:
if flags is None: retval = self.sslobj.recv(bufsize) if flags is None:
else: retval = self.sslobj.recv(bufsize, flags) retval = self.sslobj.recv(bufsize)
else:
retval = self.sslobj.recv(bufsize, flags)
except (OpenSSL.SSL.WantReadError, OpenSSL.SSL.WantWriteError), e: except (OpenSSL.SSL.WantReadError, OpenSSL.SSL.WantWriteError), e:
log.debug("Recv: Want-error: " + repr(e)) log.debug("Recv: Want-error: " + repr(e))
except OpenSSL.SSL.SysCallError, e: except OpenSSL.SSL.SysCallError, e:
log.debug("Recv: Got OpenSSL.SSL.SysCallError: " + repr(e), exc_info=True) log.debug("Recv: Got OpenSSL.SSL.SysCallError: " + repr(e),
#traceback.print_exc() exc_info=True)
raise SSLWrapper.Error(self.sock or self.sslobj, e) raise SSLWrapper.Error(self.sock or self.sslobj, e)
except OpenSSL.SSL.Error, e: except OpenSSL.SSL.Error, e:
if self.is_numtoolarge(e): if self.is_numtoolarge(e):
@ -160,22 +171,21 @@ class PyOpenSSLWrapper(SSLWrapper):
log.warning("Recv: OpenSSL: asn1enc: first num too large (ignored)") log.warning("Recv: OpenSSL: asn1enc: first num too large (ignored)")
else: else:
log.debug("Recv: Caught OpenSSL.SSL.Error:", exc_info=True) log.debug("Recv: Caught OpenSSL.SSL.Error:", exc_info=True)
#traceback.print_exc()
#print "Current Stack:"
#traceback.print_stack()
raise SSLWrapper.Error(self.sock or self.sslobj, e) raise SSLWrapper.Error(self.sock or self.sslobj, e)
return retval return retval
def send(self, data, flags=None, now = False): def send(self, data, flags=None, now=False):
try: try:
if flags is None: return self.sslobj.send(data) if flags is None:
else: return self.sslobj.send(data, flags) return self.sslobj.send(data)
else:
return self.sslobj.send(data, flags)
except (OpenSSL.SSL.WantReadError, OpenSSL.SSL.WantWriteError), e: except (OpenSSL.SSL.WantReadError, OpenSSL.SSL.WantWriteError), e:
#log.debug("Send: " + repr(e)) #log.debug("Send: " + repr(e))
time.sleep(0.1) # prevent 100% CPU usage time.sleep(0.1) # prevent 100% CPU usage
except OpenSSL.SSL.SysCallError, e: except OpenSSL.SSL.SysCallError, e:
log.error("Send: Got OpenSSL.SSL.SysCallError: " + repr(e), exc_info=True) log.error("Send: Got OpenSSL.SSL.SysCallError: " + repr(e),
#traceback.print_exc() exc_info=True)
raise SSLWrapper.Error(self.sock or self.sslobj, e) raise SSLWrapper.Error(self.sock or self.sslobj, e)
except OpenSSL.SSL.Error, e: except OpenSSL.SSL.Error, e:
if self.is_numtoolarge(e): if self.is_numtoolarge(e):
@ -183,12 +193,10 @@ class PyOpenSSLWrapper(SSLWrapper):
log.warning("Send: OpenSSL: asn1enc: first num too large (ignored)") log.warning("Send: OpenSSL: asn1enc: first num too large (ignored)")
else: else:
log.error("Send: Caught OpenSSL.SSL.Error:", exc_info=True) log.error("Send: Caught OpenSSL.SSL.Error:", exc_info=True)
#traceback.print_exc()
#print "Current Stack:"
#traceback.print_stack()
raise SSLWrapper.Error(self.sock or self.sslobj, e) raise SSLWrapper.Error(self.sock or self.sslobj, e)
return 0 return 0
class StdlibSSLWrapper(SSLWrapper): class StdlibSSLWrapper(SSLWrapper):
'''Wrapper class for Python socket.ssl read() and write() methods''' '''Wrapper class for Python socket.ssl read() and write() methods'''
@ -218,7 +226,12 @@ class StdlibSSLWrapper(SSLWrapper):
class NonBlockingTLS(PlugIn): class NonBlockingTLS(PlugIn):
''' TLS connection used to encrypts already estabilished tcp connection.''' '''
TLS connection used to encrypts already estabilished tcp connection.
Can be plugged into NonBlockingTCP and will make use of StdlibSSLWrapper or
PyOpenSSLWrapper.
'''
def __init__(self, cacerts, mycerts): def __init__(self, cacerts, mycerts):
''' '''
@ -238,7 +251,8 @@ class NonBlockingTLS(PlugIn):
def PlugIn(self, owner): def PlugIn(self, owner):
''' '''
start using encryption immediately Use to PlugIn TLS into transport and start establishing immediately
Returns True if TLS/SSL was established correctly, otherwise False.
''' '''
log.info('Starting TLS estabilishing') log.info('Starting TLS estabilishing')
PlugIn.PlugIn(self, owner) PlugIn.PlugIn(self, owner)
@ -249,14 +263,12 @@ class NonBlockingTLS(PlugIn):
return False return False
return res return res
def _dumpX509(self, cert, stream=sys.stderr): def _dumpX509(self, cert, stream=sys.stderr):
print >> stream, "Digest (SHA-1):", cert.digest("sha1") print >> stream, "Digest (SHA-1):", cert.digest("sha1")
print >> stream, "Digest (MD5):", cert.digest("md5") print >> stream, "Digest (MD5):", cert.digest("md5")
print >> stream, "Serial #:", cert.get_serial_number() print >> stream, "Serial #:", cert.get_serial_number()
print >> stream, "Version:", cert.get_version() print >> stream, "Version:", cert.get_version()
print >> stream, "Expired:", torf(cert.has_expired(), "Yes", "No") print >> stream, "Expired:", ("Yes" if cert.has_expired() else "No")
print >> stream, "Subject:" print >> stream, "Subject:"
self._dumpX509Name(cert.get_subject(), stream) self._dumpX509Name(cert.get_subject(), stream)
print >> stream, "Issuer:" print >> stream, "Issuer:"
@ -267,23 +279,51 @@ class NonBlockingTLS(PlugIn):
print >> stream, "X509Name:", str(name) print >> stream, "X509Name:", str(name)
def _dumpPKey(self, pkey, stream=sys.stderr): def _dumpPKey(self, pkey, stream=sys.stderr):
typedict = {OpenSSL.crypto.TYPE_RSA: "RSA", OpenSSL.crypto.TYPE_DSA: "DSA"} typedict = {OpenSSL.crypto.TYPE_RSA: "RSA",
OpenSSL.crypto.TYPE_DSA: "DSA"}
print >> stream, "PKey bits:", pkey.bits() print >> stream, "PKey bits:", pkey.bits()
print >> stream, "PKey type: %s (%d)" % (typedict.get(pkey.type(), "Unknown"), pkey.type()) print >> stream, "PKey type: %s (%d)" % (typedict.get(pkey.type(),
"Unknown"), pkey.type())
def _startSSL(self): def _startSSL(self):
''' Immediatedly switch socket to TLS mode. Used internally.''' ''' Immediatedly switch socket to TLS mode. Used internally.'''
log.debug("_startSSL called") log.debug("_startSSL called")
if USE_PYOPENSSL: result = self._startSSL_pyOpenSSL() if USE_PYOPENSSL:
else: result = self._startSSL_stdlib() result = self._startSSL_pyOpenSSL()
else:
result = self._startSSL_stdlib()
if result: if result:
log.debug("Synchronous handshake completed") log.debug('Synchronous handshake completed')
return True return True
else: else:
return False return False
def _load_user_certs(self, cert_path, cert_store):
if not os.path.isfile(cert_path):
return
f = open(cert_path)
lines = f.readlines()
i = 0
begin = -1
for line in lines:
if 'BEGIN CERTIFICATE' in line:
begin = i
elif 'END CERTIFICATE' in line and begin > -1:
cert = ''.join(lines[begin:i+2])
try:
x509cert = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_PEM, cert)
cert_store.add_cert(x509cert)
except OpenSSL.crypto.Error, exception_obj:
log.warning('Unable to load a certificate from file %s: %s' %\
(self.mycerts, exception_obj.args[0][0][2]))
except:
log.warning('Unknown error while loading certificate from file%s'
% self.mycerts)
begin = -1
i += 1
def _startSSL_pyOpenSSL(self): def _startSSL_pyOpenSSL(self):
log.debug("_startSSL_pyOpenSSL called") log.debug("_startSSL_pyOpenSSL called")
@ -292,39 +332,18 @@ class NonBlockingTLS(PlugIn):
#tcpsock._sslContext = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD) #tcpsock._sslContext = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
tcpsock._sslContext = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) tcpsock._sslContext = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
tcpsock.ssl_errnum = 0 tcpsock.ssl_errnum = 0
tcpsock._sslContext.set_verify(OpenSSL.SSL.VERIFY_PEER, self._ssl_verify_callback) tcpsock._sslContext.set_verify(OpenSSL.SSL.VERIFY_PEER,
self._ssl_verify_callback)
try: try:
tcpsock._sslContext.load_verify_locations(self.cacerts) tcpsock._sslContext.load_verify_locations(self.cacerts)
except: except:
log.warning('Unable to load SSL certificates from file %s' % \ log.warning('Unable to load SSL certificates from file %s' % \
os.path.abspath(self.cacerts)) os.path.abspath(self.cacerts))
# load users certs self._load_user_certs(self.mycerts, tcpsock._sslContext.get_cert_store())
if os.path.isfile(self.mycerts):
store = tcpsock._sslContext.get_cert_store()
f = open(self.mycerts)
lines = f.readlines()
i = 0
begin = -1
for line in lines:
if 'BEGIN CERTIFICATE' in line:
begin = i
elif 'END CERTIFICATE' in line and begin > -1:
cert = ''.join(lines[begin:i+2])
try:
X509cert = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_PEM, cert)
store.add_cert(X509cert)
except OpenSSL.crypto.Error, exception_obj:
log.warning('Unable to load a certificate from file %s: %s' %\
(self.mycerts, exception_obj.args[0][0][2]))
except:
log.warning('Unknown error while loading certificate from file %s' %\
self.mycerts)
begin = -1
i += 1
tcpsock._sslObj = OpenSSL.SSL.Connection(tcpsock._sslContext, tcpsock._sock)
tcpsock._sslObj.set_connect_state() # set to client mode
tcpsock._sslObj = OpenSSL.SSL.Connection(tcpsock._sslContext,
tcpsock._sock)
tcpsock._sslObj.set_connect_state() # set to client mode
wrapper = PyOpenSSLWrapper(tcpsock._sslObj) wrapper = PyOpenSSLWrapper(tcpsock._sslObj)
tcpsock._recv = wrapper.recv tcpsock._recv = wrapper.recv
tcpsock._send = wrapper.send tcpsock._send = wrapper.send
@ -340,7 +359,6 @@ class NonBlockingTLS(PlugIn):
self._owner.ssl_lib = PYOPENSSL self._owner.ssl_lib = PYOPENSSL
return True return True
def _startSSL_stdlib(self): def _startSSL_stdlib(self):
log.debug("_startSSL_stdlib called") log.debug("_startSSL_stdlib called")
tcpsock=self._owner tcpsock=self._owner
@ -371,5 +389,7 @@ class NonBlockingTLS(PlugIn):
return True return True
except: except:
log.error("Exception caught in _ssl_info_callback:", exc_info=True) log.error("Exception caught in _ssl_info_callback:", exc_info=True)
traceback.print_exc() # Make sure something is printed, even if log is disabled. # Make sure something is printed, even if log is disabled.
traceback.print_exc()
# vim: se ts=3: