Restart component on Jabber connection failure

darcs-hash:20080819060517-86b55-a7f828daab8f171e9f4b168d19aa4c288f2fc747.gz
This commit is contained in:
David Rousselie
2008-08-19 08:05:17 +02:00
parent 113c242f05
commit 48f562df6c
4 changed files with 208 additions and 166 deletions

View File

@@ -36,6 +36,7 @@ import re
import traceback import traceback
import string import string
import time import time
import socket
from Queue import Queue from Queue import Queue
@@ -665,10 +666,10 @@ class JCLComponent(Component, object):
Call Component main loop Call Component main loop
Clean up when shutting down JCLcomponent Clean up when shutting down JCLcomponent
""" """
self.connect()
self.spool_dir += "/" + unicode(self.jid) self.spool_dir += "/" + unicode(self.jid)
self.running = True self.running = True
self.last_activity = int(time.time()) self.last_activity = int(time.time())
self.connect()
timer_thread = threading.Thread(target=self.time_handler, timer_thread = threading.Thread(target=self.time_handler,
name="TimerThread") name="TimerThread")
timer_thread.start() timer_thread.start()
@@ -679,17 +680,20 @@ class JCLComponent(Component, object):
self.stream.loop_iter(JCLComponent.timeout) self.stream.loop_iter(JCLComponent.timeout)
if self.queue.qsize(): if self.queue.qsize():
raise self.queue.get(0) raise self.queue.get(0)
except socket.error, e:
self.__logger.info("Connection failed, restarting.")
return (True, 5)
finally: finally:
self.running = False self.running = False
timer_thread.join(JCLComponent.timeout)
self.wait_event.set()
if self.stream and not self.stream.eof \ if self.stream and not self.stream.eof \
and self.stream.socket is not None: and self.stream.socket is not None:
presences = self.account_manager.get_presence_all("unavailable") presences = self.account_manager.get_presence_all("unavailable")
self.send_stanzas(presences) self.send_stanzas(presences)
self.wait_event.set() self.disconnect()
timer_thread.join(JCLComponent.timeout) self.__logger.debug("Exitting normally")
self.disconnect() return (self._restart, 0)
self.__logger.debug("Exitting normally")
return self._restart
def _get_restart(self): def _get_restart(self):
return self._restart return self._restart

View File

@@ -29,6 +29,7 @@ import re
from ConfigParser import ConfigParser from ConfigParser import ConfigParser
import tempfile import tempfile
import os import os
import socket
from pyxmpp.jid import JID from pyxmpp.jid import JID
from pyxmpp.iq import Iq from pyxmpp.iq import Iq
@@ -56,8 +57,8 @@ class MockStream(object):
jid="", jid="",
secret="", secret="",
server="", server="",
port="", port=1,
keepalive=True): keepalive=None):
self.sent = [] self.sent = []
self.connection_started = False self.connection_started = False
self.connection_stopped = False self.connection_stopped = False
@@ -117,7 +118,9 @@ class MockStream(object):
class MockStreamNoConnect(MockStream): class MockStreamNoConnect(MockStream):
def connect(self): def connect(self):
self.connection_started = True self.connection_started = True
self.eof = True
def loop_iter(self, timeout):
return
class MockStreamRaiseException(MockStream): class MockStreamRaiseException(MockStream):
def loop_iter(self, timeout): def loop_iter(self, timeout):
@@ -222,160 +225,10 @@ class JCLComponent_TestCase(JCLTestCase):
self.assertEquals(len(handler1.handled), 1) self.assertEquals(len(handler1.handled), 1)
self.assertEquals(len(handler2.handled), 0) self.assertEquals(len(handler2.handled), 0)
###########################################################################
# 'run' tests
###########################################################################
def __comp_run(self):
try:
self.comp.run()
except:
# Ignore exception, might be obtain from self.comp.queue
pass
def __comp_time_handler(self):
try:
self.saved_time_handler()
except:
# Ignore exception, might be obtain from self.comp.queue
pass
def test_run(self):
"""Test basic main loop execution"""
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
# Tests in subclasses might be more precise
self.comp.stream = MockStreamNoConnect()
self.comp.stream_class = MockStreamNoConnect
result = self.comp.run()
self.assertFalse(result)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
def test_run_restart(self):
"""Test main loop execution with restart"""
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
# Tests in subclasses might be more precise
self.comp.stream = MockStreamNoConnect()
self.comp.stream_class = MockStreamNoConnect
self.comp.restart = True
result = self.comp.run()
self.assertTrue(result)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
def test_run_unhandled_error(self):
"""Test main loop unhandled error from a component handler"""
def do_nothing():
return
self.comp.time_unit = 1
self.comp.stream = MockStreamRaiseException()
self.comp.stream_class = MockStreamRaiseException
self.comp.handle_tick = do_nothing
try:
self.comp.run()
except Exception, e:
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
return
self.fail("No exception caught")
def test_run_ni_handle_tick(self):
"""Test JCLComponent 'NotImplemented' error from handle_tick method"""
self.comp.time_unit = 1
self.comp.stream = MockStream()
self.comp.stream_class = MockStream
try:
self.comp.run()
except NotImplementedError, e:
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
return
self.fail("No exception caught")
def test_run_go_offline(self):
"""Test main loop send offline presence when exiting"""
self.comp.stream = MockStream()
self.comp.stream_class = MockStream
self.comp.time_unit = 1
self.max_tick_count = 1
self.comp.handle_tick = self.__handle_tick_test_time_handler
model.db_connect()
user1 = User(jid="test1@test.com")
account11 = Account(user=user1,
name="account11",
jid="account11@jcl.test.com")
account12 = Account(user=user1,
name="account12",
jid="account12@jcl.test.com")
account2 = Account(user=User(jid="test2@test.com"),
name="account2",
jid="account2@jcl.test.com")
model.db_disconnect()
self.comp.run()
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
presence_sent = self.comp.stream.sent
self.assertEqual(len(presence_sent), 5)
self.assertEqual(len([presence
for presence in presence_sent
if presence.get_to_jid() == "test1@test.com"]),
3)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"jcl.test.com"
and presence.xpath_eval("@type")[0].get_content()
== "unavailable"]),
2)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account11@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account12@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
self.assertEqual(len([presence \
for presence in presence_sent
if presence.get_to_jid() == "test2@test.com"]),
2)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account2@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
########################################################################### ###########################################################################
# 'time_handler' tests # 'time_handler' tests
########################################################################### ###########################################################################
def __handle_tick_test_time_handler(self): def _handle_tick_test_time_handler(self):
self.max_tick_count -= 1 self.max_tick_count -= 1
if self.max_tick_count == 0: if self.max_tick_count == 0:
self.comp.running = False self.comp.running = False
@@ -383,7 +236,7 @@ class JCLComponent_TestCase(JCLTestCase):
def test_time_handler(self): def test_time_handler(self):
self.comp.time_unit = 1 self.comp.time_unit = 1
self.max_tick_count = 1 self.max_tick_count = 1
self.comp.handle_tick = self.__handle_tick_test_time_handler self.comp.handle_tick = self._handle_tick_test_time_handler
self.comp.stream = MockStream() self.comp.stream = MockStream()
self.comp.running = True self.comp.running = True
self.comp.time_handler() self.comp.time_handler()
@@ -3054,6 +2907,187 @@ class JCLComponent_TestCase(JCLTestCase):
self.assertEquals(fields[1].children.name, "value") self.assertEquals(fields[1].children.name, "value")
self.assertEquals(fields[1].children.content, "1") self.assertEquals(fields[1].children.content, "1")
###########################################################################
# 'run' tests
###########################################################################
class JCLComponent_run_TestCase(JCLComponent_TestCase):
def __comp_run(self):
try:
self.comp.run()
except:
# Ignore exception, might be obtain from self.comp.queue
pass
def __comp_time_handler(self):
try:
self.saved_time_handler()
except:
# Ignore exception, might be obtain from self.comp.queue
pass
def test_run(self):
"""Test basic main loop execution"""
def do_nothing():
self.comp.running = False
return
self.comp.handle_tick = do_nothing
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
self.comp.stream = MockStreamNoConnect()
self.comp.stream_class = MockStreamNoConnect
(result, time_to_wait) = self.comp.run()
self.assertEquals(time_to_wait, 0)
self.assertFalse(result)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
def test_run_restart(self):
"""Test main loop execution with restart"""
def do_nothing():
self.comp.running = False
return
self.comp.handle_tick = do_nothing
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
self.comp.stream = MockStreamNoConnect()
self.comp.stream_class = MockStreamNoConnect
self.comp.restart = True
(result, time_to_wait) = self.comp.run()
self.assertEquals(time_to_wait, 0)
self.assertTrue(result)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
def test_run_connection_failed(self):
"""Test when connection to Jabber server failed"""
class MockStreamLoopFailed(MockStream):
def connect(self):
self.connection_started = True
def loop_iter(self, timeout):
self.socket = None
raise socket.error
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
self.comp.stream = MockStreamLoopFailed()
self.comp.stream_class = MockStreamLoopFailed
self.comp.restart = False
(result, time_to_wait) = self.comp.run()
self.assertEquals(time_to_wait, 5)
self.assertTrue(result)
self.assertFalse(self.comp.running)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertFalse(self.comp.stream.connection_stopped)
def test_run_unhandled_error(self):
"""Test main loop unhandled error from a component handler"""
def do_nothing():
return
self.comp.time_unit = 1
self.comp.stream = MockStreamRaiseException()
self.comp.stream_class = MockStreamRaiseException
self.comp.handle_tick = do_nothing
try:
self.comp.run()
except Exception, e:
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
return
self.fail("No exception caught")
def test_run_ni_handle_tick(self):
"""Test JCLComponent 'NotImplemented' error from handle_tick method"""
self.comp.time_unit = 1
self.comp.stream = MockStream()
self.comp.stream_class = MockStream
try:
self.comp.run()
except NotImplementedError, e:
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
return
self.fail("No exception caught")
def test_run_go_offline(self):
"""Test main loop send offline presence when exiting"""
self.comp.stream = MockStream()
self.comp.stream_class = MockStream
self.comp.time_unit = 1
self.max_tick_count = 1
self.comp.handle_tick = self._handle_tick_test_time_handler
model.db_connect()
user1 = User(jid="test1@test.com")
account11 = Account(user=user1,
name="account11",
jid="account11@jcl.test.com")
account12 = Account(user=user1,
name="account12",
jid="account12@jcl.test.com")
account2 = Account(user=User(jid="test2@test.com"),
name="account2",
jid="account2@jcl.test.com")
model.db_disconnect()
self.comp.run()
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
presence_sent = self.comp.stream.sent
self.assertEqual(len(presence_sent), 5)
self.assertEqual(len([presence
for presence in presence_sent
if presence.get_to_jid() == "test1@test.com"]),
3)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"jcl.test.com"
and presence.xpath_eval("@type")[0].get_content()
== "unavailable"]),
2)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account11@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account12@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
self.assertEqual(len([presence \
for presence in presence_sent
if presence.get_to_jid() == "test2@test.com"]),
2)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account2@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
class Handler_TestCase(JCLTestCase): class Handler_TestCase(JCLTestCase):
def setUp(self): def setUp(self):
self.handler = Handler(None) self.handler = Handler(None)
@@ -3230,6 +3264,7 @@ class AccountManager_TestCase(JCLTestCase):
def suite(): def suite():
test_suite = unittest.TestSuite() test_suite = unittest.TestSuite()
test_suite.addTest(unittest.makeSuite(JCLComponent_TestCase, 'test')) test_suite.addTest(unittest.makeSuite(JCLComponent_TestCase, 'test'))
test_suite.addTest(unittest.makeSuite(JCLComponent_run_TestCase, 'test'))
test_suite.addTest(unittest.makeSuite(Handler_TestCase, 'test')) test_suite.addTest(unittest.makeSuite(Handler_TestCase, 'test'))
test_suite.addTest(unittest.makeSuite(AccountManager_TestCase, 'test')) test_suite.addTest(unittest.makeSuite(AccountManager_TestCase, 'test'))
return test_suite return test_suite

View File

@@ -25,6 +25,7 @@ import os
import sys import sys
from ConfigParser import ConfigParser from ConfigParser import ConfigParser
from getopt import gnu_getopt from getopt import gnu_getopt
import threading
from jcl.lang import Lang from jcl.lang import Lang
from jcl.jabber.component import JCLComponent from jcl.jabber.component import JCLComponent
@@ -94,6 +95,7 @@ class JCLRunner(object):
lambda arg: self.print_help())] lambda arg: self.print_help())]
self.logger = logging.getLogger() self.logger = logging.getLogger()
self.__debug = False self.__debug = False
self.wait_event = threading.Event()
def set_attr(self, attr, value): def set_attr(self, attr, value):
setattr(self, attr, value) setattr(self, attr, value)
@@ -222,7 +224,8 @@ class JCLRunner(object):
self.component_version + " is starting ...") self.component_version + " is starting ...")
restart = True restart = True
while restart: while restart:
restart = run_func() (restart, time_to_wait) = run_func()
self.wait_event.wait(time_to_wait)
self.logger.debug(self.component_name + " is exiting") self.logger.debug(self.component_name + " is exiting")
finally: finally:
if os.path.exists(self.pid_file): if os.path.exists(self.pid_file):

View File

@@ -206,7 +206,7 @@ class JCLRunner_TestCase(unittest.TestCase):
self.has_run_func = False self.has_run_func = False
def run_func(component_self): def run_func(component_self):
self.has_run_func = True self.has_run_func = True
return False return (False, 0)
self.runner.pid_file = "/tmp/jcl.pid" self.runner.pid_file = "/tmp/jcl.pid"
db_path = tempfile.mktemp("db", "jcltest", DB_DIR) db_path = tempfile.mktemp("db", "jcltest", DB_DIR)
@@ -234,7 +234,7 @@ class JCLRunner_TestCase(unittest.TestCase):
db_url = "sqlite://" + db_path db_url = "sqlite://" + db_path
self.runner.db_url = db_url self.runner.db_url = db_url
def do_nothing(): def do_nothing():
pass return (False, 0)
self.runner._run(do_nothing) self.runner._run(do_nothing)
model.db_connect() model.db_connect()
# dropTable should succeed because tables should exist # dropTable should succeed because tables should exist
@@ -254,9 +254,9 @@ class JCLRunner_TestCase(unittest.TestCase):
self.i = 0 self.i = 0
def restart(self): def restart(self):
self.i += 1 self.i += 1
yield True yield (True, 0)
self.i += 1 self.i += 1
yield False yield (False, 0)
self.i += 1 self.i += 1
restart_generator = restart(self) restart_generator = restart(self)
self.runner._run(lambda : restart_generator.next()) self.runner._run(lambda : restart_generator.next())