SilverLining's Blog

Mini-Spring-v2 之 IoC 与 DI 优化

提出问题

在 V1 版本中,我们已经简单实现了 Mini Spring 最基本的功能,但是代码并不是很优雅,还存在着一些问题。例如:

本文将首先对 MiniSpring 中的 IoC 与 DI 的逻辑进行优化和重构。

从 Servlet 到 AppcationContext

先来回顾一下我们之前写的 IoC 与 DI 部分的流程:

  1. 调用 Servlet.init() 方法
  2. 读取配置文件(xml,yml)
  3. 扫描包路径下的类,将配置封装为 BeanDefinition 对象
  4. 初始化 IoC 容器,完成对象实例化,并封装为 BeanWrapper
  5. 完成 DI 依赖注入

DispatcherSerlvet 作为 SpringMVC 的入口,我们之前的实现方式是在其 init()方法中完成了 IoC 容器的初始化。而在我们实际使用 Spring 的经验中,我们见得最多的是 ApplicationContext 类。似乎 Spring 托管的所有实例 Bean 都可以通过调用 getBean()方法来得到。

那么这个 ApplicationContext 又是从何而来的呢?从 Spring 源码中我们可以看到, DispatcherServlet 的类图如下:

从上面复杂的调用关系,我们可以简单的得出一个结论:Spring 在 Servlet 的 init() 方法中,初始化了 IoC 容器和 SpringMVC 所依赖的九大组件。

流程分析

在重构前,我们基于 Spring 的使用经验,来做一些合理的推测:

  1. Spring 中的 ApplicationContext 是一个类似于工厂类的角色;
  2. ApplicationContext.getBean 方法作用是根据 BeanName 来获取实例;
  3. Spring IoC 中的默认 Scope 是单例,并且是延迟加载(Lazy)的;
  4. Spring 中的依赖注入肯定发生在 IoC 初始化之后,其 DI 过程是在调用了 getBean 方法后完成的。

基于以上推测出,第 0 步应当是,在调用 Servlet.init() 方法时,就要完成 ApplicationContext 对象的初始化。

在读取配置文件时,考虑到要支持多种格式,如 xml,properties,注解等,此处可以抽象出 BeanDefinitionReader 类来处理配置的读取。

其后,BeanDefinitionReader 类还需要负责,将读取到的配置载入到内存中,此处我们抽象出 BeanDefinition 类来承载配置内容。

将 Bean 对象放入 IoC 容器中时,抽象出 BeanWrapper 对象来表示增强版的 Bean 实例对象。

流程图表示如下:

引入 Bean 包装类

ApplicationContext 类中最重要的方法有两个:

为了实现在容器中托管对象,我们先抽象出两个 Bean 对象的包装类。

BeanWrapper

BeanWrapper 的主要作用是存放 Bean 对象,和他的 Class 信息。这里的 Bean 对象可能是原对象,也可能是经过动态代理增强后的对象。

public class MyBeanWrapper {
    private Object wrapperInstance;
    private Class<?> wrapperClass;

    public MyBeanWrapper(Object wrapperInstance) {
        this.wrapperInstance = wrapperInstance;
        this.wrapperClass = wrapperInstance.getClass();
    }

    public Object getWrapperInstance() {
        return wrapperInstance;
    }

    // 返回代理后的 Class,可能是 $Proxy0
    public Class<?> getWrapperClass() {
        return wrapperClass;
    }
}

BeanDefinition

BeanDefinition 类暂时没什么作用,我们就先把 BeanName(即在工厂中保存的用于 getBean(name) 的名字)和 BeanClassName 保存起来。

public class MyBeanDefinition {
    private String factoryBeanName;
    private String beanClassName;

    // getters and setters...
}

IoC 顶层设计

从 Servlet 开始的初始化

在 Servlet 的初始化方法中,读取 web.xml 中的配置,然后创建出 ApplicationContext 对象。

public class MyDispatcherServlet extends HttpServlet {
    public void init(ServletConfig config) throws ServletException {
        // 初始化 ApplicationContext
        String contextConfigLocation = config.getInitParameter(CONTEXT_CONFIG_LOCATION);
        String[] configLocations = new String[] {contextConfigLocation};
        MyApplicationContext applicationContext = new MyApplicationContext(configLocations);
    }
}

ApplicationContext 类的初始化

public class MyApplicationContext {

    private final Logger logger = LoggerFactory.getLogger(MyApplicationContext.class);

    private final Map<String, MyBeanDefinition> beanDefinitionMap = new ConcurrentHashMap<>();
    private MyBeanDefinitionReader reader;

    // 保存原始 Bean 对象
    private final Map<String, Object> factoryBeanObjectCache = new ConcurrentHashMap<>();
    // 保存 BeanWrapper 对象的 IoC 容器
    private final Map<String, MyBeanWrapper> factoryBeanInstanceCache = new ConcurrentHashMap<>();

    public MyApplicationContext(String... configLocations) {

        try {
            // 1. 加载配置文件
            reader = new MyBeanDefinitionReader(configLocations);

            // 2. 解析配置文件,封装成 BeanDefinition 对象
            List<MyBeanDefinition> beanDefinitionList = reader.loadBeanDefinitions();

            // 3. 缓存 BeanDefinition 至 beanDefinitionMap
            // 此时所有的 Bean 还未真正实例化,还只是配置阶段
            this.registerBeanDefinitions(beanDefinitionList);

            // 4. 完成非延迟加载的类的依赖注入
            // 真正的实例化是从 getBean() 方法开始的
            this.doAutoWiredForNonLazyClass();

        } catch (Exception e) {
            logger.error("Exception init ApplicationContext", e);
        }
    }

    private void doAutoWiredForNonLazyClass() {
        // 暂时只处理非延迟加载
        for (Map.Entry<String, MyBeanDefinition> beanDefinitionEntry :
                this.beanDefinitionMap.entrySet()) {
            String beanName = beanDefinitionEntry.getKey();
            this.getBean(beanName);
        }
    }

    private void registerBeanDefinitions(List<MyBeanDefinition> beanDefinitionList) throws Exception {
        for (MyBeanDefinition beanDefinition : beanDefinitionList) {
            if (beanDefinitionMap.containsKey(beanDefinition.getFactoryBeanName()) ||
                    beanDefinitionMap.containsKey(beanDefinition.getBeanClassName())) {
                throw new Exception("BeanDefinition " +
                        beanDefinition.getFactoryBeanName() + " already exists");
            }
            beanDefinitionMap.put(beanDefinition.getFactoryBeanName(), beanDefinition);
            beanDefinitionMap.put(beanDefinition.getBeanClassName(), beanDefinition);
        }
    }
}

这里的 BeanDefinitionReader 在构造方法中,会执行包扫描,然后把扫描出的需要注入的 Bean 对象的 BeanName 和 BeanClassName 封装成 BeanWrapper,保存起来。

public class MyBeanDefinitionReader {

    private static final String CONFIG_SCAN_PACKAGE = "scanPackage";
    private final Logger logger = LoggerFactory.getLogger(MyBeanDefinitionReader.class);

    // 配置文件中的设置
    private final Properties contextConfigProperties = new Properties();
    // 保存扫描的结果,需要注册的 Bean Classes
    private List<String> registryBeanClasses = new ArrayList<>();

    public MyBeanDefinitionReader(String... configLocations) {
        for (String configLocation : configLocations) {
            // 读取配置文件
            this.loadContextConfig(configLocation);
            // 扫描包路径下的类
            this.scanPackage(contextConfigProperties.getProperty(CONFIG_SCAN_PACKAGE));
        }
    }

    public List<MyBeanDefinition> loadBeanDefinitions() {
        List<MyBeanDefinition> result = new ArrayList<>();
        try {
            for (String className : registryBeanClasses) {
                Class<?> beanClass = Class.forName(className);
                // 保存 BeanName, FullClassName
                String beanName = StringUtils.lowercaseInitial(beanClass.getSimpleName());
                String beanClassName = beanClass.getName();
                // BeanName 的优先级
                // 1. 使用自定义的注解
                // 2. 默认类名首字母小写
                result.add(this.createBeanDefinition(beanName, beanClassName));
                // 3. 最后接口注入
                for (Class<?> itf : beanClass.getInterfaces()) {
                    result.add(createBeanDefinition(itf.getName(), beanClass.getName()));
                }
            }
        } catch (Exception e) {
            logger.error("Fail to instantiate class: ", e);
        }
        return result;
    }

    private void loadContextConfig(String contextConfigLocation) {
        // 策略模式
        if (contextConfigLocation.startsWith("classpath:")) {
            String configPath = contextConfigLocation.replace("classpath:", "");
            try (InputStream inputStream = this.getClass().getClassLoader()
                    .getResourceAsStream(configPath)) {
                contextConfigProperties.load(inputStream);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private void scanPackage(String scanPackage) {
        // 将包名替换为文件路径
        String resourcePath = "/" + scanPackage.replace(".", "/");
        URL packageURL = this.getClass().getClassLoader().getResource(resourcePath);
        if (packageURL == null) {
            logger.error("Configured package [{}] is not found", resourcePath);
            throw new RuntimeException("Configured package is not found");
        }

        File rootDir = new File(packageURL.getFile());
        for (File dir : rootDir.listFiles()) {
            if (dir.isDirectory()) {
                this.scanPackage(scanPackage + "." + dir.getName());
            } else {
                if (!dir.getName().endsWith(".class")) {
                    continue;
                }
                String clazzFullName = scanPackage + "." +
                        dir.getName().replace(".class", "");
                registryBeanClasses.add(clazzFullName);
            }
        }
    }

    private MyBeanDefinition createBeanDefinition(String beanName, String beanClassName) {
        MyBeanDefinition beanDefinition = new MyBeanDefinition();
        beanDefinition.setFactoryBeanName(beanName);
        beanDefinition.setBeanClassName(beanClassName);
        return beanDefinition;
    }
}

DI 注入部分

至此,IoC 容器已经准备完毕,接下来我们来完成 DI 的部分,其核心流程是通过 getBean()方法触发的。

public class MyApplicationContext {
    public Object getBean(String beanName) {
        // 1. 拿到 BeanDefinition 配置信息
        MyBeanDefinition beanDefinition = beanDefinitionMap.get(beanName);
        // 2. 通过反射实例化 Bean
        Object instance = this.instantiateBean(beanName, beanDefinition);
        // 3. 封装成 BeanWrapper 对象
        // 装饰器模式的体现
        MyBeanWrapper beanWrapper = new MyBeanWrapper(instance);
        // 4. 将 BeanWrapper 保存至 IoC 容器中
        factoryBeanInstanceCache.put(beanName, beanWrapper);
        // 5. 执行依赖注入
        this.populateBean(beanName, beanDefinition, beanWrapper);

        return beanWrapper.getWrapperInstance();
    }

    /**
     * 处理 BeanWrapper 的依赖注入,注意循环依赖的处理
     *
     * @param beanName          找到 Bean
     * @param beanDefinition    如何注入的配置信息
     * @param beanWrapper       注入的对象
     */
    private void populateBean(String beanName, MyBeanDefinition beanDefinition, MyBeanWrapper beanWrapper) {
        Object instance = beanWrapper.getWrapperInstance();
        Class<?> clazz = beanWrapper.getWrapperClass();
        if (!( clazz.isAnnotationPresent(MyService.class) ||
                clazz.isAnnotationPresent(MyController.class) ))  {
            return;
        }
        // 获取所有 public / protected / private 方法
        for (Field field : clazz.getDeclaredFields()) {
            if (!field.isAnnotationPresent(MyAutowired.class)) {
                continue;
            }
            String autoWiredBeanName = field.getAnnotation(MyAutowired.class)
                    .value().trim();
            if (autoWiredBeanName.isEmpty()) {
                // 没有自定义类名,默认使用类名注入
                autoWiredBeanName = field.getType().getName();
            }
            // 即使是非 public 方法,但是设置了注解,也需要强制注入
            field.setAccessible(true);

            // TODO:对于循环引用,需要两次循环处理
            // 1. 把第一次读取结果为空的 BeanDefinition 存到第一个缓存
            // 2. 第一次循环结束后,再遍历一次,检查第一次的缓存为空的 Bean 再进行赋值
            try {
                if (!this.factoryBeanInstanceCache.containsKey(autoWiredBeanName)) {
                    continue;
                }
                field.set(instance, this.factoryBeanInstanceCache
                        .get(autoWiredBeanName).getWrapperInstance());
            } catch (IllegalAccessException e) {
                logger.error("Set {}.{} fail", clazz.getName(), field.getName());
            }
        }
    }

    private Object instantiateBean(String beanName, MyBeanDefinition beanDefinition) {
        Object instance = null;
        try {
            Class<?> beanClass = Class.forName(beanDefinition.getBeanClassName());
            instance = beanClass.getDeclaredConstructor().newInstance();
            // 备份原始的 Bean Object
            factoryBeanObjectCache.put(beanName, instance);
            return instance;
        } catch (Exception e) {
            logger.error("Error init bean, exception: ", e);
        }
        return instance;
    }
}

一些细节问题

Bean 的单例模式

之前的实现中,我们忽