SpringMvc的手写版(PS:只是闲来无事写的简化版,仅供大家理解SpringMvc的运作原理)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sun5769675/article/details/77752343

最近手头正好有些时间,想着写点什么好呢?后来看到了一篇帖子说面试的时候有面试官问他能不能手写一套SpringMvc出来,不拉不拉的….不多说了。

所以想着就写写试试,捋了捋思路,无非就是三点(大神勿喷!):
1. 实例化
2. 注入
3. url映射

连起来说就是对加了@Controller、@Service注解的对象进行实例化,然后对这些对象中的某些加了@Autowired注解的属性进行依赖注入,然后对Controller中加了@RequestMapping注解的方法做url映射,用于请求来到的时候根据url映射到需要执行的方法,同时将传递的参数注入到方法中。

额。。。这个说的有点敷衍,确实打字太费劲了,最喜欢直接贴代码了!不过本着敬业的原则还是再重新说一下这个项目具体实现了哪些功能及使用方法。
1. 实例化规则,加了@Controller、@Service注解的对象默认beanName就是类名首字母小写,同时也可以写别名,那么beanName就是别名
2. 注入规则,加了@Autowired注解的属性,默认是通过这个属性的类型来找他的实现类,如果该接口实现类有多个那么抛出异常!那么我非得有两个实现类怎么办呢?可以,你需要给实现类起个别名比如@Service(“woShiNumberOne”),然后注入的时候@Autowired(“woShiNumberOne”)这样注入就可以了。
3. url映射方法,就是扫描每个Controller,如果类名上加了@RequestMapping注解那么这算是根目录,然后逐个扫描方法,只要方法上加了@RequestMapping注解那么这就是子目录,最后会将根目录+子目录拼接到一起映射到当前controller的当前方法。
4. 方法参数注入规则,自动识别当前方法有多少个参数,除了request和response这两个参数以外的其他任何参数都需要加@RequestParam注解,用来给这个参数定义别名,否则无法注入进来!而且由于是简化版这里只支持基本数据类型的注入,不支持对象的注入,这点还没来得及写。

这样可以了吧,下面可以开心的来读代码啦!!!

想要下载源码的小伙伴可以看这里,代码已经上传到了码云大家可以通过git下载:
https://git.oschina.net/sunchenbin/SpringMvcSimulate.git

不下源码的直接来看下项目结构:
这里写图片描述
OK下面我们先来自定义5个注解:

/** 
 * @Description controller注册的注解
 * @author chenbin.sun
 * @date 2017年8月30日下午5:12:55
 * 
 */  
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface Controller {

    /**
     * 表示给controller注册别名
     * @return
     */
    String value() default "";
}
/** 
 * @Description service注册的注解
 * @author chenbin.sun
 * @date 2017年8月30日下午5:12:40
 * 
 */  
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface Service {

    /**
     * 表示给service注册别名
     * @return
     */
    String value() default "";
}
/** 
 * @Description controller和方法上的注解
 * @author chenbin.sun
 * @date 2017年8月30日下午5:16:05
 * 
 */  
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RequestMapping {

    /**
     * 表示访问该方法的url
     * @return
     */
    String value() default "";
}
/** 
 * @Description 自动注入注解(如果不加别名自动通过接口类型注入实现类)
 * @author chenbin.sun
 * @date 2017年8月30日下午5:12:40
 * 
 */  
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Autowired {

    /**
     * 表示给filed注入的bean的name
     * @return
     */
    String value() default "";
}
/** 
 * @Description 用作请求传参数的别名
 * @author chenbin.sun
 * @date 2017年8月31日下午2:31:42
 * 
 */  
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
public @interface RequestParam {

    /**
     * 表示参数的别名,必填
     * @return
     */
    String value();
}

这五个注解的作用上面都描述了,有一点值得注意@RequestParam这个注解是用于写在方法的参数上的,用于给这个参数起别名的,前端传过来的参数的key必须和这个相同才能注入进来,否则注入失败,当然这里其实还可以扩展一些应用场景,比方说该参数是否必须传入啊等等,但这些都要相应的方法去做对应的实现,目前我没有写那么多。

下面看我的控制层,很简单,几种测试场景我都写上去了

/** 
 * @Description 测试控制器类
 * @author chenbin.sun
 * @date 2017年8月31日下午6:57:35
 * 
 */  
@Controller
@RequestMapping("/test")
public class TestController {

    @Autowired("testServiceImpl")
    private TestService testService;

    @Autowired
    private TestService2 testService2;

    @RequestMapping("/doTest")
    public void test(HttpServletRequest request, HttpServletResponse response, @RequestParam("param") String param){
        String result = testService.test();
        try {
            response.getWriter().println("do service result:" + result);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @RequestMapping("/doTest2")
    public void test2(HttpServletRequest request, HttpServletResponse response){
        String result = testService2.test2();
        try {
            response.getWriter().println("do service2 result:" + result);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

然后是业务逻辑层,先上两个接口

public interface TestService {

    String test();
}
public interface TestService2 {

    String test2();
}

so easy对不对!继续看两个实现类,为了测试service之间也能相互注入,所以写了两个接口两个实现类

@Service
public class TestServiceImpl implements TestService {

    @Autowired
    private TestService2 testService2;

    @Override
    public String test() {
        System.out.println(testService2.test2());
        return "method test do success!";
    }

}
@Service
public class TestServiceImpl2 implements TestService2 {

    @Override
    public String test2() {
        return "method test2 do success!";
    }

}

OK,下面重点来了,核心代码支持全部注解特性的功能都在这里

/** 
 * @Description 请求几种处理类
 * @author chenbin.sun
 * @date 2017年8月30日下午5:23:54
 * 
 */  
public class DispatcherServlet extends HttpServlet {
    private static final long serialVersionUID = 1378531571714153483L;

    /** 要扫描的包,只有在这个包下并且加了注解的才会呗扫描到 */
    private static final String PACKAGE = "chenbin.sun";

    private static final String CONTROLLER_KEY = "controller";

    private static final String METHOD_KEY = "method";

    /** 存放Controller中url和方法的对应关系,格式:{url:{controller:实例化后的对象,method:实例化的方法}} */
    private static Map<String, Map<String, Object>> urlMethodMapping = new HashMap<>();

    public DispatcherServlet() {  
        super();  
    } 

    /**
     * 初始化方法,用于实例化扫描到的对象,并做注入和url映射(注:该方法逻辑上已经判断了,只执行一次)
     */
    @Override
    public void init(ServletConfig config) throws ServletException { 
        // 只处理一次
        if (urlMethodMapping.size() > 0) {
            return;
        }
        // 开始扫描包下全部class文件
        Set<Class<?>> classes = ClassTools.getClasses(PACKAGE);

        // 存放Controller和Service的Map,格式:{beanName:实例化后的对象} 
        Map<String, Object> instanceNameMap = new HashMap<String, Object>();
        // 存放Service接口类型与接口实例对象的Map,格式:{Service.instance.class:实现类实例化后的对象} 
        Map<Class<?>, Object> instanceTypeMap = new HashMap<Class<?>, Object>();

        // 组装instanceMap
        buildInstanceMap(classes, instanceNameMap, instanceTypeMap);

        // 开始注入
        doIoc(instanceNameMap, instanceTypeMap);

        // 注入完之后开始映射url和method
        buildUrlMethodMapping(instanceNameMap, urlMethodMapping);
    }

    @Override  
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {  
        this.doPost(req, resp);  
    }  

    @Override  
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {  
        // 完整路径
        String url = req.getRequestURI();
        // 跟路径
        String path = req.getContextPath();
        // 计算出method上配置的路径
        String finallyUrl = url.replace(path, "");

        // 取出这个url对应的Controller和method
        Map<String, Object> map = urlMethodMapping.get(finallyUrl);
        if (map == null) {
            throw new RuntimeException("请求地址不存在!");
        }
        Method method = (Method) map.get(METHOD_KEY);
        try {
            // 封装需要注入的参数,目前只支持request和response以及加了@RequestParam标签的基本数据类型的参数注入
            List<Object> paramValue = buildParamObject(req, resp, method);

            // 没有参数的场合
            if (paramValue.size() == 0) {               
                method.invoke(map.get(CONTROLLER_KEY));
            }else {
                // 有参数的场合
                method.invoke(map.get(CONTROLLER_KEY), paramValue.toArray());
            }
        } catch (Exception e) {
            throw new RuntimeException("执行url对应的method失败!");
        }
    }

    /**
     * 封装需要注入的参数,目前只支持request和response以及加了@RequestParam标签的基本数据类型的参数注入
     * @param req
     * @param resp
     * @param method
     * @return
     */
    private List<Object> buildParamObject(HttpServletRequest req, HttpServletResponse resp, Method method) {

        // 封装需要注入的参数,目前只支持request和response以及加了@RequestParam标签的基本数据类型的参数注入
        Parameter[] parameters = method.getParameters();
        List<Object> paramValue = new ArrayList<>();
        for (Parameter parameter : parameters) {
            // 当前参数有别名注解并且别名不为空
            if(parameter.isAnnotationPresent(RequestParam.class) && !parameter.getAnnotation(RequestParam.class).value().isEmpty()){
                // 我们获取
                String value = req.getParameter(parameter.getAnnotation(RequestParam.class).value());
                paramValue.add(value);
            }else if (parameter.getParameterizedType().getTypeName().contains("HttpServletRequest")) {
                paramValue.add(req);
            }else if (parameter.getParameterizedType().getTypeName().contains("HttpServletResponse")) {
                paramValue.add(resp);
            }else{
                paramValue.add(null);
            }
            // 这里只做了request和response以及基本数据类型的参数注入,如果要做对象的注入也是可以写,这里暂时就不写了
            // TODO: 做对象的注入
        }
        return paramValue;
    }  

    /**
     * 注入完之后开始映射url和method
     * @param instanceMap
     * @param urlMethodMapping
     */
    private void buildUrlMethodMapping(Map<String, Object> instanceMap,
            Map<String, Map<String, Object>> urlMethodMapping) {
        // 注入完之后开始映射url和method
        // 组装urlMethodMapping
        for (Entry<String, Object> entry : instanceMap.entrySet()) {

            // 迭代出所有的url
            String parenturl = "";

            // 判断Controller上是否加了requestMapping
            if (entry.getValue().getClass().isAnnotationPresent(RequestMapping.class)) {
                parenturl = entry.getValue().getClass().getAnnotation(RequestMapping.class).value();
            }

            // 取出全部的method
            Method[] methods = entry.getValue().getClass().getMethods();

            // 迭代全部的方法,检查哪些方法上加了requestMaping注解
            for (Method method : methods) {
                if (method.isAnnotationPresent(RequestMapping.class)) {

                    // 得到一个完整的url请求
                    String url = parenturl + method.getAnnotation(RequestMapping.class).value();
                    Map<String, Object> value = new HashMap<>();
                    value.put(CONTROLLER_KEY, entry.getValue());
                    value.put(METHOD_KEY, method);
                    urlMethodMapping.put(url, value );
                }
            }
        }
    }

    /**
     * 根据实例Map开始注入
     * @param instanceMap
     */
    private void doIoc(Map<String, Object> instanceMap, Map<Class<?>, Object> instanceTypeMap) {
        // 开始注入,我们只对加了@Controller和@Service标签中的,属性加了@autowired的进行注入操作
        for (Entry<String, Object> entry : instanceMap.entrySet()) {

            // 取出全部的属性
            Field[] fields = entry.getValue().getClass().getDeclaredFields();

            // 循环属性校验哪些是加了@autowired注解的
            for (Field field : fields) {
                field.setAccessible(true);// 可访问私有属性

                // 有注解的时候
                if (field.isAnnotationPresent(Autowired.class)) {

                    // 没有配别名注入的时候
                    if (field.getAnnotation(Autowired.class).value().isEmpty()) {
                        // 直接获取
                        try {
                            // 根据类型来获取他的实现类
                            Object object = instanceTypeMap.get(field.getType());
                            field.set(entry.getValue(), object);
                        } catch (IllegalArgumentException | IllegalAccessException e) {
                            // TODO Auto-generated catch block
                            e.printStackTrace();
                        }
                    } else {
                        try {
                            // 将被注入的对象
                            Object object = instanceMap.get(field.getAnnotation(Autowired.class).value());
                            field.set(entry.getValue(), object);
                        } catch (Exception e) {
                            throw new RuntimeException("开始注入时出现了异常");
                        }
                    }
                }
            }
        }
    }

    /**
     * 组装instanceMap
     * @param classes
     * @param instanceMap
     */
    private void buildInstanceMap(Set<Class<?>> classes, Map<String, Object> instanceMap, Map<Class<?>, Object> instanceTypeMap) {
        // 开始循环全部class
        for (Class<?> clasz : classes) {

            // 组装instanceMap
            // 判断是否是是加了Controller注解的java对象
            if (clasz.isAnnotationPresent(Controller.class)) {
                try {
                    // 实例化对象
                    Object obj = clasz.newInstance();
                    Controller controller = clasz.getAnnotation(Controller.class);

                    // 如果没有设置别名,那么用类名首字母小写做key
                    if (controller.value().isEmpty()) {
                        instanceMap.put(firstLowerName(clasz.getSimpleName()), obj);
                    }else{
                        // 如果设置了别名那么用别名做key
                        instanceMap.put(controller.value(), obj);
                    }
                } catch (Exception e) {
                    throw new RuntimeException("初始化instanceMap时在处理Controller注解时出现了异常");
                }               
            }else if(clasz.isAnnotationPresent(Service.class)) {
                // 实例化对象
                Object obj = null;
                try {
                    // 实例化对象
                    obj = clasz.newInstance();
                    Service service = clasz.getAnnotation(Service.class);

                    // 如果没有设置别名,那么用类名首字母小写做key
                    if (service.value().isEmpty()) {
                        instanceMap.put(firstLowerName(clasz.getSimpleName()), obj);
                    }else{
                        // 如果设置了别名那么用别名做key
                        instanceMap.put(service.value(), obj);
                    }
                } catch (Exception e) {
                    throw new RuntimeException("初始化instanceMap时在处理Service注解时出现了异常");
                }
                // 实现的接口数组
                Class<?>[] interfaces = clasz.getInterfaces();
                for (Class<?> class1 : interfaces) {
                    if (instanceTypeMap.get(class1) != null) {
                        throw new RuntimeException(class1.getName() + "接口不能被多个类实现!");
                    }
                    instanceTypeMap.put(class1, obj);
                }
            }else {
                continue;
            }
        }
    }

    /**
     * 首字母小写
     * @param name
     * @return
     */
    private String firstLowerName(String name) {
        name = name.substring(0, 1).toLowerCase() + name.substring(1);
       return  name;
    }
}

因为扫包的代码太长了,所以我单独写成了工具类,代码如下

public class ClassTools {
    /**
     * 从包package中获取所有的Class
     * 
     * @param pack
     * @return
     */
    public static Set<Class<?>> getClasses(String pack) {

        // 第一个class类的集合
        Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
        // 是否循环迭代
        boolean recursive = true;
        // 获取包的名字 并进行替换
        String packageName = pack;
        String packageDirName = packageName.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try {
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            // 循环迭代下去
            while (dirs.hasMoreElements()) {
                // 获取下一个元素
                URL url = dirs.nextElement();
                // 得到协议的名称
                String protocol = url.getProtocol();
                // 如果是以文件的形式保存在服务器上
                if ("file".equals(protocol)) {
                    System.err.println("file类型的扫描");
                    // 获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    // 以文件的方式扫描整个包下的文件 并添加到集合中
                    findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                } else if ("jar".equals(protocol)) {
                    // 如果是jar包文件
                    // 定义一个JarFile
                    System.err.println("jar类型的扫描");
                    JarFile jar;
                    try {
                        // 获取jar
                        jar = ((JarURLConnection) url.openConnection()).getJarFile();
                        // 从此jar包 得到一个枚举类
                        Enumeration<JarEntry> entries = jar.entries();
                        // 同样的进行循环迭代
                        while (entries.hasMoreElements()) {
                            // 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
                            JarEntry entry = entries.nextElement();
                            String name = entry.getName();
                            // 如果是以/开头的
                            if (name.charAt(0) == '/') {
                                // 获取后面的字符串
                                name = name.substring(1);
                            }
                            // 如果前半部分和定义的包名相同
                            if (name.startsWith(packageDirName)) {
                                int idx = name.lastIndexOf('/');
                                // 如果以"/"结尾 是一个包
                                if (idx != -1) {
                                    // 获取包名 把"/"替换成"."
                                    packageName = name.substring(0, idx).replace('/', '.');
                                }
                                // 如果可以迭代下去 并且是一个包
                                if ((idx != -1) || recursive) {
                                    // 如果是一个.class文件 而且不是目录
                                    if (name.endsWith(".class") && !entry.isDirectory()) {
                                        // 去掉后面的".class" 获取真正的类名
                                        String className = name.substring(packageName.length() + 1, name.length() - 6);
                                        try {
                                            // 添加到classes
                                            classes.add(Class.forName(packageName + '.' + className));
                                        } catch (ClassNotFoundException e) {
                                            // log
                                            // .error("添加用户自定义视图类错误
                                            // 找不到此类的.class文件");
                                            e.printStackTrace();
                                        }
                                    }
                                }
                            }
                        }
                    } catch (IOException e) {
                        // log.error("在扫描用户定义视图时从jar包获取文件出错");
                        e.printStackTrace();
                    }
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

        return classes;
    }

    /**
     * 以文件的形式来获取包下的所有Class
     * 
     * @param packageName
     * @param packagePath
     * @param recursive
     * @param classes
     */
    public static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive,
            Set<Class<?>> classes) {
        // 获取此包的目录 建立一个File
        File dir = new File(packagePath);
        // 如果不存在或者 也不是目录就直接返回
        if (!dir.exists() || !dir.isDirectory()) {
            // log.warn("用户定义包名 " + packageName + " 下没有任何文件");
            return;
        }
        // 如果存在 就获取包下的所有文件 包括目录
        File[] dirfiles = dir.listFiles(new FileFilter() {

            // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
            public boolean accept(File file) {
                return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
            }
        });
        // 循环所有文件
        for (File file : dirfiles) {
            // 如果是目录 则继续扫描
            if (file.isDirectory()) {
                findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive,
                        classes);
            } else {
                // 如果是java类文件 去掉后面的.class 只留下类名
                String className = file.getName().substring(0, file.getName().length() - 6);
                try {
                    // 添加到集合中去
                    // classes.add(Class.forName(packageName + '.' +
                    // className));
                    // 经过回复同学的提醒,这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净
                    classes.add(
                            Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));
                } catch (ClassNotFoundException e) {
                    // log.error("添加用户自定义视图类错误 找不到此类的.class文件");
                    e.printStackTrace();
                }
            }
        }
    }

    /**
     * 取出list对象中的某个属性的值作为list返回
     * 
     * @param objList
     * @param fieldName
     * @return
     */
    public static <T, E> List<E> getPropertyValueList(List<T> objList, String fieldName) {
        List<E> list = new ArrayList<E>();
        try {
            for (T object : objList) {
                Field field = object.getClass().getDeclaredField(fieldName);
                field.setAccessible(true);
                list.add((E) field.get(object));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

        return list;
    }
}

然后将这个Servlet配置到web.xml中,代码如下

<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://xmlns.jcp.org/xml/ns/javaee" xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/javaee http://xmlns.jcp.org/xml/ns/javaee/web-app_3_1.xsd" id="WebApp_ID" version="3.1">
  <display-name>SpringMvcSimulate</display-name>
  <servlet>
        <servlet-name>testServlet</servlet-name>
        <servlet-class>chenbin.sun.servlet.DispatcherServlet</servlet-class>
    </servlet>
    <!-- ... -->
    <servlet-mapping>
        <servlet-name>testServlet</servlet-name>
        <url-pattern>/*</url-pattern>
    </servlet-mapping>
</web-app>

好了现在这个简化版的springmvc就完成了,启动tomcat后,输入请求url地址:
http://localhost:8080/SpringMvcSimulate/test/doTest?param=aaaaa
可以看到方法执行完毕后返回的信息。

猜你喜欢

转载自blog.csdn.net/sun5769675/article/details/77752343