Skip to content

Commit c9a6a34

Browse files
committed
Refactor the What4 theorems module to unify and simplify the
handling of relative vs absolute indexing for stream values.
1 parent 7ac7646 commit c9a6a34

File tree

1 file changed

+111
-158
lines changed
  • copilot-theorem/src/Copilot/Theorem

1 file changed

+111
-158
lines changed

copilot-theorem/src/Copilot/Theorem/What4.hs

+111-158
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ import qualified Control.Monad.Fail as Fail
8484
import Control.Monad.State
8585
import qualified Data.BitVector.Sized as BV
8686
import Data.Foldable (foldrM)
87-
import Data.List (elemIndex)
87+
import Data.List (elemIndex, genericLength, genericIndex)
8888
import Data.Maybe (fromJust)
8989
import qualified Data.Map as Map
9090
import Data.Parameterized.Classes
@@ -150,19 +150,19 @@ prove solver spec = do
150150
-- Define TransM action for proving properties. Doing this in TransM rather
151151
-- than IO allows us to reuse the state for each property.
152152
let proveProperties = forM (CS.specProperties spec) $ \pr -> do
153-
let bufLen (CS.Stream _ buf _ _) = length buf
153+
let bufLen (CS.Stream _ buf _ _) = genericLength buf
154154
maxBufLen = maximum (0 : (bufLen <$> CS.specStreams spec))
155155
prefix <- forM [0 .. maxBufLen - 1] $ \k -> do
156-
XBool p <- translateExprAt sym k (CS.propertyExpr pr)
156+
XBool p <- translateExpr sym (CS.propertyExpr pr) (AbsoluteOffset k)
157157
return p
158158

159159
-- translate the induction hypothesis for all values up to maxBufLen in the past
160-
ind_hyps <- forM [1 .. maxBufLen] $ \k -> do
161-
XBool hyp <- translateExpr sym (negate k) (CS.propertyExpr pr)
160+
ind_hyps <- forM [0 .. maxBufLen-1] $ \k -> do
161+
XBool hyp <- translateExpr sym (CS.propertyExpr pr) (RelativeOffset k)
162162
return hyp
163163

164164
-- translate the predicate for the "current" value
165-
XBool p <- translateExpr sym 0 (CS.propertyExpr pr)
165+
XBool p <- translateExpr sym (CS.propertyExpr pr) (RelativeOffset maxBufLen)
166166

167167
-- compute the predicate (ind_hyps ==> p)
168168
p' <- liftIO $ foldrM (WI.impliesPred sym) p ind_hyps
@@ -246,9 +246,9 @@ computePrestate ::
246246
computePrestate sym spec =
247247
do xs <- forM (CS.specStreams spec) $
248248
\CS.Stream{ CS.streamId = nm, CS.streamExprType = tp, CS.streamBuffer = buf } ->
249-
do let buflen = length buf
250-
let idxes = reverse $ map negate $ [ 1 .. buflen ]
251-
vs <- mapM (getStreamConstant sym nm) idxes
249+
do let buflen = genericLength buf
250+
let idxes = RelativeOffset <$> [ 0 .. buflen-1 ]
251+
vs <- mapM (getStreamValue sym nm) idxes
252252
return (nm, Some tp, vs)
253253
return (BisimulationProofState xs)
254254

@@ -258,12 +258,11 @@ computePoststate ::
258258
TransM t (BisimulationProofState t)
259259
computePoststate sym spec =
260260
do xs <- forM (CS.specStreams spec) $
261-
\CS.Stream{ CS.streamId = nm, CS.streamExprType = tp, CS.streamBuffer = buf, CS.streamExpr = ex } ->
262-
do let buflen = length buf
263-
let idxes = reverse $ map negate $ [ 1 .. buflen-1 ]
264-
vs <- mapM (getStreamConstant sym nm) idxes
265-
v0 <- translateExpr sym 0 ex
266-
return (nm, Some tp, vs ++ [v0])
261+
\CS.Stream{ CS.streamId = nm, CS.streamExprType = tp, CS.streamBuffer = buf } ->
262+
do let buflen = genericLength buf
263+
let idxes = RelativeOffset <$> [ 1 .. buflen ]
264+
vs <- mapM (getStreamValue sym nm) idxes
265+
return (nm, Some tp, vs)
267266
return (BisimulationProofState xs)
268267

269268
computeTriggerState ::
@@ -272,12 +271,12 @@ computeTriggerState ::
272271
TransM t [(CE.Name, WB.BoolExpr t, [(Some CT.Type, XExpr t)])]
273272
computeTriggerState sym spec = forM (CS.specTriggers spec) $
274273
\CS.Trigger{ CS.triggerName = nm, CS.triggerGuard = guard, CS.triggerArgs = args } ->
275-
do XBool guard' <- translateExpr sym 0 guard
274+
do XBool guard' <- translateExpr sym guard (RelativeOffset 0)
276275
args' <- mapM computeArg args
277276
return (nm, guard', args')
278277
where
279278
computeArg CE.UExpr{ CE.uExprType = tp, CE.uExprExpr = ex } =
280-
do v <- translateExpr sym 0 ex
279+
do v <- translateExpr sym ex (RelativeOffset 0)
281280
return (Some tp, v)
282281

283282
computeExternalInputs ::
@@ -286,12 +285,21 @@ computeExternalInputs ::
286285
computeExternalInputs sym =
287286
do exts <- Map.toList <$> gets mentionedExternals
288287
forM exts $ \(nm, Some tp) ->
289-
do v <- getExternConstant sym tp nm 0
288+
do v <- getExternConstant sym tp nm (RelativeOffset 0)
290289
return (nm, Some tp, v)
291290

292291
--------------------------------------------------------------------------------
293292
-- What4 translation
294293

294+
data StreamOffset
295+
= AbsoluteOffset !Integer
296+
| RelativeOffset !Integer
297+
deriving (Eq, Ord, Show)
298+
299+
addOffset :: StreamOffset -> CE.DropIdx -> StreamOffset
300+
addOffset (AbsoluteOffset i) j = AbsoluteOffset (i + toInteger j)
301+
addOffset (RelativeOffset i) j = RelativeOffset (i + toInteger j)
302+
295303
-- | the state for translating Copilot expressions into What4 expressions. As we
296304
-- translate, we generate fresh symbolic constants for external variables and
297305
-- for stream variables. We need to only generate one constant per variable, so
@@ -308,21 +316,11 @@ computeExternalInputs sym =
308316
data TransState t = TransState {
309317
-- | Map keeping track of all external variables encountered during translation.
310318
mentionedExternals :: Map.Map CE.Name (Some CT.Type),
311-
-- | Map of all external variables we encounter during translation. These are
312-
-- just fresh constants. The offset indicates how many timesteps in the past
313-
-- this constant represents for that stream.
314-
externVars :: Map.Map (CE.Name, Int) (XExpr t),
315-
-- | Map of external variables at specific indices (positive), rather than
316-
-- offset into the past. This is for interpreting streams at specific offsets.
317-
externVarsAt :: Map.Map (CE.Name, Int) (XExpr t),
318-
-- | Map from (stream id, negative offset) to fresh constant. These are all
319-
-- constants representing the values of a stream at some point in the past.
320-
-- The offset (ALWAYS NEGATIVE) indicates how many timesteps in the past
321-
-- this constant represents for that stream.
322-
streamConstants :: Map.Map (CE.Id, Int) (XExpr t),
323-
-- | Map from stream ids to the streams themselves. This value is never
324-
-- modified, but I didn't want to make this an RWS, so it's represented as a
325-
-- stateful value.
319+
320+
externVars :: Map.Map (CE.Name, StreamOffset) (XExpr t),
321+
322+
streamValues :: Map.Map (CE.Id, StreamOffset) (XExpr t),
323+
326324
streams :: Map.Map CE.Id CS.Stream,
327325
-- | Binary power operator, represented as an uninterpreted function.
328326
pow :: WB.ExprSymFn t
@@ -367,7 +365,14 @@ runTransM sym spec m =
367365
(\stream -> (CS.streamId stream, stream)) <$> CS.specStreams spec
368366
pow <- WI.freshTotalUninterpFn sym (WI.safeSymbol "pow") knownRepr knownRepr
369367
logb <- WI.freshTotalUninterpFn sym (WI.safeSymbol "logb") knownRepr knownRepr
370-
let st = TransState Map.empty Map.empty Map.empty Map.empty streamMap pow logb
368+
let st = TransState
369+
{ mentionedExternals = mempty
370+
, externVars = mempty
371+
, streamValues = mempty
372+
, streams = streamMap
373+
, pow = pow
374+
, logb = logb
375+
}
371376

372377
(res, _) <- runStateT (unTransM m) st
373378
return res
@@ -524,142 +529,90 @@ freshCPConstant sym nm tp = case tp of
524529
elts <- forM (CT.toValues stp) $ \(CT.Value ftp _) -> freshCPConstant sym "" ftp
525530
return $ XStruct elts
526531

527-
-- | Get the constant for a given stream id and some offset into the past. This
528-
-- should only be called with a strictly negative offset. When this function
529-
-- gets called for the first time for a given (streamId, offset) pair, it
530-
-- generates a fresh constant and stores it in an internal map. Thereafter, this
531-
-- function will just return that constant when called with the same pair.
532-
getStreamConstant :: WB.ExprBuilder t st fs -> CE.Id -> Int -> TransM t (XExpr t)
533-
getStreamConstant sym streamId offset = do
534-
scs <- gets streamConstants
535-
case Map.lookup (streamId, offset) scs of
536-
Just xe -> return xe
537-
Nothing -> do
538-
CS.Stream _ _ _ tp <- getStreamDef streamId
539-
let nm = show streamId ++ "_" ++ show offset
540-
xe <- liftIO $ freshCPConstant sym nm tp
541-
modify (\st -> st { streamConstants = Map.insert (streamId, offset) xe scs })
542-
return xe
543-
544-
-- | Get the constant for a given external variable and some offset into the
545-
-- past. This should only be called with a strictly negative offset. When this
546-
-- function gets called for the first time for a given (var, offset) pair, it
547-
-- generates a fresh constant and stores it in an internal map. Thereafter, this
548-
-- function will just return that constant when called with the same pair.
549-
getExternConstant :: WB.ExprBuilder t st fs
550-
-> CT.Type a
551-
-> CE.Name
552-
-> Int
553-
-> TransM t (XExpr t)
554-
getExternConstant sym tp var offset = do
555-
es <- gets externVars
556-
case Map.lookup (var, offset) es of
557-
Just xe -> return xe
558-
Nothing -> do
559-
xe <- liftIO $ freshCPConstant sym var tp
560-
modify (\st -> st { externVars = Map.insert (var, offset) xe es
561-
, mentionedExternals = Map.insert var (Some tp) (mentionedExternals st)
562-
} )
563-
return xe
564-
565-
-- | Get the constant for a given external variable at some specific timestep.
566-
getExternConstantAt :: WB.ExprBuilder t st fs
567-
-> CT.Type a
568-
-> CE.Name
569-
-> Int
570-
-> TransM t (XExpr t)
571-
getExternConstantAt sym tp var ix = do
572-
es <- gets externVarsAt
573-
case Map.lookup (var, ix) es of
574-
Just xe -> return xe
575-
Nothing -> do
576-
xe <- liftIO $ freshCPConstant sym var tp
577-
modify (\st -> st { externVarsAt = Map.insert (var, ix) xe es
578-
, mentionedExternals = Map.insert var (Some tp) (mentionedExternals st)
579-
} )
580-
return xe
581532

582-
-- | Retrieve a stream definition given its id.
583-
getStreamDef :: CE.Id -> TransM t CS.Stream
584-
getStreamDef streamId = fromJust <$> gets (Map.lookup streamId . streams)
533+
getStreamValue :: WB.ExprBuilder t st fs -> CE.Id -> StreamOffset -> TransM t (XExpr t)
534+
getStreamValue sym streamId offset =
535+
do svs <- gets streamValues
536+
case Map.lookup (streamId, offset) svs of
537+
Just xe -> return xe
538+
Nothing ->
539+
do streamDef <- getStreamDef streamId
540+
xe <- computeStreamValue sym streamDef offset
541+
modify (\st -> st{ streamValues = Map.insert (streamId, offset) xe (streamValues st) })
542+
return xe
543+
544+
computeStreamValue ::
545+
WB.ExprBuilder t st fs -> CS.Stream -> StreamOffset -> TransM t (XExpr t)
546+
computeStreamValue
547+
sym
548+
CS.Stream
549+
{ CS.streamId = id, CS.streamBuffer = buf, CS.streamExpr = ex, CS.streamExprType = tp }
550+
offset =
551+
case offset of
552+
AbsoluteOffset i
553+
| i < 0 -> fail ("Invalid absolute offset " ++ show i ++ " for stream " ++ show id)
554+
| i < len -> liftIO (translateConstExpr sym tp (genericIndex buf i))
555+
| otherwise -> translateExpr sym ex (AbsoluteOffset (i - len))
556+
RelativeOffset i
557+
| i < 0 -> fail ("Invalid relative offset " ++ show i ++ " for stream " ++ show id)
558+
| i < len -> let nm = "s" ++ show id ++ "_r" ++ show i
559+
in liftIO (freshCPConstant sym nm tp)
560+
| otherwise -> translateExpr sym ex (RelativeOffset (i - len))
585561

586-
-- | Translate an expression into a what4 representation. The int offset keeps
587-
-- track of how many timesteps into the past each variable is referring to.
588-
-- Initially the value should be zero, but when we translate a stream, the
589-
-- offset is recomputed based on the length of that stream's prefix (subtracted)
590-
-- and the drop index (added).
591-
translateExpr :: WB.ExprBuilder t st fs
592-
-> Int
593-
-- ^ number of timesteps in the past we are currently looking
594-
-- (must always be <= 0)
595-
-> CE.Expr a
596-
-> TransM t (XExpr t)
597-
translateExpr sym offset e = case e of
598-
CE.Const tp a -> liftIO $ translateConstExpr sym tp a
599-
CE.Drop _tp ix streamId ->
600-
do CS.Stream _ buf e _ <- getStreamDef streamId
601-
let newidx = offset + fromIntegral ix - length buf
602-
if newidx < 0 then
603-
-- If we are referencing a past value of this stream, just return an
604-
-- unconstrained constant.
605-
getStreamConstant sym streamId newidx
606-
else
607-
-- If we are referencing a current or future value of this stream, we need
608-
-- to translate the stream's expression, using an offset computed based on
609-
-- the current offset (negative or 0), the drop index (positive or 0), and
610-
-- the length of the stream's buffer (subtracted).
611-
translateExpr sym newidx e
562+
where
563+
len = genericLength buf
612564

613-
CE.Local _ _ _ _ _ -> error "translateExpr: Local unimplemented"
614-
CE.Var _ _ -> error "translateExpr: Var unimplemented"
565+
translateExpr :: WB.ExprBuilder t st fs -> CE.Expr a -> StreamOffset -> TransM t (XExpr t)
566+
translateExpr sym e offset = case e of
567+
CE.Const tp a -> liftIO $ translateConstExpr sym tp a
568+
CE.Drop _tp ix streamId -> getStreamValue sym streamId (addOffset offset ix)
615569
CE.ExternVar tp nm _prefix -> getExternConstant sym tp nm offset
616-
CE.Op1 op e -> liftIO . translateOp1 sym op =<< translateExpr sym offset e
570+
CE.Op1 op e -> liftIO . translateOp1 sym op =<< translateExpr sym e offset
617571
CE.Op2 op e1 e2 -> do
618-
xe1 <- translateExpr sym offset e1
619-
xe2 <- translateExpr sym offset e2
572+
xe1 <- translateExpr sym e1 offset
573+
xe2 <- translateExpr sym e2 offset
620574
powFn <- gets pow
621575
logbFn <- gets logb
622576
liftIO $ translateOp2 sym powFn logbFn op xe1 xe2
623577
CE.Op3 op e1 e2 e3 -> do
624-
xe1 <- translateExpr sym offset e1
625-
xe2 <- translateExpr sym offset e2
626-
xe3 <- translateExpr sym offset e3
578+
xe1 <- translateExpr sym e1 offset
579+
xe2 <- translateExpr sym e2 offset
580+
xe3 <- translateExpr sym e3 offset
627581
liftIO $ translateOp3 sym op xe1 xe2 xe3
628582
CE.Label _ _ _ -> error "translateExpr: Label unimplemented"
583+
CE.Local _ _ _ _ _ -> error "translateExpr: Local unimplemented"
584+
CE.Var _ _ -> error "translateExpr: Var unimplemented"
585+
586+
getExternConstant ::
587+
WB.ExprBuilder t st fs -> CT.Type a -> CE.Name -> StreamOffset -> TransM t (XExpr t)
588+
getExternConstant sym tp nm offset =
589+
do es <- gets externVars
590+
case Map.lookup (nm, offset) es of
591+
Just xe -> return xe
592+
Nothing -> do
593+
xe <- computeExternConstant sym tp nm offset
594+
modify (\st -> st { externVars = Map.insert (nm, offset) xe (externVars st)
595+
, mentionedExternals = Map.insert nm (Some tp) (mentionedExternals st)
596+
} )
597+
return xe
598+
599+
computeExternConstant ::
600+
WB.ExprBuilder t st fs -> CT.Type a -> CE.Name -> StreamOffset -> TransM t (XExpr t)
601+
computeExternConstant sym tp nm offset =
602+
case offset of
603+
AbsoluteOffset i
604+
| i < 0 -> fail ("Invalid absolute offset " ++ show i ++ " for external stream " ++ nm)
605+
| otherwise -> let nm' = nm ++ "_a" ++ show i
606+
in liftIO (freshCPConstant sym nm' tp)
607+
RelativeOffset i
608+
| i < 0 -> fail ("Invalid relative offset " ++ show i ++ " for external stream " ++ nm)
609+
| otherwise -> let nm' = nm ++ "_r" ++ show i
610+
in liftIO (freshCPConstant sym nm' tp)
611+
612+
-- | Retrieve a stream definition given its id.
613+
getStreamDef :: CE.Id -> TransM t CS.Stream
614+
getStreamDef streamId = fromJust <$> gets (Map.lookup streamId . streams)
629615

630-
-- | Translate an expression into a what4 representation at a /specific/
631-
-- timestep, rather than "at some indeterminate point in the future."
632-
translateExprAt :: WB.ExprBuilder t st fs
633-
-> Int
634-
-- ^ Index of timestep
635-
-> CE.Expr a
636-
-- ^ stream expression
637-
-> TransM t (XExpr t)
638-
translateExprAt sym k e = do
639-
case e of
640-
CE.Const tp a -> liftIO $ translateConstExpr sym tp a
641-
CE.Drop _tp ix streamId -> do
642-
CS.Stream _ buf e tp <- getStreamDef streamId
643-
if k' < length buf
644-
then liftIO $ translateConstExpr sym tp (buf !! k')
645-
else translateExprAt sym (k' - length buf) e
646-
where k' = k + fromIntegral ix
647-
CE.Local _ _ _ _ _ -> error "translateExpr: Local unimplemented"
648-
CE.Var _ _ -> error "translateExpr: Var unimplemented"
649-
CE.ExternVar tp nm _prefix -> getExternConstantAt sym tp nm k
650-
CE.Op1 op e -> liftIO . translateOp1 sym op =<< translateExprAt sym k e
651-
CE.Op2 op e1 e2 -> do
652-
xe1 <- translateExprAt sym k e1
653-
xe2 <- translateExprAt sym k e2
654-
powFn <- gets pow
655-
logbFn <- gets logb
656-
liftIO $ translateOp2 sym powFn logbFn op xe1 xe2
657-
CE.Op3 op e1 e2 e3 -> do
658-
xe1 <- translateExprAt sym k e1
659-
xe2 <- translateExprAt sym k e2
660-
xe3 <- translateExprAt sym k e3
661-
liftIO $ translateOp3 sym op xe1 xe2 xe3
662-
CE.Label _ _ _ -> error "translateExpr: Label unimplemented"
663616

664617
type BVOp1 w t = (KnownNat w, 1 <= w) => WB.BVExpr t w -> IO (WB.BVExpr t w)
665618

0 commit comments

Comments
 (0)