package ru.yandex.chemodan.app.cvdemo2.admin;

import java.time.Instant;
import java.util.List;

import lombok.Data;
import org.junit.Test;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.misc.random.Random2;

/**
 * @author tolmalev
 */
public class SmallGridGeneratorTest {

    MapF<String, List<BaseGrid>> baseGridsByPhotoTypes = Cf.hashMap();
    MapF<String, int[][]> mp = Cf.hashMap();

    @Test
    public void generateFullGrid() {
        baseGridsByPhotoTypes = Cf.hashMap();
        mp = Cf.hashMap();

        generateBaseGrids();

        int photosCount = 1000000;
        int iterations = 100;

        String typesStr = Cf.repeat(() -> Random2.R.nextBoolean() ? "v" : "h", photosCount).mkString("");
        double[] beauty = Cf.repeat(() -> Random2.R.nextDouble(), photosCount).mapToDoubleArray(i -> i);


        for (int i = 0; i < iterations; i++) {
            generateFullGrid(typesStr, beauty, 8, 0.25);
        }

        Instant start = Instant.now();
        for (int i = 0; i < iterations; i++) {
            generateFullGrid(typesStr, beauty, 8, 0.25);
        }
        Instant end = Instant.now();

        System.out.println("Time past: " + (end.toEpochMilli() - start.toEpochMilli()));
        System.out.println("Time for one generation: " + (end.toEpochMilli() - start.toEpochMilli()) / 100.0);
    }

    @Test
    public void generateBaseGrids() {
        baseGridsByPhotoTypes.getOrElseUpdate("h", Cf.arrayList()).add(new BaseGrid(6, 6, Cf.list(new PhotoPosition(0, 0, 6, 6))));
        baseGridsByPhotoTypes.getOrElseUpdate("v", Cf.arrayList()).add(new BaseGrid(6, 8, Cf.list(new PhotoPosition(0, 0, 6, 8))));

        mp.put("h", new int[][] {
                {2, 2},
                {3, 3},
                {4, 4},
                {6, 6}
        });

        mp.put("v", new int[][] {
                {2, 4},
                {3, 5},
                {4, 6},
                {6, 10}
        });

        mp.put("s", new int[][] {
                {3, 4},
                {6, 8},
        });

        mp.put("p", new int[][] {
                {4, 2},
                {6, 3}
        });

        int[][] grid = new int[200][6];

        for (int i = 1; i <= 10; i++) {
            generateBaseGrids(new Grid(), grid, i);
        }
        baseGridsByPhotoTypes.values().flatMap(a -> a).groupBy(l -> l.pos.size()).forAllEntries((len, c) -> {
            System.out.println("Count grids for " + len + " photos = " + c.size());
            return true;
        });
    }

    public FullGrid generateFullGrid(String typesStr, double[] beauty, int beautySizeK, double beautyK) {
        double averageBeauty = 0;
        for (int i = 0; i < beauty.length; i++) {
            averageBeauty += beauty[i];
        }
        averageBeauty /= beauty.length;

        FullGrid[] grids = new FullGrid[typesStr.length() + 1];
        grids[0] = new FullGrid(0, 0, new BaseGrid(0, 0, Cf.list()), Option.empty());

        for (int i = 0; i < typesStr.length(); i++) {
            for (int j = Math.max(0, i - 10); j <= i; j++) {
                String lastTypes = typesStr.substring(j, i + 1);
                for (BaseGrid grid : Random2.R.shuffle(baseGridsByPhotoTypes.getOrElse(lastTypes, Cf.list()))) {
                    FullGrid prevGrid = grids[j];

                    double newWeight = prevGrid.weight - grid.H;
                    newWeight += grid.getPos().size();

                    for (int p = 0; p < grid.getPos().size(); p++) {
                        double thisBeauty = beauty[j + p];
                        thisBeauty = Math.min(thisBeauty, 5);
                        thisBeauty = Math.max(thisBeauty, -5);
                        int size = grid.getPos().get(p).w * grid.getPos().get(p).h;
                        newWeight += (thisBeauty - averageBeauty) * (size - beautySizeK) * beautyK;
                    }
                    if (grid.getPos().size() == 1 && prevGrid.lastBaseGrid.getPos().size() == 1) {
                        newWeight -= grid.H * 2;
                    }

                    int difSizes = Cf.x(grid.getPos()).map(p -> p.w * p.h).unique().size();
                    newWeight += difSizes * 4;

                    if (grids[i + 1] == null || newWeight > grids[i + 1].weight) {
                        grids[i + 1] = new FullGrid(
                                prevGrid.totalH + grid.H,
                                newWeight,
                                grid,
                                Option.of(prevGrid)
                        );
                    }
                }
            }
        }
        return grids[typesStr.length()];
    }

    private void generateBaseGrids(Grid cur, int[][] grid, int maxCount) {
        if (maxCount == 0) {
            boolean gridIsBad = false;
            int gridLines = 0;
            for (int i = 0; i < grid.length; i++) {
                boolean fullFilled = true;
                boolean fullZero = true;
                boolean allNotSameAsPrev = i != 0;

                for (int j = 0; j < grid[i].length; j++) {
                    if (i > 0 && grid[i][j] == grid[i - 1][j]) {
                        allNotSameAsPrev = false;
                    }

                    if (grid[i][j] == 0) {
                        fullFilled = false;
                    } else {
                        fullZero = false;
                    }
                }
                if (!fullZero && allNotSameAsPrev) {
                    gridIsBad = true;
                    break;
                }
                if (!fullFilled && !fullZero) {
                    gridIsBad = true;
                    break;
                }
                if (fullZero) {
                    gridIsBad = false;
                    gridLines = i;
                    break;
                }
            }
            if (!gridIsBad) {
                printGrid(cur, grid, gridLines);
            }
            return;
        }

        int startI = -1, startJ = -1;
        for (int i = 0; startI < 0 && i < grid.length; i++) {
            for (int j = 0; startI < 0 && j < grid[i].length; j++) {
                if (grid[i][j] == 0) {
                    startI = i;
                    startJ = j;
                }
            }
        }

        for (String type : Cf.list("h", "v", "s")) {
            for (int[] sizes : mp.get(type)) {
                int w = sizes[0];
                int h = sizes[1];

                if (startJ + w <= grid[0].length) {
                    boolean goodPlace = true;
                    for (int i = 0; goodPlace && i < h; i++) {
                        for (int j = 0; goodPlace && j < w; j++) {
                            if (grid[startI + i][startJ + j] != 0) {
                                goodPlace = false;
                            }
                        }
                    }
                    if (!goodPlace) {
                        continue;
                    }

                    for (int i = 0; goodPlace && i < h; i++) {
                        for (int j = 0; goodPlace && j < w; j++) {
                            grid[startI + i][startJ + j] = maxCount;
                        }
                    }

                    cur.name += type;
                    cur.positions.add(new PhotoPosition(startJ, startI, w, h));
                    generateBaseGrids(cur, grid, maxCount - 1);

                    cur.name = cur.name.substring(0, cur.name.length() - 1);
                    cur.positions.remove(cur.positions.size() - 1);

                    for (int i = 0; goodPlace && i < h; i++) {
                        for (int j = 0; goodPlace && j < w; j++) {
                            grid[startI + i][startJ + j] = 0;
                        }
                    }
                }
            }
        }
    }

    private void printGrid(Grid cur, int[][] grid, int gridLines) {
        System.out.println(cur.name);
        for (PhotoPosition pos : cur.positions) {
            System.out.println(pos.x + "," + pos.y + " " + pos.w + "x" + pos.h);
        }

//        for (int i = 0; i < gridLines; i++) {
//            for (int j = 0; j < grid[i].length; j++) {
//                System.out.print(grid[i][j]);
//            }
//            System.out.println();
//        }
//        System.out.println();
//        System.out.println();

//        allGrids.add(new Grid(cur.name, Cf.x(cur.positions).makeReadOnly()));
        List<BaseGrid> positions = baseGridsByPhotoTypes.getOrElseUpdate(cur.name, Cf.arrayList());
        BaseGrid baseGrid = new BaseGrid(grid[0].length, gridLines, cur.positions.map(i -> i));
        positions.add(baseGrid);
    }

    static class Grid {
        String name;
        ListF<PhotoPosition> positions;

        public Grid(String name, List<PhotoPosition> positions) {
            this.name = name;
            this.positions = Cf.x(positions);
        }

        public Grid() {
            this.name = "";
            this.positions = Cf.arrayList();
        }
    }

    @Data
    static class BaseGrid {
        public final int W;
        public final int H;

        public final List<PhotoPosition> pos;
    }

    @Data
    static class PhotoPosition {
        public final int x;
        public final int y;
        public final int w;
        public final int h;

        double getAspect() {
            return 1.0 * w / h;
        }
    }

    @Data
    static class FullGrid {
        public final int totalH;
        public final double weight;

        public final BaseGrid lastBaseGrid;

        public final Option<FullGrid> prevGrid;
    }

    @Data
    static class FinalGrid {
        public final List<PhotoPosition> pos;
    }
}
