# -*- test-case-name: twisted.test.test_policies -*-
# Copyright (c) 2001-2009 Twisted Matrix Laboratories.
# See LICENSE for details.

Resource limiting policies.

@seealso: See also L{twisted.protocols.htb} for rate limiting.

# system imports
import sys, operator

from zope.interface import directlyProvides, providedBy

# twisted imports
from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
from twisted.internet import error
from twisted.python import log

class ProtocolWrapper(Protocol):
    Wraps protocol instances and acts as their transport as well.

    @ivar wrappedProtocol: An L{IProtocol} provider to which L{IProtocol}
        method calls onto this L{ProtocolWrapper} will be proxied.

    @ivar factory: The L{WrappingFactory} which created this

    disconnecting = 0

    def __init__(self, factory, wrappedProtocol):
        self.wrappedProtocol = wrappedProtocol
        self.factory = factory

    def makeConnection(self, transport):
        When a connection is made, register this wrapper with its factory,
        save the real transport, and connect the wrapped protocol to this
        L{ProtocolWrapper} to intercept any transport calls it makes.
        directlyProvides(self, providedBy(transport))
        Protocol.makeConnection(self, transport)

    # Transport relaying

    def write(self, data):

    def writeSequence(self, data):

    def loseConnection(self):
        self.disconnecting = 1

    def getPeer(self):
        return self.transport.getPeer()

    def getHost(self):
        return self.transport.getHost()

    def registerProducer(self, producer, streaming):
        self.transport.registerProducer(producer, streaming)

    def unregisterProducer(self):

    def stopConsuming(self):

    def __getattr__(self, name):
        return getattr(self.transport, name)

    # Protocol relaying

    def dataReceived(self, data):

    def connectionLost(self, reason):

class WrappingFactory(ClientFactory):
    """Wraps a factory and its protocols, and keeps track of them."""

    protocol = ProtocolWrapper

    def __init__(self, wrappedFactory):
        self.wrappedFactory = wrappedFactory
        self.protocols = {}

    def doStart(self):

    def doStop(self):

    def startedConnecting(self, connector):

    def clientConnectionFailed(self, connector, reason):
        self.wrappedFactory.clientConnectionFailed(connector, reason)

    def clientConnectionLost(self, connector, reason):
        self.wrappedFactory.clientConnectionLost(connector, reason)

    def buildProtocol(self, addr):
        return self.protocol(self, self.wrappedFactory.buildProtocol(addr))

    def registerProtocol(self, p):
        """Called by protocol to register itself."""
        self.protocols[p] = 1

    def unregisterProtocol(self, p):
        """Called by protocols when they go away."""
        del self.protocols[p]

class ThrottlingProtocol(ProtocolWrapper):
    """Protocol for ThrottlingFactory."""

    # wrap API for tracking bandwidth

    def write(self, data):
        ProtocolWrapper.write(self, data)

    def writeSequence(self, seq):
        self.factory.registerWritten(reduce(operator.add, map(len, seq)))
        ProtocolWrapper.writeSequence(self, seq)

    def dataReceived(self, data):
        ProtocolWrapper.dataReceived(self, data)

    def registerProducer(self, producer, streaming):
        self.producer = producer
        ProtocolWrapper.registerProducer(self, producer, streaming)

    def unregisterProducer(self):
        del self.producer

    def throttleReads(self):

    def unthrottleReads(self):

    def throttleWrites(self):
        if hasattr(self, "producer"):

    def unthrottleWrites(self):
        if hasattr(self, "producer"):

class ThrottlingFactory(WrappingFactory):
    Throttles bandwidth and number of connections.

    Write bandwidth will only be throttled if there is a producer

    protocol = ThrottlingProtocol

    def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
                 readLimit=None, writeLimit=None):
        WrappingFactory.__init__(self, wrappedFactory)
        self.connectionCount = 0
        self.maxConnectionCount = maxConnectionCount
        self.readLimit = readLimit # max bytes we should read per second
        self.writeLimit = writeLimit # max bytes we should write per second
        self.readThisSecond = 0
        self.writtenThisSecond = 0
        self.unthrottleReadsID = None
        self.checkReadBandwidthID = None
        self.unthrottleWritesID = None
        self.checkWriteBandwidthID = None

    def callLater(self, period, func):
        Wrapper around L{reactor.callLater} for test purpose.
        from twisted.internet import reactor
        return reactor.callLater(period, func)

    def registerWritten(self, length):
        Called by protocol to tell us more bytes were written.
        self.writtenThisSecond += length

    def registerRead(self, length):
        Called by protocol to tell us more bytes were read.
        self.readThisSecond += length

    def checkReadBandwidth(self):
        Checks if we've passed bandwidth limits.
        if self.readThisSecond > self.readLimit:
            throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
            self.unthrottleReadsID = self.callLater(throttleTime,
        self.readThisSecond = 0
        self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)

    def checkWriteBandwidth(self):
        if self.writtenThisSecond > self.writeLimit:
            throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
            self.unthrottleWritesID = self.callLater(throttleTime,
        # reset for next round
        self.writtenThisSecond = 0
        self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)

    def throttleReads(self):
        Throttle reads on all protocols.
        log.msg("Throttling reads on %s" % self)
        for p in self.protocols.keys():

    def unthrottleReads(self):
        Stop throttling reads on all protocols.
        self.unthrottleReadsID = None
        log.msg("Stopped throttling reads on %s" % self)
        for p in self.protocols.keys():

    def throttleWrites(self):
        Throttle writes on all protocols.
        log.msg("Throttling writes on %s" % self)
        for p in self.protocols.keys():

    def unthrottleWrites(self):
        Stop throttling writes on all protocols.
        self.unthrottleWritesID = None
        log.msg("Stopped throttling writes on %s" % self)
        for p in self.protocols.keys():

    def buildProtocol(self, addr):
        if self.connectionCount == 0:
            if self.readLimit is not None:
            if self.writeLimit is not None:

        if self.connectionCount < self.maxConnectionCount:
            self.connectionCount += 1
            return WrappingFactory.buildProtocol(self, addr)
            log.msg("Max connection count reached!")
            return None

    def unregisterProtocol(self, p):
        WrappingFactory.unregisterProtocol(self, p)
        self.connectionCount -= 1
        if self.connectionCount == 0:
            if self.unthrottleReadsID is not None:
            if self.checkReadBandwidthID is not None:
            if self.unthrottleWritesID is not None:
            if self.checkWriteBandwidthID is not None:

class SpewingProtocol(ProtocolWrapper):
    def dataReceived(self, data):
        log.msg("Received: %r" % data)

    def write(self, data):
        log.msg("Sending: %r" % data)

class SpewingFactory(WrappingFactory):
    protocol = SpewingProtocol

class LimitConnectionsByPeer(WrappingFactory):

    maxConnectionsPerPeer = 5

    def startFactory(self):
        self.peerConnections = {}

    def buildProtocol(self, addr):
        peerHost = addr[0]
        connectionCount = self.peerConnections.get(peerHost, 0)
        if connectionCount >= self.maxConnectionsPerPeer:
            return None
        self.peerConnections[peerHost] = connectionCount + 1
        return WrappingFactory.buildProtocol(self, addr)

    def unregisterProtocol(self, p):
        peerHost = p.getPeer()[1]
        self.peerConnections[peerHost] -= 1
        if self.peerConnections[peerHost] == 0:
            del self.peerConnections[peerHost]

class LimitTotalConnectionsFactory(ServerFactory):
    Factory that limits the number of simultaneous connections.

    @type connectionCount: C{int}
    @ivar connectionCount: number of current connections.
    @type connectionLimit: C{int} or C{None}
    @cvar connectionLimit: maximum number of connections.
    @type overflowProtocol: L{Protocol} or C{None}
    @cvar overflowProtocol: Protocol to use for new connections when
        connectionLimit is exceeded.  If C{None} (the default value), excess
        connections will be closed immediately.
    connectionCount = 0
    connectionLimit = None
    overflowProtocol = None

    def buildProtocol(self, addr):
        if (self.connectionLimit is None or
            self.connectionCount < self.connectionLimit):
                # Build the normal protocol
                wrappedProtocol = self.protocol()
        elif self.overflowProtocol is None:
            # Just drop the connection
            return None
            # Too many connections, so build the overflow protocol
            wrappedProtocol = self.overflowProtocol()

        wrappedProtocol.factory = self
        protocol = ProtocolWrapper(self, wrappedProtocol)
        self.connectionCount += 1
        return protocol

    def registerProtocol(self, p):

    def unregisterProtocol(self, p):
        self.connectionCount -= 1

class TimeoutProtocol(ProtocolWrapper):
    Protocol that automatically disconnects when the connection is idle.

    def __init__(self, factory, wrappedProtocol, timeoutPeriod):

        @param factory: An L{IFactory}.
        @param wrappedProtocol: A L{Protocol} to wrapp.
        @param timeoutPeriod: Number of seconds to wait for activity before
            timing out.
        ProtocolWrapper.__init__(self, factory, wrappedProtocol)
        self.timeoutCall = None

    def setTimeout(self, timeoutPeriod=None):
        Set a timeout.

        This will cancel any existing timeouts.

        @param timeoutPeriod: If not C{None}, change the timeout period.
            Otherwise, use the existing value.
        if timeoutPeriod is not None:
            self.timeoutPeriod = timeoutPeriod
        self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)

    def cancelTimeout(self):
        Cancel the timeout.

        If the timeout was already cancelled, this does nothing.
        if self.timeoutCall:
            except error.AlreadyCalled:
            self.timeoutCall = None

    def resetTimeout(self):
        Reset the timeout, usually because some activity just happened.
        if self.timeoutCall:

    def write(self, data):
        ProtocolWrapper.write(self, data)

    def writeSequence(self, seq):
        ProtocolWrapper.writeSequence(self, seq)

    def dataReceived(self, data):
        ProtocolWrapper.dataReceived(self, data)

    def connectionLost(self, reason):
        ProtocolWrapper.connectionLost(self, reason)

    def timeoutFunc(self):
        This method is called when the timeout is triggered.

        By default it calls L{loseConnection}.  Override this if you want
        something else to happen.

class TimeoutFactory(WrappingFactory):
    Factory for TimeoutWrapper.
    protocol = TimeoutProtocol

    def __init__(self, wrappedFactory, timeoutPeriod=30*60):
        self.timeoutPeriod = timeoutPeriod
        WrappingFactory.__init__(self, wrappedFactory)

    def buildProtocol(self, addr):
        return self.protocol(self, self.wrappedFactory.buildProtocol(addr),

    def callLater(self, period, func):
        Wrapper around L{reactor.callLater} for test purpose.
        from twisted.internet import reactor
        return reactor.callLater(period, func)

class TrafficLoggingProtocol(ProtocolWrapper):

    def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
        @param factory: factory which created this protocol.
        @type factory: C{protocol.Factory}.
        @param wrappedProtocol: the underlying protocol.
        @type wrappedProtocol: C{protocol.Protocol}.
        @param logfile: file opened for writing used to write log messages.
        @type logfile: C{file}
        @param lengthLimit: maximum size of the datareceived logged.
        @type lengthLimit: C{int}
        @param number: identifier of the connection.
        @type number: C{int}.
        ProtocolWrapper.__init__(self, factory, wrappedProtocol)
        self.logfile = logfile
        self.lengthLimit = lengthLimit
        self._number = number

    def _log(self, line):
        self.logfile.write(line + '\n')

    def _mungeData(self, data):
        if self.lengthLimit and len(data) > self.lengthLimit:
            data = data[:self.lengthLimit - 12] + '<... elided>'
        return data

    # IProtocol
    def connectionMade(self):
        return ProtocolWrapper.connectionMade(self)

    def dataReceived(self, data):
        self._log('C %d: %r' % (self._number, self._mungeData(data)))
        return ProtocolWrapper.dataReceived(self, data)

    def connectionLost(self, reason):
        self._log('C %d: %r' % (self._number, reason))
        return ProtocolWrapper.connectionLost(self, reason)

    # ITransport
    def write(self, data):
        self._log('S %d: %r' % (self._number, self._mungeData(data)))
        return ProtocolWrapper.write(self, data)

    def writeSequence(self, iovec):
        self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
        return ProtocolWrapper.writeSequence(self, iovec)

    def loseConnection(self):
        self._log('S %d: *' % (self._number,))
        return ProtocolWrapper.loseConnection(self)

class TrafficLoggingFactory(WrappingFactory):
    protocol = TrafficLoggingProtocol

    _counter = 0

    def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
        self.logfilePrefix = logfilePrefix
        self.lengthLimit = lengthLimit
        WrappingFactory.__init__(self, wrappedFactory)

    def open(self, name):
        return file(name, 'w')

    def buildProtocol(self, addr):
        self._counter += 1
        logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
        return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
                             logfile, self.lengthLimit, self._counter)

    def resetCounter(self):
        Reset the value of the counter used to identify connections.
        self._counter = 0

class TimeoutMixin:
    """Mixin for protocols which wish to timeout connections

    @cvar timeOut: The number of seconds after which to timeout the connection.
    timeOut = None

    __timeoutCall = None

    def callLater(self, period, func):
        from twisted.internet import reactor
        return reactor.callLater(period, func)

    def resetTimeout(self):
        """Reset the timeout count down"""
        if self.__timeoutCall is not None and self.timeOut is not None:

    def setTimeout(self, period):
        """Change the timeout period

        @type period: C{int} or C{NoneType}
        @param period: The period, in seconds, to change the timeout to, or
        C{None} to disable the timeout.
        prev = self.timeOut
        self.timeOut = period

        if self.__timeoutCall is not None:
            if period is None:
                self.__timeoutCall = None
        elif period is not None:
            self.__timeoutCall = self.callLater(period, self.__timedOut)

        return prev

    def __timedOut(self):
        self.__timeoutCall = None

    def timeoutConnection(self):
        """Called when the connection times out.
        Override to define behavior other than dropping the connection.

