-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
140 lines (122 loc) · 4.56 KB
/
main.py
File metadata and controls
140 lines (122 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import sys
from arc_tools.logger import logger
import os
import json
from arc_tools.grid import Grid
from arc_tools.utils import debug_output
from train_tasks import *
from evaluation_tasks.tasks import *
show_count = 0
from collections import Counter, deque
from arc_tools.grid import SubGrid
from typing import Sequence
# Exception hook
import sys
import traceback
from types import TracebackType
def handle_exception(
type_: type[BaseException], value: BaseException, tb: TracebackType | None
) -> None:
traceback.print_exception(type_, value, tb)
sys.exit(1)
sys.excepthook = handle_exception
# =============
if 0:
normal_task_fns = [
check_fit,
move_object_without_collision,
repeat_reverse_grid,
]
else:
normal_task_fns = [
# row_col_color_data,
# color_swap_and_move_to_corner,
# dot_to_object,
# rope_stretch,
# fit_or_swap_fit
]
jigsaw_task_fns = [
# jigsaw_puzzle,
row_col_color_data, # can occur in normal task
]
def find_task(grids, expected_outputs, start_train_task_id=1):
if len(grids[0][0]) == len(expected_outputs[0][0]):
task_fns = normal_task_fns
else:
task_fns = jigsaw_task_fns
if actual_task_name:
task_fns = [globals()[actual_task_name]]
for task_fn in task_fns:
logger.info(task_fn.__name__)
right_task = True
for task_id, (grid, expected_output) in enumerate(zip(grids, expected_outputs), start_train_task_id):
expected_output = Grid(expected_output)
output = task_fn(grid.copy())
if not output.compare(expected_output):
debug_output(grid, expected_output, output, f'train_{task_id}_result')
if actual_task_name:
raise Exception(f'Train task {task_id} failed')
right_task = False
break
logger.info(f'Train task {task_id} passed')
if right_task:
return task_fn
logger.info('--------------------------------')
return None
def solve_task(data):
start_train_task_id = 1
start_test_task_id = 1
actual_task_name = None
start_train_task_id = 1
start_test_task_id = 1
num_train_tasks = len(data['train'])
num_test_tasks = len(data['test'])
logger.info(f"Number of train tasks: {num_train_tasks}, Number of test tasks: {num_test_tasks}")
grids = []
expected_outputs = []
actual_outputs = []
with open('reference_output.json', 'w') as f:
json.dump(data['train'][0]['output'], f)
if not actual_task_name:
for task_idx in range(start_train_task_id - 1, num_train_tasks):
grids.append(Grid(data['train'][task_idx]['input']))
expected_outputs.append(data['train'][task_idx]['output'])
task_fn = find_task(grids, expected_outputs, start_train_task_id)
if 0:
if task_fn:
logger.info(f"Found task: {task_fn.__name__}")
else:
logger.info(f"Task not found")
else:
task_fn = globals()[actual_task_name]
for task_idx in range(start_test_task_id - 1, num_test_tasks):
grid = Grid(data['test'][task_idx]['input'])
if task_fn:
expected_output = Grid(data['test'][task_idx].get('output'))
output = task_fn(grid.copy())
if expected_output:
if output.compare(expected_output):
logger.info(f"Test task {task_idx + 1} passed")
else:
logger.info(f"Test task {task_idx + 1} failed")
debug_output(grid, expected_output, output, f'test_{task_idx + 1}_result')
raise Exception(f"Incorrect task {task_idx + 1}: {task_fn.__name__}, Expected: {expected_output}, Actual: {output}")
output = {"attempt_1": output, "attempt_2": output}
else:
output = {"attempt_1": grid, "attempt_2": grid}
actual_outputs.append(output)
return actual_outputs
if __name__ == "__main__":
if sys.argv[1:]:
task_hash = sys.argv[1]
else:
print('Usage: python main.py <task_hash> <task_name>')
exit()
actual_task_name = sys.argv[2] if sys.argv[2:] else None
split = ['evaluation', 'training']
for s in split:
file = rf'../ARC-AGI-2/data/{s}/{task_hash}.json'
if os.path.exists(file):
break
data = json.load(open(file, 'r'))
solve_task(data)