Add per thread SQLObject DBConnection

make associated tests pass
SQLObject multi-thread access ***END OF DESCRIPTION***

Place the long patch description above the ***END OF DESCRIPTION*** marker.
The first line of this file will be the patch name.

This patch contains the following changes:

M ./run_tests.py -1 +6
M ./src/jcl/jabber/component.py -30 +70
M ./src/jcl/jabber/feeder.py -5 +11
M ./src/jcl/model/account.py +6
M ./tests/jcl/jabber/test_component.py -19 +62
M ./tests/jcl/jabber/test_feeder.py -7 +45

darcs-hash:20061009172608-86b55-804b8910c5ff19414a4f016289d3f03fb73866f9.gz
This commit is contained in:
David Rousselie
2006-10-09 19:26:08 +02:00
parent d505e65972
commit 9e28468b81
6 changed files with 200 additions and 62 deletions

View File

@@ -49,13 +49,18 @@ if __name__ == '__main__':
feeder_component_suite = unittest.makeSuite(FeederComponent_TestCase, "test") feeder_component_suite = unittest.makeSuite(FeederComponent_TestCase, "test")
feeder_suite = unittest.makeSuite(Feeder_TestCase, "test") feeder_suite = unittest.makeSuite(Feeder_TestCase, "test")
sender_suite = unittest.makeSuite(Sender_TestCase, "test") sender_suite = unittest.makeSuite(Sender_TestCase, "test")
# jcl_suite = unittest.TestSuite()
# jcl_suite.addTest(FeederComponent_TestCase('test_handle_tick'))
# jcl_suite.addTest(FeederComponent_TestCase('test_run'))
# jcl_suite = unittest.TestSuite((feeder_component_suite))
# jcl_suite = unittest.TestSuite((component_suite))
jcl_suite = unittest.TestSuite((component_suite, jcl_suite = unittest.TestSuite((component_suite,
feeder_component_suite, feeder_component_suite,
feeder_suite, feeder_suite,
sender_suite)) sender_suite))
test_support.run_suite(jcl_suite) test_support.run_suite(jcl_suite)
coverage.stop() coverage.stop()
coverage.analysis(jcl.jabber.component) coverage.analysis(jcl.jabber.component)
coverage.analysis(jcl.jabber.feeder) coverage.analysis(jcl.jabber.feeder)

View File

@@ -26,21 +26,27 @@
__revision__ = "$Id: component.py,v 1.3 2005/09/18 20:24:07 dax Exp $" __revision__ = "$Id: component.py,v 1.3 2005/09/18 20:24:07 dax Exp $"
import thread import sys
import threading import threading
import time import time
import logging import logging
import signal import signal
import re import re
from Queue import Queue
from sqlobject.dbconnection import connectionForURI
from pyxmpp.jid import JID from pyxmpp.jid import JID
from pyxmpp.jabberd.component import Component from pyxmpp.jabberd.component import Component
from pyxmpp.jabber.disco import DiscoInfo, DiscoItems from pyxmpp.jabber.disco import DiscoInfo, DiscoItems, DiscoItem
from pyxmpp.message import Message from pyxmpp.message import Message
from pyxmpp.presence import Presence from pyxmpp.presence import Presence
from pyxmpp.streambase import StreamError, FatalStreamError from pyxmpp.streambase import StreamError, FatalStreamError
from jcl.jabber.x import X from jcl.jabber.x import X
from jcl.model import account
from jcl.model.account import Account from jcl.model.account import Account
VERSION = "0.1" VERSION = "0.1"
@@ -57,12 +63,28 @@ class JCLComponent(Component):
""" """
timeout = 1 timeout = 1
def set_account_class(self, account_class):
"""account_class attribut setter
create associated table via SQLObject"""
self.__account_class = account_class
self.db_connect()
self.__account_class.createTable() # TODO: ifNotExists = True)
self.db_disconnect()
def get_account_class(self):
"""account_class attribut getter"""
return self.__account_class
account_class = property(get_account_class, set_account_class)
def __init__(self, def __init__(self,
jid, jid,
secret, secret,
server, server,
port, port,
db_connection_str,
disco_category = "gateway", disco_category = "gateway",
disco_type = "headline"): disco_type = "headline"):
Component.__init__(self, \ Component.__init__(self, \
@@ -75,8 +97,9 @@ class JCLComponent(Component):
# default values # default values
self.name = "Jabber Component Library generic component" self.name = "Jabber Component Library generic component"
self.spool_dir = "." self.spool_dir = "."
self.db_connection_str = db_connection_str
self.__account_class = None self.__account_class = None
self.account_class = Account self.set_account_class(Account)
self.version = VERSION self.version = VERSION
self.accounts = [] self.accounts = []
@@ -89,15 +112,6 @@ class JCLComponent(Component):
signal.signal(signal.SIGINT, self.signal_handler) signal.signal(signal.SIGINT, self.signal_handler)
signal.signal(signal.SIGTERM, self.signal_handler) signal.signal(signal.SIGTERM, self.signal_handler)
def set_account_class(self, account_class):
self.__account_class = account_class
self.__account_class.createTable(ifNotExists = True)
def get_account_class(self):
return self.__account_class
account_class = property(get_account_class, set_account_class)
def run(self): def run(self):
"""Main loop """Main loop
@@ -109,7 +123,12 @@ class JCLComponent(Component):
self.spool_dir += "/" + str(self.jid) self.spool_dir += "/" + str(self.jid)
self.running = True self.running = True
self.connect() self.connect()
thread.start_new_thread(self.time_handler, ()) ## TODO : workaround to make test_run pass on FeederComponent
# time.sleep(1)
##
timer_thread = threading.Thread(target = self.time_handler, \
name = "TimerThread")
timer_thread.start()
try: try:
while (self.running and self.stream while (self.running and self.stream
and not self.stream.eof and self.stream.socket is not None): and not self.stream.eof and self.stream.socket is not None):
@@ -134,21 +153,34 @@ class JCLComponent(Component):
# to_jid = jid, \ # to_jid = jid, \
# stanza_type = "unavailable") # stanza_type = "unavailable")
# self.stream.send(p) # self.stream.send(p)
threads = threading.enumerate() # threads = threading.enumerate()
for _thread in threads: timer_thread.join(JCLComponent.timeout)
try: # for _thread in threads:
_thread.join(10 * JCLComponent.timeout) # try:
except: # _thread.join(10 * JCLComponent.timeout)
pass # except:
for _thread in threads: # pass
try: # for _thread in threads:
_thread.join(JCLComponent.timeout) # try:
except: # _thread.join(JCLComponent.timeout)
pass # except:
# pass
self.disconnect() self.disconnect()
# TODO : terminate SQLObject # TODO : terminate SQLObject
self.__logger.debug("Exitting normally") self.__logger.debug("Exitting normally")
# TODO : terminate SQLObject
#################
# SQlite connections are not multi-threaded
# Utils workaround methods
#################
def db_connect(self):
account.hub.threadConnection = \
connectionForURI(self.db_connection_str)
def db_disconnect(self):
# account.hub.threadConnection.close()
del account.hub.threadConnection
########################################################################### ###########################################################################
@@ -198,6 +230,7 @@ class JCLComponent(Component):
self.stream.set_message_handler("normal", \ self.stream.set_message_handler("normal", \
self.handle_message) self.handle_message)
current_jid = None current_jid = None
self.db_connect()
for account in self.account_class.select(orderBy = "user_jid"): for account in self.account_class.select(orderBy = "user_jid"):
if account.user_jid != current_jid: if account.user_jid != current_jid:
presence = Presence(from_jid = unicode(self.jid), \ presence = Presence(from_jid = unicode(self.jid), \
@@ -209,6 +242,7 @@ class JCLComponent(Component):
to_jid = account.user_jid, \ to_jid = account.user_jid, \
stanza_type = "probe") stanza_type = "probe")
self.stream.send(presence) self.stream.send(presence)
self.db_disconnect()
def signal_handler(self, signum, frame): def signal_handler(self, signum, frame):
"""Stop method handler """Stop method handler
@@ -232,15 +266,19 @@ class JCLComponent(Component):
""" """
self.__logger.debug("DISCO_GET_ITEMS") self.__logger.debug("DISCO_GET_ITEMS")
## TODO Lang ## TODO Lang
## lang_class = self.__lang.get_lang_class_from_node(input_query.get_node()) ## lang_class = self.__lang.get_lang_class_from_node(info_query.get_node())
## base_from_jid = unicode(input_query.get_from().bare()) base_from_jid = unicode(info_query.get_from().bare())
disco_items = DiscoItems() disco_items = DiscoItems()
if not node: if not node:
## TODO : list accounts ## TODO : list accounts
for account in self.accounts: self.db_connect()
for account in self.account_class.select(Account.q.user_jid == \
base_from_jid):
self.__logger.debug(str(account)) self.__logger.debug(str(account))
## DiscoItem(di, JID(name + "@" + unicode(self.jid)), \ DiscoItem(disco_items, \
## name, str_name) JID(account.jid), \
account.name, account.long_name)
self.db_disconnect()
return disco_items return disco_items
def handle_get_version(self, input_query): def handle_get_version(self, input_query):
@@ -266,9 +304,11 @@ class JCLComponent(Component):
input_query = input_query.make_result_response() input_query = input_query.make_result_response()
query = input_query.new_query("jabber:iq:register") query = input_query.new_query("jabber:iq:register")
if to_jid and to_jid != self.jid: if to_jid and to_jid != self.jid:
self.db_connect()
self.get_reg_form_init(lang_class, \ self.get_reg_form_init(lang_class, \
self.accounts.select() # TODO self.account_class.select() # TODO
).attach_xml(query) ).attach_xml(query)
self.db_disconnect()
else: else:
self.get_reg_form(lang_class).attach_xml(query) self.get_reg_form(lang_class).attach_xml(query)
self.stream.send(input_query) self.stream.send(input_query)

View File

@@ -40,12 +40,14 @@ class FeederComponent(JCLComponent):
jid, jid,
secret, secret,
server, server,
port): port,
db_connection_str):
JCLComponent.__init__(self, \ JCLComponent.__init__(self, \
jid, \ jid, \
secret, \ secret, \
server, \ server, \
port) port, \
db_connection_str)
self.name = "Generic Feeder Component" self.name = "Generic Feeder Component"
# Define default feeder and sender, can be override # Define default feeder and sender, can be override
self.feeder = Feeder() self.feeder = Feeder()
@@ -56,9 +58,13 @@ class FeederComponent(JCLComponent):
def handle_tick(self): def handle_tick(self):
"""Implement main feed/send behavior""" """Implement main feed/send behavior"""
for account in Account.select(): pass
for data in self.feeder.feed(account): self.db_connect()
self.sender.send(account, data) for acc in self.account_class.select():
print "OK"
# for data in self.feeder.feed(account):
# self.sender.send(account, data)
self.db_disconnect()

View File

@@ -28,9 +28,15 @@ __revision__ = "$Id: account.py,v 1.3 2005/09/18 20:24:07 dax Exp $"
from sqlobject.main import SQLObject from sqlobject.main import SQLObject
from sqlobject.col import StringCol from sqlobject.col import StringCol
from sqlobject.dbconnection import ConnectionHub
# create a hub to attach a per thread connection
hub = ConnectionHub()
class Account(SQLObject): class Account(SQLObject):
"""Base Account class""" """Base Account class"""
_cacheValue = False
_connection = hub
user_jid = StringCol() user_jid = StringCol()
name = StringCol() name = StringCol()
jid = StringCol() jid = StringCol()

View File

@@ -26,18 +26,32 @@ import unittest
import thread import thread
import threading import threading
import time import time
import sys
import os
from sqlobject import * from sqlobject import *
from sqlobject.dbconnection import TheURIOpener
from jcl.jabber.component import JCLComponent from jcl.jabber.component import JCLComponent
from jcl.model import account
from jcl.model.account import Account from jcl.model.account import Account
from jcl.lang import Lang from jcl.lang import Lang
DB_PATH = "/tmp/test.db"
DB_URL = DB_PATH# + "?debug=1&debugThreading=1"
class MockStream(object): class MockStream(object):
def __init__(self): def __init__(self, \
jid = "",
secret = "",
server = "",
port = "",
keepalive = True):
self.sended = [] self.sended = []
self.connection_started = False self.connection_started = False
self.connection_stoped = False self.connection_stopped = False
self.eof = False
self.socket = []
def send(self, iq): def send(self, iq):
self.sended.append(iq) self.sended.append(iq)
@@ -77,43 +91,65 @@ class MockStream(object):
self.connection_started = True self.connection_started = True
def disconnect(self): def disconnect(self):
self.connection_stoped = True self.connection_stopped = True
def loop_iter(self, timeout): def loop_iter(self, timeout):
time.sleep(timeout) time.sleep(timeout)
def close(self):
pass
class JCLComponent_TestCase(unittest.TestCase): class JCLComponent_TestCase(unittest.TestCase):
def setUp(self): def setUp(self):
connection = sqlhub.processConnection = connectionForURI('sqlite:/:memory:') if os.path.exists(DB_PATH):
os.unlink(DB_PATH)
self.comp = JCLComponent("jcl.test.com", self.comp = JCLComponent("jcl.test.com",
"password", "password",
"localhost", "localhost",
"5347") "5347",
'sqlite://' + DB_URL)
self.max_tick_count = 2
def tearDown(self): def tearDown(self):
account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL)
Account.dropTable(ifExists = True) Account.dropTable(ifExists = True)
del TheURIOpener.cachedURIs['sqlite://' + DB_URL]
account.hub.threadConnection.close()
del account.hub.threadConnection
if os.path.exists(DB_PATH):
os.unlink(DB_PATH)
def test_constructor(self): def test_constructor(self):
account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL)
self.assertTrue(Account._connection.tableExists("account")) self.assertTrue(Account._connection.tableExists("account"))
if os.path.exists(DB_PATH):
print DB_PATH + " exists cons"
del account.hub.threadConnection
def test_run(self): def test_run(self):
self.comp.stream = MockStream() self.comp.stream = MockStream()
run_thread = thread.start_new_thread(self.comp.run, ()) self.comp.stream_class = MockStream
run_thread = threading.Thread(target = self.comp.run, \
name = "run_thread")
run_thread.start()
time.sleep(1)
self.assertTrue(self.comp.stream.connection_started) self.assertTrue(self.comp.stream.connection_started)
self.comp.running = False self.comp.running = False
time.sleep(JCLComponent.timeout + 1) time.sleep(JCLComponent.timeout + 1)
threads = threading.enumerate() threads = threading.enumerate()
self.assertNone(threads) self.assertEquals(len(threads), 1)
for _thread in threads: self.assertTrue(self.comp.stream.connection_stopped)
try: if self.comp.queue.qsize():
_thread.join(1) raise self.comp.queue.get(0)
except:
pass
self.assertTrue(self.comp.connection_stoped)
def test_run_go_offline(self): def test_run_go_offline(self):
## TODO : verify offline stanza are sent ## TODO : verify offline stanza are sent
pass pass
def __handle_tick_test_time_handler(self):
self.max_tick_count -= 1
if self.max_tick_count == 0:
self.comp.running = False
def test_authenticated_handler(self): def test_authenticated_handler(self):
self.comp.stream = MockStream() self.comp.stream = MockStream()
@@ -121,12 +157,17 @@ class JCLComponent_TestCase(unittest.TestCase):
self.assertTrue(True) self.assertTrue(True)
def test_authenticated_send_probe(self): def test_authenticated_send_probe(self):
account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL)
account11 = Account(user_jid = "test1@test.com", \ account11 = Account(user_jid = "test1@test.com", \
name = "test11") name = "test11", \
jid = "account11@jcl.test.com")
account12 = Account(user_jid = "test1@test.com", \ account12 = Account(user_jid = "test1@test.com", \
name = "test12") name = "test12", \
jid = "account12@jcl.test.com")
account2 = Account(user_jid = "test2@test.com", \ account2 = Account(user_jid = "test2@test.com", \
name = "test2") name = "test2", \
jid = "account2@jcl.test.com")
del account.hub.threadConnection
self.comp.stream = stream = MockStream() self.comp.stream = stream = MockStream()
self.comp.authenticated() self.comp.authenticated()
@@ -150,8 +191,10 @@ class JCLComponent_TestCase(unittest.TestCase):
self.assertTrue(True) self.assertTrue(True)
def test_get_reg_form_init(self): def test_get_reg_form_init(self):
account = Account(user_jid = "", name = "") account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL)
self.comp.get_reg_form_init(Lang.en, account) account1 = Account(user_jid = "", name = "", jid = "")
del account.hub.threadConnection
self.comp.get_reg_form_init(Lang.en, account1)
self.assertTrue(True) self.assertTrue(True)
def test_disco_get_info(self): def test_disco_get_info(self):

View File

@@ -22,42 +22,80 @@
## ##
import unittest import unittest
import os
from sqlobject import * from sqlobject import *
from sqlobject.dbconnection import TheURIOpener
from tests.jcl.jabber.test_component import JCLComponent_TestCase from tests.jcl.jabber.test_component import JCLComponent_TestCase
from jcl.jabber.feeder import Feeder, Sender from jcl.jabber.feeder import FeederComponent, Feeder, Sender
from jcl.model.account import Account from jcl.model.account import Account
from jcl.model import account
DB_PATH = "/tmp/test.db"
DB_URL = DB_PATH #+ "?debug=1&debugThreading=1"
class FeederComponent_TestCase(JCLComponent_TestCase): class FeederComponent_TestCase(JCLComponent_TestCase):
pass def setUp(self):
if os.path.exists(DB_PATH):
os.unlink(DB_PATH)
self.comp = FeederComponent("jcl.test.com",
"password",
"localhost",
"5347",
'sqlite://' + DB_URL)
def tearDown(self):
account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL)
Account.dropTable(ifExists = True)
del TheURIOpener.cachedURIs['sqlite://' + DB_URL]
account.hub.threadConnection.close()
del account.hub.threadConnection
if os.path.exists(DB_PATH):
os.unlink(DB_PATH)
def test_constructor(self):
account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL)
self.assertTrue(Account._connection.tableExists("account"))
del account.hub.threadConnection
class Feeder_TestCase(unittest.TestCase): class Feeder_TestCase(unittest.TestCase):
def setUp(self): def setUp(self):
connection = sqlhub.processConnection = connectionForURI('sqlite:/:memory:') if os.path.exists(DB_PATH):
os.unlink(DB_PATH)
account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL)
Account.createTable() Account.createTable()
def tearDown(self): def tearDown(self):
Account.dropTable(ifExists = True) Account.dropTable(ifExists = True)
del account.hub.threadConnection
# os.unlink(DB_PATH)
def test_feed_exist(self): def test_feed_exist(self):
feeder = Feeder() feeder = Feeder()
feeder.feed(Account(user_jid = "test@test.com", \ feeder.feed(Account(user_jid = "test@test.com", \
name = "test")) name = "test", \
jid = "test@jcl.test.com"))
self.assertTrue(True) self.assertTrue(True)
class Sender_TestCase(unittest.TestCase): class Sender_TestCase(unittest.TestCase):
def setUp(self): def setUp(self):
connection = sqlhub.processConnection = connectionForURI('sqlite:/:memory:') if os.path.exists(DB_PATH):
os.unlink(DB_PATH)
account.hub.threadConnection = connectionForURI('sqlite://' + DB_URL)
Account.createTable() Account.createTable()
def tearDown(self): def tearDown(self):
Account.dropTable(ifExists = True) Account.dropTable(ifExists = True)
del account.hub.threadConnection
# os.unlink(DB_PATH)
def test_send_exist(self): def test_send_exist(self):
sender = Sender() sender = Sender()
account = Account(user_jid = "test@test.com", \ account = Account(user_jid = "test@test.com", \
name = "test") name = "test", \
jid = "test@jcl.test.com")
sender.send(to_account = account, \ sender.send(to_account = account, \
message = "Hello World") message = "Hello World")
self.assertTrue(True) self.assertTrue(True)