手写spring(简易版)

本文版权归 远方的风lyh和博客园共有,欢迎转载,但须保留此段声明,并给出原文链接,谢谢合作如有错误之处忘不吝批评指正!

       理解Spring本质:

     相信之前在使用spring的时候大家都配置web.xml文件、会配置spring,(如下)配置其实就是一个Servlet,DispatcherServlet源码中,它(父类)重写了 HttpServlet接口,所有的请求将交给 DispatcherServlet来处理了    <servlet>

        <servlet-name>spring-mvc</servlet-name>
        <servlet-class>org.springframework.web.servlet.DispatcherServlet</servlet-class>
        <init-param>
            <param-name>contextConfigLocation</param-name>
            <param-value>WEB-INF/spring/spring-mvc.xml</param-value>
        </init-param>
        <load-on-startup>1</load-on-startup>
        <async-supported>true</async-supported>
    </servlet>
    <servlet-mapping>
        <servlet-name>spring-mvc</servlet-name>
        <url-pattern>/</url-pattern>
 
</servlet-mapping>

       手写spring:

  配置

    web.xm: 配置一个servlet 并接收所有请求

<!DOCTYPE web-app PUBLIC
        "-//Sun Microsystems, Inc.//DTD Web Application 2.3//EN"
        "http://java.sun.com/dtd/web-app_2_3.dtd" >

<web-app>
    <display-name>Archetype Created Web Application</display-name>
    <servlet>
        <servlet-name>MySpringMVC</servlet-name>
        <servlet-class>cn.lyh.mySpring.MyDispatcherServlet</servlet-class>
        <init-param>
            <param-name>contextConfigLocation</param-name>
            <param-value>context.properties</param-value>
        </init-param>
        <load-on-startup>1</load-on-startup>
    </servlet>

    <servlet-mapping>
        <servlet-name>MySpringMVC</servlet-name>
        <url-pattern>/*</url-pattern>
    </servlet-mapping>
</web-app>

     context.properties:

#包扫描
scan.package=cn.lyh.mySpringTest

  注解类

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyAutowired {
    String value() default "";
}


@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyController {
    String value() default "";
}

@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestMapping {
    String value() default "";
}


@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestParam {
    /**
     * 表示参数的别名,必填
     * @return
     */
    String value();

}

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyResponseAdvice {
}


@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyResponseBody {
}


@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyService {
    String value() default "";
}

   MyDispacherServlet(核心实现):

    MyDispacherServlet实现了HttpServlet 并复写doGet、doPost、init 方法

·

package cn.lyh.mySpring;

import cn.lyh.mySpring.Handler.ResponseBodyHandler;
import cn.lyh.mySpring.annotation.*;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import org.apache.log4j.Logger;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.net.URL;
import java.util.*;

/***
 *dispatcherServlet
 * @author lyh
 */
public class MyDispatcherServlet extends HttpServlet {
    /***配置***/
    private Properties contextConfig = new Properties();
    /***扫描的类名列表****/
    private List<String> classNames = new ArrayList<>();
    /***ioc容器 存放实例****/
    private Map<String, Object> ioc = new HashMap<>();
    /***url映射****/
    private Map<String, Method> handlerMapping = new HashMap<>();
    private static Logger logger = Logger.getLogger(MyDispatcherServlet.class);
    /***返回处理器****/
    private ResponseBodyHandler responseBodyHandler;

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doDispatcherServlet(req, resp);
    }


    /****
     * 加载启动
     * @param config
     * @throws ServletException
     */
    @Override
    public void init(ServletConfig config) throws ServletException {
        String contextConfigLocation = config.getInitParameter("contextConfigLocation");
        try {
            initMyDispatcherServlet(contextConfigLocation);
        } catch (Exception e) {
            e.printStackTrace();
            throw new ServletException(e.getMessage());
        }
    }


    /***
     * url请求映射到具体方法
     * @param request
     * @param response
     */
    private void doDispatcherServlet(HttpServletRequest request, HttpServletResponse response) {
        invoke(request, response);
    }


    private void invoke(HttpServletRequest request, HttpServletResponse response) {
        String queryUrl = request.getRequestURI();
        queryUrl = queryUrl.replaceAll("/+", "/");
        Method method = handlerMapping.get(queryUrl);
        if (null == method) {
            PrintWriter pw = null;
            try {
                response.setStatus(404);
                logger.debug("request fail(404): " + request.getRequestURI());
                pw = response.getWriter();
                pw.print("404    not find       ->        " + request.getRequestURI());
                pw.flush();
            } catch (IOException e) {
                e.printStackTrace();
            } finally {
                pw.close();
            }
        } else {
            //todo method parameters need  to deal
            Object[] paramValues = getMethodParamAndValue(request, response, method);
            try {
                String controllerClassName = toFirstWordLower(method.getDeclaringClass().getSimpleName());
                Object object = method.invoke(ioc.get(controllerClassName), paramValues);
                if (object != null) {
                    if (method.isAnnotationPresent(MyResponseBody.class)) {
                        response.setHeader("content-type", "application/json;charset=UTF-8");
                        if (null == responseBodyHandler) {
                            object = JSONObject.toJSONString(object, SerializerFeature.WriteMapNullValue);
                        } else {
                            object = responseBodyHandler.equals(object);
                        }
                    }
                    response.getWriter().print(object);
                    logger.debug("request-> " + request.getRequestURI() + ", response success ->" + response.getStatus());
                }
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            } catch (InvocationTargetException e) {
                e.printStackTrace();
            } catch (IOException e) {
                e.printStackTrace();
            }

        }
    }

    /****
     * @MyRequestParam
     * 参数解析 复制
     * @注意: 参数解析暂不完整 int float long double boolean string
     *        实体接收暂不支持
     * @param request
     * @param response
     * @param method
     * @return
     */
    private Object[] getMethodParamAndValue(HttpServletRequest request, HttpServletResponse response, Method method) {
        Parameter[] parameters = method.getParameters();
        Object[] paramValues = new Object[parameters.length];
        for (int i = 0; i < parameters.length; i++) {

            if (ServletRequest.class.isAssignableFrom(parameters[i].getType())) {
                paramValues[i] = request;
            } else if (ServletResponse.class.isAssignableFrom(parameters[i].getType())) {
                paramValues[i] = response;
            } else {
                String bindingValue = parameters[i].getName();
                if (parameters[i].isAnnotationPresent(MyRequestParam.class)) {
                    bindingValue = parameters[i].getAnnotation(MyRequestParam.class).value();
                }
                String paramValue = request.getParameter(bindingValue);
                paramValues[i] = paramValue;
                if (paramValue != null) {
                    if (Integer.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Integer.parseInt(paramValue);
                    } else if (Float.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Float.parseFloat(paramValue);
                    } else if (Double.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Double.parseDouble(paramValue);
                    } else if (Long.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Long.parseLong(paramValue);
                    } else if (Boolean.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Boolean.parseBoolean(paramValue);
                    }
                }
            }
        }
        return paramValues;
    }


    /****
     * 初始化
     * @param contextConfigLocation
     * @throws Exception
     */
    private void initMyDispatcherServlet(String contextConfigLocation) throws Exception {
        logger.info("-----------------------------mySpring init start-----------------------------------------");
        logger.debug("doLoadConfig:" + contextConfigLocation);
        //加载配置
        doLoadConfig(contextConfigLocation);
        //扫描 包扫描
        logger.debug("scan:" + contextConfig.getProperty("scan.package"));
        doScanner(contextConfig.getProperty("scan.package"));
        //创建实体类、ioc
        doInstance();
        //注入 di
        doAutowired();
        //url 映射
        initHandlerMapping();

    }

    /***
     * 注入
     */
    private void doAutowired() {
        if (ioc.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Object object = entry.getValue();
            Field[] fields = object.getClass().getDeclaredFields();
            for (Field filed : fields) {
                if (filed.isAnnotationPresent(MyAutowired.class)) {
                    MyAutowired myAutowired = filed.getAnnotation(MyAutowired.class);
                    String key = filed.getType().getName();
                    String val = myAutowired.value();
                    if (val != null && "".equals(val.trim())) {
                        key = val.trim();
                    }
                    filed.setAccessible(true);
                    try {
                        filed.set(object, ioc.get(key));
                    } catch (IllegalAccessException e) {
                        e.printStackTrace();
                    }
                } else {
                    continue;
                }
            }
        }
    }

    /***
     * 初始化HandlerMapper
     */
    private void initHandlerMapping() {
        if (ioc.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Object object = entry.getValue();
            Class<?> clazz = object.getClass();
            if (clazz.isAnnotationPresent(MyController.class)) {
                Method[] methods = clazz.getDeclaredMethods();
                MyRequestMapping requestMapping = clazz.getAnnotation(MyRequestMapping.class);
                String crlRequstMapping = requestMapping.value() == null ? "" : requestMapping.value();
                for (Method method : methods) {
                    if (method.isAnnotationPresent(MyRequestMapping.class)) {
                        String url = ("/" + crlRequstMapping + "/" + method.getAnnotation(MyRequestMapping.class).value()).replaceAll("/+", "/");
                        // check request url must only
                        if (handlerMapping.containsKey(url)) {
                            logger.error("mapping request url:" + url + "is already exist! request url must only");
                            new Exception("mapping:" + url + "is already exist!");
                        }
                        handlerMapping.put(url, method);
                        logger.debug("mapping: " + url);
                    } else {
                        continue;
                    }
                }
            }

        }
    }

    /***
     * 加载配置文件
     * @param contextConfigLocation
     * @throws Exception
     */
    private void doLoadConfig(String contextConfigLocation) throws Exception {
        InputStream is = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
        if (is == null) {
            logger.error("config:" + contextConfigLocation + " not exist");
            throw new Exception("config:" + contextConfigLocation + " not exist");
        } else {
            try {
                contextConfig.load(is);
            } catch (IOException e) {
                e.printStackTrace();
            } finally {
                //关流
                if (null != is) {
                    try {
                        is.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }

    /****
     * 包扫描
     * @param packageName
     * @throws Exception
     */
    private void doScanner(String packageName) throws Exception {
        if (packageName == null || packageName.length() == 0) {
            throw new Exception("init scan is empty");
        }

        URL url = this.getClass().getClassLoader().getResource("/" + packageName.replaceAll("\\.", "/"));
        if (null != url) {
            File dir = new File(url.getFile());
            for (File file : dir.listFiles()) {
                if (file.isDirectory()) {
                    //递归读取包
                    doScanner(packageName + "." + file.getName());
                } else {
                    String className = packageName + "." + file.getName().replace(".class", "");
                    logger.debug("scan class find:" + className);
                    classNames.add(className);
                }
            }
        }

    }

    /****
     * ioc实例化
     */
    private void doInstance() {
        if (classNames.isEmpty()) {
            return;
        }
        for (String className : classNames) {
            try {
                // @MyController instance
                Class<?> clazz = Class.forName(className);
                if (clazz.isAnnotationPresent(MyController.class)) {
                    logger.debug("MyController instance: " + clazz.getName());
                    ioc.put(toFirstWordLower(clazz.getSimpleName()), clazz.newInstance());
                } else if (clazz.isAnnotationPresent(MyService.class)) {
                    //todo @MyService instance
                    // 1 以自己本类或者用户自定义别名为key
                    Object newInstance = clazz.newInstance();
                    String key = toFirstWordLower(clazz.getSimpleName());
                    logger.debug("MyService instance: " + clazz.getName());
                    MyService service = clazz.getAnnotation(MyService.class);
                    String value = service.value().trim();
                    if (!"".equals(value)) {
                        key = value;
                    }
                    if (!ioc.containsKey(key)) {
                        ioc.put(key, newInstance);
                    } else {
                        logger.error("MyService instance: " + service.value() + "  is  exist");
                        throw new Exception("MyService instance: " + service.value() + "  is  exist");
                    }
                    //2 以所继承的接口为 key
                    Class<?>[] interfaces = clazz.getInterfaces();
                    for (Class<?> interClazz : interfaces) {
                        ioc.put(interClazz.getName(), clazz.newInstance());
                    }

                } else if (clazz.isAnnotationPresent(MyResponseAdvice.class)) {
                    if (clazz.isAssignableFrom(ResponseBodyHandler.class)) {
                        if (null != responseBodyHandler) {
                            continue;
                        }
                        responseBodyHandler = (ResponseBodyHandler) clazz.newInstance();
                    } else {
                        logger.error("class+'" + clazz.getName() + "' must implement ResponseBodyHandler");
                        throw new Exception("class+'" + clazz.getName() + "' must implement ResponseBodyHandler");
                    }
                } else {
                    continue;
                }


            } catch (Exception e) {
                e.printStackTrace();
                continue;
            }
        }
    }

    /**
     * 把字符串的首字母小写
     *
     * @param name
     * @return
     */
    private String toFirstWordLower(String name) {
        char[] charArray = name.toCharArray();
        charArray[0] += 32;
        return String.valueOf(charArray);
    }

}

  TestController:

import cn.lyh.mySpring.annotation.*;
import cn.lyh.mySpringTest.domain.User;
import cn.lyh.mySpringTest.service.TestService;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;


@MyController
@MyRequestMapping("/test")
public class TestController {
    @MyAutowired
    private TestService testService;

    @MyRequestMapping("test1")
    public String test1(@MyRequestParam("name") String name,
                        @MyRequestParam("sex") Integer sex,
                        HttpServletRequest request,
                        HttpServletResponse response) throws IOException {

        return "name=" + name + "sex=" + sex;
    }


    @MyRequestMapping("test2")
    public void test2() {


    }


    @MyRequestMapping("test3")
    @MyResponseBody
    public Map<String, Object> test3(@MyRequestParam("name") String name,
                                     @MyRequestParam("sex") Integer sex,
                                     HttpServletRequest request,
                                     HttpServletResponse response) throws IOException {
        Map<String, Object> result = new HashMap<>();
        result.put("name", name);
        result.put("sex", name);

        return result;
    }

    @MyRequestMapping("test4")
    @MyResponseBody
    public User test4(@MyRequestParam("name") String name,
                      @MyRequestParam("sex") Integer sex,
                      HttpServletRequest request,
                      HttpServletResponse response) throws IOException {
        User user = new User();
        user.setName(name);
        user.setId(sex);

        return user;
    }

    @MyRequestMapping("test5")
    @MyResponseBody
    public List test5(@MyRequestParam("name") String name,
                      @MyRequestParam("sex") Integer sex,
                      HttpServletRequest request,
                      HttpServletResponse response) throws IOException {
        List list = new ArrayList();
        User user = new User();
        user.setName(name);
        user.setId(sex);
        list.add(user);

        return list;
    }

    @MyRequestMapping("test6")
    @MyResponseBody
    public List test5(HttpServletRequest request,
                      HttpServletResponse response) throws IOException {
        List list = new ArrayList();
        User user = new User();
        user.setName(null);
        user.setId(1);
        list.add(user);

        return list;
    }


}

  pom文件依赖:

 <dependency>
            <groupId>javax.servlet</groupId>
            <artifactId>javax.servlet-api</artifactId>
            <version>4.0.1</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>1.6.6</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-log4j12</artifactId>
            <version>1.7.2</version>
        </dependency>
        <dependency>
            <groupId>log4j</groupId>
            <artifactId>log4j</artifactId>
            <version>1.2.17</version>
        </dependency>
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.47</version>
            <scope>compile</scope>
        </dependency>

最后附上源码地址(MySpring Moudle)

猜你喜欢

转载自www.cnblogs.com/lyhc/p/10129250.html