cf431f2e264e14a28447facdd523b6c495f4c32c
[pcsx_rearmed.git] / deps / lightrec / optimizer.c
1 /*
2  * Copyright (C) 2014-2020 Paul Cercueil <paul@crapouillou.net>
3  *
4  * This library is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * This library is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  */
14
15 #include "disassembler.h"
16 #include "lightrec.h"
17 #include "memmanager.h"
18 #include "optimizer.h"
19 #include "regcache.h"
20
21 #include <errno.h>
22 #include <stdbool.h>
23 #include <stdlib.h>
24
25 struct optimizer_list {
26         void (**optimizers)(struct opcode *);
27         unsigned int nb_optimizers;
28 };
29
30 bool opcode_reads_register(union code op, u8 reg)
31 {
32         switch (op.i.op) {
33         case OP_SPECIAL:
34                 switch (op.r.op) {
35                 case OP_SPECIAL_SYSCALL:
36                 case OP_SPECIAL_BREAK:
37                         return false;
38                 case OP_SPECIAL_JR:
39                 case OP_SPECIAL_JALR:
40                 case OP_SPECIAL_MTHI:
41                 case OP_SPECIAL_MTLO:
42                         return op.r.rs == reg;
43                 case OP_SPECIAL_MFHI:
44                         return reg == REG_HI;
45                 case OP_SPECIAL_MFLO:
46                         return reg == REG_LO;
47                 case OP_SPECIAL_SLL:
48                 case OP_SPECIAL_SRL:
49                 case OP_SPECIAL_SRA:
50                         return op.r.rt == reg;
51                 default:
52                         return op.r.rs == reg || op.r.rt == reg;
53                 }
54         case OP_CP0:
55                 switch (op.r.rs) {
56                 case OP_CP0_MTC0:
57                 case OP_CP0_CTC0:
58                         return op.r.rt == reg;
59                 default:
60                         return false;
61                 }
62         case OP_CP2:
63                 if (op.r.op == OP_CP2_BASIC) {
64                         switch (op.r.rs) {
65                         case OP_CP2_BASIC_MTC2:
66                         case OP_CP2_BASIC_CTC2:
67                                 return op.r.rt == reg;
68                         default:
69                                 return false;
70                         }
71                 } else {
72                         return false;
73                 }
74         case OP_J:
75         case OP_JAL:
76         case OP_LUI:
77                 return false;
78         case OP_BEQ:
79         case OP_BNE:
80         case OP_LWL:
81         case OP_LWR:
82         case OP_SB:
83         case OP_SH:
84         case OP_SWL:
85         case OP_SW:
86         case OP_SWR:
87                 return op.i.rs == reg || op.i.rt == reg;
88         default:
89                 return op.i.rs == reg;
90         }
91 }
92
93 bool opcode_writes_register(union code op, u8 reg)
94 {
95         switch (op.i.op) {
96         case OP_SPECIAL:
97                 switch (op.r.op) {
98                 case OP_SPECIAL_JR:
99                 case OP_SPECIAL_JALR:
100                 case OP_SPECIAL_SYSCALL:
101                 case OP_SPECIAL_BREAK:
102                         return false;
103                 case OP_SPECIAL_MULT:
104                 case OP_SPECIAL_MULTU:
105                 case OP_SPECIAL_DIV:
106                 case OP_SPECIAL_DIVU:
107                         return reg == REG_LO || reg == REG_HI;
108                 case OP_SPECIAL_MTHI:
109                         return reg == REG_HI;
110                 case OP_SPECIAL_MTLO:
111                         return reg == REG_LO;
112                 default:
113                         return op.r.rd == reg;
114                 }
115         case OP_ADDI:
116         case OP_ADDIU:
117         case OP_SLTI:
118         case OP_SLTIU:
119         case OP_ANDI:
120         case OP_ORI:
121         case OP_XORI:
122         case OP_LUI:
123         case OP_LB:
124         case OP_LH:
125         case OP_LWL:
126         case OP_LW:
127         case OP_LBU:
128         case OP_LHU:
129         case OP_LWR:
130                 return op.i.rt == reg;
131         case OP_CP0:
132                 switch (op.r.rs) {
133                 case OP_CP0_MFC0:
134                 case OP_CP0_CFC0:
135                         return op.i.rt == reg;
136                 default:
137                         return false;
138                 }
139         case OP_CP2:
140                 if (op.r.op == OP_CP2_BASIC) {
141                         switch (op.r.rs) {
142                         case OP_CP2_BASIC_MFC2:
143                         case OP_CP2_BASIC_CFC2:
144                                 return op.i.rt == reg;
145                         default:
146                                 return false;
147                         }
148                 } else {
149                         return false;
150                 }
151         case OP_META_MOV:
152                 return op.r.rd == reg;
153         default:
154                 return false;
155         }
156 }
157
158 /* TODO: Complete */
159 static bool is_nop(union code op)
160 {
161         if (opcode_writes_register(op, 0)) {
162                 switch (op.i.op) {
163                 case OP_CP0:
164                         return op.r.rs != OP_CP0_MFC0;
165                 case OP_LB:
166                 case OP_LH:
167                 case OP_LWL:
168                 case OP_LW:
169                 case OP_LBU:
170                 case OP_LHU:
171                 case OP_LWR:
172                         return false;
173                 default:
174                         return true;
175                 }
176         }
177
178         switch (op.i.op) {
179         case OP_SPECIAL:
180                 switch (op.r.op) {
181                 case OP_SPECIAL_AND:
182                         return op.r.rd == op.r.rt && op.r.rd == op.r.rs;
183                 case OP_SPECIAL_ADD:
184                 case OP_SPECIAL_ADDU:
185                         return (op.r.rd == op.r.rt && op.r.rs == 0) ||
186                                 (op.r.rd == op.r.rs && op.r.rt == 0);
187                 case OP_SPECIAL_SUB:
188                 case OP_SPECIAL_SUBU:
189                         return op.r.rd == op.r.rs && op.r.rt == 0;
190                 case OP_SPECIAL_OR:
191                         if (op.r.rd == op.r.rt)
192                                 return op.r.rd == op.r.rs || op.r.rs == 0;
193                         else
194                                 return (op.r.rd == op.r.rs) && op.r.rt == 0;
195                 case OP_SPECIAL_SLL:
196                 case OP_SPECIAL_SRA:
197                 case OP_SPECIAL_SRL:
198                         return op.r.rd == op.r.rt && op.r.imm == 0;
199                 default:
200                         return false;
201                 }
202         case OP_ORI:
203         case OP_ADDI:
204         case OP_ADDIU:
205                 return op.i.rt == op.i.rs && op.i.imm == 0;
206         case OP_BGTZ:
207                 return (op.i.rs == 0 || op.i.imm == 1);
208         case OP_REGIMM:
209                 return (op.i.op == OP_REGIMM_BLTZ ||
210                                 op.i.op == OP_REGIMM_BLTZAL) &&
211                         (op.i.rs == 0 || op.i.imm == 1);
212         case OP_BNE:
213                 return (op.i.rs == op.i.rt || op.i.imm == 1);
214         default:
215                 return false;
216         }
217 }
218
219 bool load_in_delay_slot(union code op)
220 {
221         switch (op.i.op) {
222         case OP_CP0:
223                 switch (op.r.rs) {
224                 case OP_CP0_MFC0:
225                 case OP_CP0_CFC0:
226                         return true;
227                 default:
228                         break;
229                 }
230
231                 break;
232         case OP_CP2:
233                 if (op.r.op == OP_CP2_BASIC) {
234                         switch (op.r.rs) {
235                         case OP_CP2_BASIC_MFC2:
236                         case OP_CP2_BASIC_CFC2:
237                                 return true;
238                         default:
239                                 break;
240                         }
241                 }
242
243                 break;
244         case OP_LB:
245         case OP_LH:
246         case OP_LW:
247         case OP_LWL:
248         case OP_LWR:
249         case OP_LBU:
250         case OP_LHU:
251                 return true;
252         default:
253                 break;
254         }
255
256         return false;
257 }
258
259 static u32 lightrec_propagate_consts(union code c, u32 known, u32 *v)
260 {
261         switch (c.i.op) {
262         case OP_SPECIAL:
263                 switch (c.r.op) {
264                 case OP_SPECIAL_SLL:
265                         if (known & BIT(c.r.rt)) {
266                                 known |= BIT(c.r.rd);
267                                 v[c.r.rd] = v[c.r.rt] << c.r.imm;
268                         } else {
269                                 known &= ~BIT(c.r.rd);
270                         }
271                         break;
272                 case OP_SPECIAL_SRL:
273                         if (known & BIT(c.r.rt)) {
274                                 known |= BIT(c.r.rd);
275                                 v[c.r.rd] = v[c.r.rt] >> c.r.imm;
276                         } else {
277                                 known &= ~BIT(c.r.rd);
278                         }
279                         break;
280                 case OP_SPECIAL_SRA:
281                         if (known & BIT(c.r.rt)) {
282                                 known |= BIT(c.r.rd);
283                                 v[c.r.rd] = (s32)v[c.r.rt] >> c.r.imm;
284                         } else {
285                                 known &= ~BIT(c.r.rd);
286                         }
287                         break;
288                 case OP_SPECIAL_SLLV:
289                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
290                                 known |= BIT(c.r.rd);
291                                 v[c.r.rd] = v[c.r.rt] << (v[c.r.rs] & 0x1f);
292                         } else {
293                                 known &= ~BIT(c.r.rd);
294                         }
295                         break;
296                 case OP_SPECIAL_SRLV:
297                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
298                                 known |= BIT(c.r.rd);
299                                 v[c.r.rd] = v[c.r.rt] >> (v[c.r.rs] & 0x1f);
300                         } else {
301                                 known &= ~BIT(c.r.rd);
302                         }
303                         break;
304                 case OP_SPECIAL_SRAV:
305                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
306                                 known |= BIT(c.r.rd);
307                                 v[c.r.rd] = (s32)v[c.r.rt]
308                                           >> (v[c.r.rs] & 0x1f);
309                         } else {
310                                 known &= ~BIT(c.r.rd);
311                         }
312                         break;
313                 case OP_SPECIAL_ADD:
314                 case OP_SPECIAL_ADDU:
315                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
316                                 known |= BIT(c.r.rd);
317                                 v[c.r.rd] = (s32)v[c.r.rt] + (s32)v[c.r.rs];
318                         } else {
319                                 known &= ~BIT(c.r.rd);
320                         }
321                         break;
322                 case OP_SPECIAL_SUB:
323                 case OP_SPECIAL_SUBU:
324                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
325                                 known |= BIT(c.r.rd);
326                                 v[c.r.rd] = v[c.r.rt] - v[c.r.rs];
327                         } else {
328                                 known &= ~BIT(c.r.rd);
329                         }
330                         break;
331                 case OP_SPECIAL_AND:
332                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
333                                 known |= BIT(c.r.rd);
334                                 v[c.r.rd] = v[c.r.rt] & v[c.r.rs];
335                         } else {
336                                 known &= ~BIT(c.r.rd);
337                         }
338                         break;
339                 case OP_SPECIAL_OR:
340                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
341                                 known |= BIT(c.r.rd);
342                                 v[c.r.rd] = v[c.r.rt] | v[c.r.rs];
343                         } else {
344                                 known &= ~BIT(c.r.rd);
345                         }
346                         break;
347                 case OP_SPECIAL_XOR:
348                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
349                                 known |= BIT(c.r.rd);
350                                 v[c.r.rd] = v[c.r.rt] ^ v[c.r.rs];
351                         } else {
352                                 known &= ~BIT(c.r.rd);
353                         }
354                         break;
355                 case OP_SPECIAL_NOR:
356                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
357                                 known |= BIT(c.r.rd);
358                                 v[c.r.rd] = ~(v[c.r.rt] | v[c.r.rs]);
359                         } else {
360                                 known &= ~BIT(c.r.rd);
361                         }
362                         break;
363                 case OP_SPECIAL_SLT:
364                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
365                                 known |= BIT(c.r.rd);
366                                 v[c.r.rd] = (s32)v[c.r.rs] < (s32)v[c.r.rt];
367                         } else {
368                                 known &= ~BIT(c.r.rd);
369                         }
370                         break;
371                 case OP_SPECIAL_SLTU:
372                         if (known & BIT(c.r.rt) && known & BIT(c.r.rs)) {
373                                 known |= BIT(c.r.rd);
374                                 v[c.r.rd] = v[c.r.rs] < v[c.r.rt];
375                         } else {
376                                 known &= ~BIT(c.r.rd);
377                         }
378                         break;
379                 default:
380                         break;
381                 }
382                 break;
383         case OP_REGIMM:
384                 break;
385         case OP_ADDI:
386         case OP_ADDIU:
387                 if (known & BIT(c.i.rs)) {
388                         known |= BIT(c.i.rt);
389                         v[c.i.rt] = v[c.i.rs] + (s32)(s16)c.i.imm;
390                 } else {
391                         known &= ~BIT(c.i.rt);
392                 }
393                 break;
394         case OP_SLTI:
395                 if (known & BIT(c.i.rs)) {
396                         known |= BIT(c.i.rt);
397                         v[c.i.rt] = (s32)v[c.i.rs] < (s32)(s16)c.i.imm;
398                 } else {
399                         known &= ~BIT(c.i.rt);
400                 }
401                 break;
402         case OP_SLTIU:
403                 if (known & BIT(c.i.rs)) {
404                         known |= BIT(c.i.rt);
405                         v[c.i.rt] = v[c.i.rs] < (u32)(s32)(s16)c.i.imm;
406                 } else {
407                         known &= ~BIT(c.i.rt);
408                 }
409                 break;
410         case OP_ANDI:
411                 if (known & BIT(c.i.rs)) {
412                         known |= BIT(c.i.rt);
413                         v[c.i.rt] = v[c.i.rs] & c.i.imm;
414                 } else {
415                         known &= ~BIT(c.i.rt);
416                 }
417                 break;
418         case OP_ORI:
419                 if (known & BIT(c.i.rs)) {
420                         known |= BIT(c.i.rt);
421                         v[c.i.rt] = v[c.i.rs] | c.i.imm;
422                 } else {
423                         known &= ~BIT(c.i.rt);
424                 }
425                 break;
426         case OP_XORI:
427                 if (known & BIT(c.i.rs)) {
428                         known |= BIT(c.i.rt);
429                         v[c.i.rt] = v[c.i.rs] ^ c.i.imm;
430                 } else {
431                         known &= ~BIT(c.i.rt);
432                 }
433                 break;
434         case OP_LUI:
435                 known |= BIT(c.i.rt);
436                 v[c.i.rt] = c.i.imm << 16;
437                 break;
438         case OP_CP0:
439                 switch (c.r.rs) {
440                 case OP_CP0_MFC0:
441                 case OP_CP0_CFC0:
442                         known &= ~BIT(c.r.rt);
443                         break;
444                 }
445                 break;
446         case OP_CP2:
447                 if (c.r.op == OP_CP2_BASIC) {
448                         switch (c.r.rs) {
449                         case OP_CP2_BASIC_MFC2:
450                         case OP_CP2_BASIC_CFC2:
451                                 known &= ~BIT(c.r.rt);
452                                 break;
453                         }
454                 }
455                 break;
456         case OP_LB:
457         case OP_LH:
458         case OP_LWL:
459         case OP_LW:
460         case OP_LBU:
461         case OP_LHU:
462         case OP_LWR:
463         case OP_LWC2:
464                 known &= ~BIT(c.i.rt);
465                 break;
466         case OP_META_MOV:
467                 if (known & BIT(c.r.rs)) {
468                         known |= BIT(c.r.rd);
469                         v[c.r.rd] = v[c.r.rs];
470                 } else {
471                         known &= ~BIT(c.r.rd);
472                 }
473                 break;
474         default:
475                 break;
476         }
477
478         return known;
479 }
480
481 static int lightrec_add_meta(struct block *block,
482                              struct opcode *op, union code code)
483 {
484         struct opcode *meta;
485
486         meta = lightrec_malloc(block->state, MEM_FOR_IR, sizeof(*meta));
487         if (!meta)
488                 return -ENOMEM;
489
490         meta->c = code;
491         meta->flags = 0;
492
493         if (op) {
494                 meta->offset = op->offset;
495                 meta->next = op->next;
496                 op->next = meta;
497         } else {
498                 meta->offset = 0;
499                 meta->next = block->opcode_list;
500                 block->opcode_list = meta;
501         }
502
503         return 0;
504 }
505
506 static int lightrec_add_sync(struct block *block, struct opcode *prev)
507 {
508         return lightrec_add_meta(block, prev, (union code){
509                                  .j.op = OP_META_SYNC,
510                                  });
511 }
512
513 static int lightrec_transform_ops(struct block *block)
514 {
515         struct opcode *list = block->opcode_list;
516
517         for (; list; list = list->next) {
518
519                 /* Transform all opcodes detected as useless to real NOPs
520                  * (0x0: SLL r0, r0, #0) */
521                 if (list->opcode != 0 && is_nop(list->c)) {
522                         pr_debug("Converting useless opcode 0x%08x to NOP\n",
523                                         list->opcode);
524                         list->opcode = 0x0;
525                 }
526
527                 if (!list->opcode)
528                         continue;
529
530                 switch (list->i.op) {
531                 /* Transform BEQ / BNE to BEQZ / BNEZ meta-opcodes if one of the
532                  * two registers is zero. */
533                 case OP_BEQ:
534                         if ((list->i.rs == 0) ^ (list->i.rt == 0)) {
535                                 list->i.op = OP_META_BEQZ;
536                                 if (list->i.rs == 0) {
537                                         list->i.rs = list->i.rt;
538                                         list->i.rt = 0;
539                                 }
540                         } else if (list->i.rs == list->i.rt) {
541                                 list->i.rs = 0;
542                                 list->i.rt = 0;
543                         }
544                         break;
545                 case OP_BNE:
546                         if (list->i.rs == 0) {
547                                 list->i.op = OP_META_BNEZ;
548                                 list->i.rs = list->i.rt;
549                                 list->i.rt = 0;
550                         } else if (list->i.rt == 0) {
551                                 list->i.op = OP_META_BNEZ;
552                         }
553                         break;
554
555                 /* Transform ORI/ADDI/ADDIU with imm #0 or ORR/ADD/ADDU/SUB/SUBU
556                  * with register $zero to the MOV meta-opcode */
557                 case OP_ORI:
558                 case OP_ADDI:
559                 case OP_ADDIU:
560                         if (list->i.imm == 0) {
561                                 pr_debug("Convert ORI/ADDI/ADDIU #0 to MOV\n");
562                                 list->i.op = OP_META_MOV;
563                                 list->r.rd = list->i.rt;
564                         }
565                         break;
566                 case OP_SPECIAL:
567                         switch (list->r.op) {
568                         case OP_SPECIAL_SLL:
569                         case OP_SPECIAL_SRA:
570                         case OP_SPECIAL_SRL:
571                                 if (list->r.imm == 0) {
572                                         pr_debug("Convert SLL/SRL/SRA #0 to MOV\n");
573                                         list->i.op = OP_META_MOV;
574                                         list->r.rs = list->r.rt;
575                                 }
576                                 break;
577                         case OP_SPECIAL_OR:
578                         case OP_SPECIAL_ADD:
579                         case OP_SPECIAL_ADDU:
580                                 if (list->r.rs == 0) {
581                                         pr_debug("Convert OR/ADD $zero to MOV\n");
582                                         list->i.op = OP_META_MOV;
583                                         list->r.rs = list->r.rt;
584                                 }
585                         case OP_SPECIAL_SUB: /* fall-through */
586                         case OP_SPECIAL_SUBU:
587                                 if (list->r.rt == 0) {
588                                         pr_debug("Convert OR/ADD/SUB $zero to MOV\n");
589                                         list->i.op = OP_META_MOV;
590                                 }
591                         default: /* fall-through */
592                                 break;
593                         }
594                 default: /* fall-through */
595                         break;
596                 }
597         }
598
599         return 0;
600 }
601
602 static int lightrec_switch_delay_slots(struct block *block)
603 {
604         struct opcode *list, *prev;
605         u8 flags;
606
607         for (list = block->opcode_list, prev = NULL; list->next;
608              prev = list, list = list->next) {
609                 union code op = list->c;
610                 union code next_op = list->next->c;
611
612                 if (!has_delay_slot(op) ||
613                     list->flags & (LIGHTREC_NO_DS | LIGHTREC_EMULATE_BRANCH) ||
614                     op.opcode == 0)
615                         continue;
616
617                 if (prev && has_delay_slot(prev->c))
618                         continue;
619
620                 switch (list->i.op) {
621                 case OP_SPECIAL:
622                         switch (op.r.op) {
623                         case OP_SPECIAL_JALR:
624                                 if (opcode_reads_register(next_op, op.r.rd) ||
625                                     opcode_writes_register(next_op, op.r.rd))
626                                         continue;
627                         case OP_SPECIAL_JR: /* fall-through */
628                                 if (opcode_writes_register(next_op, op.r.rs))
629                                         continue;
630                         default: /* fall-through */
631                                 break;
632                         }
633                 case OP_J: /* fall-through */
634                         break;
635                 case OP_JAL:
636                         if (opcode_reads_register(next_op, 31) ||
637                             opcode_writes_register(next_op, 31))
638                                 continue;
639                         else
640                                 break;
641                 case OP_BEQ:
642                 case OP_BNE:
643                         if (op.i.rt && opcode_writes_register(next_op, op.i.rt))
644                                 continue;
645                 case OP_BLEZ: /* fall-through */
646                 case OP_BGTZ:
647                 case OP_META_BEQZ:
648                 case OP_META_BNEZ:
649                         if (op.i.rs && opcode_writes_register(next_op, op.i.rs))
650                                 continue;
651                         break;
652                 case OP_REGIMM:
653                         switch (op.r.rt) {
654                         case OP_REGIMM_BLTZAL:
655                         case OP_REGIMM_BGEZAL:
656                                 if (opcode_reads_register(next_op, 31) ||
657                                     opcode_writes_register(next_op, 31))
658                                         continue;
659                         case OP_REGIMM_BLTZ: /* fall-through */
660                         case OP_REGIMM_BGEZ:
661                                 if (op.i.rs &&
662                                     opcode_writes_register(next_op, op.i.rs))
663                                         continue;
664                                 break;
665                         }
666                 default: /* fall-through */
667                         break;
668                 }
669
670                 pr_debug("Swap branch and delay slot opcodes "
671                          "at offsets 0x%x / 0x%x\n", list->offset << 2,
672                          list->next->offset << 2);
673
674                 flags = list->next->flags;
675                 list->c = next_op;
676                 list->next->c = op;
677                 list->next->flags = list->flags | LIGHTREC_NO_DS;
678                 list->flags = flags | LIGHTREC_NO_DS;
679                 list->offset++;
680                 list->next->offset--;
681         }
682
683         return 0;
684 }
685
686 static int lightrec_detect_impossible_branches(struct block *block)
687 {
688         struct opcode *op, *next;
689
690         for (op = block->opcode_list, next = op->next; next;
691              op = next, next = op->next) {
692                 if (!has_delay_slot(op->c) ||
693                     (!load_in_delay_slot(next->c) &&
694                      !has_delay_slot(next->c) &&
695                      !(next->i.op == OP_CP0 && next->r.rs == OP_CP0_RFE)))
696                         continue;
697
698                 if (op->c.opcode == next->c.opcode) {
699                         /* The delay slot is the exact same opcode as the branch
700                          * opcode: this is effectively a NOP */
701                         next->c.opcode = 0;
702                         continue;
703                 }
704
705                 if (op == block->opcode_list) {
706                         /* If the first opcode is an 'impossible' branch, we
707                          * only keep the first two opcodes of the block (the
708                          * branch itself + its delay slot) */
709                         lightrec_free_opcode_list(block->state, next->next);
710                         next->next = NULL;
711                         block->nb_ops = 2;
712                 }
713
714                 op->flags |= LIGHTREC_EMULATE_BRANCH;
715         }
716
717         return 0;
718 }
719
720 static int lightrec_local_branches(struct block *block)
721 {
722         struct opcode *list, *target, *prev;
723         s32 offset;
724         int ret;
725
726         for (list = block->opcode_list; list; list = list->next) {
727                 if (list->flags & LIGHTREC_EMULATE_BRANCH)
728                         continue;
729
730                 switch (list->i.op) {
731                 case OP_BEQ:
732                 case OP_BNE:
733                 case OP_BLEZ:
734                 case OP_BGTZ:
735                 case OP_REGIMM:
736                 case OP_META_BEQZ:
737                 case OP_META_BNEZ:
738                         offset = list->offset + 1 + (s16)list->i.imm;
739                         if (offset >= 0 && offset < block->nb_ops)
740                                 break;
741                 default: /* fall-through */
742                         continue;
743                 }
744
745                 pr_debug("Found local branch to offset 0x%x\n", offset << 2);
746
747                 for (target = block->opcode_list, prev = NULL;
748                      target; prev = target, target = target->next) {
749                         if (target->offset != offset ||
750                             target->j.op == OP_META_SYNC)
751                                 continue;
752
753                         if (target->flags & LIGHTREC_EMULATE_BRANCH) {
754                                 pr_debug("Branch target must be emulated"
755                                          " - skip\n");
756                                 break;
757                         }
758
759                         if (prev && has_delay_slot(prev->c)) {
760                                 pr_debug("Branch target is a delay slot"
761                                          " - skip\n");
762                                 break;
763                         }
764
765                         if (prev && prev->j.op != OP_META_SYNC) {
766                                 pr_debug("Adding sync before offset "
767                                          "0x%x\n", offset << 2);
768                                 ret = lightrec_add_sync(block, prev);
769                                 if (ret)
770                                         return ret;
771
772                                 prev->next->offset = target->offset;
773                         }
774
775                         list->flags |= LIGHTREC_LOCAL_BRANCH;
776                         break;
777                 }
778         }
779
780         return 0;
781 }
782
783 bool has_delay_slot(union code op)
784 {
785         switch (op.i.op) {
786         case OP_SPECIAL:
787                 switch (op.r.op) {
788                 case OP_SPECIAL_JR:
789                 case OP_SPECIAL_JALR:
790                         return true;
791                 default:
792                         return false;
793                 }
794         case OP_J:
795         case OP_JAL:
796         case OP_BEQ:
797         case OP_BNE:
798         case OP_BLEZ:
799         case OP_BGTZ:
800         case OP_REGIMM:
801         case OP_META_BEQZ:
802         case OP_META_BNEZ:
803                 return true;
804         default:
805                 return false;
806         }
807 }
808
809 static int lightrec_add_unload(struct block *block, struct opcode *op, u8 reg)
810 {
811         return lightrec_add_meta(block, op, (union code){
812                                  .i.op = OP_META_REG_UNLOAD,
813                                  .i.rs = reg,
814                                  });
815 }
816
817 static int lightrec_early_unload(struct block *block)
818 {
819         struct opcode *list = block->opcode_list;
820         u8 i;
821
822         for (i = 1; i < 34; i++) {
823                 struct opcode *op, *last_r = NULL, *last_w = NULL;
824                 unsigned int last_r_id = 0, last_w_id = 0, id = 0;
825                 int ret;
826
827                 for (op = list; op->next; op = op->next, id++) {
828                         if (opcode_reads_register(op->c, i)) {
829                                 last_r = op;
830                                 last_r_id = id;
831                         }
832
833                         if (opcode_writes_register(op->c, i)) {
834                                 last_w = op;
835                                 last_w_id = id;
836                         }
837                 }
838
839                 if (last_w_id > last_r_id) {
840                         if (has_delay_slot(last_w->c) &&
841                             !(last_w->flags & LIGHTREC_NO_DS))
842                                 last_w = last_w->next;
843
844                         if (last_w->next) {
845                                 ret = lightrec_add_unload(block, last_w, i);
846                                 if (ret)
847                                         return ret;
848                         }
849                 } else if (last_r) {
850                         if (has_delay_slot(last_r->c) &&
851                             !(last_r->flags & LIGHTREC_NO_DS))
852                                 last_r = last_r->next;
853
854                         if (last_r->next) {
855                                 ret = lightrec_add_unload(block, last_r, i);
856                                 if (ret)
857                                         return ret;
858                         }
859                 }
860         }
861
862         return 0;
863 }
864
865 static int lightrec_flag_stores(struct block *block)
866 {
867         struct opcode *list;
868         u32 known = BIT(0);
869         u32 values[32] = { 0 };
870
871         for (list = block->opcode_list; list; list = list->next) {
872                 /* Register $zero is always, well, zero */
873                 known |= BIT(0);
874                 values[0] = 0;
875
876                 switch (list->i.op) {
877                 case OP_SB:
878                 case OP_SH:
879                 case OP_SW:
880                         /* Mark all store operations that target $sp or $gp
881                          * as not requiring code invalidation. This is based
882                          * on the heuristic that stores using one of these
883                          * registers as address will never hit a code page. */
884                         if (list->i.rs >= 28 && list->i.rs <= 29 &&
885                             !block->state->maps[PSX_MAP_KERNEL_USER_RAM].ops) {
886                                 pr_debug("Flaging opcode 0x%08x as not requiring invalidation\n",
887                                          list->opcode);
888                                 list->flags |= LIGHTREC_NO_INVALIDATE;
889                         }
890
891                         /* Detect writes whose destination address is inside the
892                          * current block, using constant propagation. When these
893                          * occur, we mark the blocks as not compilable. */
894                         if ((known & BIT(list->i.rs)) &&
895                             kunseg(values[list->i.rs]) >= kunseg(block->pc) &&
896                             kunseg(values[list->i.rs]) < (kunseg(block->pc) +
897                                                           block->nb_ops * 4)) {
898                                 pr_debug("Self-modifying block detected\n");
899                                 block->flags |= BLOCK_NEVER_COMPILE;
900                                 list->flags |= LIGHTREC_SMC;
901                         }
902                 default: /* fall-through */
903                         break;
904                 }
905
906                 known = lightrec_propagate_consts(list->c, known, values);
907         }
908
909         return 0;
910 }
911
912 static bool is_mult32(const struct block *block, const struct opcode *op)
913 {
914         const struct opcode *next, *last = NULL;
915         u32 offset;
916
917         for (op = op->next; op != last; op = op->next) {
918                 switch (op->i.op) {
919                 case OP_BEQ:
920                 case OP_BNE:
921                 case OP_BLEZ:
922                 case OP_BGTZ:
923                 case OP_REGIMM:
924                 case OP_META_BEQZ:
925                 case OP_META_BNEZ:
926                         /* TODO: handle backwards branches too */
927                         if ((op->flags & LIGHTREC_LOCAL_BRANCH) &&
928                             (s16)op->c.i.imm >= 0) {
929                                 offset = op->offset + 1 + (s16)op->c.i.imm;
930
931                                 for (next = op; next->offset != offset;
932                                      next = next->next);
933
934                                 if (!is_mult32(block, next))
935                                         return false;
936
937                                 last = next;
938                                 continue;
939                         } else {
940                                 return false;
941                         }
942                 case OP_SPECIAL:
943                         switch (op->r.op) {
944                         case OP_SPECIAL_MULT:
945                         case OP_SPECIAL_MULTU:
946                         case OP_SPECIAL_DIV:
947                         case OP_SPECIAL_DIVU:
948                         case OP_SPECIAL_MTHI:
949                                 return true;
950                         case OP_SPECIAL_JR:
951                                 return op->r.rs == 31 &&
952                                         ((op->flags & LIGHTREC_NO_DS) ||
953                                          !(op->next->i.op == OP_SPECIAL &&
954                                            op->next->r.op == OP_SPECIAL_MFHI));
955                         case OP_SPECIAL_JALR:
956                         case OP_SPECIAL_MFHI:
957                                 return false;
958                         default:
959                                 continue;
960                         }
961                 default:
962                         continue;
963                 }
964         }
965
966         return last != NULL;
967 }
968
969 static int lightrec_flag_mults(struct block *block)
970 {
971         struct opcode *list, *prev;
972
973         for (list = block->opcode_list, prev = NULL; list;
974              prev = list, list = list->next) {
975                 if (list->i.op != OP_SPECIAL)
976                         continue;
977
978                 switch (list->r.op) {
979                 case OP_SPECIAL_MULT:
980                 case OP_SPECIAL_MULTU:
981                         break;
982                 default:
983                         continue;
984                 }
985
986                 /* Don't support MULT(U) opcodes in delay slots */
987                 if (prev && has_delay_slot(prev->c))
988                         continue;
989
990                 if (is_mult32(block, list)) {
991                         pr_debug("Mark MULT(U) opcode at offset 0x%x as"
992                                  " 32-bit\n", list->offset << 2);
993                         list->flags |= LIGHTREC_MULT32;
994                 }
995         }
996
997         return 0;
998 }
999
1000 static int (*lightrec_optimizers[])(struct block *) = {
1001         &lightrec_detect_impossible_branches,
1002         &lightrec_transform_ops,
1003         &lightrec_local_branches,
1004         &lightrec_switch_delay_slots,
1005         &lightrec_flag_stores,
1006         &lightrec_flag_mults,
1007         &lightrec_early_unload,
1008 };
1009
1010 int lightrec_optimize(struct block *block)
1011 {
1012         unsigned int i;
1013
1014         for (i = 0; i < ARRAY_SIZE(lightrec_optimizers); i++) {
1015                 int ret = lightrec_optimizers[i](block);
1016
1017                 if (ret)
1018                         return ret;
1019         }
1020
1021         return 0;
1022 }