动态代理:JDK实现方式源码分析

JDK动态代理

所谓动态代理是指:在程序运行期间根据需要动态创建代理类及其实例来完成具体的功能,动态代理主要分为JDK动态代理和cglib动态代理两大类,本文主要对JDK动态代理进行探讨。

使用步骤

  1. 新建接口
  2. 新建一个接口实现类
  3. 实现代理类回调接口InvocationHandler
  4. 通过Proxy.newProxyInstance()方法创建代理类

使用案例

1,新建HelloWord接口

public interface HelloWord {
    
    
    void sayHello();
    void sayGoodBye();
}

2,HelloWord接口实现类HelloWordImpl

public class HelloWordImpl implements HelloWord {
    
    
    @Override
    public void sayHello() {
    
    
        System.out.println("Hello");
    }

    @Override
    public void sayGoodBye() {
    
    
        System.out.println("GoodBye");
    }
}

3,InvokeHandler接口实现类

	public static class HelloWordInvokeHandler implements InvocationHandler {
    
    

        private Object target;

        public HelloWordInvokeHandler(HelloWord helloWord) {
    
    
            this.target = helloWord;
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    
    
            System.out.println("执行前");
            Object invoke = method.invoke(target, args);
            System.out.println("执行后");
            return invoke;
        }
    }

4,创建代理类,并调用代理类

public class MyTest {
    
    
    public static void main(String[] args) {
    
    
        // 被代理实例
        HelloWordImpl helloWord = new HelloWordImpl();
        // 获取classLoader
        ClassLoader classLoader = helloWord.getClass().getClassLoader();
        // 获取接口数组
        Class<?>[] interfaces = helloWord.getClass().getInterfaces();
        // 创建代理类
        HelloWord proxyInstance = (HelloWord) Proxy.newProxyInstance(classLoader, interfaces, new HelloWordInvokeHandler(helloWord));
        // 调用代理类
        proxyInstance.sayHello();
    }
}

5,运行输出

执行前
Hello
执行后

JDK动态代理源码分析

分析Proxy.newProxyInstance()创建代理类的过程,大致分为以下步骤:

  1. 为接口创建代理类的字节码文件
  2. 使用ClassLoader将字节码文件加载到 JVM
  3. 创建代理类实例对象

创建完代理类,通过调用代理类方法,回调到InvokeHander实例,利用反射实现被代理类方法的调用。

Proxy.newProxyInstance()源码入手

public class Proxy implements java.io.Serializable {
    
    
    
    // 创建的代理类的构造函数参数,为InvocationHandler对象
    private static final Class<?>[] constructorParams =
        {
    
     InvocationHandler.class };
    
    @CallerSensitive
    public static Object newProxyInstance(ClassLoader loader,
                                          Class<?>[] interfaces,
                                          InvocationHandler h)
        throws IllegalArgumentException
    {
    
    
        // 要求InvocationHandler不等于null,这里很好理解,因为主要靠InvocationHandler回调来实现对被代理类的调用。
        Objects.requireNonNull(h);
		// 接口数组拷贝一份
        final Class<?>[] intfs = interfaces.clone();
        // 安全检查
        final SecurityManager sm = System.getSecurityManager();
        if (sm != null) {
    
    
            checkProxyAccess(Reflection.getCallerClass(), loader, intfs);
        }

        // 如果缓存中存在直接获取,或者创建新的代理类class对象
        Class<?> cl = getProxyClass0(loader, intfs);

        /*
         * Invoke its constructor with the designated invocation handler.
         */
        try {
    
    
            if (sm != null) {
    
    
                checkNewProxyPermission(Reflection.getCallerClass(), cl);
            }
			// 得到代理类对象的构造函数,这个构造函数的参数由constructorParams指定
            // 参数constructorParames为常量值:private static final Class<?>[] constructorParams = { InvocationHandler.class };
            final Constructor<?> cons = cl.getConstructor(constructorParams);
            final InvocationHandler ih = h;
            if (!Modifier.isPublic(cl.getModifiers())) {
    
    
                AccessController.doPrivileged(new PrivilegedAction<Void>() {
    
    
                    public Void run() {
    
    
                        cons.setAccessible(true);
                        return null;
                    }
                });
            }
            // 创建代理类的对象
            return cons.newInstance(new Object[]{
    
    h});
        } catch (IllegalAccessException|InstantiationException e) {
    
    
            throw new InternalError(e.toString(), e);
        } catch (InvocationTargetException e) {
    
    
            Throwable t = e.getCause();
            if (t instanceof RuntimeException) {
    
    
                throw (RuntimeException) t;
            } else {
    
    
                throw new InternalError(t.toString(), t);
            }
        } catch (NoSuchMethodException e) {
    
    
            throw new InternalError(e.toString(), e);
        }
    }
}

以上源码完成了代理类的创建:

  1. getProxyClass0()方法:为接口创建代理类的字节码文件,并使用ClassLoader将字节妈文件加载到JVM,返回代理类的Class对象
  2. cl.getConstructor(constructorParams)方法:获得代理类指定的构造函数
  3. cons.newInstance(new Object[]{h})方法:通过反射创建代理类对象

最重要的是第一步getProxyClass0()方法:

public class Proxy implements java.io.Serializable {
    
    
    private static final WeakCache<ClassLoader, Class<?>[], Class<?>>
        proxyClassCache = new WeakCache<>(new KeyFactory(), new ProxyClassFactory());
    
    //此方法也是Proxy类下的方法
    private static Class<?> getProxyClass0(ClassLoader loader,
                                           Class<?>... interfaces) {
    
    
        // 接口个数超过65535个,直接抛出异常。
        if (interfaces.length > 65535) {
    
    
            throw new IllegalArgumentException("interface limit exceeded");
        }

        // 如果代理类被指定的类加载器loader定义了,并实现了给定的接口interfaces,
        // 那么就返回缓存的代理类对象,否则使用ProxyClassFactory创建代理类。
        return proxyClassCache.get(loader, interfaces);
    }
}

proxyClassCache介绍

private static final WeakCache<ClassLoader, Class<?>[], Class<?>>
        proxyClassCache = new WeakCache<>(new KeyFactory(), new ProxyClassFactory());

proxyClassCache是个WeakCache类的对象,调用proxyClassCache.get(loader, interfaces); 可以得到缓存的代理类或创建代理类(没有缓存的情况)。说明WeakCache中有get这个方法。先看下WeakCache类的定义(这里先只给出变量的定义和构造函数):

// K代表key的类型,P代表参数的类型,V代表value的类型。
// WeakCache<ClassLoader, Class<?>[], Class<?>>  proxyClassCache  说明proxyClassCache存的值是Class<?>对象,正是我们需要的代理类对象。
final class WeakCache<K, P, V> {
    
    
    private final ConcurrentMap<Object, ConcurrentMap<Object, Supplier<V>>> map = new ConcurrentHashMap<>();
    private final BiFunction<K, P, ?> subKeyFactory;
    private final BiFunction<K, P, V> valueFactory;
    
    public WeakCache(BiFunction<K, P, ?> subKeyFactory,BiFunction<K, P, V> valueFactory) {
    
    
        this.subKeyFactory = Objects.requireNonNull(subKeyFactory);
        this.valueFactory = Objects.requireNonNull(valueFactory);
    }
}

其中map变量是实现缓存的核心变量,其可以翻译为:ConcurrentMap<key, ConcurrentMap<subKey, value>>

  • key:是传进来的Classloader进行包装后的对象,cacheKey
  • subKey:是由WeakCache构造函数传人的KeyFactory()生成的。
  • value:是产生代理类的对象,是由WeakCache构造函数传人的ProxyClassFactory()生成的。

通过sub-key拿到一个Supplier<Class<?>>对象,然后调用这个对象的get方法,最终得到代理类的Class对象。

回到proxyClassCache.get(loader, interfaces);源码

final class WeakCache<K, P, V> {
    
    

    private final ReferenceQueue<K> refQueue = new ReferenceQueue<>();
    private final ConcurrentMap<Object, ConcurrentMap<Object, Supplier<V>>> map
        = new ConcurrentHashMap<>();
    private final ConcurrentMap<Supplier<V>, Boolean> reverseMap
        = new ConcurrentHashMap<>();
    private final BiFunction<K, P, ?> subKeyFactory;
    private final BiFunction<K, P, V> valueFactory;

    public WeakCache(BiFunction<K, P, ?> subKeyFactory,BiFunction<K, P, V> valueFactory) {
    
    
        this.subKeyFactory = Objects.requireNonNull(subKeyFactory);
        this.valueFactory = Objects.requireNonNull(valueFactory);
    }

    // K和P就是WeakCache定义中的泛型,key是类加载器,parameter是接口类数组
    public V get(K key, P parameter) {
    
    
        // 检查parameter不为空
        Objects.requireNonNull(parameter);
		// 清除无效的缓存
        expungeStaleEntries();
		// cacheKey就是刚才提到的key,也就是一级key
        Object cacheKey = CacheKey.valueOf(key, refQueue);

        // 根据key获取到ConcurrentMap<Object, Supplier<V>> valuesMap对象
        // 如果valuesMap之前不存在,则新创建一个对象放进去
        ConcurrentMap<Object, Supplier<V>> valuesMap = map.get(cacheKey);
        if (valuesMap == null) {
    
    
            ConcurrentMap<Object, Supplier<V>> oldValuesMap
                = map.putIfAbsent(cacheKey,
                                  valuesMap = new ConcurrentHashMap<>());
            if (oldValuesMap != null) {
    
    
                valuesMap = oldValuesMap;
            }
        }

        // 生成subKey
        Object subKey = Objects.requireNonNull(subKeyFactory.apply(key, parameter));
        // 通过subKey获取supplier对象
        Supplier<V> supplier = valuesMap.get(subKey);
        // supplier实际上就是这个factory
        Factory factory = null;

        while (true) {
    
    
            // 如果缓存里有supplier ,那就直接通过get方法,得到代理类对象,返回,就结束了
            if (supplier != null) {
    
    
                V value = supplier.get();
                if (value != null) {
    
    
                    return value;
                }
            }
            // 下面的所有代码目的就是:如果缓存中没有supplier,则创建一个Factory对象,把factory对象在多线程的环境下安全的赋给supplier。
            // 因为是在while(true)中,赋值成功后又回到上面去调get方法,返回才结束。
            if (factory == null) {
    
    
                factory = new Factory(key, parameter, subKey, valuesMap);
            }

            if (supplier == null) {
    
    
                supplier = valuesMap.putIfAbsent(subKey, factory);
                if (supplier == null) {
    
    
                    // successfully installed Factory
                    supplier = factory;
                }
                // else retry with winning supplier
            } else {
    
    
                if (valuesMap.replace(subKey, supplier, factory)) {
    
    
                    // successfully replaced
                    // cleared CacheEntry / unsuccessful Factory
                    // with our Factory
                    supplier = factory;
                } else {
    
    
                    // retry with current supplier
                    supplier = valuesMap.get(subKey);
                }
            }
        }
    }

所以接下来我们看Factory类中的get方法:

private final class Factory implements Supplier<V> {
    
    

        private final K key;
        private final P parameter;
        private final Object subKey;
        private final ConcurrentMap<Object, Supplier<V>> valuesMap;

        Factory(K key, P parameter, Object subKey,
                ConcurrentMap<Object, Supplier<V>> valuesMap) {
    
    
            this.key = key;
            this.parameter = parameter;
            this.subKey = subKey;
            this.valuesMap = valuesMap;
        }

        @Override
        public synchronized V get() {
    
     // serialize access
            // 重新检查得到的supplier是不是当前对象
            Supplier<V> supplier = valuesMap.get(subKey);
            if (supplier != this) {
    
    
                return null;
            }
            // else still us (supplier == this)

            // create new value
            V value = null;
            try {
    
    
                // 代理类就是在这个位置调用valueFactory生成的
                // valueFactory就是我们传入的 new ProxyClassFactory()
                // 一会我们分析ProxyClassFactory()的apply方法
                value = Objects.requireNonNull(valueFactory.apply(key, parameter));
            } finally {
    
    
                if (value == null) {
    
     // remove us on failure
                    valuesMap.remove(subKey, this);
                }
            }
            // the only path to reach here is with non-null value
            assert value != null;

            // 把value包装成弱引用
            CacheValue<V> cacheValue = new CacheValue<>(value);

            // reverseMap是用来实现缓存的有效性
            reverseMap.put(cacheValue, Boolean.TRUE);

            // try replacing us with CacheValue (this should always succeed)
            if (!valuesMap.replace(subKey, this, cacheValue)) {
    
    
                throw new AssertionError("Should not reach here");
            }

            // successfully replaced us with new CacheValue -> return the value
            // wrapped by it
            return value;
        }
    }

接下来到ProxyClassFactory的apply方法,代理类就是在这里生成的:

// 这里的BiFunction<T, U, R>是个函数式接口,可以理解为用T,U两种类型做参数,得到R类型的返回值
private static final class ProxyClassFactory
        implements BiFunction<ClassLoader, Class<?>[], Class<?>>
    {
    
    
        // 所有代理类名字的前缀
        private static final String proxyClassNamePrefix = "$Proxy";

        // 用于生成代理类名字的计数器
        private static final AtomicLong nextUniqueNumber = new AtomicLong();

        @Override
        public Class<?> apply(ClassLoader loader, Class<?>[] interfaces) {
    
    

            Map<Class<?>, Boolean> interfaceSet = new IdentityHashMap<>(interfaces.length);
            // 验证代理接口,可不看
            for (Class<?> intf : interfaces) {
    
    
                /*
                 * Verify that the class loader resolves the name of this
                 * interface to the same Class object.
                 */
                Class<?> interfaceClass = null;
                try {
    
    
                    interfaceClass = Class.forName(intf.getName(), false, loader);
                } catch (ClassNotFoundException e) {
    
    
                }
                if (interfaceClass != intf) {
    
    
                    throw new IllegalArgumentException(
                        intf + " is not visible from class loader");
                }
                /*
                 * Verify that the Class object actually represents an
                 * interface.
                 */
                if (!interfaceClass.isInterface()) {
    
    
                    throw new IllegalArgumentException(
                        interfaceClass.getName() + " is not an interface");
                }
                /*
                 * Verify that this interface is not a duplicate.
                 */
                if (interfaceSet.put(interfaceClass, Boolean.TRUE) != null) {
    
    
                    throw new IllegalArgumentException(
                        "repeated interface: " + interfaceClass.getName());
                }
            }
			// 生成的代理类的包名 
            String proxyPkg = null;     // package to define proxy class in
            // 代理类访问控制符: public ,final
            int accessFlags = Modifier.PUBLIC | Modifier.FINAL;

            // 验证所有非公共的接口在同一个包内;公共的就无需处理
            // 生成包名和类名的逻辑,包名默认是com.sun.proxy,类名默认是$Proxy 加上一个自增的整数值
            // 如果被代理类是 non-public proxy interface ,则用和被代理类接口一样的包名
            for (Class<?> intf : interfaces) {
    
    
                int flags = intf.getModifiers();
                if (!Modifier.isPublic(flags)) {
    
    
                    accessFlags = Modifier.FINAL;
                    String name = intf.getName();
                    int n = name.lastIndexOf('.');
                    String pkg = ((n == -1) ? "" : name.substring(0, n + 1));
                    if (proxyPkg == null) {
    
    
                        proxyPkg = pkg;
                    } else if (!pkg.equals(proxyPkg)) {
    
    
                        throw new IllegalArgumentException(
                            "non-public interfaces from different packages");
                    }
                }
            }

            if (proxyPkg == null) {
    
    
                // if no non-public proxy interfaces, use com.sun.proxy package
                proxyPkg = ReflectUtil.PROXY_PACKAGE + ".";
            }

            /*
             * Choose a name for the proxy class to generate.
             */
            long num = nextUniqueNumber.getAndIncrement();
            // 代理类的完全限定名,如com.sun.proxy.$Proxy0.class
            String proxyName = proxyPkg + proxyClassNamePrefix + num;

            // 核心部分,生成代理类的字节码
            byte[] proxyClassFile = ProxyGenerator.generateProxyClass(
                proxyName, interfaces, accessFlags);
            try {
    
    
                // 把代理类加载到JVM中,返回代理类的class对象,至此动态代理过程基本结束了
                return defineClass0(loader, proxyName,
                                    proxyClassFile, 0, proxyClassFile.length);
            } catch (ClassFormatError e) {
    
    
                throw new IllegalArgumentException(e.toString());
            }
        }
    }

实际生成的字节码的方法是:ProxyGenerator.generateProxyClass(proxyName, interfaces, accessFlags);

	public static byte[] generateProxyClass(final String var0, Class<?>[] var1, int var2) {
    
    
        ProxyGenerator var3 = new ProxyGenerator(var0, var1, var2);
        final byte[] var4 = var3.generateClassFile();
    	// 将要生成代理类的字节码文件保存在磁盘中
        if (saveGeneratedFiles) {
    
    
            AccessController.doPrivileged(new PrivilegedAction<Void>() {
    
    
                public Void run() {
    
    
                    try {
    
    
                        int var1 = var0.lastIndexOf(46);
                        Path var2;
                        if (var1 > 0) {
    
    
                            Path var3 = Paths.get(var0.substring(0, var1).replace('.', File.separatorChar));
                            Files.createDirectories(var3);
                            var2 = var3.resolve(var0.substring(var1 + 1, var0.length()) + ".class");
                        } else {
    
    
                            var2 = Paths.get(var0 + ".class");
                        }

                        Files.write(var2, var4, new OpenOption[0]);
                        return null;
                    } catch (IOException var4x) {
    
    
                        throw new InternalError("I/O exception saving generated file: " + var4x);
                    }
                }
            });
        }

        return var4;
    }

如果想要生成的话可以添加如下参数:

System.getProperties().put("sun.misc.ProxyGenerator.saveGeneratedFiles", "true");

生成的代理类字节码如下:

package com.sun.proxy;

import com.bobo.proxy.HelloWord;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.lang.reflect.UndeclaredThrowableException;

public final class $Proxy0 extends Proxy implements HelloWord {
    
    
    private static Method m1;
    private static Method m3;
    private static Method m2;
    private static Method m4;
    private static Method m0;

    public $Proxy0(InvocationHandler var1) throws  {
    
    
        super(var1);
    }

    public final boolean equals(Object var1) throws  {
    
    
        try {
    
    
            return (Boolean)super.h.invoke(this, m1, new Object[]{
    
    var1});
        } catch (RuntimeException | Error var3) {
    
    
            throw var3;
        } catch (Throwable var4) {
    
    
            throw new UndeclaredThrowableException(var4);
        }
    }

    public final void sayHello() throws  {
    
    
        try {
    
    
            super.h.invoke(this, m3, (Object[])null);
        } catch (RuntimeException | Error var2) {
    
    
            throw var2;
        } catch (Throwable var3) {
    
    
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final String toString() throws  {
    
    
        try {
    
    
            return (String)super.h.invoke(this, m2, (Object[])null);
        } catch (RuntimeException | Error var2) {
    
    
            throw var2;
        } catch (Throwable var3) {
    
    
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final void sayGoodBye() throws  {
    
    
        try {
    
    
            super.h.invoke(this, m4, (Object[])null);
        } catch (RuntimeException | Error var2) {
    
    
            throw var2;
        } catch (Throwable var3) {
    
    
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final int hashCode() throws  {
    
    
        try {
    
    
            return (Integer)super.h.invoke(this, m0, (Object[])null);
        } catch (RuntimeException | Error var2) {
    
    
            throw var2;
        } catch (Throwable var3) {
    
    
            throw new UndeclaredThrowableException(var3);
        }
    }

    static {
    
    
        try {
    
    
            m1 = Class.forName("java.lang.Object").getMethod("equals", Class.forName("java.lang.Object"));
            m3 = Class.forName("com.bobo.proxy.HelloWord").getMethod("sayHello");
            m2 = Class.forName("java.lang.Object").getMethod("toString");
            m4 = Class.forName("com.bobo.proxy.HelloWord").getMethod("sayGoodBye");
            m0 = Class.forName("java.lang.Object").getMethod("hashCode");
        } catch (NoSuchMethodException var2) {
    
    
            throw new NoSuchMethodError(var2.getMessage());
        } catch (ClassNotFoundException var3) {
    
    
            throw new NoClassDefFoundError(var3.getMessage());
        }
    }
}

总结:

通过上面生成的代理类我们很清晰的看到:

  1. 代理类继承java.lang.reflect.Proxy类,并实现了并实现了我们定义的HelloWord接口
  2. 通过反射获取被代理类每个方法的Method对象,定义成m1,m2,m3,m4,m5
  3. 代理类通过执行InvokeHandler的invoke方法,把被代理类的method对象和参数回调到InvokeHandler里面
  4. 最终通过我们在InvokeHandler里面的回调来实现对被代理类的调用,并且在调用被代理类时,可对其调用前后增强。

猜你喜欢

转载自blog.csdn.net/u013277209/article/details/111593453