Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add missing ops #1463

Merged
merged 7 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions hugr-core/src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ pub enum FloatOps {
fabs,
fmul,
fdiv,
fpow,
ffloor,
fceil,
fround,
ftostring,
}

Expand All @@ -60,10 +62,10 @@ impl MakeOpDef for FloatOps {
feq | fne | flt | fgt | fle | fge => {
Signature::new(type_row![FLOAT64_TYPE; 2], type_row![BOOL_T])
}
fmax | fmin | fadd | fsub | fmul | fdiv => {
fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
Signature::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE])
}
fneg | fabs | ffloor | fceil => Signature::new_endo(type_row![FLOAT64_TYPE]),
fneg | fabs | ffloor | fceil | fround => Signature::new_endo(type_row![FLOAT64_TYPE]),
ftostring => Signature::new(type_row![FLOAT64_TYPE], STRING_TYPE),
}
.into()
Expand All @@ -86,8 +88,10 @@ impl MakeOpDef for FloatOps {
fabs => "absolute value",
fmul => "multiplication",
fdiv => "division",
fpow => "exponentiation",
ffloor => "floor",
fceil => "ceiling",
fround => "round",
ftostring => "string representation",
}
.to_string()
Expand Down Expand Up @@ -133,6 +137,9 @@ impl MakeRegisteredOp for FloatOps {

#[cfg(test)]
mod test {
use cgmath::AbsDiffEq;
use rstest::rstest;

use super::*;

#[test]
Expand All @@ -144,4 +151,47 @@ mod test {
assert!(name.as_str().starts_with('f'));
}
}

#[rstest]
#[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😁

#[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])]
#[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])]
#[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])]
#[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])]
#[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])]
#[case::fceil(FloatOps::fceil, &[42.42], &[43.])]
#[case::fround(FloatOps::fround, &[42.42], &[42.])]
fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) {
use crate::ops::Value;
use crate::std_extensions::arithmetic::float_types::ConstF64;

let consts: Vec<_> = inputs
.iter()
.enumerate()
.map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x))))
.collect();

let res = op
.to_extension_op()
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, expected) in outputs.iter().enumerate() {
let res_val: f64 = res
.get(i)
.unwrap()
.1
.get_custom_value::<ConstF64>()
.expect("This function assumes all incoming constants are floats.")
.value();

assert!(
res_val.abs_diff_eq(expected, f64::EPSILON),
"expected {:?}, got {:?}",
expected,
res_val
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ pub(super) fn set_fold(op: &FloatOps, def: &mut OpDef) {
use FloatOps::*;

match op {
fmax | fmin | fadd | fsub | fmul | fdiv => def.set_constant_folder(BinaryFold::from_op(op)),
fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
def.set_constant_folder(BinaryFold::from_op(op))
}
feq | fne | flt | fgt | fle | fge => def.set_constant_folder(CmpFold::from_op(*op)),
fneg | fabs | ffloor | fceil => def.set_constant_folder(UnaryFold::from_op(op)),
fneg | fabs | ffloor | fceil | fround => def.set_constant_folder(UnaryFold::from_op(op)),
ftostring => def.set_constant_folder(ToStringFold::from_op(op)),
}
}
Expand Down Expand Up @@ -43,6 +45,7 @@ impl BinaryFold {
fsub => std::ops::Sub::sub,
fmul => std::ops::Mul::mul,
fdiv => std::ops::Div::div,
fpow => f64::powf,
_ => panic!("not binary op"),
}))
}
Expand Down Expand Up @@ -106,6 +109,7 @@ impl UnaryFold {
fabs => f64::abs,
ffloor => f64::floor,
fceil => f64::ceil,
fround => f64::round,
_ => panic!("not unary op."),
}))
}
Expand Down
74 changes: 67 additions & 7 deletions hugr-core/src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub enum IntOpDef {
idiv_s,
imod_checked_s,
imod_s,
ipow,
iabs,
iand,
ior,
Expand All @@ -98,6 +99,8 @@ pub enum IntOpDef {
ishr,
irotl,
irotr,
iu_to_s,
is_to_u,
itostring_u,
itostring_s,
}
Expand All @@ -116,12 +119,12 @@ impl MakeOpDef for IntOpDef {
let tv0 = int_tv(0);
match self {
iwiden_s | iwiden_u => CustomValidator::new(
int_polytype(2, vec![tv0.clone()], vec![int_tv(1)]),
int_polytype(2, vec![tv0], vec![int_tv(1)]),
IOValidator { f_ge_s: false },
)
.into(),
inarrow_s | inarrow_u => CustomValidator::new(
int_polytype(2, tv0.clone(), sum_ty_with_err(int_tv(1))),
int_polytype(2, tv0, sum_ty_with_err(int_tv(1))),
IOValidator { f_ge_s: true },
)
.into(),
Expand All @@ -130,10 +133,10 @@ impl MakeOpDef for IntOpDef {
ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => {
int_polytype(1, vec![tv0; 2], type_row![BOOL_T]).into()
}
imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor => {
imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor | ipow => {
ibinop_sig().into()
}
ineg | iabs | inot => iunop_sig().into(),
ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(),
idivmod_checked_u | idivmod_checked_s => {
let intpair: TypeRowRV = vec![tv0; 2].into();
int_polytype(
Expand Down Expand Up @@ -173,8 +176,8 @@ impl MakeOpDef for IntOpDef {
iwiden_s => "widen a signed integer to a wider one with the same value",
inarrow_u => "narrow an unsigned integer to a narrower one with the same value if possible",
inarrow_s => "narrow a signed integer to a narrower one with the same value if possible",
itobool => "convert to bool (1 is true, 0 is false)",
ifrombool => "convert from bool (1 is true, 0 is false)",
itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
ieq => "equality test",
ine => "inequality test",
ilt_u => "\"less than\" as unsigned integers",
Expand Down Expand Up @@ -209,6 +212,7 @@ impl MakeOpDef for IntOpDef {
idiv_s => "as idivmod_s but discarding the second output",
imod_checked_s => "as idivmod_checked_s but discarding the first output",
imod_s => "as idivmod_s but discarding the first output",
ipow => "raise first input to the power of second input",
iabs => "convert signed to unsigned by taking absolute value",
iand => "bitwise AND",
ior => "bitwise OR",
Expand All @@ -222,6 +226,8 @@ impl MakeOpDef for IntOpDef {
(leftmost bits replace rightmost bits)",
irotr => "rotate first input right by k bits where k is unsigned interpretation of second input \
(rightmost bits replace leftmost bits)",
is_to_u => "convert signed to unsigned by taking absolute value",
iu_to_s => "convert unsigned to signed by taking absolute value",
itostring_s => "convert a signed integer to its string representation",
itostring_u => "convert an unsigned integer to its string representation",
}.into()
Expand Down Expand Up @@ -366,6 +372,8 @@ fn sum_ty_with_err(t: Type) -> Type {

#[cfg(test)]
mod test {
use rstest::rstest;

use crate::{
ops::{dataflow::DataflowOpTrait, ExtensionOp},
std_extensions::arithmetic::int_types::int_type,
Expand All @@ -378,7 +386,7 @@ mod test {
fn test_int_ops_extension() {
assert_eq!(EXTENSION.name() as &str, "arithmetic.int");
assert_eq!(EXTENSION.types().count(), 0);
assert_eq!(EXTENSION.operations().count(), 47);
assert_eq!(EXTENSION.operations().count(), 50);
for (name, _) in EXTENSION.operations() {
assert!(name.starts_with('i'));
}
Expand Down Expand Up @@ -450,4 +458,56 @@ mod test {
assert_eq!(ConcreteIntOp::from_op(&ext_op).unwrap(), o);
assert_eq!(IntOpDef::from_op(&ext_op).unwrap(), IntOpDef::itobool);
}

#[rstest]
#[case::iadd(IntOpDef::iadd.with_log_width(5), &[1, 2], &[3], 5)]
#[case::isub(IntOpDef::isub.with_log_width(5), &[5, 2], &[3], 5)]
#[case::imul(IntOpDef::imul.with_log_width(5), &[2, 8], &[16], 5)]
#[case::idiv(IntOpDef::idiv_u.with_log_width(5), &[37, 8], &[4], 5)]
#[case::imod(IntOpDef::imod_u.with_log_width(5), &[43, 8], &[3], 5)]
#[case::ipow(IntOpDef::ipow.with_log_width(5), &[2, 8], &[256], 5)]
#[case::iu_to_s(IntOpDef::iu_to_s.with_log_width(5), &[42], &[42], 5)]
#[case::is_to_u(IntOpDef::is_to_u.with_log_width(5), &[42], &[42], 5)]
#[should_panic(expected = "too large to be converted to signed")]
#[case::iu_to_s_panic(IntOpDef::iu_to_s.with_log_width(5), &[u32::MAX as u64], &[], 5)]
#[should_panic(expected = "Cannot convert negative integer")]
#[case::is_to_u_panic(IntOpDef::is_to_u.with_log_width(5), &[(0u32.wrapping_sub(42)) as u64], &[], 5)]
fn int_fold(
#[case] op: ConcreteIntOp,
#[case] inputs: &[u64],
#[case] outputs: &[u64],
#[case] log_width: u8,
) {
use crate::ops::Value;
use crate::std_extensions::arithmetic::int_types::ConstInt;

let consts: Vec<_> = inputs
.iter()
.enumerate()
.map(|(i, &x)| {
(
i.into(),
Value::extension(ConstInt::new_u(log_width, x).unwrap()),
)
})
.collect();

let res = op
.to_extension_op()
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, &expected) in outputs.iter().enumerate() {
let res_val: u64 = res
.get(i)
.unwrap()
.1
.get_custom_value::<ConstInt>()
.expect("This function assumes all incoming constants are floats.")
.value_u();

assert_eq!(res_val, expected);
}
}
}
74 changes: 74 additions & 0 deletions hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,36 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
},
),
},
IntOpDef::ipow => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
let [arg] = type_args else {
return None;
};
let logwidth: u8 = get_log_width(arg).ok()?;
let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?;
if n0.log_width() != logwidth || n1.log_width() != logwidth {
None
} else {
Some(vec![(
0.into(),
Value::extension(
ConstInt::new_u(
logwidth,
n0.value_u()
.overflowing_pow(
n1.value_u().try_into().unwrap_or(u32::MAX),
)
.0
& bitmask_from_logwidth(logwidth),
)
.unwrap(),
),
)])
}
},
),
},
IntOpDef::idivmod_checked_u => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
Expand Down Expand Up @@ -1154,6 +1184,50 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
},
),
},
IntOpDef::is_to_u => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
let [arg] = type_args else {
return None;
};
let logwidth: u8 = get_log_width(arg).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
if n0.log_width() != logwidth {
None
} else {
if n0.value_s() < 0 {
panic!(
"Cannot convert negative integer {} to unsigned.",
n0.value_s()
);
}
Some(vec![(0.into(), Value::extension(n0.clone()))])
}
},
),
},
IntOpDef::iu_to_s => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
let [arg] = type_args else {
return None;
};
let logwidth: u8 = get_log_width(arg).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
if n0.log_width() != logwidth {
None
} else {
if n0.value_s() < 0 {
panic!(
"Unsigned integer {} is too large to be converted to signed.",
n0.value_u()
);
}
Some(vec![(0.into(), Value::extension(n0.clone()))])
}
},
),
},
IntOpDef::itostring_u => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
Expand Down
Loading
Loading