动态规划求解最短路径问题

一、需求背景

现有一张地图,各结点代表城市,两结点间连线代表道路,线上数字表示城市间的距离。如下图所示,请找出从起点A到终点E的最短距离。

二、算法描述

利用动态规划的思想,求解最短路径问题,算法过程如下:

1.节点标号。

将节点A到节点E进行标号,A节点标号0,B1节点标号1......以此类型,节点E标号10。

2.描述最优解方程。

令f(i)表示从起点0到节点i的最短距离,节点j为与节点i相连接的节点,d[j][i]表示节点j与节点i之间的距离,则:

f(i) = min(f(j) + d[j][i])

很显然,f(0)=0。

3.自底向上,逐步求解。

利用第2步的公式,从节点1开始,逐步求解,直至节点10结束。

三、算法实现

3.1 各节点距离不同,只求一条最短路径。

先计算出各个节点的最短路径集合dist,然后根据dist自顶向下计算最优路径。

具体代码为:

package com.test.dynamicalgothrim;

import java.util.Arrays;
import java.util.Stack;

/**
 * 利用动态规划求解最短路径问题
 */

public class OneRouteDjsterMinDistance {
    // 计算最短距离
    public static int[] calMinDistance(int[][] distance) {
        int[] dist = new int[distance.length];
        dist[0] = 0;

        for (int i = 1; i < distance.length; i++) {
            int iMinDist = Integer.MAX_VALUE;
            for (int j = 0; j < i; j++) {
                if (distance[j][i] != 0) {
                    if ((dist[j] + distance[j][i]) < iMinDist) {
                        iMinDist = dist[j] + distance[j][i];
                    }
                }
            }
            dist[i] = iMinDist;
        }

        return dist;
    }

    // 计算路径
    public static String calTheRoute(int[][] distance, int[] dist) {
        Stack<Integer> st = new Stack<>();
        StringBuilder buf = new StringBuilder();
        int i = distance.length - 1;
        st.add(i); // 将尾插入
        while (i > 0) {
            // int num = 0;
            for (int j = 0; j < i; j++) {
                if (distance[j][i] != 0) {
                    // num++;
                    if (dist[i] - distance[j][i] == dist[j]) {
                        st.add(j);
                    }
                }
            }
            i = st.peek();
        }

        String arrow = "-->";
        while (!st.empty()) {
            buf.append(st.pop()).append(arrow);
        }
        String result = buf.toString();
        result = result.substring(0, result.length()-arrow.length());
        return result;
    }


    public static void main(String[] args) {
         6个点
//        int[][] map = new int[6][6];
//        map[0][1]=2;map[0][2]=3;map[0][3]=6;
//        map[1][0]=2;map[1][4]=4;map[1][5]=6;
//        map[2][0]=3;map[2][3]=2;
//        map[3][0]=6;map[3][2]=2;map[3][4]=1;map[3][5]=3;
//        map[4][1]=4;map[4][3]=1;
//        map[5][1]=6;map[5][3]=3;
         11个点
        int[][] map = new int[11][11];
        map[0][1]=5;map[0][2]=3;
        map[1][0]=5;map[1][3]=1;map[1][4]=6;map[1][5]=8;
        map[2][0]=3;map[2][4]=8;map[2][6]=4;
        map[3][1]=1;map[3][7]=5;map[3][8]=6;
        map[4][1]=6;map[4][2]=8;map[4][7]=5;
        map[5][1]=3;map[5][9]=8;
        map[6][2]=4;map[6][9]=3;
        map[7][3]=5;map[7][4]=5;map[7][10]=3;
        map[8][3]=6;map[8][10]=4;
        map[9][5]=8;map[9][6]=3;map[9][10]=3;
        int[] dist = calMinDistance(map);

        System.out.println("dist:" + Arrays.toString(dist));
        System.out.println("最短路径长度为:" + dist[map.length - 1]);
        System.out.println("最短路径为:" + calTheRoute(map, dist));
    }
}

输出结果为:

dist:[0, 5, 3, 6, 11, 13, 7, 11, 12, 10, 13]
最短路径长度为:13
最短路径为:0-->2-->6-->9-->10

3.2 存在相同节点距离,只求一条最短路径。

若存在相同节点距离,则就不能根据dist自顶向下计算最优路径;否则,会出现错误节点。这种情况下,可以在计算最短路径时同步保存各节点的最优路径。

具体代码如下:

package com.test.dynamicalgothrim;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
 * 利用动态规划求解最短路径问题
 *
 */

public class OneRouteMinDistance {
    // 计算最短距离
    public static int[] calMinDistance(int[][] distance, Map<Integer, List<Integer>> point2Routes) {
        int[] dist = new int[distance.length];
        dist[0] = 0;

        for (int i = 1; i < distance.length; i++) {
            int iMinDist = Integer.MAX_VALUE;
            int pre = 0;
            for (int j = 0; j < i; j++) {
                if (distance[j][i] != 0) {
                    if ((dist[j] + distance[j][i]) < iMinDist) {
                        iMinDist = dist[j] + distance[j][i];
                        pre = j;
                    }
                }
            }
            List<Integer> preRoutes = point2Routes.getOrDefault(pre, Lists.newArrayList());
            List<Integer> routes = point2Routes.getOrDefault(i, Lists.newArrayList());
            routes.addAll(preRoutes);
            routes.add(pre);
            point2Routes.put(i, routes);
            dist[i] = iMinDist;
        }

        return dist;
    }


    public static void main(String[] args) {
         6个点
//        int[][] map = new int[6][6];
//        map[0][1]=2;map[0][2]=3;map[0][3]=6;
//        map[1][0]=2;map[1][4]=4;map[1][5]=6;
//        map[2][0]=3;map[2][3]=2;
//        map[3][0]=6;map[3][2]=2;map[3][4]=1;map[3][5]=3;
//        map[4][1]=4;map[4][3]=1;
//        map[5][1]=6;map[5][3]=3;
         11个点
        int[][] map = new int[11][11];
        map[0][1]=5;map[0][2]=3;
        map[1][0]=5;map[1][3]=1;map[1][4]=6;map[1][5]=8;
        map[2][0]=3;map[2][4]=8;map[2][6]=4;
        map[3][1]=1;map[3][7]=5;map[3][8]=6;
        map[4][1]=6;map[4][2]=8;map[4][7]=5;
        map[5][1]=3;map[5][9]=8;
        map[6][2]=4;map[6][9]=3;
        map[7][3]=5;map[7][4]=5;map[7][10]=3;
        map[8][3]=6;map[8][10]=4;
        map[9][5]=8;map[9][6]=3;map[9][10]=3;
        Map<Integer, List<Integer>> point2Routes = Maps.newHashMap();
        int[] dist = calMinDistance(map, point2Routes);
        System.out.println("dist:" + Arrays.toString(dist));
        System.out.println("最短路径长度为:" + dist[map.length - 1]);

        List<Integer> routePoints = Lists.newArrayList();
        routePoints.addAll(point2Routes.get(map.length - 1));
        routePoints.add(map.length - 1);
        String arrow = "-->";
        StringBuilder buf = new StringBuilder();
        for (Integer point: routePoints) {
            buf.append(point).append(arrow);
        }

        String result = buf.toString();
        result = result.substring(0, result.length()-arrow.length());
        System.out.println("最短路径为:" + result);
    }
}

计算结果为:

dist:[0, 5, 3, 6, 11, 13, 7, 11, 12, 10, 13]
最短路径长度为:13
最短路径为:0-->2-->6-->9-->10

3.3 存在相同节点距离,求取所有最短路径。

这种情况下,在计算最短路径的过程中,需要记录各节点所有的最短路径。

具体代码为:

package com.test.dynamicalgothrim;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.collections4.CollectionUtils;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
 * 利用动态规划求解最短路径问题
 *
 */

public class MultiRoutesMinDistance {
    // 计算最短距离
    public static int[] calMinDistance(int[][] distance, Map<Integer, List<List<Integer>>> point2Routes) {
        int[] dist = new int[distance.length];
        dist[0] = 0;

        for (int i = 1; i < distance.length; i++) {
            // 计算起点到节点i的最短路径值
            int iMinDist = Integer.MAX_VALUE;
            for (int j = 0; j < i; j++) {
                if (distance[j][i] != 0) {
                    int j2iDist = dist[j] + distance[j][i];
                    if (j2iDist < iMinDist) {
                        iMinDist = j2iDist;
                    }
                }
            }
            // 可能存在多条路径节点
            List<Integer> preList = Lists.newArrayList();
            for (int j = 0; j < i; j++) {
                if (distance[j][i] != 0) {
                    int j2iDistance = dist[j] + distance[j][i];
                    if (j2iDistance == iMinDist) {
                        preList.add(j);
                    }
                }
            }

            List<List<Integer>> routes = Lists.newArrayList();
            preList.forEach(pre-> {
                List<List<Integer>> preRoutes = point2Routes.getOrDefault(pre, Lists.newArrayList());
                if (CollectionUtils.isNotEmpty(preRoutes)) {
                    preRoutes.forEach(preRoute -> {
                        List<Integer> route = Lists.newArrayList();
                        route.addAll(preRoute);
                        route.add(pre);
                        routes.add(route);
                    });
                } else {
                    List<Integer> route = Lists.newArrayList();
                    route.add(pre);
                    routes.add(route);
                }
            });

            point2Routes.put(i, routes);
            dist[i] = iMinDist;
        }

        return dist;
    }


    public static void main(String[] args) {
         6个点
//        int[][] map = new int[6][6];
//        map[0][1]=2;map[0][2]=3;map[0][3]=6;
//        map[1][0]=2;map[1][4]=4;map[1][5]=6;
//        map[2][0]=3;map[2][3]=2;
//        map[3][0]=6;map[3][2]=2;map[3][4]=1;map[3][5]=3;
//        map[4][1]=4;map[4][3]=1;
//        map[5][1]=6;map[5][3]=3;
         11个点
        int[][] map = new int[11][11];
        map[0][1]=5;map[0][2]=3;
        map[1][0]=5;map[1][3]=1;map[1][4]=6;map[1][5]=8;
        map[2][0]=3;map[2][4]=8;map[2][6]=4;
        map[3][1]=1;map[3][7]=5;map[3][8]=6;
        map[4][1]=6;map[4][2]=8;map[4][7]=5;
        map[5][1]=3;map[5][9]=8;
        map[6][2]=4;map[6][9]=3;
        map[7][3]=5;map[7][4]=5;map[7][10]=3;
        map[8][3]=6;map[8][10]=4;
        map[9][5]=8;map[9][6]=3;map[9][10]=3;
        Map<Integer, List<List<Integer>>> point2Routes = Maps.newHashMap();
        int[] dist = calMinDistance(map, point2Routes);
        System.out.println("dist:" + Arrays.toString(dist));
        System.out.println("最短路径长度为:" + dist[map.length - 1]);

        String arrow = "-->";
        List<List<Integer>> pointRoutes = point2Routes.get(map.length - 1);
        int i = 1;
        for (List<Integer> routePoints: pointRoutes) {
            StringBuilder buf = new StringBuilder();
            for (Integer point: routePoints) {
                buf.append(point).append(arrow);
            }
            // 添加终点
            buf.append(map.length - 1);
            String result = buf.toString();
            System.out.println("第 " + i + " 条最短路径为:" + result);
            i++;
        }
    }
}

运行结果:

dist:[0, 5, 3, 6, 11, 13, 7, 11, 12, 10, 13]
最短路径长度为:13
第 1 条最短路径为:0-->2-->6-->9-->10

3.4 从本地文件读取数据源计算最短路径

如果是从本地文件读取数据,可以使用如下方法:

public static int[][] readTheFile(File f, int n) throws Exception {
        Reader input = null;
        try {
            input = new FileReader(f);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            throw new Exception("文件读取失败");
        }

        BufferedReader buf = null;
        buf = new BufferedReader(input);
        List<String> list = new ArrayList<String>();
        try {
            String str = buf.readLine();
            while (str != null) {
                list.add(str);
                str = buf.readLine();
            }
        } catch (IOException e) {
            e.printStackTrace();
            throw new Exception("文件解析失败");
        }

        Iterator<String> it = list.iterator();
        int[][] distance = new int[n][n];
        while (it.hasNext()) {
            String[] str1 = it.next().split(",");
            int i = Integer.parseInt(str1[0]); // 序号值
            int j = Integer.parseInt(str1[1]);  // 序号值
            distance[i - 1][j - 1] = Integer.parseInt(str1[2]); // 第i个节点与第j个节点之间的距离值
        }
        return distance;
    }

文件内容为:

1,2,5
1,3,3
2,1,5
2,4,1
2,5,6
2,6,8
3,1,3
3,5,8
3,7,4
4,2,1
4,8,5
4,9,6
5,2,6
5,3,8
5,8,5
6,2,3
6,10,8
7,3,4
7,10,3
8,4,5
8,5,5
8,11,3
9,4,6
9,11,4
10,6,8
10,7,3
10,11,3

此时,map[][]值通过如下方式获得:

 File f = new File("/test/data/distance_data.csv");
 int map[][] = readTheFile(f, 11);

运行结果与从内存中直接赋值是一样的。

猜你喜欢

转载自blog.csdn.net/chinawangfei/article/details/123405461