Skip to content

Commit

Permalink
feat: Add missing ops (#1463)
Browse files Browse the repository at this point in the history
Adds some missing operations:

- `fpow`
- `fround`
- `ipow`
- `iu_to_s` / `is_to_u`. These are almost noops, but some runtimes may
prefer to check and panic if the value is out-of-bounds.
- Mention that `ifrombool` / `itobool` only work with `i1` in their
description.
  • Loading branch information
aborgna-q authored Aug 27, 2024
1 parent 001e66a commit 841f450
Show file tree
Hide file tree
Showing 8 changed files with 687 additions and 15 deletions.
54 changes: 52 additions & 2 deletions hugr-core/src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ pub enum FloatOps {
fabs,
fmul,
fdiv,
fpow,
ffloor,
fceil,
fround,
ftostring,
}

Expand All @@ -60,10 +62,10 @@ impl MakeOpDef for FloatOps {
feq | fne | flt | fgt | fle | fge => {
Signature::new(type_row![FLOAT64_TYPE; 2], type_row![BOOL_T])
}
fmax | fmin | fadd | fsub | fmul | fdiv => {
fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
Signature::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE])
}
fneg | fabs | ffloor | fceil => Signature::new_endo(type_row![FLOAT64_TYPE]),
fneg | fabs | ffloor | fceil | fround => Signature::new_endo(type_row![FLOAT64_TYPE]),
ftostring => Signature::new(type_row![FLOAT64_TYPE], STRING_TYPE),
}
.into()
Expand All @@ -86,8 +88,10 @@ impl MakeOpDef for FloatOps {
fabs => "absolute value",
fmul => "multiplication",
fdiv => "division",
fpow => "exponentiation",
ffloor => "floor",
fceil => "ceiling",
fround => "round",
ftostring => "string representation",
}
.to_string()
Expand Down Expand Up @@ -133,6 +137,9 @@ impl MakeRegisteredOp for FloatOps {

#[cfg(test)]
mod test {
use cgmath::AbsDiffEq;
use rstest::rstest;

use super::*;

#[test]
Expand All @@ -144,4 +151,47 @@ mod test {
assert!(name.as_str().starts_with('f'));
}
}

#[rstest]
#[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])]
#[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])]
#[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])]
#[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])]
#[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])]
#[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])]
#[case::fceil(FloatOps::fceil, &[42.42], &[43.])]
#[case::fround(FloatOps::fround, &[42.42], &[42.])]
fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) {
use crate::ops::Value;
use crate::std_extensions::arithmetic::float_types::ConstF64;

let consts: Vec<_> = inputs
.iter()
.enumerate()
.map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x))))
.collect();

let res = op
.to_extension_op()
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, expected) in outputs.iter().enumerate() {
let res_val: f64 = res
.get(i)
.unwrap()
.1
.get_custom_value::<ConstF64>()
.expect("This function assumes all incoming constants are floats.")
.value();

assert!(
res_val.abs_diff_eq(expected, f64::EPSILON),
"expected {:?}, got {:?}",
expected,
res_val
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ pub(super) fn set_fold(op: &FloatOps, def: &mut OpDef) {
use FloatOps::*;

match op {
fmax | fmin | fadd | fsub | fmul | fdiv => def.set_constant_folder(BinaryFold::from_op(op)),
fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
def.set_constant_folder(BinaryFold::from_op(op))
}
feq | fne | flt | fgt | fle | fge => def.set_constant_folder(CmpFold::from_op(*op)),
fneg | fabs | ffloor | fceil => def.set_constant_folder(UnaryFold::from_op(op)),
fneg | fabs | ffloor | fceil | fround => def.set_constant_folder(UnaryFold::from_op(op)),
ftostring => def.set_constant_folder(ToStringFold::from_op(op)),
}
}
Expand Down Expand Up @@ -43,6 +45,7 @@ impl BinaryFold {
fsub => std::ops::Sub::sub,
fmul => std::ops::Mul::mul,
fdiv => std::ops::Div::div,
fpow => f64::powf,
_ => panic!("not binary op"),
}))
}
Expand Down Expand Up @@ -106,6 +109,7 @@ impl UnaryFold {
fabs => f64::abs,
ffloor => f64::floor,
fceil => f64::ceil,
fround => f64::round,
_ => panic!("not unary op."),
}))
}
Expand Down
74 changes: 67 additions & 7 deletions hugr-core/src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub enum IntOpDef {
idiv_s,
imod_checked_s,
imod_s,
ipow,
iabs,
iand,
ior,
Expand All @@ -98,6 +99,8 @@ pub enum IntOpDef {
ishr,
irotl,
irotr,
iu_to_s,
is_to_u,
itostring_u,
itostring_s,
}
Expand All @@ -116,12 +119,12 @@ impl MakeOpDef for IntOpDef {
let tv0 = int_tv(0);
match self {
iwiden_s | iwiden_u => CustomValidator::new(
int_polytype(2, vec![tv0.clone()], vec![int_tv(1)]),
int_polytype(2, vec![tv0], vec![int_tv(1)]),
IOValidator { f_ge_s: false },
)
.into(),
inarrow_s | inarrow_u => CustomValidator::new(
int_polytype(2, tv0.clone(), sum_ty_with_err(int_tv(1))),
int_polytype(2, tv0, sum_ty_with_err(int_tv(1))),
IOValidator { f_ge_s: true },
)
.into(),
Expand All @@ -130,10 +133,10 @@ impl MakeOpDef for IntOpDef {
ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => {
int_polytype(1, vec![tv0; 2], type_row![BOOL_T]).into()
}
imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor => {
imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor | ipow => {
ibinop_sig().into()
}
ineg | iabs | inot => iunop_sig().into(),
ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(),
idivmod_checked_u | idivmod_checked_s => {
let intpair: TypeRowRV = vec![tv0; 2].into();
int_polytype(
Expand Down Expand Up @@ -173,8 +176,8 @@ impl MakeOpDef for IntOpDef {
iwiden_s => "widen a signed integer to a wider one with the same value",
inarrow_u => "narrow an unsigned integer to a narrower one with the same value if possible",
inarrow_s => "narrow a signed integer to a narrower one with the same value if possible",
itobool => "convert to bool (1 is true, 0 is false)",
ifrombool => "convert from bool (1 is true, 0 is false)",
itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
ieq => "equality test",
ine => "inequality test",
ilt_u => "\"less than\" as unsigned integers",
Expand Down Expand Up @@ -209,6 +212,7 @@ impl MakeOpDef for IntOpDef {
idiv_s => "as idivmod_s but discarding the second output",
imod_checked_s => "as idivmod_checked_s but discarding the first output",
imod_s => "as idivmod_s but discarding the first output",
ipow => "raise first input to the power of second input",
iabs => "convert signed to unsigned by taking absolute value",
iand => "bitwise AND",
ior => "bitwise OR",
Expand All @@ -222,6 +226,8 @@ impl MakeOpDef for IntOpDef {
(leftmost bits replace rightmost bits)",
irotr => "rotate first input right by k bits where k is unsigned interpretation of second input \
(rightmost bits replace leftmost bits)",
is_to_u => "convert signed to unsigned by taking absolute value",
iu_to_s => "convert unsigned to signed by taking absolute value",
itostring_s => "convert a signed integer to its string representation",
itostring_u => "convert an unsigned integer to its string representation",
}.into()
Expand Down Expand Up @@ -366,6 +372,8 @@ fn sum_ty_with_err(t: Type) -> Type {

#[cfg(test)]
mod test {
use rstest::rstest;

use crate::{
ops::{dataflow::DataflowOpTrait, ExtensionOp},
std_extensions::arithmetic::int_types::int_type,
Expand All @@ -378,7 +386,7 @@ mod test {
fn test_int_ops_extension() {
assert_eq!(EXTENSION.name() as &str, "arithmetic.int");
assert_eq!(EXTENSION.types().count(), 0);
assert_eq!(EXTENSION.operations().count(), 47);
assert_eq!(EXTENSION.operations().count(), 50);
for (name, _) in EXTENSION.operations() {
assert!(name.starts_with('i'));
}
Expand Down Expand Up @@ -450,4 +458,56 @@ mod test {
assert_eq!(ConcreteIntOp::from_op(&ext_op).unwrap(), o);
assert_eq!(IntOpDef::from_op(&ext_op).unwrap(), IntOpDef::itobool);
}

#[rstest]
#[case::iadd(IntOpDef::iadd.with_log_width(5), &[1, 2], &[3], 5)]
#[case::isub(IntOpDef::isub.with_log_width(5), &[5, 2], &[3], 5)]
#[case::imul(IntOpDef::imul.with_log_width(5), &[2, 8], &[16], 5)]
#[case::idiv(IntOpDef::idiv_u.with_log_width(5), &[37, 8], &[4], 5)]
#[case::imod(IntOpDef::imod_u.with_log_width(5), &[43, 8], &[3], 5)]
#[case::ipow(IntOpDef::ipow.with_log_width(5), &[2, 8], &[256], 5)]
#[case::iu_to_s(IntOpDef::iu_to_s.with_log_width(5), &[42], &[42], 5)]
#[case::is_to_u(IntOpDef::is_to_u.with_log_width(5), &[42], &[42], 5)]
#[should_panic(expected = "too large to be converted to signed")]
#[case::iu_to_s_panic(IntOpDef::iu_to_s.with_log_width(5), &[u32::MAX as u64], &[], 5)]
#[should_panic(expected = "Cannot convert negative integer")]
#[case::is_to_u_panic(IntOpDef::is_to_u.with_log_width(5), &[(0u32.wrapping_sub(42)) as u64], &[], 5)]
fn int_fold(
#[case] op: ConcreteIntOp,
#[case] inputs: &[u64],
#[case] outputs: &[u64],
#[case] log_width: u8,
) {
use crate::ops::Value;
use crate::std_extensions::arithmetic::int_types::ConstInt;

let consts: Vec<_> = inputs
.iter()
.enumerate()
.map(|(i, &x)| {
(
i.into(),
Value::extension(ConstInt::new_u(log_width, x).unwrap()),
)
})
.collect();

let res = op
.to_extension_op()
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, &expected) in outputs.iter().enumerate() {
let res_val: u64 = res
.get(i)
.unwrap()
.1
.get_custom_value::<ConstInt>()
.expect("This function assumes all incoming constants are floats.")
.value_u();

assert_eq!(res_val, expected);
}
}
}
74 changes: 74 additions & 0 deletions hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,36 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
},
),
},
IntOpDef::ipow => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
let [arg] = type_args else {
return None;
};
let logwidth: u8 = get_log_width(arg).ok()?;
let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?;
if n0.log_width() != logwidth || n1.log_width() != logwidth {
None
} else {
Some(vec![(
0.into(),
Value::extension(
ConstInt::new_u(
logwidth,
n0.value_u()
.overflowing_pow(
n1.value_u().try_into().unwrap_or(u32::MAX),
)
.0
& bitmask_from_logwidth(logwidth),
)
.unwrap(),
),
)])
}
},
),
},
IntOpDef::idivmod_checked_u => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
Expand Down Expand Up @@ -1154,6 +1184,50 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
},
),
},
IntOpDef::is_to_u => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
let [arg] = type_args else {
return None;
};
let logwidth: u8 = get_log_width(arg).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
if n0.log_width() != logwidth {
None
} else {
if n0.value_s() < 0 {
panic!(
"Cannot convert negative integer {} to unsigned.",
n0.value_s()
);
}
Some(vec![(0.into(), Value::extension(n0.clone()))])
}
},
),
},
IntOpDef::iu_to_s => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
let [arg] = type_args else {
return None;
};
let logwidth: u8 = get_log_width(arg).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
if n0.log_width() != logwidth {
None
} else {
if n0.value_s() < 0 {
panic!(
"Unsigned integer {} is too large to be converted to signed.",
n0.value_u()
);
}
Some(vec![(0.into(), Value::extension(n0.clone()))])
}
},
),
},
IntOpDef::itostring_u => Folder {
folder: Box::new(
|type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult {
Expand Down
Loading

0 comments on commit 841f450

Please sign in to comment.