Update lightrec 20220716 (#672)
[pcsx_rearmed.git] / deps / lightrec / optimizer.c
1 // SPDX-License-Identifier: LGPL-2.1-or-later
2 /*
3  * Copyright (C) 2014-2021 Paul Cercueil <paul@crapouillou.net>
4  */
5
6 #include "lightrec-config.h"
7 #include "disassembler.h"
8 #include "lightrec.h"
9 #include "memmanager.h"
10 #include "optimizer.h"
11 #include "regcache.h"
12
13 #include <errno.h>
14 #include <stdbool.h>
15 #include <stdlib.h>
16 #include <string.h>
17
18 #define IF_OPT(opt, ptr) ((opt) ? (ptr) : NULL)
19
20 struct optimizer_list {
21         void (**optimizers)(struct opcode *);
22         unsigned int nb_optimizers;
23 };
24
25 static bool is_nop(union code op);
26
27 bool is_unconditional_jump(union code c)
28 {
29         switch (c.i.op) {
30         case OP_SPECIAL:
31                 return c.r.op == OP_SPECIAL_JR || c.r.op == OP_SPECIAL_JALR;
32         case OP_J:
33         case OP_JAL:
34                 return true;
35         case OP_BEQ:
36         case OP_BLEZ:
37                 return c.i.rs == c.i.rt;
38         case OP_REGIMM:
39                 return (c.r.rt == OP_REGIMM_BGEZ ||
40                         c.r.rt == OP_REGIMM_BGEZAL) && c.i.rs == 0;
41         default:
42                 return false;
43         }
44 }
45
46 bool is_syscall(union code c)
47 {
48         return (c.i.op == OP_SPECIAL && c.r.op == OP_SPECIAL_SYSCALL) ||
49                 (c.i.op == OP_CP0 && (c.r.rs == OP_CP0_MTC0 ||
50                                         c.r.rs == OP_CP0_CTC0) &&
51                  (c.r.rd == 12 || c.r.rd == 13));
52 }
53
54 static u64 opcode_read_mask(union code op)
55 {
56         switch (op.i.op) {
57         case OP_SPECIAL:
58                 switch (op.r.op) {
59                 case OP_SPECIAL_SYSCALL:
60                 case OP_SPECIAL_BREAK:
61                         return 0;
62                 case OP_SPECIAL_JR:
63                 case OP_SPECIAL_JALR:
64                 case OP_SPECIAL_MTHI:
65                 case OP_SPECIAL_MTLO:
66                         return BIT(op.r.rs);
67                 case OP_SPECIAL_MFHI:
68                         return BIT(REG_HI);
69                 case OP_SPECIAL_MFLO:
70                         return BIT(REG_LO);
71                 case OP_SPECIAL_SLL:
72                         if (!op.r.imm)
73                                 return 0;
74                         fallthrough;
75                 case OP_SPECIAL_SRL:
76                 case OP_SPECIAL_SRA:
77                         return BIT(op.r.rt);
78                 default:
79                         return BIT(op.r.rs) | BIT(op.r.rt);
80                 }
81         case OP_CP0:
82                 switch (op.r.rs) {
83                 case OP_CP0_MTC0:
84                 case OP_CP0_CTC0:
85                         return BIT(op.r.rt);
86                 default:
87                         return 0;
88                 }
89         case OP_CP2:
90                 if (op.r.op == OP_CP2_BASIC) {
91                         switch (op.r.rs) {
92                         case OP_CP2_BASIC_MTC2:
93                         case OP_CP2_BASIC_CTC2:
94                                 return BIT(op.r.rt);
95                         default:
96                                 break;
97                         }
98                 }
99                 return 0;
100         case OP_J:
101         case OP_JAL:
102         case OP_LUI:
103                 return 0;
104         case OP_BEQ:
105                 if (op.i.rs == op.i.rt)
106                         return 0;
107                 fallthrough;
108         case OP_BNE:
109         case OP_LWL:
110         case OP_LWR:
111         case OP_SB:
112         case OP_SH:
113         case OP_SWL:
114         case OP_SW:
115         case OP_SWR:
116                 return BIT(op.i.rs) | BIT(op.i.rt);
117         default:
118                 return BIT(op.i.rs);
119         }
120 }
121
122 static u64 opcode_write_mask(union code op)
123 {
124         u64 flags;
125
126         switch (op.i.op) {
127         case OP_SPECIAL:
128                 switch (op.r.op) {
129                 case OP_SPECIAL_JR:
130                 case OP_SPECIAL_SYSCALL:
131                 case OP_SPECIAL_BREAK:
132                         return 0;
133                 case OP_SPECIAL_MULT:
134                 case OP_SPECIAL_MULTU:
135                 case OP_SPECIAL_DIV:
136                 case OP_SPECIAL_DIVU:
137                         if (!OPT_FLAG_MULT_DIV)
138                                 return BIT(REG_LO) | BIT(REG_HI);
139
140                         if (op.r.rd)
141                                 flags = BIT(op.r.rd);
142                         else
143                                 flags = BIT(REG_LO);
144                         if (op.r.imm)
145                                 flags |= BIT(op.r.imm);
146                         else
147                                 flags |= BIT(REG_HI);
148                         return flags;
149                 case OP_SPECIAL_MTHI:
150                         return BIT(REG_HI);
151                 case OP_SPECIAL_MTLO:
152                         return BIT(REG_LO);
153                 case OP_SPECIAL_SLL:
154                         if (!op.r.imm)
155                                 return 0;
156                         fallthrough;
157                 default:
158                         return BIT(op.r.rd);
159                 }
160         case OP_ADDI:
161         case OP_ADDIU:
162         case OP_SLTI:
163         case OP_SLTIU:
164         case OP_ANDI:
165         case OP_ORI:
166         case OP_XORI:
167         case OP_LUI:
168         case OP_LB:
169         case OP_LH:
170         case OP_LWL:
171         case OP_LW:
172         case OP_LBU:
173         case OP_LHU:
174         case OP_LWR:
175                 return BIT(op.i.rt);
176         case OP_JAL:
177                 return BIT(31);
178         case OP_CP0:
179                 switch (op.r.rs) {
180                 case OP_CP0_MFC0:
181                 case OP_CP0_CFC0:
182                         return BIT(op.i.rt);
183                 default:
184                         return 0;
185                 }
186         case OP_CP2:
187                 if (op.r.op == OP_CP2_BASIC) {
188                         switch (op.r.rs) {
189                         case OP_CP2_BASIC_MFC2:
190                         case OP_CP2_BASIC_CFC2:
191                                 return BIT(op.i.rt);
192                         default:
193                                 break;
194                         }
195                 }
196                 return 0;
197         case OP_REGIMM:
198                 switch (op.r.rt) {
199                 case OP_REGIMM_BLTZAL:
200                 case OP_REGIMM_BGEZAL:
201                         return BIT(31);
202                 default:
203                         return 0;
204                 }
205         case OP_META_MOV:
206                 return BIT(op.r.rd);
207         default:
208                 return 0;
209         }
210 }
211
212 bool opcode_reads_register(union code op, u8 reg)
213 {
214         return opcode_read_mask(op) & BIT(reg);
215 }
216
217 bool opcode_writes_register(union code op, u8 reg)
218 {
219         return opcode_write_mask(op) & BIT(reg);
220 }
221
222 static int find_prev_writer(const struct opcode *list, unsigned int offset, u8 reg)
223 {
224         union code c;
225         unsigned int i;
226
227         if (op_flag_sync(list[offset].flags))
228                 return -1;
229
230         for (i = offset; i > 0; i--) {
231                 c = list[i - 1].c;
232
233                 if (opcode_writes_register(c, reg)) {
234                         if (i > 1 && has_delay_slot(list[i - 2].c))
235                                 break;
236
237                         return i - 1;
238                 }
239
240                 if (op_flag_sync(list[i - 1].flags) ||
241                     has_delay_slot(c) ||
242                     opcode_reads_register(c, reg))
243                         break;
244         }
245
246         return -1;
247 }
248
249 static int find_next_reader(const struct opcode *list, unsigned int offset, u8 reg)
250 {
251         unsigned int i;
252         union code c;
253
254         if (op_flag_sync(list[offset].flags))
255                 return -1;
256
257         for (i = offset; ; i++) {
258                 c = list[i].c;
259
260                 if (opcode_reads_register(c, reg)) {
261                         if (i > 0 && has_delay_slot(list[i - 1].c))
262                                 break;
263
264                         return i;
265                 }
266
267                 if (op_flag_sync(list[i].flags) ||
268                     has_delay_slot(c) || opcode_writes_register(c, reg))
269                         break;
270         }
271
272         return -1;
273 }
274
275 static bool reg_is_dead(const struct opcode *list, unsigned int offset, u8 reg)
276 {
277         unsigned int i;
278
279         if (op_flag_sync(list[offset].flags))
280                 return false;
281
282         for (i = offset + 1; ; i++) {
283                 if (opcode_reads_register(list[i].c, reg))
284                         return false;
285
286                 if (opcode_writes_register(list[i].c, reg))
287                         return true;
288
289                 if (has_delay_slot(list[i].c)) {
290                         if (op_flag_no_ds(list[i].flags) ||
291                             opcode_reads_register(list[i + 1].c, reg))
292                                 return false;
293
294                         return opcode_writes_register(list[i + 1].c, reg);
295                 }
296         }
297 }
298
299 static bool reg_is_read(const struct opcode *list,
300                         unsigned int a, unsigned int b, u8 reg)
301 {
302         /* Return true if reg is read in one of the opcodes of the interval
303          * [a, b[ */
304         for (; a < b; a++) {
305                 if (!is_nop(list[a].c) && opcode_reads_register(list[a].c, reg))
306                         return true;
307         }
308
309         return false;
310 }
311
312 static bool reg_is_written(const struct opcode *list,
313                            unsigned int a, unsigned int b, u8 reg)
314 {
315         /* Return true if reg is written in one of the opcodes of the interval
316          * [a, b[ */
317
318         for (; a < b; a++) {
319                 if (!is_nop(list[a].c) && opcode_writes_register(list[a].c, reg))
320                         return true;
321         }
322
323         return false;
324 }
325
326 static bool reg_is_read_or_written(const struct opcode *list,
327                                    unsigned int a, unsigned int b, u8 reg)
328 {
329         return reg_is_read(list, a, b, reg) || reg_is_written(list, a, b, reg);
330 }
331
332 static bool opcode_is_load(union code op)
333 {
334         switch (op.i.op) {
335         case OP_LB:
336         case OP_LH:
337         case OP_LWL:
338         case OP_LW:
339         case OP_LBU:
340         case OP_LHU:
341         case OP_LWR:
342         case OP_LWC2:
343                 return true;
344         default:
345                 return false;
346         }
347 }
348
349 static bool opcode_is_store(union code op)
350 {
351         switch (op.i.op) {
352         case OP_SB:
353         case OP_SH:
354         case OP_SW:
355         case OP_SWL:
356         case OP_SWR:
357         case OP_SWC2:
358                 return true;
359         default:
360                 return false;
361         }
362 }
363
364 bool opcode_is_io(union code op)
365 {
366         return opcode_is_load(op) || opcode_is_store(op);
367 }
368
369 /* TODO: Complete */
370 static bool is_nop(union code op)
371 {
372         if (opcode_writes_register(op, 0)) {
373                 switch (op.i.op) {
374                 case OP_CP0:
375                         return op.r.rs != OP_CP0_MFC0;
376                 case OP_LB:
377                 case OP_LH:
378                 case OP_LWL:
379                 case OP_LW:
380                 case OP_LBU:
381                 case OP_LHU:
382                 case OP_LWR:
383                         return false;
384                 default:
385                         return true;
386                 }
387         }
388
389         switch (op.i.op) {
390         case OP_SPECIAL:
391                 switch (op.r.op) {
392                 case OP_SPECIAL_AND:
393                         return op.r.rd == op.r.rt && op.r.rd == op.r.rs;
394                 case OP_SPECIAL_ADD:
395                 case OP_SPECIAL_ADDU:
396                         return (op.r.rd == op.r.rt && op.r.rs == 0) ||
397                                 (op.r.rd == op.r.rs && op.r.rt == 0);
398                 case OP_SPECIAL_SUB:
399                 case OP_SPECIAL_SUBU:
400                         return op.r.rd == op.r.rs && op.r.rt == 0;
401                 case OP_SPECIAL_OR:
402                         if (op.r.rd == op.r.rt)
403                                 return op.r.rd == op.r.rs || op.r.rs == 0;
404                         else
405                                 return (op.r.rd == op.r.rs) && op.r.rt == 0;
406                 case OP_SPECIAL_SLL:
407                 case OP_SPECIAL_SRA:
408                 case OP_SPECIAL_SRL:
409                         return op.r.rd == op.r.rt && op.r.imm == 0;
410                 case OP_SPECIAL_MFHI:
411                 case OP_SPECIAL_MFLO:
412                         return op.r.rd == 0;
413                 default:
414                         return false;
415                 }
416         case OP_ORI:
417         case OP_ADDI:
418         case OP_ADDIU:
419                 return op.i.rt == op.i.rs && op.i.imm == 0;
420         case OP_BGTZ:
421                 return (op.i.rs == 0 || op.i.imm == 1);
422         case OP_REGIMM:
423                 return (op.i.op == OP_REGIMM_BLTZ ||
424                                 op.i.op == OP_REGIMM_BLTZAL) &&
425                         (op.i.rs == 0 || op.i.imm == 1);
426         case OP_BNE:
427                 return (op.i.rs == op.i.rt || op.i.imm == 1);
428         default:
429                 return false;
430         }
431 }
432
433 bool load_in_delay_slot(union code op)
434 {
435         switch (op.i.op) {
436         case OP_CP0:
437                 switch (op.r.rs) {
438                 case OP_CP0_MFC0:
439                 case OP_CP0_CFC0:
440                         return true;
441                 default:
442                         break;
443                 }
444
445                 break;
446         case OP_CP2:
447                 if (op.r.op == OP_CP2_BASIC) {
448                         switch (op.r.rs) {
449                         case OP_CP2_BASIC_MFC2:
450                         case OP_CP2_BASIC_CFC2:
451                                 return true;
452                         default:
453                                 break;
454                         }
455                 }
456
457                 break;
458         case OP_LB:
459         case OP_LH:
460         case OP_LW:
461         case OP_LWL:
462         case OP_LWR:
463         case OP_LBU:
464         case OP_LHU:
465                 return true;
466         default:
467                 break;
468         }
469
470         return false;
471 }
472
473 static u32 lightrec_propagate_consts(const struct opcode *op,
474                                      const struct opcode *prev,
475                                      u32 known, u32 *v)
476 {
477         union code c = prev->c;
478
479         /* Register $zero is always, well, zero */
480         known |= BIT(0);
481         v[0] = 0;
482
483         if (op_flag_sync(op->flags))
484                 return BIT(0);
485
486         switch (c.i.op) {
487         case OP_SPECIAL:
488                 switch (c.r.op) {
489                 case OP_SPECIAL_SLL:
490                         if (known & BIT(c.r.rt)) {
491                                 known |= BIT(c.r.rd);
492                                 v[c.r.rd] = v[c.r.rt] << c.r.imm;
493                         } else {
494                                 known &= ~BIT(c.r.rd);
495                         }
496                         break;
497                 case OP_SPECIAL_SRL:
498                         if (known & BIT(c.r.rt)) {
499                                 known |= BIT(c.r.rd);
500                                 v[c.r.rd] = v[c.r.rt] >> c.r.imm;
501                         } else {
502                                 known &= ~BIT(c.r.rd);
503                         }
504                         break;
505                 case OP_SPECIAL_SRA:
506                         if (known & BIT(c.r.rt)) {
507                                 known |= BIT(c.r.rd);
508                                 v[c.r.rd] = (s32)v[c.r.rt] >> c.r.imm;
509                         } else {
510                                 known &= ~BIT(c.r.rd);
511                         }
512                         break;
513                 case OP_SPECIAL_SLLV:
514                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
515                                 known |= BIT(c.r.rd);
516                                 v[c.r.rd] = v[c.r.rt] << (v[c.r.rs] & 0x1f);
517                         } else {
518                                 known &= ~BIT(c.r.rd);
519                         }
520                         break;
521                 case OP_SPECIAL_SRLV:
522                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
523                                 known |= BIT(c.r.rd);
524                                 v[c.r.rd] = v[c.r.rt] >> (v[c.r.rs] & 0x1f);
525                         } else {
526                                 known &= ~BIT(c.r.rd);
527                         }
528                         break;
529                 case OP_SPECIAL_SRAV:
530                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
531                                 known |= BIT(c.r.rd);
532                                 v[c.r.rd] = (s32)v[c.r.rt]
533                                           >> (v[c.r.rs] & 0x1f);
534                         } else {
535                                 known &= ~BIT(c.r.rd);
536                         }
537                         break;
538                 case OP_SPECIAL_ADD:
539                 case OP_SPECIAL_ADDU:
540                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
541                                 known |= BIT(c.r.rd);
542                                 v[c.r.rd] = (s32)v[c.r.rt] + (s32)v[c.r.rs];
543                         } else {
544                                 known &= ~BIT(c.r.rd);
545                         }
546                         break;
547                 case OP_SPECIAL_SUB:
548                 case OP_SPECIAL_SUBU:
549                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
550                                 known |= BIT(c.r.rd);
551                                 v[c.r.rd] = v[c.r.rt] - v[c.r.rs];
552                         } else {
553                                 known &= ~BIT(c.r.rd);
554                         }
555                         break;
556                 case OP_SPECIAL_AND:
557                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
558                                 known |= BIT(c.r.rd);
559                                 v[c.r.rd] = v[c.r.rt] & v[c.r.rs];
560                         } else {
561                                 known &= ~BIT(c.r.rd);
562                         }
563                         break;
564                 case OP_SPECIAL_OR:
565                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
566                                 known |= BIT(c.r.rd);
567                                 v[c.r.rd] = v[c.r.rt] | v[c.r.rs];
568                         } else {
569                                 known &= ~BIT(c.r.rd);
570                         }
571                         break;
572                 case OP_SPECIAL_XOR:
573                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
574                                 known |= BIT(c.r.rd);
575                                 v[c.r.rd] = v[c.r.rt] ^ v[c.r.rs];
576                         } else {
577                                 known &= ~BIT(c.r.rd);
578                         }
579                         break;
580                 case OP_SPECIAL_NOR:
581                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
582                                 known |= BIT(c.r.rd);
583                                 v[c.r.rd] = ~(v[c.r.rt] | v[c.r.rs]);
584                         } else {
585                                 known &= ~BIT(c.r.rd);
586                         }
587                         break;
588                 case OP_SPECIAL_SLT:
589                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
590                                 known |= BIT(c.r.rd);
591                                 v[c.r.rd] = (s32)v[c.r.rs] < (s32)v[c.r.rt];
592                         } else {
593                                 known &= ~BIT(c.r.rd);
594                         }
595                         break;
596                 case OP_SPECIAL_SLTU:
597                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
598                                 known |= BIT(c.r.rd);
599                                 v[c.r.rd] = v[c.r.rs] < v[c.r.rt];
600                         } else {
601                                 known &= ~BIT(c.r.rd);
602                         }
603                         break;
604                 default:
605                         break;
606                 }
607                 break;
608         case OP_REGIMM:
609                 break;
610         case OP_ADDI:
611         case OP_ADDIU:
612                 if (known & BIT(c.i.rs)) {
613                         known |= BIT(c.i.rt);
614                         v[c.i.rt] = v[c.i.rs] + (s32)(s16)c.i.imm;
615                 } else {
616                         known &= ~BIT(c.i.rt);
617                 }
618                 break;
619         case OP_SLTI:
620                 if (known & BIT(c.i.rs)) {
621                         known |= BIT(c.i.rt);
622                         v[c.i.rt] = (s32)v[c.i.rs] < (s32)(s16)c.i.imm;
623                 } else {
624                         known &= ~BIT(c.i.rt);
625                 }
626                 break;
627         case OP_SLTIU:
628                 if (known & BIT(c.i.rs)) {
629                         known |= BIT(c.i.rt);
630                         v[c.i.rt] = v[c.i.rs] < (u32)(s32)(s16)c.i.imm;
631                 } else {
632                         known &= ~BIT(c.i.rt);
633                 }
634                 break;
635         case OP_ANDI:
636                 if (known & BIT(c.i.rs)) {
637                         known |= BIT(c.i.rt);
638                         v[c.i.rt] = v[c.i.rs] & c.i.imm;
639                 } else {
640                         known &= ~BIT(c.i.rt);
641                 }
642                 break;
643         case OP_ORI:
644                 if (known & BIT(c.i.rs)) {
645                         known |= BIT(c.i.rt);
646                         v[c.i.rt] = v[c.i.rs] | c.i.imm;
647                 } else {
648                         known &= ~BIT(c.i.rt);
649                 }
650                 break;
651         case OP_XORI:
652                 if (known & BIT(c.i.rs)) {
653                         known |= BIT(c.i.rt);
654                         v[c.i.rt] = v[c.i.rs] ^ c.i.imm;
655                 } else {
656                         known &= ~BIT(c.i.rt);
657                 }
658                 break;
659         case OP_LUI:
660                 known |= BIT(c.i.rt);
661                 v[c.i.rt] = c.i.imm << 16;
662                 break;
663         case OP_CP0:
664                 switch (c.r.rs) {
665                 case OP_CP0_MFC0:
666                 case OP_CP0_CFC0:
667                         known &= ~BIT(c.r.rt);
668                         break;
669                 }
670                 break;
671         case OP_CP2:
672                 if (c.r.op == OP_CP2_BASIC) {
673                         switch (c.r.rs) {
674                         case OP_CP2_BASIC_MFC2:
675                         case OP_CP2_BASIC_CFC2:
676                                 known &= ~BIT(c.r.rt);
677                                 break;
678                         }
679                 }
680                 break;
681         case OP_LB:
682         case OP_LH:
683         case OP_LWL:
684         case OP_LW:
685         case OP_LBU:
686         case OP_LHU:
687         case OP_LWR:
688         case OP_LWC2:
689                 known &= ~BIT(c.i.rt);
690                 break;
691         case OP_META_MOV:
692                 if (known & BIT(c.r.rs)) {
693                         known |= BIT(c.r.rd);
694                         v[c.r.rd] = v[c.r.rs];
695                 } else {
696                         known &= ~BIT(c.r.rd);
697                 }
698                 break;
699         default:
700                 break;
701         }
702
703         return known;
704 }
705
706 static void lightrec_optimize_sll_sra(struct opcode *list, unsigned int offset)
707 {
708         struct opcode *prev, *prev2 = NULL, *curr = &list[offset];
709         struct opcode *to_change, *to_nop;
710         int idx, idx2;
711
712         if (curr->r.imm != 24 && curr->r.imm != 16)
713                 return;
714
715         idx = find_prev_writer(list, offset, curr->r.rt);
716         if (idx < 0)
717                 return;
718
719         prev = &list[idx];
720
721         if (prev->i.op != OP_SPECIAL || prev->r.op != OP_SPECIAL_SLL ||
722             prev->r.imm != curr->r.imm || prev->r.rd != curr->r.rt)
723                 return;
724
725         if (prev->r.rd != prev->r.rt && curr->r.rd != curr->r.rt) {
726                 /* sll rY, rX, 16
727                  * ...
728                  * srl rZ, rY, 16 */
729
730                 if (!reg_is_dead(list, offset, curr->r.rt) ||
731                     reg_is_read_or_written(list, idx, offset, curr->r.rd))
732                         return;
733
734                 /* If rY is dead after the SRL, and rZ is not used after the SLL,
735                  * we can change rY to rZ */
736
737                 pr_debug("Detected SLL/SRA with middle temp register\n");
738                 prev->r.rd = curr->r.rd;
739                 curr->r.rt = prev->r.rd;
740         }
741
742         /* We got a SLL/SRA combo. If imm #16, that's a cast to u16.
743          * If imm #24 that's a cast to u8.
744          *
745          * First of all, make sure that the target register of the SLL is not
746          * read before the SRA. */
747
748         if (prev->r.rd == prev->r.rt) {
749                 /* sll rX, rX, 16
750                  * ...
751                  * srl rY, rX, 16 */
752                 to_change = curr;
753                 to_nop = prev;
754
755                 /* rX is used after the SRA - we cannot convert it. */
756                 if (prev->r.rd != curr->r.rd && !reg_is_dead(list, offset, prev->r.rd))
757                         return;
758         } else {
759                 /* sll rY, rX, 16
760                  * ...
761                  * srl rY, rY, 16 */
762                 to_change = prev;
763                 to_nop = curr;
764         }
765
766         idx2 = find_prev_writer(list, idx, prev->r.rt);
767         if (idx2 >= 0) {
768                 /* Note that PSX games sometimes do casts after
769                  * a LHU or LBU; in this case we can change the
770                  * load opcode to a LH or LB, and the cast can
771                  * be changed to a MOV or a simple NOP. */
772
773                 prev2 = &list[idx2];
774
775                 if (curr->r.rd != prev2->i.rt &&
776                     !reg_is_dead(list, offset, prev2->i.rt))
777                         prev2 = NULL;
778                 else if (curr->r.imm == 16 && prev2->i.op == OP_LHU)
779                         prev2->i.op = OP_LH;
780                 else if (curr->r.imm == 24 && prev2->i.op == OP_LBU)
781                         prev2->i.op = OP_LB;
782                 else
783                         prev2 = NULL;
784
785                 if (prev2) {
786                         if (curr->r.rd == prev2->i.rt) {
787                                 to_change->opcode = 0;
788                         } else if (reg_is_dead(list, offset, prev2->i.rt) &&
789                                    !reg_is_read_or_written(list, idx2 + 1, offset, curr->r.rd)) {
790                                 /* The target register of the SRA is dead after the
791                                  * LBU/LHU; we can change the target register of the
792                                  * LBU/LHU to the one of the SRA. */
793                                 prev2->i.rt = curr->r.rd;
794                                 to_change->opcode = 0;
795                         } else {
796                                 to_change->i.op = OP_META_MOV;
797                                 to_change->r.rd = curr->r.rd;
798                                 to_change->r.rs = prev2->i.rt;
799                         }
800
801                         if (to_nop->r.imm == 24)
802                                 pr_debug("Convert LBU+SLL+SRA to LB\n");
803                         else
804                                 pr_debug("Convert LHU+SLL+SRA to LH\n");
805                 }
806         }
807
808         if (!prev2) {
809                 pr_debug("Convert SLL/SRA #%u to EXT%c\n",
810                          prev->r.imm,
811                          prev->r.imm == 24 ? 'C' : 'S');
812
813                 if (to_change == prev) {
814                         to_change->i.rs = prev->r.rt;
815                         to_change->i.rt = curr->r.rd;
816                 } else {
817                         to_change->i.rt = curr->r.rd;
818                         to_change->i.rs = prev->r.rt;
819                 }
820
821                 if (to_nop->r.imm == 24)
822                         to_change->i.op = OP_META_EXTC;
823                 else
824                         to_change->i.op = OP_META_EXTS;
825         }
826
827         to_nop->opcode = 0;
828 }
829
830 static void lightrec_remove_useless_lui(struct block *block, unsigned int offset,
831                                         u32 known, u32 *values)
832 {
833         struct opcode *list = block->opcode_list,
834                       *op = &block->opcode_list[offset];
835         int reader;
836
837         if (!op_flag_sync(op->flags) && (known & BIT(op->i.rt)) &&
838             values[op->i.rt] == op->i.imm << 16) {
839                 pr_debug("Converting duplicated LUI to NOP\n");
840                 op->opcode = 0x0;
841                 return;
842         }
843
844         if (op->i.imm != 0 || op->i.rt == 0)
845                 return;
846
847         reader = find_next_reader(list, offset + 1, op->i.rt);
848         if (reader <= 0)
849                 return;
850
851         if (opcode_writes_register(list[reader].c, op->i.rt) ||
852             reg_is_dead(list, reader, op->i.rt)) {
853                 pr_debug("Removing useless LUI 0x0\n");
854
855                 if (list[reader].i.rs == op->i.rt)
856                         list[reader].i.rs = 0;
857                 if (list[reader].i.op == OP_SPECIAL &&
858                     list[reader].i.rt == op->i.rt)
859                         list[reader].i.rt = 0;
860                 op->opcode = 0x0;
861         }
862 }
863
864 static void lightrec_modify_lui(struct block *block, unsigned int offset)
865 {
866         union code c, *lui = &block->opcode_list[offset].c;
867         bool stop = false, stop_next = false;
868         unsigned int i;
869
870         for (i = offset + 1; !stop && i < block->nb_ops; i++) {
871                 c = block->opcode_list[i].c;
872                 stop = stop_next;
873
874                 if ((opcode_is_store(c) && c.i.rt == lui->i.rt)
875                     || (!opcode_is_load(c) && opcode_reads_register(c, lui->i.rt)))
876                         break;
877
878                 if (opcode_writes_register(c, lui->i.rt)) {
879                         pr_debug("Convert LUI at offset 0x%x to kuseg\n",
880                                  i - 1 << 2);
881                         lui->i.imm = kunseg(lui->i.imm << 16) >> 16;
882                         break;
883                 }
884
885                 if (has_delay_slot(c))
886                         stop_next = true;
887         }
888 }
889
890 static int lightrec_transform_branches(struct lightrec_state *state,
891                                        struct block *block)
892 {
893         struct opcode *op;
894         unsigned int i;
895         s32 offset;
896
897         for (i = 0; i < block->nb_ops; i++) {
898                 op = &block->opcode_list[i];
899
900                 switch (op->i.op) {
901                 case OP_J:
902                         /* Transform J opcode into BEQ $zero, $zero if possible. */
903                         offset = (s32)((block->pc & 0xf0000000) >> 2 | op->j.imm)
904                                 - (s32)(block->pc >> 2) - (s32)i - 1;
905
906                         if (offset == (s16)offset) {
907                                 pr_debug("Transform J into BEQ $zero, $zero\n");
908                                 op->i.op = OP_BEQ;
909                                 op->i.rs = 0;
910                                 op->i.rt = 0;
911                                 op->i.imm = offset;
912
913                         }
914                 default: /* fall-through */
915                         break;
916                 }
917         }
918
919         return 0;
920 }
921
922 static int lightrec_transform_ops(struct lightrec_state *state, struct block *block)
923 {
924         struct opcode *list = block->opcode_list;
925         struct opcode *prev, *op = NULL;
926         u32 known = BIT(0);
927         u32 values[32] = { 0 };
928         unsigned int i;
929
930         for (i = 0; i < block->nb_ops; i++) {
931                 prev = op;
932                 op = &list[i];
933
934                 if (prev)
935                         known = lightrec_propagate_consts(op, prev, known, values);
936
937                 /* Transform all opcodes detected as useless to real NOPs
938                  * (0x0: SLL r0, r0, #0) */
939                 if (op->opcode != 0 && is_nop(op->c)) {
940                         pr_debug("Converting useless opcode 0x%08x to NOP\n",
941                                         op->opcode);
942                         op->opcode = 0x0;
943                 }
944
945                 if (!op->opcode)
946                         continue;
947
948                 switch (op->i.op) {
949                 case OP_BEQ:
950                         if (op->i.rs == op->i.rt) {
951                                 op->i.rs = 0;
952                                 op->i.rt = 0;
953                         } else if (op->i.rs == 0) {
954                                 op->i.rs = op->i.rt;
955                                 op->i.rt = 0;
956                         }
957                         break;
958
959                 case OP_BNE:
960                         if (op->i.rs == 0) {
961                                 op->i.rs = op->i.rt;
962                                 op->i.rt = 0;
963                         }
964                         break;
965
966                 case OP_LUI:
967                         if (!prev || !has_delay_slot(prev->c))
968                                 lightrec_modify_lui(block, i);
969                         lightrec_remove_useless_lui(block, i, known, values);
970                         break;
971
972                 /* Transform ORI/ADDI/ADDIU with imm #0 or ORR/ADD/ADDU/SUB/SUBU
973                  * with register $zero to the MOV meta-opcode */
974                 case OP_ORI:
975                 case OP_ADDI:
976                 case OP_ADDIU:
977                         if (op->i.imm == 0) {
978                                 pr_debug("Convert ORI/ADDI/ADDIU #0 to MOV\n");
979                                 op->i.op = OP_META_MOV;
980                                 op->r.rd = op->i.rt;
981                         }
982                         break;
983                 case OP_SPECIAL:
984                         switch (op->r.op) {
985                         case OP_SPECIAL_SRA:
986                                 if (op->r.imm == 0) {
987                                         pr_debug("Convert SRA #0 to MOV\n");
988                                         op->i.op = OP_META_MOV;
989                                         op->r.rs = op->r.rt;
990                                         break;
991                                 }
992
993                                 lightrec_optimize_sll_sra(block->opcode_list, i);
994                                 break;
995                         case OP_SPECIAL_SLL:
996                         case OP_SPECIAL_SRL:
997                                 if (op->r.imm == 0) {
998                                         pr_debug("Convert SLL/SRL #0 to MOV\n");
999                                         op->i.op = OP_META_MOV;
1000                                         op->r.rs = op->r.rt;
1001                                 }
1002                                 break;
1003                         case OP_SPECIAL_OR:
1004                         case OP_SPECIAL_ADD:
1005                         case OP_SPECIAL_ADDU:
1006                                 if (op->r.rs == 0) {
1007                                         pr_debug("Convert OR/ADD $zero to MOV\n");
1008                                         op->i.op = OP_META_MOV;
1009                                         op->r.rs = op->r.rt;
1010                                 }
1011                                 fallthrough;
1012                         case OP_SPECIAL_SUB:
1013                         case OP_SPECIAL_SUBU:
1014                                 if (op->r.rt == 0) {
1015                                         pr_debug("Convert OR/ADD/SUB $zero to MOV\n");
1016                                         op->i.op = OP_META_MOV;
1017                                 }
1018                                 fallthrough;
1019                         default:
1020                                 break;
1021                         }
1022                         fallthrough;
1023                 default:
1024                         break;
1025                 }
1026         }
1027
1028         return 0;
1029 }
1030
1031 static int lightrec_switch_delay_slots(struct lightrec_state *state, struct block *block)
1032 {
1033         struct opcode *list, *next = &block->opcode_list[0];
1034         unsigned int i;
1035         union code op, next_op;
1036         u32 flags;
1037
1038         for (i = 0; i < block->nb_ops - 1; i++) {
1039                 list = next;
1040                 next = &block->opcode_list[i + 1];
1041                 next_op = next->c;
1042                 op = list->c;
1043
1044                 if (!has_delay_slot(op) || op_flag_no_ds(list->flags) ||
1045                     op_flag_emulate_branch(list->flags) ||
1046                     op.opcode == 0 || next_op.opcode == 0)
1047                         continue;
1048
1049                 if (i && has_delay_slot(block->opcode_list[i - 1].c) &&
1050                     !op_flag_no_ds(block->opcode_list[i - 1].flags))
1051                         continue;
1052
1053                 if (op_flag_sync(list->flags) || op_flag_sync(next->flags))
1054                         continue;
1055
1056                 switch (list->i.op) {
1057                 case OP_SPECIAL:
1058                         switch (op.r.op) {
1059                         case OP_SPECIAL_JALR:
1060                                 if (opcode_reads_register(next_op, op.r.rd) ||
1061                                     opcode_writes_register(next_op, op.r.rd))
1062                                         continue;
1063                                 fallthrough;
1064                         case OP_SPECIAL_JR:
1065                                 if (opcode_writes_register(next_op, op.r.rs))
1066                                         continue;
1067                                 fallthrough;
1068                         default:
1069                                 break;
1070                         }
1071                         fallthrough;
1072                 case OP_J:
1073                         break;
1074                 case OP_JAL:
1075                         if (opcode_reads_register(next_op, 31) ||
1076                             opcode_writes_register(next_op, 31))
1077                                 continue;
1078                         else
1079                                 break;
1080                 case OP_BEQ:
1081                 case OP_BNE:
1082                         if (op.i.rt && opcode_writes_register(next_op, op.i.rt))
1083                                 continue;
1084                         fallthrough;
1085                 case OP_BLEZ:
1086                 case OP_BGTZ:
1087                         if (op.i.rs && opcode_writes_register(next_op, op.i.rs))
1088                                 continue;
1089                         break;
1090                 case OP_REGIMM:
1091                         switch (op.r.rt) {
1092                         case OP_REGIMM_BLTZAL:
1093                         case OP_REGIMM_BGEZAL:
1094                                 if (opcode_reads_register(next_op, 31) ||
1095                                     opcode_writes_register(next_op, 31))
1096                                         continue;
1097                                 fallthrough;
1098                         case OP_REGIMM_BLTZ:
1099                         case OP_REGIMM_BGEZ:
1100                                 if (op.i.rs &&
1101                                     opcode_writes_register(next_op, op.i.rs))
1102                                         continue;
1103                                 break;
1104                         }
1105                         fallthrough;
1106                 default:
1107                         break;
1108                 }
1109
1110                 pr_debug("Swap branch and delay slot opcodes "
1111                          "at offsets 0x%x / 0x%x\n",
1112                          i << 2, (i + 1) << 2);
1113
1114                 flags = next->flags;
1115                 list->c = next_op;
1116                 next->c = op;
1117                 next->flags = list->flags | LIGHTREC_NO_DS;
1118                 list->flags = flags | LIGHTREC_NO_DS;
1119         }
1120
1121         return 0;
1122 }
1123
1124 static int shrink_opcode_list(struct lightrec_state *state, struct block *block, u16 new_size)
1125 {
1126         struct opcode *list;
1127
1128         if (new_size >= block->nb_ops) {
1129                 pr_err("Invalid shrink size (%u vs %u)\n",
1130                        new_size, block->nb_ops);
1131                 return -EINVAL;
1132         }
1133
1134
1135         list = lightrec_malloc(state, MEM_FOR_IR,
1136                                sizeof(*list) * new_size);
1137         if (!list) {
1138                 pr_err("Unable to allocate memory\n");
1139                 return -ENOMEM;
1140         }
1141
1142         memcpy(list, block->opcode_list, sizeof(*list) * new_size);
1143
1144         lightrec_free_opcode_list(state, block);
1145         block->opcode_list = list;
1146         block->nb_ops = new_size;
1147
1148         pr_debug("Shrunk opcode list of block PC 0x%08x to %u opcodes\n",
1149                  block->pc, new_size);
1150
1151         return 0;
1152 }
1153
1154 static int lightrec_detect_impossible_branches(struct lightrec_state *state,
1155                                                struct block *block)
1156 {
1157         struct opcode *op, *list = block->opcode_list, *next = &list[0];
1158         unsigned int i;
1159         int ret = 0;
1160         s16 offset;
1161
1162         for (i = 0; i < block->nb_ops - 1; i++) {
1163                 op = next;
1164                 next = &list[i + 1];
1165
1166                 if (!has_delay_slot(op->c) ||
1167                     (!load_in_delay_slot(next->c) &&
1168                      !has_delay_slot(next->c) &&
1169                      !(next->i.op == OP_CP0 && next->r.rs == OP_CP0_RFE)))
1170                         continue;
1171
1172                 if (op->c.opcode == next->c.opcode) {
1173                         /* The delay slot is the exact same opcode as the branch
1174                          * opcode: this is effectively a NOP */
1175                         next->c.opcode = 0;
1176                         continue;
1177                 }
1178
1179                 offset = i + 1 + (s16)op->i.imm;
1180                 if (load_in_delay_slot(next->c) &&
1181                     (offset >= 0 && offset < block->nb_ops) &&
1182                     !opcode_reads_register(list[offset].c, next->c.i.rt)) {
1183                         /* The 'impossible' branch is a local branch - we can
1184                          * verify here that the first opcode of the target does
1185                          * not use the target register of the delay slot */
1186
1187                         pr_debug("Branch at offset 0x%x has load delay slot, "
1188                                  "but is local and dest opcode does not read "
1189                                  "dest register\n", i << 2);
1190                         continue;
1191                 }
1192
1193                 op->flags |= LIGHTREC_EMULATE_BRANCH;
1194
1195                 if (op == list) {
1196                         pr_debug("First opcode of block PC 0x%08x is an impossible branch\n",
1197                                  block->pc);
1198
1199                         /* If the first opcode is an 'impossible' branch, we
1200                          * only keep the first two opcodes of the block (the
1201                          * branch itself + its delay slot) */
1202                         if (block->nb_ops > 2)
1203                                 ret = shrink_opcode_list(state, block, 2);
1204                         break;
1205                 }
1206         }
1207
1208         return ret;
1209 }
1210
1211 static int lightrec_local_branches(struct lightrec_state *state, struct block *block)
1212 {
1213         struct opcode *list;
1214         unsigned int i;
1215         s32 offset;
1216
1217         for (i = 0; i < block->nb_ops; i++) {
1218                 list = &block->opcode_list[i];
1219
1220                 if (should_emulate(list))
1221                         continue;
1222
1223                 switch (list->i.op) {
1224                 case OP_BEQ:
1225                 case OP_BNE:
1226                 case OP_BLEZ:
1227                 case OP_BGTZ:
1228                 case OP_REGIMM:
1229                         offset = i + 1 + (s16)list->i.imm;
1230                         if (offset >= 0 && offset < block->nb_ops)
1231                                 break;
1232                         fallthrough;
1233                 default:
1234                         continue;
1235                 }
1236
1237                 pr_debug("Found local branch to offset 0x%x\n", offset << 2);
1238
1239                 if (should_emulate(&block->opcode_list[offset])) {
1240                         pr_debug("Branch target must be emulated - skip\n");
1241                         continue;
1242                 }
1243
1244                 if (offset && has_delay_slot(block->opcode_list[offset - 1].c)) {
1245                         pr_debug("Branch target is a delay slot - skip\n");
1246                         continue;
1247                 }
1248
1249                 pr_debug("Adding sync at offset 0x%x\n", offset << 2);
1250
1251                 block->opcode_list[offset].flags |= LIGHTREC_SYNC;
1252                 list->flags |= LIGHTREC_LOCAL_BRANCH;
1253         }
1254
1255         return 0;
1256 }
1257
1258 bool has_delay_slot(union code op)
1259 {
1260         switch (op.i.op) {
1261         case OP_SPECIAL:
1262                 switch (op.r.op) {
1263                 case OP_SPECIAL_JR:
1264                 case OP_SPECIAL_JALR:
1265                         return true;
1266                 default:
1267                         return false;
1268                 }
1269         case OP_J:
1270         case OP_JAL:
1271         case OP_BEQ:
1272         case OP_BNE:
1273         case OP_BLEZ:
1274         case OP_BGTZ:
1275         case OP_REGIMM:
1276                 return true;
1277         default:
1278                 return false;
1279         }
1280 }
1281
1282 bool should_emulate(const struct opcode *list)
1283 {
1284         return op_flag_emulate_branch(list->flags) && has_delay_slot(list->c);
1285 }
1286
1287 static bool op_writes_rd(union code c)
1288 {
1289         switch (c.i.op) {
1290         case OP_SPECIAL:
1291         case OP_META_MOV:
1292                 return true;
1293         default:
1294                 return false;
1295         }
1296 }
1297
1298 static void lightrec_add_reg_op(struct opcode *op, u8 reg, u32 reg_op)
1299 {
1300         if (op_writes_rd(op->c) && reg == op->r.rd)
1301                 op->flags |= LIGHTREC_REG_RD(reg_op);
1302         else if (op->i.rs == reg)
1303                 op->flags |= LIGHTREC_REG_RS(reg_op);
1304         else if (op->i.rt == reg)
1305                 op->flags |= LIGHTREC_REG_RT(reg_op);
1306         else
1307                 pr_debug("Cannot add unload/clean/discard flag: "
1308                          "opcode does not touch register %s!\n",
1309                          lightrec_reg_name(reg));
1310 }
1311
1312 static void lightrec_add_unload(struct opcode *op, u8 reg)
1313 {
1314         lightrec_add_reg_op(op, reg, LIGHTREC_REG_UNLOAD);
1315 }
1316
1317 static void lightrec_add_discard(struct opcode *op, u8 reg)
1318 {
1319         lightrec_add_reg_op(op, reg, LIGHTREC_REG_DISCARD);
1320 }
1321
1322 static void lightrec_add_clean(struct opcode *op, u8 reg)
1323 {
1324         lightrec_add_reg_op(op, reg, LIGHTREC_REG_CLEAN);
1325 }
1326
1327 static void
1328 lightrec_early_unload_sync(struct opcode *list, s16 *last_r, s16 *last_w)
1329 {
1330         unsigned int reg;
1331         s16 offset;
1332
1333         for (reg = 0; reg < 34; reg++) {
1334                 offset = s16_max(last_w[reg], last_r[reg]);
1335
1336                 if (offset >= 0)
1337                         lightrec_add_unload(&list[offset], reg);
1338         }
1339
1340         memset(last_r, 0xff, sizeof(*last_r) * 34);
1341         memset(last_w, 0xff, sizeof(*last_w) * 34);
1342 }
1343
1344 static int lightrec_early_unload(struct lightrec_state *state, struct block *block)
1345 {
1346         u16 i, offset;
1347         struct opcode *op;
1348         s16 last_r[34], last_w[34], last_sync = 0, next_sync = 0;
1349         u64 mask_r, mask_w, dirty = 0, loaded = 0;
1350         u8 reg;
1351
1352         memset(last_r, 0xff, sizeof(last_r));
1353         memset(last_w, 0xff, sizeof(last_w));
1354
1355         /*
1356          * Clean if:
1357          * - the register is dirty, and is read again after a branch opcode
1358          *
1359          * Unload if:
1360          * - the register is dirty or loaded, and is not read again
1361          * - the register is dirty or loaded, and is written again after a branch opcode
1362          * - the next opcode has the SYNC flag set
1363          *
1364          * Discard if:
1365          * - the register is dirty or loaded, and is written again
1366          */
1367
1368         for (i = 0; i < block->nb_ops; i++) {
1369                 op = &block->opcode_list[i];
1370
1371                 if (op_flag_sync(op->flags) || should_emulate(op)) {
1372                         /* The next opcode has the SYNC flag set, or is a branch
1373                          * that should be emulated: unload all registers. */
1374                         lightrec_early_unload_sync(block->opcode_list, last_r, last_w);
1375                         dirty = 0;
1376                         loaded = 0;
1377                 }
1378
1379                 if (next_sync == i) {
1380                         last_sync = i;
1381                         pr_debug("Last sync: 0x%x\n", last_sync << 2);
1382                 }
1383
1384                 if (has_delay_slot(op->c)) {
1385                         next_sync = i + 1 + !op_flag_no_ds(op->flags);
1386                         pr_debug("Next sync: 0x%x\n", next_sync << 2);
1387                 }
1388
1389                 mask_r = opcode_read_mask(op->c);
1390                 mask_w = opcode_write_mask(op->c);
1391
1392                 for (reg = 0; reg < 34; reg++) {
1393                         if (mask_r & BIT(reg)) {
1394                                 if (dirty & BIT(reg) && last_w[reg] < last_sync) {
1395                                         /* The register is dirty, and is read
1396                                          * again after a branch: clean it */
1397
1398                                         lightrec_add_clean(&block->opcode_list[last_w[reg]], reg);
1399                                         dirty &= ~BIT(reg);
1400                                         loaded |= BIT(reg);
1401                                 }
1402
1403                                 last_r[reg] = i;
1404                         }
1405
1406                         if (mask_w & BIT(reg)) {
1407                                 if ((dirty & BIT(reg) && last_w[reg] < last_sync) ||
1408                                     (loaded & BIT(reg) && last_r[reg] < last_sync)) {
1409                                         /* The register is dirty or loaded, and
1410                                          * is written again after a branch:
1411                                          * unload it */
1412
1413                                         offset = s16_max(last_w[reg], last_r[reg]);
1414                                         lightrec_add_unload(&block->opcode_list[offset], reg);
1415                                         dirty &= ~BIT(reg);
1416                                         loaded &= ~BIT(reg);
1417                                 } else if (!(mask_r & BIT(reg)) &&
1418                                            ((dirty & BIT(reg) && last_w[reg] > last_sync) ||
1419                                            (loaded & BIT(reg) && last_r[reg] > last_sync))) {
1420                                         /* The register is dirty or loaded, and
1421                                          * is written again: discard it */
1422
1423                                         offset = s16_max(last_w[reg], last_r[reg]);
1424                                         lightrec_add_discard(&block->opcode_list[offset], reg);
1425                                         dirty &= ~BIT(reg);
1426                                         loaded &= ~BIT(reg);
1427                                 }
1428
1429                                 last_w[reg] = i;
1430                         }
1431
1432                 }
1433
1434                 dirty |= mask_w;
1435                 loaded |= mask_r;
1436         }
1437
1438         /* Unload all registers that are dirty or loaded at the end of block. */
1439         lightrec_early_unload_sync(block->opcode_list, last_r, last_w);
1440
1441         return 0;
1442 }
1443
1444 static int lightrec_flag_io(struct lightrec_state *state, struct block *block)
1445 {
1446         struct opcode *prev = NULL, *list = NULL;
1447         enum psx_map psx_map;
1448         u32 known = BIT(0);
1449         u32 values[32] = { 0 };
1450         unsigned int i;
1451         u32 val, kunseg_val;
1452
1453         for (i = 0; i < block->nb_ops; i++) {
1454                 prev = list;
1455                 list = &block->opcode_list[i];
1456
1457                 if (prev)
1458                         known = lightrec_propagate_consts(list, prev, known, values);
1459
1460                 switch (list->i.op) {
1461                 case OP_SB:
1462                 case OP_SH:
1463                 case OP_SW:
1464                         if (OPT_FLAG_STORES) {
1465                                 /* Mark all store operations that target $sp or $gp
1466                                  * as not requiring code invalidation. This is based
1467                                  * on the heuristic that stores using one of these
1468                                  * registers as address will never hit a code page. */
1469                                 if (list->i.rs >= 28 && list->i.rs <= 29 &&
1470                                     !state->maps[PSX_MAP_KERNEL_USER_RAM].ops) {
1471                                         pr_debug("Flaging opcode 0x%08x as not "
1472                                                  "requiring invalidation\n",
1473                                                  list->opcode);
1474                                         list->flags |= LIGHTREC_NO_INVALIDATE;
1475                                         list->flags |= LIGHTREC_IO_MODE(LIGHTREC_IO_DIRECT);
1476                                 }
1477
1478                                 /* Detect writes whose destination address is inside the
1479                                  * current block, using constant propagation. When these
1480                                  * occur, we mark the blocks as not compilable. */
1481                                 if ((known & BIT(list->i.rs)) &&
1482                                     kunseg(values[list->i.rs]) >= kunseg(block->pc) &&
1483                                     kunseg(values[list->i.rs]) < (kunseg(block->pc) +
1484                                                                   block->nb_ops * 4)) {
1485                                         pr_debug("Self-modifying block detected\n");
1486                                         block->flags |= BLOCK_NEVER_COMPILE;
1487                                         list->flags |= LIGHTREC_SMC;
1488                                 }
1489                         }
1490                         fallthrough;
1491                 case OP_SWL:
1492                 case OP_SWR:
1493                 case OP_SWC2:
1494                 case OP_LB:
1495                 case OP_LBU:
1496                 case OP_LH:
1497                 case OP_LHU:
1498                 case OP_LW:
1499                 case OP_LWL:
1500                 case OP_LWR:
1501                 case OP_LWC2:
1502                         if (OPT_FLAG_IO && (known & BIT(list->i.rs))) {
1503                                 val = values[list->i.rs] + (s16) list->i.imm;
1504                                 kunseg_val = kunseg(val);
1505                                 psx_map = lightrec_get_map_idx(state, kunseg_val);
1506
1507                                 list->flags &= ~LIGHTREC_IO_MASK;
1508
1509                                 switch (psx_map) {
1510                                 case PSX_MAP_KERNEL_USER_RAM:
1511                                         if (val == kunseg_val)
1512                                                 list->flags |= LIGHTREC_NO_MASK;
1513                                         fallthrough;
1514                                 case PSX_MAP_MIRROR1:
1515                                 case PSX_MAP_MIRROR2:
1516                                 case PSX_MAP_MIRROR3:
1517                                         pr_debug("Flaging opcode %u as RAM access\n", i);
1518                                         list->flags |= LIGHTREC_IO_MODE(LIGHTREC_IO_RAM);
1519                                         break;
1520                                 case PSX_MAP_BIOS:
1521                                         pr_debug("Flaging opcode %u as BIOS access\n", i);
1522                                         list->flags |= LIGHTREC_IO_MODE(LIGHTREC_IO_BIOS);
1523                                         break;
1524                                 case PSX_MAP_SCRATCH_PAD:
1525                                         pr_debug("Flaging opcode %u as scratchpad access\n", i);
1526                                         list->flags |= LIGHTREC_IO_MODE(LIGHTREC_IO_SCRATCH);
1527
1528                                         /* Consider that we're never going to run code from
1529                                          * the scratchpad. */
1530                                         list->flags |= LIGHTREC_NO_INVALIDATE;
1531                                         break;
1532                                 default:
1533                                         pr_debug("Flagging opcode %u as I/O access\n",
1534                                                  i);
1535                                         list->flags |= LIGHTREC_IO_MODE(LIGHTREC_IO_HW);
1536                                         break;
1537                                 }
1538                         }
1539                         fallthrough;
1540                 default:
1541                         break;
1542                 }
1543         }
1544
1545         return 0;
1546 }
1547
1548 static u8 get_mfhi_mflo_reg(const struct block *block, u16 offset,
1549                             const struct opcode *last,
1550                             u32 mask, bool sync, bool mflo, bool another)
1551 {
1552         const struct opcode *op, *next = &block->opcode_list[offset];
1553         u32 old_mask;
1554         u8 reg2, reg = mflo ? REG_LO : REG_HI;
1555         u16 branch_offset;
1556         unsigned int i;
1557
1558         for (i = offset; i < block->nb_ops; i++) {
1559                 op = next;
1560                 next = &block->opcode_list[i + 1];
1561                 old_mask = mask;
1562
1563                 /* If any other opcode writes or reads to the register
1564                  * we'd use, then we cannot use it anymore. */
1565                 mask |= opcode_read_mask(op->c);
1566                 mask |= opcode_write_mask(op->c);
1567
1568                 if (op_flag_sync(op->flags))
1569                         sync = true;
1570
1571                 switch (op->i.op) {
1572                 case OP_BEQ:
1573                 case OP_BNE:
1574                 case OP_BLEZ:
1575                 case OP_BGTZ:
1576                 case OP_REGIMM:
1577                         /* TODO: handle backwards branches too */
1578                         if (!last && op_flag_local_branch(op->flags) &&
1579                             (s16)op->c.i.imm >= 0) {
1580                                 branch_offset = i + 1 + (s16)op->c.i.imm
1581                                         - !!op_flag_no_ds(op->flags);
1582
1583                                 reg = get_mfhi_mflo_reg(block, branch_offset, NULL,
1584                                                         mask, sync, mflo, false);
1585                                 reg2 = get_mfhi_mflo_reg(block, offset + 1, next,
1586                                                          mask, sync, mflo, false);
1587                                 if (reg > 0 && reg == reg2)
1588                                         return reg;
1589                                 if (!reg && !reg2)
1590                                         return 0;
1591                         }
1592
1593                         return mflo ? REG_LO : REG_HI;
1594                 case OP_SPECIAL:
1595                         switch (op->r.op) {
1596                         case OP_SPECIAL_MULT:
1597                         case OP_SPECIAL_MULTU:
1598                         case OP_SPECIAL_DIV:
1599                         case OP_SPECIAL_DIVU:
1600                                 return 0;
1601                         case OP_SPECIAL_MTHI:
1602                                 if (!mflo)
1603                                         return 0;
1604                                 continue;
1605                         case OP_SPECIAL_MTLO:
1606                                 if (mflo)
1607                                         return 0;
1608                                 continue;
1609                         case OP_SPECIAL_JR:
1610                                 if (op->r.rs != 31)
1611                                         return reg;
1612
1613                                 if (!sync && !op_flag_no_ds(op->flags) &&
1614                                     (next->i.op == OP_SPECIAL) &&
1615                                     ((!mflo && next->r.op == OP_SPECIAL_MFHI) ||
1616                                     (mflo && next->r.op == OP_SPECIAL_MFLO)))
1617                                         return next->r.rd;
1618
1619                                 return 0;
1620                         case OP_SPECIAL_JALR:
1621                                 return reg;
1622                         case OP_SPECIAL_MFHI:
1623                                 if (!mflo) {
1624                                         if (another)
1625                                                 return op->r.rd;
1626                                         /* Must use REG_HI if there is another MFHI target*/
1627                                         reg2 = get_mfhi_mflo_reg(block, i + 1, next,
1628                                                          0, sync, mflo, true);
1629                                         if (reg2 > 0 && reg2 != REG_HI)
1630                                                 return REG_HI;
1631
1632                                         if (!sync && !(old_mask & BIT(op->r.rd)))
1633                                                 return op->r.rd;
1634                                         else
1635                                                 return REG_HI;
1636                                 }
1637                                 continue;
1638                         case OP_SPECIAL_MFLO:
1639                                 if (mflo) {
1640                                         if (another)
1641                                                 return op->r.rd;
1642                                         /* Must use REG_LO if there is another MFLO target*/
1643                                         reg2 = get_mfhi_mflo_reg(block, i + 1, next,
1644                                                          0, sync, mflo, true);
1645                                         if (reg2 > 0 && reg2 != REG_LO)
1646                                                 return REG_LO;
1647
1648                                         if (!sync && !(old_mask & BIT(op->r.rd)))
1649                                                 return op->r.rd;
1650                                         else
1651                                                 return REG_LO;
1652                                 }
1653                                 continue;
1654                         default:
1655                                 break;
1656                         }
1657
1658                         fallthrough;
1659                 default:
1660                         continue;
1661                 }
1662         }
1663
1664         return reg;
1665 }
1666
1667 static void lightrec_replace_lo_hi(struct block *block, u16 offset,
1668                                    u16 last, bool lo)
1669 {
1670         unsigned int i;
1671         u32 branch_offset;
1672
1673         /* This function will remove the following MFLO/MFHI. It must be called
1674          * only if get_mfhi_mflo_reg() returned a non-zero value. */
1675
1676         for (i = offset; i < last; i++) {
1677                 struct opcode *op = &block->opcode_list[i];
1678
1679                 switch (op->i.op) {
1680                 case OP_BEQ:
1681                 case OP_BNE:
1682                 case OP_BLEZ:
1683                 case OP_BGTZ:
1684                 case OP_REGIMM:
1685                         /* TODO: handle backwards branches too */
1686                         if (op_flag_local_branch(op->flags) && (s16)op->c.i.imm >= 0) {
1687                                 branch_offset = i + 1 + (s16)op->c.i.imm
1688                                         - !!op_flag_no_ds(op->flags);
1689
1690                                 lightrec_replace_lo_hi(block, branch_offset, last, lo);
1691                                 lightrec_replace_lo_hi(block, i + 1, branch_offset, lo);
1692                         }
1693                         break;
1694
1695                 case OP_SPECIAL:
1696                         if (lo && op->r.op == OP_SPECIAL_MFLO) {
1697                                 pr_debug("Removing MFLO opcode at offset 0x%x\n",
1698                                          i << 2);
1699                                 op->opcode = 0;
1700                                 return;
1701                         } else if (!lo && op->r.op == OP_SPECIAL_MFHI) {
1702                                 pr_debug("Removing MFHI opcode at offset 0x%x\n",
1703                                          i << 2);
1704                                 op->opcode = 0;
1705                                 return;
1706                         }
1707
1708                         fallthrough;
1709                 default:
1710                         break;
1711                 }
1712         }
1713 }
1714
1715 static bool lightrec_always_skip_div_check(void)
1716 {
1717 #ifdef __mips__
1718         return true;
1719 #else
1720         return false;
1721 #endif
1722 }
1723
1724 static int lightrec_flag_mults_divs(struct lightrec_state *state, struct block *block)
1725 {
1726         struct opcode *prev, *list = NULL;
1727         u8 reg_hi, reg_lo;
1728         unsigned int i;
1729         u32 known = BIT(0);
1730         u32 values[32] = { 0 };
1731
1732         for (i = 0; i < block->nb_ops - 1; i++) {
1733                 prev = list;
1734                 list = &block->opcode_list[i];
1735
1736                 if (prev)
1737                         known = lightrec_propagate_consts(list, prev, known, values);
1738
1739                 if (list->i.op != OP_SPECIAL)
1740                         continue;
1741
1742                 switch (list->r.op) {
1743                 case OP_SPECIAL_DIV:
1744                 case OP_SPECIAL_DIVU:
1745                         /* If we are dividing by a non-zero constant, don't
1746                          * emit the div-by-zero check. */
1747                         if (lightrec_always_skip_div_check() ||
1748                             (known & BIT(list->c.r.rt) && values[list->c.r.rt]))
1749                                 list->flags |= LIGHTREC_NO_DIV_CHECK;
1750                         fallthrough;
1751                 case OP_SPECIAL_MULT:
1752                 case OP_SPECIAL_MULTU:
1753                         break;
1754                 default:
1755                         continue;
1756                 }
1757
1758                 /* Don't support opcodes in delay slots */
1759                 if ((i && has_delay_slot(block->opcode_list[i - 1].c)) ||
1760                     op_flag_no_ds(list->flags)) {
1761                         continue;
1762                 }
1763
1764                 reg_lo = get_mfhi_mflo_reg(block, i + 1, NULL, 0, false, true, false);
1765                 if (reg_lo == 0) {
1766                         pr_debug("Mark MULT(U)/DIV(U) opcode at offset 0x%x as"
1767                                  " not writing LO\n", i << 2);
1768                         list->flags |= LIGHTREC_NO_LO;
1769                 }
1770
1771                 reg_hi = get_mfhi_mflo_reg(block, i + 1, NULL, 0, false, false, false);
1772                 if (reg_hi == 0) {
1773                         pr_debug("Mark MULT(U)/DIV(U) opcode at offset 0x%x as"
1774                                  " not writing HI\n", i << 2);
1775                         list->flags |= LIGHTREC_NO_HI;
1776                 }
1777
1778                 if (!reg_lo && !reg_hi) {
1779                         pr_debug("Both LO/HI unused in this block, they will "
1780                                  "probably be used in parent block - removing "
1781                                  "flags.\n");
1782                         list->flags &= ~(LIGHTREC_NO_LO | LIGHTREC_NO_HI);
1783                 }
1784
1785                 if (reg_lo > 0 && reg_lo != REG_LO) {
1786                         pr_debug("Found register %s to hold LO (rs = %u, rt = %u)\n",
1787                                  lightrec_reg_name(reg_lo), list->r.rs, list->r.rt);
1788
1789                         lightrec_replace_lo_hi(block, i + 1, block->nb_ops, true);
1790                         list->r.rd = reg_lo;
1791                 } else {
1792                         list->r.rd = 0;
1793                 }
1794
1795                 if (reg_hi > 0 && reg_hi != REG_HI) {
1796                         pr_debug("Found register %s to hold HI (rs = %u, rt = %u)\n",
1797                                  lightrec_reg_name(reg_hi), list->r.rs, list->r.rt);
1798
1799                         lightrec_replace_lo_hi(block, i + 1, block->nb_ops, false);
1800                         list->r.imm = reg_hi;
1801                 } else {
1802                         list->r.imm = 0;
1803                 }
1804         }
1805
1806         return 0;
1807 }
1808
1809 static bool remove_div_sequence(struct block *block, unsigned int offset)
1810 {
1811         struct opcode *op;
1812         unsigned int i, found = 0;
1813
1814         /*
1815          * Scan for the zero-checking sequence that GCC automatically introduced
1816          * after most DIV/DIVU opcodes. This sequence checks the value of the
1817          * divisor, and if zero, executes a BREAK opcode, causing the BIOS
1818          * handler to crash the PS1.
1819          *
1820          * For DIV opcodes, this sequence additionally checks that the signed
1821          * operation does not overflow.
1822          *
1823          * With the assumption that the games never crashed the PS1, we can
1824          * therefore assume that the games never divided by zero or overflowed,
1825          * and these sequences can be removed.
1826          */
1827
1828         for (i = offset; i < block->nb_ops; i++) {
1829                 op = &block->opcode_list[i];
1830
1831                 if (!found) {
1832                         if (op->i.op == OP_SPECIAL &&
1833                             (op->r.op == OP_SPECIAL_DIV || op->r.op == OP_SPECIAL_DIVU))
1834                                 break;
1835
1836                         if ((op->opcode & 0xfc1fffff) == 0x14000002) {
1837                                 /* BNE ???, zero, +8 */
1838                                 found++;
1839                         } else {
1840                                 offset++;
1841                         }
1842                 } else if (found == 1 && !op->opcode) {
1843                         /* NOP */
1844                         found++;
1845                 } else if (found == 2 && op->opcode == 0x0007000d) {
1846                         /* BREAK 0x1c00 */
1847                         found++;
1848                 } else if (found == 3 && op->opcode == 0x2401ffff) {
1849                         /* LI at, -1 */
1850                         found++;
1851                 } else if (found == 4 && (op->opcode & 0xfc1fffff) == 0x14010004) {
1852                         /* BNE ???, at, +16 */
1853                         found++;
1854                 } else if (found == 5 && op->opcode == 0x3c018000) {
1855                         /* LUI at, 0x8000 */
1856                         found++;
1857                 } else if (found == 6 && (op->opcode & 0x141fffff) == 0x14010002) {
1858                         /* BNE ???, at, +16 */
1859                         found++;
1860                 } else if (found == 7 && !op->opcode) {
1861                         /* NOP */
1862                         found++;
1863                 } else if (found == 8 && op->opcode == 0x0006000d) {
1864                         /* BREAK 0x1800 */
1865                         found++;
1866                         break;
1867                 } else {
1868                         break;
1869                 }
1870         }
1871
1872         if (found >= 3) {
1873                 if (found != 9)
1874                         found = 3;
1875
1876                 pr_debug("Removing DIV%s sequence at offset 0x%x\n",
1877                          found == 9 ? "" : "U", offset << 2);
1878
1879                 for (i = 0; i < found; i++)
1880                         block->opcode_list[offset + i].opcode = 0;
1881
1882                 return true;
1883         }
1884
1885         return false;
1886 }
1887
1888 static int lightrec_remove_div_by_zero_check_sequence(struct lightrec_state *state,
1889                                                       struct block *block)
1890 {
1891         struct opcode *op;
1892         unsigned int i;
1893
1894         for (i = 0; i < block->nb_ops; i++) {
1895                 op = &block->opcode_list[i];
1896
1897                 if (op->i.op == OP_SPECIAL &&
1898                     (op->r.op == OP_SPECIAL_DIVU || op->r.op == OP_SPECIAL_DIV) &&
1899                     remove_div_sequence(block, i + 1))
1900                         op->flags |= LIGHTREC_NO_DIV_CHECK;
1901         }
1902
1903         return 0;
1904 }
1905
1906 static const u32 memset_code[] = {
1907         0x10a00006,     // beqz         a1, 2f
1908         0x24a2ffff,     // addiu        v0,a1,-1
1909         0x2403ffff,     // li           v1,-1
1910         0xac800000,     // 1: sw        zero,0(a0)
1911         0x2442ffff,     // addiu        v0,v0,-1
1912         0x1443fffd,     // bne          v0,v1, 1b
1913         0x24840004,     // addiu        a0,a0,4
1914         0x03e00008,     // 2: jr        ra
1915         0x00000000,     // nop
1916 };
1917
1918 static int lightrec_replace_memset(struct lightrec_state *state, struct block *block)
1919 {
1920         unsigned int i;
1921         union code c;
1922
1923         for (i = 0; i < block->nb_ops; i++) {
1924                 c = block->opcode_list[i].c;
1925
1926                 if (c.opcode != memset_code[i])
1927                         return 0;
1928
1929                 if (i == ARRAY_SIZE(memset_code) - 1) {
1930                         /* success! */
1931                         pr_debug("Block at PC 0x%x is a memset\n", block->pc);
1932                         block->flags |= BLOCK_IS_MEMSET | BLOCK_NEVER_COMPILE;
1933
1934                         /* Return non-zero to skip other optimizers. */
1935                         return 1;
1936                 }
1937         }
1938
1939         return 0;
1940 }
1941
1942 static int (*lightrec_optimizers[])(struct lightrec_state *state, struct block *) = {
1943         IF_OPT(OPT_REMOVE_DIV_BY_ZERO_SEQ, &lightrec_remove_div_by_zero_check_sequence),
1944         IF_OPT(OPT_REPLACE_MEMSET, &lightrec_replace_memset),
1945         IF_OPT(OPT_DETECT_IMPOSSIBLE_BRANCHES, &lightrec_detect_impossible_branches),
1946         IF_OPT(OPT_TRANSFORM_OPS, &lightrec_transform_branches),
1947         IF_OPT(OPT_LOCAL_BRANCHES, &lightrec_local_branches),
1948         IF_OPT(OPT_TRANSFORM_OPS, &lightrec_transform_ops),
1949         IF_OPT(OPT_SWITCH_DELAY_SLOTS, &lightrec_switch_delay_slots),
1950         IF_OPT(OPT_FLAG_IO || OPT_FLAG_STORES, &lightrec_flag_io),
1951         IF_OPT(OPT_FLAG_MULT_DIV, &lightrec_flag_mults_divs),
1952         IF_OPT(OPT_EARLY_UNLOAD, &lightrec_early_unload),
1953 };
1954
1955 int lightrec_optimize(struct lightrec_state *state, struct block *block)
1956 {
1957         unsigned int i;
1958         int ret;
1959
1960         for (i = 0; i < ARRAY_SIZE(lightrec_optimizers); i++) {
1961                 if (lightrec_optimizers[i]) {
1962                         ret = (*lightrec_optimizers[i])(state, block);
1963                         if (ret)
1964                                 return ret;
1965                 }
1966         }
1967
1968         return 0;
1969 }