From 48f562df6ca1c9700d3389ad47934347d3639505 Mon Sep 17 00:00:00 2001 From: David Rousselie Date: Tue, 19 Aug 2008 08:05:17 +0200 Subject: [PATCH] Restart component on Jabber connection failure darcs-hash:20080819060517-86b55-a7f828daab8f171e9f4b168d19aa4c288f2fc747.gz --- src/jcl/jabber/component.py | 16 +- src/jcl/jabber/tests/component.py | 345 ++++++++++++++++-------------- src/jcl/runner.py | 5 +- src/jcl/tests/runner.py | 8 +- 4 files changed, 208 insertions(+), 166 deletions(-) diff --git a/src/jcl/jabber/component.py b/src/jcl/jabber/component.py index e50ac51..adced1c 100644 --- a/src/jcl/jabber/component.py +++ b/src/jcl/jabber/component.py @@ -36,6 +36,7 @@ import re import traceback import string import time +import socket from Queue import Queue @@ -665,10 +666,10 @@ class JCLComponent(Component, object): Call Component main loop Clean up when shutting down JCLcomponent """ + self.connect() self.spool_dir += "/" + unicode(self.jid) self.running = True self.last_activity = int(time.time()) - self.connect() timer_thread = threading.Thread(target=self.time_handler, name="TimerThread") timer_thread.start() @@ -679,17 +680,20 @@ class JCLComponent(Component, object): self.stream.loop_iter(JCLComponent.timeout) if self.queue.qsize(): raise self.queue.get(0) + except socket.error, e: + self.__logger.info("Connection failed, restarting.") + return (True, 5) finally: self.running = False + timer_thread.join(JCLComponent.timeout) + self.wait_event.set() if self.stream and not self.stream.eof \ and self.stream.socket is not None: presences = self.account_manager.get_presence_all("unavailable") self.send_stanzas(presences) - self.wait_event.set() - timer_thread.join(JCLComponent.timeout) - self.disconnect() - self.__logger.debug("Exitting normally") - return self._restart + self.disconnect() + self.__logger.debug("Exitting normally") + return (self._restart, 0) def _get_restart(self): return self._restart diff --git a/src/jcl/jabber/tests/component.py b/src/jcl/jabber/tests/component.py index fd8b90c..c8888d6 100644 --- a/src/jcl/jabber/tests/component.py +++ b/src/jcl/jabber/tests/component.py @@ -29,6 +29,7 @@ import re from ConfigParser import ConfigParser import tempfile import os +import socket from pyxmpp.jid import JID from pyxmpp.iq import Iq @@ -56,8 +57,8 @@ class MockStream(object): jid="", secret="", server="", - port="", - keepalive=True): + port=1, + keepalive=None): self.sent = [] self.connection_started = False self.connection_stopped = False @@ -117,7 +118,9 @@ class MockStream(object): class MockStreamNoConnect(MockStream): def connect(self): self.connection_started = True - self.eof = True + + def loop_iter(self, timeout): + return class MockStreamRaiseException(MockStream): def loop_iter(self, timeout): @@ -222,160 +225,10 @@ class JCLComponent_TestCase(JCLTestCase): self.assertEquals(len(handler1.handled), 1) self.assertEquals(len(handler2.handled), 0) - ########################################################################### - # 'run' tests - ########################################################################### - def __comp_run(self): - try: - self.comp.run() - except: - # Ignore exception, might be obtain from self.comp.queue - pass - - def __comp_time_handler(self): - try: - self.saved_time_handler() - except: - # Ignore exception, might be obtain from self.comp.queue - pass - - def test_run(self): - """Test basic main loop execution""" - self.comp.time_unit = 1 - # Do not loop, handle_tick is virtual - # Tests in subclasses might be more precise - self.comp.stream = MockStreamNoConnect() - self.comp.stream_class = MockStreamNoConnect - result = self.comp.run() - self.assertFalse(result) - self.assertTrue(self.comp.stream.connection_started) - threads = threading.enumerate() - self.assertEquals(len(threads), 1) - self.assertTrue(self.comp.stream.connection_stopped) - if self.comp.queue.qsize(): - raise self.comp.queue.get(0) - - def test_run_restart(self): - """Test main loop execution with restart""" - self.comp.time_unit = 1 - # Do not loop, handle_tick is virtual - # Tests in subclasses might be more precise - self.comp.stream = MockStreamNoConnect() - self.comp.stream_class = MockStreamNoConnect - self.comp.restart = True - result = self.comp.run() - self.assertTrue(result) - self.assertTrue(self.comp.stream.connection_started) - threads = threading.enumerate() - self.assertEquals(len(threads), 1) - self.assertTrue(self.comp.stream.connection_stopped) - if self.comp.queue.qsize(): - raise self.comp.queue.get(0) - - def test_run_unhandled_error(self): - """Test main loop unhandled error from a component handler""" - def do_nothing(): - return - self.comp.time_unit = 1 - self.comp.stream = MockStreamRaiseException() - self.comp.stream_class = MockStreamRaiseException - self.comp.handle_tick = do_nothing - try: - self.comp.run() - except Exception, e: - threads = threading.enumerate() - self.assertEquals(len(threads), 1) - self.assertTrue(self.comp.stream.connection_stopped) - return - self.fail("No exception caught") - - def test_run_ni_handle_tick(self): - """Test JCLComponent 'NotImplemented' error from handle_tick method""" - self.comp.time_unit = 1 - self.comp.stream = MockStream() - self.comp.stream_class = MockStream - try: - self.comp.run() - except NotImplementedError, e: - threads = threading.enumerate() - self.assertEquals(len(threads), 1) - self.assertTrue(self.comp.stream.connection_stopped) - return - self.fail("No exception caught") - - def test_run_go_offline(self): - """Test main loop send offline presence when exiting""" - self.comp.stream = MockStream() - self.comp.stream_class = MockStream - self.comp.time_unit = 1 - self.max_tick_count = 1 - self.comp.handle_tick = self.__handle_tick_test_time_handler - model.db_connect() - user1 = User(jid="test1@test.com") - account11 = Account(user=user1, - name="account11", - jid="account11@jcl.test.com") - account12 = Account(user=user1, - name="account12", - jid="account12@jcl.test.com") - account2 = Account(user=User(jid="test2@test.com"), - name="account2", - jid="account2@jcl.test.com") - model.db_disconnect() - self.comp.run() - self.assertTrue(self.comp.stream.connection_started) - threads = threading.enumerate() - self.assertEquals(len(threads), 1) - self.assertTrue(self.comp.stream.connection_stopped) - if self.comp.queue.qsize(): - raise self.comp.queue.get(0) - presence_sent = self.comp.stream.sent - self.assertEqual(len(presence_sent), 5) - self.assertEqual(len([presence - for presence in presence_sent - if presence.get_to_jid() == "test1@test.com"]), - 3) - self.assertEqual(\ - len([presence - for presence in presence_sent - if presence.get_from_jid() == \ - "jcl.test.com" - and presence.xpath_eval("@type")[0].get_content() - == "unavailable"]), - 2) - self.assertEqual(\ - len([presence - for presence in presence_sent - if presence.get_from_jid() == \ - "account11@jcl.test.com" - and presence.xpath_eval("@type")[0].get_content() \ - == "unavailable"]), - 1) - self.assertEqual(\ - len([presence - for presence in presence_sent - if presence.get_from_jid() == \ - "account12@jcl.test.com" - and presence.xpath_eval("@type")[0].get_content() \ - == "unavailable"]), - 1) - self.assertEqual(len([presence \ - for presence in presence_sent - if presence.get_to_jid() == "test2@test.com"]), - 2) - self.assertEqual(\ - len([presence - for presence in presence_sent - if presence.get_from_jid() == \ - "account2@jcl.test.com" - and presence.xpath_eval("@type")[0].get_content() \ - == "unavailable"]), - 1) - ########################################################################### # 'time_handler' tests ########################################################################### - def __handle_tick_test_time_handler(self): + def _handle_tick_test_time_handler(self): self.max_tick_count -= 1 if self.max_tick_count == 0: self.comp.running = False @@ -383,7 +236,7 @@ class JCLComponent_TestCase(JCLTestCase): def test_time_handler(self): self.comp.time_unit = 1 self.max_tick_count = 1 - self.comp.handle_tick = self.__handle_tick_test_time_handler + self.comp.handle_tick = self._handle_tick_test_time_handler self.comp.stream = MockStream() self.comp.running = True self.comp.time_handler() @@ -3054,6 +2907,187 @@ class JCLComponent_TestCase(JCLTestCase): self.assertEquals(fields[1].children.name, "value") self.assertEquals(fields[1].children.content, "1") +########################################################################### +# 'run' tests +########################################################################### +class JCLComponent_run_TestCase(JCLComponent_TestCase): + def __comp_run(self): + try: + self.comp.run() + except: + # Ignore exception, might be obtain from self.comp.queue + pass + + def __comp_time_handler(self): + try: + self.saved_time_handler() + except: + # Ignore exception, might be obtain from self.comp.queue + pass + + def test_run(self): + """Test basic main loop execution""" + def do_nothing(): + self.comp.running = False + return + self.comp.handle_tick = do_nothing + self.comp.time_unit = 1 + # Do not loop, handle_tick is virtual + self.comp.stream = MockStreamNoConnect() + self.comp.stream_class = MockStreamNoConnect + (result, time_to_wait) = self.comp.run() + self.assertEquals(time_to_wait, 0) + self.assertFalse(result) + self.assertTrue(self.comp.stream.connection_started) + threads = threading.enumerate() + self.assertEquals(len(threads), 1) + self.assertTrue(self.comp.stream.connection_stopped) + if self.comp.queue.qsize(): + raise self.comp.queue.get(0) + + def test_run_restart(self): + """Test main loop execution with restart""" + def do_nothing(): + self.comp.running = False + return + self.comp.handle_tick = do_nothing + self.comp.time_unit = 1 + # Do not loop, handle_tick is virtual + self.comp.stream = MockStreamNoConnect() + self.comp.stream_class = MockStreamNoConnect + self.comp.restart = True + (result, time_to_wait) = self.comp.run() + self.assertEquals(time_to_wait, 0) + self.assertTrue(result) + self.assertTrue(self.comp.stream.connection_started) + threads = threading.enumerate() + self.assertEquals(len(threads), 1) + self.assertTrue(self.comp.stream.connection_stopped) + if self.comp.queue.qsize(): + raise self.comp.queue.get(0) + + def test_run_connection_failed(self): + """Test when connection to Jabber server failed""" + class MockStreamLoopFailed(MockStream): + def connect(self): + self.connection_started = True + def loop_iter(self, timeout): + self.socket = None + raise socket.error + self.comp.time_unit = 1 + # Do not loop, handle_tick is virtual + self.comp.stream = MockStreamLoopFailed() + self.comp.stream_class = MockStreamLoopFailed + self.comp.restart = False + (result, time_to_wait) = self.comp.run() + self.assertEquals(time_to_wait, 5) + self.assertTrue(result) + self.assertFalse(self.comp.running) + self.assertTrue(self.comp.stream.connection_started) + threads = threading.enumerate() + self.assertEquals(len(threads), 1) + self.assertFalse(self.comp.stream.connection_stopped) + + def test_run_unhandled_error(self): + """Test main loop unhandled error from a component handler""" + def do_nothing(): + return + self.comp.time_unit = 1 + self.comp.stream = MockStreamRaiseException() + self.comp.stream_class = MockStreamRaiseException + self.comp.handle_tick = do_nothing + try: + self.comp.run() + except Exception, e: + threads = threading.enumerate() + self.assertEquals(len(threads), 1) + self.assertTrue(self.comp.stream.connection_stopped) + return + self.fail("No exception caught") + + def test_run_ni_handle_tick(self): + """Test JCLComponent 'NotImplemented' error from handle_tick method""" + self.comp.time_unit = 1 + self.comp.stream = MockStream() + self.comp.stream_class = MockStream + try: + self.comp.run() + except NotImplementedError, e: + threads = threading.enumerate() + self.assertEquals(len(threads), 1) + self.assertTrue(self.comp.stream.connection_stopped) + return + self.fail("No exception caught") + + def test_run_go_offline(self): + """Test main loop send offline presence when exiting""" + self.comp.stream = MockStream() + self.comp.stream_class = MockStream + self.comp.time_unit = 1 + self.max_tick_count = 1 + self.comp.handle_tick = self._handle_tick_test_time_handler + model.db_connect() + user1 = User(jid="test1@test.com") + account11 = Account(user=user1, + name="account11", + jid="account11@jcl.test.com") + account12 = Account(user=user1, + name="account12", + jid="account12@jcl.test.com") + account2 = Account(user=User(jid="test2@test.com"), + name="account2", + jid="account2@jcl.test.com") + model.db_disconnect() + self.comp.run() + self.assertTrue(self.comp.stream.connection_started) + threads = threading.enumerate() + self.assertEquals(len(threads), 1) + self.assertTrue(self.comp.stream.connection_stopped) + if self.comp.queue.qsize(): + raise self.comp.queue.get(0) + presence_sent = self.comp.stream.sent + self.assertEqual(len(presence_sent), 5) + self.assertEqual(len([presence + for presence in presence_sent + if presence.get_to_jid() == "test1@test.com"]), + 3) + self.assertEqual(\ + len([presence + for presence in presence_sent + if presence.get_from_jid() == \ + "jcl.test.com" + and presence.xpath_eval("@type")[0].get_content() + == "unavailable"]), + 2) + self.assertEqual(\ + len([presence + for presence in presence_sent + if presence.get_from_jid() == \ + "account11@jcl.test.com" + and presence.xpath_eval("@type")[0].get_content() \ + == "unavailable"]), + 1) + self.assertEqual(\ + len([presence + for presence in presence_sent + if presence.get_from_jid() == \ + "account12@jcl.test.com" + and presence.xpath_eval("@type")[0].get_content() \ + == "unavailable"]), + 1) + self.assertEqual(len([presence \ + for presence in presence_sent + if presence.get_to_jid() == "test2@test.com"]), + 2) + self.assertEqual(\ + len([presence + for presence in presence_sent + if presence.get_from_jid() == \ + "account2@jcl.test.com" + and presence.xpath_eval("@type")[0].get_content() \ + == "unavailable"]), + 1) + class Handler_TestCase(JCLTestCase): def setUp(self): self.handler = Handler(None) @@ -3230,6 +3264,7 @@ class AccountManager_TestCase(JCLTestCase): def suite(): test_suite = unittest.TestSuite() test_suite.addTest(unittest.makeSuite(JCLComponent_TestCase, 'test')) + test_suite.addTest(unittest.makeSuite(JCLComponent_run_TestCase, 'test')) test_suite.addTest(unittest.makeSuite(Handler_TestCase, 'test')) test_suite.addTest(unittest.makeSuite(AccountManager_TestCase, 'test')) return test_suite diff --git a/src/jcl/runner.py b/src/jcl/runner.py index f7c507f..19c75c7 100644 --- a/src/jcl/runner.py +++ b/src/jcl/runner.py @@ -25,6 +25,7 @@ import os import sys from ConfigParser import ConfigParser from getopt import gnu_getopt +import threading from jcl.lang import Lang from jcl.jabber.component import JCLComponent @@ -94,6 +95,7 @@ class JCLRunner(object): lambda arg: self.print_help())] self.logger = logging.getLogger() self.__debug = False + self.wait_event = threading.Event() def set_attr(self, attr, value): setattr(self, attr, value) @@ -222,7 +224,8 @@ class JCLRunner(object): self.component_version + " is starting ...") restart = True while restart: - restart = run_func() + (restart, time_to_wait) = run_func() + self.wait_event.wait(time_to_wait) self.logger.debug(self.component_name + " is exiting") finally: if os.path.exists(self.pid_file): diff --git a/src/jcl/tests/runner.py b/src/jcl/tests/runner.py index 3d3df2e..663f29e 100644 --- a/src/jcl/tests/runner.py +++ b/src/jcl/tests/runner.py @@ -206,7 +206,7 @@ class JCLRunner_TestCase(unittest.TestCase): self.has_run_func = False def run_func(component_self): self.has_run_func = True - return False + return (False, 0) self.runner.pid_file = "/tmp/jcl.pid" db_path = tempfile.mktemp("db", "jcltest", DB_DIR) @@ -234,7 +234,7 @@ class JCLRunner_TestCase(unittest.TestCase): db_url = "sqlite://" + db_path self.runner.db_url = db_url def do_nothing(): - pass + return (False, 0) self.runner._run(do_nothing) model.db_connect() # dropTable should succeed because tables should exist @@ -254,9 +254,9 @@ class JCLRunner_TestCase(unittest.TestCase): self.i = 0 def restart(self): self.i += 1 - yield True + yield (True, 0) self.i += 1 - yield False + yield (False, 0) self.i += 1 restart_generator = restart(self) self.runner._run(lambda : restart_generator.next())