Skip to content

Commit

Permalink
FCVT:add conversion for FP16 and modified CVT64 module to parameteriz…
Browse files Browse the repository at this point in the history
…e it

*The scalar float conversion function can be parameterized and trimmed in the CVT64
*Add scalar IntToFP conversion for FP16
  • Loading branch information
zmx2018 committed Aug 23, 2024
1 parent fdd7611 commit df5e161
Show file tree
Hide file tree
Showing 15 changed files with 1,450 additions and 354 deletions.
11 changes: 8 additions & 3 deletions src/main/scala/yunsuan/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -530,18 +530,15 @@ object VfcvtType {
def vfcvt_fxv = "b01_000011".U(8.W)
def vfcvt_rtz_xufv = "b10_000110".U(8.W)
def vfcvt_rtz_xfv = "b10_000111".U(8.W)

def vfrsqrt7 = "b11_100000".U(8.W)
def vfrec7 = "b11_100001".U(8.W)

def vfwcvt_xufv = "b10_001000".U(8.W)
def vfwcvt_xfv = "b10_001001".U(8.W)
def vfwcvt_fxuv = "b01_001010".U(8.W)
def vfwcvt_fxv = "b01_001011".U(8.W)
def vfwcvt_ffv = "b11_001100".U(8.W)
def vfwcvt_rtz_xufv = "b10_001110".U(8.W)
def vfwcvt_rtz_xfv = "b10_001111".U(8.W)

def vfncvt_xufw = "b10_010000".U(8.W)
def vfncvt_xfw = "b10_010001".U(8.W)
def vfncvt_fxuw = "b01_010010".U(8.W)
Expand All @@ -550,6 +547,14 @@ object VfcvtType {
def vfncvt_rod_ffw = "b11_010101".U(8.W)
def vfncvt_rtz_xufw = "b10_010110".U(8.W)
def vfncvt_rtz_xfw = "b10_010111".U(8.W)
def fcvt_h_s = "b11_010000".U(8.W)
def fcvt_s_h = "b11_001000".U(8.W)
def fcvt_h_d = "b11_011000".U(8.W)
def fcvt_d_h = "b11_011000".U(8.W)
def fcvt_w_h = "b10_001001".U(8.W)
def fcvt_wu_h = "b10_001000".U(8.W)
def fcvt_l_h = "b10_011001".U(8.W)
def fcvt_lu_h = "b10_011000".U(8.W)
}


Expand Down
160 changes: 160 additions & 0 deletions src/main/scala/yunsuan/scalar/Convert.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package yunsuan.scalar

import chisel3._
import chisel3.util._
import chisel3.util.experimental.decode._
import yunsuan.util._
import yunsuan.vector.VectorConvert.CVT64

// Scalar Int to Float Convert.
class I2fCvtIO extends Bundle{
val src = Input(UInt(64.W))
val opType = Input(UInt(5.W))
val rm = Input(UInt(3.W))
val wflags = Input(Bool())
val rmInst = Input(UInt(3.W))

val result = Output(UInt(64.W))
val fflags = Output(UInt(5.W))
}
class INT2FP(latency: Int, XLEN: Int) extends Module {
val io = IO(new I2fCvtIO)
val rm = Mux(io.rmInst === "b111".U, io.rm, io.rmInst)
val regEnables = IO(Input(Vec(latency, Bool())))
dontTouch(regEnables)
// stage1
val in = io.src
val wflags = io.wflags
val typeIn = io.opType(3)
val typeOut = io.opType(2,1)
val signIn = io.opType(0)
val intValue = RegEnable(Mux(wflags,
Mux(typeIn,
Mux(!signIn, ZeroExt(in, XLEN), SignExt(in, XLEN)),
Mux(!signIn, ZeroExt(in(31, 0), XLEN), SignExt(in(31, 0), XLEN))
),
in
), regEnables(0))
val typeInReg = RegEnable(typeIn, regEnables(0))
val typeOutReg = RegEnable(typeOut, regEnables(0))
val signInReg = RegEnable(signIn, regEnables(0))
val wflagsReg = RegEnable(wflags, regEnables(0))
val rmReg = RegEnable(rm, regEnables(0))

// stage2
val s2_typeInReg = typeInReg
val s2_signInReg = signInReg
val s2_typeOutReg = typeOutReg
val s2_wflags = wflagsReg
val s2_rmReg = rmReg

val mux = Wire(new Bundle() {
val data = UInt(XLEN.W)
val exc = UInt(5.W)
})

mux.data := intValue
mux.exc := 0.U

when(s2_wflags){
val i2fResults = for(t <- FPU.ftypes) yield {
val i2f = Module(new IntToFP(t.expWidth, t.precision))
i2f.io.sign := s2_signInReg
i2f.io.long := s2_typeInReg
i2f.io.int := intValue
i2f.io.rm := s2_rmReg
(i2f.io.result, i2f.io.fflags)
}
val (data, exc) = i2fResults.unzip
mux.data := VecInit(data)(s2_typeOutReg)
mux.exc := VecInit(exc)(s2_typeOutReg)
}

// stage 3
val s3_out = RegEnable(mux, regEnables(1))
val s3_tag = RegEnable(s2_typeOutReg, regEnables(1))

io.fflags := s3_out.exc
io.result := FPU.box(s3_out.data, s3_tag)
}
// Scalar Float to Int or Float Convert.
class FpCvtIO(width: Int) extends Bundle {
val fire = Input(Bool())
val src = Input(UInt(width.W))
val opType = Input(UInt(8.W))
val sew = Input(UInt(2.W))
val rm = Input(UInt(3.W))
val isFpToVecInst = Input(Bool())

val result = Output(UInt(width.W))
val fflags = Output(UInt(5.W))
}
class FPCVT(xlen :Int) extends Module{
val io = IO(new FpCvtIO(xlen))
val (opType, sew) = (io.opType, io.sew)
val widen = opType(4, 3) // 0->single 1->widen 2->norrow => width of result
// input width 8, 16, 32, 64
val input1H = Wire(UInt(4.W))
input1H := chisel3.util.experimental.decode.decoder(
widen ## sew,
TruthTable(
Seq(
BitPat("b00_01") -> BitPat("b0010"), // 16
BitPat("b00_10") -> BitPat("b0100"), // 32
BitPat("b00_11") -> BitPat("b1000"), // 64

BitPat("b01_00") -> BitPat("b0001"), // 8
BitPat("b01_01") -> BitPat("b0010"), // 16
BitPat("b01_10") -> BitPat("b0100"), // 32

BitPat("b10_00") -> BitPat("b0010"), // 16
BitPat("b10_01") -> BitPat("b0100"), // 32
BitPat("b10_10") -> BitPat("b1000"), // 64

BitPat("b11_01") -> BitPat("b0010"), // f16->f64/i64/ui64
BitPat("b11_11") -> BitPat("b1000"), // f64->f16
),
BitPat("b0000")
)
)
// output width 8, 16, 32, 64
val output1H = Wire(UInt(4.W))
output1H := chisel3.util.experimental.decode.decoder(
widen ## sew,
TruthTable(
Seq(
BitPat("b00_01") -> BitPat("b0010"), // 16
BitPat("b00_10") -> BitPat("b0100"), // 32
BitPat("b00_11") -> BitPat("b1000"), // 64

BitPat("b01_00") -> BitPat("b0010"), // 16
BitPat("b01_01") -> BitPat("b0100"), // 32
BitPat("b01_10") -> BitPat("b1000"), // 64

BitPat("b10_00") -> BitPat("b0001"), // 8
BitPat("b10_01") -> BitPat("b0010"), // 16
BitPat("b10_10") -> BitPat("b0100"), // 32

BitPat("b11_11") -> BitPat("b0010"), // f64->f16
BitPat("b11_01") -> BitPat("b1000"), // f16->f64/i64/ui64
),
BitPat("b0000")
)
)
dontTouch(input1H)
dontTouch(output1H)
val fcvt = Module(new CVT64(64, false))
fcvt.io.sew := io.sew
fcvt.io.fire := io.fire
fcvt.io.src := io.src
fcvt.io.rm := io.rm
fcvt.io.opType := io.opType
fcvt.io.rm := io.rm
fcvt.io.isFpToVecInst := io.isFpToVecInst
fcvt.io.input1H := input1H
fcvt.io.output1H := output1H

io.fflags := fcvt.io.fflags
io.result := fcvt.io.result

}
40 changes: 40 additions & 0 deletions src/main/scala/yunsuan/scalar/FPU.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package yunsuan.scalar

import chisel3._
import chisel3.util._

object FPU {

case class FType(expWidth: Int, precision: Int) {
val sigWidth = precision - 1
val len = expWidth + precision
}

val f16 = FType(5, 11)
val f32 = FType(8, 24)
val f64 = FType(11, 53)

val ftypes = List(f16, f32, f64)

val H = ftypes.indexOf(f16).U(log2Ceil(ftypes.length).W)
val S = ftypes.indexOf(f32).U(log2Ceil(ftypes.length).W)
val D = ftypes.indexOf(f64).U(log2Ceil(ftypes.length).W)


def box(x: UInt, typeTag: UInt): UInt = {
require(x.getWidth == 64)
Mux(typeTag === D, x, Mux(typeTag === S, Cat(~0.U(32.W), x(31, 0)), Cat(~0.U(48.W), x(15, 0))))
}

def box(x: UInt, t: FType): UInt = {
if(t == f32){
Cat(~0.U(32.W), x(31, 0))
} else if(t == f64){
x(63, 0)
} else {
assert(cond = false, "Unknown ftype!")
0.U
}
}

}
122 changes: 122 additions & 0 deletions src/main/scala/yunsuan/scalar/IntToFP.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package yunsuan.scalar

import chisel3._
import chisel3.util._
import yunsuan.vector.VectorConvert.RoundingModle._
class IntToFP_prenorm_in extends Bundle {
val int = Input(UInt(64.W))
val sign = Input(Bool())
val long = Input(Bool())
}

class IntToFP_prenorm_out extends Bundle {
val norm_int = Output(UInt(63.W))
val lzc = Output(UInt(6.W))
val is_zero = Output(Bool())
val sign = Output(Bool())
}

/**
* different fp types can share this unit
*/
class IntToFP_prenorm extends Module {

val io = IO(new Bundle() {
val in = new IntToFP_prenorm_in
val out = new IntToFP_prenorm_out
})

val (in, signed_int, long_int) = (io.in.int, io.in.sign, io.in.long)

val in_sign = signed_int && Mux(long_int, in(63), in(31))
val in_sext = Cat(Fill(32, in(31)), in(31, 0))
val in_raw = Mux(signed_int && !long_int, in_sext, in)
val in_abs = Mux(in_sign, (~in_raw).asUInt + 1.U, in_raw)

val lza = Module(new LZA(64))
lza.io.a := 0.U
lza.io.b := ~in_raw

val pos_lzc = CLZ(in)
val neg_lzc = CLZ(lza.io.f)

val lzc = Mux(in_sign, neg_lzc, pos_lzc)

// eg: 001010 => 001000
val one_mask = Cat((0 until 64).reverseMap {
case i @ 63 => lza.io.f(i)
case i @ 0 => !lza.io.f(63, i + 1).orR
case i => lza.io.f(i) && !lza.io.f(63, i + 1).orR
})

val lzc_error = Mux(in_sign, !(in_abs & one_mask).orR, false.B)

val in_shift_s1 = (in_abs << lzc)(62, 0)
val in_norm = Mux(lzc_error, Cat(in_shift_s1.tail(1), 0.U(1.W)), in_shift_s1)

io.out.norm_int := in_norm
io.out.lzc := lzc + lzc_error
io.out.is_zero := in === 0.U
io.out.sign := in_sign

}

class IntToFP_postnorm(val expWidth: Int, val precision: Int) extends Module {
val io = IO(new Bundle() {
val in = Flipped(new IntToFP_prenorm_out)
val rm = Input(UInt(3.W))
val result = Output(UInt((expWidth + precision).W))
val fflags = Output(UInt(5.W))
})
val (in, lzc, is_zero, in_sign, rm) =
(io.in.norm_int, io.in.lzc, io.in.is_zero, io.in.sign, io.rm)

val exp_raw = (63 + FloatPoint.expBias(expWidth)).U(11.W) - lzc
val sig_raw = in.head(precision - 1) // exclude hidden bit
val round_bit = in.tail(precision - 1).head(1)
val sticky_bit = in.tail(precision).orR

val rounder = Module(new RoundingUnit(precision - 1))
rounder.io.in := sig_raw
rounder.io.roundIn := round_bit
rounder.io.stickyIn := sticky_bit
rounder.io.signIn := in_sign
rounder.io.rm := rm

val rmin =
io.rm === RTZ || (in_sign && io.rm === RUP) || (!in_sign && io.rm === RDN)

val nv, dz, of, uf, nx = Wire(Bool())
val ix = rounder.io.inexact
val fp_exp = Mux(is_zero, 0.U, exp_raw + rounder.io.cout)
val fp_sig = rounder.io.out
val flow = fp_exp>((1<<expWidth)-2).U // underflow or overflow

nv := false.B
dz := false.B
of := !in_sign && flow
uf := in_sign && flow
nx := flow || ix

io.result := Mux(flow, Mux(rmin, Cat(in_sign, FloatPoint.maxNormExp(expWidth).U(expWidth.W), ~0.U((precision-1).W)),
Cat(in_sign, ~0.U(expWidth.W), 0.U((precision-1).W))), Cat(in_sign, fp_exp, fp_sig))
io.fflags := Cat(nv, dz, of, uf, nx)
}

class IntToFP(val expWidth: Int, val precision: Int) extends Module {
val io = IO(new IntToFP_prenorm_in {
val rm = Input(UInt(3.W))
val result = Output(UInt((expWidth + precision).W))
val fflags = Output(UInt(5.W))
})

val pre_norm = Module(new IntToFP_prenorm)
val post_norm = Module(new IntToFP_postnorm(expWidth, precision))

pre_norm.io.in := io
post_norm.io.in := pre_norm.io.out
post_norm.io.rm := io.rm

io.result := post_norm.io.result
io.fflags := post_norm.io.fflags
}
Loading

0 comments on commit df5e161

Please sign in to comment.