1
2
3
4 package org.sourceforge.jemmrpc.shared;
5
6 import java.io.IOException;
7 import java.io.NotSerializableException;
8 import java.io.ObjectInputStream;
9 import java.io.ObjectOutputStream;
10 import java.lang.reflect.InvocationTargetException;
11 import java.lang.reflect.Method;
12 import java.lang.reflect.Proxy;
13 import java.net.Socket;
14 import java.util.ArrayList;
15 import java.util.HashMap;
16 import java.util.List;
17 import java.util.Map;
18 import java.util.concurrent.ConcurrentHashMap;
19 import java.util.concurrent.CountDownLatch;
20 import java.util.concurrent.ExecutorService;
21 import java.util.concurrent.SynchronousQueue;
22
23 import org.apache.log4j.Logger;
24 import org.sourceforge.jemmrpc.client.RPCClient;
25 import org.sourceforge.jemmrpc.server.RPCServer;
26
27
28
29
30
31
32
33
34
35
36
37
38 public class RPCHandler implements Runnable
39 {
40 Logger logger = Logger.getLogger(RPCHandler.class);
41
42 Socket socket;
43
44 protected CountDownLatch initialisationLatch = new CountDownLatch(1);
45
46 protected volatile boolean closing = false;
47
48 protected ObjectInputStream is;
49
50 protected ObjectOutputStream os;
51
52
53 Map<Class<?>, Object> remoteInterfaces;
54
55
56 Map<Class<?>, Object> localInterfaces;
57
58 ExecutorService requestExecutor;
59
60 ThreadLocal<String> threadIdTL = new ThreadLocal<String>();
61
62 ThreadLocal<SynchronousQueue<Message>> syncQueueTL = new ThreadLocal<SynchronousQueue<Message>>();
63
64 ConcurrentHashMap<String, SynchronousQueue<Message>> msgSyncPoints = new ConcurrentHashMap<String, SynchronousQueue<Message>>();
65
66 boolean isClient;
67
68 static ThreadLocal<Object> connectionIdTL = new ThreadLocal<Object>();
69
70 protected RPCHandlerListener listener = null;
71
72 Object connectionId;
73
74
75
76
77
78
79
80
81
82
83 public RPCHandler(boolean isClient, Socket socket,
84 Map<Class<?>, Object> localInterfaces, ExecutorService requestExecutor,
85 Object connectionId)
86 {
87 this.isClient = isClient;
88 this.socket = socket;
89
90 this.localInterfaces = localInterfaces;
91 this.requestExecutor = requestExecutor;
92 this.connectionId = connectionId;
93 }
94
95
96
97
98
99
100 public synchronized void setHandlerListener(RPCHandlerListener listener)
101 {
102 this.listener = listener;
103 }
104
105
106
107
108
109
110
111
112 public synchronized Object getRemoteIF(Class<?> ifClass)
113 {
114 if (!ifClass.isInterface())
115 throw new IllegalArgumentException("given class is not an interface");
116
117 Object obj = remoteInterfaces.get(ifClass);
118 if (obj == null)
119 {
120 if (!remoteInterfaces.keySet().contains(ifClass))
121 throw new IllegalArgumentException("Interface " + ifClass
122 + " not offered by server");
123
124 obj = createProxyClass(ifClass);
125 remoteInterfaces.put(ifClass, obj);
126 }
127
128 return obj;
129 }
130
131 protected Object createProxyClass(Class<?> ifClass)
132 {
133 final Class<?>[] ifs = { ifClass };
134 final RPCProxyHandler ph = new RPCProxyHandler(this, ifClass);
135 final Object obj = Proxy.newProxyInstance(this.getClass().getClassLoader(),
136 ifs,ph);
137 return obj;
138 }
139
140
141
142
143
144
145 public void run()
146 {
147 try
148 {
149 if (isClient)
150 {
151 is = new ObjectInputStream(socket.getInputStream());
152 os = new ObjectOutputStream(socket.getOutputStream());
153 }
154 else
155 {
156 os = new ObjectOutputStream(socket.getOutputStream());
157 is = new ObjectInputStream(socket.getInputStream());
158 }
159
160 final AvailableIFsMessage ifOMsg = new AvailableIFsMessage(localInterfaces
161 .keySet().toArray(new Class<?>[0]));
162 os.writeObject(ifOMsg);
163 os.flush();
164
165 try
166 {
167 final AvailableIFsMessage ifIMsg = (AvailableIFsMessage) is.readObject();
168
169 synchronized (this)
170 {
171
172 remoteInterfaces = new HashMap<Class<?>, Object>();
173 for (final Class<?> ifClass : ifIMsg.offeredIFs)
174 remoteInterfaces.put(ifClass, null);
175 }
176 }
177 catch (final ClassNotFoundException e)
178 {
179 throw new IOException("Error initialising connection");
180 }
181
182 initialisationLatch.countDown();
183
184 while (true)
185 {
186 Object o = null;
187 try
188 {
189 o = is.readObject();
190 }
191 catch (final ClassNotFoundException e1)
192 {
193 e1.printStackTrace();
194 }
195
196 if (o instanceof Message)
197 {
198 final Message message = (Message) o;
199 try
200 {
201 receiveMessage(message);
202 }
203 catch (final Exception e)
204 {
205 logger.warn(
206 "exception thrown whilst sending message to receiver", e);
207 }
208 } else
209 logger.warn("Invalid object on stream " + o);
210 }
211 }
212 catch (final IOException se)
213 {
214 connectionTerminated();
215 }
216 }
217
218 protected void receiveMessage(Message message)
219 {
220 if (message instanceof RPCCallRespMessage)
221 {
222 final String threadId = message.getThreadId();
223 final SynchronousQueue<Message> replyQueue = msgSyncPoints.get(threadId);
224 if (replyQueue != null)
225 try
226 {
227 replyQueue.put(message);
228 } catch (final InterruptedException e)
229 {
230 logger.info("Receive thread interrupted");
231 }
232 else
233 logger.error("No client thread found for sync message to " + threadId);
234 } else if (message instanceof RPCCallMessage)
235 processCallMessage((RPCCallMessage) message);
236 else if (message instanceof ErrorMessage)
237 logger.warn("Error message recieved from server "
238 + ((ErrorMessage) message).errorMsg);
239 else
240 logger.warn("Invalid message type received by client: " + message.getClass());
241 }
242
243
244
245
246
247
248
249
250 protected void processCallMessage(final RPCCallMessage message)
251 {
252 requestExecutor.execute(new Runnable()
253 {
254 public void run()
255 {
256 connectionIdTL.set(connectionId);
257 try
258 {
259 final Class<?> targetIF = message.getIfClass();
260 final Object targetIFImpl = localInterfaces.get(targetIF);
261 if (targetIFImpl != null)
262 {
263 final Method method = targetIF.getMethod(message.methodName,
264 message.parameterTypes);
265 if (method == null)
266 throw new IllegalArgumentException(
267 "Interface method does not exist");
268 final Object resp = method.invoke(targetIFImpl,
269 message.getParameters());
270 if (!message.asyncCall)
271 writeMessage(new RPCCallRespMessage(message.threadId, true,
272 resp));
273 } else
274 throw new IllegalArgumentException("Interface not supported");
275 }
276 catch (final Exception e)
277 {
278 final Throwable cause = e instanceof InvocationTargetException ? e
279 .getCause() : e;
280 if (!message.asyncCall)
281 writeMessage(new RPCCallRespMessage(message.threadId, false,
282 cause));
283 else
284 logger.warn("Exception caught whilst processing async call to "
285 + message.ifClass + "." + message.methodName + "()", e);
286 }
287 connectionIdTL.set(null);
288 }
289 });
290 }
291
292 protected synchronized void connectionTerminated()
293 {
294 if (!closing)
295 if (isClient)
296 logger.error("Lost connection to Server");
297 else
298 logger.info("Lost connection to client");
299
300 closing = true;
301 if (msgSyncPoints.size() > 0)
302 {
303 logger.error("Client connection closed with waiters active");
304 final List<String> list = new ArrayList<String>();
305 list.addAll(msgSyncPoints.keySet());
306
307 final ErrorMessage errorMessage = new ErrorMessage("Server connection lost");
308 for (final String threadId : list)
309 {
310 final SynchronousQueue<Message> queue = msgSyncPoints.get(threadId);
311 if (queue != null)
312 if (!queue.offer(errorMessage))
313 logger.warn("Unable to inform thread " + threadId
314 + " of connection close");
315 }
316 }
317
318 if (listener != null)
319 listener.connectionTerminated();
320 }
321
322
323
324
325 public void close()
326 {
327 closing = true;
328 try
329 {
330 socket.close();
331 } catch (final IOException ioe)
332 {
333 logger.warn("IOException thrown whilst closing client socket", ioe);
334 }
335 }
336
337 protected synchronized void writeMessage(Message message)
338 {
339 try
340 {
341 os.writeObject(message);
342 os.flush();
343 } catch (final NotSerializableException nse)
344 {
345 logger.error("Sent message not serializable " + nse);
346 } catch (final IOException e)
347 {
348 logger.warn("error caught writing object", e);
349 }
350
351 }
352
353
354
355
356
357
358
359
360 public Message sendSyncMessage(Message message)
361 {
362 final String threadId = ThreadUtil.getThreadId();
363
364 SynchronousQueue<Message> sq = syncQueueTL.get();
365 if (sq == null)
366 {
367 sq = new SynchronousQueue<Message>();
368 syncQueueTL.set(sq);
369 }
370
371 msgSyncPoints.put(threadId, sq);
372
373 writeMessage(message);
374 Message replyMsg = null;
375 try
376 {
377 replyMsg = sq.take();
378 } catch (final InterruptedException ie)
379 {
380 replyMsg = new ErrorMessage(
381 "InterruptedException received whilst waiting for reply");
382 }
383
384 msgSyncPoints.remove(threadId);
385
386 return replyMsg;
387 }
388
389
390
391
392
393
394
395
396
397 protected void makeAsyncCall(Class<?> ifClass, String methodName,
398 Class<?>[] parameterTypes, Object[] args)
399 {
400 final RPCCallMessage callMessage = new RPCCallMessage(ThreadUtil.getThreadId(),
401 true, ifClass, methodName, parameterTypes, args);
402 writeMessage(callMessage);
403 }
404
405
406
407
408
409
410
411
412
413
414 protected RPCCallRespMessage makeSyncCall(Class<?> ifClass, String methodName,
415 Class<?>[] parameterTypes, Object[] args)
416 {
417 final RPCCallMessage callMessage = new RPCCallMessage(ThreadUtil.getThreadId(),
418 false, ifClass, methodName, parameterTypes, args);
419 final Message replyMsg = sendSyncMessage(callMessage);
420 if (replyMsg instanceof RPCCallRespMessage)
421 return (RPCCallRespMessage) replyMsg;
422 else
423 return new RPCCallRespMessage(ThreadUtil.getThreadId(), false,
424 new IllegalStateException("Unexpected message returned "
425 + replyMsg.getClass()));
426 }
427
428
429
430
431
432
433
434
435 public static Object getConnectionId()
436 {
437 return connectionIdTL.get();
438 }
439
440
441
442
443
444 public void start()
445 {
446 (new Thread(this)).start();
447 try
448 {
449 initialisationLatch.await();
450 } catch (final InterruptedException e)
451 {
452
453 }
454 }
455
456 }