467 lines
		
	
	
		
			No EOL
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			467 lines
		
	
	
		
			No EOL
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #
 | |
| # (c) Dave Kirby 2001 - 2005
 | |
| #     mock@thedeveloperscoach.com
 | |
| #
 | |
| # Original call interceptor and call assertion code by Phil Dawes (pdawes@users.sourceforge.net)
 | |
| # Call interceptor code enhanced by Bruce Cropley (cropleyb@yahoo.com.au)
 | |
| #
 | |
| # This Python  module and associated files are released under the FreeBSD
 | |
| # license. Essentially, you can do what you like with it except pretend you wrote
 | |
| # it yourself.
 | |
| # 
 | |
| # 
 | |
| #     Copyright (c) 2005, Dave Kirby
 | |
| # 
 | |
| #     All rights reserved.
 | |
| # 
 | |
| #     Redistribution and use in source and binary forms, with or without
 | |
| #     modification, are permitted provided that the following conditions are met:
 | |
| # 
 | |
| #         * Redistributions of source code must retain the above copyright
 | |
| #           notice, this list of conditions and the following disclaimer.
 | |
| # 
 | |
| #         * Redistributions in binary form must reproduce the above copyright
 | |
| #           notice, this list of conditions and the following disclaimer in the
 | |
| #           documentation and/or other materials provided with the distribution.
 | |
| # 
 | |
| #         * Neither the name of this library nor the names of its
 | |
| #           contributors may be used to endorse or promote products derived from
 | |
| #           this software without specific prior written permission.
 | |
| # 
 | |
| #     THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 | |
| #     ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 | |
| #     WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
| #     DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 | |
| #     ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 | |
| #     (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 | |
| #     LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 | |
| #     ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 | |
| #     (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 | |
| #     SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
| # 
 | |
| #         mock@thedeveloperscoach.com
 | |
| 
 | |
| 
 | |
| """
 | |
| Mock object library for Python. Mock objects can be used when unit testing
 | |
| to remove a dependency on another production class. They are typically used
 | |
| when the dependency would either pull in lots of other classes, or
 | |
| significantly slow down the execution of the test.
 | |
| They are also used to create exceptional conditions that cannot otherwise
 | |
| be easily triggered in the class under test.
 | |
| """
 | |
| 
 | |
| __version__ = "0.1.0"
 | |
| 
 | |
| # Added in Python 2.1
 | |
| import inspect
 | |
| import re
 | |
| 
 | |
| class MockInterfaceError(Exception):
 | |
|     pass
 | |
| 
 | |
| class Mock:
 | |
|     """
 | |
|     The Mock class emulates any other class for testing purposes.
 | |
|     All method calls are stored for later examination.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, returnValues=None, realClass=None):
 | |
|         """
 | |
|         The Mock class constructor takes a dictionary of method names and
 | |
|         the values they return.  Methods that are not in the returnValues
 | |
|         dictionary will return None.
 | |
|         You may also supply a class whose interface is being mocked.
 | |
|         All calls will be checked to see if they appear in the original
 | |
|         interface. Any calls to methods not appearing in the real class
 | |
|         will raise a MockInterfaceError.  Any calls that would fail due to
 | |
|         non-matching parameter lists will also raise a MockInterfaceError.
 | |
|         Both of these help to prevent the Mock class getting out of sync
 | |
|         with the class it is Mocking.
 | |
|         """
 | |
|         self.mockCalledMethods = {}
 | |
|         self.mockAllCalledMethods = []
 | |
|         self.mockReturnValues = returnValues or {}
 | |
|         self.mockExpectations = {}
 | |
|         self.realClassMethods = None
 | |
|         if realClass:
 | |
|             self.realClassMethods = dict(inspect.getmembers(realClass, inspect.isroutine))
 | |
|             for retMethod in self.mockReturnValues.keys():
 | |
|                 if not self.realClassMethods.has_key(retMethod):
 | |
|                     raise MockInterfaceError("Return value supplied for method '%s' that was not in the original class" % retMethod)
 | |
|         self._setupSubclassMethodInterceptors()
 | |
|      
 | |
|     def _setupSubclassMethodInterceptors(self):
 | |
|         methods = inspect.getmembers(self.__class__,inspect.isroutine)
 | |
|         baseMethods = dict(inspect.getmembers(Mock, inspect.ismethod))
 | |
|         for m in methods:
 | |
|             name = m[0]
 | |
|             # Don't record calls to methods of Mock base class.
 | |
|             if not name in baseMethods:
 | |
|                 self.__dict__[name] = MockCallable(name, self, handcrafted=True)
 | |
|  
 | |
|     def __getattr__(self, name):
 | |
|         return MockCallable(name, self)
 | |
|     
 | |
|     def mockAddReturnValues(self, **methodReturnValues ):
 | |
|         self.mockReturnValues.update(methodReturnValues)
 | |
|         
 | |
|     def mockSetExpectation(self, name, testFn, after=0, until=0):
 | |
|         self.mockExpectations.setdefault(name, []).append((testFn,after,until))
 | |
| 
 | |
|     def _checkInterfaceCall(self, name, callParams, callKwParams):
 | |
|         """
 | |
|         Check that a call to a method of the given name to the original
 | |
|         class with the given parameters would not fail. If it would fail,
 | |
|         raise a MockInterfaceError.
 | |
|         Based on the Python 2.3.3 Reference Manual section 5.3.4: Calls.
 | |
|         """
 | |
|         if self.realClassMethods == None:
 | |
|             return
 | |
|         if not self.realClassMethods.has_key(name):
 | |
|             raise MockInterfaceError("Calling mock method '%s' that was not found in the original class" % name)
 | |
| 
 | |
|         func = self.realClassMethods[name]
 | |
|         try:
 | |
|             args, varargs, varkw, defaults = inspect.getargspec(func)
 | |
|         except TypeError:
 | |
|             # func is not a Python function. It is probably a builtin,
 | |
|             # such as __repr__ or __coerce__. TODO: Checking?
 | |
|             # For now assume params are OK.
 | |
|             return
 | |
| 
 | |
|         # callParams doesn't include self; args does include self.
 | |
|         numPosCallParams = 1 + len(callParams)
 | |
| 
 | |
|         if numPosCallParams > len(args) and not varargs:
 | |
|             raise MockInterfaceError("Original %s() takes at most %s arguments (%s given)" % 
 | |
|                 (name, len(args), numPosCallParams))
 | |
| 
 | |
|         # Get the number of positional arguments that appear in the call,
 | |
|         # also check for duplicate parameters and unknown parameters
 | |
|         numPosSeen = _getNumPosSeenAndCheck(numPosCallParams, callKwParams, args, varkw)
 | |
| 
 | |
|         lenArgsNoDefaults = len(args) - len(defaults or [])
 | |
|         if numPosSeen < lenArgsNoDefaults:
 | |
|             raise MockInterfaceError("Original %s() takes at least %s arguments (%s given)" % (name, lenArgsNoDefaults, numPosSeen))
 | |
| 
 | |
|     def mockGetAllCalls(self):
 | |
|         """
 | |
|         Return a list of MockCall objects,
 | |
|         representing all the methods in the order they were called.
 | |
|         """
 | |
|         return self.mockAllCalledMethods
 | |
|     getAllCalls = mockGetAllCalls  # deprecated - kept for backward compatibility
 | |
| 
 | |
|     def mockGetNamedCalls(self, methodName):
 | |
|         """
 | |
|         Return a list of MockCall objects,
 | |
|         representing all the calls to the named method in the order they were called.
 | |
|         """
 | |
|         return self.mockCalledMethods.get(methodName, [])
 | |
|     getNamedCalls = mockGetNamedCalls  # deprecated - kept for backward compatibility
 | |
| 
 | |
|     def mockCheckCall(self, index, name, *args, **kwargs):
 | |
|         '''test that the index-th call had the specified name and parameters'''
 | |
|         call = self.mockAllCalledMethods[index]
 | |
|         assert name == call.getName(), "%r != %r" % (name, call.getName())
 | |
|         call.checkArgs(*args, **kwargs)
 | |
| 
 | |
| 
 | |
| def _getNumPosSeenAndCheck(numPosCallParams, callKwParams, args, varkw):
 | |
|     """
 | |
|     Positional arguments can appear as call parameters either named as
 | |
|     a named (keyword) parameter, or just as a value to be matched by
 | |
|     position. Count the positional arguments that are given by either
 | |
|     keyword or position, and check for duplicate specifications.
 | |
|     Also check for arguments specified by keyword that do not appear
 | |
|     in the method's parameter list.
 | |
|     """
 | |
|     posSeen = {}
 | |
|     for arg in args[:numPosCallParams]:
 | |
|         posSeen[arg] = True
 | |
|     for kwp in callKwParams:
 | |
|         if posSeen.has_key(kwp):
 | |
|             raise MockInterfaceError("%s appears as both a positional and named parameter." % kwp)
 | |
|         if kwp in args:
 | |
|             posSeen[kwp] = True
 | |
|         elif not varkw:
 | |
|             raise MockInterfaceError("Original method does not have a parameter '%s'" % kwp)
 | |
|     return len(posSeen)
 | |
| 
 | |
| class MockCall:
 | |
|     """
 | |
|     MockCall records the name and parameters of a call to an instance
 | |
|     of a Mock class. Instances of MockCall are created by the Mock class,
 | |
|     but can be inspected later as part of the test.
 | |
|     """
 | |
|     def __init__(self, name, params, kwparams ):
 | |
|         self.name = name
 | |
|         self.params = params
 | |
|         self.kwparams = kwparams
 | |
| 
 | |
|     def checkArgs(self, *args, **kwargs):
 | |
|         assert args == self.params, "%r != %r" % (args, self.params)
 | |
|         assert kwargs == self.kwparams, "%r != %r" % (kwargs, self.kwparams)
 | |
| 
 | |
|     def getParam( self, n ):
 | |
|         if isinstance(n, int):
 | |
|             return self.params[n]
 | |
|         elif isinstance(n, str):
 | |
|             return self.kwparams[n]
 | |
|         else:
 | |
|             raise IndexError, 'illegal index type for getParam'
 | |
| 
 | |
|     def getNumParams(self):
 | |
|         return len(self.params)
 | |
| 
 | |
|     def getNumKwParams(self):
 | |
|         return len(self.kwparams)
 | |
| 
 | |
|     def getName(self):
 | |
|         return self.name
 | |
|     
 | |
|     #pretty-print the method call
 | |
|     def __str__(self):
 | |
|         s = self.name + "("
 | |
|         sep = ''
 | |
|         for p in self.params:
 | |
|             s = s + sep + repr(p)
 | |
|             sep = ', '
 | |
|         items = self.kwparams.items()
 | |
|         items.sort()
 | |
|         for k,v in items:
 | |
|             s = s + sep + k + '=' + repr(v)
 | |
|             sep = ', '
 | |
|         s = s + ')'
 | |
|         return s
 | |
|     def __repr__(self):
 | |
|         return self.__str__()
 | |
| 
 | |
| class MockCallable:
 | |
|     """
 | |
|     Intercepts the call and records it, then delegates to either the mock's
 | |
|     dictionary of mock return values that was passed in to the constructor,
 | |
|     or a handcrafted method of a Mock subclass.
 | |
|     """
 | |
|     def __init__(self, name, mock, handcrafted=False):
 | |
|         self.name = name
 | |
|         self.mock = mock
 | |
|         self.handcrafted = handcrafted
 | |
| 
 | |
|     def __call__(self,  *params, **kwparams):
 | |
|         self.mock._checkInterfaceCall(self.name, params, kwparams)
 | |
|         thisCall = self.recordCall(params,kwparams)
 | |
|         self.checkExpectations(thisCall, params, kwparams)
 | |
|         return self.makeCall(params, kwparams)
 | |
| 
 | |
|     def recordCall(self, params, kwparams):
 | |
|         """
 | |
|         Record the MockCall in an ordered list of all calls, and an ordered
 | |
|         list of calls for that method name.
 | |
|         """
 | |
|         thisCall = MockCall(self.name, params, kwparams)
 | |
|         calls = self.mock.mockCalledMethods.setdefault(self.name, [])
 | |
|         calls.append(thisCall)
 | |
|         self.mock.mockAllCalledMethods.append(thisCall)
 | |
|         return thisCall
 | |
| 
 | |
|     def makeCall(self, params, kwparams):
 | |
|         if self.handcrafted:
 | |
|             allPosParams = (self.mock,) + params
 | |
|             func = _findFunc(self.mock.__class__, self.name)
 | |
|             if not func:
 | |
|                 raise NotImplementedError
 | |
|             return func(*allPosParams, **kwparams)
 | |
|         else:
 | |
|             returnVal = self.mock.mockReturnValues.get(self.name)
 | |
|             if isinstance(returnVal, ReturnValuesBase):
 | |
|                 returnVal = returnVal.next()
 | |
|             return returnVal
 | |
| 
 | |
|     def checkExpectations(self, thisCall, params, kwparams):
 | |
|         if self.name in self.mock.mockExpectations:
 | |
|             callsMade = len(self.mock.mockCalledMethods[self.name])
 | |
|             for (expectation, after, until) in self.mock.mockExpectations[self.name]:
 | |
|                 if callsMade > after and (until==0 or callsMade < until):
 | |
|                     assert expectation(self.mock, thisCall, len(self.mock.mockAllCalledMethods)-1), 'Expectation failed: '+str(thisCall)
 | |
| 
 | |
| 
 | |
| def _findFunc(cl, name):
 | |
|     """ Depth first search for a method with a given name. """
 | |
|     if cl.__dict__.has_key(name):
 | |
|         return cl.__dict__[name]
 | |
|     for base in cl.__bases__:
 | |
|         func = _findFunc(base, name)
 | |
|         if func:
 | |
|             return func
 | |
|     return None
 | |
| 
 | |
| 
 | |
| 
 | |
| class ReturnValuesBase:
 | |
|     def next(self):
 | |
|         try:
 | |
|             return self.iter.next()
 | |
|         except StopIteration:
 | |
|             raise AssertionError("No more return values")
 | |
|     def __iter__(self):
 | |
|         return self
 | |
| 
 | |
| class ReturnValues(ReturnValuesBase):
 | |
|     def __init__(self, *values):
 | |
|         self.iter = iter(values)
 | |
|         
 | |
| 
 | |
| class ReturnIterator(ReturnValuesBase):
 | |
|     def __init__(self, iterator):
 | |
|         self.iter = iter(iterator)
 | |
| 
 | |
|         
 | |
| def expectParams(*params, **keywords):
 | |
|     '''check that the callObj is called with specified params and keywords
 | |
|     '''
 | |
|     def fn(mockObj, callObj, idx):
 | |
|         return callObj.params == params and callObj.kwparams == keywords
 | |
|     return fn
 | |
| 
 | |
| 
 | |
| def expectAfter(*methods):
 | |
|     '''check that the function is only called after all the functions in 'methods'
 | |
|     '''
 | |
|     def fn(mockObj, callObj, idx):
 | |
|         calledMethods = [method.getName() for method in mockObj.mockGetAllCalls()]
 | |
|         #skip last entry, since that is the current call
 | |
|         calledMethods = calledMethods[:-1]
 | |
|         for method in methods:  
 | |
|             if method not in calledMethods:
 | |
|                 return False
 | |
|         return True
 | |
|     return fn
 | |
| 
 | |
| def expectException(exception, *args, **kwargs):
 | |
|     ''' raise an exception when the method is called
 | |
|     '''
 | |
|     def fn(mockObj, callObj, idx):
 | |
|         raise exception(*args, **kwargs)
 | |
|     return fn
 | |
| 
 | |
| 
 | |
| def expectParam(paramIdx, cond):
 | |
|     '''check that the callObj is called with parameter specified by paramIdx (a position index or keyword)
 | |
|     fulfills the condition specified by cond.
 | |
|     cond is a function that takes a single argument, the value to test.
 | |
|     '''
 | |
|     def fn(mockObj, callObj, idx):
 | |
|         param = callObj.getParam(paramIdx)
 | |
|         return cond(param)
 | |
|     return fn
 | |
| 
 | |
| def EQ(value):
 | |
|     def testFn(param):
 | |
|         return param == value
 | |
|     return testFn
 | |
| 
 | |
| def NE(value):
 | |
|     def testFn(param):
 | |
|         return param != value
 | |
|     return testFn
 | |
| 
 | |
| def GT(value):
 | |
|     def testFn(param):
 | |
|         return param > value
 | |
|     return testFn
 | |
| 
 | |
| def LT(value):
 | |
|     def testFn(param):
 | |
|         return param < value
 | |
|     return testFn
 | |
| 
 | |
| def GE(value):
 | |
|     def testFn(param):
 | |
|         return param >= value
 | |
|     return testFn
 | |
| 
 | |
| def LE(value):
 | |
|     def testFn(param):
 | |
|         return param <= value
 | |
|     return testFn
 | |
| 
 | |
| def AND(*condlist):
 | |
|     def testFn(param):
 | |
|         for cond in condlist:
 | |
|             if not cond(param):
 | |
|                 return False
 | |
|         return True
 | |
|     return testFn
 | |
| 
 | |
| def OR(*condlist):
 | |
|     def testFn(param):
 | |
|         for cond in condlist:
 | |
|             if cond(param):
 | |
|                 return True
 | |
|         return False
 | |
|     return testFn
 | |
| 
 | |
| def NOT(cond):
 | |
|     def testFn(param):
 | |
|         return not cond(param)
 | |
|     return testFn
 | |
| 
 | |
| def MATCHES(regex, *args, **kwargs):
 | |
|     compiled_regex = re.compile(regex, *args, **kwargs)
 | |
|     def testFn(param):
 | |
|         return compiled_regex.match(param) != None
 | |
|     return testFn
 | |
| 
 | |
| def SEQ(*sequence):
 | |
|     iterator = iter(sequence)
 | |
|     def testFn(param):
 | |
|         try:
 | |
|             cond = iterator.next()
 | |
|         except StopIteration:
 | |
|             raise AssertionError('SEQ exhausted')
 | |
|         return cond(param)
 | |
|     return testFn
 | |
| 
 | |
| def IS(instance):
 | |
|     def testFn(param):
 | |
|         return param is instance
 | |
|     return testFn
 | |
| 
 | |
| def ISINSTANCE(class_):
 | |
|     def testFn(param):
 | |
|         return isinstance(param, class_) 
 | |
|     return testFn
 | |
| 
 | |
| def ISSUBCLASS(class_):
 | |
|     def testFn(param):
 | |
|         return issubclass(param, class_) 
 | |
|     return testFn
 | |
| 
 | |
| def CONTAINS(val):
 | |
|     def testFn(param):
 | |
|         return val in param 
 | |
|     return testFn
 | |
| 
 | |
| def IN(container):
 | |
|     def testFn(param):
 | |
|         return param in container
 | |
|     return testFn
 | |
| 
 | |
| def HASATTR(attr):
 | |
|     def testFn(param):
 | |
|         return hasattr(param, attr)
 | |
|     return testFn
 | |
| 
 | |
| def HASMETHOD(method):
 | |
|     def testFn(param):
 | |
|         return hasattr(param, method) and callable(getattr(param, method))
 | |
|     return testFn
 | |
| 
 | |
| CALLABLE = callable
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| # vim: se ts=3: |