@@ -116,6 +116,11 @@ def extend(self, regions: Iterable[Region]):
116
116
for region in regions :
117
117
self .region_cache [region .player ][region .name ] = region
118
118
119
+ def add_group (self , new_id : int ):
120
+ self .region_cache [new_id ] = {}
121
+ self .entrance_cache [new_id ] = {}
122
+ self .location_cache [new_id ] = {}
123
+
119
124
def __iter__ (self ) -> Iterator [Region ]:
120
125
for regions in self .region_cache .values ():
121
126
yield from regions .values ()
@@ -223,6 +228,7 @@ def add_group(self, name: str, game: str, players: Set[int] = frozenset()) -> Tu
223
228
return group_id , group
224
229
new_id : int = self .players + len (self .groups ) + 1
225
230
231
+ self .regions .add_group (new_id )
226
232
self .game [new_id ] = game
227
233
self .player_types [new_id ] = NetUtils .SlotType .group
228
234
world_type = AutoWorld .AutoWorldRegister .world_types [game ]
@@ -621,7 +627,7 @@ class CollectionState():
621
627
additional_copy_functions : List [Callable [[CollectionState , CollectionState ], CollectionState ]] = []
622
628
623
629
def __init__ (self , parent : MultiWorld , allow_partial_entrances : bool = False ):
624
- self .prog_items = {player : Counter () for player in parent .player_ids }
630
+ self .prog_items = {player : Counter () for player in parent .get_all_ids () }
625
631
self .multiworld = parent
626
632
self .reachable_regions = {player : set () for player in parent .get_all_ids ()}
627
633
self .blocked_connections = {player : set () for player in parent .get_all_ids ()}
@@ -656,10 +662,9 @@ def update_reachable_regions(self, player: int):
656
662
if new_region in rrp :
657
663
bc .remove (connection )
658
664
elif connection .can_reach (self ):
659
- if not self .allow_partial_entrances :
660
- assert new_region , f"tried to search through an Entrance \" { connection } \" with no Region"
661
- elif not new_region :
665
+ if self .allow_partial_entrances and not new_region :
662
666
continue
667
+ assert new_region , f"tried to search through an Entrance \" { connection } \" with no Region"
663
668
rrp .add (new_region )
664
669
bc .remove (connection )
665
670
bc .update (new_region .exits )
@@ -716,37 +721,43 @@ def sweep_for_events(self, key_only: bool = False, locations: Optional[Iterable[
716
721
assert isinstance (event .item , Item ), "tried to collect Event with no Item"
717
722
self .collect (event .item , True , event )
718
723
724
+ # item name related
719
725
def has (self , item : str , player : int , count : int = 1 ) -> bool :
720
726
return self .prog_items [player ][item ] >= count
721
727
722
- def has_all (self , items : Set [str ], player : int ) -> bool :
728
+ def has_all (self , items : Iterable [str ], player : int ) -> bool :
723
729
"""Returns True if each item name of items is in state at least once."""
724
730
return all (self .prog_items [player ][item ] for item in items )
725
731
726
- def has_any (self , items : Set [str ], player : int ) -> bool :
732
+ def has_any (self , items : Iterable [str ], player : int ) -> bool :
727
733
"""Returns True if at least one item name of items is in state at least once."""
728
734
return any (self .prog_items [player ][item ] for item in items )
729
735
730
736
def count (self , item : str , player : int ) -> int :
731
737
return self .prog_items [player ][item ]
732
738
739
+ def item_count (self , item : str , player : int ) -> int :
740
+ Utils .deprecate ("Use count instead." )
741
+ return self .count (item , player )
742
+
743
+ # item name group related
733
744
def has_group (self , item_name_group : str , player : int , count : int = 1 ) -> bool :
734
745
found : int = 0
746
+ player_prog_items = self .prog_items [player ]
735
747
for item_name in self .multiworld .worlds [player ].item_name_groups [item_name_group ]:
736
- found += self . prog_items [ player ] [item_name ]
748
+ found += player_prog_items [item_name ]
737
749
if found >= count :
738
750
return True
739
751
return False
740
752
741
753
def count_group (self , item_name_group : str , player : int ) -> int :
742
754
found : int = 0
755
+ player_prog_items = self .prog_items [player ]
743
756
for item_name in self .multiworld .worlds [player ].item_name_groups [item_name_group ]:
744
- found += self . prog_items [ player ] [item_name ]
757
+ found += player_prog_items [item_name ]
745
758
return found
746
759
747
- def item_count (self , item : str , player : int ) -> int :
748
- return self .prog_items [player ][item ]
749
-
760
+ # Item related
750
761
def collect (self , item : Item , event : bool = False , location : Optional [Location ] = None ) -> bool :
751
762
if location :
752
763
self .locations_checked .add (location )
@@ -774,7 +785,7 @@ def remove(self, item: Item):
774
785
775
786
776
787
class Entrance :
777
- class Type (IntEnum ):
788
+ class EntranceType (IntEnum ):
778
789
ONE_WAY = 1
779
790
TWO_WAY = 2
780
791
@@ -785,13 +796,13 @@ class Type(IntEnum):
785
796
parent_region : Optional [Region ]
786
797
connected_region : Optional [Region ] = None
787
798
er_group : str
788
- er_type : Type
799
+ er_type : EntranceType
789
800
# LttP specific, TODO: should make a LttPEntrance
790
801
addresses = None
791
802
target = None
792
803
793
- def __init__ (self , player : int , name : str = '' , parent : Region = None ,
794
- er_group : str = ' Default' , er_type : Type = Type .ONE_WAY ):
804
+ def __init__ (self , player : int , name : str = "" , parent : Region = None ,
805
+ er_group : str = " Default" , er_type : EntranceType = EntranceType .ONE_WAY ):
795
806
self .name = name
796
807
self .parent_region = parent
797
808
self .player = player
@@ -800,7 +811,7 @@ def __init__(self, player: int, name: str = '', parent: Region = None,
800
811
801
812
def can_reach (self , state : CollectionState ) -> bool :
802
813
if self .parent_region .can_reach (state ) and self .access_rule (state ):
803
- if not self .hide_path and not self in state .path :
814
+ if not self .hide_path and self not in state .path :
804
815
state .path [self ] = (self .name , state .path .get (self .parent_region , (self .parent_region .name , None )))
805
816
return True
806
817
@@ -812,7 +823,7 @@ def connect(self, region: Region, addresses: Any = None, target: Any = None) ->
812
823
self .addresses = addresses
813
824
region .entrances .append (self )
814
825
815
- def is_valid_source_transition (self , state : ERPlacementState ) -> bool :
826
+ def is_valid_source_transition (self , state : " ERPlacementState" ) -> bool :
816
827
"""
817
828
Determines whether this is a valid source transition, that is, whether the entrance
818
829
randomizer is allowed to pair it to place any other regions. By default, this is the
@@ -823,7 +834,7 @@ def is_valid_source_transition(self, state: ERPlacementState) -> bool:
823
834
"""
824
835
return self .can_reach (state .collection_state )
825
836
826
- def can_connect_to (self , other : Entrance , state : ERPlacementState ) -> bool :
837
+ def can_connect_to (self , other : Entrance , state : " ERPlacementState" ) -> bool :
827
838
"""
828
839
Determines whether a given Entrance is a valid target transition, that is, whether
829
840
the entrance randomizer is allowed to pair this Entrance to that Entrance.
0 commit comments