001/*
002 * (c) 2005, 2009, 2010 ThoughtWorks Ltd
003 * All rights reserved.
004 *
005 * The software in this package is published under the terms of the BSD
006 * style license a copy of which has been included with this distribution in
007 * the LICENSE.txt file.
008 * 
009 * Created on 24-Feb-2005
010 */
011package com.thoughtworks.proxy.toys.dispatch;
012
013import static com.thoughtworks.proxy.toys.delegate.DelegationMode.DIRECT;
014
015import java.io.IOException;
016import java.io.InvalidObjectException;
017import java.io.ObjectInputStream;
018import java.io.ObjectOutputStream;
019import java.lang.reflect.Method;
020import java.util.ArrayList;
021import java.util.Arrays;
022import java.util.HashSet;
023import java.util.List;
024import java.util.Set;
025
026import com.thoughtworks.proxy.Invoker;
027import com.thoughtworks.proxy.ProxyFactory;
028import com.thoughtworks.proxy.factory.InvokerReference;
029import com.thoughtworks.proxy.factory.StandardProxyFactory;
030import com.thoughtworks.proxy.kit.ObjectReference;
031import com.thoughtworks.proxy.kit.ReflectionUtils;
032import com.thoughtworks.proxy.toys.delegate.DelegatingInvoker;
033
034
035/**
036 * Invoker that dispatches all invocations to different objects according the membership of the method.
037 *
038 * @author Jörg Schaible after idea by Rickard Öberg
039 * @since 0.2
040 */
041public class DispatchingInvoker implements Invoker {
042    private static final long serialVersionUID = 1L;
043    private List<Class<?>> types;
044    private Invoker[] invokers;
045    private transient Set<Method>[] methodSets;
046    private transient Method[] toStringMethods;
047
048    /**
049     * Construct a DispatchingInvoker.
050     *
051     * @param proxyFactory       the {@link ProxyFactory} to use
052     * @param types              the types of the generated proxy
053     * @param delegateReferences the {@link ObjectReference ObjectReferences} for the delegates
054     * @since 0.2
055     */
056    public DispatchingInvoker(
057            final ProxyFactory proxyFactory, final Class<?>[] types, final ObjectReference<Object>[] delegateReferences) {
058        this.types = Arrays.asList(types);
059        invokers = new Invoker[types.length];
060        toStringMethods = new Method[types.length];
061        @SuppressWarnings("unchecked")
062        Set<Method>[] sets = new Set[types.length];
063        methodSets = sets;
064        for (int i = 0; i < types.length; i++) {
065            for (final ObjectReference<Object> delegateReference : delegateReferences) {
066                if (types[i].isAssignableFrom(delegateReference.get().getClass())) {
067                    invokers[i] = new DelegatingInvoker<Object>(proxyFactory, delegateReference, DIRECT);
068                    methodSets[i] = new HashSet<Method>(Arrays.asList(types[i].getMethods()));
069                    for (Method method : methodSets[i]) {
070                        if (method.getName().equals("toString") && method.getParameterTypes().length == 0) {
071                            toStringMethods[i] = method;
072                            break;
073                        }
074                    }
075                    break;
076                }
077            }
078            if (invokers[i] == null) {
079                throw new DispatchingException("Cannot dispatch type " + types[i].getName(), types[i]);
080            }
081        }
082    }
083
084    /**
085     * Constructor used by pure reflection serialization.
086     * 
087     * @since 0.2
088     */
089    protected DispatchingInvoker() {
090    }
091
092    public Object invoke(final Object proxy, Method method, final Object[] args) throws Throwable {
093        if (method.equals(ReflectionUtils.equals)) {
094            final Object arg = args[0];
095            if (new StandardProxyFactory().isProxyClass(arg.getClass())
096                    && (InvokerReference.class.cast(arg)).getInvoker() instanceof DispatchingInvoker) {
097                final DispatchingInvoker invoker = DispatchingInvoker.class.cast((InvokerReference.class.cast(arg)).getInvoker());
098                if (types.size() == invoker.types.size()) {
099                    boolean isEqual = true;
100                    for (int i = 0; isEqual && i < types.size(); ++i) {
101                        final Class<?> type = types.get(i);
102                        for (int j = 0; isEqual && j < invoker.types.size(); ++j) {
103                            if (invoker.types.get(j).equals(type)) {
104                                if (!invokers[i].equals(invoker.invokers[j])) {
105                                    isEqual = false;
106                                }
107                            }
108                        }
109                    }
110                    return isEqual;
111                }
112            }
113            return Boolean.FALSE;
114        } else if (method.equals(ReflectionUtils.hashCode)) {
115            return hashCode();
116        } else if (method.equals(ReflectionUtils.toString)) {
117            for (int i = 0; i < invokers.length; i++) {
118                Method toString = toStringMethods[i];
119                if (toString != null && toString.getDeclaringClass().isAssignableFrom(proxy.getClass())) {
120                    return invokers[i].invoke(proxy, method, args);
121                }
122            }
123            return types.toString();
124        } else {
125            for (int i = 0; i < invokers.length; i++) {
126                if (methodSets[i].contains(method)) {
127                    return invokers[i].invoke(proxy, method, args);
128                }
129            }
130        }
131        throw new RuntimeException("Cannot dispatch method " + method.getName());
132    }
133
134    private void writeObject(final ObjectOutputStream out) throws IOException {
135        out.defaultWriteObject();
136        @SuppressWarnings("unchecked")
137        final List<Class<?>>[] types = new List[methodSets.length];
138        @SuppressWarnings("unchecked")
139        final List<String>[] names = new List[methodSets.length];
140        @SuppressWarnings("unchecked")
141        final List<Class<?>[]>[] arguments = new List[methodSets.length];
142        for (int i = 0; i < methodSets.length; i++) {
143            final Method[] methods = methodSets[i].toArray(new Method[methodSets[i].size()]);
144            types[i] = new ArrayList<Class<?>>();
145            names[i] = new ArrayList<String>();
146            arguments[i] = new ArrayList<Class<?>[]>();
147            for (Method method : methods) {
148                types[i].add(method.getDeclaringClass());
149                names[i].add(method.getName());
150                arguments[i].add(method.getParameterTypes());
151            }
152        }
153        out.writeObject(types);
154        out.writeObject(names);
155        out.writeObject(arguments);
156    }
157
158    private void readObject(final ObjectInputStream in) throws IOException, ClassNotFoundException {
159        in.defaultReadObject();
160        @SuppressWarnings("unchecked")
161        final List<Class<?>>[] types = List[].class.cast(in.readObject());
162        @SuppressWarnings("unchecked")
163        final List<String>[] names = List[].class.cast(in.readObject());
164        @SuppressWarnings("unchecked")
165        final List<Class<?>[]>[] arguments = List[].class.cast(in.readObject());
166        @SuppressWarnings("unchecked")
167        final Set<Method>[] set = new Set[types.length];
168        methodSets = set;
169        toStringMethods = new Method[types.length];
170        try {
171            for (int i = 0; i < methodSets.length; i++) {
172                methodSets[i] = new HashSet<Method>();
173                for (int j = 0; j < types[i].size(); j++) {
174                    final Class<?> type = types[i].get(j);
175                    final String name = names[i].get(j);
176                    final Class<?>[] argumentTypes = arguments[i].get(j);
177                    final Method method = type.getMethod(name, argumentTypes);
178                    methodSets[i].add(method);
179                    if (name.equals("toString") && argumentTypes.length == 0) {
180                        toStringMethods[i] = method;
181                    }
182                }
183            }
184        } catch (final NoSuchMethodException e) {
185            throw new InvalidObjectException(e.getMessage());
186        }
187    }
188}