Skip to content

Commit

Permalink
Add 3-arg Select and go over... (#1210)
Browse files Browse the repository at this point in the history
* Add 3-argument Select, by adding a parameter inside Structure's filter function.
* Go over Structure class: remove duplication of filter() and add the count functionality needed by Select[]
* Update Select doctests for 3 arg form. Move error checking to pytest where it belongs.

`Select[]` count-parameter issue is mentioned in Mathics3/Mathics3-Rubi#2
  • Loading branch information
rocky authored Dec 11, 2024
1 parent 00f55a6 commit ed5e212
Show file tree
Hide file tree
Showing 16 changed files with 219 additions and 124 deletions.
3 changes: 0 additions & 3 deletions mathics/builtin/atomic/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,6 @@ class _StringFind(Builtin, ABC):

messages = {
"srep": "`1` is not a valid string replacement rule.",
"innf": (
"Non-negative integer or Infinity expected at " "position `1` in `2`."
),
}

def _find(py_stri, py_rules, py_n, flags):
Expand Down
63 changes: 46 additions & 17 deletions mathics/builtin/list/eol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from itertools import chain

from mathics.builtin.box.layout import RowBox
from mathics.core.atoms import Integer, Integer0, Integer1, String
from mathics.core.atoms import Integer, Integer0, Integer1, Integer3, Integer4, String
from mathics.core.attributes import (
A_HOLD_FIRST,
A_HOLD_REST,
Expand All @@ -29,7 +29,7 @@
PartError,
PartRangeError,
)
from mathics.core.expression import Expression
from mathics.core.expression import Expression, ExpressionInfinity
from mathics.core.list import ListExpression
from mathics.core.rules import Rule
from mathics.core.symbols import Atom, Symbol, SymbolNull, SymbolTrue
Expand All @@ -41,6 +41,7 @@
SymbolKey,
SymbolMakeBoxes,
SymbolMissing,
SymbolSelect,
SymbolSequence,
SymbolSet,
)
Expand Down Expand Up @@ -433,7 +434,6 @@ class DeleteCases(Builtin):

messages = {
"level": "Level specification `1` is not of the form n, {n}, or {m, n}.",
"innf": "Non-negative integer or Infinity expected at position 4 in `1`",
}
summary_text = "delete all occurrences of a pattern"

Expand All @@ -451,14 +451,15 @@ def eval_ls_n(self, items, pattern, levelspec, n, evaluation):

levelspec = python_levelspec(levelspec)

if n is SymbolInfinity:
if n is SymbolInfinity or ExpressionInfinity == n:
n = -1
elif n.get_head_name() == "System`Integer":
n = n.get_int_value()
elif isinstance(n, Integer):
n = n.value
if n < 0:
evaluation.message(
"DeleteCases",
"innf",
Integer4,
Expression(SymbolDeleteCases, items, pattern, levelspec, n),
)
else:
Expand Down Expand Up @@ -1500,28 +1501,55 @@ class Select(Builtin):
<url>:WMA link:https://reference.wolfram.com/language/ref/Select.html</url>
<dl>
<dt>'Select[{$e1$, $e2$, ...}, $f$]'
<dd>returns a list of the elements $ei$ for which $f$[$ei$] returns 'True'.
<dt>'Select[{$e1$, $e2$, ...}, $crit$]'
<dd>returns a list of the elements $ei$ for which $crit$[$ei$] is 'True'.
<dt>'Select[{$e1$, $e2$, ...}, $crit$, n]'
<dd>returns a list of the first $n$ elements $ei$ for which $crit$[$ei$] is 'True'.
</dl>
Find numbers greater than zero:
>> Select[{-3, 0, 1, 3, a}, #>0&]
= {1, 3}
Get a list of even numbers up to 10:
>> Select[Range[10], EvenQ]
= {2, 4, 6, 8, 10}
Find numbers that are greater than zero in a list:
>> Select[{-3, 0, 10, 3, a}, #>0&]
= {10, 3}
Find the first number that is list greater than zero in a list:
>> Select[{-3, 0, 10, 3, a}, #>0&, 1]
= {10}
'Select' works on an expression with any head:
>> Select[f[a, 2, 3], NumberQ]
= f[2, 3]
>> Select[a, True]
: Nonatomic expression expected.
= Select[a, True]
"""

summary_text = "pick elements according to a criterion"

def eval(self, items, expr, evaluation):
def eval(self, items, expr, evaluation: Evaluation):
"Select[items_, expr_]"

return self.eval_with_n(items, expr, SymbolInfinity, evaluation)

def eval_with_n(self, items, expr, n, evaluation: Evaluation):
"Select[items_, expr_, n_]"

count_is_valid = True
if n is SymbolInfinity or ExpressionInfinity == n:
count = None
elif isinstance(n, Integer):
count = n.value
if count < 0:
count_is_valid = False
else:
count_is_valid = False

if not count_is_valid:
evaluation.message(
"Select", "innf", Integer3, Expression(SymbolSelect, items, expr, n)
)
return

if isinstance(items, Atom):
evaluation.message("Select", "normal")
return
Expand All @@ -1530,7 +1558,7 @@ def cond(element):
test = Expression(expr, element)
return test.evaluate(evaluation) is SymbolTrue

return items.filter(items.head, cond, evaluation)
return items.filter(items.head, cond, evaluation, count=count)


class Span(InfixOperator):
Expand Down Expand Up @@ -1616,6 +1644,7 @@ class UpTo(Builtin):
</dl>
"""

# TODO: is there as way we can use general's innf?
messages = {
"innf": "Expected non-negative integer or infinity at position 1 in ``.",
"argx": "UpTo expects 1 argument, `1` arguments were given.",
Expand Down
2 changes: 1 addition & 1 deletion mathics/builtin/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class General(Builtin):
"Single or list of non-negative integers expected at " "position `1`."
),
"indet": "Indeterminate expression `1` encountered.",
"innf": "Non-negative integer or Infinity expected at position `1`.",
"innf": "Non-negative integer or Infinity expected at position `1` in `2`",
"int": "Integer expected.",
"intp": "Positive integer expected.",
"intnn": "Non-negative integer expected.",
Expand Down
23 changes: 15 additions & 8 deletions mathics/builtin/patterns/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,21 @@

from typing import Optional as OptionalType

from mathics.core.atoms import Integer, Integer0, Integer2, Number
from mathics.core.atoms import Integer, Integer0, Integer2, Integer3, Number
from mathics.core.attributes import A_HOLD_REST, A_PROTECTED, A_SEQUENCE_HOLD
from mathics.core.builtin import AtomBuiltin, Builtin, InfixOperator, PatternError
from mathics.core.element import BaseElement
from mathics.core.evaluation import Evaluation
from mathics.core.exceptions import InvalidLevelspecError
from mathics.core.expression import Expression
from mathics.core.expression import Expression, ExpressionInfinity
from mathics.core.list import ListExpression
from mathics.core.symbols import SymbolTrue
from mathics.core.systemsymbols import SymbolInfinity, SymbolRule, SymbolRuleDelayed
from mathics.core.systemsymbols import (
SymbolInfinity,
SymbolReplaceList,
SymbolRule,
SymbolRuleDelayed,
)
from mathics.eval.rules import (
Dispatch,
create_rules,
Expand Down Expand Up @@ -341,9 +346,6 @@ class ReplaceList(Builtin):
Like in 'ReplaceAll', $rules$ can be a nested list:
>> ReplaceList[{a, b, c}, {{{___, x__, ___} -> {x}}, {{a, b, c} -> t}}, 2]
= {{{a}, {a, b}}, {t}}
>> ReplaceList[expr, {}, -1]
: Non-negative integer or Infinity expected at position 3.
= ReplaceList[expr, {}, -1]
Possible matches for a sum:
>> ReplaceList[a + b + c, x_ + y_ -> {x, y}]
Expand All @@ -369,12 +371,17 @@ def eval(
# default argument, when it is passed explictly, e.g.
# ReplaceList[expr, {}, Infinity], then Infinity
# comes in as DirectedInfinity[1].
if maxidx == SymbolInfinity:
if maxidx == SymbolInfinity or ExpressionInfinity == maxidx:
max_count = None
else:
max_count = maxidx.get_int_value()
if max_count is None or max_count < 0:
evaluation.message("ReplaceList", "innf", 3)
evaluation.message(
"ReplaceList",
"innf",
Integer3,
Expression(SymbolReplaceList, expr, rules, maxidx),
)
return None
try:
rules, ret = create_rules(
Expand Down
8 changes: 4 additions & 4 deletions mathics/builtin/scipy_utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from mathics.core.convert.function import expression_to_callable_and_args
from mathics.core.element import BaseElement
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.systemsymbols import SymbolAutomatic, SymbolFailed, SymbolInfinity
from mathics.core.expression import Expression, ExpressionInfinity
from mathics.core.systemsymbols import SymbolAutomatic, SymbolFailed
from mathics.eval.nevaluator import eval_N

if not check_requires_list(["scipy", "numpy"]):
Expand All @@ -30,7 +30,7 @@ def get_tolerance_and_maxit(opts: dict, scale: float, evaluation: Evaluation):
acc_goal = eval_N(acc_goal, evaluation)
if acc_goal is SymbolAutomatic:
acc_goal = Real(12.0)
elif acc_goal is SymbolInfinity:
elif ExpressionInfinity == acc_goal:
acc_goal = None
elif not isinstance(acc_goal, Number):
acc_goal = None
Expand All @@ -40,7 +40,7 @@ def get_tolerance_and_maxit(opts: dict, scale: float, evaluation: Evaluation):
prec_goal = eval_N(prec_goal, evaluation)
if prec_goal is SymbolAutomatic:
prec_goal = Real(12.0)
elif prec_goal is SymbolInfinity:
elif ExpressionInfinity == prec_goal:
prec_goal = None
elif not isinstance(prec_goal, Number):
prec_goal = None
Expand Down
4 changes: 2 additions & 2 deletions mathics/builtin/string/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
mathics_split,
to_regex,
)
from mathics.core.atoms import Integer, Integer1, String
from mathics.core.atoms import Integer, Integer1, Integer3, String
from mathics.core.attributes import (
A_FLAT,
A_LISTABLE,
Expand Down Expand Up @@ -437,7 +437,7 @@ def eval_n(self, string, patt, n, evaluation: Evaluation, options: dict):
else:
py_n = n.get_int_value()
if py_n is None or py_n < 0:
evaluation.message("StringPosition", "innf", expr, Integer(3))
evaluation.message("StringPosition", "innf", expr, Integer3)
return

# check options
Expand Down
1 change: 1 addition & 0 deletions mathics/core/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def is_zero(self) -> bool:
Integer1 = Integer(1)
Integer2 = Integer(2)
Integer3 = Integer(3)
Integer4 = Integer(4)
Integer310 = Integer(310)
Integer10 = Integer(10)
IntegerM1 = Integer(-1)
Expand Down
9 changes: 6 additions & 3 deletions mathics/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import sympy

from mathics.core.atoms import Integer, String
from mathics.core.atoms import Integer, Integer1, String
from mathics.core.attributes import (
A_FLAT,
A_HOLD_ALL,
Expand Down Expand Up @@ -625,9 +625,9 @@ def evaluate_elements(self, evaluation) -> "Expression":
head = head.evaluate_elements(evaluation)
return Expression(head, *elements)

def filter(self, head, cond, evaluation):
def filter(self, head, cond, evaluation: Evaluation, count: Optional[int] = None):
# faster equivalent to: Expression(head, [element in self.elements if cond(element)])
return structure(head, self, evaluation).filter(self, cond)
return structure(head, self, evaluation).filter(self, cond, count)

# FIXME: go over and preserve elements_properties.
def flatten_pattern_sequence(self, evaluation):
Expand Down Expand Up @@ -2068,3 +2068,6 @@ def convert_expression_elements(

def string_list(head, elements, evaluation):
return atom_list_constructor(evaluation, head, "String")(elements)


ExpressionInfinity = Expression(SymbolDirectedInfinity, Integer1)
41 changes: 26 additions & 15 deletions mathics/core/structure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-

from abc import ABC
from typing import Optional

class Structure:

class Structure(ABC):
"""
Structure helps implementations make the ExpressionCache not invalidate across simple commands
such as Take[], Most[], etc. without this, constant reevaluation of lists happens, which results
Expand All @@ -11,18 +14,32 @@ class Structure:
"""

def __call__(self, elements):
# create an Expression with the given list "elements" as elements.
# NOTE: the caller guarantees that "elements" only contains items that are from "origins".
"""Return an Expression with the given list "elements" as elements.
Tthe caller guarantees that "elements" only contains items that are from "origins
"""
raise NotImplementedError

def filter(self, expr, cond):
# create an Expression with a subset of "expr".elements (picked out by the filter "cond").
# NOTE: the caller guarantees that "expr" is from "origins".
raise NotImplementedError
def filter(self, expr, condition, count: Optional[int] = None):
"""
Returns self type consisting of `expr` filtered by `condition`.
If `count` is not None, return at most `count` elements.
"""
if count is None:
result = [element for element in expr.elements if condition(element)]
else:
result = []
for element in expr.elements:
if condition(element):
result.append(element)
count -= 1
if count == 0:
break

return self(result)

def slice(self, expr, py_slice):
# create an Expression, using the given slice of "expr".elements as elements.
# NOTE: the caller guarantees that "expr" is from "origins".
"""create an Expression, using the given slice of "expr".elements as elements.
The caller guarantees that "expr" is from "origins"."""
raise NotImplementedError


Expand Down Expand Up @@ -52,9 +69,6 @@ def __call__(self, elements):
# from mathics.core.convert.expression import to_expression_with_specialization
# return to_expression_with_specialization(self._head, *new_elements)

def filter(self, expr, cond):
return self([element for element in expr.elements if cond(element)])

def slice(self, expr, py_slice):
elements = expr.elements
lower, upper, step = py_slice.indices(len(elements))
Expand Down Expand Up @@ -82,9 +96,6 @@ def __call__(self, elements):
expr._cache = self._cache.reordered()
return expr

def filter(self, expr, cond):
return self([element for element in expr.elements if cond(element)])

def slice(self, expr, py_slice):
elements = expr.elements
lower, upper, step = py_slice.indices(len(elements))
Expand Down
2 changes: 2 additions & 0 deletions mathics/core/systemsymbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
SymbolRealSign = Symbol("System`RealSign")
SymbolRepeated = Symbol("System`Repeated")
SymbolRepeatedNull = Symbol("System`RepeatedNull")
SymbolReplaceList = Symbol("System`ReplaceList")
SymbolReturn = Symbol("System`Return")
SymbolReverse = Symbol("System`Reverse")
SymbolRight = Symbol("System`Right")
Expand All @@ -225,6 +226,7 @@
SymbolRule = Symbol("System`Rule")
SymbolRuleDelayed = Symbol("System`RuleDelayed")
SymbolSameQ = Symbol("System`SameQ")
SymbolSelect = Symbol("System`Select")
SymbolSequence = Symbol("System`Sequence")
SymbolSeries = Symbol("System`Series")
SymbolSeriesData = Symbol("System`SeriesData")
Expand Down
Loading

0 comments on commit ed5e212

Please sign in to comment.