diff --git a/replication.py b/replication.py index 6ec850c..f9d2d18 100644 --- a/replication.py +++ b/replication.py @@ -84,7 +84,8 @@ class ReplicatedDatablock(object): del dict[self.uuid] else: dict[self.uuid] = self - pass + + return self.uuid def deserialize(self,data): """ diff --git a/replication_client.py b/replication_client.py index 85a48ab..c1c29f0 100644 --- a/replication_client.py +++ b/replication_client.py @@ -9,29 +9,42 @@ log = logging.getLogger(__name__) class Client(object): def __init__(self,factory=None, config=None): - self.rep_store = {} - self.net = ClientNetService(self.rep_store) - self.factory = factory + self._rep_store = {} + self._net = ClientNetService(self._rep_store) + assert(factory) + self._factory = factory def connect(self): - self.net.start() + self._net.start() - def replicate(self, object): - new_item = self.factory.construct(object)(owner="client") + def disconnect(self): + self._net.stop() - new_item.store(self.rep_store) + def register(self, object): + """ + Register a new item for replication + """ + assert(object) - def state(self): - return self.net.state + new_item = self._factory.construct(object)(owner="client") + new_item.store(self._rep_store) - def stop(self): - self.net.stop() + return new_item.uuid + + def get(self,object=None): + pass + + + def unregister(self,object): + pass + + class ClientNetService(threading.Thread): def __init__(self,store_reference=None): # Threading threading.Thread.__init__(self) - self.name = "NetLink" + self.name = "ClientNetLink" self.daemon = True self.exit_event = threading.Event() @@ -53,6 +66,7 @@ class ClientNetService(threading.Thread): def run(self): + log.info("Client is listening") poller = zmq.Poller() poller.register(self.snapshot, zmq.POLLIN) poller.register(self.subscriber, zmq.POLLIN) @@ -75,6 +89,10 @@ class ClientNetService(threading.Thread): def stop(self): self.exit_event.set() + self.snapshot.close() + self.subscriber.close() + self.publish.close() + self.state = 0 @@ -99,34 +117,44 @@ class ServerNetService(threading.Thread): def __init__(self,store_reference=None): # Threading threading.Thread.__init__(self) - self.name = "NetLink" + self.name = "ServerNetLink" self.daemon = True self.exit_event = threading.Event() self.store = store_reference self.context = zmq.Context.instance() - - # Update request - self.snapshot = self.context.socket(zmq.ROUTER) - self.snapshot.setsockopt(zmq.IDENTITY, b'SERVER') - self.snapshot.setsockopt(zmq.RCVHWM, 60) - self.snapshot.bind("tcp://*:5560") - - # Update all clients - self.publisher = self.context.socket(zmq.PUB) - self.publisher.setsockopt(zmq.SNDHWM, 60) - self.publisher.bind("tcp://*:5561") - time.sleep(0.2) - - # Update collector - self.pull = self.context.socket(zmq.PULL) - self.pull.setsockopt(zmq.RCVHWM, 60) - self.pull.bind("tcp://*:5562") - - + self.snapshot = None + self.publisher = None + self.pull = None self.state = 0 + self.bind_ports() + + def bind_ports(self): + try: + # Update request + self.snapshot = self.context.socket(zmq.ROUTER) + self.snapshot.setsockopt(zmq.IDENTITY, b'SERVER') + self.snapshot.setsockopt(zmq.RCVHWM, 60) + self.snapshot.bind("tcp://*:5560") + + # Update all clients + self.publisher = self.context.socket(zmq.PUB) + self.publisher.setsockopt(zmq.SNDHWM, 60) + self.publisher.bind("tcp://*:5561") + time.sleep(0.2) + + # Update collector + self.pull = self.context.socket(zmq.PULL) + self.pull.setsockopt(zmq.RCVHWM, 60) + self.pull.bind("tcp://*:5562") + + except zmq.error.ZMQError: + log.error("Address already in use, change net config") + + def run(self): + log.info("Server is listening") poller = zmq.Poller() poller.register(self.snapshot, zmq.POLLIN) poller.register(self.pull, zmq.POLLIN) @@ -139,4 +167,13 @@ class ServerNetService(threading.Thread): if not items: pass - time.sleep(.1) \ No newline at end of file + time.sleep(.1) + + def stop(self): + self.exit_event.set() + + self.snapshot.close() + self.pull.close() + self.publisher.close() + + self.state = 0 \ No newline at end of file diff --git a/test_replication.py b/test_replication.py index 7be9608..311f7e9 100644 --- a/test_replication.py +++ b/test_replication.py @@ -28,10 +28,26 @@ class RepSampleData(ReplicatedDatablock): return pickle.load(data) -class TestData(unittest.TestCase): - def setUp(self): - self.map = {} +# class TestClient(unittest.TestCase): +# def setUp(self): +# factory = ReplicatedDataFactory() +# self.client_api = Client(factory=factory) + +# def test_client_connect(self): +# self.client_api.connect() +# time.sleep(1) +# self.assertEqual(self.client_api._net.state,1) + + +# def test_client_disconnect(self): +# self.client_api.disconnect() +# time.sleep(1) +# self.assertEqual(self.client_api._net.state,0) + + + +class TestDataReplication(unittest.TestCase): # def test_server_launching(self): # log.info("test_server_launching") # self.server_api.serve() @@ -45,21 +61,28 @@ class TestData(unittest.TestCase): # self.assertEqual(self.server_api.state(),0) def test_setup_data_factory(self): - self.factory = ReplicatedDataFactory() - self.factory.register_type(SampleData, RepSampleData) - + factory = ReplicatedDataFactory() + factory.register_type(SampleData, RepSampleData) data_sample = SampleData() - rep_sample = self.factory.construct(data_sample)(owner="toto") + rep_sample = factory.construct(data_sample)(owner="toto") self.assertEqual(isinstance(rep_sample,RepSampleData), True) - def test_setup_net(self): - self.server_api = Server() - self.server_api.serve() - self.client_api = Client() - self.client_api.connect() + def test_replicate_client_data(self): + factory = ReplicatedDataFactory() + factory.register_type(SampleData, RepSampleData) + + server_api = Server() + server_api.serve() + client_api = Client(factory=factory) + client_api.connect() + + data_sample = SampleData() + data_sample_key = client_api.register(data_sample) + + + self.assertEqual(data_sample_key) + - def test_push_data(self): - self. # def test_client_connect(self): # log.info("test_client_connect")