-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FCVT:add conversion for FP16 and modified CVT64 module to parameteriz…
…e it *The scalar float conversion function can be parameterized and trimmed in the CVT64 *Add scalar IntToFP conversion for FP16
- Loading branch information
Showing
15 changed files
with
1,450 additions
and
354 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package yunsuan.scalar | ||
|
||
import chisel3._ | ||
import chisel3.util._ | ||
import chisel3.util.experimental.decode._ | ||
import yunsuan.util._ | ||
import yunsuan.vector.VectorConvert.CVT64 | ||
|
||
// Scalar Int to Float Convert. | ||
class I2fCvtIO extends Bundle{ | ||
val src = Input(UInt(64.W)) | ||
val opType = Input(UInt(5.W)) | ||
val rm = Input(UInt(3.W)) | ||
val wflags = Input(Bool()) | ||
val rmInst = Input(UInt(3.W)) | ||
|
||
val result = Output(UInt(64.W)) | ||
val fflags = Output(UInt(5.W)) | ||
} | ||
class INT2FP(latency: Int, XLEN: Int) extends Module { | ||
val io = IO(new I2fCvtIO) | ||
val rm = Mux(io.rmInst === "b111".U, io.rm, io.rmInst) | ||
val regEnables = IO(Input(Vec(latency, Bool()))) | ||
dontTouch(regEnables) | ||
// stage1 | ||
val in = io.src | ||
val wflags = io.wflags | ||
val typeIn = io.opType(3) | ||
val typeOut = io.opType(2,1) | ||
val signIn = io.opType(0) | ||
val intValue = RegEnable(Mux(wflags, | ||
Mux(typeIn, | ||
Mux(!signIn, ZeroExt(in, XLEN), SignExt(in, XLEN)), | ||
Mux(!signIn, ZeroExt(in(31, 0), XLEN), SignExt(in(31, 0), XLEN)) | ||
), | ||
in | ||
), regEnables(0)) | ||
val typeInReg = RegEnable(typeIn, regEnables(0)) | ||
val typeOutReg = RegEnable(typeOut, regEnables(0)) | ||
val signInReg = RegEnable(signIn, regEnables(0)) | ||
val wflagsReg = RegEnable(wflags, regEnables(0)) | ||
val rmReg = RegEnable(rm, regEnables(0)) | ||
|
||
// stage2 | ||
val s2_typeInReg = typeInReg | ||
val s2_signInReg = signInReg | ||
val s2_typeOutReg = typeOutReg | ||
val s2_wflags = wflagsReg | ||
val s2_rmReg = rmReg | ||
|
||
val mux = Wire(new Bundle() { | ||
val data = UInt(XLEN.W) | ||
val exc = UInt(5.W) | ||
}) | ||
|
||
mux.data := intValue | ||
mux.exc := 0.U | ||
|
||
when(s2_wflags){ | ||
val i2fResults = for(t <- FPU.ftypes) yield { | ||
val i2f = Module(new IntToFP(t.expWidth, t.precision)) | ||
i2f.io.sign := s2_signInReg | ||
i2f.io.long := s2_typeInReg | ||
i2f.io.int := intValue | ||
i2f.io.rm := s2_rmReg | ||
(i2f.io.result, i2f.io.fflags) | ||
} | ||
val (data, exc) = i2fResults.unzip | ||
mux.data := VecInit(data)(s2_typeOutReg) | ||
mux.exc := VecInit(exc)(s2_typeOutReg) | ||
} | ||
|
||
// stage 3 | ||
val s3_out = RegEnable(mux, regEnables(1)) | ||
val s3_tag = RegEnable(s2_typeOutReg, regEnables(1)) | ||
|
||
io.fflags := s3_out.exc | ||
io.result := FPU.box(s3_out.data, s3_tag) | ||
} | ||
// Scalar Float to Int or Float Convert. | ||
class FpCvtIO(width: Int) extends Bundle { | ||
val fire = Input(Bool()) | ||
val src = Input(UInt(width.W)) | ||
val opType = Input(UInt(8.W)) | ||
val sew = Input(UInt(2.W)) | ||
val rm = Input(UInt(3.W)) | ||
val isFpToVecInst = Input(Bool()) | ||
|
||
val result = Output(UInt(width.W)) | ||
val fflags = Output(UInt(5.W)) | ||
} | ||
class FPCVT(xlen :Int) extends Module{ | ||
val io = IO(new FpCvtIO(xlen)) | ||
val (opType, sew) = (io.opType, io.sew) | ||
val widen = opType(4, 3) // 0->single 1->widen 2->norrow => width of result | ||
// input width 8, 16, 32, 64 | ||
val input1H = Wire(UInt(4.W)) | ||
input1H := chisel3.util.experimental.decode.decoder( | ||
widen ## sew, | ||
TruthTable( | ||
Seq( | ||
BitPat("b00_01") -> BitPat("b0010"), // 16 | ||
BitPat("b00_10") -> BitPat("b0100"), // 32 | ||
BitPat("b00_11") -> BitPat("b1000"), // 64 | ||
|
||
BitPat("b01_00") -> BitPat("b0001"), // 8 | ||
BitPat("b01_01") -> BitPat("b0010"), // 16 | ||
BitPat("b01_10") -> BitPat("b0100"), // 32 | ||
|
||
BitPat("b10_00") -> BitPat("b0010"), // 16 | ||
BitPat("b10_01") -> BitPat("b0100"), // 32 | ||
BitPat("b10_10") -> BitPat("b1000"), // 64 | ||
|
||
BitPat("b11_01") -> BitPat("b0010"), // f16->f64/i64/ui64 | ||
BitPat("b11_11") -> BitPat("b1000"), // f64->f16 | ||
), | ||
BitPat("b0000") | ||
) | ||
) | ||
// output width 8, 16, 32, 64 | ||
val output1H = Wire(UInt(4.W)) | ||
output1H := chisel3.util.experimental.decode.decoder( | ||
widen ## sew, | ||
TruthTable( | ||
Seq( | ||
BitPat("b00_01") -> BitPat("b0010"), // 16 | ||
BitPat("b00_10") -> BitPat("b0100"), // 32 | ||
BitPat("b00_11") -> BitPat("b1000"), // 64 | ||
|
||
BitPat("b01_00") -> BitPat("b0010"), // 16 | ||
BitPat("b01_01") -> BitPat("b0100"), // 32 | ||
BitPat("b01_10") -> BitPat("b1000"), // 64 | ||
|
||
BitPat("b10_00") -> BitPat("b0001"), // 8 | ||
BitPat("b10_01") -> BitPat("b0010"), // 16 | ||
BitPat("b10_10") -> BitPat("b0100"), // 32 | ||
|
||
BitPat("b11_11") -> BitPat("b0010"), // f64->f16 | ||
BitPat("b11_01") -> BitPat("b1000"), // f16->f64/i64/ui64 | ||
), | ||
BitPat("b0000") | ||
) | ||
) | ||
dontTouch(input1H) | ||
dontTouch(output1H) | ||
val fcvt = Module(new CVT64(64, false)) | ||
fcvt.io.sew := io.sew | ||
fcvt.io.fire := io.fire | ||
fcvt.io.src := io.src | ||
fcvt.io.rm := io.rm | ||
fcvt.io.opType := io.opType | ||
fcvt.io.rm := io.rm | ||
fcvt.io.isFpToVecInst := io.isFpToVecInst | ||
fcvt.io.input1H := input1H | ||
fcvt.io.output1H := output1H | ||
|
||
io.fflags := fcvt.io.fflags | ||
io.result := fcvt.io.result | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package yunsuan.scalar | ||
|
||
import chisel3._ | ||
import chisel3.util._ | ||
|
||
object FPU { | ||
|
||
case class FType(expWidth: Int, precision: Int) { | ||
val sigWidth = precision - 1 | ||
val len = expWidth + precision | ||
} | ||
|
||
val f16 = FType(5, 11) | ||
val f32 = FType(8, 24) | ||
val f64 = FType(11, 53) | ||
|
||
val ftypes = List(f16, f32, f64) | ||
|
||
val H = ftypes.indexOf(f16).U(log2Ceil(ftypes.length).W) | ||
val S = ftypes.indexOf(f32).U(log2Ceil(ftypes.length).W) | ||
val D = ftypes.indexOf(f64).U(log2Ceil(ftypes.length).W) | ||
|
||
|
||
def box(x: UInt, typeTag: UInt): UInt = { | ||
require(x.getWidth == 64) | ||
Mux(typeTag === D, x, Mux(typeTag === S, Cat(~0.U(32.W), x(31, 0)), Cat(~0.U(48.W), x(15, 0)))) | ||
} | ||
|
||
def box(x: UInt, t: FType): UInt = { | ||
if(t == f32){ | ||
Cat(~0.U(32.W), x(31, 0)) | ||
} else if(t == f64){ | ||
x(63, 0) | ||
} else { | ||
assert(cond = false, "Unknown ftype!") | ||
0.U | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package yunsuan.scalar | ||
|
||
import chisel3._ | ||
import chisel3.util._ | ||
import yunsuan.vector.VectorConvert.RoundingModle._ | ||
class IntToFP_prenorm_in extends Bundle { | ||
val int = Input(UInt(64.W)) | ||
val sign = Input(Bool()) | ||
val long = Input(Bool()) | ||
} | ||
|
||
class IntToFP_prenorm_out extends Bundle { | ||
val norm_int = Output(UInt(63.W)) | ||
val lzc = Output(UInt(6.W)) | ||
val is_zero = Output(Bool()) | ||
val sign = Output(Bool()) | ||
} | ||
|
||
/** | ||
* different fp types can share this unit | ||
*/ | ||
class IntToFP_prenorm extends Module { | ||
|
||
val io = IO(new Bundle() { | ||
val in = new IntToFP_prenorm_in | ||
val out = new IntToFP_prenorm_out | ||
}) | ||
|
||
val (in, signed_int, long_int) = (io.in.int, io.in.sign, io.in.long) | ||
|
||
val in_sign = signed_int && Mux(long_int, in(63), in(31)) | ||
val in_sext = Cat(Fill(32, in(31)), in(31, 0)) | ||
val in_raw = Mux(signed_int && !long_int, in_sext, in) | ||
val in_abs = Mux(in_sign, (~in_raw).asUInt + 1.U, in_raw) | ||
|
||
val lza = Module(new LZA(64)) | ||
lza.io.a := 0.U | ||
lza.io.b := ~in_raw | ||
|
||
val pos_lzc = CLZ(in) | ||
val neg_lzc = CLZ(lza.io.f) | ||
|
||
val lzc = Mux(in_sign, neg_lzc, pos_lzc) | ||
|
||
// eg: 001010 => 001000 | ||
val one_mask = Cat((0 until 64).reverseMap { | ||
case i @ 63 => lza.io.f(i) | ||
case i @ 0 => !lza.io.f(63, i + 1).orR | ||
case i => lza.io.f(i) && !lza.io.f(63, i + 1).orR | ||
}) | ||
|
||
val lzc_error = Mux(in_sign, !(in_abs & one_mask).orR, false.B) | ||
|
||
val in_shift_s1 = (in_abs << lzc)(62, 0) | ||
val in_norm = Mux(lzc_error, Cat(in_shift_s1.tail(1), 0.U(1.W)), in_shift_s1) | ||
|
||
io.out.norm_int := in_norm | ||
io.out.lzc := lzc + lzc_error | ||
io.out.is_zero := in === 0.U | ||
io.out.sign := in_sign | ||
|
||
} | ||
|
||
class IntToFP_postnorm(val expWidth: Int, val precision: Int) extends Module { | ||
val io = IO(new Bundle() { | ||
val in = Flipped(new IntToFP_prenorm_out) | ||
val rm = Input(UInt(3.W)) | ||
val result = Output(UInt((expWidth + precision).W)) | ||
val fflags = Output(UInt(5.W)) | ||
}) | ||
val (in, lzc, is_zero, in_sign, rm) = | ||
(io.in.norm_int, io.in.lzc, io.in.is_zero, io.in.sign, io.rm) | ||
|
||
val exp_raw = (63 + FloatPoint.expBias(expWidth)).U(11.W) - lzc | ||
val sig_raw = in.head(precision - 1) // exclude hidden bit | ||
val round_bit = in.tail(precision - 1).head(1) | ||
val sticky_bit = in.tail(precision).orR | ||
|
||
val rounder = Module(new RoundingUnit(precision - 1)) | ||
rounder.io.in := sig_raw | ||
rounder.io.roundIn := round_bit | ||
rounder.io.stickyIn := sticky_bit | ||
rounder.io.signIn := in_sign | ||
rounder.io.rm := rm | ||
|
||
val rmin = | ||
io.rm === RTZ || (in_sign && io.rm === RUP) || (!in_sign && io.rm === RDN) | ||
|
||
val nv, dz, of, uf, nx = Wire(Bool()) | ||
val ix = rounder.io.inexact | ||
val fp_exp = Mux(is_zero, 0.U, exp_raw + rounder.io.cout) | ||
val fp_sig = rounder.io.out | ||
val flow = fp_exp>((1<<expWidth)-2).U // underflow or overflow | ||
|
||
nv := false.B | ||
dz := false.B | ||
of := !in_sign && flow | ||
uf := in_sign && flow | ||
nx := flow || ix | ||
|
||
io.result := Mux(flow, Mux(rmin, Cat(in_sign, FloatPoint.maxNormExp(expWidth).U(expWidth.W), ~0.U((precision-1).W)), | ||
Cat(in_sign, ~0.U(expWidth.W), 0.U((precision-1).W))), Cat(in_sign, fp_exp, fp_sig)) | ||
io.fflags := Cat(nv, dz, of, uf, nx) | ||
} | ||
|
||
class IntToFP(val expWidth: Int, val precision: Int) extends Module { | ||
val io = IO(new IntToFP_prenorm_in { | ||
val rm = Input(UInt(3.W)) | ||
val result = Output(UInt((expWidth + precision).W)) | ||
val fflags = Output(UInt(5.W)) | ||
}) | ||
|
||
val pre_norm = Module(new IntToFP_prenorm) | ||
val post_norm = Module(new IntToFP_postnorm(expWidth, precision)) | ||
|
||
pre_norm.io.in := io | ||
post_norm.io.in := pre_norm.io.out | ||
post_norm.io.rm := io.rm | ||
|
||
io.result := post_norm.io.result | ||
io.fflags := post_norm.io.fflags | ||
} |
Oops, something went wrong.