606 lines
19 KiB
Python
606 lines
19 KiB
Python
## xmlstream.py
|
|
##
|
|
## Copyright (C) 2001 Matthew Allum
|
|
##
|
|
## This program is free software; you can redistribute it and/or modify
|
|
## it under the terms of the GNU Lesser General Public License as published
|
|
## by the Free Software Foundation; either version 2, or (at your option)
|
|
## any later version.
|
|
##
|
|
## This program is distributed in the hope that it will be useful,
|
|
## but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
## GNU Lesser General Public License for more details.
|
|
|
|
|
|
"""\
|
|
xmlstream.py provides simple functionality for implementing
|
|
XML stream based network protocols. It is used as a base
|
|
for jabber.py.
|
|
|
|
xmlstream.py manages the network connectivity and xml parsing
|
|
of the stream. When a complete 'protocol element' ( meaning a
|
|
complete child of the xmlstreams root ) is parsed the dipatch
|
|
method is called with a 'Node' instance of this structure.
|
|
The Node class is a very simple XML DOM like class for
|
|
manipulating XML documents or 'protocol elements' in this
|
|
case.
|
|
|
|
"""
|
|
|
|
# $Id: xmlstream.py,v 1.26 2003/02/20 10:22:33 shire Exp $
|
|
|
|
import site
|
|
site.encoding = 'UTF-8'
|
|
import time, sys, re, socket
|
|
from select import select
|
|
from string import split,find,replace,join
|
|
import xml.parsers.expat
|
|
|
|
VERSION = 0.3
|
|
|
|
False = 0
|
|
True = 1
|
|
|
|
TCP = 1
|
|
STDIO = 0
|
|
TCP_SSL = 2
|
|
|
|
ENCODING = site.encoding
|
|
|
|
BLOCK_SIZE = 1024 ## Number of bytes to get at at time via socket
|
|
## transactions
|
|
|
|
|
|
def XMLescape(txt):
|
|
"Escape XML entities"
|
|
txt = replace(txt, "&", "&")
|
|
txt = replace(txt, "<", "<")
|
|
txt = replace(txt, ">", ">")
|
|
return txt
|
|
|
|
def XMLunescape(txt):
|
|
"Unescape XML entities"
|
|
txt = replace(txt, "<", "<")
|
|
txt = replace(txt, ">", ">")
|
|
txt = replace(txt, "&", "&")
|
|
return txt
|
|
|
|
class error:
|
|
def __init__(self, value):
|
|
self.value = str(value)
|
|
def __str__(self):
|
|
return self.value
|
|
|
|
class Node:
|
|
"""A simple XML DOM like class"""
|
|
def __init__(self, tag='', parent=None, attrs=None ):
|
|
bits = split(tag)
|
|
if len(bits) == 1:
|
|
self.name = tag
|
|
self.namespace = ''
|
|
else:
|
|
self.namespace, self.name = bits
|
|
|
|
if attrs is None:
|
|
self.attrs = {}
|
|
else:
|
|
self.attrs = attrs
|
|
|
|
self.data = []
|
|
self.kids = []
|
|
self.parent = parent
|
|
|
|
def setParent(self, node):
|
|
"Set the nodes parent node."
|
|
self.parent = node
|
|
|
|
def getParent(self):
|
|
"return the nodes parent node."
|
|
return self.parent
|
|
|
|
def getName(self):
|
|
"Set the nodes tag name."
|
|
return self.name
|
|
|
|
def setName(self,val):
|
|
"Set the nodes tag name."
|
|
self.name = val
|
|
|
|
def putAttr(self, key, val):
|
|
"Add a name/value attribute to the node."
|
|
self.attrs[key] = val
|
|
|
|
def getAttr(self, key):
|
|
"Get a value for the nodes named attribute."
|
|
try: return self.attrs[key]
|
|
except: return None
|
|
|
|
def putData(self, data):
|
|
"Set the nodes textual data"
|
|
self.data.append(data)
|
|
|
|
def insertData(self, data):
|
|
"Set the nodes textual data"
|
|
self.data.append(data)
|
|
|
|
def getData(self):
|
|
"Return the nodes textual data"
|
|
return join(self.data, '')
|
|
|
|
def getDataAsParts(self):
|
|
"Return the node data as an array"
|
|
return self.data
|
|
|
|
def getNamespace(self):
|
|
"Returns the nodes namespace."
|
|
return self.namespace
|
|
|
|
def setNamespace(self, namespace):
|
|
"Set the nodes namespace."
|
|
self.namespace = namespace
|
|
|
|
def insertTag(self, name):
|
|
""" Add a child tag of name 'name' to the node.
|
|
|
|
Returns the newly created node.
|
|
"""
|
|
newnode = Node(tag=name, parent=self)
|
|
self.kids.append(newnode)
|
|
return newnode
|
|
|
|
def insertNode(self, node):
|
|
"Add a child node to the node"
|
|
self.kids.append(node)
|
|
return node
|
|
|
|
def insertXML(self, xml_str):
|
|
"Add raw xml as a child of the node"
|
|
newnode = NodeBuilder(xml_str).getDom()
|
|
self.kids.append(newnode)
|
|
return newnode
|
|
|
|
def __str__(self):
|
|
return self._xmlnode2str()
|
|
|
|
def _xmlnode2str(self, parent=None):
|
|
"""Returns an xml ( string ) representation of the node
|
|
and it children"""
|
|
s = "<" + self.name
|
|
if self.namespace:
|
|
if parent and parent.namespace != self.namespace:
|
|
s = s + " xmlns = '%s' " % self.namespace
|
|
for key in self.attrs.keys():
|
|
val = str(self.attrs[key])
|
|
s = s + " %s='%s'" % ( key, XMLescape(val) )
|
|
s = s + ">"
|
|
cnt = 0
|
|
if self.kids != None:
|
|
for a in self.kids:
|
|
if (len(self.data)-1) >= cnt: s = s + XMLescape(self.data[cnt])
|
|
s = s + a._xmlnode2str(parent=self)
|
|
cnt=cnt+1
|
|
if (len(self.data)-1) >= cnt: s = s + XMLescape(self.data[cnt])
|
|
s = s + "</" + self.name + ">"
|
|
return s
|
|
|
|
def getTag(self, name):
|
|
"""Returns a child node with tag name. Returns None
|
|
if not found."""
|
|
for node in self.kids:
|
|
if node.getName() == name:
|
|
return node
|
|
return None
|
|
|
|
def getTags(self, name):
|
|
"""Like getTag but returns a list with matching child nodes"""
|
|
nodes=[]
|
|
for node in self.kids:
|
|
if node.getName() == name:
|
|
nodes.append(node)
|
|
return nodes
|
|
|
|
|
|
def getChildren(self):
|
|
"""Returns a nodes children"""
|
|
return self.kids
|
|
|
|
class NodeBuilder:
|
|
"""builds a 'minidom' from data parsed to it. Primarily for insertXML
|
|
method of Node"""
|
|
def __init__(self,data):
|
|
self._parser = xml.parsers.expat.ParserCreate(namespace_separator=' ')
|
|
self._parser.StartElementHandler = self.unknown_starttag
|
|
self._parser.EndElementHandler = self.unknown_endtag
|
|
self._parser.CharacterDataHandler = self.handle_data
|
|
|
|
self.__depth = 0
|
|
self.__done = 0 #needed ?
|
|
self.__space_regex = re.compile('^\s+$')
|
|
|
|
self._parser.Parse(data,1)
|
|
|
|
def unknown_starttag(self, tag, attrs):
|
|
self.__depth = self.__depth + 1
|
|
if self.__depth == 1:
|
|
self._mini_dom = Node(tag=tag, attrs=attrs)
|
|
self._ptr = self._mini_dom
|
|
elif self.__depth > 1:
|
|
self._ptr.kids.append(Node(tag =tag,
|
|
parent=self._ptr,
|
|
attrs =attrs ))
|
|
self._ptr = self._ptr.kids[-1]
|
|
else: ## fix this ....
|
|
pass
|
|
|
|
def unknown_endtag(self, tag ):
|
|
self.__depth = self.__depth - 1
|
|
if self.__depth == 0:
|
|
self.dispatch(self._mini_dom)
|
|
elif self.__depth > 0:
|
|
self._ptr = self._ptr.parent
|
|
else:
|
|
pass
|
|
|
|
def handle_data(self, data):
|
|
if not self.__space_regex.match(data): ## check its not all blank
|
|
self._ptr.data.append(data)
|
|
|
|
def dispatch(self,dom):
|
|
self.__done = 1
|
|
|
|
def getDom(self):
|
|
return self._mini_dom
|
|
|
|
|
|
class Stream:
|
|
def __init__(
|
|
self, host, port, namespace,
|
|
debug=True,
|
|
log=None,
|
|
sock=None,
|
|
id=None,
|
|
connection=TCP
|
|
):
|
|
|
|
|
|
self._parser = xml.parsers.expat.ParserCreate(namespace_separator=' ')
|
|
self._parser.StartElementHandler = self._unknown_starttag
|
|
self._parser.EndElementHandler = self._unknown_endtag
|
|
self._parser.CharacterDataHandler = self._handle_data
|
|
|
|
self._host = host
|
|
self._port = port
|
|
self._namespace = namespace
|
|
self.__depth = 0
|
|
self._sock = sock
|
|
|
|
self._sslObj = None
|
|
self._sslIssuer = None
|
|
self._sslServer = None
|
|
|
|
self._incomingID = None
|
|
self._outgoingID = id
|
|
|
|
self._debug = debug
|
|
self._connection=connection
|
|
|
|
self.DEBUG("stream init called")
|
|
|
|
if log:
|
|
if type(log) is type(""):
|
|
try:
|
|
self._logFH = open(log,'w')
|
|
except:
|
|
print "ERROR: can open %s for writing"
|
|
sys.exit(0)
|
|
else: ## assume its a stream type object
|
|
self._logFH = log
|
|
else:
|
|
self._logFH = None
|
|
self._timestampLog = True
|
|
|
|
def timestampLog(self,timestamp):
|
|
""" Enable or disable the showing of a timestamp in the log.
|
|
By default, timestamping is enabled.
|
|
"""
|
|
self._timestampLog = timestamp
|
|
|
|
def DEBUG(self,txt):
|
|
if self._debug:
|
|
try:
|
|
sys.stderr.write("DEBUG: %s\n" % txt)
|
|
except:
|
|
# unicode strikes again ;)
|
|
s=u''
|
|
for i in range(len(txt)):
|
|
if ord(txt[i]) < 128:
|
|
c = txt[i]
|
|
else:
|
|
c = '?'
|
|
s=s+c
|
|
sys.stderr.write("DEBUG: %s\n" % s )
|
|
|
|
def getSocket(self):
|
|
return self._sock
|
|
|
|
def header(self):
|
|
self.DEBUG("stream: sending initial header")
|
|
str = u"<?xml version='1.0' encoding='UTF-8' ?> \
|
|
<stream:stream to='%s' xmlns='%s'" % ( self._host,
|
|
self._namespace )
|
|
|
|
if self._outgoingID: str = str + " id='%s' " % self._outgoingID
|
|
str = str + " xmlns:stream='http://etherx.jabber.org/streams'>"
|
|
self.write (str)
|
|
self.read()
|
|
|
|
def _handle_data(self, data):
|
|
"""XML Parser callback"""
|
|
self.DEBUG("data-> " + data)
|
|
## TODO: get rid of empty space
|
|
## self._ptr.data = self._ptr.data + data
|
|
self._ptr.data.append(data)
|
|
|
|
def _unknown_starttag(self, tag, attrs):
|
|
"""XML Parser callback"""
|
|
self.__depth = self.__depth + 1
|
|
self.DEBUG("DEPTH -> %i , tag -> %s, attrs -> %s" % \
|
|
(self.__depth, tag, str(attrs)) )
|
|
if self.__depth == 2:
|
|
self._mini_dom = Node(tag=tag, attrs=attrs)
|
|
self._ptr = self._mini_dom
|
|
elif self.__depth > 2:
|
|
self._ptr.kids.append(Node(tag=tag,parent=self._ptr,attrs=attrs))
|
|
self._ptr = self._ptr.kids[-1]
|
|
else: ## it the stream tag:
|
|
if attrs.has_key('id'):
|
|
self._incomingID = attrs['id']
|
|
|
|
def _unknown_endtag(self, tag ):
|
|
"""XML Parser callback"""
|
|
self.__depth = self.__depth - 1
|
|
self.DEBUG("DEPTH -> %i" % self.__depth)
|
|
if self.__depth == 1:
|
|
self.dispatch(self._mini_dom)
|
|
elif self.__depth > 1:
|
|
self._ptr = self._ptr.parent
|
|
else:
|
|
self.DEBUG("*** Server closed connection ? ****")
|
|
|
|
def dispatch(self, nodes, depth = 0):
|
|
"""Overide with the method you want to called with
|
|
a node structure of a 'protocol element."""
|
|
|
|
padding = ' '
|
|
padding = padding * depth
|
|
depth = depth + 1
|
|
for n in nodes:
|
|
if n.kids != None:
|
|
self.dispatch(n.kids, depth)
|
|
|
|
##def syntax_error(self, message):
|
|
## self.DEBUG("error " + message)
|
|
|
|
def _do_read( self, action, buff_size ):
|
|
"""workhorse for read() method.
|
|
|
|
added 021231 by jaclu"""
|
|
data=''
|
|
data_in = action(buff_size)
|
|
while data_in:
|
|
data = data + data_in
|
|
if len(data_in) != buff_size:
|
|
break
|
|
data_in = action(buff_size)
|
|
return data
|
|
|
|
def read(self):
|
|
"""Reads incoming data. Called by process() so nonblocking
|
|
|
|
changed 021231 by jaclu
|
|
"""
|
|
if self._connection == TCP:
|
|
raw_data = self._do_read(self._sock.recv, BLOCK_SIZE)
|
|
elif self._connection == TCP_SSL:
|
|
raw_data = self._do_read(self._sslObj.read, BLOCK_SIZE)
|
|
elif self._connection == STDIO:
|
|
raw_data = self._do_read(self.stdin.read, 1024)
|
|
else:
|
|
raw_data = '' # should never get here
|
|
|
|
# just encode incoming data once!
|
|
data = unicode(raw_data,'utf-8').encode(ENCODING,'replace')
|
|
self.DEBUG("got data %s" % data )
|
|
self.log(data, 'RECV:')
|
|
self._parser.Parse(data)
|
|
return data
|
|
|
|
def write(self,raw_data=u''):
|
|
"""Writes raw outgoing data. blocks
|
|
|
|
changed 021231 by jaclu, added unicode encoding
|
|
"""
|
|
if type(raw_data) == type(u''):
|
|
data_out = raw_data.encode('utf-8','replace')
|
|
else:
|
|
# since not suplied as unicode, we must guess at
|
|
# what the data is, iso-8859-1 seems reasonable.
|
|
# To avoid this auto assumption,
|
|
# send your data as a unicode string!
|
|
data_out = unicode(raw_data,'iso-8859-1').encode(ENCODING,'replace')
|
|
try:
|
|
if self._connection == TCP:
|
|
self._sock.send (data_out)
|
|
elif self._connection == TCP_SSL:
|
|
self._sslObj.write(data_out)
|
|
elif self._connection == STDIO:
|
|
self.stdout.write(data_out)
|
|
else:
|
|
pass
|
|
self.log(data_out, 'SENT:')
|
|
self.DEBUG("sent %s" % data_out)
|
|
except:
|
|
self.DEBUG("xmlstream write threw error")
|
|
self.disconnected()
|
|
|
|
def process(self,timeout):
|
|
|
|
reader=Node
|
|
|
|
if self._connection == TCP:
|
|
reader = self._sock
|
|
elif self._connection == TCP_SSL:
|
|
reader = self._sock
|
|
elif self._connection == STDIO:
|
|
reader = sys.stdin
|
|
else:
|
|
pass
|
|
|
|
ready_for_read,ready_for_write,err = \
|
|
select( [reader],[],[],timeout)
|
|
for s in ready_for_read:
|
|
if s == reader:
|
|
if not len(self.read()): # length of 0 means disconnect
|
|
## raise error("network error") ?
|
|
self.disconnected()
|
|
return False
|
|
return True
|
|
return False
|
|
|
|
def disconnect(self):
|
|
"""Close the stream and socket"""
|
|
self.write ( "</stream:stream>" )
|
|
self._sock.close()
|
|
self._sock = None
|
|
|
|
def disconnected(self): ## To be overidden ##
|
|
"""Called when a Network Error or disconnection occurs.
|
|
Designed to be overidden"""
|
|
self.DEBUG("Network Disconnection")
|
|
pass
|
|
|
|
def log(self, data, inout=''):
|
|
"""Logs data to the specified filehandle. Data is time stamped
|
|
and prefixed with inout"""
|
|
if self._logFH is not None:
|
|
if self._timestampLog:
|
|
self._logFH.write("%s - %s - %s\n" % (time.asctime(), inout, data))
|
|
else:
|
|
self._logFH.write("%s - %s\n" % (inout, data ) )
|
|
self._logFH.flush()
|
|
|
|
def getIncomingID(self):
|
|
"""Returns the streams ID"""
|
|
return self._incomingID
|
|
|
|
def getOutgoingID(self):
|
|
"""Returns the streams ID"""
|
|
return self._incomingID
|
|
|
|
|
|
class Client(Stream):
|
|
|
|
def connect(self):
|
|
"""Attempt to connect to specified host"""
|
|
|
|
self.DEBUG("client connect called to %s %s type %i" % (self._host,
|
|
self._port,
|
|
self._connection) )
|
|
|
|
## TODO: check below that stdin/stdout are actually open
|
|
if self._connection == STDIO: return
|
|
|
|
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
try:
|
|
self._sock.connect((self._host, self._port))
|
|
except socket.error, e:
|
|
self.DEBUG("socket error")
|
|
raise error(e)
|
|
|
|
if self._connection == TCP_SSL:
|
|
try:
|
|
self.DEBUG("Attempting to create ssl socket")
|
|
self._sslObj = socket.ssl( self._sock, None, None )
|
|
self._sslIssuer = self._sslObj.issuer()
|
|
self._sslServer = self._sslObj.server()
|
|
except:
|
|
self.DEBUG("Socket Error: No SSL Support")
|
|
raise error("No SSL Support")
|
|
|
|
self.DEBUG("connected")
|
|
self.header()
|
|
return 0
|
|
|
|
class Server:
|
|
|
|
def now(self): return time.ctime(time.time())
|
|
|
|
def __init__(self, maxclients=10):
|
|
|
|
self.host = ''
|
|
self.port = 5222
|
|
self.streams = []
|
|
|
|
# make main sockets for accepting new client requests
|
|
self.mainsocks, self.readsocks, self.writesocks = [], [], []
|
|
|
|
self.portsock = socket(AF_INET, SOCK_STREAM)
|
|
self.portsock.bind((self.host, self.port))
|
|
self.portsock.listen(maxclients)
|
|
|
|
self.mainsocks.append(self.portsock) # add to main list to identify
|
|
self.readsocks.append(self.portsock) # add to select inputs list
|
|
|
|
# event loop: listen and multiplex until server process killed
|
|
|
|
|
|
def serve(self):
|
|
|
|
print 'select-server loop starting'
|
|
|
|
while 1:
|
|
print "LOOPING"
|
|
readables, writeables, exceptions = select(self.readsocks,
|
|
self.writesocks, [])
|
|
for sockobj in readables:
|
|
if sockobj in self. mainsocks: # for ready input sockets
|
|
newsock, address = sockobj.accept() # accept not block
|
|
print 'Connect:', address, id(newsock)
|
|
self.readsocks.append(newsock)
|
|
self._makeNewStream(newsock)
|
|
# add to select list, wait
|
|
else:
|
|
# client socket: read next line
|
|
data = sockobj.recv(1024)
|
|
# recv should not block
|
|
print '\tgot', data, 'on', id(sockobj)
|
|
if not data: # if closed by the clients
|
|
sockobj.close() # close here and remv from
|
|
self.readsocks.remove(sockobj)
|
|
else:
|
|
# this may block: should really select for writes too
|
|
sockobj.send('Echo=>%s' % data)
|
|
|
|
def _makeNewStream(self, sckt):
|
|
new_stream = Stream('localhost', 5222,
|
|
'jabber:client',
|
|
sock=sckt)
|
|
self.streams.append(new_stream)
|
|
## maybe overide for a 'server stream'
|
|
new_stream.header()
|
|
return new_stream
|
|
|
|
def _getStreamSockets(self):
|
|
socks = [];
|
|
for s in self.streams:
|
|
socks.append(s.getSocket())
|
|
return socks
|
|
|
|
def _getStreamFromSocket(self, sock):
|
|
for s in self.streams:
|
|
if s.getSocket() == sock:
|
|
return s
|
|
return None
|
|
|