diff --git a/src/book.rs b/src/book.rs index 6dd954b..f1dac1d 100644 --- a/src/book.rs +++ b/src/book.rs @@ -45,7 +45,7 @@ impl Book { let reader = BufReader::new(f); let mut book = Book::new(); for line in reader.lines() { - book.append(Record::parse(&line?)?)?; + book.append(line?.parse::()?)?; } Ok(book) } @@ -162,7 +162,7 @@ impl Book { } hands.push(hand); } - new_records.push(Record::new(rec.get_initial(), &hands, last_score.unwrap())); + new_records.push(Record::new(rec.get_initial(), &hands, last_score)); } new_records.dedup(); Book::from_records(&new_records) @@ -242,7 +242,7 @@ fn play_with_book( hands.push(hand); board = board.play_hand(hand).unwrap(); } - let record = Record::new(Board::initial_state(), &hands, board.score().into()); + let record = Record::new(Board::initial_state(), &hands, Some(board.score().into())); eprintln!("{}", record); book.lock().unwrap().append(record).unwrap(); } diff --git a/src/play.rs b/src/play.rs index 0ee320e..8fd688a 100644 --- a/src/play.rs +++ b/src/play.rs @@ -8,7 +8,6 @@ use crate::engine::search::*; use crate::engine::table::*; use crate::engine::think::*; use crate::setup::*; -use crate::train::*; use clap::ArgMatches; use rand::prelude::*; use rayon::prelude::*; @@ -377,11 +376,6 @@ pub fn codingame(_matches: &ArgMatches) -> Result<(), Box solve_with_move(board.board, &mut solve_obj, &sub_solver, None) }; solve_obj.cache_gen += 1; - match best { - Hand::Play(pos) => { - println!("{}", pos_to_str(pos).to_ascii_lowercase()); - } - _ => panic!(), - } + println!("{}", format!("{}", best).to_ascii_lowercase()); } } diff --git a/src/record.rs b/src/record.rs index cabdc66..6bfb428 100644 --- a/src/record.rs +++ b/src/record.rs @@ -1,18 +1,36 @@ +#[cfg(test)] +mod test; use crate::engine::board::*; use crate::engine::hand::*; use anyhow::Result; use std::fmt::*; +use std::fs::File; +use std::io::{BufRead, BufReader, Read}; +use std::path::Path; use std::str::FromStr; +use thiserror::Error; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct Record { initial_board: Board, hands: Vec, - final_score: i16, + final_score: Option, +} + +#[derive(Error, Debug)] +#[error("Score is not registered")] +pub struct ScoreIsNotRegistered {} + +#[derive(Error, Debug)] +pub enum ParseRecordError { + #[error("Failed to parse hand")] + FailedToParseHand, + #[error("invalid hand")] + InvalidHand, } impl Record { - pub fn new(initial_board: Board, hands: &[Hand], final_score: i16) -> Record { + pub fn new(initial_board: Board, hands: &[Hand], final_score: Option) -> Record { Record { initial_board, hands: hands.to_vec(), @@ -20,24 +38,6 @@ impl Record { } } - pub fn parse(record_str: &str) -> Result { - let mut hands = Vec::new(); - let mut board = Board::initial_state(); - let splitted = record_str.split_ascii_whitespace().collect::>(); - let l = splitted[0].len(); - for i in 0..(l / 2) { - let h = Hand::from_str(&record_str[(2 * i)..(2 * i + 2)])?; - hands.push(h); - board = board.play_hand(h).ok_or(UnmovableError {})?; - } - let score = if let Some(score) = splitted.get(1) { - score.parse().unwrap() - } else { - board.score() as i16 - }; - Ok(Record::new(Board::initial_state(), &hands, score)) - } - pub fn get_initial(&self) -> Board { self.initial_board } @@ -45,10 +45,11 @@ impl Record { pub fn timeline(&self) -> Result> { let mut board = self.initial_board; let mut res = Vec::new(); + let final_score = self.final_score.ok_or(ScoreIsNotRegistered {})?; let mut score = if self.hands.len() % 2 == 0 { - self.final_score + final_score } else { - -self.final_score + -final_score }; for &h in &self.hands { res.push((board, h, score)); @@ -65,7 +66,80 @@ impl Display for Record { for h in &self.hands { write!(f, "{}", h)?; } - write!(f, " {}", self.final_score)?; + if let Some(final_score) = self.final_score { + write!(f, " {}", final_score)?; + } Ok(()) } } + +impl FromStr for Record { + type Err = ParseRecordError; + + fn from_str(record_str: &str) -> Result { + let mut hands = Vec::new(); + let mut board = Board::initial_state(); + let splitted = record_str.split_ascii_whitespace().collect::>(); + let l = splitted[0].len(); + for i in 0..(l / 2) { + let h = splitted[0][(2 * i)..(2 * i + 2)] + .parse::() + .or(Err(ParseRecordError::FailedToParseHand))?; + board = match board.play_hand(h) { + Some(next) => next, + None => { + let passed = board.pass().ok_or(ParseRecordError::InvalidHand)?; + match passed.play_hand(h) { + Some(next) => { + hands.push(Hand::Pass); + next + } + None => return Err(ParseRecordError::InvalidHand.into()), + } + } + }; + hands.push(h); + } + let score = if let Some(score) = splitted.get(1) { + score.parse().ok() + } else if board.is_gameover() { + Some(board.score() as i16) + } else { + None + }; + Ok(Record::new(Board::initial_state(), &hands, score)) + } +} + +pub struct LoadRecords { + reader: BufReader, + buffer: String, + remain: usize, +} + +impl Iterator for LoadRecords { + type Item = Result; + fn next(&mut self) -> Option { + if self.remain > 0 { + self.remain -= 1; + self.reader.read_line(&mut self.buffer).ok()?; + return Some(self.buffer.parse::().map_err(|e| e.into())); + } + None + } +} + +pub fn load_records(path: &Path) -> Result> { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + let mut buffer = String::new(); + + reader.read_line(&mut buffer)?; + let remain = buffer.trim().parse()?; + + Ok(LoadRecords { + reader, + buffer, + remain, + }) +} diff --git a/src/record/test.rs b/src/record/test.rs new file mode 100644 index 0000000..559635e --- /dev/null +++ b/src/record/test.rs @@ -0,0 +1,22 @@ +extern crate test; +use super::*; + +#[test] +fn test_parse_record() { + let record = "f5d6c3d3 0".parse::().unwrap(); + assert_eq!(record.initial_board, Board::initial_state()); + let timeline = record.timeline().unwrap(); + assert_eq!(timeline.len(), 5); + assert_eq!(timeline[0].1, "f5".parse::().unwrap()); + assert_eq!(timeline[1].1, "d6".parse::().unwrap()); + assert_eq!(timeline[2].1, "c3".parse::().unwrap()); + assert_eq!(timeline[3].1, "d3".parse::().unwrap()); + assert_eq!(timeline[4].1, Hand::Pass); +} + +#[test] +fn test_parse_record_with_pass() { + let record_with_pass = "f5f6d3g5h5h4f7h6psc5 0".parse::().unwrap(); + let record_without_pass = "f5f6d3g5h5h4f7h6c5 0".parse::().unwrap(); + assert_eq!(record_with_pass, record_without_pass); +} diff --git a/src/train.rs b/src/train.rs index 704e4fa..f0183e6 100644 --- a/src/train.rs +++ b/src/train.rs @@ -5,6 +5,7 @@ use crate::engine::hand::*; use crate::engine::pattern_eval::*; use crate::engine::table::*; use crate::engine::think::*; +use crate::record::*; use crate::sparse_mat::*; use clap::ArgMatches; use rayon::prelude::*; @@ -16,118 +17,16 @@ use std::path::Path; use std::str; use std::sync::Arc; -// parse pos string [A-H][1-8] -fn parse_pos(s: &[u8]) -> Option { - const CODE_1: u8 = '1' as u32 as u8; - const CODE_8: u8 = '8' as u32 as u8; - const CODE_A: u8 = 'A' as u32 as u8; - const CODE_H: u8 = 'H' as u32 as u8; - if s.len() != 2 { - None - } else if s[0] < CODE_A || s[0] > CODE_H { - None - } else if s[1] < CODE_1 || s[1] > CODE_8 { - None - } else { - Some(((s[0] - CODE_A) + (s[1] - CODE_1) * 8) as usize) - } -} - -pub fn parse_record(line: &str) -> Vec { - let mut result = Vec::new(); - for chunk in line.as_bytes().chunks(2) { - if chunk == "ps".as_bytes() { - continue; - } - match parse_pos(chunk) { - Some(pos) => result.push(pos), - None => { - return result; - } - } - } - result -} - -pub fn step_by_pos(board: &Board, pos: usize) -> Option { - match board.play(pos) { - Some(next) => Some(next), - None => { - if !board.mobility().is_empty() { - None - } else { - board.pass_unchecked().play(pos) - } - } - } -} - -pub fn collect_boards(record: &[usize]) -> Option> { - let mut board = Board { - player: 0x0000_0008_1000_0000, - opponent: 0x0000_0010_0800_0000, - }; - let mut boards = Vec::with_capacity(70); // enough large - for &pos in record { - boards.push(board); - board = match step_by_pos(&board, pos) { - Some(next) => next, - None => { - return None; - } - }; - } - if !board.is_gameover() { - return None; - } - boards.push(board); - Some(boards) -} - -fn boards_from_record_impl(board: Board, record: &[usize]) -> (Vec<(Board, i8, Hand)>, i8) { - match record.first() { - Some(&first) => { - let ((mut boards, score), hand) = if board.mobility_bits() == 0 { - ( - boards_from_record_impl(board.pass_unchecked(), record), - Hand::Pass, - ) - } else { - ( - boards_from_record_impl(step_by_pos(&board, first).unwrap(), &record[1..]), - Hand::Play(first), - ) - }; - boards.insert(0, (board, -score, hand)); - (boards, -score) - } - None => (vec![(board, board.score(), Hand::Pass)], board.score()), - } -} - -fn boards_from_record(line: &str) -> Vec<(Board, i8, Hand)> { - let record = parse_record(line); - let board = Board::initial_state(); - boards_from_record_impl(board, &record).0 -} - pub fn clean_record(matches: &ArgMatches) { let input_path = matches.get_one::("INPUT").unwrap(); let output_path = matches.get_one::("OUTPUT").unwrap(); - let in_f = File::open(input_path).unwrap(); - let mut reader = BufReader::new(in_f); - - let mut input_line = String::new(); - reader.read_line(&mut input_line).unwrap(); - let num_records = input_line.trim().parse().unwrap(); let mut result = Vec::new(); - for _i in 0..num_records { - let mut input_line = String::new(); - reader.read_line(&mut input_line).unwrap(); - let record = parse_record(&input_line); - if let Some(_boards) = collect_boards(&record) { - result.push(input_line); + for record in load_records(Path::new(input_path)).unwrap() { + if let Ok(record) = record { + if let Ok(_timeline) = record.timeline() { + result.push(record); + } } } @@ -135,39 +34,21 @@ pub fn clean_record(matches: &ArgMatches) { let mut writer = BufWriter::new(out_f); writeln!(writer, "{}", result.len()).unwrap(); - for line in result { - write!(writer, "{}", line).unwrap(); + for record in result { + write!(writer, "{}", record).unwrap(); } } -pub fn pos_to_str(pos: usize) -> String { - let row = pos / 8; - let col = pos % 8; - let first = (col as u8) + b'A'; - let second = (row as u8) + b'1'; - let mut result = String::new(); - result.push(first as char); - result.push(second as char); - result -} - pub fn gen_dataset(matches: &ArgMatches) { let input_path = matches.get_one::("INPUT").unwrap(); let output_path = matches.get_one::("OUTPUT").unwrap(); let max_output = *matches.get_one::("MAX_OUT").unwrap(); eprintln!("Parse input..."); - let in_f = File::open(input_path).unwrap(); - let mut reader = BufReader::new(in_f); - - let mut input_line = String::new(); - reader.read_line(&mut input_line).unwrap(); - let num_records = input_line.trim().parse().unwrap(); let mut boards_with_results = Vec::new(); - for _i in 0..num_records { - let mut input_line = String::new(); - reader.read_line(&mut input_line).unwrap(); - boards_with_results.append(&mut boards_from_record(&input_line)); + for record in load_records(Path::new(input_path)).unwrap() { + let mut timeline = record.unwrap().timeline().unwrap(); + boards_with_results.append(&mut timeline); } eprintln!("Total board count = {}", boards_with_results.len()); @@ -182,7 +63,7 @@ pub fn gen_dataset(matches: &ArgMatches) { min(boards_with_results.len(), max_output) ) .unwrap(); - for (idx, (board, score, hand)) in boards_with_results.iter().enumerate() { + for (idx, (board, hand, score)) in boards_with_results.iter().enumerate() { if idx >= max_output { break; }