469 lines
16 KiB
Python
469 lines
16 KiB
Python
#
|
|
# (c) Dave Kirby 2001 - 2005
|
|
# mock@thedeveloperscoach.com
|
|
#
|
|
# Original call interceptor and call assertion code by Phil Dawes (pdawes@users.sourceforge.net)
|
|
# Call interceptor code enhanced by Bruce Cropley (cropleyb@yahoo.com.au)
|
|
#
|
|
# This Python module and associated files are released under the FreeBSD
|
|
# license. Essentially, you can do what you like with it except pretend you wrote
|
|
# it yourself.
|
|
#
|
|
#
|
|
# Copyright (c) 2005, Dave Kirby
|
|
# Copyright (c) 2009, Yann Leboulanger
|
|
#
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright
|
|
# notice, this list of conditions and the following disclaimer.
|
|
#
|
|
# * Redistributions in binary form must reproduce the above copyright
|
|
# notice, this list of conditions and the following disclaimer in the
|
|
# documentation and/or other materials provided with the distribution.
|
|
#
|
|
# * Neither the name of this library nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
|
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
|
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
# mock@thedeveloperscoach.com
|
|
|
|
|
|
"""
|
|
Mock object library for Python. Mock objects can be used when unit testing
|
|
to remove a dependency on another production class. They are typically used
|
|
when the dependency would either pull in lots of other classes, or
|
|
significantly slow down the execution of the test.
|
|
They are also used to create exceptional conditions that cannot otherwise
|
|
be easily triggered in the class under test.
|
|
"""
|
|
|
|
__version__ = "0.1.0"
|
|
|
|
# Added in Python 2.1
|
|
import inspect
|
|
import re
|
|
|
|
class MockInterfaceError(Exception):
|
|
pass
|
|
|
|
class Mock(object):
|
|
"""
|
|
The Mock class emulates any other class for testing purposes.
|
|
All method calls are stored for later examination.
|
|
"""
|
|
|
|
def __init__(self, returnValues=None, realClass=None):
|
|
"""
|
|
The Mock class constructor takes a dictionary of method names and
|
|
the values they return. Methods that are not in the returnValues
|
|
dictionary will return None.
|
|
You may also supply a class whose interface is being mocked.
|
|
All calls will be checked to see if they appear in the original
|
|
interface. Any calls to methods not appearing in the real class
|
|
will raise a MockInterfaceError. Any calls that would fail due to
|
|
non-matching parameter lists will also raise a MockInterfaceError.
|
|
Both of these help to prevent the Mock class getting out of sync
|
|
with the class it is Mocking.
|
|
"""
|
|
self.mockCalledMethods = {}
|
|
self.mockAllCalledMethods = []
|
|
self.mockReturnValues = returnValues or {}
|
|
self.mockExpectations = {}
|
|
self.realClass = realClass
|
|
self.realClassMethods = None
|
|
if realClass:
|
|
self.realClassMethods = dict(inspect.getmembers(realClass, inspect.isroutine))
|
|
for retMethod in self.mockReturnValues.keys():
|
|
if retMethod not in self.realClassMethods:
|
|
raise MockInterfaceError("Return value supplied for method '%s' that was not in the original class" % retMethod)
|
|
self._setupSubclassMethodInterceptors()
|
|
|
|
def _setupSubclassMethodInterceptors(self):
|
|
methods = inspect.getmembers(self.realClass,inspect.isroutine)
|
|
baseMethods = dict(inspect.getmembers(Mock, inspect.ismethod))
|
|
for m in methods:
|
|
name = m[0]
|
|
# Don't record calls to methods of Mock base class.
|
|
if not name in baseMethods:
|
|
self.__dict__[name] = MockCallable(name, self, handcrafted=True)
|
|
|
|
def __getattr__(self, name):
|
|
return MockCallable(name, self)
|
|
|
|
def mockAddReturnValues(self, **methodReturnValues ):
|
|
self.mockReturnValues.update(methodReturnValues)
|
|
|
|
def mockSetExpectation(self, name, testFn, after=0, until=0):
|
|
self.mockExpectations.setdefault(name, []).append((testFn,after,until))
|
|
|
|
def _checkInterfaceCall(self, name, callParams, callKwParams):
|
|
"""
|
|
Check that a call to a method of the given name to the original
|
|
class with the given parameters would not fail. If it would fail,
|
|
raise a MockInterfaceError.
|
|
Based on the Python 2.3.3 Reference Manual section 5.3.4: Calls.
|
|
"""
|
|
if self.realClassMethods is None:
|
|
return
|
|
if name not in self.realClassMethods:
|
|
return
|
|
|
|
func = self.realClassMethods[name]
|
|
try:
|
|
args, varargs, varkw, defaults = inspect.getargspec(func)
|
|
except TypeError:
|
|
# func is not a Python function. It is probably a builtin,
|
|
# such as __repr__ or __coerce__. TODO: Checking?
|
|
# For now assume params are OK.
|
|
return
|
|
|
|
# callParams doesn't include self; args does include self.
|
|
numPosCallParams = 1 + len(callParams)
|
|
|
|
if numPosCallParams > len(args) and not varargs:
|
|
raise MockInterfaceError("Original %s() takes at most %s arguments (%s given)" %
|
|
(name, len(args), numPosCallParams))
|
|
|
|
# Get the number of positional arguments that appear in the call,
|
|
# also check for duplicate parameters and unknown parameters
|
|
numPosSeen = _getNumPosSeenAndCheck(numPosCallParams, callKwParams, args, varkw)
|
|
|
|
lenArgsNoDefaults = len(args) - len(defaults or [])
|
|
if numPosSeen < lenArgsNoDefaults:
|
|
raise MockInterfaceError("Original %s() takes at least %s arguments (%s given)" % (name, lenArgsNoDefaults, numPosSeen))
|
|
|
|
def mockGetAllCalls(self):
|
|
"""
|
|
Return a list of MockCall objects,
|
|
representing all the methods in the order they were called.
|
|
"""
|
|
return self.mockAllCalledMethods
|
|
getAllCalls = mockGetAllCalls # deprecated - kept for backward compatibility
|
|
|
|
def mockGetNamedCalls(self, methodName):
|
|
"""
|
|
Return a list of MockCall objects,
|
|
representing all the calls to the named method in the order they were called.
|
|
"""
|
|
return self.mockCalledMethods.get(methodName, [])
|
|
getNamedCalls = mockGetNamedCalls # deprecated - kept for backward compatibility
|
|
|
|
def mockCheckCall(self, index, name, *args, **kwargs):
|
|
'''test that the index-th call had the specified name and parameters'''
|
|
call = self.mockAllCalledMethods[index]
|
|
assert name == call.getName(), "%r != %r" % (name, call.getName())
|
|
call.checkArgs(*args, **kwargs)
|
|
|
|
|
|
def _getNumPosSeenAndCheck(numPosCallParams, callKwParams, args, varkw):
|
|
"""
|
|
Positional arguments can appear as call parameters either named as
|
|
a named (keyword) parameter, or just as a value to be matched by
|
|
position. Count the positional arguments that are given by either
|
|
keyword or position, and check for duplicate specifications.
|
|
Also check for arguments specified by keyword that do not appear
|
|
in the method's parameter list.
|
|
"""
|
|
posSeen = {}
|
|
for arg in args[:numPosCallParams]:
|
|
posSeen[arg] = True
|
|
for kwp in callKwParams:
|
|
if kwp in posSeen:
|
|
raise MockInterfaceError("%s appears as both a positional and named parameter." % kwp)
|
|
if kwp in args:
|
|
posSeen[kwp] = True
|
|
elif not varkw:
|
|
raise MockInterfaceError("Original method does not have a parameter '%s'" % kwp)
|
|
return len(posSeen)
|
|
|
|
class MockCall:
|
|
"""
|
|
MockCall records the name and parameters of a call to an instance
|
|
of a Mock class. Instances of MockCall are created by the Mock class,
|
|
but can be inspected later as part of the test.
|
|
"""
|
|
def __init__(self, name, params, kwparams ):
|
|
self.name = name
|
|
self.params = params
|
|
self.kwparams = kwparams
|
|
|
|
def checkArgs(self, *args, **kwargs):
|
|
assert args == self.params, "%r != %r" % (args, self.params)
|
|
assert kwargs == self.kwparams, "%r != %r" % (kwargs, self.kwparams)
|
|
|
|
def getParam( self, n ):
|
|
if isinstance(n, int):
|
|
return self.params[n]
|
|
elif isinstance(n, str):
|
|
return self.kwparams[n]
|
|
else:
|
|
raise IndexError, 'illegal index type for getParam'
|
|
|
|
def getNumParams(self):
|
|
return len(self.params)
|
|
|
|
def getNumKwParams(self):
|
|
return len(self.kwparams)
|
|
|
|
def getName(self):
|
|
return self.name
|
|
|
|
#pretty-print the method call
|
|
def __str__(self):
|
|
s = self.name + "("
|
|
sep = ''
|
|
for p in self.params:
|
|
s = s + sep + repr(p)
|
|
sep = ', '
|
|
items = sorted(self.kwparams.items())
|
|
for k,v in items:
|
|
s = s + sep + k + '=' + repr(v)
|
|
sep = ', '
|
|
s = s + ')'
|
|
return s
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
|
|
class MockCallable:
|
|
"""
|
|
Intercepts the call and records it, then delegates to either the mock's
|
|
dictionary of mock return values that was passed in to the constructor,
|
|
or a handcrafted method of a Mock subclass.
|
|
"""
|
|
def __init__(self, name, mock, handcrafted=False):
|
|
self.name = name
|
|
self.mock = mock
|
|
self.handcrafted = handcrafted
|
|
|
|
def __call__(self, *params, **kwparams):
|
|
self.mock._checkInterfaceCall(self.name, params, kwparams)
|
|
thisCall = self.recordCall(params,kwparams)
|
|
self.checkExpectations(thisCall, params, kwparams)
|
|
return self.makeCall(params, kwparams)
|
|
|
|
def recordCall(self, params, kwparams):
|
|
"""
|
|
Record the MockCall in an ordered list of all calls, and an ordered
|
|
list of calls for that method name.
|
|
"""
|
|
thisCall = MockCall(self.name, params, kwparams)
|
|
calls = self.mock.mockCalledMethods.setdefault(self.name, [])
|
|
calls.append(thisCall)
|
|
self.mock.mockAllCalledMethods.append(thisCall)
|
|
return thisCall
|
|
|
|
def makeCall(self, params, kwparams):
|
|
if self.handcrafted:
|
|
allPosParams = (self.mock,) + params
|
|
func = _findFunc(self.mock.realClass, self.name)
|
|
if not func:
|
|
raise NotImplementedError
|
|
return func(*allPosParams, **kwparams)
|
|
else:
|
|
returnVal = self.mock.mockReturnValues.get(self.name)
|
|
if isinstance(returnVal, ReturnValuesBase):
|
|
returnVal = returnVal.next()
|
|
return returnVal
|
|
|
|
def checkExpectations(self, thisCall, params, kwparams):
|
|
if self.name in self.mock.mockExpectations:
|
|
callsMade = len(self.mock.mockCalledMethods[self.name])
|
|
for (expectation, after, until) in self.mock.mockExpectations[self.name]:
|
|
if callsMade > after and (until==0 or callsMade < until):
|
|
assert expectation(self.mock, thisCall, len(self.mock.mockAllCalledMethods)-1), 'Expectation failed: '+str(thisCall)
|
|
|
|
|
|
def _findFunc(cl, name):
|
|
""" Depth first search for a method with a given name. """
|
|
if name in cl.__dict__:
|
|
return cl.__dict__[name]
|
|
for base in cl.__bases__:
|
|
func = _findFunc(base, name)
|
|
if func:
|
|
return func
|
|
return None
|
|
|
|
|
|
|
|
class ReturnValuesBase:
|
|
def next(self):
|
|
try:
|
|
return self.iter.next()
|
|
except StopIteration:
|
|
raise AssertionError("No more return values")
|
|
def __iter__(self):
|
|
return self
|
|
|
|
class ReturnValues(ReturnValuesBase):
|
|
def __init__(self, *values):
|
|
self.iter = iter(values)
|
|
|
|
|
|
class ReturnIterator(ReturnValuesBase):
|
|
def __init__(self, iterator):
|
|
self.iter = iter(iterator)
|
|
|
|
|
|
def expectParams(*params, **keywords):
|
|
'''check that the callObj is called with specified params and keywords
|
|
'''
|
|
def fn(mockObj, callObj, idx):
|
|
return callObj.params == params and callObj.kwparams == keywords
|
|
return fn
|
|
|
|
|
|
def expectAfter(*methods):
|
|
'''check that the function is only called after all the functions in 'methods'
|
|
'''
|
|
def fn(mockObj, callObj, idx):
|
|
calledMethods = [method.getName() for method in mockObj.mockGetAllCalls()]
|
|
#skip last entry, since that is the current call
|
|
calledMethods = calledMethods[:-1]
|
|
for method in methods:
|
|
if method not in calledMethods:
|
|
return False
|
|
return True
|
|
return fn
|
|
|
|
def expectException(exception, *args, **kwargs):
|
|
''' raise an exception when the method is called
|
|
'''
|
|
def fn(mockObj, callObj, idx):
|
|
raise exception(*args, **kwargs)
|
|
return fn
|
|
|
|
|
|
def expectParam(paramIdx, cond):
|
|
'''check that the callObj is called with parameter specified by paramIdx (a position index or keyword)
|
|
fulfills the condition specified by cond.
|
|
cond is a function that takes a single argument, the value to test.
|
|
'''
|
|
def fn(mockObj, callObj, idx):
|
|
param = callObj.getParam(paramIdx)
|
|
return cond(param)
|
|
return fn
|
|
|
|
def EQ(value):
|
|
def testFn(param):
|
|
return param == value
|
|
return testFn
|
|
|
|
def NE(value):
|
|
def testFn(param):
|
|
return param != value
|
|
return testFn
|
|
|
|
def GT(value):
|
|
def testFn(param):
|
|
return param > value
|
|
return testFn
|
|
|
|
def LT(value):
|
|
def testFn(param):
|
|
return param < value
|
|
return testFn
|
|
|
|
def GE(value):
|
|
def testFn(param):
|
|
return param >= value
|
|
return testFn
|
|
|
|
def LE(value):
|
|
def testFn(param):
|
|
return param <= value
|
|
return testFn
|
|
|
|
def AND(*condlist):
|
|
def testFn(param):
|
|
for cond in condlist:
|
|
if not cond(param):
|
|
return False
|
|
return True
|
|
return testFn
|
|
|
|
def OR(*condlist):
|
|
def testFn(param):
|
|
for cond in condlist:
|
|
if cond(param):
|
|
return True
|
|
return False
|
|
return testFn
|
|
|
|
def NOT(cond):
|
|
def testFn(param):
|
|
return not cond(param)
|
|
return testFn
|
|
|
|
def MATCHES(regex, *args, **kwargs):
|
|
compiled_regex = re.compile(regex, *args, **kwargs)
|
|
def testFn(param):
|
|
return compiled_regex.match(param) is not None
|
|
return testFn
|
|
|
|
def SEQ(*sequence):
|
|
iterator = iter(sequence)
|
|
def testFn(param):
|
|
try:
|
|
cond = iterator.next()
|
|
except StopIteration:
|
|
raise AssertionError('SEQ exhausted')
|
|
return cond(param)
|
|
return testFn
|
|
|
|
def IS(instance):
|
|
def testFn(param):
|
|
return param is instance
|
|
return testFn
|
|
|
|
def ISINSTANCE(class_):
|
|
def testFn(param):
|
|
return isinstance(param, class_)
|
|
return testFn
|
|
|
|
def ISSUBCLASS(class_):
|
|
def testFn(param):
|
|
return issubclass(param, class_)
|
|
return testFn
|
|
|
|
def CONTAINS(val):
|
|
def testFn(param):
|
|
return val in param
|
|
return testFn
|
|
|
|
def IN(container):
|
|
def testFn(param):
|
|
return param in container
|
|
return testFn
|
|
|
|
def HASATTR(attr):
|
|
def testFn(param):
|
|
return hasattr(param, attr)
|
|
return testFn
|
|
|
|
def HASMETHOD(method):
|
|
def testFn(param):
|
|
return hasattr(param, method) and callable(getattr(param, method))
|
|
return testFn
|
|
|
|
CALLABLE = callable
|
|
|
|
|
|
|
|
|
|
# vim: se ts=3:
|