Update lightrec 20220910 (#686)
[pcsx_rearmed.git] / deps / lightrec / optimizer.c
index 8da84ee..2eba60e 100644 (file)
@@ -119,11 +119,31 @@ static u64 opcode_read_mask(union code op)
        }
 }
 
-static u64 opcode_write_mask(union code op)
+static u64 mult_div_write_mask(union code op)
 {
        u64 flags;
 
+       if (!OPT_FLAG_MULT_DIV)
+               return BIT(REG_LO) | BIT(REG_HI);
+
+       if (op.r.rd)
+               flags = BIT(op.r.rd);
+       else
+               flags = BIT(REG_LO);
+       if (op.r.imm)
+               flags |= BIT(op.r.imm);
+       else
+               flags |= BIT(REG_HI);
+
+       return flags;
+}
+
+static u64 opcode_write_mask(union code op)
+{
        switch (op.i.op) {
+       case OP_META_MULT2:
+       case OP_META_MULTU2:
+               return mult_div_write_mask(op);
        case OP_SPECIAL:
                switch (op.r.op) {
                case OP_SPECIAL_JR:
@@ -134,18 +154,7 @@ static u64 opcode_write_mask(union code op)
                case OP_SPECIAL_MULTU:
                case OP_SPECIAL_DIV:
                case OP_SPECIAL_DIVU:
-                       if (!OPT_FLAG_MULT_DIV)
-                               return BIT(REG_LO) | BIT(REG_HI);
-
-                       if (op.r.rd)
-                               flags = BIT(op.r.rd);
-                       else
-                               flags = BIT(REG_LO);
-                       if (op.r.imm)
-                               flags |= BIT(op.r.imm);
-                       else
-                               flags |= BIT(REG_HI);
-                       return flags;
+                       return mult_div_write_mask(op);
                case OP_SPECIAL_MTHI:
                        return BIT(REG_HI);
                case OP_SPECIAL_MTLO:
@@ -361,6 +370,22 @@ static bool opcode_is_store(union code op)
        }
 }
 
+static u8 opcode_get_io_size(union code op)
+{
+       switch (op.i.op) {
+       case OP_LB:
+       case OP_LBU:
+       case OP_SB:
+               return 8;
+       case OP_LH:
+       case OP_LHU:
+       case OP_SH:
+               return 16;
+       default:
+               return 32;
+       }
+}
+
 bool opcode_is_io(union code op)
 {
        return opcode_is_load(op) || opcode_is_store(op);
@@ -601,10 +626,48 @@ static u32 lightrec_propagate_consts(const struct opcode *op,
                                known &= ~BIT(c.r.rd);
                        }
                        break;
+               case OP_SPECIAL_MULT:
+               case OP_SPECIAL_MULTU:
+               case OP_SPECIAL_DIV:
+               case OP_SPECIAL_DIVU:
+                       if (OPT_FLAG_MULT_DIV && c.r.rd)
+                               known &= ~BIT(c.r.rd);
+                       if (OPT_FLAG_MULT_DIV && c.r.imm)
+                               known &= ~BIT(c.r.imm);
+                       break;
                default:
                        break;
                }
                break;
+       case OP_META_MULT2:
+       case OP_META_MULTU2:
+               if (OPT_FLAG_MULT_DIV && (known & BIT(c.r.rs))) {
+                       if (c.r.rd) {
+                               known |= BIT(c.r.rd);
+
+                               if (c.r.op < 32)
+                                       v[c.r.rd] = v[c.r.rs] << c.r.op;
+                               else
+                                       v[c.r.rd] = 0;
+                       }
+
+                       if (c.r.imm) {
+                               known |= BIT(c.r.imm);
+
+                               if (c.r.op >= 32)
+                                       v[c.r.imm] = v[c.r.rs] << (c.r.op - 32);
+                               else if (c.i.op == OP_META_MULT2)
+                                       v[c.r.imm] = (s32) v[c.r.rs] >> (32 - c.r.op);
+                               else
+                                       v[c.r.imm] = v[c.r.rs] >> (32 - c.r.op);
+                       }
+               } else {
+                       if (OPT_FLAG_MULT_DIV && c.r.rd)
+                               known &= ~BIT(c.r.rd);
+                       if (OPT_FLAG_MULT_DIV && c.r.imm)
+                               known &= ~BIT(c.r.imm);
+               }
+               break;
        case OP_REGIMM:
                break;
        case OP_ADDI:
@@ -911,7 +974,8 @@ static int lightrec_transform_branches(struct lightrec_state *state,
                                op->i.imm = offset;
 
                        }
-               default: /* fall-through */
+                       fallthrough;
+               default:
                        break;
                }
        }
@@ -919,6 +983,11 @@ static int lightrec_transform_branches(struct lightrec_state *state,
        return 0;
 }
 
+static inline bool is_power_of_two(u32 value)
+{
+       return popcount32(value) == 1;
+}
+
 static int lightrec_transform_ops(struct lightrec_state *state, struct block *block)
 {
        struct opcode *list = block->opcode_list;
@@ -926,6 +995,7 @@ static int lightrec_transform_ops(struct lightrec_state *state, struct block *bl
        u32 known = BIT(0);
        u32 values[32] = { 0 };
        unsigned int i;
+       u8 tmp;
 
        for (i = 0; i < block->nb_ops; i++) {
                prev = op;
@@ -1000,6 +1070,28 @@ static int lightrec_transform_ops(struct lightrec_state *state, struct block *bl
                                        op->r.rs = op->r.rt;
                                }
                                break;
+                       case OP_SPECIAL_MULT:
+                       case OP_SPECIAL_MULTU:
+                               if ((known & BIT(op->r.rs)) &&
+                                   is_power_of_two(values[op->r.rs])) {
+                                       tmp = op->c.i.rs;
+                                       op->c.i.rs = op->c.i.rt;
+                                       op->c.i.rt = tmp;
+                               } else if (!(known & BIT(op->r.rt)) ||
+                                          !is_power_of_two(values[op->r.rt])) {
+                                       break;
+                               }
+
+                               pr_debug("Multiply by power-of-two: %u\n",
+                                        values[op->r.rt]);
+
+                               if (op->r.op == OP_SPECIAL_MULT)
+                                       op->i.op = OP_META_MULT2;
+                               else
+                                       op->i.op = OP_META_MULTU2;
+
+                               op->r.op = ffs32(values[op->r.rt]);
+                               break;
                        case OP_SPECIAL_OR:
                        case OP_SPECIAL_ADD:
                        case OP_SPECIAL_ADDU:
@@ -1028,6 +1120,64 @@ static int lightrec_transform_ops(struct lightrec_state *state, struct block *bl
        return 0;
 }
 
+static bool lightrec_can_switch_delay_slot(union code op, union code next_op)
+{
+       switch (op.i.op) {
+       case OP_SPECIAL:
+               switch (op.r.op) {
+               case OP_SPECIAL_JALR:
+                       if (opcode_reads_register(next_op, op.r.rd) ||
+                           opcode_writes_register(next_op, op.r.rd))
+                               return false;
+                       fallthrough;
+               case OP_SPECIAL_JR:
+                       if (opcode_writes_register(next_op, op.r.rs))
+                               return false;
+                       fallthrough;
+               default:
+                       break;
+               }
+               fallthrough;
+       case OP_J:
+               break;
+       case OP_JAL:
+               if (opcode_reads_register(next_op, 31) ||
+                   opcode_writes_register(next_op, 31))
+                       return false;;
+
+               break;
+       case OP_BEQ:
+       case OP_BNE:
+               if (op.i.rt && opcode_writes_register(next_op, op.i.rt))
+                       return false;
+               fallthrough;
+       case OP_BLEZ:
+       case OP_BGTZ:
+               if (op.i.rs && opcode_writes_register(next_op, op.i.rs))
+                       return false;
+               break;
+       case OP_REGIMM:
+               switch (op.r.rt) {
+               case OP_REGIMM_BLTZAL:
+               case OP_REGIMM_BGEZAL:
+                       if (opcode_reads_register(next_op, 31) ||
+                           opcode_writes_register(next_op, 31))
+                               return false;
+                       fallthrough;
+               case OP_REGIMM_BLTZ:
+               case OP_REGIMM_BGEZ:
+                       if (op.i.rs && opcode_writes_register(next_op, op.i.rs))
+                               return false;
+                       break;
+               }
+               fallthrough;
+       default:
+               break;
+       }
+
+       return true;
+}
+
 static int lightrec_switch_delay_slots(struct lightrec_state *state, struct block *block)
 {
        struct opcode *list, *next = &block->opcode_list[0];
@@ -1050,71 +1200,20 @@ static int lightrec_switch_delay_slots(struct lightrec_state *state, struct bloc
                    !op_flag_no_ds(block->opcode_list[i - 1].flags))
                        continue;
 
-               if (op_flag_sync(list->flags) || op_flag_sync(next->flags))
+               if (op_flag_sync(next->flags))
                        continue;
 
-               switch (list->i.op) {
-               case OP_SPECIAL:
-                       switch (op.r.op) {
-                       case OP_SPECIAL_JALR:
-                               if (opcode_reads_register(next_op, op.r.rd) ||
-                                   opcode_writes_register(next_op, op.r.rd))
-                                       continue;
-                               fallthrough;
-                       case OP_SPECIAL_JR:
-                               if (opcode_writes_register(next_op, op.r.rs))
-                                       continue;
-                               fallthrough;
-                       default:
-                               break;
-                       }
-                       fallthrough;
-               case OP_J:
-                       break;
-               case OP_JAL:
-                       if (opcode_reads_register(next_op, 31) ||
-                           opcode_writes_register(next_op, 31))
-                               continue;
-                       else
-                               break;
-               case OP_BEQ:
-               case OP_BNE:
-                       if (op.i.rt && opcode_writes_register(next_op, op.i.rt))
-                               continue;
-                       fallthrough;
-               case OP_BLEZ:
-               case OP_BGTZ:
-                       if (op.i.rs && opcode_writes_register(next_op, op.i.rs))
-                               continue;
-                       break;
-               case OP_REGIMM:
-                       switch (op.r.rt) {
-                       case OP_REGIMM_BLTZAL:
-                       case OP_REGIMM_BGEZAL:
-                               if (opcode_reads_register(next_op, 31) ||
-                                   opcode_writes_register(next_op, 31))
-                                       continue;
-                               fallthrough;
-                       case OP_REGIMM_BLTZ:
-                       case OP_REGIMM_BGEZ:
-                               if (op.i.rs &&
-                                   opcode_writes_register(next_op, op.i.rs))
-                                       continue;
-                               break;
-                       }
-                       fallthrough;
-               default:
-                       break;
-               }
+               if (!lightrec_can_switch_delay_slot(list->c, next_op))
+                       continue;
 
                pr_debug("Swap branch and delay slot opcodes "
                         "at offsets 0x%x / 0x%x\n",
                         i << 2, (i + 1) << 2);
 
-               flags = next->flags;
+               flags = next->flags | (list->flags & LIGHTREC_SYNC);
                list->c = next_op;
                next->c = op;
-               next->flags = list->flags | LIGHTREC_NO_DS;
+               next->flags = (list->flags | LIGHTREC_NO_DS) & ~LIGHTREC_SYNC;
                list->flags = flags | LIGHTREC_NO_DS;
        }
 
@@ -1123,7 +1222,7 @@ static int lightrec_switch_delay_slots(struct lightrec_state *state, struct bloc
 
 static int shrink_opcode_list(struct lightrec_state *state, struct block *block, u16 new_size)
 {
-       struct opcode *list;
+       struct opcode_list *list, *old_list;
 
        if (new_size >= block->nb_ops) {
                pr_err("Invalid shrink size (%u vs %u)\n",
@@ -1131,19 +1230,20 @@ static int shrink_opcode_list(struct lightrec_state *state, struct block *block,
                return -EINVAL;
        }
 
-
        list = lightrec_malloc(state, MEM_FOR_IR,
-                              sizeof(*list) * new_size);
+                              sizeof(*list) + sizeof(struct opcode) * new_size);
        if (!list) {
                pr_err("Unable to allocate memory\n");
                return -ENOMEM;
        }
 
-       memcpy(list, block->opcode_list, sizeof(*list) * new_size);
+       old_list = container_of(block->opcode_list, struct opcode_list, ops);
+       memcpy(list->ops, old_list->ops, sizeof(struct opcode) * new_size);
 
-       lightrec_free_opcode_list(state, block);
-       block->opcode_list = list;
+       lightrec_free_opcode_list(state, block->opcode_list);
+       list->nb_ops = new_size;
        block->nb_ops = new_size;
+       block->opcode_list = list->ops;
 
        pr_debug("Shrunk opcode list of block PC 0x%08x to %u opcodes\n",
                 block->pc, new_size);
@@ -1449,6 +1549,7 @@ static int lightrec_flag_io(struct lightrec_state *state, struct block *block)
        u32 values[32] = { 0 };
        unsigned int i;
        u32 val, kunseg_val;
+       bool no_mask;
 
        for (i = 0; i < block->nb_ops; i++) {
                prev = list;
@@ -1483,7 +1584,7 @@ static int lightrec_flag_io(struct lightrec_state *state, struct block *block)
                                    kunseg(values[list->i.rs]) < (kunseg(block->pc) +
                                                                  block->nb_ops * 4)) {
                                        pr_debug("Self-modifying block detected\n");
-                                       block->flags |= BLOCK_NEVER_COMPILE;
+                                       block_set_flags(block, BLOCK_NEVER_COMPILE);
                                        list->flags |= LIGHTREC_SMC;
                                }
                        }
@@ -1505,10 +1606,11 @@ static int lightrec_flag_io(struct lightrec_state *state, struct block *block)
                                psx_map = lightrec_get_map_idx(state, kunseg_val);
 
                                list->flags &= ~LIGHTREC_IO_MASK;
+                               no_mask = val == kunseg_val;
 
                                switch (psx_map) {
                                case PSX_MAP_KERNEL_USER_RAM:
-                                       if (val == kunseg_val)
+                                       if (no_mask)
                                                list->flags |= LIGHTREC_NO_MASK;
                                        fallthrough;
                                case PSX_MAP_MIRROR1:
@@ -1516,19 +1618,36 @@ static int lightrec_flag_io(struct lightrec_state *state, struct block *block)
                                case PSX_MAP_MIRROR3:
                                        pr_debug("Flaging opcode %u as RAM access\n", i);
                                        list->flags |= LIGHTREC_IO_MODE(LIGHTREC_IO_RAM);
+                                       if (no_mask && state->mirrors_mapped)
+                                               list->flags |= LIGHTREC_NO_MASK;
                                        break;
                                case PSX_MAP_BIOS:
                                        pr_debug("Flaging opcode %u as BIOS access\n", i);
                                        list->flags |= LIGHTREC_IO_MODE(LIGHTREC_IO_BIOS);
+                                       if (no_mask)
+                                               list->flags |= LIGHTREC_NO_MASK;
                                        break;
                                case PSX_MAP_SCRATCH_PAD:
                                        pr_debug("Flaging opcode %u as scratchpad access\n", i);
                                        list->flags |= LIGHTREC_IO_MODE(LIGHTREC_IO_SCRATCH);
+                                       if (no_mask)
+                                               list->flags |= LIGHTREC_NO_MASK;
 
                                        /* Consider that we're never going to run code from
                                         * the scratchpad. */
                                        list->flags |= LIGHTREC_NO_INVALIDATE;
                                        break;
+                               case PSX_MAP_HW_REGISTERS:
+                                       if (state->ops.hw_direct &&
+                                           state->ops.hw_direct(kunseg_val,
+                                                                opcode_is_store(list->c),
+                                                                opcode_get_io_size(list->c))) {
+                                               pr_debug("Flagging opcode %u as direct I/O access\n",
+                                                        i);
+                                               list->flags |= LIGHTREC_IO_MODE(LIGHTREC_IO_DIRECT_HW);
+                                               break;
+                                       }
+                                       fallthrough;
                                default:
                                        pr_debug("Flagging opcode %u as I/O access\n",
                                                 i);
@@ -1591,6 +1710,9 @@ static u8 get_mfhi_mflo_reg(const struct block *block, u16 offset,
                        }
 
                        return mflo ? REG_LO : REG_HI;
+               case OP_META_MULT2:
+               case OP_META_MULTU2:
+                       return 0;
                case OP_SPECIAL:
                        switch (op->r.op) {
                        case OP_SPECIAL_MULT:
@@ -1736,20 +1858,26 @@ static int lightrec_flag_mults_divs(struct lightrec_state *state, struct block *
                if (prev)
                        known = lightrec_propagate_consts(list, prev, known, values);
 
-               if (list->i.op != OP_SPECIAL)
-                       continue;
-
-               switch (list->r.op) {
-               case OP_SPECIAL_DIV:
-               case OP_SPECIAL_DIVU:
-                       /* If we are dividing by a non-zero constant, don't
-                        * emit the div-by-zero check. */
-                       if (lightrec_always_skip_div_check() ||
-                           (known & BIT(list->c.r.rt) && values[list->c.r.rt]))
-                               list->flags |= LIGHTREC_NO_DIV_CHECK;
+               switch (list->i.op) {
+               case OP_SPECIAL:
+                       switch (list->r.op) {
+                       case OP_SPECIAL_DIV:
+                       case OP_SPECIAL_DIVU:
+                               /* If we are dividing by a non-zero constant, don't
+                                * emit the div-by-zero check. */
+                               if (lightrec_always_skip_div_check() ||
+                                   ((known & BIT(list->c.r.rt)) && values[list->c.r.rt]))
+                                       list->flags |= LIGHTREC_NO_DIV_CHECK;
+                               fallthrough;
+                       case OP_SPECIAL_MULT:
+                       case OP_SPECIAL_MULTU:
+                               break;
+                       default:
+                               continue;
+                       }
                        fallthrough;
-               case OP_SPECIAL_MULT:
-               case OP_SPECIAL_MULTU:
+               case OP_META_MULT2:
+               case OP_META_MULTU2:
                        break;
                default:
                        continue;
@@ -1929,7 +2057,8 @@ static int lightrec_replace_memset(struct lightrec_state *state, struct block *b
                if (i == ARRAY_SIZE(memset_code) - 1) {
                        /* success! */
                        pr_debug("Block at PC 0x%x is a memset\n", block->pc);
-                       block->flags |= BLOCK_IS_MEMSET | BLOCK_NEVER_COMPILE;
+                       block_set_flags(block,
+                                       BLOCK_IS_MEMSET | BLOCK_NEVER_COMPILE);
 
                        /* Return non-zero to skip other optimizers. */
                        return 1;