Restart component on Jabber connection failure

darcs-hash:20080819060517-86b55-a7f828daab8f171e9f4b168d19aa4c288f2fc747.gz
This commit is contained in:
David Rousselie
2008-08-19 08:05:17 +02:00
parent 113c242f05
commit 48f562df6c
4 changed files with 208 additions and 166 deletions

View File

@@ -36,6 +36,7 @@ import re
import traceback
import string
import time
import socket
from Queue import Queue
@@ -665,10 +666,10 @@ class JCLComponent(Component, object):
Call Component main loop
Clean up when shutting down JCLcomponent
"""
self.connect()
self.spool_dir += "/" + unicode(self.jid)
self.running = True
self.last_activity = int(time.time())
self.connect()
timer_thread = threading.Thread(target=self.time_handler,
name="TimerThread")
timer_thread.start()
@@ -679,17 +680,20 @@ class JCLComponent(Component, object):
self.stream.loop_iter(JCLComponent.timeout)
if self.queue.qsize():
raise self.queue.get(0)
except socket.error, e:
self.__logger.info("Connection failed, restarting.")
return (True, 5)
finally:
self.running = False
timer_thread.join(JCLComponent.timeout)
self.wait_event.set()
if self.stream and not self.stream.eof \
and self.stream.socket is not None:
presences = self.account_manager.get_presence_all("unavailable")
self.send_stanzas(presences)
self.wait_event.set()
timer_thread.join(JCLComponent.timeout)
self.disconnect()
self.__logger.debug("Exitting normally")
return self._restart
self.disconnect()
self.__logger.debug("Exitting normally")
return (self._restart, 0)
def _get_restart(self):
return self._restart

View File

@@ -29,6 +29,7 @@ import re
from ConfigParser import ConfigParser
import tempfile
import os
import socket
from pyxmpp.jid import JID
from pyxmpp.iq import Iq
@@ -56,8 +57,8 @@ class MockStream(object):
jid="",
secret="",
server="",
port="",
keepalive=True):
port=1,
keepalive=None):
self.sent = []
self.connection_started = False
self.connection_stopped = False
@@ -117,7 +118,9 @@ class MockStream(object):
class MockStreamNoConnect(MockStream):
def connect(self):
self.connection_started = True
self.eof = True
def loop_iter(self, timeout):
return
class MockStreamRaiseException(MockStream):
def loop_iter(self, timeout):
@@ -222,160 +225,10 @@ class JCLComponent_TestCase(JCLTestCase):
self.assertEquals(len(handler1.handled), 1)
self.assertEquals(len(handler2.handled), 0)
###########################################################################
# 'run' tests
###########################################################################
def __comp_run(self):
try:
self.comp.run()
except:
# Ignore exception, might be obtain from self.comp.queue
pass
def __comp_time_handler(self):
try:
self.saved_time_handler()
except:
# Ignore exception, might be obtain from self.comp.queue
pass
def test_run(self):
"""Test basic main loop execution"""
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
# Tests in subclasses might be more precise
self.comp.stream = MockStreamNoConnect()
self.comp.stream_class = MockStreamNoConnect
result = self.comp.run()
self.assertFalse(result)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
def test_run_restart(self):
"""Test main loop execution with restart"""
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
# Tests in subclasses might be more precise
self.comp.stream = MockStreamNoConnect()
self.comp.stream_class = MockStreamNoConnect
self.comp.restart = True
result = self.comp.run()
self.assertTrue(result)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
def test_run_unhandled_error(self):
"""Test main loop unhandled error from a component handler"""
def do_nothing():
return
self.comp.time_unit = 1
self.comp.stream = MockStreamRaiseException()
self.comp.stream_class = MockStreamRaiseException
self.comp.handle_tick = do_nothing
try:
self.comp.run()
except Exception, e:
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
return
self.fail("No exception caught")
def test_run_ni_handle_tick(self):
"""Test JCLComponent 'NotImplemented' error from handle_tick method"""
self.comp.time_unit = 1
self.comp.stream = MockStream()
self.comp.stream_class = MockStream
try:
self.comp.run()
except NotImplementedError, e:
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
return
self.fail("No exception caught")
def test_run_go_offline(self):
"""Test main loop send offline presence when exiting"""
self.comp.stream = MockStream()
self.comp.stream_class = MockStream
self.comp.time_unit = 1
self.max_tick_count = 1
self.comp.handle_tick = self.__handle_tick_test_time_handler
model.db_connect()
user1 = User(jid="test1@test.com")
account11 = Account(user=user1,
name="account11",
jid="account11@jcl.test.com")
account12 = Account(user=user1,
name="account12",
jid="account12@jcl.test.com")
account2 = Account(user=User(jid="test2@test.com"),
name="account2",
jid="account2@jcl.test.com")
model.db_disconnect()
self.comp.run()
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
presence_sent = self.comp.stream.sent
self.assertEqual(len(presence_sent), 5)
self.assertEqual(len([presence
for presence in presence_sent
if presence.get_to_jid() == "test1@test.com"]),
3)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"jcl.test.com"
and presence.xpath_eval("@type")[0].get_content()
== "unavailable"]),
2)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account11@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account12@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
self.assertEqual(len([presence \
for presence in presence_sent
if presence.get_to_jid() == "test2@test.com"]),
2)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account2@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
###########################################################################
# 'time_handler' tests
###########################################################################
def __handle_tick_test_time_handler(self):
def _handle_tick_test_time_handler(self):
self.max_tick_count -= 1
if self.max_tick_count == 0:
self.comp.running = False
@@ -383,7 +236,7 @@ class JCLComponent_TestCase(JCLTestCase):
def test_time_handler(self):
self.comp.time_unit = 1
self.max_tick_count = 1
self.comp.handle_tick = self.__handle_tick_test_time_handler
self.comp.handle_tick = self._handle_tick_test_time_handler
self.comp.stream = MockStream()
self.comp.running = True
self.comp.time_handler()
@@ -3054,6 +2907,187 @@ class JCLComponent_TestCase(JCLTestCase):
self.assertEquals(fields[1].children.name, "value")
self.assertEquals(fields[1].children.content, "1")
###########################################################################
# 'run' tests
###########################################################################
class JCLComponent_run_TestCase(JCLComponent_TestCase):
def __comp_run(self):
try:
self.comp.run()
except:
# Ignore exception, might be obtain from self.comp.queue
pass
def __comp_time_handler(self):
try:
self.saved_time_handler()
except:
# Ignore exception, might be obtain from self.comp.queue
pass
def test_run(self):
"""Test basic main loop execution"""
def do_nothing():
self.comp.running = False
return
self.comp.handle_tick = do_nothing
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
self.comp.stream = MockStreamNoConnect()
self.comp.stream_class = MockStreamNoConnect
(result, time_to_wait) = self.comp.run()
self.assertEquals(time_to_wait, 0)
self.assertFalse(result)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
def test_run_restart(self):
"""Test main loop execution with restart"""
def do_nothing():
self.comp.running = False
return
self.comp.handle_tick = do_nothing
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
self.comp.stream = MockStreamNoConnect()
self.comp.stream_class = MockStreamNoConnect
self.comp.restart = True
(result, time_to_wait) = self.comp.run()
self.assertEquals(time_to_wait, 0)
self.assertTrue(result)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
def test_run_connection_failed(self):
"""Test when connection to Jabber server failed"""
class MockStreamLoopFailed(MockStream):
def connect(self):
self.connection_started = True
def loop_iter(self, timeout):
self.socket = None
raise socket.error
self.comp.time_unit = 1
# Do not loop, handle_tick is virtual
self.comp.stream = MockStreamLoopFailed()
self.comp.stream_class = MockStreamLoopFailed
self.comp.restart = False
(result, time_to_wait) = self.comp.run()
self.assertEquals(time_to_wait, 5)
self.assertTrue(result)
self.assertFalse(self.comp.running)
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertFalse(self.comp.stream.connection_stopped)
def test_run_unhandled_error(self):
"""Test main loop unhandled error from a component handler"""
def do_nothing():
return
self.comp.time_unit = 1
self.comp.stream = MockStreamRaiseException()
self.comp.stream_class = MockStreamRaiseException
self.comp.handle_tick = do_nothing
try:
self.comp.run()
except Exception, e:
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
return
self.fail("No exception caught")
def test_run_ni_handle_tick(self):
"""Test JCLComponent 'NotImplemented' error from handle_tick method"""
self.comp.time_unit = 1
self.comp.stream = MockStream()
self.comp.stream_class = MockStream
try:
self.comp.run()
except NotImplementedError, e:
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
return
self.fail("No exception caught")
def test_run_go_offline(self):
"""Test main loop send offline presence when exiting"""
self.comp.stream = MockStream()
self.comp.stream_class = MockStream
self.comp.time_unit = 1
self.max_tick_count = 1
self.comp.handle_tick = self._handle_tick_test_time_handler
model.db_connect()
user1 = User(jid="test1@test.com")
account11 = Account(user=user1,
name="account11",
jid="account11@jcl.test.com")
account12 = Account(user=user1,
name="account12",
jid="account12@jcl.test.com")
account2 = Account(user=User(jid="test2@test.com"),
name="account2",
jid="account2@jcl.test.com")
model.db_disconnect()
self.comp.run()
self.assertTrue(self.comp.stream.connection_started)
threads = threading.enumerate()
self.assertEquals(len(threads), 1)
self.assertTrue(self.comp.stream.connection_stopped)
if self.comp.queue.qsize():
raise self.comp.queue.get(0)
presence_sent = self.comp.stream.sent
self.assertEqual(len(presence_sent), 5)
self.assertEqual(len([presence
for presence in presence_sent
if presence.get_to_jid() == "test1@test.com"]),
3)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"jcl.test.com"
and presence.xpath_eval("@type")[0].get_content()
== "unavailable"]),
2)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account11@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account12@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
self.assertEqual(len([presence \
for presence in presence_sent
if presence.get_to_jid() == "test2@test.com"]),
2)
self.assertEqual(\
len([presence
for presence in presence_sent
if presence.get_from_jid() == \
"account2@jcl.test.com"
and presence.xpath_eval("@type")[0].get_content() \
== "unavailable"]),
1)
class Handler_TestCase(JCLTestCase):
def setUp(self):
self.handler = Handler(None)
@@ -3230,6 +3264,7 @@ class AccountManager_TestCase(JCLTestCase):
def suite():
test_suite = unittest.TestSuite()
test_suite.addTest(unittest.makeSuite(JCLComponent_TestCase, 'test'))
test_suite.addTest(unittest.makeSuite(JCLComponent_run_TestCase, 'test'))
test_suite.addTest(unittest.makeSuite(Handler_TestCase, 'test'))
test_suite.addTest(unittest.makeSuite(AccountManager_TestCase, 'test'))
return test_suite

View File

@@ -25,6 +25,7 @@ import os
import sys
from ConfigParser import ConfigParser
from getopt import gnu_getopt
import threading
from jcl.lang import Lang
from jcl.jabber.component import JCLComponent
@@ -94,6 +95,7 @@ class JCLRunner(object):
lambda arg: self.print_help())]
self.logger = logging.getLogger()
self.__debug = False
self.wait_event = threading.Event()
def set_attr(self, attr, value):
setattr(self, attr, value)
@@ -222,7 +224,8 @@ class JCLRunner(object):
self.component_version + " is starting ...")
restart = True
while restart:
restart = run_func()
(restart, time_to_wait) = run_func()
self.wait_event.wait(time_to_wait)
self.logger.debug(self.component_name + " is exiting")
finally:
if os.path.exists(self.pid_file):

View File

@@ -206,7 +206,7 @@ class JCLRunner_TestCase(unittest.TestCase):
self.has_run_func = False
def run_func(component_self):
self.has_run_func = True
return False
return (False, 0)
self.runner.pid_file = "/tmp/jcl.pid"
db_path = tempfile.mktemp("db", "jcltest", DB_DIR)
@@ -234,7 +234,7 @@ class JCLRunner_TestCase(unittest.TestCase):
db_url = "sqlite://" + db_path
self.runner.db_url = db_url
def do_nothing():
pass
return (False, 0)
self.runner._run(do_nothing)
model.db_connect()
# dropTable should succeed because tables should exist
@@ -254,9 +254,9 @@ class JCLRunner_TestCase(unittest.TestCase):
self.i = 0
def restart(self):
self.i += 1
yield True
yield (True, 0)
self.i += 1
yield False
yield (False, 0)
self.i += 1
restart_generator = restart(self)
self.runner._run(lambda : restart_generator.next())