From 9e28468b81b687a4029580f88f61220945cf8d7f Mon Sep 17 00:00:00 2001 From: David Rousselie Date: Mon, 9 Oct 2006 19:26:08 +0200 Subject: [PATCH] Add per thread SQLObject DBConnection make associated tests pass SQLObject multi-thread access ***END OF DESCRIPTION*** Place the long patch description above the ***END OF DESCRIPTION*** marker. The first line of this file will be the patch name. This patch contains the following changes: M ./run_tests.py -1 +6 M ./src/jcl/jabber/component.py -30 +70 M ./src/jcl/jabber/feeder.py -5 +11 M ./src/jcl/model/account.py +6 M ./tests/jcl/jabber/test_component.py -19 +62 M ./tests/jcl/jabber/test_feeder.py -7 +45 darcs-hash:20061009172608-86b55-804b8910c5ff19414a4f016289d3f03fb73866f9.gz --- run_tests.py | 7 +- src/jcl/jabber/component.py | 100 ++++++++++++++++++++--------- src/jcl/jabber/feeder.py | 16 +++-- src/jcl/model/account.py | 6 ++ tests/jcl/jabber/test_component.py | 81 +++++++++++++++++------ tests/jcl/jabber/test_feeder.py | 52 +++++++++++++-- 6 files changed, 200 insertions(+), 62 deletions(-) diff --git a/run_tests.py b/run_tests.py index 1debef5..7ed304a 100644 --- a/run_tests.py +++ b/run_tests.py @@ -49,13 +49,18 @@ if __name__ == '__main__': feeder_component_suite = unittest.makeSuite(FeederComponent_TestCase, "test") feeder_suite = unittest.makeSuite(Feeder_TestCase, "test") sender_suite = unittest.makeSuite(Sender_TestCase, "test") - +# jcl_suite = unittest.TestSuite() +# jcl_suite.addTest(FeederComponent_TestCase('test_handle_tick')) +# jcl_suite.addTest(FeederComponent_TestCase('test_run')) +# jcl_suite = unittest.TestSuite((feeder_component_suite)) +# jcl_suite = unittest.TestSuite((component_suite)) jcl_suite = unittest.TestSuite((component_suite, feeder_component_suite, feeder_suite, sender_suite)) test_support.run_suite(jcl_suite) + coverage.stop() coverage.analysis(jcl.jabber.component) coverage.analysis(jcl.jabber.feeder) diff --git a/src/jcl/jabber/component.py b/src/jcl/jabber/component.py index 0c5fc7b..0ad28c4 100644 --- a/src/jcl/jabber/component.py +++ b/src/jcl/jabber/component.py @@ -26,21 +26,27 @@ __revision__ = "$Id: component.py,v 1.3 2005/09/18 20:24:07 dax Exp $" -import thread +import sys + import threading import time import logging import signal import re +from Queue import Queue + +from sqlobject.dbconnection import connectionForURI + from pyxmpp.jid import JID from pyxmpp.jabberd.component import Component -from pyxmpp.jabber.disco import DiscoInfo, DiscoItems +from pyxmpp.jabber.disco import DiscoInfo, DiscoItems, DiscoItem from pyxmpp.message import Message from pyxmpp.presence import Presence from pyxmpp.streambase import StreamError, FatalStreamError from jcl.jabber.x import X +from jcl.model import account from jcl.model.account import Account VERSION = "0.1" @@ -57,12 +63,28 @@ class JCLComponent(Component): """ timeout = 1 + + def set_account_class(self, account_class): + """account_class attribut setter + create associated table via SQLObject""" + self.__account_class = account_class + self.db_connect() + self.__account_class.createTable() # TODO: ifNotExists = True) + self.db_disconnect() + + def get_account_class(self): + """account_class attribut getter""" + return self.__account_class + + account_class = property(get_account_class, set_account_class) + def __init__(self, jid, secret, server, port, + db_connection_str, disco_category = "gateway", disco_type = "headline"): Component.__init__(self, \ @@ -75,8 +97,9 @@ class JCLComponent(Component): # default values self.name = "Jabber Component Library generic component" self.spool_dir = "." + self.db_connection_str = db_connection_str self.__account_class = None - self.account_class = Account + self.set_account_class(Account) self.version = VERSION self.accounts = [] @@ -89,15 +112,6 @@ class JCLComponent(Component): signal.signal(signal.SIGINT, self.signal_handler) signal.signal(signal.SIGTERM, self.signal_handler) - - def set_account_class(self, account_class): - self.__account_class = account_class - self.__account_class.createTable(ifNotExists = True) - - def get_account_class(self): - return self.__account_class - - account_class = property(get_account_class, set_account_class) def run(self): """Main loop @@ -109,7 +123,12 @@ class JCLComponent(Component): self.spool_dir += "/" + str(self.jid) self.running = True self.connect() - thread.start_new_thread(self.time_handler, ()) + ## TODO : workaround to make test_run pass on FeederComponent +# time.sleep(1) + ## + timer_thread = threading.Thread(target = self.time_handler, \ + name = "TimerThread") + timer_thread.start() try: while (self.running and self.stream and not self.stream.eof and self.stream.socket is not None): @@ -134,21 +153,34 @@ class JCLComponent(Component): # to_jid = jid, \ # stanza_type = "unavailable") # self.stream.send(p) - threads = threading.enumerate() - for _thread in threads: - try: - _thread.join(10 * JCLComponent.timeout) - except: - pass - for _thread in threads: - try: - _thread.join(JCLComponent.timeout) - except: - pass +# threads = threading.enumerate() + timer_thread.join(JCLComponent.timeout) +# for _thread in threads: +# try: +# _thread.join(10 * JCLComponent.timeout) +# except: +# pass +# for _thread in threads: +# try: +# _thread.join(JCLComponent.timeout) +# except: +# pass self.disconnect() # TODO : terminate SQLObject self.__logger.debug("Exitting normally") + # TODO : terminate SQLObject + ################# + # SQlite connections are not multi-threaded + # Utils workaround methods + ################# + def db_connect(self): + account.hub.threadConnection = \ + connectionForURI(self.db_connection_str) + + def db_disconnect(self): +# account.hub.threadConnection.close() + del account.hub.threadConnection ########################################################################### @@ -198,6 +230,7 @@ class JCLComponent(Component): self.stream.set_message_handler("normal", \ self.handle_message) current_jid = None + self.db_connect() for account in self.account_class.select(orderBy = "user_jid"): if account.user_jid != current_jid: presence = Presence(from_jid = unicode(self.jid), \ @@ -209,6 +242,7 @@ class JCLComponent(Component): to_jid = account.user_jid, \ stanza_type = "probe") self.stream.send(presence) + self.db_disconnect() def signal_handler(self, signum, frame): """Stop method handler @@ -232,15 +266,19 @@ class JCLComponent(Component): """ self.__logger.debug("DISCO_GET_ITEMS") ## TODO Lang -## lang_class = self.__lang.get_lang_class_from_node(input_query.get_node()) -## base_from_jid = unicode(input_query.get_from().bare()) +## lang_class = self.__lang.get_lang_class_from_node(info_query.get_node()) + base_from_jid = unicode(info_query.get_from().bare()) disco_items = DiscoItems() if not node: ## TODO : list accounts - for account in self.accounts: + self.db_connect() + for account in self.account_class.select(Account.q.user_jid == \ + base_from_jid): self.__logger.debug(str(account)) -## DiscoItem(di, JID(name + "@" + unicode(self.jid)), \ -## name, str_name) + DiscoItem(disco_items, \ + JID(account.jid), \ + account.name, account.long_name) + self.db_disconnect() return disco_items def handle_get_version(self, input_query): @@ -266,9 +304,11 @@ class JCLComponent(Component): input_query = input_query.make_result_response() query = input_query.new_query("jabber:iq:register") if to_jid and to_jid != self.jid: + self.db_connect() self.get_reg_form_init(lang_class, \ - self.accounts.select() # TODO + self.account_class.select() # TODO ).attach_xml(query) + self.db_disconnect() else: self.get_reg_form(lang_class).attach_xml(query) self.stream.send(input_query) diff --git a/src/jcl/jabber/feeder.py b/src/jcl/jabber/feeder.py index 3fb5cd9..8d6b4c7 100644 --- a/src/jcl/jabber/feeder.py +++ b/src/jcl/jabber/feeder.py @@ -40,12 +40,14 @@ class FeederComponent(JCLComponent): jid, secret, server, - port): + port, + db_connection_str): JCLComponent.__init__(self, \ jid, \ secret, \ server, \ - port) + port, \ + db_connection_str) self.name = "Generic Feeder Component" # Define default feeder and sender, can be override self.feeder = Feeder() @@ -56,9 +58,13 @@ class FeederComponent(JCLComponent): def handle_tick(self): """Implement main feed/send behavior""" - for account in Account.select(): - for data in self.feeder.feed(account): - self.sender.send(account, data) + pass + self.db_connect() + for acc in self.account_class.select(): + print "OK" +# for data in self.feeder.feed(account): +# self.sender.send(account, data) + self.db_disconnect() diff --git a/src/jcl/model/account.py b/src/jcl/model/account.py index 5c9bf3a..6883333 100644 --- a/src/jcl/model/account.py +++ b/src/jcl/model/account.py @@ -28,9 +28,15 @@ __revision__ = "$Id: account.py,v 1.3 2005/09/18 20:24:07 dax Exp $" from sqlobject.main import SQLObject from sqlobject.col import StringCol +from sqlobject.dbconnection import ConnectionHub + +# create a hub to attach a per thread connection +hub = ConnectionHub() class Account(SQLObject): """Base Account class""" + _cacheValue = False + _connection = hub user_jid = StringCol() name = StringCol() jid = StringCol() diff --git a/tests/jcl/jabber/test_component.py b/tests/jcl/jabber/test_component.py index a4400e0..f6dd99c 100644 --- a/tests/jcl/jabber/test_component.py +++ b/tests/jcl/jabber/test_component.py @@ -26,18 +26,32 @@ import unittest import thread import threading import time +import sys +import os from sqlobject import * +from sqlobject.dbconnection import TheURIOpener from jcl.jabber.component import JCLComponent +from jcl.model import account from jcl.model.account import Account from jcl.lang import Lang +DB_PATH = "/tmp/test.db" +DB_URL = DB_PATH# + "?debug=1&debugThreading=1" + class MockStream(object): - def __init__(self): + def __init__(self, \ + jid = "", + secret = "", + server = "", + port = "", + keepalive = True): self.sended = [] self.connection_started = False - self.connection_stoped = False + self.connection_stopped = False + self.eof = False + self.socket = [] def send(self, iq): self.sended.append(iq) @@ -77,43 +91,65 @@ class MockStream(object): self.connection_started = True def disconnect(self): - self.connection_stoped = True + self.connection_stopped = True def loop_iter(self, timeout): time.sleep(timeout) - + + def close(self): + pass + class JCLComponent_TestCase(unittest.TestCase): def setUp(self): - connection = sqlhub.processConnection = connectionForURI('sqlite:/:memory:') + if os.path.exists(DB_PATH): + os.unlink(DB_PATH) self.comp = JCLComponent("jcl.test.com", "password", "localhost", - "5347") + "5347", + 'sqlite://' + DB_URL) + self.max_tick_count = 2 def tearDown(self): + account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL) Account.dropTable(ifExists = True) + del TheURIOpener.cachedURIs['sqlite://' + DB_URL] + account.hub.threadConnection.close() + del account.hub.threadConnection + if os.path.exists(DB_PATH): + os.unlink(DB_PATH) def test_constructor(self): + account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL) self.assertTrue(Account._connection.tableExists("account")) + if os.path.exists(DB_PATH): + print DB_PATH + " exists cons" + del account.hub.threadConnection def test_run(self): self.comp.stream = MockStream() - run_thread = thread.start_new_thread(self.comp.run, ()) + self.comp.stream_class = MockStream + run_thread = threading.Thread(target = self.comp.run, \ + name = "run_thread") + run_thread.start() + time.sleep(1) self.assertTrue(self.comp.stream.connection_started) self.comp.running = False time.sleep(JCLComponent.timeout + 1) threads = threading.enumerate() - self.assertNone(threads) - for _thread in threads: - try: - _thread.join(1) - except: - pass - self.assertTrue(self.comp.connection_stoped) + self.assertEquals(len(threads), 1) + self.assertTrue(self.comp.stream.connection_stopped) + if self.comp.queue.qsize(): + raise self.comp.queue.get(0) def test_run_go_offline(self): ## TODO : verify offline stanza are sent pass + + def __handle_tick_test_time_handler(self): + self.max_tick_count -= 1 + if self.max_tick_count == 0: + self.comp.running = False def test_authenticated_handler(self): self.comp.stream = MockStream() @@ -121,12 +157,17 @@ class JCLComponent_TestCase(unittest.TestCase): self.assertTrue(True) def test_authenticated_send_probe(self): + account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL) account11 = Account(user_jid = "test1@test.com", \ - name = "test11") + name = "test11", \ + jid = "account11@jcl.test.com") account12 = Account(user_jid = "test1@test.com", \ - name = "test12") + name = "test12", \ + jid = "account12@jcl.test.com") account2 = Account(user_jid = "test2@test.com", \ - name = "test2") + name = "test2", \ + jid = "account2@jcl.test.com") + del account.hub.threadConnection self.comp.stream = stream = MockStream() self.comp.authenticated() @@ -150,8 +191,10 @@ class JCLComponent_TestCase(unittest.TestCase): self.assertTrue(True) def test_get_reg_form_init(self): - account = Account(user_jid = "", name = "") - self.comp.get_reg_form_init(Lang.en, account) + account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL) + account1 = Account(user_jid = "", name = "", jid = "") + del account.hub.threadConnection + self.comp.get_reg_form_init(Lang.en, account1) self.assertTrue(True) def test_disco_get_info(self): diff --git a/tests/jcl/jabber/test_feeder.py b/tests/jcl/jabber/test_feeder.py index 52005c9..65de583 100644 --- a/tests/jcl/jabber/test_feeder.py +++ b/tests/jcl/jabber/test_feeder.py @@ -22,42 +22,80 @@ ## import unittest +import os + from sqlobject import * +from sqlobject.dbconnection import TheURIOpener from tests.jcl.jabber.test_component import JCLComponent_TestCase -from jcl.jabber.feeder import Feeder, Sender +from jcl.jabber.feeder import FeederComponent, Feeder, Sender from jcl.model.account import Account +from jcl.model import account + +DB_PATH = "/tmp/test.db" +DB_URL = DB_PATH #+ "?debug=1&debugThreading=1" class FeederComponent_TestCase(JCLComponent_TestCase): - pass + def setUp(self): + if os.path.exists(DB_PATH): + os.unlink(DB_PATH) + self.comp = FeederComponent("jcl.test.com", + "password", + "localhost", + "5347", + 'sqlite://' + DB_URL) + + def tearDown(self): + account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL) + Account.dropTable(ifExists = True) + del TheURIOpener.cachedURIs['sqlite://' + DB_URL] + account.hub.threadConnection.close() + del account.hub.threadConnection + if os.path.exists(DB_PATH): + os.unlink(DB_PATH) + + def test_constructor(self): + account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL) + self.assertTrue(Account._connection.tableExists("account")) + del account.hub.threadConnection class Feeder_TestCase(unittest.TestCase): def setUp(self): - connection = sqlhub.processConnection = connectionForURI('sqlite:/:memory:') + if os.path.exists(DB_PATH): + os.unlink(DB_PATH) + account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL) Account.createTable() def tearDown(self): Account.dropTable(ifExists = True) + del account.hub.threadConnection +# os.unlink(DB_PATH) def test_feed_exist(self): feeder = Feeder() feeder.feed(Account(user_jid = "test@test.com", \ - name = "test")) + name = "test", \ + jid = "test@jcl.test.com")) self.assertTrue(True) class Sender_TestCase(unittest.TestCase): def setUp(self): - connection = sqlhub.processConnection = connectionForURI('sqlite:/:memory:') + if os.path.exists(DB_PATH): + os.unlink(DB_PATH) + account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL) Account.createTable() def tearDown(self): Account.dropTable(ifExists = True) - + del account.hub.threadConnection +# os.unlink(DB_PATH) + def test_send_exist(self): sender = Sender() account = Account(user_jid = "test@test.com", \ - name = "test") + name = "test", \ + jid = "test@jcl.test.com") sender.send(to_account = account, \ message = "Hello World") self.assertTrue(True)