From df5e1618ed7807b19ffdbf7f08a82488f8f0e811 Mon Sep 17 00:00:00 2001 From: zmx <122964395@qq.com> Date: Fri, 19 Jul 2024 17:37:02 +0800 Subject: [PATCH] FCVT:add conversion for FP16 and modified CVT64 module to parameterize it *The scalar float conversion function can be parameterized and trimmed in the CVT64 *Add scalar IntToFP conversion for FP16 --- src/main/scala/yunsuan/package.scala | 11 +- src/main/scala/yunsuan/scalar/Convert.scala | 160 ++++ src/main/scala/yunsuan/scalar/FPU.scala | 40 + src/main/scala/yunsuan/scalar/IntToFP.scala | 122 +++ .../scala/yunsuan/scalar/RoundingUnit.scala | 57 ++ src/main/scala/yunsuan/scalar/utils.scala | 76 ++ .../yunsuan/vector/VectorConvert/CVT64.scala | 794 ++++++++++-------- .../yunsuan/vector/VectorConvert/VCVT.scala | 2 +- src/test/csrc/golden_model/gm_common.cpp | 210 ++++- .../golden_model/scalar_float_convert.cpp | 167 ++++ src/test/csrc/include/gm_common.h | 8 + src/test/csrc/include/test_driver.h | 2 +- src/test/csrc/include/vpu_constant.h | 66 +- src/test/csrc/test_driver.cpp | 48 +- src/test/scala/top/VectorSimTop.scala | 41 +- 15 files changed, 1450 insertions(+), 354 deletions(-) create mode 100644 src/main/scala/yunsuan/scalar/Convert.scala create mode 100644 src/main/scala/yunsuan/scalar/FPU.scala create mode 100644 src/main/scala/yunsuan/scalar/IntToFP.scala create mode 100644 src/main/scala/yunsuan/scalar/RoundingUnit.scala create mode 100644 src/main/scala/yunsuan/scalar/utils.scala create mode 100644 src/test/csrc/golden_model/scalar_float_convert.cpp diff --git a/src/main/scala/yunsuan/package.scala b/src/main/scala/yunsuan/package.scala index 535ced4..e16c551 100644 --- a/src/main/scala/yunsuan/package.scala +++ b/src/main/scala/yunsuan/package.scala @@ -530,10 +530,8 @@ object VfcvtType { def vfcvt_fxv = "b01_000011".U(8.W) def vfcvt_rtz_xufv = "b10_000110".U(8.W) def vfcvt_rtz_xfv = "b10_000111".U(8.W) - def vfrsqrt7 = "b11_100000".U(8.W) def vfrec7 = "b11_100001".U(8.W) - def vfwcvt_xufv = "b10_001000".U(8.W) def vfwcvt_xfv = "b10_001001".U(8.W) def vfwcvt_fxuv = "b01_001010".U(8.W) @@ -541,7 +539,6 @@ object VfcvtType { def vfwcvt_ffv = "b11_001100".U(8.W) def vfwcvt_rtz_xufv = "b10_001110".U(8.W) def vfwcvt_rtz_xfv = "b10_001111".U(8.W) - def vfncvt_xufw = "b10_010000".U(8.W) def vfncvt_xfw = "b10_010001".U(8.W) def vfncvt_fxuw = "b01_010010".U(8.W) @@ -550,6 +547,14 @@ object VfcvtType { def vfncvt_rod_ffw = "b11_010101".U(8.W) def vfncvt_rtz_xufw = "b10_010110".U(8.W) def vfncvt_rtz_xfw = "b10_010111".U(8.W) + def fcvt_h_s = "b11_010000".U(8.W) + def fcvt_s_h = "b11_001000".U(8.W) + def fcvt_h_d = "b11_011000".U(8.W) + def fcvt_d_h = "b11_011000".U(8.W) + def fcvt_w_h = "b10_001001".U(8.W) + def fcvt_wu_h = "b10_001000".U(8.W) + def fcvt_l_h = "b10_011001".U(8.W) + def fcvt_lu_h = "b10_011000".U(8.W) } diff --git a/src/main/scala/yunsuan/scalar/Convert.scala b/src/main/scala/yunsuan/scalar/Convert.scala new file mode 100644 index 0000000..530d0b1 --- /dev/null +++ b/src/main/scala/yunsuan/scalar/Convert.scala @@ -0,0 +1,160 @@ +package yunsuan.scalar + +import chisel3._ +import chisel3.util._ +import chisel3.util.experimental.decode._ +import yunsuan.util._ +import yunsuan.vector.VectorConvert.CVT64 + +// Scalar Int to Float Convert. +class I2fCvtIO extends Bundle{ + val src = Input(UInt(64.W)) + val opType = Input(UInt(5.W)) + val rm = Input(UInt(3.W)) + val wflags = Input(Bool()) + val rmInst = Input(UInt(3.W)) + + val result = Output(UInt(64.W)) + val fflags = Output(UInt(5.W)) +} +class INT2FP(latency: Int, XLEN: Int) extends Module { + val io = IO(new I2fCvtIO) + val rm = Mux(io.rmInst === "b111".U, io.rm, io.rmInst) + val regEnables = IO(Input(Vec(latency, Bool()))) + dontTouch(regEnables) + // stage1 + val in = io.src + val wflags = io.wflags + val typeIn = io.opType(3) + val typeOut = io.opType(2,1) + val signIn = io.opType(0) + val intValue = RegEnable(Mux(wflags, + Mux(typeIn, + Mux(!signIn, ZeroExt(in, XLEN), SignExt(in, XLEN)), + Mux(!signIn, ZeroExt(in(31, 0), XLEN), SignExt(in(31, 0), XLEN)) + ), + in + ), regEnables(0)) + val typeInReg = RegEnable(typeIn, regEnables(0)) + val typeOutReg = RegEnable(typeOut, regEnables(0)) + val signInReg = RegEnable(signIn, regEnables(0)) + val wflagsReg = RegEnable(wflags, regEnables(0)) + val rmReg = RegEnable(rm, regEnables(0)) + + // stage2 + val s2_typeInReg = typeInReg + val s2_signInReg = signInReg + val s2_typeOutReg = typeOutReg + val s2_wflags = wflagsReg + val s2_rmReg = rmReg + + val mux = Wire(new Bundle() { + val data = UInt(XLEN.W) + val exc = UInt(5.W) + }) + + mux.data := intValue + mux.exc := 0.U + + when(s2_wflags){ + val i2fResults = for(t <- FPU.ftypes) yield { + val i2f = Module(new IntToFP(t.expWidth, t.precision)) + i2f.io.sign := s2_signInReg + i2f.io.long := s2_typeInReg + i2f.io.int := intValue + i2f.io.rm := s2_rmReg + (i2f.io.result, i2f.io.fflags) + } + val (data, exc) = i2fResults.unzip + mux.data := VecInit(data)(s2_typeOutReg) + mux.exc := VecInit(exc)(s2_typeOutReg) + } + + // stage 3 + val s3_out = RegEnable(mux, regEnables(1)) + val s3_tag = RegEnable(s2_typeOutReg, regEnables(1)) + + io.fflags := s3_out.exc + io.result := FPU.box(s3_out.data, s3_tag) +} +// Scalar Float to Int or Float Convert. +class FpCvtIO(width: Int) extends Bundle { + val fire = Input(Bool()) + val src = Input(UInt(width.W)) + val opType = Input(UInt(8.W)) + val sew = Input(UInt(2.W)) + val rm = Input(UInt(3.W)) + val isFpToVecInst = Input(Bool()) + + val result = Output(UInt(width.W)) + val fflags = Output(UInt(5.W)) +} +class FPCVT(xlen :Int) extends Module{ + val io = IO(new FpCvtIO(xlen)) + val (opType, sew) = (io.opType, io.sew) + val widen = opType(4, 3) // 0->single 1->widen 2->norrow => width of result + // input width 8, 16, 32, 64 + val input1H = Wire(UInt(4.W)) + input1H := chisel3.util.experimental.decode.decoder( + widen ## sew, + TruthTable( + Seq( + BitPat("b00_01") -> BitPat("b0010"), // 16 + BitPat("b00_10") -> BitPat("b0100"), // 32 + BitPat("b00_11") -> BitPat("b1000"), // 64 + + BitPat("b01_00") -> BitPat("b0001"), // 8 + BitPat("b01_01") -> BitPat("b0010"), // 16 + BitPat("b01_10") -> BitPat("b0100"), // 32 + + BitPat("b10_00") -> BitPat("b0010"), // 16 + BitPat("b10_01") -> BitPat("b0100"), // 32 + BitPat("b10_10") -> BitPat("b1000"), // 64 + + BitPat("b11_01") -> BitPat("b0010"), // f16->f64/i64/ui64 + BitPat("b11_11") -> BitPat("b1000"), // f64->f16 + ), + BitPat("b0000") + ) + ) + // output width 8, 16, 32, 64 + val output1H = Wire(UInt(4.W)) + output1H := chisel3.util.experimental.decode.decoder( + widen ## sew, + TruthTable( + Seq( + BitPat("b00_01") -> BitPat("b0010"), // 16 + BitPat("b00_10") -> BitPat("b0100"), // 32 + BitPat("b00_11") -> BitPat("b1000"), // 64 + + BitPat("b01_00") -> BitPat("b0010"), // 16 + BitPat("b01_01") -> BitPat("b0100"), // 32 + BitPat("b01_10") -> BitPat("b1000"), // 64 + + BitPat("b10_00") -> BitPat("b0001"), // 8 + BitPat("b10_01") -> BitPat("b0010"), // 16 + BitPat("b10_10") -> BitPat("b0100"), // 32 + + BitPat("b11_11") -> BitPat("b0010"), // f64->f16 + BitPat("b11_01") -> BitPat("b1000"), // f16->f64/i64/ui64 + ), + BitPat("b0000") + ) + ) + dontTouch(input1H) + dontTouch(output1H) + val fcvt = Module(new CVT64(64, false)) + fcvt.io.sew := io.sew + fcvt.io.fire := io.fire + fcvt.io.src := io.src + fcvt.io.rm := io.rm + fcvt.io.opType := io.opType + fcvt.io.rm := io.rm + fcvt.io.isFpToVecInst := io.isFpToVecInst + fcvt.io.input1H := input1H + fcvt.io.output1H := output1H + + io.fflags := fcvt.io.fflags + io.result := fcvt.io.result + +} diff --git a/src/main/scala/yunsuan/scalar/FPU.scala b/src/main/scala/yunsuan/scalar/FPU.scala new file mode 100644 index 0000000..c5a17f2 --- /dev/null +++ b/src/main/scala/yunsuan/scalar/FPU.scala @@ -0,0 +1,40 @@ +package yunsuan.scalar + +import chisel3._ +import chisel3.util._ + +object FPU { + + case class FType(expWidth: Int, precision: Int) { + val sigWidth = precision - 1 + val len = expWidth + precision + } + + val f16 = FType(5, 11) + val f32 = FType(8, 24) + val f64 = FType(11, 53) + + val ftypes = List(f16, f32, f64) + + val H = ftypes.indexOf(f16).U(log2Ceil(ftypes.length).W) + val S = ftypes.indexOf(f32).U(log2Ceil(ftypes.length).W) + val D = ftypes.indexOf(f64).U(log2Ceil(ftypes.length).W) + + + def box(x: UInt, typeTag: UInt): UInt = { + require(x.getWidth == 64) + Mux(typeTag === D, x, Mux(typeTag === S, Cat(~0.U(32.W), x(31, 0)), Cat(~0.U(48.W), x(15, 0)))) + } + + def box(x: UInt, t: FType): UInt = { + if(t == f32){ + Cat(~0.U(32.W), x(31, 0)) + } else if(t == f64){ + x(63, 0) + } else { + assert(cond = false, "Unknown ftype!") + 0.U + } + } + +} diff --git a/src/main/scala/yunsuan/scalar/IntToFP.scala b/src/main/scala/yunsuan/scalar/IntToFP.scala new file mode 100644 index 0000000..78ecbb0 --- /dev/null +++ b/src/main/scala/yunsuan/scalar/IntToFP.scala @@ -0,0 +1,122 @@ +package yunsuan.scalar + +import chisel3._ +import chisel3.util._ +import yunsuan.vector.VectorConvert.RoundingModle._ +class IntToFP_prenorm_in extends Bundle { + val int = Input(UInt(64.W)) + val sign = Input(Bool()) + val long = Input(Bool()) +} + +class IntToFP_prenorm_out extends Bundle { + val norm_int = Output(UInt(63.W)) + val lzc = Output(UInt(6.W)) + val is_zero = Output(Bool()) + val sign = Output(Bool()) +} + +/** + * different fp types can share this unit + */ +class IntToFP_prenorm extends Module { + + val io = IO(new Bundle() { + val in = new IntToFP_prenorm_in + val out = new IntToFP_prenorm_out + }) + + val (in, signed_int, long_int) = (io.in.int, io.in.sign, io.in.long) + + val in_sign = signed_int && Mux(long_int, in(63), in(31)) + val in_sext = Cat(Fill(32, in(31)), in(31, 0)) + val in_raw = Mux(signed_int && !long_int, in_sext, in) + val in_abs = Mux(in_sign, (~in_raw).asUInt + 1.U, in_raw) + + val lza = Module(new LZA(64)) + lza.io.a := 0.U + lza.io.b := ~in_raw + + val pos_lzc = CLZ(in) + val neg_lzc = CLZ(lza.io.f) + + val lzc = Mux(in_sign, neg_lzc, pos_lzc) + + // eg: 001010 => 001000 + val one_mask = Cat((0 until 64).reverseMap { + case i @ 63 => lza.io.f(i) + case i @ 0 => !lza.io.f(63, i + 1).orR + case i => lza.io.f(i) && !lza.io.f(63, i + 1).orR + }) + + val lzc_error = Mux(in_sign, !(in_abs & one_mask).orR, false.B) + + val in_shift_s1 = (in_abs << lzc)(62, 0) + val in_norm = Mux(lzc_error, Cat(in_shift_s1.tail(1), 0.U(1.W)), in_shift_s1) + + io.out.norm_int := in_norm + io.out.lzc := lzc + lzc_error + io.out.is_zero := in === 0.U + io.out.sign := in_sign + +} + +class IntToFP_postnorm(val expWidth: Int, val precision: Int) extends Module { + val io = IO(new Bundle() { + val in = Flipped(new IntToFP_prenorm_out) + val rm = Input(UInt(3.W)) + val result = Output(UInt((expWidth + precision).W)) + val fflags = Output(UInt(5.W)) + }) + val (in, lzc, is_zero, in_sign, rm) = + (io.in.norm_int, io.in.lzc, io.in.is_zero, io.in.sign, io.rm) + + val exp_raw = (63 + FloatPoint.expBias(expWidth)).U(11.W) - lzc + val sig_raw = in.head(precision - 1) // exclude hidden bit + val round_bit = in.tail(precision - 1).head(1) + val sticky_bit = in.tail(precision).orR + + val rounder = Module(new RoundingUnit(precision - 1)) + rounder.io.in := sig_raw + rounder.io.roundIn := round_bit + rounder.io.stickyIn := sticky_bit + rounder.io.signIn := in_sign + rounder.io.rm := rm + + val rmin = + io.rm === RTZ || (in_sign && io.rm === RUP) || (!in_sign && io.rm === RDN) + + val nv, dz, of, uf, nx = Wire(Bool()) + val ix = rounder.io.inexact + val fp_exp = Mux(is_zero, 0.U, exp_raw + rounder.io.cout) + val fp_sig = rounder.io.out + val flow = fp_exp>((1< ((r && s) || (r && !s && g)), + RTZ -> false.B, + RUP -> (inexact & !io.signIn), + RDN -> (inexact & io.signIn), + RMM -> r + ) + ) + val out_r_up = io.in + 1.U + io.out := Mux(r_up, out_r_up, io.in) + io.inexact := inexact + // r_up && io.in === 111...1 + io.cout := r_up && io.in.andR + io.r_up := r_up +} + +object RoundingUnit { + def apply(in: UInt, rm: UInt, sign: Bool, width: Int): RoundingUnit = { + require(in.getWidth >= width) + val in_pad = if(in.getWidth < width + 2) padd_tail(in, width + 2) else in + val rounder = Module(new RoundingUnit(width)) + rounder.io.in := in_pad.head(width) + rounder.io.roundIn := in_pad.tail(width).head(1).asBool + rounder.io.stickyIn := in_pad.tail(width + 1).orR + rounder.io.rm := rm + rounder.io.signIn := sign + rounder + } + def padd_tail(x: UInt, w: Int): UInt = Cat(x, 0.U((w - x.getWidth).W)) + def is_rmin(rm: UInt, sign: Bool): Bool = { + rm === RTZ || (rm === RDN && !sign) || (rm === RUP && sign) + } +} + diff --git a/src/main/scala/yunsuan/scalar/utils.scala b/src/main/scala/yunsuan/scalar/utils.scala new file mode 100644 index 0000000..9df9bb8 --- /dev/null +++ b/src/main/scala/yunsuan/scalar/utils.scala @@ -0,0 +1,76 @@ +package yunsuan.scalar + +import chisel3._ +import chisel3.util._ + + +object FloatPoint { + def expBias(expWidth: Int): BigInt = { + (BigInt(1) << (expWidth - 1)) - 1 + } + def maxNormExp(expWidth: Int): BigInt = { + (BigInt(1) << expWidth) - 2 + } + +} +object SignExt { + def apply(a: UInt, len: Int): UInt = { + val aLen = a.getWidth + val signBit = a(aLen-1) + if (aLen >= len) a(len-1,0) else Cat(Fill(len - aLen, signBit), a) + } +} + +object ZeroExt { + def apply(a: UInt, len: Int): UInt = { + val aLen = a.getWidth + if (aLen >= len) a(len-1,0) else Cat(0.U((len - aLen).W), a) + } +} + +class LzaIO(val len: Int) extends Bundle { + val a, b = Input(UInt(len.W)) + val f = Output(UInt(len.W)) +} + +class LZA(len: Int) extends Module { + val io = IO(new LzaIO(len)) + + val (a, b) = (io.a, io.b) + + val p, k, f = Wire(Vec(len, Bool())) + for (i <- 0 until len) { + p(i) := a(i) ^ b(i) + k(i) := (!a(i)) & (!b(i)) + if (i == 0) { + f(i) := false.B + } else { + f(i) := p(i) ^ (!k(i - 1)) + } + } + io.f := Cat(f.reverse) +} + +class CLZ(len: Int, zero: Boolean) extends Module { + + val inWidth = len + val outWidth = (inWidth - 1).U.getWidth + + val io = IO(new Bundle() { + val in = Input(UInt(inWidth.W)) + val out = Output(UInt(outWidth.W)) + }) + + io.out := PriorityEncoder(io.in.asBools.reverse) +} + +object CLZ { + def apply(value: UInt): UInt = { + val clz = Module(new CLZ(value.getWidth, true)) + clz.io.in := value + clz.io.out + } + def apply(xs: Seq[Bool]): UInt = { + apply(Cat(xs.reverse)) + } +} \ No newline at end of file diff --git a/src/main/scala/yunsuan/vector/VectorConvert/CVT64.scala b/src/main/scala/yunsuan/vector/VectorConvert/CVT64.scala index ab08fb9..d7b605d 100644 --- a/src/main/scala/yunsuan/vector/VectorConvert/CVT64.scala +++ b/src/main/scala/yunsuan/vector/VectorConvert/CVT64.scala @@ -6,35 +6,48 @@ import yunsuan.vector.VectorConvert.util._ import yunsuan.vector.VectorConvert.utils._ import yunsuan.vector.VectorConvert.RoundingModle._ import yunsuan.util._ - -class CVT64(width: Int = 64) extends CVT(width){ - - //parameter - val fpParamMap = Seq(f16, f32, f64) - val biasDeltaMap = Seq(f32.bias - f16.bias, f64.bias - f32.bias) - val intParamMap = (0 to 3).map(i => (1 << i) * 8) - val widthExpAdder = 13 // 13bits is enough - - // input - val (fire, src, sew, opType, rmNext, input1H, output1H, isFpToVecInst) = - (io.fire, io.src, io.sew, io.opType, io.rm, io.input1H, io.output1H, io.isFpToVecInst) +class CVT64(width: Int = 64,mode: Boolean) extends CVT(width){ + val (fire, src, sew, opType, rm, input1H, output1H, isFpToVecInst) = + (io.fire, io.src, io.sew, io.opType, io.rm, io.input1H, io.output1H, io.isFpToVecInst) val fireReg = GatedValidRegNext(fire) - // control for cycle 0 - val isWiden = !opType(4) && opType(3) - val isNarrow = opType(4) && !opType(3) - val inIsFpNext = opType.head(1).asBool val outIsFpNext = opType.tail(1).head(1).asBool val hasSignIntNext = opType(0).asBool + val inIsFpNext = opType.head(1).asBool + val isWiden = !opType(4) && opType(3) + val isNarrow = opType(4) && !opType(3) + val outIsF16 = outIsFpNext && output1H(1) + val outIsF64 = outIsFpNext && output1H(3) + val isCrossHigh = opType(4) && opType(3) && outIsF64 + val isCrossLow = opType(4) && opType(3) && outIsF16 + val isEstimate7Next = opType(5) + + val (isInt2FpNext, isFpWidenNext, isFpNarrowNext, isFp2IntNext, isFpCrossHighNext, isFpCrossLowNext) = + (!inIsFpNext, inIsFpNext && outIsFpNext && isWiden, inIsFpNext && outIsFpNext && isNarrow, + !outIsFpNext, inIsFpNext && outIsFpNext && isCrossHigh, inIsFpNext && outIsFpNext && isCrossLow) + + val isInt2Fp = RegEnable(isInt2FpNext, false.B, fire) + val isFpWiden = RegEnable(isFpWidenNext, false.B, fire) + val isFpNarrow = RegEnable(isFpNarrowNext, false.B, fire) + val isEstimate7 = RegEnable(isEstimate7Next, false.B, fire) + val isFp2Int = RegEnable(isFp2IntNext, false.B, fire) + val isFpCrossHigh = RegEnable(isFpCrossHighNext, false.B, fire) + val isFpCrossLow = RegEnable(isFpCrossLowNext, false.B, fire) + val isFPsrc = isFpWiden || isFpNarrow || isFpCrossHigh || isFpCrossLow || isFp2Int val s0_outIsF64 = outIsFpNext && output1H(3) + val s0_outIsF32 = outIsFpNext && output1H(2) + val s0_outIsF16 = outIsFpNext && output1H(1) val s0_outIsU32 = !outIsFpNext && output1H(2) && !hasSignIntNext val s0_outIsS32 = !outIsFpNext && output1H(2) && hasSignIntNext val s0_outIsU64 = !outIsFpNext && output1H(3) && !hasSignIntNext val s0_outIsS64 = !outIsFpNext && output1H(3) && hasSignIntNext val s0_fpCanonicalNAN = isFpToVecInst & inIsFpNext & (input1H(1) & !src.head(48).andR | input1H(2) & !src.head(32).andR) + val s1_outIsFP = RegEnable(outIsFpNext, fire) val s1_outIsF64 = RegEnable(s0_outIsF64, fire) + val s1_outIsF32 = RegEnable(s0_outIsF32, fire) + val s1_outIsF16 = RegEnable(s0_outIsF16, fire) val s1_outIsU32 = RegEnable(s0_outIsU32, fire) val s1_outIsS32 = RegEnable(s0_outIsS32, fire) val s1_outIsU64 = RegEnable(s0_outIsU64, fire) @@ -42,232 +55,214 @@ class CVT64(width: Int = 64) extends CVT(width){ val s1_fpCanonicalNAN = RegEnable(s0_fpCanonicalNAN, fire) val s2_outIsF64 = RegEnable(s1_outIsF64, fireReg) - val s2_outIsU32 = RegEnable(s1_outIsU32, fireReg) - val s2_outIsS32 = RegEnable(s1_outIsS32, fireReg) - val s2_outIsU64 = RegEnable(s1_outIsU64, fireReg) - val s2_outIsS64 = RegEnable(s1_outIsS64, fireReg) + val s2_outIsFP = RegEnable(s1_outIsFP, fireReg) val s2_fpCanonicalNAN = RegEnable(s1_fpCanonicalNAN, fireReg) + //inst FPTOINT and FPTOFP module + val fpcvt = Module(new FP_INCVT) + fpcvt.io.fire := fire + fpcvt.io.src := src + fpcvt.io.rm := rm + fpcvt.io.opType := opType + fpcvt.io.input1H := input1H + fpcvt.io.output1H := output1H + fpcvt.io.isFpToVecInst := isFpToVecInst - val int1HSrcNext = input1H - val float1HSrcNext = input1H.head(3)//exclude f8 + val s1_resultForfpCanonicalNAN = Mux1H( + Seq(s1_outIsF64, s1_outIsF32, s1_outIsF16, s1_outIsU32 || s1_outIsU64, s1_outIsS32, s1_outIsS64), + Seq(~0.U((f64.expWidth+1).W) ## 0.U((f64.fracWidth-1).W), + ~0.U((f32.expWidth+1).W) ## 0.U((f32.fracWidth-1).W), + ~0.U((f16.expWidth+1).W) ## 0.U((f16.fracWidth-1).W), + ~0.U(64.W), + ~0.U(31.W), + ~0.U(63.W)) + ) + val s2_resultForfpCanonicalNAN = RegEnable(s1_resultForfpCanonicalNAN, fireReg) + if(mode){//inst INTTOFP and ESTMATE module + val int2fp = Module(new INT2FP) + int2fp.io.fire := fire + int2fp.io.src := src + int2fp.io.rm := rm + int2fp.io.opType := opType + int2fp.io.input1H := input1H + int2fp.io.output1H := output1H + val estmate7 = Module(new Estimate7) + estmate7.io.fire := fire + estmate7.io.src := src + estmate7.io.rm := rm + estmate7.io.opType := opType + estmate7.io.input1H := input1H + estmate7.io.output1H := output1H + //result + val type1H = Cat(isInt2FpNext, isFPsrc, isEstimate7Next).asBools.reverse + val result = Mux1H(type1H, Seq( + int2fp.io.result, + fpcvt.io.result, + estmate7.io.result + )) + val fflags = Mux1H(type1H, Seq( + int2fp.io.fflags, + fpcvt.io.fflags, + estmate7.io.fflags + )) + io.result := Mux(s2_fpCanonicalNAN, s2_resultForfpCanonicalNAN, result) + io.fflags := Mux(s2_fpCanonicalNAN && !s2_outIsFP, "b10000".U, fflags) + }else{ + val result = fpcvt.io.result + val fflags = fpcvt.io.fflags + io.result := Mux(s2_fpCanonicalNAN, s2_resultForfpCanonicalNAN, result) + io.fflags := Mux(s2_fpCanonicalNAN && !s2_outIsFP, "b10000".U, fflags) + } +} +class CVT_IO extends Bundle{ + val fire = Input(Bool()) + val src = Input(UInt(64.W)) + val opType = Input(UInt(8.W)) + val rm = Input(UInt(3.W)) + val input1H = Input(UInt(4.W)) + val output1H = Input(UInt(4.W)) + val isFpToVecInst = Input(Bool()) + val result = Output(UInt(64.W)) + val fflags = Output(UInt(5.W)) +} +class INTCVT_IO extends Bundle{ + val fire = Input(Bool()) + val src = Input(UInt(64.W)) + val opType = Input(UInt(8.W)) + val rm = Input(UInt(3.W)) + val input1H = Input(UInt(4.W)) + val output1H = Input(UInt(4.W)) + val result = Output(UInt(64.W)) + val fflags = Output(UInt(5.W)) +} +class FP_INCVT extends Module { + val io = IO(new CVT_IO) + //parameter + val fpParamMap = Seq(f16, f32, f64) + val biasDeltaMap = Seq(f32.bias - f16.bias, f64.bias - f32.bias, f64.bias - f16.bias) + val intParamMap = (0 to 3).map(i => (1 << i) * 8) + val widthExpAdder = 13 // 13bits is enough + //input + val (fire, src, opType, rmNext, input1H, output1H, isFpToVecInst) = + (io.fire, io.src, io.opType, io.rm, io.input1H, io.output1H, io.isFpToVecInst) + val fireReg = GatedValidRegNext(fire) - val int1HOutNext = output1H + val isWiden = !opType(4) && opType(3) + val isNarrow = opType(4) && !opType(3) + val outIsFpNext = opType.tail(1).head(1).asBool + val outIsF16 = outIsFpNext && output1H(1) + val outIsF64 = outIsFpNext && output1H(3) + val isCrossHigh = opType(4) && opType(3) && outIsF64 + val isCrossLow = opType(4) && opType(3) && outIsF16 + val hasSignIntNext = opType(0).asBool + val float1HSrcNext = input1H.head(3)//exclude f8 val float1HOutNext = output1H.head(3)//exclude f8 + //fp input extend val srcMap = (0 to 3).map(i => src((1 << i) * 8 - 1, 0)) - val intMap = srcMap.map(int => intExtend(int, hasSignIntNext && int.head(1).asBool)) - val floatMap = srcMap.zipWithIndex.map{case (float,i) => floatExtend(float, i)}.drop(1) - val input = Mux(inIsFpNext, - Mux1H(float1HSrcNext, floatMap), - Mux1H(int1HSrcNext, intMap) - ) - - val signSrcNext = input.head(1).asBool - - // src is int - val absIntSrcNext = Wire(UInt(64.W)) //cycle0 - absIntSrcNext := Mux(signSrcNext, (~input.tail(1)).asUInt + 1.U, input.tail(1)) - val isZeroIntSrcNext = !absIntSrcNext.orR - - /** src is float contral path - * special: +/- INF, NaN, qNaN, SNaN, 0, Great NF, canonical NaN - */ + val input = Mux1H(float1HSrcNext, floatMap) val expSrcNext = input.tail(1).head(f64.expWidth) val fracSrc = input.tail(f64.expWidth+1).head(f64.fracWidth) + val signSrcNext = input.head(1).asBool val decodeFloatSrc = Mux1H(float1HSrcNext, fpParamMap.map(fp => - VecInit(expSrcNext(fp.expWidth-1,0).orR, expSrcNext(fp.expWidth-1,0).andR, fracSrc.head(fp.fracWidth).orR).asUInt - ) + VecInit(expSrcNext(fp.expWidth-1,0).orR, expSrcNext(fp.expWidth-1,0).andR, fracSrc.head(fp.fracWidth).orR).asUInt + ) ) - val (expNotZeroSrcNext, expIsOnesSrcNext, fracNotZeroSrcNext) = (decodeFloatSrc(0), decodeFloatSrc(1), decodeFloatSrc(2)) - val expIsZeroSrcNext = !expNotZeroSrcNext - val fracIsZeroSrcNext = !fracNotZeroSrcNext - val isSubnormalSrcNext = expIsZeroSrcNext && fracNotZeroSrcNext - val isnormalSrcNext = !expIsOnesSrcNext && !expIsZeroSrcNext - val isInfSrcNext = expIsOnesSrcNext && fracIsZeroSrcNext - val isZeroSrcNext = expIsZeroSrcNext && fracIsZeroSrcNext val isNaNSrcNext = expIsOnesSrcNext && fracNotZeroSrcNext + val isZeroSrcNext = !expNotZeroSrcNext && !fracNotZeroSrcNext + val isSubnormalSrcNext = !expNotZeroSrcNext && fracNotZeroSrcNext + val isnormalSrcNext = !expIsOnesSrcNext && expNotZeroSrcNext + val isInfSrcNext = expIsOnesSrcNext && !fracNotZeroSrcNext val isSNaNSrcNext = isNaNSrcNext && !fracSrc.head(1) - val isQNaNSrcNext = isNaNSrcNext && fracSrc.head(1).asBool - // for sqrt7/rec7 - val isEstimate7Next = opType(5) - val isRecNext = opType(5) && opType(0) + val (isFpWidenNext, isFpNarrowNext, isFp2IntNext, isFpCrossHighNext, isFpCrossLowNext) = + (outIsFpNext && isWiden, outIsFpNext && isNarrow, !outIsFpNext, + outIsFpNext && isCrossHigh, outIsFpNext && isCrossLow) - val decodeFloatSrcRec = Mux1H(float1HSrcNext, - fpParamMap.map(fp => expSrcNext(fp.expWidth - 1, 0)).zip(fpParamMap.map(fp => fp.expWidth)).map { case (exp, expWidth) => - VecInit( - exp.head(expWidth-1).andR && !exp(0), - exp.head(expWidth-2).andR && !exp(1) && exp(0) - ).asUInt - } - ) - - val (isNormalRec0Next, isNormalRec1Next) = (decodeFloatSrcRec(0), decodeFloatSrcRec(1)) - val isNormalRec2Next = expNotZeroSrcNext && !expIsOnesSrcNext && !isNormalRec0Next && !isNormalRec1Next - val isSubnormalRec0Next = isSubnormalSrcNext && fracSrc.head(1).asBool - val isSubnormalRec1Next = isSubnormalSrcNext && !fracSrc.head(1) && fracSrc.tail(1).head(1).asBool - val isSubnormalRec2Next = isSubnormalSrcNext && !fracSrc.head(2).orR - - // type int->fp, fp->fp widen, fp->fp Narrow, fp->int - val (isInt2FpNext, isFpWidenNext, isFpNarrowNext, isFp2IntNext) = - (!inIsFpNext, inIsFpNext && outIsFpNext && isWiden, inIsFpNext && outIsFpNext && isNarrow, !outIsFpNext) - - //contral sign to cycle1 - val expNotZeroSrc = RegEnable(expNotZeroSrcNext, false.B, fire) + //s1 val expIsOnesSrc = RegEnable(expIsOnesSrcNext, false.B, fire) val fracNotZeroSrc = RegEnable(fracNotZeroSrcNext, false.B, fire) - val expIsZeroSrc = RegEnable(expIsZeroSrcNext, false.B, fire) - val fracIsZeroSrc = RegEnable(fracIsZeroSrcNext, false.B, fire) - val isSubnormalSrc = RegEnable(isSubnormalSrcNext, false.B, fire) - val isnormalSrc = RegEnable(isnormalSrcNext, false.B, fire) val isInfSrc = RegEnable(isInfSrcNext, false.B, fire) val isZeroSrc = RegEnable(isZeroSrcNext, false.B, fire) - val isNaNSrc = RegEnable(isNaNSrcNext, false.B, fire) + val isSubnormalSrc = RegEnable(isSubnormalSrcNext, false.B, fire) + val isnormalSrc = RegEnable(isnormalSrcNext, false.B, fire) val isSNaNSrc = RegEnable(isSNaNSrcNext, false.B, fire) - val isQNaNSrc = RegEnable(isQNaNSrcNext, false.B, fire) - val isNormalRec0 = RegEnable(isNormalRec0Next, false.B, fire) - val isNormalRec1 = RegEnable(isNormalRec1Next, false.B, fire) - val isNormalRec2 = RegEnable(isNormalRec2Next, false.B, fire) - val isSubnormalRec0 = RegEnable(isSubnormalRec0Next, false.B, fire) - val isSubnormalRec1 = RegEnable(isSubnormalRec1Next, false.B, fire) - val isSubnormalRec2 = RegEnable(isSubnormalRec2Next, false.B, fire) - val isRec = RegEnable(isRecNext, false.B, fire) - - val isInt2Fp = RegEnable(isInt2FpNext, false.B, fire) val isFpWiden = RegEnable(isFpWidenNext, false.B, fire) val isFpNarrow = RegEnable(isFpNarrowNext, false.B, fire) - val isEstimate7 = RegEnable(isEstimate7Next, false.B, fire) val isFp2Int = RegEnable(isFp2IntNext, false.B, fire) + val isFpCrossHigh = RegEnable(isFpCrossHighNext, false.B, fire) + val isFpCrossLow = RegEnable(isFpCrossLowNext, false.B, fire) + val isNaNSrc = RegEnable(isNaNSrcNext, false.B, fire) + val s0_fpCanonicalNAN = isFpToVecInst & (input1H(1) & !src.head(48).andR | input1H(2) & !src.head(32).andR) + val s1_fpCanonicalNAN = RegEnable(s0_fpCanonicalNAN, fire) // for fpnarrow sub val trunSticky = RegEnable(fracSrc.tail(f32.fracWidth).orR, false.B, fire) - val signSrc = RegEnable(signSrcNext, false.B, fire) val rm = RegEnable(rmNext, 0.U(3.W), fire) - val hasSignInt = RegEnable(hasSignIntNext, false.B, fire) - val isZeroIntSrc = RegEnable(isZeroIntSrcNext, false.B, fire) val signNonNan = !isNaNSrc && signSrc - - /** critical path - * share: - * 1.count leading zero Max is 64 - * 2.adder: exp - * 3.shift left/right(UInt) - * 4.rounding module: +1(exp & roundInput) - * 5.Mux/Mux1H - * - * general step: - * step1: clz + adder -> compute really exp - * step2: shift left/right -> put the first one to the correct position - * step3: rounding - * step4: select result and fflags by mux1H - * - * pipe: - * cycle0: adder:64bits -> 13bits adder -> sl 6bits/rl 7bits - * cycle1: adder:64bits -> adder: 64bits -> Mux/Mux1H - * cycle2: result/fflags - * | exp adder | - * int->fp: abs(adder) | -> sl -> rounding(adder) -> Mux/Mux1H | - * fpwiden: fpdecode | -> Mux/Mux1H | - * fpNarrow(nor): fpdecode | -> sl -> rounding(adder) --\ | - * fpNarrow(sub): fpdecode -> exp adder2 | -> sr -> rounding(adder) -> Mux/Mux1H | - * estimate7: fpdecode | -> decoder -> Mux | - * fp-> int: fpdecode -> exp adder2 | -> sr -> rounding(adder) -> ~+1 -> Mux/Mux1H | - * | -> result & fflags - */ - - // for cycle1 val output1HReg = RegEnable(output1H, 0.U(4.W), fire) val float1HOut = Wire(UInt(3.W)) float1HOut := output1HReg.head(3) val int1HOut = Wire(UInt(4.W)) int1HOut := output1HReg - //for cycle2 -> output + //output val nv, dz, of, uf, nx = Wire(Bool()) //cycle1 val fflagsNext = Wire(UInt(5.W)) val fflags = RegEnable(fflagsNext, 0.U(5.W), fireReg) val resultNext = Wire(UInt(64.W)) val result = RegEnable(resultNext, 0.U(64.W), fireReg) - /** clz - * for: int->fp, fp->fp widen, estimate7, reuse clz according to fracSrc << (64 - f64.fracWidth) - * cycle: 0 - */ - val clzIn = Mux(inIsFpNext, fracSrc<<(64 - f64.fracWidth), absIntSrcNext).asUInt - val leadZerosNext = CLZ(clzIn) - - /** exp adder - * for: all exp compute - * cycle: 1 - */ - - val type1H = Cat(isInt2FpNext, isFpWidenNext, isFpNarrowNext, isEstimate7Next, isFp2IntNext).asBools.reverse + //exp val expAdderIn0Next = Wire(UInt(widthExpAdder.W)) //13bits is enough val expAdderIn1Next = Wire(UInt(widthExpAdder.W)) val expAdderIn0 = RegEnable(expAdderIn0Next, fire) val expAdderIn1 = RegEnable(expAdderIn1Next, fire) + val exp = Wire(UInt(widthExpAdder.W)) + exp := expAdderIn0 + expAdderIn1 - val biasDelta = Mux1H(float1HOutNext.tail(1), biasDeltaMap.map(delta => delta.U)) + val leadZerosNext = CLZ((fracSrc<<(64 - f64.fracWidth)).asUInt) + val type1H = Cat(isFpWidenNext, isFpCrossHighNext, isFpNarrowNext || isFpCrossLowNext,isFp2IntNext).asBools.reverse + expAdderIn0Next := Mux1H(type1H, Seq( + Mux1H(float1HOutNext.head(2), biasDeltaMap.take(2).map(delta => delta.U)), + biasDeltaMap(2).U, + Mux(isSubnormalSrcNext, false.B ## 1.U, false.B ## expSrcNext), + Mux(isSubnormalSrcNext, false.B ## 1.U, false.B ## expSrcNext) + ) + ) + val biasDelta = Mux1H(float1HOutNext.tail(1), biasDeltaMap.take(2).map(delta => delta.U)) val bias = Mux1H(float1HSrcNext, fpParamMap.map(fp => fp.bias.U)) val minusExp = extend((~(false.B ## Mux1H( - Cat(isInt2FpNext || isFpWidenNext, isFpNarrowNext, isEstimate7Next, isFp2IntNext).asBools.reverse, + Cat(isFpWidenNext || isFpCrossHighNext, isFpNarrowNext, isFpCrossLowNext, isFp2IntNext).asBools.reverse, Seq( leadZerosNext, biasDelta, - expSrcNext, + biasDeltaMap(2).U, bias )))).asUInt + 1.U, widthExpAdder).asUInt - - expAdderIn0Next := Mux1H(type1H, Seq( - Mux1H(float1HOutNext, fpParamMap.map(fp => (fp.bias + 63).U)), - Mux1H(float1HOutNext.head(2), biasDeltaMap.map(delta => delta.U)), - Mux(isSubnormalSrcNext, false.B ## 1.U, false.B ## expSrcNext), - Mux1H(float1HOutNext, fpParamMap.map(fp => Mux(isRecNext, (2 * fp.bias - 1).U, (3 * fp.bias - 1).U))), - Mux(isSubnormalSrcNext, false.B ## 1.U, false.B ## expSrcNext) - ) - ) - expAdderIn1Next := Mux1H( - Cat(isInt2FpNext || isFpNarrowNext || isFp2IntNext, isFpWidenNext, isEstimate7Next).asBools.reverse, + Cat(isFpNarrowNext || isFp2IntNext || isFpCrossLowNext, isFpWidenNext || isFpCrossHighNext).asBools.reverse, Seq( minusExp, Mux(isSubnormalSrcNext, minusExp, expSrcNext), - Mux(isSubnormalSrcNext, leadZerosNext, minusExp), ) ) - val exp = Wire(UInt(widthExpAdder.W)) - exp := expAdderIn0 + expAdderIn1 - - // for estimate7 - val expNormaled = Mux(isSubnormalSrcNext, leadZerosNext(0), expSrcNext(0)) //only the last bit is needed - val expNormaled0 = RegEnable(expNormaled(0), false.B, fire) - - /** shift left - * for: int->fp, fp->fp widen, estimate7, reuse shift left according to fracSrc << (64 - f64.fracWidth) - * cycle: 1 - * - */ + //frac val fracSrcLeftNext = Wire(UInt(64.W)) fracSrcLeftNext := fracSrc << (64 - f64.fracWidth) - val inIsFp = RegEnable(inIsFpNext, false.B, fire) - val fracSrcLeft = RegEnable(fracSrcLeftNext, 0.U(64.W), fire) - val absIntSrc = RegEnable(absIntSrcNext, fire) val leadZeros = RegEnable(leadZerosNext, fire) - + val fracSrcLeft = RegEnable(fracSrcLeftNext, 0.U(64.W), fire) val shiftLeft = Wire(UInt(64.W)) - shiftLeft := (Mux(inIsFp, fracSrcLeft, absIntSrc).asUInt << 1) << leadZeros //cycle1 - // for estimate7 & fp->fp widen + shiftLeft := (fracSrcLeft.asUInt << 1) << leadZeros //cycle1 val fracNormaled = Wire(UInt(64.W)) fracNormaled := Mux(isSubnormalSrc, shiftLeft, fracSrcLeft) //cycle1 - /** shift right * for: fp->fp Narrow, fp->int * cycle: 1 @@ -278,8 +273,8 @@ class CVT64(width: Int = 64) extends CVT(width){ val fracValueSrc = (expNotZeroSrcNext && !expIsOnesSrcNext) ## fracSrc val shamtInNext = fracValueSrc ## 0.U(11.W) ## false.B //fp Narrow & fp->int val shamtWidth = Mux(!outIsFpNext, Mux1H(float1HSrcNext, fpParamMap.map(fp => (63+fp.bias).U)), - Mux1H(float1HOutNext.tail(1), biasDeltaMap.map(delta => (delta + 1).U)) - ) - expSrcNext + Mux(isCrossLow, (biasDeltaMap(2) + 1).U, Mux1H(float1HOutNext.tail(1), biasDeltaMap.take(2).map(delta => (delta + 1).U))) + ) - expSrcNext val shamtNext = Mux(shamtWidth >= 65.U, 65.U, shamtWidth) val shamtIn = RegEnable(shamtInNext, fire) @@ -289,14 +284,12 @@ class CVT64(width: Int = 64) extends CVT(width){ val sticky = Wire(Bool()) inRounder := inRounderTmp sticky := stickyTmp - - /** rounder * for: int->fp, fp-fp Narrow, fp->int * cycle: 1 */ val rounderMapIn = Wire(UInt(64.W)) - rounderMapIn := Mux(isFpNarrow, fracSrcLeft, shiftLeft) + rounderMapIn := Mux(isFpNarrow || isFpCrossLow, fracSrcLeft, shiftLeft) val rounderMap = fpParamMap.map(fp => Seq( @@ -310,10 +303,7 @@ class CVT64(width: Int = 64) extends CVT(width){ val (rounderInputMap, rounerInMap, rounderStikyMap, isOnesRounderInputMap) = { (rounderMap(0), rounderMap(1), rounderMap(2), rounderMap(3)) } - val rounderInput = Mux(isFp2Int, inRounder.head(64), Mux1H(float1HOut, rounderInputMap)) - - val rounder = Module(new RoundingUnit(64)) rounder.io.in := rounderInput rounder.io.roundIn := Mux(isFp2Int, inRounder(0), Mux1H(float1HOut, rounerInMap)) @@ -324,21 +314,13 @@ class CVT64(width: Int = 64) extends CVT(width){ // from rounder val nxRounded = rounder.io.inexact val upRounded = rounder.io.r_up - - /** after rounding - * for all exclude estimate7 & fp->fp widen - * cycle: 1 - */ val expIncrease = exp + 1.U val rounderInputIncrease = rounderInput + 1.U - // for fp2int // 8bit: => u64, i64, u32, i32, u16, i16, u8, i8 val hasSignInt1HOut = int1HOut.asBools.map(oh => Seq(oh && !hasSignInt, oh && hasSignInt)).flatten val isOnesRounderInputMapFp2Int = intParamMap.map(intType => Seq(intType, intType - 1)).flatten.map(intType => rounderInput.tail(64 - intType).andR) - - // for all val cout = upRounded && Mux(isFp2Int, Mux1H(hasSignInt1HOut, isOnesRounderInputMapFp2Int), Mux1H(float1HOut, isOnesRounderInputMap) @@ -346,52 +328,10 @@ class CVT64(width: Int = 64) extends CVT(width){ val expRounded = Wire(UInt(f64.expWidth.W)) expRounded := Mux(cout, expIncrease, exp) val fracRounded = Mux(upRounded, rounderInputIncrease, rounderInput) - val rmin = rm === RTZ || (signSrc && rm === RUP) || (!signSrc && rm === RDN) //cycle1 - - - /** Mux/Mux1H - * cycle: 1 - */ - when(isInt2Fp){ - /** int->fp any int/uint-> any fp - */ - // Mux(cout, exp > FP.maxExp -1, exp > FP.maxExp) - val ofRounded = !exp.head(1).asBool && Mux1H(float1HOut, - fpParamMap.map(fp => Mux(cout, - exp(fp.expWidth - 1, 1).andR || exp(exp.getWidth - 2, fp.expWidth).orR, - exp(fp.expWidth - 1, 0).andR || exp(exp.getWidth - 2, fp.expWidth).orR) - ) - ) - - nv := false.B - dz := false.B - of := ofRounded - uf := false.B - nx := ofRounded || nxRounded - - val result1H = Cat( - ofRounded && rmin, - ofRounded && !rmin, - isZeroIntSrc, - !ofRounded && !isZeroIntSrc - ) - - def int2FpResultMapGen(fp: FloatFormat): Seq[UInt] = { - VecInit((0 to 3).map { - case 0 => signSrc ## fp.maxExp.U(fp.expWidth.W) ## ~0.U(fp.fracWidth.W) //GNF - case 1 => signSrc ## ~0.U(fp.expWidth.W) ## 0.U(fp.fracWidth.W) // INF - case 2 => signSrc ## 0.U((fp.width - 1).W) // 0 - case 3 => signSrc ## expRounded(fp.expWidth-1, 0) ## fracRounded(fp.fracWidth-1, 0) // normal - }) - } - - val int2FpResultMap: Seq[UInt] = fpParamMap.map(fp => Mux1H(result1H.asBools.reverse, int2FpResultMapGen(fp))) - resultNext := Mux1H(float1HOut, int2FpResultMap) - - }.elsewhen(isFpWiden){ - /** fp -> fp widen + when(isFpWiden || isFpCrossHigh){ + /** fp -> fp widen/ fp16 -> fp64 cross high */ def fpWidenResultMapGen(fp: FloatFormat): Seq[UInt] = { VecInit((0 to 2).map { @@ -400,13 +340,11 @@ class CVT64(width: Int = 64) extends CVT(width){ case 2 => signNonNan ## exp(fp.expWidth - 1, 0) ## fracNormaled.head(fp.fracWidth) }) } - val result1H = Cat( expIsOnesSrc, isZeroSrc, isSubnormalSrc || isnormalSrc ) - nv := isSNaNSrc && !s1_fpCanonicalNAN dz := false.B of := false.B @@ -415,15 +353,11 @@ class CVT64(width: Int = 64) extends CVT(width){ val fpwidenResultMap: Seq[UInt] = Seq(f32, f64).map(fp => Mux1H(result1H.asBools.reverse, fpWidenResultMapGen(fp))) resultNext := Mux1H(float1HOut.head(2), fpwidenResultMap) - - }.elsewhen(isFpNarrow){ - /** fp -> fp Narrow + }.elsewhen(isFpNarrow || isFpCrossLow){ + /** fp -> fp Narrow / fp64 -> fp16 cross low * note: IEEE754 uf:exp in (-b^emin, b^emin), after rounding(RiscV!!!) * note: IEEE754 uf:exp in (-b^emin, b^emin), before rounding(other) */ - - /**dest is normal - */ // Mux(cout, exp > FP.maxExp -1, exp > FP.maxExp) val ofRounded = !exp.head(1).asBool && Mux1H(float1HOut, fpParamMap.map(fp => Mux(cout, @@ -431,11 +365,9 @@ class CVT64(width: Int = 64) extends CVT(width){ exp(fp.expWidth - 1, 0).andR || exp(exp.getWidth - 2, fp.expWidth).orR) ) ) - //val ufExpRounded = Mux(cout, interExp < 0.S, interExp < 1.S) val ufExpRounded = Mux(cout, exp.head(1).asBool, exp.head(1).asBool || !exp.orR) val nxOfRounded = nxRounded || ofRounded - /** dest is Subnormal * dest: 1-toBias, src: srcExp - srcBias * src->dest :exp = srcExp - srcBias + toBias @@ -452,7 +384,7 @@ class CVT64(width: Int = 64) extends CVT(width){ subFrac.tail(fp.fracWidth+1).head(1), //1+toFracWidth +1 => drop head & tail trunSticky || shiftSticky || subFrac.tail(fp.fracWidth+2).orR, subFrac.tail(1).head(fp.fracWidth).andR - ) + ) ).transpose val (subRounderInputMap, subRounerInMap, subRounderStikyMap, subIsOnesRounderInputMap) = { @@ -466,25 +398,22 @@ class CVT64(width: Int = 64) extends CVT(width){ subRounder.io.stickyIn := Mux1H(float1HOut.tail(1), subRounderStikyMap) subRounder.io.signIn := signSrc subRounder.io.rm := rm - // from roundingUnit val subNxRounded = subRounder.io.inexact val subUpRounded = subRounder.io.r_up - // out of roundingUint subFracRounded := Mux(subUpRounded, subRounderInput + 1.U, subRounderInput) val subCout = subUpRounded && Mux1H(float1HOut.tail(1), subIsOnesRounderInputMap).asBool subExpRounded := Mux(subCout, 1.U, 0.U) - - nv := isSNaNSrc + nv := isSNaNSrc && !s1_fpCanonicalNAN dz := false.B - of := !expIsOnesSrc && ofRounded - uf := !expIsOnesSrc && maybeSub && ufExpRounded && subNxRounded + of := !expIsOnesSrc && ofRounded && !s1_fpCanonicalNAN + uf := !expIsOnesSrc && maybeSub && ufExpRounded && subNxRounded && !s1_fpCanonicalNAN nx := !expIsOnesSrc && ( (!maybeSub && nxOfRounded) || (maybeSub && subNxRounded) - ) + ) && !s1_fpCanonicalNAN val result1H = Cat( expIsOnesSrc, @@ -493,7 +422,6 @@ class CVT64(width: Int = 64) extends CVT(width){ !expIsOnesSrc && !maybeSub && !ofRounded, !expIsOnesSrc && maybeSub ) - def fpNarrowResultMapGen(fp: FloatFormat): Seq[UInt] ={ VecInit((0 to 4).map { case 0 => signNonNan ## ~0.U(fp.expWidth.W) ## fracNotZeroSrc ## 0.U((fp.fracWidth - 1).W) // INF or NaN->QNAN @@ -503,75 +431,15 @@ class CVT64(width: Int = 64) extends CVT(width){ case 4 => signNonNan ## subExpRounded(fp.expWidth - 1, 0) ## subFracRounded(fp.fracWidth - 1, 0) //sub or uf }) } - val fpNarrowResultMap: Seq[UInt] = Seq(f16, f32).map(fp => Mux1H(result1H.asBools.reverse, fpNarrowResultMapGen(fp))) resultNext := Mux1H(float1HOut.tail(1), fpNarrowResultMap) - }.elsewhen(isEstimate7) { - /** Estimate7: sqrt7 & rec7 - */ - - val rsqrt7Table = Module(new Rsqrt7Table) - rsqrt7Table.src := expNormaled0 ## fracNormaled.head(6) - val rec7Table = Module(new Rec7Table) - rec7Table.src := fracNormaled.head(7) - val fracEstimate = Mux(isRec, rec7Table.out, rsqrt7Table.out) - - nv := Mux(isRec, isSNaNSrc, (signSrc && !isZeroSrc && !isQNaNSrc) | isSNaNSrc) - dz := isZeroSrc - of := isRec && isSubnormalRec2 - uf := false.B - nx := of - - def recResultMapGen(fp: FloatFormat): Seq[UInt] = { - VecInit((0 to 6).map { - case 0 => false.B ## ~0.U(fp.expWidth.W) ## true.B ## 0.U((fp.fracWidth - 1).W) //can - case 1 => signSrc ## 0.U((fp.width - 1).W) //0 - case 2 => signSrc ## ~0.U(fp.expWidth.W) ## 0.U(fp.fracWidth.W) //INF - case 3 => signSrc ## 0.U(fp.expWidth.W) ## 1.U(2.W) ## fracEstimate ## 0.U((fp.fracWidth - 2 - 7).W) - case 4 => signSrc ## 0.U(fp.expWidth.W) ## 1.U(1.W) ## fracEstimate ## 0.U((fp.fracWidth - 1 - 7).W) - case 5 => signSrc ## exp(fp.expWidth - 1, 0) ## fracEstimate ## 0.U((fp.fracWidth - 7).W) - case 6 => signSrc ## fp.maxExp.U(fp.expWidth.W) ## ~0.U(fp.fracWidth.W) //GNF - }) - } - val recResult1H = Cat( - isNaNSrc, - isInfSrc, - isZeroSrc || isSubnormalRec2 && !rmin, - isNormalRec0, - isNormalRec1, - isNormalRec2 || isSubnormalRec0 || isSubnormalRec1, - isSubnormalRec2 && rmin - ) - val recResultMap: Seq[UInt] = fpParamMap.map(fp => Mux1H(recResult1H.asBools.reverse, recResultMapGen(fp))) - - def sqrtResultMapGen(fp: FloatFormat): Seq[UInt] = { - VecInit((0 to 3).map { - case 0 => false.B ## ~0.U(fp.expWidth.W) ## true.B ## 0.U((fp.fracWidth - 1).W) - case 1 => signSrc ## ~0.U(fp.expWidth.W) ## 0.U(fp.fracWidth.W) - case 2 => signSrc ## exp(fp.expWidth, 1) ## fracEstimate ## 0.U((fp.fracWidth - 7).W) // exp/2 => >>1 - case 3 => 0.U(fp.width.W) - }) - } - val sqrtResult1H = Cat( - signSrc & !isZeroSrc | isNaNSrc, - isZeroSrc, - !signSrc & !isZeroSrc & !expIsOnesSrc, - !signSrc & isInfSrc, - ) - val sqrtResultMap: Seq[UInt] = fpParamMap.map(fp => Mux1H(sqrtResult1H.asBools.reverse, sqrtResultMapGen(fp))) - resultNext := Mux(isRec, Mux1H(float1HOut, recResultMap), Mux1H(float1HOut, sqrtResultMap)) - - }.otherwise {//5 - // !outIsFp + }.otherwise{ /** out is int, any fp->any int/uint * drop the shift left! - * todo: detail refactor exclude */ val resultRounded = fracRounded val isZeroRounded = !resultRounded.orR - val normalResult = Mux(signSrc && resultRounded.orR, (~resultRounded).asUInt + 1.U, resultRounded) //exclude 0 - // i=log2(intType) val ofExpRounded = !exp.head(1) && Mux1H(int1HOut, (3 to 6).map(i => @@ -584,10 +452,8 @@ class CVT64(width: Int = 64) extends CVT(width){ ) ) ) - val excludeFrac = Mux1H(int1HOut, intParamMap.map(intType => resultRounded(intType - 1) && !resultRounded(intType - 2, 0).orR)) // 10000***000 - // i=log2(intType) val excludeExp = Mux1H(int1HOut, (3 to 6).map(i => !exp.head(exp.getWidth - i).orR && @@ -597,11 +463,9 @@ class CVT64(width: Int = 64) extends CVT(width){ ) ) ) - val toUnv = ofExpRounded || expIsOnesSrc || signSrc && !(isZeroSrc || isZeroRounded && !ofExpRounded) //exclude 0 & -0 after rounding val toUnx = !toUnv && nxRounded - val toInv = ofExpRounded && !(signSrc && excludeExp && excludeFrac) || expIsOnesSrc //nv has included inf & nan val toInx = !toInv && nxRounded @@ -610,36 +474,300 @@ class CVT64(width: Int = 64) extends CVT(width){ of := false.B uf := false.B nx := Mux(hasSignInt, toInx, toUnx) - - val result1H = Cat( (!hasSignInt && !toUnv) || (hasSignInt && !toInv), //toUnv include nan & inf !hasSignInt && toUnv && (isNaNSrc || !signSrc && (isInfSrc || ofExpRounded)), !hasSignInt && toUnv && signSrc && !isNaNSrc, hasSignInt && toInv ) - resultNext := Mux1H(result1H.asBools.reverse, Seq( - normalResult, - (~0.U(64.W)).asUInt, - 0.U(64.W), - Mux1H(int1HOut, intParamMap.map(intType => signNonNan ## Fill(intType - 1, !signNonNan))) - ) + normalResult, + (~0.U(64.W)).asUInt, + 0.U(64.W), + Mux1H(int1HOut, intParamMap.map(intType => signNonNan ## Fill(intType - 1, !signNonNan))) + ) ) } + fflagsNext := Cat(nv, dz, of, uf, nx) + io.result := result + io.fflags := fflags +} +class INT2FP extends Module{ + val io = IO(new INTCVT_IO) + //parameter + val fpParamMap = Seq(f16, f32, f64) + val biasDeltaMap = Seq(f32.bias - f16.bias, f64.bias - f32.bias, f64.bias - f16.bias) + val widthExpAdder = 13 // 13bits is enough + //input + val (fire, src, opType, rmNext, input1H, output1H) = + (io.fire, io.src, io.opType, io.rm, io.input1H, io.output1H) + val fireReg = GatedValidRegNext(fire) + val hasSignIntNext = opType(0).asBool + val int1HSrcNext = input1H + val float1HOutNext = output1H.head(3)//exclude f8 + val output1HReg = RegEnable(output1H, 0.U(4.W), fire) + val float1HOut = Wire(UInt(3.W)) + float1HOut := output1HReg.head(3) + val srcMap = (0 to 3).map(i => src((1 << i) * 8 - 1, 0)) + val intMap = srcMap.map(int => intExtend(int, hasSignIntNext && int.head(1).asBool)) + val input = Mux1H(int1HSrcNext, intMap) + val signSrcNext = input.head(1).asBool + val signSrc = RegEnable(signSrcNext, false.B, fire) + val rm = RegEnable(rmNext, 0.U(3.W), fire) + // src is int + val absIntSrcNext = Wire(UInt(64.W)) //cycle0 + absIntSrcNext := Mux(signSrcNext, (~input.tail(1)).asUInt + 1.U, input.tail(1)) + val isZeroIntSrcNext = !absIntSrcNext.orR + val isZeroIntSrc = RegEnable(isZeroIntSrcNext, false.B, fire) + //CLZ + val clzIn = absIntSrcNext.asUInt + val leadZerosNext = CLZ(clzIn) + //exp + val expAdderIn0Next = Wire(UInt(widthExpAdder.W)) //13bits is enough + val expAdderIn1Next = Wire(UInt(widthExpAdder.W)) + val expAdderIn0 = RegEnable(expAdderIn0Next, fire) + val expAdderIn1 = RegEnable(expAdderIn1Next, fire) + val minusExp = extend((~(false.B ## leadZerosNext)).asUInt + + 1.U, widthExpAdder).asUInt + expAdderIn0Next := Mux1H(float1HOutNext, fpParamMap.map(fp => (fp.bias + 63).U)) + expAdderIn1Next := minusExp + val exp = Wire(UInt(widthExpAdder.W)) + exp := expAdderIn0 + expAdderIn1 + //frac + val absIntSrc = RegEnable(absIntSrcNext, fire) + val leadZeros = RegEnable(leadZerosNext, fire) + val shiftLeft = Wire(UInt(64.W)) + shiftLeft := (absIntSrc.asUInt << 1) << leadZeros //cycle1 + //round + val rounderMapIn = Wire(UInt(64.W)) + rounderMapIn := shiftLeft + val rounderMap = + fpParamMap.map(fp => Seq( + rounderMapIn.head(fp.fracWidth), + rounderMapIn.tail(fp.fracWidth).head(1), + rounderMapIn.tail(fp.fracWidth + 1).orR, + rounderMapIn.head(fp.fracWidth).andR + ) + ).transpose + val (rounderInputMap, rounerInMap, rounderStikyMap, isOnesRounderInputMap) = { + (rounderMap(0), rounderMap(1), rounderMap(2), rounderMap(3)) + } + val rounderInput = Mux1H(float1HOut, rounderInputMap) + val rounder = Module(new RoundingUnit(64)) + rounder.io.in := rounderInput + rounder.io.roundIn := Mux1H(float1HOut, rounerInMap) + rounder.io.stickyIn := Mux1H(float1HOut, rounderStikyMap) + rounder.io.signIn := signSrc + rounder.io.rm := rm + val expIncrease = exp + 1.U + val rounderInputIncrease = rounderInput + 1.U + // from rounder + val nxRounded = rounder.io.inexact + val upRounded = rounder.io.r_up + val cout = upRounded && Mux1H(float1HOut, isOnesRounderInputMap).asBool + val expRounded = Wire(UInt(f64.expWidth.W)) + expRounded := Mux(cout, expIncrease, exp) + val fracRounded = Mux(upRounded, rounderInputIncrease, rounderInput) + val rmin = + rm === RTZ || (signSrc && rm === RUP) || (!signSrc && rm === RDN) //cycle1 + /** int->fp any int/uint-> any fp + */ + // Mux(cout, exp > FP.maxExp -1, exp > FP.maxExp) + val ofRounded = !exp.head(1).asBool && Mux1H(float1HOut, + fpParamMap.map(fp => Mux(cout, + exp(fp.expWidth - 1, 1).andR || exp(exp.getWidth - 2, fp.expWidth).orR, + exp(fp.expWidth - 1, 0).andR || exp(exp.getWidth - 2, fp.expWidth).orR) + ) + ) + val nv, dz, of, uf, nx = Wire(Bool()) //cycle1 + val fflagsNext = Wire(UInt(5.W)) + val fflags = RegEnable(fflagsNext, 0.U(5.W), fireReg) + val resultNext = Wire(UInt(64.W)) + val result = RegEnable(resultNext, 0.U(64.W), fireReg) + nv := false.B + dz := false.B + of := ofRounded + uf := false.B + nx := ofRounded || nxRounded + + val result1H = Cat( + ofRounded && rmin, + ofRounded && !rmin, + isZeroIntSrc, + !ofRounded && !isZeroIntSrc + ) + def int2FpResultMapGen(fp: FloatFormat): Seq[UInt] = { + VecInit((0 to 3).map { + case 0 => signSrc ## fp.maxExp.U(fp.expWidth.W) ## ~0.U(fp.fracWidth.W) //GNF + case 1 => signSrc ## ~0.U(fp.expWidth.W) ## 0.U(fp.fracWidth.W) // INF + case 2 => signSrc ## 0.U((fp.width - 1).W) // 0 + case 3 => signSrc ## expRounded(fp.expWidth-1, 0) ## fracRounded(fp.fracWidth-1, 0) // normal + }) + } + val int2FpResultMap: Seq[UInt] = fpParamMap.map(fp => Mux1H(result1H.asBools.reverse, int2FpResultMapGen(fp))) + resultNext := Mux1H(float1HOut, int2FpResultMap) + //output fflagsNext := Cat(nv, dz, of, uf, nx) + io.result := result + io.fflags := fflags +} +class Estimate7 extends Module{ + /** Estimate7: sqrt7 & rec7 + */ + val io = IO(new INTCVT_IO) + //parameter + val fpParamMap = Seq(f16, f32, f64) + val biasDeltaMap = Seq(f32.bias - f16.bias, f64.bias - f32.bias, f64.bias - f16.bias) + val widthExpAdder = 13 // 13bits is enough + //input + val (fire, src, opType, rmNext, input1H, output1H) = + (io.fire, io.src, io.opType, io.rm, io.input1H, io.output1H) + val fireReg = GatedValidRegNext(fire) + val int1HSrcNext = input1H + val float1HSrcNext = input1H.head(3)//exclude f8 + val int1HOutNext = output1H + val float1HOutNext = output1H.head(3)//exclude f8 + val srcMap = (0 to 3).map(i => src((1 << i) * 8 - 1, 0)) + val floatMap = srcMap.zipWithIndex.map{case (float,i) => floatExtend(float, i)}.drop(1) + val input = Mux1H(float1HSrcNext, floatMap) + val signSrcNext = input.head(1).asBool + val isEstimate7Next = opType(5) + val isRecNext = opType(5) && opType(0) + val rm = RegEnable(rmNext, 0.U(3.W), fire) + val expSrcNext = input.tail(1).head(f64.expWidth) + val fracSrc = input.tail(f64.expWidth+1).head(f64.fracWidth) + val decodeFloatSrc = Mux1H(float1HSrcNext, fpParamMap.map(fp => + VecInit(expSrcNext(fp.expWidth-1,0).orR, expSrcNext(fp.expWidth-1,0).andR, fracSrc.head(fp.fracWidth).orR).asUInt + ) + ) + val (expNotZeroSrcNext, expIsOnesSrcNext, fracNotZeroSrcNext) = (decodeFloatSrc(0), decodeFloatSrc(1), decodeFloatSrc(2)) + val isSubnormalSrcNext = !expNotZeroSrcNext && fracNotZeroSrcNext + val isZeroSrcNext = !expNotZeroSrcNext && !fracNotZeroSrcNext + val isInfSrcNext = expIsOnesSrcNext && !fracNotZeroSrcNext + val isNaNSrcNext = expIsOnesSrcNext && fracNotZeroSrcNext + val isSNaNSrcNext = isNaNSrcNext && !fracSrc.head(1) + val isQNaNSrcNext = isNaNSrcNext && fracSrc.head(1).asBool + val isSubnormalRec2Next = isSubnormalSrcNext && !fracSrc.head(2).orR - val s1_resultForfpCanonicalNAN = Mux1H( - Seq(s1_outIsF64, s1_outIsU32, s1_outIsS32, s1_outIsU64, s1_outIsS64), - Seq(~0.U((f64.expWidth+1).W) ## 0.U((f64.fracWidth-1).W), - ~0.U(32.W), - ~0.U(31.W), - ~0.U(64.W), - ~0.U(63.W)) + val expIsOnesSrc = RegEnable(expIsOnesSrcNext, false.B, fire) + val isSubnormalSrc = RegEnable(isSubnormalSrcNext, false.B, fire) + val isRec = RegEnable(isRecNext, false.B, fire) + val isSNaNSrc = RegEnable(isSNaNSrcNext, false.B, fire) + val signSrc = RegEnable(signSrcNext, false.B, fire) + val isZeroSrc = RegEnable(isZeroSrcNext, false.B, fire) + val isQNaNSrc = RegEnable(isQNaNSrcNext, false.B, fire) + val isSubnormalRec2 = RegEnable(isSubnormalRec2Next, false.B, fire) + val isInfSrc = RegEnable(isInfSrcNext, false.B, fire) + val isNaNSrc = RegEnable(isNaNSrcNext, false.B, fire) + + val decodeFloatSrcRec = Mux1H(float1HSrcNext, + fpParamMap.map(fp => expSrcNext(fp.expWidth - 1, 0)).zip(fpParamMap.map(fp => fp.expWidth)).map { case (exp, expWidth) => + VecInit( + exp.head(expWidth-1).andR && !exp(0), + exp.head(expWidth-2).andR && !exp(1) && exp(0) + ).asUInt + } ) - val s2_resultForfpCanonicalNAN = RegEnable(s1_resultForfpCanonicalNAN, fireReg) - io.result := Mux(s2_fpCanonicalNAN, s2_resultForfpCanonicalNAN, result) - io.fflags := Mux(s2_fpCanonicalNAN && !s2_outIsF64, "b10000".U, fflags) + val (isNormalRec0Next, isNormalRec1Next) = (decodeFloatSrcRec(0), decodeFloatSrcRec(1)) + val isNormalRec2Next = expNotZeroSrcNext && !expIsOnesSrcNext && !isNormalRec0Next && !isNormalRec1Next + val isSubnormalRec0Next = isSubnormalSrcNext && fracSrc.head(1).asBool + val isSubnormalRec1Next = isSubnormalSrcNext && !fracSrc.head(1) && fracSrc.tail(1).head(1).asBool + + val isNormalRec0 = RegEnable(isNormalRec0Next, false.B, fire) + val isNormalRec1 = RegEnable(isNormalRec1Next, false.B, fire) + val isNormalRec2 = RegEnable(isNormalRec2Next, false.B, fire) + val isSubnormalRec0 = RegEnable(isSubnormalRec0Next, false.B, fire) + val isSubnormalRec1 = RegEnable(isSubnormalRec1Next, false.B, fire) + + val output1HReg = RegEnable(output1H, 0.U(4.W), fire) + val float1HOut = Wire(UInt(3.W)) + float1HOut := output1HReg.head(3) + val nv, dz, of, uf, nx = Wire(Bool()) //cycle1 + val fflagsNext = Wire(UInt(5.W)) + val fflags = RegEnable(fflagsNext, 0.U(5.W), fireReg) + val resultNext = Wire(UInt(64.W)) + val result = RegEnable(resultNext, 0.U(64.W), fireReg) + + val clzIn = (fracSrc<<(64 - f64.fracWidth)).asUInt + val leadZerosNext = CLZ(clzIn) + val rmin = + rm === RTZ || (signSrc && rm === RUP) || (!signSrc && rm === RDN) //cycle1 + //exp + val expAdderIn0Next = Wire(UInt(widthExpAdder.W)) //13bits is enough + val expAdderIn1Next = Wire(UInt(widthExpAdder.W)) + val expAdderIn0 = RegEnable(expAdderIn0Next, fire) + val expAdderIn1 = RegEnable(expAdderIn1Next, fire) + val minusExp = extend((~(false.B ## expSrcNext)).asUInt + 1.U, widthExpAdder).asUInt + expAdderIn0Next := Mux1H(float1HOutNext, fpParamMap.map(fp => Mux(isRecNext, (2 * fp.bias - 1).U, (3 * fp.bias - 1).U))) + expAdderIn1Next := Mux(isSubnormalSrcNext, leadZerosNext, minusExp) + val exp = Wire(UInt(widthExpAdder.W)) + exp := expAdderIn0 + expAdderIn1 + + val expNormaled = Mux(isSubnormalSrcNext, leadZerosNext(0), expSrcNext(0)) //only the last bit is needed + val expNormaled0 = RegEnable(expNormaled(0), false.B, fire) + val fracSrcLeftNext = Wire(UInt(64.W)) + fracSrcLeftNext := fracSrc << (64 - f64.fracWidth) + val fracSrcLeft = RegEnable(fracSrcLeftNext, 0.U(64.W), fire) + val leadZeros = RegEnable(leadZerosNext, fire) + + val shiftLeft = Wire(UInt(64.W)) + shiftLeft := (fracSrcLeft.asUInt << 1) << leadZeros //cycle1 + val fracNormaled = Wire(UInt(64.W)) + fracNormaled := Mux(isSubnormalSrc, shiftLeft, fracSrcLeft) //cycle1 + val rsqrt7Table = Module(new Rsqrt7Table) + rsqrt7Table.src := expNormaled0 ## fracNormaled.head(6) + val rec7Table = Module(new Rec7Table) + rec7Table.src := fracNormaled.head(7) + val fracEstimate = Mux(isRec, rec7Table.out, rsqrt7Table.out) + + nv := Mux(isRec, isSNaNSrc, (signSrc && !isZeroSrc && !isQNaNSrc) | isSNaNSrc) + dz := isZeroSrc + of := isRec && isSubnormalRec2 + uf := false.B + nx := of + def recResultMapGen(fp: FloatFormat): Seq[UInt] = { + VecInit((0 to 6).map { + case 0 => false.B ## ~0.U(fp.expWidth.W) ## true.B ## 0.U((fp.fracWidth - 1).W) //can + case 1 => signSrc ## 0.U((fp.width - 1).W) //0 + case 2 => signSrc ## ~0.U(fp.expWidth.W) ## 0.U(fp.fracWidth.W) //INF + case 3 => signSrc ## 0.U(fp.expWidth.W) ## 1.U(2.W) ## fracEstimate ## 0.U((fp.fracWidth - 2 - 7).W) + case 4 => signSrc ## 0.U(fp.expWidth.W) ## 1.U(1.W) ## fracEstimate ## 0.U((fp.fracWidth - 1 - 7).W) + case 5 => signSrc ## exp(fp.expWidth - 1, 0) ## fracEstimate ## 0.U((fp.fracWidth - 7).W) + case 6 => signSrc ## fp.maxExp.U(fp.expWidth.W) ## ~0.U(fp.fracWidth.W) //GNF + }) + } + val recResult1H = Cat( + isNaNSrc, + isInfSrc, + isZeroSrc || isSubnormalRec2 && !rmin, + isNormalRec0, + isNormalRec1, + isNormalRec2 || isSubnormalRec0 || isSubnormalRec1, + isSubnormalRec2 && rmin + ) + val recResultMap: Seq[UInt] = fpParamMap.map(fp => Mux1H(recResult1H.asBools.reverse, recResultMapGen(fp))) + def sqrtResultMapGen(fp: FloatFormat): Seq[UInt] = { + VecInit((0 to 3).map { + case 0 => false.B ## ~0.U(fp.expWidth.W) ## true.B ## 0.U((fp.fracWidth - 1).W) + case 1 => signSrc ## ~0.U(fp.expWidth.W) ## 0.U(fp.fracWidth.W) + case 2 => signSrc ## exp(fp.expWidth, 1) ## fracEstimate ## 0.U((fp.fracWidth - 7).W) // exp/2 => >>1 + case 3 => 0.U(fp.width.W) + }) + } + val sqrtResult1H = Cat( + signSrc & !isZeroSrc | isNaNSrc, + isZeroSrc, + !signSrc & !isZeroSrc & !expIsOnesSrc, + !signSrc & isInfSrc, + ) + val sqrtResultMap: Seq[UInt] = fpParamMap.map(fp => Mux1H(sqrtResult1H.asBools.reverse, sqrtResultMapGen(fp))) + resultNext := Mux(isRec, Mux1H(float1HOut, recResultMap), Mux1H(float1HOut, sqrtResultMap)) + + fflagsNext := Cat(nv, dz, of, uf, nx) + io.result := result + io.fflags := fflags } + + diff --git a/src/main/scala/yunsuan/vector/VectorConvert/VCVT.scala b/src/main/scala/yunsuan/vector/VectorConvert/VCVT.scala index bb2db97..ca4734f 100644 --- a/src/main/scala/yunsuan/vector/VectorConvert/VCVT.scala +++ b/src/main/scala/yunsuan/vector/VectorConvert/VCVT.scala @@ -25,7 +25,7 @@ class VCVT(width: Int) extends Module{ val vcvtImpl = width match { case 16 => Module(new CVT16(16)) case 32 => Module(new CVT32(32)) - case 64 => Module(new CVT64(64)) + case 64 => Module(new CVT64(64, true)) } io <> vcvtImpl.io } diff --git a/src/test/csrc/golden_model/gm_common.cpp b/src/test/csrc/golden_model/gm_common.cpp index ebfd0b8..5e3368f 100644 --- a/src/test/csrc/golden_model/gm_common.cpp +++ b/src/test/csrc/golden_model/gm_common.cpp @@ -13,11 +13,14 @@ VecOutput VPUGoldenModel::get_expected_output(VecInput input) { int half_number = number >> 1; int result_shift_len = 8 << sew; int widenNorrow = (input.fuOpType >> 3) & 0X3; + int i2f_inputType = (input.fuOpType >> 3) & 0X1; + int i2f_number = (128 / 8) >> (i2f_inputType+2); + int i2f_half_number = i2f_number >> 1; + int i2f_outputType = (input.fuOpType >> 1) & 0X3; softfloat_detectTininess = softfloat_tininess_afterRounding; uint64_t mask = 0; VecOutput output; ElementOutput output_part[number]; - if (input.fuType == VFloatCvt){ if(widenNorrow == 1){ //widen @@ -72,12 +75,122 @@ VecOutput VPUGoldenModel::get_expected_output(VecInput input) { } } } + }else if(input.fuType == FloatCvtF2X){ + half_number = 1; + if(widenNorrow == 1){ //widen + // half_number = half_number >> 1; + result_shift_len = result_shift_len << 1; + for(int i = 0; i < number; i++) { + ElementInput element = select_element(input, i); + switch (sew) { + case 1: output_part[i] = calculation_e16(element); mask = 0xFFFFFFFF; break; //fp16->fp32/int32/uint32 + case 2: output_part[i] = calculation_e32(element); mask = 0xFFFFFFFFFFFFFFFF; break;//fp32->fp64/int64/uint64 + default: + printf("VPU Golden Modle, bad sew %d\n", input.sew); + exit(1); + } + if (output_part[i].fflags > 0x1f) { + printf("Bad fflags of %x, check golden model e8 %d\n", output_part[i].fflags, i); + exit(1); + } + } + }else if(widenNorrow == 2){//narrow + // half_number = half_number >> 1; + for(int i = 0; i < number/2; i++) { + ElementInput element = select_element(input, i); + switch (sew) { + case 0: output_part[i] = calculation_e16(element); mask = 0xFF; break; + case 1: output_part[i] = calculation_e32(element); mask = 0xFFFF; break; //fp32->fp16 + case 2: output_part[i] = calculation_e64(element); mask = 0xFFFFFFFF; break; //fp64->fp32/i32/ui32 + default: + printf("VPU Golden Modle, bad sew %d\n", input.sew); + exit(1); + } + if (output_part[i].fflags > 0x1f) { + printf("Bad fflags of %x, check golden model e8 %d\n", output_part[i].fflags, i); + exit(1); + } + } + }else if(widenNorrow == 3){ + if(sew == 1){//cross high fp16->fp64/int64/uint64 + for(int i = 0; i < number; i++) { + ElementInput element = select_element(input, i); + output_part[i] = calculation_e16(element); + mask = 0xFFFFFFFFFFFFFFFF; + if (output_part[i].fflags > 0x1f) { + printf("Bad fflags of %x, check golden model e8 %d\n", output_part[i].fflags, i); + exit(1); + } + } + }else if(sew == 3){//corss low fp64->fp16 + for(int i = 0; i < number; i++) { + ElementInput element = select_element(input, i); + output_part[i] = calculation_e64(element); + mask = 0xFFFF; + if (output_part[i].fflags > 0x1f) { + printf("Bad fflags of %x, check golden model e8 %d\n", output_part[i].fflags, i); + exit(1); + } + } + } + }else if(widenNorrow == 0){ // single + for(int i = 0; i < number; i++) { + ElementInput element = select_element(input, i); + switch (sew) { + case 2: output_part[i] = calculation_e32(element); mask = 0xFFFFFFFF; break; //fp32->i32/u32 + case 3: output_part[i] = calculation_e64(element); mask = 0xFFFFFFFFFFFFFFFF; break;//fp64->i64/u64 + default: + printf("VPU Golden Modle, bad sew %d\n", input.sew); + exit(1); + } + if (output_part[i].fflags > 0x1f) { + printf("Bad fflags of %x, check golden model e8 %d\n", output_part[i].fflags, i); + exit(1); + } + } + } + }else if(input.fuType == FloatCvtI2F){ + if(i2f_inputType == 1){// input:i64/ui64 + for(int i = 0; i < i2f_number; i++){ + ElementInput element = select_element(input, i); + output_part[i] = calculation_e64(element); + switch (i2f_outputType) { + case 0: mask = 0xFFFFFFFFFFFF0000; half_number = 1; break; //64->16 + case 1: mask = 0xFFFFFFFF00000000; half_number = 1; break; // 64->32 + case 2: mask = 0x0000000000000000; half_number = 1; break; // 64->64 + default: + printf("VPU Golden Modle, bad i2f_outputType %d\n", i2f_outputType); + exit(1); + } + if (output_part[i].fflags > 0x1f) { + printf("Bad fflags of %x, check golden model e8 %d\n", output_part[i].fflags, i); + exit(1); + } + } + }else{// input:i32/ui32 + for(int i = 0; i < i2f_number; i++){ + ElementInput element = select_element(input, i); + output_part[i] = calculation_e32(element); + switch (i2f_outputType) { + case 0: mask = 0xFFFFFFFFFFFF0000; half_number = 2; break; // 32->16 + case 1: mask = 0xFFFFFFFF00000000; half_number = 2; break; // 32->32 + case 2: mask = 0x0000000000000000; half_number = 1; break; // 32->64 + default: + printf("VPU Golden Modle, bad i2f_outputType %d\n", i2f_outputType); + exit(1); + } + if (output_part[i].fflags > 0x1f) { + printf("Bad fflags of %x, check golden model e8 %d\n", output_part[i].fflags, i); + exit(1); + } + } + } } else{ for(int i = 0; i < number; i++) { ElementInput element = select_element(input, i); switch (sew) { - case 0: output_part[i] = calculation_e8(element); mask = 0xFF; break; + case 0: output_part[i] = calculation_e8(element); mask = 0xFF; break; case 1: output_part[i] = calculation_e16(element); mask = 0xFFFF; break; case 2: output_part[i] = calculation_e32(element); mask = 0xFFFFFFFF; break; case 3: output_part[i] = calculation_e64(element); mask = 0xFFFFFFFFFFFFFFFF; break; @@ -99,7 +212,7 @@ VecOutput VPUGoldenModel::get_expected_output(VecInput input) { if(input.fuType == VIntegerDivider) { output.result[i] += (uint64_t)(output_part[i*half_number+j].result&mask) << (j*result_shift_len); output.fflags[i] += (uint32_t)output_part[i*half_number+j].fflags << j; - } else if(input.fuType == VFloatCvt){ + }else if(input.fuType == VFloatCvt){ if(widenNorrow == 1){//widen output.result[i] += ((uint64_t)output_part[(i<<1)*half_number+j].result&mask) << (j*result_shift_len); output.fflags[i] += (uint32_t)output_part[(i<<1)*half_number+j].fflags << (j*5); @@ -107,6 +220,57 @@ VecOutput VPUGoldenModel::get_expected_output(VecInput input) { output.result[i] += ((uint64_t)output_part[i*half_number+j].result&mask) << (j*result_shift_len); output.fflags[i] += (uint32_t)output_part[i*half_number+j].fflags << (j*5); } + }else if(input.fuType == FloatCvtF2X){ + if(widenNorrow == 1){//widen + if(sew == 1){//fp16->fp32/int32/uint32 + output.result[i] = ((uint64_t)output_part[(i<<2)*half_number+j].result) ; + output.fflags[i] = (uint32_t)output_part[(i<<2)*half_number+j].fflags ; + if(output.result[i] == 0x00000000FFFFFFFF) + output.result[i] = 0xFFFFFFFFFFFFFFFF; + }else if(sew == 2){//fp32->fp64/int64/uint64 + output.result[i] = ((uint64_t)output_part[(i<<1)*half_number+j].result) ; + output.fflags[i] = (uint32_t)output_part[(i<<1)*half_number+j].fflags ; + } + }else if(widenNorrow == 2){//norrow + if(sew == 1){//fp32->fp16 + output.result[i] = ((uint64_t)output_part[(i<<1)*half_number+j].result) ; + output.fflags[i] = (uint32_t)output_part[(i<<1)*half_number+j].fflags ; + }else if(sew == 2){//fp64->fp32/i32/ui32 + output.result[i] = ((uint64_t)output_part[i*half_number+j].result) ; + output.fflags[i] = (uint32_t)output_part[i*half_number+j].fflags ; + if(output.result[i] == 0x00000000FFFFFFFF) + output.result[i] = 0xFFFFFFFFFFFFFFFF; + } + }else if(widenNorrow == 0) {//single + if(sew == 2){//fp32->i32/u32 + output.result[i] = ((uint64_t)output_part[(i<<1)*half_number+j].result) ; + output.fflags[i] = (uint32_t)output_part[(i<<1)*half_number+j].fflags ; + if(output.result[i] == 0x00000000FFFFFFFF) + output.result[i] = 0xFFFFFFFFFFFFFFFF; + }else if(sew == 3){//fp64->i64/u64 + output.result[i] = ((uint64_t)output_part[i*half_number+j].result) ; + output.fflags[i] = (uint32_t)output_part[i*half_number+j].fflags ; + } + }else if(widenNorrow == 3){//cross + if(sew == 1){//cross high 16->64 + output.result[i] = ((uint64_t)output_part[(i<<2)*half_number+j].result) ; + output.fflags[i] = (uint32_t)output_part[(i<<2)*half_number+j].fflags ; + }else if(sew == 3){//cross low fp64->fp16 + output.result[i] = ((uint64_t)output_part[i*half_number+j].result) ; + output.fflags[i] = (uint32_t)output_part[i*half_number+j].fflags ; + } + } + }else if(input.fuType == FloatCvtI2F){ + if(i2f_inputType == 0 && i2f_outputType == 2){//widen + output.result[i] = ((uint64_t)output_part[(i<<1)*half_number].result|mask); + output.fflags[i] = (uint32_t)output_part[(i<<1)*half_number].fflags; + }else if(i2f_inputType == 1 && i2f_outputType == 0){//cross low + output.result[i] = ((uint64_t)output_part[i*half_number].result|mask); + output.fflags[i] = (uint32_t)output_part[i*half_number].fflags; + }else {//single or norrow + output.result[i] = ((uint64_t)output_part[i*half_number].result|mask); + output.fflags[i] = (uint32_t)output_part[i*half_number].fflags; + } }else { output.result[i] += ((uint64_t)output_part[i*half_number+j].result) << (j*result_shift_len); output.fflags[i] += (uint32_t)output_part[i*half_number+j].fflags << (j*5); @@ -174,7 +338,7 @@ ElementInput VPUGoldenModel::select_element(VecInput input, int idx) { printf("VPU Golden Modle, not support widen fuType %d\n", input.fuType); exit(1); } - }else if((input.fuType == VFloatCvt) && (((input.fuOpType >>3) & 0X3) == 2)){ //cvt norrow select 2sew + }else if((input.fuType == VFloatCvt) && (((input.fuOpType >>3) & 0X3) == 2) ){ //cvt norrow select 2sew switch (sew) { case 0: element.src1 = input.is_frs2 ? (uint64_t)input64->src1[0] : (uint64_t)input16->src1[idx]; @@ -195,7 +359,43 @@ ElementInput VPUGoldenModel::select_element(VecInput input, int idx) { printf("VPU Golden Modle, bad sew %d\n", input.sew); exit(1); } - }else { + }else if((input.fuType == FloatCvtF2X) && (((input.fuOpType >>3) & 0X3) == 2) ){ //cvt norrow select 2sew + switch (sew) { + case 0: + element.src1 = input.is_frs2 ? (uint64_t)input64->src1[0] : (uint64_t)input16->src1[idx]; + element.src2 = input.is_frs1 ? (uint64_t)input64->src2[0] : (uint64_t)input16->src2[idx]; + element.src3 = (uint64_t)input16->src3[idx]; + break; + case 1: + element.src1 = input.is_frs2 ? (uint64_t)input64->src1[0] : (uint64_t)input32->src1[idx]; + element.src2 = input.is_frs1 ? (uint64_t)input64->src2[0] : (uint64_t)input32->src2[idx]; + element.src3 = (uint64_t)input32->src3[idx]; + break; + case 2: + element.src1 = input.is_frs2 ? (uint64_t)input64->src1[0] : (uint64_t)input64->src1[idx]; + element.src2 = input.is_frs1 ? (uint64_t)input64->src2[0] : (uint64_t)input64->src2[idx]; + element.src3 = (uint64_t)input64->src3[idx]; + break; + default: + printf("VPU Golden Modle, bad sew %d\n", input.sew); + exit(1); + } + }else if((input.fuType == FloatCvtF2X) && (((input.fuOpType>>3) & 0x3) == 3) && (sew == 3)){//f64->f16 + element.src1 = input.is_frs2 ? (uint64_t)input64->src1[0] : (uint64_t)input64->src1[idx]; + element.src2 = input.is_frs1 ? (uint64_t)input64->src2[0] : (uint64_t)input64->src2[idx]; + element.src3 = (uint64_t)input64->src3[idx]; + }else if(input.fuType == FloatCvtI2F){ + if(((input.fuOpType>>3) & 0x1) == 1){ + element.src1 = input.is_frs2 ? (uint64_t)input64->src1[0] : (uint64_t)input64->src1[idx]; + element.src2 = input.is_frs1 ? (uint64_t)input64->src2[0] : (uint64_t)input64->src2[idx]; + element.src3 = (uint64_t)input64->src3[idx]; + }else { + element.src1 = input.is_frs2 ? (uint64_t)input64->src1[0] : (uint64_t)input32->src1[idx]; + element.src2 = input.is_frs1 ? (uint64_t)input64->src2[0] : (uint64_t)input32->src2[idx]; + element.src3 = (uint64_t)input32->src3[idx]; + } + } + else { switch (sew) { case 0: element.src1 = (uint64_t)input8->src1[idx]; diff --git a/src/test/csrc/golden_model/scalar_float_convert.cpp b/src/test/csrc/golden_model/scalar_float_convert.cpp new file mode 100644 index 0000000..cb6a3a7 --- /dev/null +++ b/src/test/csrc/golden_model/scalar_float_convert.cpp @@ -0,0 +1,167 @@ +#include "../include/gm_common.h" +#include "../include/vfpu_functions.h" +#include +#include + +#define BOX_MASK_FP16 0xFFFFFFFFFFFF0000 +#define BOX_MASK_FP32 0xFFFFFFFF00000000 +#define defaultNaNF16UI 0x7E00 +#define defaultNaNF32UI 0x7FC00000 + +static inline uint64_t unboxf16(uint64_t r) { + return (r & BOX_MASK_FP16) == BOX_MASK_FP16 + ? (r & ~BOX_MASK_FP16) : defaultNaNF16UI; +} +static inline uint64_t unboxf32(uint64_t r) { + return (r & BOX_MASK_FP32) == BOX_MASK_FP32 + ? (r & ~BOX_MASK_FP32) : defaultNaNF32UI; +} + +static inline float32_t my_f16_to_f32 (float16_t a) { + return f16_to_f32(a); +} +static inline float64_t my_f16_to_f64 (float16_t a) { + return f16_to_f64(a); +} +static inline float16_t my_f32_to_f16 (float32_t a) { + return f32_to_f16(a); +} +static inline float64_t my_f32_to_f64 (float32_t a) { + return f32_to_f64(a); +} +static inline float32_t rtlToF32(uint64_t r) { + float32_t f = { .v = (uint32_t)unboxf32(r) }; + return f; +} +static inline float16_t rtlToF16(uint64_t r) { + float16_t f = { .v = (uint16_t)unboxf16(r) }; + return f; +} + +ElementOutput SGMFloatCvt::calculation_e8(ElementInput input) { + fp_set_rm(input.rm); + fp_clear_exception(); + ElementOutput output; +// switch(input.fuOpType) { +// // widen 8->16 +// case VFWCVT_FXUV: //ui8 -> f16 +// output.result = ui32_to_f16((uint32_t)input.src1).v; break; +// case VFWCVT_FXV: //i8 -> f16 +// output.result = i32_to_f16((int32_t)(int8_t)input.src1).v; break; //todo +// default: +// printf("VFConvert Unsupported fuOpType %d\n", input.fuOpType); +// exit(1); +// } + +// output.fflags = softfloat_exceptionFlags & 0x1f; +// if (verbose) { display_calculation(typeid(this).name(), __func__, input, output); } + return output; +} + + +ElementOutput SGMFloatCvt::calculation_e16(ElementInput input) { + fp_set_rm(input.rm); + fp_clear_exception(); + float16_t fsrc1 = rtlToF16(input.src1); + ElementOutput output; + switch(input.fuOpType) { + //scalar + case FCVT_D_H: //f16->f64 + output.result = my_f16_to_f64(rtlToF16(input.src1)).v; break; + case FCVT_S_H: //f16 -> f32 + output.result = my_f16_to_f32(rtlToF16(input.src1)).v; break; + case FCVT_L_H: //f16->i64 + output.result = f16_to_i64(rtlToF16(input.src1), softfloat_roundingMode, true); break; + case FCVT_LU_H: //f16->ui64 + output.result = f16_to_ui64(rtlToF16(input.src1), softfloat_roundingMode, true); break; + case FCVT_W_H: //f16->i32 + output.result = f16_to_i32(rtlToF16(input.src1), softfloat_roundingMode, true); break; + case FCVT_WU_H: //f16->ui32 + output.result = f16_to_ui32(rtlToF16(input.src1), softfloat_roundingMode, true); break; + default: + printf("SFConvert Unsupported fuOpType %d\n", input.fuOpType); + exit(1); + } + output.fflags = softfloat_exceptionFlags & 0x1f; + if (verbose) { display_calculation(typeid(this).name(), __func__, input, output); } + return output; +} + +ElementOutput SGMFloatCvt::calculation_e32(ElementInput input) { + fp_set_rm(input.rm); + fp_clear_exception(); + ElementOutput output; + switch(input.fuOpType) { + //saclar + case FCVT_H_W: // i32 ->f16 + output.result = i32_to_f16((uint32_t)input.src1).v; break; + case FCVT_H_WU: // ui32 ->f16 + output.result = ui32_to_f16((uint32_t)input.src1).v; break; + case FCVT_S_W: // i32 ->f32 + output.result = i32_to_f32((uint32_t)input.src1).v; break; + case FCVT_S_WU: // ui32 ->f32 + output.result = ui32_to_f32((uint32_t)input.src1).v; break; + case FCVT_D_W: // i32 ->f64 + output.result = i32_to_f64((uint32_t)input.src1).v; break; + case FCVT_D_WU: // ui32 ->f64 + output.result = ui32_to_f64((uint32_t)input.src1).v; break; + case FCVT_H_S: // f32 ->f16 + output.result = my_f32_to_f16(rtlToF32(input.src1)).v; break; + case FCVT_D_S: //f32 -> f64 + output.result = my_f32_to_f64(rtlToF32(input.src1)).v; break; + case FCVT_W_S: //f32 -> i32 + output.result = f32_to_i32(rtlToF32(input.src1), softfloat_roundingMode, true); break; + case FCVT_WU_S: //f32 -> ui32 + output.result = f32_to_ui32(rtlToF32(input.src1), softfloat_roundingMode, true); break; + case FCVT_L_S: //f32 -> i64 + output.result = f32_to_i64(rtlToF32(input.src1), softfloat_roundingMode, true); break; + case FCVT_LU_S: //f32 -> ui64 + output.result = f32_to_ui64(rtlToF32(input.src1), softfloat_roundingMode, true); break; + default: + printf("SFConvert Unsupported fuOpType %d\n", input.fuOpType); + exit(1); + } + + output.fflags = softfloat_exceptionFlags & 0x1f; + if (verbose) { display_calculation(typeid(this).name(), __func__, input, output); } + return output; +} + +ElementOutput SGMFloatCvt::calculation_e64(ElementInput input) { + fp_set_rm(input.rm); + fp_clear_exception(); + ElementOutput output; + switch(input.fuOpType) { + case FCVT_H_D: //f64 -> f16 + output.result = f64_to_f16(i2f64((uint64_t)input.src1)).v; break; + case FCVT_S_D: //f64 -> f32 + output.result = f64_to_f32(i2f64((uint64_t)input.src1)).v; break; + case FCVT_H_L: // i64 ->f16 + output.result = i64_to_f16((uint64_t)input.src1).v; break; + case FCVT_H_LU: // ui64 ->f16 + output.result = ui64_to_f16((uint64_t)input.src1).v; break; + case FCVT_S_L: // i64 ->f32 + output.result = i64_to_f32((uint64_t)input.src1).v; break; + case FCVT_S_LU: // ui64 ->f32 + output.result = ui64_to_f32((uint64_t)input.src1).v; break; + case FCVT_D_L: // i64 ->f64 + output.result = i64_to_f64((uint64_t)input.src1).v; break; + case FCVT_D_LU: // ui64 ->f64 + output.result = ui64_to_f64((uint64_t)input.src1).v; break; + case FCVT_L_D: //f64 -> i64 + output.result = f64_to_i64(i2f64((uint64_t)input.src1), softfloat_roundingMode, true); break; + case FCVT_LU_D: //f64 -> ui64 + output.result = f64_to_ui64(i2f64((uint64_t)input.src1), softfloat_roundingMode, true); break; + case FCVT_W_D: //f64 -> i32 + output.result = f64_to_i32(i2f64((uint64_t)input.src1), softfloat_roundingMode, true); break; + case FCVT_WU_D: //f64 -> ui32 + output.result = f64_to_ui32(i2f64((uint64_t)input.src1), softfloat_roundingMode, true); break; + default: + printf("SFConvert Unsupported fuOpType %d\n", input.fuOpType); + exit(1); + } + + output.fflags = softfloat_exceptionFlags & 0x1f; + if (verbose) { display_calculation(typeid(this).name(), __func__, input, output); } + return output; +} \ No newline at end of file diff --git a/src/test/csrc/include/gm_common.h b/src/test/csrc/include/gm_common.h index 4d53fa5..c0da348 100644 --- a/src/test/csrc/include/gm_common.h +++ b/src/test/csrc/include/gm_common.h @@ -55,6 +55,14 @@ class VGMFloatCvt : public VGMFloatBase { virtual ElementOutput calculation_e64(ElementInput input); }; +// scalar cvt +class SGMFloatCvt : public VGMFloatBase { + virtual ElementOutput calculation_e8(ElementInput input); + virtual ElementOutput calculation_e16(ElementInput input); + virtual ElementOutput calculation_e32(ElementInput input); + virtual ElementOutput calculation_e64(ElementInput input); +}; + class VGMFloatFMA : public VGMFloatBase { virtual ElementOutput calculation_e16(ElementInput input); diff --git a/src/test/csrc/include/test_driver.h b/src/test/csrc/include/test_driver.h index 7db077c..9c7f7f6 100644 --- a/src/test/csrc/include/test_driver.h +++ b/src/test/csrc/include/test_driver.h @@ -43,7 +43,7 @@ class TestDriver { VGMIntegerALUF vialuF; VGMIntegerDividier vid; VGMFloatCvt vcvt; - + SGMFloatCvt scvt; public: TestDriver(); diff --git a/src/test/csrc/include/vpu_constant.h b/src/test/csrc/include/vpu_constant.h index fadfb90..5dd1efd 100644 --- a/src/test/csrc/include/vpu_constant.h +++ b/src/test/csrc/include/vpu_constant.h @@ -22,11 +22,21 @@ extern "C"{ #define VIntegerALUV2 (5) #define VIntegerDivider (6) #define VFloatCvt (7) +#define FloatCvtF2X (8) //f->i/ui/f +#define FloatCvtI2F (9) //i/ui->f // #define ALL_FUTYPES {VFloatAdder,VFloatFMA,VFloatDivider,VIntegerALU,VPermutation,VIntegerALUV2,VIntegerDivider,VFloatCvt} //will be delated -#define FU_NUM 6 -#define ALL_FUTYPES {VFloatFMA,VFloatDivider,VIntegerALU,VPermutation,VIntegerDivider,VFloatCvt} +// #define FU_NUM 7 +// #define ALL_FUTYPES {VFloatFMA,VFloatDivider,VIntegerALU,VPermutation,VIntegerDivider,VFloatCvt,FloatCvt} +// #define FU_NUM 1 +// #define ALL_FUTYPES {VFloatCvt} +// #define FU_NUM 1 +// #define ALL_FUTYPES {FloatCvtF2X} +#define FU_NUM 2 +#define ALL_FUTYPES {FloatCvtF2X,FloatCvtI2F} +// #define FU_NUM 1 +// #define ALL_FUTYPES {FloatCvtI2F} #define INT_ROUNDING(result, xrm, gb) \ do { \ @@ -288,6 +298,44 @@ extern "C"{ #define VFRSQRT7 (binstoi("11100000")) #define VFREC7 (binstoi("11100001")) + //FloatCvtF2X + //sew == 1 + #define FCVT_S_H (binstoi("11001000")) + #define FCVT_D_H (binstoi("11011000")) + #define FCVT_W_H (binstoi("10001001")) + #define FCVT_WU_H (binstoi("10001000")) + #define FCVT_L_H (binstoi("10011001")) + #define FCVT_LU_H (binstoi("10011000")) + #define FCVT_H_S (binstoi("11010000")) + //sew == 2 + #define FCVT_W_S (binstoi("10000001")) + #define FCVT_WU_S (binstoi("10000000")) + #define FCVT_L_S (binstoi("10001001")) + #define FCVT_LU_S (binstoi("10001000")) + #define FCVT_W_D (binstoi("10010001")) + #define FCVT_WU_D (binstoi("10010000")) + #define FCVT_S_D (binstoi("11010100")) + #define FCVT_D_S (binstoi("11001100")) + //sew == 3 + #define FCVT_H_D (binstoi("11011000")) + #define FCVT_L_D (binstoi("10000001")) + #define FCVT_LU_D (binstoi("10000000")) + + //FloatCvtI2F + #define FCVT_H_WU (binstoi("00000000")) + #define FCVT_H_W (binstoi("00000001")) + #define FCVT_H_LU (binstoi("00001000")) + #define FCVT_H_L (binstoi("00001001")) + + #define FCVT_S_WU (binstoi("00000010")) + #define FCVT_S_W (binstoi("00000011")) + #define FCVT_S_LU (binstoi("00001010")) + #define FCVT_S_L (binstoi("00001011")) + + #define FCVT_D_WU (binstoi("00000100")) + #define FCVT_D_W (binstoi("00000101")) + #define FCVT_D_LU (binstoi("00001100")) + #define FCVT_D_L (binstoi("00001101")) #define VFCVT_ALL_OPTYPES {VFCVT_XUFV, VFCVT_XFV, VFCVT_FXUV, VFCVT_FXV, VFCVT_RTZ_XUFV, VFCVT_RTZ_XFV, \ @@ -306,6 +354,20 @@ extern "C"{ #define VFCVT_64_NUM 8 #define VFCVT_64_OPTYPES {VFCVT_XUFV,VFCVT_XFV,VFCVT_FXUV,VFCVT_FXV,VFCVT_RTZ_XUFV,VFCVT_RTZ_XFV,VFRSQRT7,VFREC7} + //F2X + //sew == 1 + #define FCVT_16_NUM 7 + #define FCVT_16_OPTYPES {FCVT_H_S,FCVT_S_H,FCVT_D_H,FCVT_W_H,FCVT_WU_H,FCVT_L_H,FCVT_LU_H} + //sew == 2 + #define FCVT_32_NUM 8 + #define FCVT_32_OPTYPES {FCVT_W_S,FCVT_WU_S,FCVT_D_S,FCVT_L_S,FCVT_LU_S,FCVT_S_D,FCVT_W_D,FCVT_WU_D} + //sew == 3 + #define FCVT_64_NUM 3 + #define FCVT_64_OPTYPES {FCVT_H_D,FCVT_L_D,FCVT_LU_D} + + //I2F + #define I2FCVT_64_NUM 12 + #define I2FCVT_64_OPTYPES {FCVT_H_WU,FCVT_H_W,FCVT_H_LU,FCVT_H_L,FCVT_S_WU,FCVT_S_W,FCVT_S_LU,FCVT_S_L,FCVT_D_WU,FCVT_D_W,FCVT_D_LU,FCVT_D_L} // pre-compile stoi constexpr uint8_t binstoi(const char str[]) { diff --git a/src/test/csrc/test_driver.cpp b/src/test/csrc/test_driver.cpp index ff96d93..1f25cc1 100644 --- a/src/test/csrc/test_driver.cpp +++ b/src/test/csrc/test_driver.cpp @@ -82,11 +82,6 @@ uint8_t TestDriver::gen_random_optype() { break; } case VFloatCvt:{ - // uint8_t vid_all_optype[VFCVT_NUM] = VFCVT_ALL_OPTYPES; - // return vid_all_optype[rand() % VFCVT_NUM]; - // uint8_t vid_all_optype[2] = {VFRSQRT7, VFREC7}; - // return vid_all_optype[rand() % 2]; - // break; if (input.sew == 0) { uint8_t vfcvt_8_optype[VFCVT_8_NUM] = VFCVT_8_OPTYPES; return vfcvt_8_optype[rand() % VFCVT_8_NUM]; @@ -105,6 +100,26 @@ uint8_t TestDriver::gen_random_optype() { break; } } + case FloatCvtF2X:{ + if(input.sew == 1){ + uint8_t fcvt_16_optype[FCVT_16_NUM] = FCVT_16_OPTYPES; + return fcvt_16_optype[rand() % FCVT_16_NUM]; + break; + }else if(input.sew == 2){ + uint8_t fcvt_32_optype[FCVT_32_NUM] = FCVT_32_OPTYPES; + return fcvt_32_optype[rand() % FCVT_32_NUM]; + break; + }else if(input.sew == 3){ + uint8_t fcvt_64_optype[FCVT_64_NUM] = FCVT_64_OPTYPES; + return fcvt_64_optype[rand() % FCVT_64_NUM]; + break; + } + } + case FloatCvtI2F:{ + uint8_t i2fcvt_64_optype[I2FCVT_64_NUM] = I2FCVT_64_OPTYPES; + return i2fcvt_64_optype[rand() % I2FCVT_64_NUM]; + break; + } default: printf("Unsupported FuType %d\n", input.fuType); exit(1); @@ -119,6 +134,8 @@ uint8_t TestDriver::gen_random_sew() { case VIntegerALU: return rand()%4; break; case VPermutation: return rand()%4; break; case VFloatCvt: return rand()%4; break; + case FloatCvtF2X: return (rand()%3)+1 ; break; + case FloatCvtI2F: return 0 ; break; default: return (rand()%3)+1; break; } } @@ -357,7 +374,6 @@ void TestDriver::get_random_input() { if (!test_type.pick_fuType) { input.fuType = gen_random_futype(ALL_FUTYPES); } else { input.fuType = test_type.fuType; } - if(input.fuType == VFloatCvt){ input.sew = gen_random_sew(); input.is_frs1 = false; @@ -365,6 +381,20 @@ void TestDriver::get_random_input() { input.widen = false; if (!test_type.pick_fuOpType) { input.fuOpType = gen_random_optype(); } else { input.fuOpType = test_type.fuOpType; } + }else if(input.fuType == FloatCvtF2X){ + input.sew = gen_random_sew(); + input.is_frs1 = false; + input.is_frs2 = false; + input.widen = false; + if (!test_type.pick_fuOpType) { input.fuOpType = gen_random_optype(); } + else { input.fuOpType = test_type.fuOpType; } + }else if(input.fuType == FloatCvtI2F){ + input.sew = gen_random_sew(); + input.is_frs1 = false; + input.is_frs2 = false; + input.widen = false; + if (!test_type.pick_fuOpType) { input.fuOpType = gen_random_optype(); } + else { input.fuOpType = test_type.fuOpType; } }else{ if (!test_type.pick_fuOpType) { input.fuOpType = gen_random_optype(); } else { input.fuOpType = test_type.fuOpType; } @@ -428,6 +458,12 @@ void TestDriver::get_expected_output() { case VFloatCvt: if (verbose) { printf("FuType:%d, choose VFloatCvt %d\n", input.fuType, VFloatCvt); } expect_output = vcvt.get_expected_output(input); return; + case FloatCvtF2X: + if (verbose) { printf("FuType:%d, choose FloatCvtF2X %d\n", input.fuType, FloatCvtF2X); } + expect_output = scvt.get_expected_output(input); return; + case FloatCvtI2F: + if (verbose) { printf("FuType:%d, choose FloatCvtI2F %d\n", input.fuType, FloatCvtI2F); } + expect_output = scvt.get_expected_output(input); return; default: printf("Unsupported FuType %d\n", input.fuType); exit(1); diff --git a/src/test/scala/top/VectorSimTop.scala b/src/test/scala/top/VectorSimTop.scala index 979af3e..9d12b91 100644 --- a/src/test/scala/top/VectorSimTop.scala +++ b/src/test/scala/top/VectorSimTop.scala @@ -6,6 +6,8 @@ import chisel3.util._ import yunsuan.util._ import yunsuan.vector.VectorConvert.VectorCvt import yunsuan.vector._ +import yunsuan.scalar.INT2FP +import yunsuan.scalar.FPCVT trait VSPParameter { val VLEN : Int = 128 @@ -29,9 +31,11 @@ object VPUTestFuType { // only use in test, difftest with xs def viaf = "b0000_0101".U(8.W) def vid = "b0000_0110".U(8.W) def vcvt= "b0000_0111".U(8.W) + def fcvtf2x= "b0000_1000".U(8.W) + def fcvti2f= "b0000_1001".U(8.W) def unknown(typ: UInt) = { - (typ > 7.U) + (typ > 9.U) } } @@ -104,7 +108,9 @@ class SimTop() extends VPUTestModule { VPUTestFuType.vperm -> VPERM_latency.U, VPUTestFuType.viaf -> VIAF_latency.U, VPUTestFuType.vid -> VID_latency.U, - VPUTestFuType.vcvt -> VCVT_latency.U + VPUTestFuType.vcvt -> VCVT_latency.U, + VPUTestFuType.fcvtf2x -> VCVT_latency.U, + VPUTestFuType.fcvti2f -> VCVT_latency.U )) // fuType --> latency, spec case for div assert(!VPUTestFuType.unknown(io.in.bits.fuType)) } @@ -135,6 +141,8 @@ class SimTop() extends VPUTestModule { val vid_result = Wire(new VSTOutputIO) val vid_result_valid = Wire(Bool()) val vcvt_result = Wire(new VSTOutputIO) + val i2f_result = Wire(new VSTOutputIO) + val fpcvt_result = Wire(new VSTOutputIO) when (io.in.fire() || io.out.fire()) { vfd_result_valid.map(_ := false.B) } @@ -148,6 +156,8 @@ class SimTop() extends VPUTestModule { val vfd = Module(new VectorFloatDivider) val via = Module(new VectorIntAdder) val vcvt = Module(new VectorCvt(XLEN)) + val i2fcvt = Module(new INT2FP(2, XLEN)) + val fpcvt = Module(new FPCVT(XLEN)) require(vfa.io.fp_a.getWidth == XLEN) vfa.io.fire := busy @@ -266,6 +276,29 @@ class SimTop() extends VPUTestModule { vcvt_result.vxsat := 0.U vcvt_result.result(i) := vcvt.io.result vcvt_result.fflags(i) := vcvt.io.fflags + + // i2fcvt + i2fcvt.regEnables(0) := true.B + i2fcvt.regEnables(1) := true.B + i2fcvt.io.wflags := busy + i2fcvt.io.opType := opcode(4,0) + i2fcvt.io.rm := rm + i2fcvt.io.rmInst := 7.U + i2fcvt.io.src := src1 + i2f_result.vxsat := 0.U + i2f_result.result(i) := i2fcvt.io.result + i2f_result.fflags(i) := i2fcvt.io.fflags + + //fpcvt + fpcvt.io.fire := busy + fpcvt.io.sew := sew + fpcvt.io.opType := opcode + fpcvt.io.rm := rm + fpcvt.io.src := src1 + fpcvt.io.isFpToVecInst := true.B + fpcvt_result.vxsat := 0.U + fpcvt_result.result(i) := fpcvt.io.result + fpcvt_result.fflags(i) := fpcvt.io.fflags } val vperm = Module(new VPermTop) @@ -352,7 +385,9 @@ class SimTop() extends VPUTestModule { VPUTestFuType.vperm -> vperm_result, VPUTestFuType.viaf -> viaf_result, VPUTestFuType.vid -> vid_result, - VPUTestFuType.vcvt -> vcvt_result + VPUTestFuType.vcvt -> vcvt_result, + VPUTestFuType.fcvtf2x -> fpcvt_result, + VPUTestFuType.fcvti2f -> i2f_result )) }