|
| 1 | +import argparse |
| 2 | +import datetime |
| 3 | +import os.path |
| 4 | +import sys |
| 5 | + |
| 6 | +from loda.lang import Program |
| 7 | +from loda.ml.keras.program_generation_rnn import load_model, Generator |
| 8 | + |
| 9 | + |
| 10 | +def eprint(*args, **kwargs): |
| 11 | + print(*args, file=sys.stderr, **kwargs) |
| 12 | + |
| 13 | + |
| 14 | +def generate_programs(generator, num_programs: int, use_line_format: bool, write_fn, verbose=0): |
| 15 | + for i in range(num_programs): |
| 16 | + p = generator() |
| 17 | + if use_line_format: |
| 18 | + p = "; ".join([str(op) for op in p.operations]) |
| 19 | + write_fn("{}\n".format(p)) |
| 20 | + if verbose > 0 and i % 10 == 0: |
| 21 | + ct = datetime.datetime.now() |
| 22 | + eprint(ct, generator.get_stats_info_str()) |
| 23 | + |
| 24 | + |
| 25 | +def generate(model_path: str, output_path=None, num_programs=100, format="asm", verbose=0): |
| 26 | + model = load_model(model_path) |
| 27 | + if verbose > 0: |
| 28 | + model.summary(print_fn=eprint) |
| 29 | + initial_program = Program() |
| 30 | + # initial_program.operations.append(Operation("mov $1,1")) |
| 31 | + num_lanes = 10 |
| 32 | + if num_programs >= 1000: |
| 33 | + num_lanes = 100 |
| 34 | + elif num_programs >= 10000: |
| 35 | + num_lanes = 1000 |
| 36 | + generator = Generator( |
| 37 | + model, initial_program=initial_program, num_lanes=100) |
| 38 | + use_line_format = (format == "line") |
| 39 | + if output_path: |
| 40 | + with open(output_path, "w") as file: |
| 41 | + generate_programs(generator, num_programs, |
| 42 | + use_line_format, file.write, verbose) |
| 43 | + else: |
| 44 | + generate_programs(generator, num_programs, |
| 45 | + use_line_format, sys.stdout.write, verbose) |
| 46 | + |
| 47 | + |
| 48 | +if __name__ == "__main__": |
| 49 | + parser = argparse.ArgumentParser() |
| 50 | + parser.add_argument("model", type=str) |
| 51 | + parser.add_argument( |
| 52 | + "-f", "--format", type=str, choices=["asm", "line"], help="output format of the generated programs") |
| 53 | + parser.add_argument( |
| 54 | + "-o", "--output", type=str, help="output file for writing the programs to") |
| 55 | + parser.add_argument( |
| 56 | + "-n", type=int, help="number of programs to generate", default=100) |
| 57 | + parser.add_argument('-v', '--verbose', action='count', default=0) |
| 58 | + args = parser.parse_args() |
| 59 | + generate(model_path=args.model, output_path=args.output, |
| 60 | + num_programs=args.n, format=args.format, verbose=args.verbose) |
0 commit comments