springboot下Filter的POST和GET过滤参数

//定义一个filter过滤器



import org.apache.commons.lang.StringUtils;
import org.springframework.stereotype.Component;
import org.apache.commons.lang.StringEscapeUtils;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.Map;
import java.util.Set;

@Component
@WebFilter(filterName = "ValidatorFilter" , urlPatterns = "/*")
public class ValidatorFilter implements Filter {
    String[] strArr = {"\"","%","'"};
    @Override
    public void doFilter(ServletRequest request,
                         ServletResponse response,
                         FilterChain chain) throws IOException, ServletException{
     
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        String method  = (httpServletRequest.getMethod());
        Map<String, String[]> map = httpServletRequest.getParameterMap();
        ServletRequest requestWrapper = null;
        GetParameterRequestWrapper requestWrapper1= null;
        if(httpServletRequest.getMethod().equals("POST")){
             requestWrapper = new PostParameterRequestWrapper(httpServletRequest,method,map);
             chain.doFilter(requestWrapper, response);
        }else if(httpServletRequest.getMethod().equals("GET")){
            requestWrapper1 = new GetParameterRequestWrapper((HttpServletRequest)request);
            Set<String> key = map.keySet();
            for(String arr :strArr){
                for(String k : key){
                    String[] arrValues =  map.get(k);
                    String newValues= StringUtils.join(arrValues);
                    if(newValues.contains(arr)){
                        //对不合法参数转义
                        String escape = StringEscapeUtils.escapeXml(arr);
                        String s1 = newValues.replace(arr,escape);
                        //重新put相同的key,替换对应的values
                        requestWrapper1.addParameter(k, new String[]{s1});
                    }
                }
            }
            chain.doFilter(requestWrapper1, response);
        }
    }


    @Override
    public void destroy() { }


    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
    }


}
//get方式,修改请求域中的参数值,拦截不合法的参数,进行转义



import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.*;

class GetParameterRequestWrapper extends HttpServletRequestWrapper {


    private Map<String , String[]> params = new HashMap<String, String[]>();


    @SuppressWarnings("unchecked")
    public GetParameterRequestWrapper(HttpServletRequest request) {
        super(request);
        this.params.putAll(request.getParameterMap());
    }


    public GetParameterRequestWrapper(HttpServletRequest request , Map<String , Object> extendParams) {
        this(request);
        addAllParameters(extendParams);
    }

    @Override
    public String getParameter(String name) {
        String[] values = params.get(name);
        if (values == null || values.length == 0) {
            return null;
        }
        return values[0];
    }


    public String[] getParameterValues(String name) {
        return params.get(name);
    }

    public void addAllParameters(Map<String , Object>otherParams) {
        for(Map.Entry<String , Object>entry : otherParams.entrySet()) {
            addParameter(entry.getKey() , entry.getValue());
        }
    }

    public void addParameter(String name , Object value) {
        if(value != null) {
            if(value instanceof String[]) {
                params.put(name , (String[])value);
            }else if(value instanceof String) {
                params.put(name , new String[] {(String)value});
            }else {
                params.put(name , new String[] {String.valueOf(value)});
            }
        }
    }
}
//post方式,修改请求域中的参数值,拦截不合法的参数,进行转义





import org.apache.commons.lang.StringEscapeUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class PostParameterRequestWrapper  extends HttpServletRequestWrapper {

    private  byte[] body;
    String[] strArr = {"\"","%","'"};

    public PostParameterRequestWrapper(HttpServletRequest request, String method, Map<String, String[]> newParams) throws IOException {
        super(request);

        //获取request域json类型参数
        String param = getBodyString(request);

        //拆分json,参数属性放一个List集合中
        List<String> shuxing = new ArrayList<String>();
        //拆分json,参数值放一个List集合中
        List<String> values = new ArrayList<String>();

        System.out.println("param  "+param);
        if(param!= null && !param.equals("")){
            String newParam = param.substring(1,param.length()-1);
            String[] arrParam = newParam.split(",");
            for(String arr : arrParam){
                String[] newArr =  arr.split(":");

                //属性
                String par = newArr[0].trim();
                if(par.contains("\"") && par.length()>2){
                    par = par.substring(1,par.length()-1);
                }
                shuxing.add(par);

                //值
                if(newArr.length>1){
                    String par1 = newArr[1].trim();
                    if(par1.contains("\"") && par1.length()>2){
                        par1 = par1.substring(1,par1.length()-1);
                    }
                    values.add(par1);
                }else{
                    values.add("");
                }
            }

            //对值里面的不合法参数转义
            for(int i = 0;i<shuxing.size();i++){
                for(String arr :strArr){
                    if(values.get(i).contains(arr)){
                        //对不合法参数values转义
                        String newValues = StringEscapeUtils.escapeXml(arr);
                        String s1 = values.get(i).replace(arr,newValues);
                        values.set(i,s1);
                    }
                }
            }
            StringBuffer bf =new StringBuffer();
            //重组json字符串
            for(int k = 0;k<shuxing.size();k++){
                if(k+1 != shuxing.size()){
                    bf.append("\""+shuxing.get(k)+"\""+":"+ "\""+ values.get(k)+"\""+",");
                }else{
                    bf.append("\""+shuxing.get(k)+"\""+":"+  "\""+values.get(k)+"\"");
                }
            }
            String sb = "{"+ bf.toString() +"}";
            System.out.println("sb "+sb);
            body = sb.getBytes(Charset.forName("UTF-8"));
        }
    }

    /**
     * 获取请求Body
     *
     * @param request
     * @return
     */
    public String getBodyString(final ServletRequest request) {
        StringBuilder sb = new StringBuilder();
        InputStream inputStream = null;
        BufferedReader reader = null;
        try {
            inputStream = cloneInputStream(request.getInputStream());
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        finally {
            if (inputStream != null) {
                try {
                    inputStream.close();
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (reader != null) {
                try {
                    reader.close();
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        System.out.println("sb.toString " +sb.toString());
        return sb.toString();
    }

    /**
     * Description: 复制输入流</br>
     *
     * @param inputStream
     * @return</br>
     */
    public InputStream cloneInputStream(ServletInputStream inputStream) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[1024];
        int len;
        try {
            while ((len = inputStream.read(buffer)) > -1) {
                byteArrayOutputStream.write(buffer, 0, len);
            }
            byteArrayOutputStream.flush();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
        return byteArrayInputStream;
    }
    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {

        final ByteArrayInputStream bais = new ByteArrayInputStream(body);

        return new ServletInputStream() {

            @Override
            public int read() throws IOException {
                return bais.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }
        };
    }

}

猜你喜欢

转载自blog.csdn.net/SmallTenMr/article/details/82786468